001    package aima.test.learningtest;
002    
003    import java.util.Hashtable;
004    
005    import junit.framework.TestCase;
006    import aima.learning.reinforcement.PassiveADPAgent;
007    import aima.learning.reinforcement.PassiveTDAgent;
008    import aima.learning.reinforcement.QLearningAgent;
009    import aima.learning.reinforcement.QTable;
010    import aima.probability.Randomizer;
011    import aima.probability.decision.MDP;
012    import aima.probability.decision.MDPFactory;
013    import aima.probability.decision.MDPPerception;
014    import aima.probability.decision.MDPPolicy;
015    import aima.probability.decision.MDPUtilityFunction;
016    import aima.probability.decision.cellworld.CellWorld;
017    import aima.probability.decision.cellworld.CellWorldPosition;
018    import aima.test.probabilitytest.MockRandomizer;
019    import aima.util.Pair;
020    
021    /**
022     * @author Ravi Mohan
023     * 
024     */
025    public class ReinforcementLearningTest extends TestCase {
026            MDP<CellWorldPosition, String> fourByThree;
027    
028            MDPPolicy<CellWorldPosition, String> policy;
029    
030            @Override
031            public void setUp() {
032                    fourByThree = MDPFactory.createFourByThreeMDP();
033    
034                    policy = new MDPPolicy<CellWorldPosition, String>();
035    
036                    policy.setAction(new CellWorldPosition(1, 1), CellWorld.UP);
037                    policy.setAction(new CellWorldPosition(1, 2), CellWorld.LEFT);
038                    policy.setAction(new CellWorldPosition(1, 3), CellWorld.LEFT);
039                    policy.setAction(new CellWorldPosition(1, 4), CellWorld.LEFT);
040    
041                    policy.setAction(new CellWorldPosition(2, 1), CellWorld.UP);
042                    policy.setAction(new CellWorldPosition(2, 3), CellWorld.UP);
043    
044                    policy.setAction(new CellWorldPosition(3, 1), CellWorld.RIGHT);
045                    policy.setAction(new CellWorldPosition(3, 2), CellWorld.RIGHT);
046                    policy.setAction(new CellWorldPosition(3, 3), CellWorld.RIGHT);
047            }
048    
049            public void testPassiveADPAgent() {
050    
051                    PassiveADPAgent<CellWorldPosition, String> agent = new PassiveADPAgent<CellWorldPosition, String>(
052                                    fourByThree, policy);
053    
054                    // Randomizer r = new JavaRandomizer();
055                    Randomizer r = new MockRandomizer(new double[] { 0.1, 0.9, 0.2, 0.8,
056                                    0.3, 0.7, 0.4, 0.6, 0.5 });
057                    MDPUtilityFunction<CellWorldPosition> uf = null;
058                    for (int i = 0; i < 100; i++) {
059                            agent.executeTrial(r);
060                            uf = agent.getUtilityFunction();
061    
062                    }
063    
064                    assertEquals(0.676, uf.getUtility(new CellWorldPosition(1, 1)), 0.001);
065                    assertEquals(0.626, uf.getUtility(new CellWorldPosition(1, 2)), 0.001);
066                    assertEquals(0.573, uf.getUtility(new CellWorldPosition(1, 3)), 0.001);
067                    assertEquals(0.519, uf.getUtility(new CellWorldPosition(1, 4)), 0.001);
068    
069                    assertEquals(0.746, uf.getUtility(new CellWorldPosition(2, 1)), 0.001);
070                    assertEquals(0.865, uf.getUtility(new CellWorldPosition(2, 3)), 0.001);
071                    // assertEquals(-1.0, uf.getUtility(new
072                    // CellWorldPosition(2,4)),0.001);//the pseudo random genrator never
073                    // gets to this square
074    
075                    assertEquals(0.796, uf.getUtility(new CellWorldPosition(3, 1)), 0.001);
076                    assertEquals(0.906, uf.getUtility(new CellWorldPosition(3, 3)), 0.001);
077                    assertEquals(1.0, uf.getUtility(new CellWorldPosition(3, 4)), 0.001);
078            }
079    
080            public void testPassiveTDAgent() {
081                    PassiveTDAgent<CellWorldPosition, String> agent = new PassiveTDAgent<CellWorldPosition, String>(
082                                    fourByThree, policy);
083                    // Randomizer r = new JavaRandomizer();
084                    Randomizer r = new MockRandomizer(new double[] { 0.1, 0.9, 0.2, 0.8,
085                                    0.3, 0.7, 0.4, 0.6, 0.5 });
086                    MDPUtilityFunction<CellWorldPosition> uf = null;
087                    for (int i = 0; i < 200; i++) {
088                            agent.executeTrial(r);
089                            uf = agent.getUtilityFunction();
090                            // System.out.println(uf);
091    
092                    }
093    
094                    assertEquals(0.662, uf.getUtility(new CellWorldPosition(1, 1)), 0.001);
095                    assertEquals(0.610, uf.getUtility(new CellWorldPosition(1, 2)), 0.001);
096                    assertEquals(0.553, uf.getUtility(new CellWorldPosition(1, 3)), 0.001);
097                    assertEquals(0.496, uf.getUtility(new CellWorldPosition(1, 4)), 0.001);
098    
099                    assertEquals(0.735, uf.getUtility(new CellWorldPosition(2, 1)), 0.001);
100                    assertEquals(0.835, uf.getUtility(new CellWorldPosition(2, 3)), 0.001);
101                    // assertEquals(-1.0, uf.getUtility(new
102                    // CellWorldPosition(2,4)),0.001);//the pseudo random genrator never
103                    // gets to this square
104    
105                    assertEquals(0.789, uf.getUtility(new CellWorldPosition(3, 1)), 0.001);
106                    assertEquals(0.889, uf.getUtility(new CellWorldPosition(3, 3)), 0.001);
107                    assertEquals(1.0, uf.getUtility(new CellWorldPosition(3, 4)), 0.001);
108            }
109    
110            public void xtestQLearningAgent() {
111                    QLearningAgent<CellWorldPosition, String> qla = new QLearningAgent(
112                                    fourByThree);
113                    Randomizer r = new MockRandomizer(new double[] { 0.1, 0.9, 0.2, 0.8,
114                                    0.3, 0.7, 0.4, 0.6, 0.5 });
115    
116                    // Randomizer r = new JavaRandomizer();
117                    Hashtable<Pair<CellWorldPosition, String>, Double> q = null;
118                    QTable<CellWorldPosition, String> qTable = null;
119                    for (int i = 0; i < 100; i++) {
120                            qla.executeTrial(r);
121                            q = qla.getQ();
122                            qTable = qla.getQTable();
123    
124                    }
125                    // qTable.normalize();
126                    System.out.println(qTable);
127                    System.out.println(qTable.getPolicy());
128    
129                    // first step
130    
131            }
132    
133            public void testFirstStepsOfQLAAgentUnderNormalProbability() {
134                    QLearningAgent<CellWorldPosition, String> qla = new QLearningAgent<CellWorldPosition, String>(
135                                    fourByThree);
136    
137                    Randomizer alwaysLessThanEightyPercent = new MockRandomizer(
138                                    new double[] { 0.7 });
139                    CellWorldPosition startingPosition = new CellWorldPosition(1, 4);
140                    String action = qla.decideAction(new MDPPerception<CellWorldPosition>(
141                                    startingPosition, -0.04));
142                    assertEquals(CellWorld.LEFT, action);
143                    assertEquals(0.0, qla.getQTable().getQValue(startingPosition, action));
144    
145                    qla.execute(action, alwaysLessThanEightyPercent);
146                    assertEquals(new CellWorldPosition(1, 3), qla.getCurrentState());
147                    assertEquals(-0.04, qla.getCurrentReward());
148                    assertEquals(0.0, qla.getQTable().getQValue(startingPosition, action));
149                    String action2 = qla.decideAction(new MDPPerception<CellWorldPosition>(
150                                    new CellWorldPosition(1, 3), -0.04));
151    
152                    assertEquals(-0.04, qla.getQTable().getQValue(startingPosition, action));
153    
154            }
155    
156            public void testFirstStepsOfQLAAgentWhenFirstStepTerminates() {
157                    QLearningAgent<CellWorldPosition, String> qla = new QLearningAgent<CellWorldPosition, String>(
158                                    fourByThree);
159    
160                    CellWorldPosition startingPosition = new CellWorldPosition(1, 4);
161                    String action = qla.decideAction(new MDPPerception<CellWorldPosition>(
162                                    startingPosition, -0.04));
163                    assertEquals(CellWorld.LEFT, action);
164    
165                    Randomizer betweenEightyANdNinetyPercent = new MockRandomizer(
166                                    new double[] { 0.85 }); // to force left to become an "up"
167                    qla.execute(action, betweenEightyANdNinetyPercent);
168                    assertEquals(new CellWorldPosition(2, 4), qla.getCurrentState());
169                    assertEquals(-1.0, qla.getCurrentReward());
170                    assertEquals(0.0, qla.getQTable().getQValue(startingPosition, action));
171                    String action2 = qla.decideAction(new MDPPerception<CellWorldPosition>(
172                                    new CellWorldPosition(2, 4), -1));
173                    assertNull(action2);
174                    assertEquals(-1.0, qla.getQTable().getQValue(startingPosition, action));
175            }
176    
177    }