001    package aima.probability.decision;
002    
003    import java.util.List;
004    
005    import aima.probability.Randomizer;
006    import aima.util.Pair;
007    
008    /**
009     * @author Ravi Mohan
010     * 
011     */
012    
013    public class MDP<STATE_TYPE, ACTION_TYPE> {
014            private STATE_TYPE initialState;
015    
016            private MDPTransitionModel<STATE_TYPE, ACTION_TYPE> transitionModel;
017    
018            private MDPRewardFunction<STATE_TYPE> rewardFunction;
019    
020            private List<STATE_TYPE> nonFinalstates, terminalStates;
021    
022            private MDPSource<STATE_TYPE, ACTION_TYPE> source;
023    
024            public MDP(MDPSource<STATE_TYPE, ACTION_TYPE> source) {
025                    this.initialState = source.getInitialState();
026                    this.transitionModel = source.getTransitionModel();
027                    this.rewardFunction = source.getRewardFunction();
028                    this.nonFinalstates = source.getNonFinalStates();
029                    this.terminalStates = source.getFinalStates();
030                    this.source = source;
031            }
032    
033            public MDP<STATE_TYPE, ACTION_TYPE> emptyMdp() {
034                    MDP<STATE_TYPE, ACTION_TYPE> mdp = new MDP<STATE_TYPE, ACTION_TYPE>(
035                                    source);
036                    mdp.rewardFunction = new MDPRewardFunction<STATE_TYPE>();
037                    mdp.rewardFunction.setReward(initialState, rewardFunction
038                                    .getRewardFor(initialState));
039                    mdp.transitionModel = new MDPTransitionModel<STATE_TYPE, ACTION_TYPE>(
040                                    terminalStates);
041                    return mdp;
042            }
043    
044            public MDPUtilityFunction<STATE_TYPE> valueIteration(double gamma,
045                            double error, double delta) {
046                    MDPUtilityFunction<STATE_TYPE> U = initialUtilityFunction();
047                    MDPUtilityFunction<STATE_TYPE> U_dash = initialUtilityFunction();
048                    double delta_max = (error * gamma) / (1 - gamma);
049                    do {
050                            U = U_dash.copy();
051                            // System.out.println(U);
052                            delta = 0.0;
053                            for (STATE_TYPE s : nonFinalstates) {
054                                    Pair<ACTION_TYPE, Double> highestUtilityTransition = transitionModel
055                                                    .getTransitionWithMaximumExpectedUtility(s, U);
056                                    double utility = rewardFunction.getRewardFor(s)
057                                                    + (gamma * highestUtilityTransition.getSecond());
058                                    U_dash.setUtility(s, utility);
059                                    if ((Math.abs(U_dash.getUtility(s) - U.getUtility(s))) > delta) {
060                                            delta = Math.abs(U_dash.getUtility(s) - U.getUtility(s));
061                                    }
062    
063                            }
064                    } while (delta < delta_max);
065                    return U;
066    
067            }
068    
069            public MDPUtilityFunction<STATE_TYPE> valueIterationForFixedIterations(
070                            int numberOfIterations, double gamma) {
071                    MDPUtilityFunction<STATE_TYPE> utilityFunction = initialUtilityFunction();
072    
073                    for (int i = 0; i < numberOfIterations; i++) {
074                            MDPUtilityFunction<STATE_TYPE> cachedUtilityFunction = utilityFunction
075                                            .copy();
076    
077                            Pair<MDPUtilityFunction<STATE_TYPE>, Double> result = valueIterateOnce(
078                                            gamma, utilityFunction);
079                            utilityFunction = result.getFirst();
080                            double maxUtilityGrowth = result.getSecond();
081                            // System.out.println("maxUtilityGrowth " + maxUtilityGrowth);
082    
083                    }
084    
085                    return utilityFunction;
086            }
087    
088            public MDPUtilityFunction<STATE_TYPE> valueIterationTillMAximumUtilityGrowthFallsBelowErrorMargin(
089                            double gamma, double errorMargin) {
090                    int iterationCounter = 0;
091                    double maxUtilityGrowth = 0.0;
092                    MDPUtilityFunction<STATE_TYPE> utilityFunction = initialUtilityFunction();
093                    do {
094                            Pair<MDPUtilityFunction<STATE_TYPE>, Double> result = valueIterateOnce(
095                                            gamma, utilityFunction);
096                            utilityFunction = result.getFirst();
097                            maxUtilityGrowth = result.getSecond();
098                            iterationCounter++;
099                            // System.out.println("Itration Number" +iterationCounter + " max
100                            // utility growth " + maxUtilityGrowth);
101    
102                    } while (maxUtilityGrowth > errorMargin);
103    
104                    return utilityFunction;
105            }
106    
107            public Pair<MDPUtilityFunction<STATE_TYPE>, Double> valueIterateOnce(
108                            double gamma, MDPUtilityFunction<STATE_TYPE> presentUtilityFunction) {
109                    double maxUtilityGrowth = 0.0;
110                    MDPUtilityFunction<STATE_TYPE> newUtilityFunction = new MDPUtilityFunction<STATE_TYPE>();
111    
112                    for (STATE_TYPE s : nonFinalstates) {
113                            Pair<ACTION_TYPE, Double> highestUtilityTransition = transitionModel
114                                            .getTransitionWithMaximumExpectedUtility(s,
115                                                            presentUtilityFunction);
116                            // double utility = rewardFunction.getRewardFor(s)
117                            // + (gamma * highestUtilityTransition.getSecond());
118    
119                            double utility = valueIterateOnceForGivenState(gamma,
120                                            presentUtilityFunction, s);
121    
122                            double differenceInUtility = Math.abs(utility
123                                            - presentUtilityFunction.getUtility(s));
124                            if (differenceInUtility > maxUtilityGrowth) {
125                                    maxUtilityGrowth = differenceInUtility;
126                            }
127                            newUtilityFunction.setUtility(s, utility);
128    
129                            for (STATE_TYPE state : terminalStates) {
130                                    newUtilityFunction.setUtility(state, presentUtilityFunction
131                                                    .getUtility(state));
132                            }
133    
134                    }
135    
136                    return new Pair<MDPUtilityFunction<STATE_TYPE>, Double>(
137                                    newUtilityFunction, maxUtilityGrowth);
138    
139            }
140    
141            private double valueIterateOnceForGivenState(double gamma,
142                            MDPUtilityFunction<STATE_TYPE> presentUtilityFunction,
143                            STATE_TYPE state) {
144                    Pair<ACTION_TYPE, Double> highestUtilityTransition = transitionModel
145                                    .getTransitionWithMaximumExpectedUtility(state,
146                                                    presentUtilityFunction);
147                    double utility = rewardFunction.getRewardFor(state)
148                                    + (gamma * highestUtilityTransition.getSecond());
149    
150                    return utility;
151            }
152    
153            public MDPPolicy<STATE_TYPE, ACTION_TYPE> policyIteration(double gamma) {
154                    MDPUtilityFunction<STATE_TYPE> U = initialUtilityFunction();
155                    MDPPolicy<STATE_TYPE, ACTION_TYPE> pi = randomPolicy();
156                    boolean unchanged = false;
157                    do {
158                            unchanged = true;
159    
160                            U = policyEvaluation(pi, U, gamma, 3);
161                            for (STATE_TYPE s : nonFinalstates) {
162                                    Pair<ACTION_TYPE, Double> maxTransit = transitionModel
163                                                    .getTransitionWithMaximumExpectedUtility(s, U);
164                                    Pair<ACTION_TYPE, Double> maxPolicyTransit = transitionModel
165                                                    .getTransitionWithMaximumExpectedUtilityUsingPolicy(pi,
166                                                                    s, U);
167    
168                                    if (maxTransit.getSecond() > maxPolicyTransit.getSecond()) {
169                                            pi.setAction(s, maxTransit.getFirst());
170                                            unchanged = false;
171                                    }
172                            }
173                    } while (unchanged == false);
174                    return pi;
175            }
176    
177            public MDPUtilityFunction<STATE_TYPE> policyEvaluation(
178                            MDPPolicy<STATE_TYPE, ACTION_TYPE> pi,
179                            MDPUtilityFunction<STATE_TYPE> U, double gamma, int iterations) {
180                    MDPUtilityFunction<STATE_TYPE> U_dash = U.copy();
181                    for (int i = 0; i < iterations; i++) {
182    
183                            U_dash = valueIterateOnceWith(gamma, pi, U_dash);
184                    }
185                    return U_dash;
186            }
187    
188            private MDPUtilityFunction<STATE_TYPE> valueIterateOnceWith(double gamma,
189                            MDPPolicy<STATE_TYPE, ACTION_TYPE> pi,
190                            MDPUtilityFunction<STATE_TYPE> U) {
191                    MDPUtilityFunction<STATE_TYPE> U_dash = U.copy();
192                    for (STATE_TYPE s : nonFinalstates) {
193    
194                            Pair<ACTION_TYPE, Double> highestPolicyTransition = transitionModel
195                                            .getTransitionWithMaximumExpectedUtilityUsingPolicy(pi, s,
196                                                            U);
197                            double utility = rewardFunction.getRewardFor(s)
198                                            + (gamma * highestPolicyTransition.getSecond());
199                            U_dash.setUtility(s, utility);
200    
201                    }
202                    // System.out.println("ValueIterationOnce before " + U);
203                    // System.out.println("ValueIterationOnce after " + U_dash);
204                    return U_dash;
205            }
206    
207            public MDPPolicy<STATE_TYPE, ACTION_TYPE> randomPolicy() {
208                    MDPPolicy<STATE_TYPE, ACTION_TYPE> policy = new MDPPolicy<STATE_TYPE, ACTION_TYPE>();
209                    for (STATE_TYPE s : nonFinalstates) {
210                            policy.setAction(s, transitionModel.randomActionFor(s));
211                    }
212                    return policy;
213            }
214    
215            public MDPUtilityFunction<STATE_TYPE> initialUtilityFunction() {
216    
217                    return rewardFunction.asUtilityFunction();
218            }
219    
220            public STATE_TYPE getInitialState() {
221                    return initialState;
222            }
223    
224            public double getRewardFor(STATE_TYPE state) {
225                    return rewardFunction.getRewardFor(state);
226            }
227    
228            public void setReward(STATE_TYPE state, double reward) {
229                    rewardFunction.setReward(state, reward);
230            }
231    
232            public void setTransitionProbability(
233                            MDPTransition<STATE_TYPE, ACTION_TYPE> transition,
234                            double probability) {
235                    transitionModel.setTransitionProbability(transition.getInitialState(),
236                                    transition.getAction(), transition.getDestinationState(),
237                                    probability);
238            }
239    
240            public double getTransitionProbability(
241                            MDPTransition<STATE_TYPE, ACTION_TYPE> transition) {
242                    return transitionModel.getTransitionProbability(transition
243                                    .getInitialState(), transition.getAction(), transition
244                                    .getDestinationState());
245            }
246    
247            public MDPPerception<STATE_TYPE> execute(STATE_TYPE state,
248                            ACTION_TYPE action, Randomizer r) {
249                    return source.execute(state, action, r);
250            }
251    
252            public boolean isTerminalState(STATE_TYPE state) {
253                    return terminalStates.contains(state);
254            }
255    
256            public List<MDPTransition<STATE_TYPE, ACTION_TYPE>> getTransitionsWith(
257                            STATE_TYPE initialState, ACTION_TYPE action) {
258                    return transitionModel.getTransitionsWithStartingStateAndAction(
259                                    initialState, action);
260            }
261    
262            public List<ACTION_TYPE> getAllActions() {
263                    return source.getAllActions();
264            }
265    
266            @Override
267            public String toString() {
268                    return "initial State = " + initialState.toString()
269                                    + "\n rewardFunction = " + rewardFunction.toString()
270                                    + "\n transitionModel = " + transitionModel.toString()
271                                    + "\n states = " + nonFinalstates.toString();
272            }
273    
274    }