001    package aima.probability.decision;
002    
003    import java.util.ArrayList;
004    import java.util.Hashtable;
005    import java.util.List;
006    
007    import aima.util.Pair;
008    import aima.util.Util;
009    
010    /**
011     * @author Ravi Mohan
012     * 
013     */
014    
015    public class MDPTransitionModel<STATE_TYPE, ACTION_TYPE> {
016    
017            private Hashtable<MDPTransition<STATE_TYPE, ACTION_TYPE>, Double> transitionToProbability = new Hashtable<MDPTransition<STATE_TYPE, ACTION_TYPE>, Double>();
018    
019            private List<STATE_TYPE> terminalStates;
020    
021            public MDPTransitionModel(List<STATE_TYPE> terminalStates) {
022                    this.terminalStates = terminalStates;
023    
024            }
025    
026            public void setTransitionProbability(STATE_TYPE initialState,
027                            ACTION_TYPE action, STATE_TYPE finalState, double probability) {
028                    if (!(isTerminal(initialState))) {
029                            MDPTransition<STATE_TYPE, ACTION_TYPE> t = new MDPTransition<STATE_TYPE, ACTION_TYPE>(
030                                            initialState, action, finalState);
031                            transitionToProbability.put(t, probability);
032                    }
033            }
034    
035            public double getTransitionProbability(STATE_TYPE initialState,
036                            ACTION_TYPE action, STATE_TYPE finalState) {
037                    MDPTransition<STATE_TYPE, ACTION_TYPE> key = new MDPTransition<STATE_TYPE, ACTION_TYPE>(
038                                    initialState, action, finalState);
039                    if (transitionToProbability.keySet().contains(key)) {
040                            return transitionToProbability.get(key);
041                    } else {
042                            return 0.0;
043                    }
044            }
045    
046            @Override
047            public String toString() {
048                    StringBuffer buf = new StringBuffer();
049                    for (MDPTransition<STATE_TYPE, ACTION_TYPE> transition : transitionToProbability
050                                    .keySet()) {
051                            buf.append(transition.toString() + " -> "
052                                            + transitionToProbability.get(transition) + " \n");
053                    }
054                    return buf.toString();
055            }
056    
057            public Pair<ACTION_TYPE, Double> getTransitionWithMaximumExpectedUtility(
058                            STATE_TYPE s, MDPUtilityFunction<STATE_TYPE> uf) {
059    
060                    if ((isTerminal(s))) {
061                            return new Pair<ACTION_TYPE, Double>(null, 0.0);
062                    }
063    
064                    List<MDPTransition<STATE_TYPE, ACTION_TYPE>> transitionsStartingWithS = getTransitionsStartingWith(s);
065                    Hashtable<ACTION_TYPE, Double> actionsToUtilities = getExpectedUtilityForSelectedTransitions(
066                                    transitionsStartingWithS, uf);
067    
068                    return getActionWithMaximumUtility(actionsToUtilities);
069    
070            }
071    
072            public Pair<ACTION_TYPE, Double> getTransitionWithMaximumExpectedUtilityUsingPolicy(
073                            MDPPolicy<STATE_TYPE, ACTION_TYPE> policy, STATE_TYPE s,
074                            MDPUtilityFunction<STATE_TYPE> uf) {
075                    if ((isTerminal(s))) {
076                            return new Pair<ACTION_TYPE, Double>(null, 0.0);
077                    }
078                    List<MDPTransition<STATE_TYPE, ACTION_TYPE>> transitionsWithStartingStateSAndActionFromPolicy = getTransitionsWithStartingStateAndAction(
079                                    s, policy.getAction(s));
080                    Hashtable<ACTION_TYPE, Double> actionsToUtilities = getExpectedUtilityForSelectedTransitions(
081                                    transitionsWithStartingStateSAndActionFromPolicy, uf);
082    
083                    return getActionWithMaximumUtility(actionsToUtilities);
084    
085            }
086    
087            private boolean isTerminal(STATE_TYPE s) {
088                    return terminalStates.contains(s);
089            }
090    
091            private Pair<ACTION_TYPE, Double> getActionWithMaximumUtility(
092                            Hashtable<ACTION_TYPE, Double> actionsToUtilities) {
093                    Pair<ACTION_TYPE, Double> highest = new Pair<ACTION_TYPE, Double>(null,
094                                    Double.MIN_VALUE);
095                    for (ACTION_TYPE key : actionsToUtilities.keySet()) {
096                            Double value = actionsToUtilities.get(key);
097                            if (value > highest.getSecond()) {
098                                    highest = new Pair<ACTION_TYPE, Double>(key, value);
099                            }
100                    }
101                    return highest;
102            }
103    
104            private Hashtable<ACTION_TYPE, Double> getExpectedUtilityForSelectedTransitions(
105    
106            List<MDPTransition<STATE_TYPE, ACTION_TYPE>> transitions,
107                            MDPUtilityFunction<STATE_TYPE> uf) {
108                    Hashtable<ACTION_TYPE, Double> actionsToUtilities = new Hashtable<ACTION_TYPE, Double>();
109                    for (MDPTransition<STATE_TYPE, ACTION_TYPE> triplet : transitions) {
110                            STATE_TYPE s = triplet.getInitialState();
111                            ACTION_TYPE action = triplet.getAction();
112                            STATE_TYPE destinationState = triplet.getDestinationState();
113                            double probabilityOfTransition = getTransitionProbability(s,
114                                            action, destinationState);
115                            double expectedUtility = (probabilityOfTransition * uf
116                                            .getUtility(destinationState));
117                            Double presentValue = actionsToUtilities.get(action);
118    
119                            if (presentValue == null) {
120                                    actionsToUtilities.put(action, expectedUtility);
121                            } else {
122                                    actionsToUtilities.put(action, presentValue + expectedUtility);
123                            }
124                    }
125                    return actionsToUtilities;
126            }
127    
128            private List<MDPTransition<STATE_TYPE, ACTION_TYPE>> getTransitionsStartingWith(
129                            STATE_TYPE s) {
130                    List<MDPTransition<STATE_TYPE, ACTION_TYPE>> result = new ArrayList<MDPTransition<STATE_TYPE, ACTION_TYPE>>();
131                    for (MDPTransition<STATE_TYPE, ACTION_TYPE> transition : transitionToProbability
132                                    .keySet()) {
133                            if (transition.getInitialState().equals(s)) {
134                                    result.add(transition);
135                            }
136                    }
137                    return result;
138            }
139    
140            public List<MDPTransition<STATE_TYPE, ACTION_TYPE>> getTransitionsWithStartingStateAndAction(
141                            STATE_TYPE s, ACTION_TYPE a) {
142                    List<MDPTransition<STATE_TYPE, ACTION_TYPE>> result = new ArrayList<MDPTransition<STATE_TYPE, ACTION_TYPE>>();
143                    for (MDPTransition<STATE_TYPE, ACTION_TYPE> transition : transitionToProbability
144                                    .keySet()) {
145                            if ((transition.getInitialState().equals(s))
146                                            && (transition.getAction().equals(a))) {
147                                    result.add(transition);
148                            }
149                    }
150                    return result;
151            }
152    
153            public ACTION_TYPE randomActionFor(STATE_TYPE s) {
154                    List<MDPTransition<STATE_TYPE, ACTION_TYPE>> transitions = getTransitionsStartingWith(s);
155                    MDPTransition<STATE_TYPE, ACTION_TYPE> randomTransition = Util
156                                    .selectRandomlyFromList(transitions);
157                    return transitions.get(0).getAction();
158                    // return randomTransition.getAction();
159            }
160    }