001    package aima.learning.reinforcement;
002    
003    import java.util.ArrayList;
004    import java.util.Hashtable;
005    import java.util.List;
006    import java.util.Set;
007    
008    import aima.probability.decision.MDPPolicy;
009    import aima.util.Pair;
010    import aima.util.Util;
011    
012    /**
013     * @author Ravi Mohan
014     * 
015     */
016    
017    public class QTable<STATE_TYPE, ACTION_TYPE> {
018    
019            Hashtable<Pair<STATE_TYPE, ACTION_TYPE>, Double> table;
020    
021            private List<ACTION_TYPE> allPossibleActions;
022    
023            public QTable(List<ACTION_TYPE> allPossibleActions) {
024                    this.table = new Hashtable<Pair<STATE_TYPE, ACTION_TYPE>, Double>();
025                    this.allPossibleActions = allPossibleActions;
026            }
027    
028            public Double getQValue(STATE_TYPE state, ACTION_TYPE action) {
029                    Pair<STATE_TYPE, ACTION_TYPE> stateActionPair = new Pair<STATE_TYPE, ACTION_TYPE>(
030                                    state, action);
031                    if (!(table.keySet().contains(stateActionPair))) {
032                            return 0.0;
033                    } else {
034                            return table.get(stateActionPair);
035                    }
036            }
037    
038            public Pair<ACTION_TYPE, Double> maxDiff(STATE_TYPE startState,
039                            ACTION_TYPE action, STATE_TYPE endState) {
040                    Double maxDiff = 0.0;
041                    ACTION_TYPE maxAction = null;
042                    // randomly choose an action so that it doesn't return the same action
043                    // every time if all Q(a,s) are zero
044                    maxAction = Util.selectRandomlyFromList(allPossibleActions);
045                    maxDiff = getQValue(endState, maxAction)
046                                    - getQValue(startState, action);
047    
048                    for (ACTION_TYPE anAction : allPossibleActions) {
049                            Double diff = getQValue(endState, anAction)
050                                            - getQValue(startState, action);
051                            if (diff > maxDiff) {
052                                    maxAction = anAction;
053                                    maxDiff = diff;
054                            }
055                    }
056    
057                    return new Pair<ACTION_TYPE, Double>(maxAction, maxDiff);
058            }
059    
060            public void setQValue(STATE_TYPE state, ACTION_TYPE action, Double d) {
061                    Pair<STATE_TYPE, ACTION_TYPE> stateActionPair = new Pair<STATE_TYPE, ACTION_TYPE>(
062                                    state, action);
063                    table.put(stateActionPair, d);
064            }
065    
066            public ACTION_TYPE upDateQ(STATE_TYPE startState, ACTION_TYPE action,
067                            STATE_TYPE endState, double alpha, double reward, double phi) {
068                    double oldQValue = getQValue(startState, action);
069                    Pair<ACTION_TYPE, Double> actionAndMaxDiffValue = maxDiff(startState,
070                                    action, endState);
071                    double addedValue = alpha
072                                    * (reward + (phi * actionAndMaxDiffValue.getSecond()));
073                    setQValue(startState, action, oldQValue + addedValue);
074                    return actionAndMaxDiffValue.getFirst();
075            }
076    
077            public void normalize() {
078                    Double maxValue = findMaximumValue();
079                    if (maxValue != 0.0) {
080                            for (Pair<STATE_TYPE, ACTION_TYPE> key : table.keySet()) {
081                                    Double presentValue = table.get(key);
082                                    table.put(key, presentValue / maxValue);
083                            }
084                    }
085            }
086    
087            private Double findMaximumValue() {
088                    Set<Pair<STATE_TYPE, ACTION_TYPE>> keys = table.keySet();
089                    if (keys.size() > 0) {
090                            Double maxValue = table.get(keys.toArray()[0]);
091                            for (Pair<STATE_TYPE, ACTION_TYPE> key : keys) {
092                                    Double v = table.get(key);
093                                    if (v > maxValue) {
094                                            maxValue = v;
095                                    }
096                            }
097                            return maxValue;
098    
099                    } else {
100                            return 0.0;
101                    }
102            }
103    
104            public MDPPolicy<STATE_TYPE, ACTION_TYPE> getPolicy() {
105                    MDPPolicy<STATE_TYPE, ACTION_TYPE> policy = new MDPPolicy<STATE_TYPE, ACTION_TYPE>();
106                    List<STATE_TYPE> startingStatesRecorded = getAllStartingStates();
107    
108                    for (STATE_TYPE state : startingStatesRecorded) {
109                            ACTION_TYPE action = getRecordedActionWithMaximumQValue(state);
110                            policy.setAction(state, action);
111                    }
112                    return policy;
113            }
114    
115            private ACTION_TYPE getRecordedActionWithMaximumQValue(STATE_TYPE state) {
116                    Double maxValue = Double.NEGATIVE_INFINITY;
117                    ACTION_TYPE action = null;
118                    for (Pair<STATE_TYPE, ACTION_TYPE> stateActionPair : table.keySet()) {
119                            if (stateActionPair.getFirst().equals(state)) {
120                                    ACTION_TYPE ac = stateActionPair.getSecond();
121                                    Double value = table.get(stateActionPair);
122                                    if (value > maxValue) {
123                                            maxValue = value;
124                                            action = ac;
125                                    }
126                            }
127                    }
128                    return action;
129            }
130    
131            private List<STATE_TYPE> getAllStartingStates() {
132                    List<STATE_TYPE> states = new ArrayList<STATE_TYPE>();
133                    for (Pair<STATE_TYPE, ACTION_TYPE> stateActionPair : table.keySet()) {
134                            STATE_TYPE state = stateActionPair.getFirst();
135                            if (!(states).contains(state)) {
136                                    states.add(state);
137                            }
138                    }
139                    return states;
140            }
141    
142            @Override
143            public String toString() {
144                    return table.toString();
145            }
146    
147    }