001    package aima.learning.reinforcement;
002    
003    import java.util.Hashtable;
004    import java.util.List;
005    
006    import aima.probability.decision.MDP;
007    import aima.probability.decision.MDPPerception;
008    import aima.util.FrequencyCounter;
009    import aima.util.Pair;
010    
011    /**
012     * @author Ravi Mohan
013     * 
014     */
015    public class QLearningAgent<STATE_TYPE, ACTION_TYPE> extends
016                    MDPAgent<STATE_TYPE, ACTION_TYPE> {
017    
018            private Hashtable<Pair<STATE_TYPE, ACTION_TYPE>, Double> Q;
019    
020            private FrequencyCounter<Pair<STATE_TYPE, ACTION_TYPE>> stateActionCount;
021    
022            private Double previousReward;
023    
024            private QTable<STATE_TYPE, ACTION_TYPE> qTable;
025    
026            private int actionCounter;
027    
028            public QLearningAgent(MDP<STATE_TYPE, ACTION_TYPE> mdp) {
029                    super(mdp);
030                    Q = new Hashtable<Pair<STATE_TYPE, ACTION_TYPE>, Double>();
031                    qTable = new QTable<STATE_TYPE, ACTION_TYPE>(mdp.getAllActions());
032                    stateActionCount = new FrequencyCounter<Pair<STATE_TYPE, ACTION_TYPE>>();
033                    actionCounter = 0;
034            }
035    
036            @Override
037            public ACTION_TYPE decideAction(MDPPerception<STATE_TYPE> perception) {
038                    currentState = perception.getState();
039                    currentReward = perception.getReward();
040    
041                    if (startingTrial()) {
042                            ACTION_TYPE chosenAction = selectRandomAction();
043                            updateLearnerState(chosenAction);
044                            return previousAction;
045                    }
046    
047                    if (mdp.isTerminalState(currentState)) {
048                            incrementStateActionCount(previousState, previousAction);
049                            updateQ(0.8);
050                            previousAction = null;
051                            previousState = null;
052                            previousReward = null;
053                            return previousAction;
054                    }
055    
056                    else {
057                            incrementStateActionCount(previousState, previousAction);
058                            ACTION_TYPE chosenAction = updateQ(0.8);
059                            updateLearnerState(chosenAction);
060                            return previousAction;
061                    }
062    
063            }
064    
065            private void updateLearnerState(ACTION_TYPE chosenAction) {
066                    // previousAction = actionMaximizingLearningFunction();
067                    previousAction = chosenAction;
068                    previousAction = chosenAction;
069                    previousState = currentState;
070                    previousReward = currentReward;
071            }
072    
073            private ACTION_TYPE updateQ(double gamma) {
074    
075                    actionCounter++;
076                    // qtable update
077    
078                    double alpha = calculateProbabilityOf(previousState, previousAction);
079                    ACTION_TYPE ac = qTable.upDateQ(previousState, previousAction,
080                                    currentState, alpha, currentReward, 0.8);
081    
082                    return ac;
083    
084            }
085    
086            private double calculateProbabilityOf(STATE_TYPE state, ACTION_TYPE action) {
087                    Double den = 0.0;
088                    Double num = 0.0;
089                    for (Pair<STATE_TYPE, ACTION_TYPE> stateActionPair : stateActionCount
090                                    .getStates()) {
091    
092                            if (stateActionPair.getFirst().equals(state)) {
093                                    den += 1;
094                                    if (stateActionPair.getSecond().equals(action)) {
095                                            num += 1;
096                                    }
097                            }
098                    }
099                    return num / den;
100            }
101    
102            private ACTION_TYPE actionMaximizingLearningFunction() {
103                    ACTION_TYPE maxAct = null;
104                    Double maxValue = Double.NEGATIVE_INFINITY;
105                    for (ACTION_TYPE action : mdp.getAllActions()) {
106                            Double qValue = qTable.getQValue(currentState, action);
107                            Double lfv = learningFunction(qValue);
108                            if (lfv > maxValue) {
109                                    maxValue = lfv;
110                                    maxAct = action;
111                            }
112                    }
113                    return maxAct;
114            }
115    
116            private Double learningFunction(Double utility) {
117                    if (actionCounter > 3) {
118                            actionCounter = 0;
119                            return 1.0;
120                    } else {
121                            return utility;
122                    }
123            }
124    
125            private ACTION_TYPE selectRandomAction() {
126                    List<ACTION_TYPE> allActions = mdp.getAllActions();
127                    return allActions.get(0);
128                    // return Util.selectRandomlyFromList(allActions);
129            }
130    
131            private boolean startingTrial() {
132                    return (previousAction == null) && (previousState == null)
133                                    && (previousReward == null)
134                                    && (currentState.equals(mdp.getInitialState()));
135            }
136    
137            private void incrementStateActionCount(STATE_TYPE state, ACTION_TYPE action) {
138                    Pair<STATE_TYPE, ACTION_TYPE> stateActionPair = new Pair<STATE_TYPE, ACTION_TYPE>(
139                                    state, action);
140                    stateActionCount.incrementFor(stateActionPair);
141            }
142    
143            public Hashtable<Pair<STATE_TYPE, ACTION_TYPE>, Double> getQ() {
144                    return Q;
145            }
146    
147            public QTable<STATE_TYPE, ACTION_TYPE> getQTable() {
148                    return qTable;
149            }
150    
151    }