001 package aima.test.learningtest; 002 003 import java.util.Hashtable; 004 005 import junit.framework.TestCase; 006 import aima.learning.reinforcement.PassiveADPAgent; 007 import aima.learning.reinforcement.PassiveTDAgent; 008 import aima.learning.reinforcement.QLearningAgent; 009 import aima.learning.reinforcement.QTable; 010 import aima.probability.Randomizer; 011 import aima.probability.decision.MDP; 012 import aima.probability.decision.MDPFactory; 013 import aima.probability.decision.MDPPerception; 014 import aima.probability.decision.MDPPolicy; 015 import aima.probability.decision.MDPUtilityFunction; 016 import aima.probability.decision.cellworld.CellWorld; 017 import aima.probability.decision.cellworld.CellWorldPosition; 018 import aima.test.probabilitytest.MockRandomizer; 019 import aima.util.Pair; 020 021 /** 022 * @author Ravi Mohan 023 * 024 */ 025 public class ReinforcementLearningTest extends TestCase { 026 MDP<CellWorldPosition, String> fourByThree; 027 028 MDPPolicy<CellWorldPosition, String> policy; 029 030 @Override 031 public void setUp() { 032 fourByThree = MDPFactory.createFourByThreeMDP(); 033 034 policy = new MDPPolicy<CellWorldPosition, String>(); 035 036 policy.setAction(new CellWorldPosition(1, 1), CellWorld.UP); 037 policy.setAction(new CellWorldPosition(1, 2), CellWorld.LEFT); 038 policy.setAction(new CellWorldPosition(1, 3), CellWorld.LEFT); 039 policy.setAction(new CellWorldPosition(1, 4), CellWorld.LEFT); 040 041 policy.setAction(new CellWorldPosition(2, 1), CellWorld.UP); 042 policy.setAction(new CellWorldPosition(2, 3), CellWorld.UP); 043 044 policy.setAction(new CellWorldPosition(3, 1), CellWorld.RIGHT); 045 policy.setAction(new CellWorldPosition(3, 2), CellWorld.RIGHT); 046 policy.setAction(new CellWorldPosition(3, 3), CellWorld.RIGHT); 047 } 048 049 public void testPassiveADPAgent() { 050 051 PassiveADPAgent<CellWorldPosition, String> agent = new PassiveADPAgent<CellWorldPosition, String>( 052 fourByThree, policy); 053 054 // Randomizer r = new JavaRandomizer(); 055 Randomizer r = new MockRandomizer(new double[] { 0.1, 0.9, 0.2, 0.8, 056 0.3, 0.7, 0.4, 0.6, 0.5 }); 057 MDPUtilityFunction<CellWorldPosition> uf = null; 058 for (int i = 0; i < 100; i++) { 059 agent.executeTrial(r); 060 uf = agent.getUtilityFunction(); 061 062 } 063 064 assertEquals(0.676, uf.getUtility(new CellWorldPosition(1, 1)), 0.001); 065 assertEquals(0.626, uf.getUtility(new CellWorldPosition(1, 2)), 0.001); 066 assertEquals(0.573, uf.getUtility(new CellWorldPosition(1, 3)), 0.001); 067 assertEquals(0.519, uf.getUtility(new CellWorldPosition(1, 4)), 0.001); 068 069 assertEquals(0.746, uf.getUtility(new CellWorldPosition(2, 1)), 0.001); 070 assertEquals(0.865, uf.getUtility(new CellWorldPosition(2, 3)), 0.001); 071 // assertEquals(-1.0, uf.getUtility(new 072 // CellWorldPosition(2,4)),0.001);//the pseudo random genrator never 073 // gets to this square 074 075 assertEquals(0.796, uf.getUtility(new CellWorldPosition(3, 1)), 0.001); 076 assertEquals(0.906, uf.getUtility(new CellWorldPosition(3, 3)), 0.001); 077 assertEquals(1.0, uf.getUtility(new CellWorldPosition(3, 4)), 0.001); 078 } 079 080 public void testPassiveTDAgent() { 081 PassiveTDAgent<CellWorldPosition, String> agent = new PassiveTDAgent<CellWorldPosition, String>( 082 fourByThree, policy); 083 // Randomizer r = new JavaRandomizer(); 084 Randomizer r = new MockRandomizer(new double[] { 0.1, 0.9, 0.2, 0.8, 085 0.3, 0.7, 0.4, 0.6, 0.5 }); 086 MDPUtilityFunction<CellWorldPosition> uf = null; 087 for (int i = 0; i < 200; i++) { 088 agent.executeTrial(r); 089 uf = agent.getUtilityFunction(); 090 // System.out.println(uf); 091 092 } 093 094 assertEquals(0.662, uf.getUtility(new CellWorldPosition(1, 1)), 0.001); 095 assertEquals(0.610, uf.getUtility(new CellWorldPosition(1, 2)), 0.001); 096 assertEquals(0.553, uf.getUtility(new CellWorldPosition(1, 3)), 0.001); 097 assertEquals(0.496, uf.getUtility(new CellWorldPosition(1, 4)), 0.001); 098 099 assertEquals(0.735, uf.getUtility(new CellWorldPosition(2, 1)), 0.001); 100 assertEquals(0.835, uf.getUtility(new CellWorldPosition(2, 3)), 0.001); 101 // assertEquals(-1.0, uf.getUtility(new 102 // CellWorldPosition(2,4)),0.001);//the pseudo random genrator never 103 // gets to this square 104 105 assertEquals(0.789, uf.getUtility(new CellWorldPosition(3, 1)), 0.001); 106 assertEquals(0.889, uf.getUtility(new CellWorldPosition(3, 3)), 0.001); 107 assertEquals(1.0, uf.getUtility(new CellWorldPosition(3, 4)), 0.001); 108 } 109 110 public void xtestQLearningAgent() { 111 QLearningAgent<CellWorldPosition, String> qla = new QLearningAgent( 112 fourByThree); 113 Randomizer r = new MockRandomizer(new double[] { 0.1, 0.9, 0.2, 0.8, 114 0.3, 0.7, 0.4, 0.6, 0.5 }); 115 116 // Randomizer r = new JavaRandomizer(); 117 Hashtable<Pair<CellWorldPosition, String>, Double> q = null; 118 QTable<CellWorldPosition, String> qTable = null; 119 for (int i = 0; i < 100; i++) { 120 qla.executeTrial(r); 121 q = qla.getQ(); 122 qTable = qla.getQTable(); 123 124 } 125 // qTable.normalize(); 126 System.out.println(qTable); 127 System.out.println(qTable.getPolicy()); 128 129 // first step 130 131 } 132 133 public void testFirstStepsOfQLAAgentUnderNormalProbability() { 134 QLearningAgent<CellWorldPosition, String> qla = new QLearningAgent<CellWorldPosition, String>( 135 fourByThree); 136 137 Randomizer alwaysLessThanEightyPercent = new MockRandomizer( 138 new double[] { 0.7 }); 139 CellWorldPosition startingPosition = new CellWorldPosition(1, 4); 140 String action = qla.decideAction(new MDPPerception<CellWorldPosition>( 141 startingPosition, -0.04)); 142 assertEquals(CellWorld.LEFT, action); 143 assertEquals(0.0, qla.getQTable().getQValue(startingPosition, action)); 144 145 qla.execute(action, alwaysLessThanEightyPercent); 146 assertEquals(new CellWorldPosition(1, 3), qla.getCurrentState()); 147 assertEquals(-0.04, qla.getCurrentReward()); 148 assertEquals(0.0, qla.getQTable().getQValue(startingPosition, action)); 149 String action2 = qla.decideAction(new MDPPerception<CellWorldPosition>( 150 new CellWorldPosition(1, 3), -0.04)); 151 152 assertEquals(-0.04, qla.getQTable().getQValue(startingPosition, action)); 153 154 } 155 156 public void testFirstStepsOfQLAAgentWhenFirstStepTerminates() { 157 QLearningAgent<CellWorldPosition, String> qla = new QLearningAgent<CellWorldPosition, String>( 158 fourByThree); 159 160 CellWorldPosition startingPosition = new CellWorldPosition(1, 4); 161 String action = qla.decideAction(new MDPPerception<CellWorldPosition>( 162 startingPosition, -0.04)); 163 assertEquals(CellWorld.LEFT, action); 164 165 Randomizer betweenEightyANdNinetyPercent = new MockRandomizer( 166 new double[] { 0.85 }); // to force left to become an "up" 167 qla.execute(action, betweenEightyANdNinetyPercent); 168 assertEquals(new CellWorldPosition(2, 4), qla.getCurrentState()); 169 assertEquals(-1.0, qla.getCurrentReward()); 170 assertEquals(0.0, qla.getQTable().getQValue(startingPosition, action)); 171 String action2 = qla.decideAction(new MDPPerception<CellWorldPosition>( 172 new CellWorldPosition(2, 4), -1)); 173 assertNull(action2); 174 assertEquals(-1.0, qla.getQTable().getQValue(startingPosition, action)); 175 } 176 177 }