001 package aima.probability.decision; 002 003 import java.util.ArrayList; 004 import java.util.Hashtable; 005 import java.util.List; 006 007 import aima.util.Pair; 008 import aima.util.Util; 009 010 /** 011 * @author Ravi Mohan 012 * 013 */ 014 015 public class MDPTransitionModel<STATE_TYPE, ACTION_TYPE> { 016 017 private Hashtable<MDPTransition<STATE_TYPE, ACTION_TYPE>, Double> transitionToProbability = new Hashtable<MDPTransition<STATE_TYPE, ACTION_TYPE>, Double>(); 018 019 private List<STATE_TYPE> terminalStates; 020 021 public MDPTransitionModel(List<STATE_TYPE> terminalStates) { 022 this.terminalStates = terminalStates; 023 024 } 025 026 public void setTransitionProbability(STATE_TYPE initialState, 027 ACTION_TYPE action, STATE_TYPE finalState, double probability) { 028 if (!(isTerminal(initialState))) { 029 MDPTransition<STATE_TYPE, ACTION_TYPE> t = new MDPTransition<STATE_TYPE, ACTION_TYPE>( 030 initialState, action, finalState); 031 transitionToProbability.put(t, probability); 032 } 033 } 034 035 public double getTransitionProbability(STATE_TYPE initialState, 036 ACTION_TYPE action, STATE_TYPE finalState) { 037 MDPTransition<STATE_TYPE, ACTION_TYPE> key = new MDPTransition<STATE_TYPE, ACTION_TYPE>( 038 initialState, action, finalState); 039 if (transitionToProbability.keySet().contains(key)) { 040 return transitionToProbability.get(key); 041 } else { 042 return 0.0; 043 } 044 } 045 046 @Override 047 public String toString() { 048 StringBuffer buf = new StringBuffer(); 049 for (MDPTransition<STATE_TYPE, ACTION_TYPE> transition : transitionToProbability 050 .keySet()) { 051 buf.append(transition.toString() + " -> " 052 + transitionToProbability.get(transition) + " \n"); 053 } 054 return buf.toString(); 055 } 056 057 public Pair<ACTION_TYPE, Double> getTransitionWithMaximumExpectedUtility( 058 STATE_TYPE s, MDPUtilityFunction<STATE_TYPE> uf) { 059 060 if ((isTerminal(s))) { 061 return new Pair<ACTION_TYPE, Double>(null, 0.0); 062 } 063 064 List<MDPTransition<STATE_TYPE, ACTION_TYPE>> transitionsStartingWithS = getTransitionsStartingWith(s); 065 Hashtable<ACTION_TYPE, Double> actionsToUtilities = getExpectedUtilityForSelectedTransitions( 066 transitionsStartingWithS, uf); 067 068 return getActionWithMaximumUtility(actionsToUtilities); 069 070 } 071 072 public Pair<ACTION_TYPE, Double> getTransitionWithMaximumExpectedUtilityUsingPolicy( 073 MDPPolicy<STATE_TYPE, ACTION_TYPE> policy, STATE_TYPE s, 074 MDPUtilityFunction<STATE_TYPE> uf) { 075 if ((isTerminal(s))) { 076 return new Pair<ACTION_TYPE, Double>(null, 0.0); 077 } 078 List<MDPTransition<STATE_TYPE, ACTION_TYPE>> transitionsWithStartingStateSAndActionFromPolicy = getTransitionsWithStartingStateAndAction( 079 s, policy.getAction(s)); 080 Hashtable<ACTION_TYPE, Double> actionsToUtilities = getExpectedUtilityForSelectedTransitions( 081 transitionsWithStartingStateSAndActionFromPolicy, uf); 082 083 return getActionWithMaximumUtility(actionsToUtilities); 084 085 } 086 087 private boolean isTerminal(STATE_TYPE s) { 088 return terminalStates.contains(s); 089 } 090 091 private Pair<ACTION_TYPE, Double> getActionWithMaximumUtility( 092 Hashtable<ACTION_TYPE, Double> actionsToUtilities) { 093 Pair<ACTION_TYPE, Double> highest = new Pair<ACTION_TYPE, Double>(null, 094 Double.MIN_VALUE); 095 for (ACTION_TYPE key : actionsToUtilities.keySet()) { 096 Double value = actionsToUtilities.get(key); 097 if (value > highest.getSecond()) { 098 highest = new Pair<ACTION_TYPE, Double>(key, value); 099 } 100 } 101 return highest; 102 } 103 104 private Hashtable<ACTION_TYPE, Double> getExpectedUtilityForSelectedTransitions( 105 106 List<MDPTransition<STATE_TYPE, ACTION_TYPE>> transitions, 107 MDPUtilityFunction<STATE_TYPE> uf) { 108 Hashtable<ACTION_TYPE, Double> actionsToUtilities = new Hashtable<ACTION_TYPE, Double>(); 109 for (MDPTransition<STATE_TYPE, ACTION_TYPE> triplet : transitions) { 110 STATE_TYPE s = triplet.getInitialState(); 111 ACTION_TYPE action = triplet.getAction(); 112 STATE_TYPE destinationState = triplet.getDestinationState(); 113 double probabilityOfTransition = getTransitionProbability(s, 114 action, destinationState); 115 double expectedUtility = (probabilityOfTransition * uf 116 .getUtility(destinationState)); 117 Double presentValue = actionsToUtilities.get(action); 118 119 if (presentValue == null) { 120 actionsToUtilities.put(action, expectedUtility); 121 } else { 122 actionsToUtilities.put(action, presentValue + expectedUtility); 123 } 124 } 125 return actionsToUtilities; 126 } 127 128 private List<MDPTransition<STATE_TYPE, ACTION_TYPE>> getTransitionsStartingWith( 129 STATE_TYPE s) { 130 List<MDPTransition<STATE_TYPE, ACTION_TYPE>> result = new ArrayList<MDPTransition<STATE_TYPE, ACTION_TYPE>>(); 131 for (MDPTransition<STATE_TYPE, ACTION_TYPE> transition : transitionToProbability 132 .keySet()) { 133 if (transition.getInitialState().equals(s)) { 134 result.add(transition); 135 } 136 } 137 return result; 138 } 139 140 public List<MDPTransition<STATE_TYPE, ACTION_TYPE>> getTransitionsWithStartingStateAndAction( 141 STATE_TYPE s, ACTION_TYPE a) { 142 List<MDPTransition<STATE_TYPE, ACTION_TYPE>> result = new ArrayList<MDPTransition<STATE_TYPE, ACTION_TYPE>>(); 143 for (MDPTransition<STATE_TYPE, ACTION_TYPE> transition : transitionToProbability 144 .keySet()) { 145 if ((transition.getInitialState().equals(s)) 146 && (transition.getAction().equals(a))) { 147 result.add(transition); 148 } 149 } 150 return result; 151 } 152 153 public ACTION_TYPE randomActionFor(STATE_TYPE s) { 154 List<MDPTransition<STATE_TYPE, ACTION_TYPE>> transitions = getTransitionsStartingWith(s); 155 MDPTransition<STATE_TYPE, ACTION_TYPE> randomTransition = Util 156 .selectRandomlyFromList(transitions); 157 return transitions.get(0).getAction(); 158 // return randomTransition.getAction(); 159 } 160 }