001    package aima.learning.reinforcement;
002    
003    import java.util.Hashtable;
004    import java.util.List;
005    
006    import aima.probability.decision.MDP;
007    import aima.probability.decision.MDPPerception;
008    import aima.probability.decision.MDPPolicy;
009    import aima.probability.decision.MDPTransition;
010    import aima.probability.decision.MDPUtilityFunction;
011    import aima.util.Pair;
012    
013    /**
014     * @author Ravi Mohan
015     * 
016     */
017    
018    public class PassiveADPAgent<STATE_TYPE, ACTION_TYPE> extends
019                    MDPAgent<STATE_TYPE, ACTION_TYPE> {
020            private MDPPolicy<STATE_TYPE, ACTION_TYPE> policy;
021    
022            private MDPUtilityFunction<STATE_TYPE> utilityFunction;
023    
024            private Hashtable<Pair<STATE_TYPE, ACTION_TYPE>, Double> nsa;
025    
026            private Hashtable<MDPTransition<STATE_TYPE, ACTION_TYPE>, Double> nsasdash;
027    
028            public PassiveADPAgent(MDP<STATE_TYPE, ACTION_TYPE> mdp,
029                            MDPPolicy<STATE_TYPE, ACTION_TYPE> policy) {
030                    super(mdp.emptyMdp());
031                    this.policy = policy;
032                    this.utilityFunction = new MDPUtilityFunction<STATE_TYPE>();
033                    this.nsa = new Hashtable<Pair<STATE_TYPE, ACTION_TYPE>, Double>();
034                    this.nsasdash = new Hashtable<MDPTransition<STATE_TYPE, ACTION_TYPE>, Double>();
035    
036            }
037    
038            @Override
039            public ACTION_TYPE decideAction(MDPPerception<STATE_TYPE> perception) {
040    
041                    if (!(utilityFunction.hasUtilityFor(perception.getState()))) { // if
042                            // perceptionState
043                            // is
044                            // new
045                            utilityFunction.setUtility(perception.getState(), perception
046                                            .getReward());
047                            mdp.setReward(perception.getState(), perception.getReward());
048                    }
049                    if (!(previousState == null)) {
050                            Double oldValue1 = nsa.get(new Pair<STATE_TYPE, ACTION_TYPE>(
051                                            previousState, previousAction));
052                            if (oldValue1 == null) {
053                                    nsa.put(new Pair<STATE_TYPE, ACTION_TYPE>(previousState,
054                                                    previousAction), 1.0);
055                            } else {
056                                    nsa.put(new Pair<STATE_TYPE, ACTION_TYPE>(previousState,
057                                                    previousAction), oldValue1 + 1);
058                            }
059                            Double oldValue2 = nsasdash
060                                            .get(new MDPTransition<STATE_TYPE, ACTION_TYPE>(
061                                                            previousState, previousAction, currentState));
062                            if (oldValue2 == null) {
063                                    nsasdash.put(new MDPTransition<STATE_TYPE, ACTION_TYPE>(
064                                                    previousState, previousAction, currentState), 1.0);
065    
066                            } else {
067                                    nsasdash.put(new MDPTransition<STATE_TYPE, ACTION_TYPE>(
068                                                    previousState, previousAction, currentState),
069                                                    oldValue2 + 1);
070                            }
071                            for (MDPTransition<STATE_TYPE, ACTION_TYPE> transition : nsasdash
072                                            .keySet()) {
073                                    if (nsasdash.get(transition) != 0.0) {
074                                            double newValue = nsasdash.get(transition)
075                                                            / nsa.get(new Pair<STATE_TYPE, ACTION_TYPE>(
076                                                                            transition.getInitialState(), transition
077                                                                                            .getAction()));
078                                            mdp.setTransitionProbability(transition, newValue);
079                                    }
080                            }
081                            List<MDPTransition<STATE_TYPE, ACTION_TYPE>> validTransitions = mdp
082                                            .getTransitionsWith(previousState, policy
083                                                            .getAction(previousState));
084                            utilityFunction = valueDetermination(validTransitions, 1);
085                    }
086    
087                    if (mdp.isTerminalState(currentState)) {
088                            previousState = null;
089                            previousAction = null;
090                    } else {
091                            previousState = currentState;
092                            previousAction = policy.getAction(currentState);
093                    }
094                    return previousAction;
095            }
096    
097            private MDPUtilityFunction<STATE_TYPE> valueDetermination(
098                            List<MDPTransition<STATE_TYPE, ACTION_TYPE>> validTransitions,
099                            double gamma) {
100                    MDPUtilityFunction<STATE_TYPE> uf = utilityFunction.copy();
101                    double additional = 0.0;
102                    if (validTransitions.size() > 0) {
103                            STATE_TYPE initState = validTransitions.get(0).getInitialState();
104                            double reward = mdp.getRewardFor(initState);
105                            for (MDPTransition<STATE_TYPE, ACTION_TYPE> transition : validTransitions) {
106                                    additional += mdp.getTransitionProbability(transition)
107                                                    * utilityFunction.getUtility(transition
108                                                                    .getDestinationState());
109                            }
110                            uf.setUtility(initState, reward + (gamma * additional));
111                    }
112    
113                    return uf;
114            }
115    
116            public MDPUtilityFunction<STATE_TYPE> getUtilityFunction() {
117    
118                    return utilityFunction;
119            }
120    
121    }