001 package aima.learning.reinforcement; 002 003 import java.util.Hashtable; 004 import java.util.List; 005 006 import aima.probability.decision.MDP; 007 import aima.probability.decision.MDPPerception; 008 import aima.util.FrequencyCounter; 009 import aima.util.Pair; 010 011 /** 012 * @author Ravi Mohan 013 * 014 */ 015 public class QLearningAgent<STATE_TYPE, ACTION_TYPE> extends 016 MDPAgent<STATE_TYPE, ACTION_TYPE> { 017 018 private Hashtable<Pair<STATE_TYPE, ACTION_TYPE>, Double> Q; 019 020 private FrequencyCounter<Pair<STATE_TYPE, ACTION_TYPE>> stateActionCount; 021 022 private Double previousReward; 023 024 private QTable<STATE_TYPE, ACTION_TYPE> qTable; 025 026 private int actionCounter; 027 028 public QLearningAgent(MDP<STATE_TYPE, ACTION_TYPE> mdp) { 029 super(mdp); 030 Q = new Hashtable<Pair<STATE_TYPE, ACTION_TYPE>, Double>(); 031 qTable = new QTable<STATE_TYPE, ACTION_TYPE>(mdp.getAllActions()); 032 stateActionCount = new FrequencyCounter<Pair<STATE_TYPE, ACTION_TYPE>>(); 033 actionCounter = 0; 034 } 035 036 @Override 037 public ACTION_TYPE decideAction(MDPPerception<STATE_TYPE> perception) { 038 currentState = perception.getState(); 039 currentReward = perception.getReward(); 040 041 if (startingTrial()) { 042 ACTION_TYPE chosenAction = selectRandomAction(); 043 updateLearnerState(chosenAction); 044 return previousAction; 045 } 046 047 if (mdp.isTerminalState(currentState)) { 048 incrementStateActionCount(previousState, previousAction); 049 updateQ(0.8); 050 previousAction = null; 051 previousState = null; 052 previousReward = null; 053 return previousAction; 054 } 055 056 else { 057 incrementStateActionCount(previousState, previousAction); 058 ACTION_TYPE chosenAction = updateQ(0.8); 059 updateLearnerState(chosenAction); 060 return previousAction; 061 } 062 063 } 064 065 private void updateLearnerState(ACTION_TYPE chosenAction) { 066 // previousAction = actionMaximizingLearningFunction(); 067 previousAction = chosenAction; 068 previousAction = chosenAction; 069 previousState = currentState; 070 previousReward = currentReward; 071 } 072 073 private ACTION_TYPE updateQ(double gamma) { 074 075 actionCounter++; 076 // qtable update 077 078 double alpha = calculateProbabilityOf(previousState, previousAction); 079 ACTION_TYPE ac = qTable.upDateQ(previousState, previousAction, 080 currentState, alpha, currentReward, 0.8); 081 082 return ac; 083 084 } 085 086 private double calculateProbabilityOf(STATE_TYPE state, ACTION_TYPE action) { 087 Double den = 0.0; 088 Double num = 0.0; 089 for (Pair<STATE_TYPE, ACTION_TYPE> stateActionPair : stateActionCount 090 .getStates()) { 091 092 if (stateActionPair.getFirst().equals(state)) { 093 den += 1; 094 if (stateActionPair.getSecond().equals(action)) { 095 num += 1; 096 } 097 } 098 } 099 return num / den; 100 } 101 102 private ACTION_TYPE actionMaximizingLearningFunction() { 103 ACTION_TYPE maxAct = null; 104 Double maxValue = Double.NEGATIVE_INFINITY; 105 for (ACTION_TYPE action : mdp.getAllActions()) { 106 Double qValue = qTable.getQValue(currentState, action); 107 Double lfv = learningFunction(qValue); 108 if (lfv > maxValue) { 109 maxValue = lfv; 110 maxAct = action; 111 } 112 } 113 return maxAct; 114 } 115 116 private Double learningFunction(Double utility) { 117 if (actionCounter > 3) { 118 actionCounter = 0; 119 return 1.0; 120 } else { 121 return utility; 122 } 123 } 124 125 private ACTION_TYPE selectRandomAction() { 126 List<ACTION_TYPE> allActions = mdp.getAllActions(); 127 return allActions.get(0); 128 // return Util.selectRandomlyFromList(allActions); 129 } 130 131 private boolean startingTrial() { 132 return (previousAction == null) && (previousState == null) 133 && (previousReward == null) 134 && (currentState.equals(mdp.getInitialState())); 135 } 136 137 private void incrementStateActionCount(STATE_TYPE state, ACTION_TYPE action) { 138 Pair<STATE_TYPE, ACTION_TYPE> stateActionPair = new Pair<STATE_TYPE, ACTION_TYPE>( 139 state, action); 140 stateActionCount.incrementFor(stateActionPair); 141 } 142 143 public Hashtable<Pair<STATE_TYPE, ACTION_TYPE>, Double> getQ() { 144 return Q; 145 } 146 147 public QTable<STATE_TYPE, ACTION_TYPE> getQTable() { 148 return qTable; 149 } 150 151 }