001 /* 002 * Created on Feb 17, 2005 003 * 004 */ 005 package aima.probability.demos; 006 007 import java.util.ArrayList; 008 import java.util.Hashtable; 009 import java.util.List; 010 011 import aima.learning.reinforcement.PassiveADPAgent; 012 import aima.learning.reinforcement.PassiveTDAgent; 013 import aima.learning.reinforcement.QLearningAgent; 014 import aima.learning.reinforcement.QTable; 015 import aima.probability.BayesNet; 016 import aima.probability.BayesNetNode; 017 import aima.probability.EnumerateJointAsk; 018 import aima.probability.EnumerationAsk; 019 import aima.probability.JavaRandomizer; 020 import aima.probability.ProbabilityDistribution; 021 import aima.probability.Query; 022 import aima.probability.RandomVariable; 023 import aima.probability.Randomizer; 024 import aima.probability.decision.MDP; 025 import aima.probability.decision.MDPFactory; 026 import aima.probability.decision.MDPPolicy; 027 import aima.probability.decision.MDPUtilityFunction; 028 import aima.probability.decision.cellworld.CellWorld; 029 import aima.probability.decision.cellworld.CellWorldPosition; 030 import aima.probability.reasoning.HMMFactory; 031 import aima.probability.reasoning.HiddenMarkovModel; 032 import aima.probability.reasoning.HmmConstants; 033 import aima.probability.reasoning.ParticleSet; 034 import aima.util.Pair; 035 036 /** 037 * @author Ravi Mohan 038 * 039 */ 040 041 public class ProbabilityDemo { 042 043 public static void main(String[] args) { 044 enumerationJointAskDemo(); 045 enumerationAskDemo(); 046 priorSampleDemo(); 047 rejectionSamplingDemo(); 048 likelihoodWeightingDemo(); 049 mcmcAskDemo(); 050 051 forwardBackWardDemo(); 052 particleFilterinfDemo(); 053 054 valueIterationDemo(); 055 policyIterationDemo(); 056 057 passiveADPgentDemo(); 058 passiveTDAgentDemo(); 059 qLearningAgentDemo(); 060 } 061 062 private static void forwardBackWardDemo() { 063 064 System.out.println("\nForward BackWard Demo\n"); 065 066 HiddenMarkovModel rainmanHmm = HMMFactory.createRainmanHMM(); 067 System.out 068 .println("Creating a Hdden Markov Model to represent the model in Fig 15.5 "); 069 List<String> perceptions = new ArrayList<String>(); 070 perceptions.add(HmmConstants.SEE_UMBRELLA); 071 perceptions.add(HmmConstants.SEE_UMBRELLA); 072 073 List<RandomVariable> results = rainmanHmm.forward_backward(perceptions); 074 075 RandomVariable smoothedDayOne = results.get(1); 076 System.out.println("Smoothed Probability Of Raining on Day One = " 077 + smoothedDayOne.getProbabilityOf(HmmConstants.RAINING)); 078 System.out.println("Smoothed Probability Of NOT Raining on Day One =" 079 + smoothedDayOne.getProbabilityOf(HmmConstants.NOT_RAINING)); 080 081 RandomVariable smoothedDayTwo = results.get(2); 082 System.out.println("Smoothed Probability Of Raining on Day Two = " 083 + smoothedDayTwo.getProbabilityOf(HmmConstants.RAINING)); 084 System.out.println("Smoothed Probability Of NOT Raining on Day Two = " 085 + smoothedDayTwo.getProbabilityOf(HmmConstants.NOT_RAINING)); 086 087 } 088 089 private static void particleFilterinfDemo() { 090 System.out.println("\nParticle Filtering Demo\n"); 091 HiddenMarkovModel rainman = HMMFactory.createRainmanHMM(); 092 Randomizer r = new JavaRandomizer(); 093 ParticleSet starting = rainman.prior().toParticleSet(rainman, r, 1000); 094 System.out.println("at the beginning, " 095 + starting.numberOfParticlesWithState(HmmConstants.RAINING) 096 + " particles 0f 1000 indicate status == 'raining' "); 097 System.out.println("at the beginning, " 098 + starting.numberOfParticlesWithState(HmmConstants.NOT_RAINING) 099 + " particles of 1000 indicate status == 'NOT raining' "); 100 101 System.out 102 .println("\n Filtering Particle Set.On perception == 'SEE_UMBRELLA' ..\n"); 103 ParticleSet afterSeeingUmbrella = starting.filter( 104 HmmConstants.SEE_UMBRELLA, r); 105 System.out.println("after filtering " 106 + afterSeeingUmbrella 107 .numberOfParticlesWithState(HmmConstants.RAINING) 108 + " particles of 1000 indicate status == 'raining' "); 109 System.out.println("after filtering " 110 + afterSeeingUmbrella 111 .numberOfParticlesWithState(HmmConstants.NOT_RAINING) 112 + " particles of 1000 indicate status == 'NOT raining' "); 113 114 } 115 116 private static void valueIterationDemo() { 117 118 System.out.println("\nValue Iteration Demo\n"); 119 System.out.println("creating an MDP to represent the 4 X 3 world"); 120 MDP<CellWorldPosition, String> fourByThreeMDP = MDPFactory 121 .createFourByThreeMDP(); 122 123 System.out.println("Beginning Value Iteration"); 124 MDPUtilityFunction<CellWorldPosition> uf = fourByThreeMDP 125 .valueIterationTillMAximumUtilityGrowthFallsBelowErrorMargin(1, 126 0.00001); 127 for (int i = 1; i <= 3; i++) { 128 for (int j = 1; j <= 4; j++) { 129 if (!((i == 2) && (j == 2))) { 130 printUtility(uf, i, j); 131 } 132 133 } 134 } 135 136 } 137 138 private static void printUtility(MDPUtilityFunction<CellWorldPosition> uf, 139 int i, int j) { 140 System.out.println("Utility of (" + i + " , " + j + " ) " 141 + uf.getUtility(new CellWorldPosition(i, j))); 142 143 } 144 145 private static void policyIterationDemo() { 146 147 System.out.println("\nPolicy Iteration Demo\n"); 148 System.out.println("\nValue Iteration Demo\n"); 149 System.out.println("creating an MDP to represent the 4 X 3 world"); 150 MDP<CellWorldPosition, String> fourByThreeMDP = MDPFactory 151 .createFourByThreeMDP(); 152 MDPPolicy<CellWorldPosition, String> policy = fourByThreeMDP 153 .policyIteration(1); 154 for (int i = 1; i <= 3; i++) { 155 for (int j = 1; j <= 4; j++) { 156 if (!((i == 2) && (j == 2))) { 157 printPolicy(i, j, policy); 158 } 159 } 160 161 } 162 } 163 164 private static void printPolicy(int i, int j, 165 MDPPolicy<CellWorldPosition, String> policy) { 166 System.out.println("Reccomended Action for (" + i + " , " + j 167 + " ) = " + policy.getAction(new CellWorldPosition(i, j))); 168 169 } 170 171 private static void passiveADPgentDemo() { 172 System.out.println("\nPassive ADP Agent Demo\n"); 173 System.out.println("creating an MDP to represent the 4 X 3 world"); 174 MDP<CellWorldPosition, String> fourByThree = MDPFactory 175 .createFourByThreeMDP(); 176 ; 177 178 MDPPolicy<CellWorldPosition, String> policy = new MDPPolicy<CellWorldPosition, String>(); 179 System.out 180 .println("Creating a policy to reflect the policy in Fig 17.3"); 181 policy.setAction(new CellWorldPosition(1, 1), CellWorld.UP); 182 policy.setAction(new CellWorldPosition(1, 2), CellWorld.LEFT); 183 policy.setAction(new CellWorldPosition(1, 3), CellWorld.LEFT); 184 policy.setAction(new CellWorldPosition(1, 4), CellWorld.LEFT); 185 186 policy.setAction(new CellWorldPosition(2, 1), CellWorld.UP); 187 policy.setAction(new CellWorldPosition(2, 3), CellWorld.UP); 188 189 policy.setAction(new CellWorldPosition(3, 1), CellWorld.RIGHT); 190 policy.setAction(new CellWorldPosition(3, 2), CellWorld.RIGHT); 191 policy.setAction(new CellWorldPosition(3, 3), CellWorld.RIGHT); 192 193 PassiveADPAgent<CellWorldPosition, String> agent = new PassiveADPAgent<CellWorldPosition, String>( 194 fourByThree, policy); 195 196 Randomizer r = new JavaRandomizer(); 197 System.out 198 .println("Deriving Utility Function using the Passive ADP Agent From 100 trials in the 4 by 3 world"); 199 MDPUtilityFunction<CellWorldPosition> uf = null; 200 for (int i = 0; i < 100; i++) { 201 agent.executeTrial(r); 202 uf = agent.getUtilityFunction(); 203 204 } 205 206 for (int i = 1; i <= 3; i++) { 207 for (int j = 1; j <= 4; j++) { 208 if (!((i == 2) && (j == 2))) { 209 printUtility(uf, i, j); 210 } 211 212 } 213 } 214 215 } 216 217 private static void passiveTDAgentDemo() { 218 System.out.println("\nPassive TD Agent Demo\n"); 219 System.out.println("creating an MDP to represent the 4 X 3 world"); 220 MDP<CellWorldPosition, String> fourByThree = MDPFactory 221 .createFourByThreeMDP(); 222 ; 223 224 MDPPolicy<CellWorldPosition, String> policy = new MDPPolicy<CellWorldPosition, String>(); 225 System.out 226 .println("Creating a policy to reflect the policy in Fig 17.3"); 227 policy.setAction(new CellWorldPosition(1, 1), CellWorld.UP); 228 policy.setAction(new CellWorldPosition(1, 2), CellWorld.LEFT); 229 policy.setAction(new CellWorldPosition(1, 3), CellWorld.LEFT); 230 policy.setAction(new CellWorldPosition(1, 4), CellWorld.LEFT); 231 232 policy.setAction(new CellWorldPosition(2, 1), CellWorld.UP); 233 policy.setAction(new CellWorldPosition(2, 3), CellWorld.UP); 234 235 policy.setAction(new CellWorldPosition(3, 1), CellWorld.RIGHT); 236 policy.setAction(new CellWorldPosition(3, 2), CellWorld.RIGHT); 237 policy.setAction(new CellWorldPosition(3, 3), CellWorld.RIGHT); 238 PassiveTDAgent<CellWorldPosition, String> agent = new PassiveTDAgent<CellWorldPosition, String>( 239 fourByThree, policy); 240 Randomizer r = new JavaRandomizer(); 241 System.out 242 .println("Deriving Utility Function in the Passive ADP Agent From 200 trials in the 4 by 3 world"); 243 MDPUtilityFunction<CellWorldPosition> uf = null; 244 for (int i = 0; i < 200; i++) { 245 agent.executeTrial(r); 246 uf = agent.getUtilityFunction(); 247 // System.out.println(uf); 248 249 } 250 for (int i = 1; i <= 3; i++) { 251 for (int j = 1; j <= 4; j++) { 252 if (!((i == 2) && (j == 2))) { 253 printUtility(uf, i, j); 254 } 255 256 } 257 } 258 259 } 260 261 private static void qLearningAgentDemo() { 262 System.out.println("\nQ Learning Agent Demo Demo\n"); 263 System.out.println("creating an MDP to represent the 4 X 3 world"); 264 MDP<CellWorldPosition, String> fourByThree = MDPFactory 265 .createFourByThreeMDP(); 266 ; 267 QLearningAgent<CellWorldPosition, String> qla = new QLearningAgent<CellWorldPosition, String>( 268 fourByThree); 269 Randomizer r = new JavaRandomizer(); 270 271 // Randomizer r = new JavaRandomizer(); 272 Hashtable<Pair<CellWorldPosition, String>, Double> q = null; 273 QTable<CellWorldPosition, String> qTable = null; 274 System.out.println("After 100 trials in the 4 by 3 world"); 275 for (int i = 0; i < 100; i++) { 276 qla.executeTrial(r); 277 q = qla.getQ(); 278 qTable = qla.getQTable(); 279 280 } 281 System.out.println("Final Q table" + qTable); 282 283 } 284 285 public static void enumerationJointAskDemo() { 286 System.out.println("\nEnumerationJointAsk Demo\n"); 287 ProbabilityDistribution jp = new ProbabilityDistribution("ToothAche", 288 "Cavity", "Catch"); 289 jp.set(true, true, true, 0.108); 290 jp.set(true, true, false, 0.012); 291 jp.set(false, true, true, 0.072); 292 jp.set(false, true, false, 0.008); 293 jp.set(true, false, true, 0.016); 294 jp.set(true, false, false, 0.064); 295 jp.set(false, false, true, 0.144); 296 jp.set(false, false, false, 0.008); 297 298 Query q = new Query("Cavity", new String[] { "ToothAche" }, 299 new boolean[] { true }); 300 double[] probs = EnumerateJointAsk.ask(q, jp); 301 System.out 302 .println("Using the full joint distribution of page 475 of Aima 2nd Edition"); 303 System.out 304 .println("Probability distribution of ToothAche using Enumeration joint ask is " 305 + string(probs)); 306 } 307 308 private static void priorSampleDemo() { 309 System.out.println("\nPriorSample Demo\n"); 310 BayesNet net = createWetGrassNetwork(); 311 System.out 312 .println("Using the Bayesian Network from page 510 of AIMA 2nd Edition generates"); 313 Hashtable table = net.getPriorSample(); 314 System.out.println(table.toString()); 315 } 316 317 private static void rejectionSamplingDemo() { 318 BayesNet net = createWetGrassNetwork(); 319 Hashtable<String, Boolean> evidence = new Hashtable<String, Boolean>(); 320 evidence.put("Sprinkler", Boolean.TRUE); 321 double[] results = net.rejectionSample("Rain", evidence, 100); 322 System.out.println("\nRejectionSampling Demo\n"); 323 System.out 324 .println("Using the Bayesian Network from page 510 of AIMA 2nd Edition "); 325 System.out 326 .println("and querying for P(Rain|Sprinkler=true) using 100 samples gives"); 327 System.out.println(string(results)); 328 329 } 330 331 private static void likelihoodWeightingDemo() { 332 BayesNet net = createWetGrassNetwork(); 333 Hashtable<String, Boolean> evidence = new Hashtable<String, Boolean>(); 334 evidence.put("Sprinkler", Boolean.TRUE); 335 double[] results = net.likelihoodWeighting("Rain", evidence, 100); 336 System.out.println("\nLikelihoodWeighting Demo\n"); 337 System.out 338 .println("Using the Bayesian Network from page 510 of AIMA 2nd Edition "); 339 System.out 340 .println("and querying for P(Rain|Sprinkler=true) using 100 samples gives"); 341 System.out.println(string(results)); 342 343 } 344 345 private static void mcmcAskDemo() { 346 BayesNet net = createWetGrassNetwork(); 347 Hashtable<String, Boolean> evidence = new Hashtable<String, Boolean>(); 348 evidence.put("Sprinkler", Boolean.TRUE); 349 double[] results = net.mcmcAsk("Rain", evidence, 100); 350 System.out.println("\nMCMCAsk Demo\n"); 351 System.out 352 .println("Using the Bayesian Network from page 510 of AIMA 2nd Edition "); 353 System.out 354 .println("and querying for P(Rain|Sprinkler=true) using 100 samples gives"); 355 System.out.println(string(results)); 356 357 } 358 359 public static void enumerationAskDemo() { 360 System.out.println("\nEnumerationAsk Demo\n"); 361 Query q = new Query("Burglary", 362 new String[] { "JohnCalls", "MaryCalls" }, new boolean[] { 363 true, true }); 364 double[] probs = EnumerationAsk.ask(q, createBurglaryNetwork()); 365 System.out 366 .println("Using the Burglary BayesNet from page 494 of AIMA 2nd Edition"); 367 System.out 368 .println("Querying the probability of Burglary|JohnCalls=true, MaryCalls=true gives " 369 + string(probs)); 370 371 } 372 373 private static BayesNet createBurglaryNetwork() { 374 BayesNetNode burglary = new BayesNetNode("Burglary"); 375 BayesNetNode earthquake = new BayesNetNode("EarthQuake"); 376 BayesNetNode alarm = new BayesNetNode("Alarm"); 377 BayesNetNode johnCalls = new BayesNetNode("JohnCalls"); 378 BayesNetNode maryCalls = new BayesNetNode("MaryCalls"); 379 380 alarm.influencedBy(burglary, earthquake); 381 johnCalls.influencedBy(alarm); 382 maryCalls.influencedBy(alarm); 383 384 burglary.setProbability(true, 0.001);// TODO behaviour changes if 385 // root node 386 earthquake.setProbability(true, 0.002); 387 388 alarm.setProbability(true, true, 0.95); 389 alarm.setProbability(true, false, 0.94); 390 alarm.setProbability(false, true, 0.29); 391 alarm.setProbability(false, false, 0.001); 392 393 johnCalls.setProbability(true, 0.90); 394 johnCalls.setProbability(false, 0.05); 395 396 maryCalls.setProbability(true, 0.70); 397 maryCalls.setProbability(false, 0.01); 398 399 BayesNet net = new BayesNet(burglary, earthquake); 400 return net; 401 } 402 403 private static BayesNet createWetGrassNetwork() { 404 BayesNetNode cloudy = new BayesNetNode("Cloudy"); 405 BayesNetNode sprinkler = new BayesNetNode("Sprinkler"); 406 BayesNetNode rain = new BayesNetNode("Rain"); 407 BayesNetNode wetGrass = new BayesNetNode("WetGrass"); 408 409 sprinkler.influencedBy(cloudy); 410 rain.influencedBy(cloudy); 411 wetGrass.influencedBy(rain, sprinkler); 412 413 cloudy.setProbability(true, 0.5); 414 sprinkler.setProbability(true, 0.10); 415 sprinkler.setProbability(false, 0.50); 416 417 rain.setProbability(true, 0.8); 418 rain.setProbability(false, 0.2); 419 420 wetGrass.setProbability(true, true, 0.99); 421 wetGrass.setProbability(true, false, 0.90); 422 wetGrass.setProbability(false, true, 0.90); 423 wetGrass.setProbability(false, false, 0.00); 424 425 BayesNet net = new BayesNet(cloudy); 426 return net; 427 } 428 429 private static String string(double[] probs) { 430 return " [ " + probs[0] + " , " + probs[1] + " ] "; 431 } 432 433 }