001 package aima.probability.decision; 002 003 import java.util.List; 004 005 import aima.probability.Randomizer; 006 import aima.util.Pair; 007 008 /** 009 * @author Ravi Mohan 010 * 011 */ 012 013 public class MDP<STATE_TYPE, ACTION_TYPE> { 014 private STATE_TYPE initialState; 015 016 private MDPTransitionModel<STATE_TYPE, ACTION_TYPE> transitionModel; 017 018 private MDPRewardFunction<STATE_TYPE> rewardFunction; 019 020 private List<STATE_TYPE> nonFinalstates, terminalStates; 021 022 private MDPSource<STATE_TYPE, ACTION_TYPE> source; 023 024 public MDP(MDPSource<STATE_TYPE, ACTION_TYPE> source) { 025 this.initialState = source.getInitialState(); 026 this.transitionModel = source.getTransitionModel(); 027 this.rewardFunction = source.getRewardFunction(); 028 this.nonFinalstates = source.getNonFinalStates(); 029 this.terminalStates = source.getFinalStates(); 030 this.source = source; 031 } 032 033 public MDP<STATE_TYPE, ACTION_TYPE> emptyMdp() { 034 MDP<STATE_TYPE, ACTION_TYPE> mdp = new MDP<STATE_TYPE, ACTION_TYPE>( 035 source); 036 mdp.rewardFunction = new MDPRewardFunction<STATE_TYPE>(); 037 mdp.rewardFunction.setReward(initialState, rewardFunction 038 .getRewardFor(initialState)); 039 mdp.transitionModel = new MDPTransitionModel<STATE_TYPE, ACTION_TYPE>( 040 terminalStates); 041 return mdp; 042 } 043 044 public MDPUtilityFunction<STATE_TYPE> valueIteration(double gamma, 045 double error, double delta) { 046 MDPUtilityFunction<STATE_TYPE> U = initialUtilityFunction(); 047 MDPUtilityFunction<STATE_TYPE> U_dash = initialUtilityFunction(); 048 double delta_max = (error * gamma) / (1 - gamma); 049 do { 050 U = U_dash.copy(); 051 // System.out.println(U); 052 delta = 0.0; 053 for (STATE_TYPE s : nonFinalstates) { 054 Pair<ACTION_TYPE, Double> highestUtilityTransition = transitionModel 055 .getTransitionWithMaximumExpectedUtility(s, U); 056 double utility = rewardFunction.getRewardFor(s) 057 + (gamma * highestUtilityTransition.getSecond()); 058 U_dash.setUtility(s, utility); 059 if ((Math.abs(U_dash.getUtility(s) - U.getUtility(s))) > delta) { 060 delta = Math.abs(U_dash.getUtility(s) - U.getUtility(s)); 061 } 062 063 } 064 } while (delta < delta_max); 065 return U; 066 067 } 068 069 public MDPUtilityFunction<STATE_TYPE> valueIterationForFixedIterations( 070 int numberOfIterations, double gamma) { 071 MDPUtilityFunction<STATE_TYPE> utilityFunction = initialUtilityFunction(); 072 073 for (int i = 0; i < numberOfIterations; i++) { 074 MDPUtilityFunction<STATE_TYPE> cachedUtilityFunction = utilityFunction 075 .copy(); 076 077 Pair<MDPUtilityFunction<STATE_TYPE>, Double> result = valueIterateOnce( 078 gamma, utilityFunction); 079 utilityFunction = result.getFirst(); 080 double maxUtilityGrowth = result.getSecond(); 081 // System.out.println("maxUtilityGrowth " + maxUtilityGrowth); 082 083 } 084 085 return utilityFunction; 086 } 087 088 public MDPUtilityFunction<STATE_TYPE> valueIterationTillMAximumUtilityGrowthFallsBelowErrorMargin( 089 double gamma, double errorMargin) { 090 int iterationCounter = 0; 091 double maxUtilityGrowth = 0.0; 092 MDPUtilityFunction<STATE_TYPE> utilityFunction = initialUtilityFunction(); 093 do { 094 Pair<MDPUtilityFunction<STATE_TYPE>, Double> result = valueIterateOnce( 095 gamma, utilityFunction); 096 utilityFunction = result.getFirst(); 097 maxUtilityGrowth = result.getSecond(); 098 iterationCounter++; 099 // System.out.println("Itration Number" +iterationCounter + " max 100 // utility growth " + maxUtilityGrowth); 101 102 } while (maxUtilityGrowth > errorMargin); 103 104 return utilityFunction; 105 } 106 107 public Pair<MDPUtilityFunction<STATE_TYPE>, Double> valueIterateOnce( 108 double gamma, MDPUtilityFunction<STATE_TYPE> presentUtilityFunction) { 109 double maxUtilityGrowth = 0.0; 110 MDPUtilityFunction<STATE_TYPE> newUtilityFunction = new MDPUtilityFunction<STATE_TYPE>(); 111 112 for (STATE_TYPE s : nonFinalstates) { 113 Pair<ACTION_TYPE, Double> highestUtilityTransition = transitionModel 114 .getTransitionWithMaximumExpectedUtility(s, 115 presentUtilityFunction); 116 // double utility = rewardFunction.getRewardFor(s) 117 // + (gamma * highestUtilityTransition.getSecond()); 118 119 double utility = valueIterateOnceForGivenState(gamma, 120 presentUtilityFunction, s); 121 122 double differenceInUtility = Math.abs(utility 123 - presentUtilityFunction.getUtility(s)); 124 if (differenceInUtility > maxUtilityGrowth) { 125 maxUtilityGrowth = differenceInUtility; 126 } 127 newUtilityFunction.setUtility(s, utility); 128 129 for (STATE_TYPE state : terminalStates) { 130 newUtilityFunction.setUtility(state, presentUtilityFunction 131 .getUtility(state)); 132 } 133 134 } 135 136 return new Pair<MDPUtilityFunction<STATE_TYPE>, Double>( 137 newUtilityFunction, maxUtilityGrowth); 138 139 } 140 141 private double valueIterateOnceForGivenState(double gamma, 142 MDPUtilityFunction<STATE_TYPE> presentUtilityFunction, 143 STATE_TYPE state) { 144 Pair<ACTION_TYPE, Double> highestUtilityTransition = transitionModel 145 .getTransitionWithMaximumExpectedUtility(state, 146 presentUtilityFunction); 147 double utility = rewardFunction.getRewardFor(state) 148 + (gamma * highestUtilityTransition.getSecond()); 149 150 return utility; 151 } 152 153 public MDPPolicy<STATE_TYPE, ACTION_TYPE> policyIteration(double gamma) { 154 MDPUtilityFunction<STATE_TYPE> U = initialUtilityFunction(); 155 MDPPolicy<STATE_TYPE, ACTION_TYPE> pi = randomPolicy(); 156 boolean unchanged = false; 157 do { 158 unchanged = true; 159 160 U = policyEvaluation(pi, U, gamma, 3); 161 for (STATE_TYPE s : nonFinalstates) { 162 Pair<ACTION_TYPE, Double> maxTransit = transitionModel 163 .getTransitionWithMaximumExpectedUtility(s, U); 164 Pair<ACTION_TYPE, Double> maxPolicyTransit = transitionModel 165 .getTransitionWithMaximumExpectedUtilityUsingPolicy(pi, 166 s, U); 167 168 if (maxTransit.getSecond() > maxPolicyTransit.getSecond()) { 169 pi.setAction(s, maxTransit.getFirst()); 170 unchanged = false; 171 } 172 } 173 } while (unchanged == false); 174 return pi; 175 } 176 177 public MDPUtilityFunction<STATE_TYPE> policyEvaluation( 178 MDPPolicy<STATE_TYPE, ACTION_TYPE> pi, 179 MDPUtilityFunction<STATE_TYPE> U, double gamma, int iterations) { 180 MDPUtilityFunction<STATE_TYPE> U_dash = U.copy(); 181 for (int i = 0; i < iterations; i++) { 182 183 U_dash = valueIterateOnceWith(gamma, pi, U_dash); 184 } 185 return U_dash; 186 } 187 188 private MDPUtilityFunction<STATE_TYPE> valueIterateOnceWith(double gamma, 189 MDPPolicy<STATE_TYPE, ACTION_TYPE> pi, 190 MDPUtilityFunction<STATE_TYPE> U) { 191 MDPUtilityFunction<STATE_TYPE> U_dash = U.copy(); 192 for (STATE_TYPE s : nonFinalstates) { 193 194 Pair<ACTION_TYPE, Double> highestPolicyTransition = transitionModel 195 .getTransitionWithMaximumExpectedUtilityUsingPolicy(pi, s, 196 U); 197 double utility = rewardFunction.getRewardFor(s) 198 + (gamma * highestPolicyTransition.getSecond()); 199 U_dash.setUtility(s, utility); 200 201 } 202 // System.out.println("ValueIterationOnce before " + U); 203 // System.out.println("ValueIterationOnce after " + U_dash); 204 return U_dash; 205 } 206 207 public MDPPolicy<STATE_TYPE, ACTION_TYPE> randomPolicy() { 208 MDPPolicy<STATE_TYPE, ACTION_TYPE> policy = new MDPPolicy<STATE_TYPE, ACTION_TYPE>(); 209 for (STATE_TYPE s : nonFinalstates) { 210 policy.setAction(s, transitionModel.randomActionFor(s)); 211 } 212 return policy; 213 } 214 215 public MDPUtilityFunction<STATE_TYPE> initialUtilityFunction() { 216 217 return rewardFunction.asUtilityFunction(); 218 } 219 220 public STATE_TYPE getInitialState() { 221 return initialState; 222 } 223 224 public double getRewardFor(STATE_TYPE state) { 225 return rewardFunction.getRewardFor(state); 226 } 227 228 public void setReward(STATE_TYPE state, double reward) { 229 rewardFunction.setReward(state, reward); 230 } 231 232 public void setTransitionProbability( 233 MDPTransition<STATE_TYPE, ACTION_TYPE> transition, 234 double probability) { 235 transitionModel.setTransitionProbability(transition.getInitialState(), 236 transition.getAction(), transition.getDestinationState(), 237 probability); 238 } 239 240 public double getTransitionProbability( 241 MDPTransition<STATE_TYPE, ACTION_TYPE> transition) { 242 return transitionModel.getTransitionProbability(transition 243 .getInitialState(), transition.getAction(), transition 244 .getDestinationState()); 245 } 246 247 public MDPPerception<STATE_TYPE> execute(STATE_TYPE state, 248 ACTION_TYPE action, Randomizer r) { 249 return source.execute(state, action, r); 250 } 251 252 public boolean isTerminalState(STATE_TYPE state) { 253 return terminalStates.contains(state); 254 } 255 256 public List<MDPTransition<STATE_TYPE, ACTION_TYPE>> getTransitionsWith( 257 STATE_TYPE initialState, ACTION_TYPE action) { 258 return transitionModel.getTransitionsWithStartingStateAndAction( 259 initialState, action); 260 } 261 262 public List<ACTION_TYPE> getAllActions() { 263 return source.getAllActions(); 264 } 265 266 @Override 267 public String toString() { 268 return "initial State = " + initialState.toString() 269 + "\n rewardFunction = " + rewardFunction.toString() 270 + "\n transitionModel = " + transitionModel.toString() 271 + "\n states = " + nonFinalstates.toString(); 272 } 273 274 }