001 package aima.learning.reinforcement; 002 003 import java.util.ArrayList; 004 import java.util.Hashtable; 005 import java.util.List; 006 import java.util.Set; 007 008 import aima.probability.decision.MDPPolicy; 009 import aima.util.Pair; 010 import aima.util.Util; 011 012 /** 013 * @author Ravi Mohan 014 * 015 */ 016 017 public class QTable<STATE_TYPE, ACTION_TYPE> { 018 019 Hashtable<Pair<STATE_TYPE, ACTION_TYPE>, Double> table; 020 021 private List<ACTION_TYPE> allPossibleActions; 022 023 public QTable(List<ACTION_TYPE> allPossibleActions) { 024 this.table = new Hashtable<Pair<STATE_TYPE, ACTION_TYPE>, Double>(); 025 this.allPossibleActions = allPossibleActions; 026 } 027 028 public Double getQValue(STATE_TYPE state, ACTION_TYPE action) { 029 Pair<STATE_TYPE, ACTION_TYPE> stateActionPair = new Pair<STATE_TYPE, ACTION_TYPE>( 030 state, action); 031 if (!(table.keySet().contains(stateActionPair))) { 032 return 0.0; 033 } else { 034 return table.get(stateActionPair); 035 } 036 } 037 038 public Pair<ACTION_TYPE, Double> maxDiff(STATE_TYPE startState, 039 ACTION_TYPE action, STATE_TYPE endState) { 040 Double maxDiff = 0.0; 041 ACTION_TYPE maxAction = null; 042 // randomly choose an action so that it doesn't return the same action 043 // every time if all Q(a,s) are zero 044 maxAction = Util.selectRandomlyFromList(allPossibleActions); 045 maxDiff = getQValue(endState, maxAction) 046 - getQValue(startState, action); 047 048 for (ACTION_TYPE anAction : allPossibleActions) { 049 Double diff = getQValue(endState, anAction) 050 - getQValue(startState, action); 051 if (diff > maxDiff) { 052 maxAction = anAction; 053 maxDiff = diff; 054 } 055 } 056 057 return new Pair<ACTION_TYPE, Double>(maxAction, maxDiff); 058 } 059 060 public void setQValue(STATE_TYPE state, ACTION_TYPE action, Double d) { 061 Pair<STATE_TYPE, ACTION_TYPE> stateActionPair = new Pair<STATE_TYPE, ACTION_TYPE>( 062 state, action); 063 table.put(stateActionPair, d); 064 } 065 066 public ACTION_TYPE upDateQ(STATE_TYPE startState, ACTION_TYPE action, 067 STATE_TYPE endState, double alpha, double reward, double phi) { 068 double oldQValue = getQValue(startState, action); 069 Pair<ACTION_TYPE, Double> actionAndMaxDiffValue = maxDiff(startState, 070 action, endState); 071 double addedValue = alpha 072 * (reward + (phi * actionAndMaxDiffValue.getSecond())); 073 setQValue(startState, action, oldQValue + addedValue); 074 return actionAndMaxDiffValue.getFirst(); 075 } 076 077 public void normalize() { 078 Double maxValue = findMaximumValue(); 079 if (maxValue != 0.0) { 080 for (Pair<STATE_TYPE, ACTION_TYPE> key : table.keySet()) { 081 Double presentValue = table.get(key); 082 table.put(key, presentValue / maxValue); 083 } 084 } 085 } 086 087 private Double findMaximumValue() { 088 Set<Pair<STATE_TYPE, ACTION_TYPE>> keys = table.keySet(); 089 if (keys.size() > 0) { 090 Double maxValue = table.get(keys.toArray()[0]); 091 for (Pair<STATE_TYPE, ACTION_TYPE> key : keys) { 092 Double v = table.get(key); 093 if (v > maxValue) { 094 maxValue = v; 095 } 096 } 097 return maxValue; 098 099 } else { 100 return 0.0; 101 } 102 } 103 104 public MDPPolicy<STATE_TYPE, ACTION_TYPE> getPolicy() { 105 MDPPolicy<STATE_TYPE, ACTION_TYPE> policy = new MDPPolicy<STATE_TYPE, ACTION_TYPE>(); 106 List<STATE_TYPE> startingStatesRecorded = getAllStartingStates(); 107 108 for (STATE_TYPE state : startingStatesRecorded) { 109 ACTION_TYPE action = getRecordedActionWithMaximumQValue(state); 110 policy.setAction(state, action); 111 } 112 return policy; 113 } 114 115 private ACTION_TYPE getRecordedActionWithMaximumQValue(STATE_TYPE state) { 116 Double maxValue = Double.NEGATIVE_INFINITY; 117 ACTION_TYPE action = null; 118 for (Pair<STATE_TYPE, ACTION_TYPE> stateActionPair : table.keySet()) { 119 if (stateActionPair.getFirst().equals(state)) { 120 ACTION_TYPE ac = stateActionPair.getSecond(); 121 Double value = table.get(stateActionPair); 122 if (value > maxValue) { 123 maxValue = value; 124 action = ac; 125 } 126 } 127 } 128 return action; 129 } 130 131 private List<STATE_TYPE> getAllStartingStates() { 132 List<STATE_TYPE> states = new ArrayList<STATE_TYPE>(); 133 for (Pair<STATE_TYPE, ACTION_TYPE> stateActionPair : table.keySet()) { 134 STATE_TYPE state = stateActionPair.getFirst(); 135 if (!(states).contains(state)) { 136 states.add(state); 137 } 138 } 139 return states; 140 } 141 142 @Override 143 public String toString() { 144 return table.toString(); 145 } 146 147 }