001 package aima.learning.reinforcement; 002 003 import java.util.Hashtable; 004 import java.util.List; 005 006 import aima.probability.decision.MDP; 007 import aima.probability.decision.MDPPerception; 008 import aima.probability.decision.MDPPolicy; 009 import aima.probability.decision.MDPTransition; 010 import aima.probability.decision.MDPUtilityFunction; 011 import aima.util.Pair; 012 013 /** 014 * @author Ravi Mohan 015 * 016 */ 017 018 public class PassiveADPAgent<STATE_TYPE, ACTION_TYPE> extends 019 MDPAgent<STATE_TYPE, ACTION_TYPE> { 020 private MDPPolicy<STATE_TYPE, ACTION_TYPE> policy; 021 022 private MDPUtilityFunction<STATE_TYPE> utilityFunction; 023 024 private Hashtable<Pair<STATE_TYPE, ACTION_TYPE>, Double> nsa; 025 026 private Hashtable<MDPTransition<STATE_TYPE, ACTION_TYPE>, Double> nsasdash; 027 028 public PassiveADPAgent(MDP<STATE_TYPE, ACTION_TYPE> mdp, 029 MDPPolicy<STATE_TYPE, ACTION_TYPE> policy) { 030 super(mdp.emptyMdp()); 031 this.policy = policy; 032 this.utilityFunction = new MDPUtilityFunction<STATE_TYPE>(); 033 this.nsa = new Hashtable<Pair<STATE_TYPE, ACTION_TYPE>, Double>(); 034 this.nsasdash = new Hashtable<MDPTransition<STATE_TYPE, ACTION_TYPE>, Double>(); 035 036 } 037 038 @Override 039 public ACTION_TYPE decideAction(MDPPerception<STATE_TYPE> perception) { 040 041 if (!(utilityFunction.hasUtilityFor(perception.getState()))) { // if 042 // perceptionState 043 // is 044 // new 045 utilityFunction.setUtility(perception.getState(), perception 046 .getReward()); 047 mdp.setReward(perception.getState(), perception.getReward()); 048 } 049 if (!(previousState == null)) { 050 Double oldValue1 = nsa.get(new Pair<STATE_TYPE, ACTION_TYPE>( 051 previousState, previousAction)); 052 if (oldValue1 == null) { 053 nsa.put(new Pair<STATE_TYPE, ACTION_TYPE>(previousState, 054 previousAction), 1.0); 055 } else { 056 nsa.put(new Pair<STATE_TYPE, ACTION_TYPE>(previousState, 057 previousAction), oldValue1 + 1); 058 } 059 Double oldValue2 = nsasdash 060 .get(new MDPTransition<STATE_TYPE, ACTION_TYPE>( 061 previousState, previousAction, currentState)); 062 if (oldValue2 == null) { 063 nsasdash.put(new MDPTransition<STATE_TYPE, ACTION_TYPE>( 064 previousState, previousAction, currentState), 1.0); 065 066 } else { 067 nsasdash.put(new MDPTransition<STATE_TYPE, ACTION_TYPE>( 068 previousState, previousAction, currentState), 069 oldValue2 + 1); 070 } 071 for (MDPTransition<STATE_TYPE, ACTION_TYPE> transition : nsasdash 072 .keySet()) { 073 if (nsasdash.get(transition) != 0.0) { 074 double newValue = nsasdash.get(transition) 075 / nsa.get(new Pair<STATE_TYPE, ACTION_TYPE>( 076 transition.getInitialState(), transition 077 .getAction())); 078 mdp.setTransitionProbability(transition, newValue); 079 } 080 } 081 List<MDPTransition<STATE_TYPE, ACTION_TYPE>> validTransitions = mdp 082 .getTransitionsWith(previousState, policy 083 .getAction(previousState)); 084 utilityFunction = valueDetermination(validTransitions, 1); 085 } 086 087 if (mdp.isTerminalState(currentState)) { 088 previousState = null; 089 previousAction = null; 090 } else { 091 previousState = currentState; 092 previousAction = policy.getAction(currentState); 093 } 094 return previousAction; 095 } 096 097 private MDPUtilityFunction<STATE_TYPE> valueDetermination( 098 List<MDPTransition<STATE_TYPE, ACTION_TYPE>> validTransitions, 099 double gamma) { 100 MDPUtilityFunction<STATE_TYPE> uf = utilityFunction.copy(); 101 double additional = 0.0; 102 if (validTransitions.size() > 0) { 103 STATE_TYPE initState = validTransitions.get(0).getInitialState(); 104 double reward = mdp.getRewardFor(initState); 105 for (MDPTransition<STATE_TYPE, ACTION_TYPE> transition : validTransitions) { 106 additional += mdp.getTransitionProbability(transition) 107 * utilityFunction.getUtility(transition 108 .getDestinationState()); 109 } 110 uf.setUtility(initState, reward + (gamma * additional)); 111 } 112 113 return uf; 114 } 115 116 public MDPUtilityFunction<STATE_TYPE> getUtilityFunction() { 117 118 return utilityFunction; 119 } 120 121 }