001    /*
002     * Created on Feb 17, 2005
003     *
004     */
005    package aima.probability.demos;
006    
007    import java.util.ArrayList;
008    import java.util.Hashtable;
009    import java.util.List;
010    
011    import aima.learning.reinforcement.PassiveADPAgent;
012    import aima.learning.reinforcement.PassiveTDAgent;
013    import aima.learning.reinforcement.QLearningAgent;
014    import aima.learning.reinforcement.QTable;
015    import aima.probability.BayesNet;
016    import aima.probability.BayesNetNode;
017    import aima.probability.EnumerateJointAsk;
018    import aima.probability.EnumerationAsk;
019    import aima.probability.JavaRandomizer;
020    import aima.probability.ProbabilityDistribution;
021    import aima.probability.Query;
022    import aima.probability.RandomVariable;
023    import aima.probability.Randomizer;
024    import aima.probability.decision.MDP;
025    import aima.probability.decision.MDPFactory;
026    import aima.probability.decision.MDPPolicy;
027    import aima.probability.decision.MDPUtilityFunction;
028    import aima.probability.decision.cellworld.CellWorld;
029    import aima.probability.decision.cellworld.CellWorldPosition;
030    import aima.probability.reasoning.HMMFactory;
031    import aima.probability.reasoning.HiddenMarkovModel;
032    import aima.probability.reasoning.HmmConstants;
033    import aima.probability.reasoning.ParticleSet;
034    import aima.util.Pair;
035    
036    /**
037     * @author Ravi Mohan
038     * 
039     */
040    
041    public class ProbabilityDemo {
042    
043            public static void main(String[] args) {
044                    enumerationJointAskDemo();
045                    enumerationAskDemo();
046                    priorSampleDemo();
047                    rejectionSamplingDemo();
048                    likelihoodWeightingDemo();
049                    mcmcAskDemo();
050    
051                    forwardBackWardDemo();
052                    particleFilterinfDemo();
053    
054                    valueIterationDemo();
055                    policyIterationDemo();
056    
057                    passiveADPgentDemo();
058                    passiveTDAgentDemo();
059                    qLearningAgentDemo();
060            }
061    
062            private static void forwardBackWardDemo() {
063    
064                    System.out.println("\nForward BackWard Demo\n");
065    
066                    HiddenMarkovModel rainmanHmm = HMMFactory.createRainmanHMM();
067                    System.out
068                                    .println("Creating a Hdden Markov Model to represent the model in Fig 15.5 ");
069                    List<String> perceptions = new ArrayList<String>();
070                    perceptions.add(HmmConstants.SEE_UMBRELLA);
071                    perceptions.add(HmmConstants.SEE_UMBRELLA);
072    
073                    List<RandomVariable> results = rainmanHmm.forward_backward(perceptions);
074    
075                    RandomVariable smoothedDayOne = results.get(1);
076                    System.out.println("Smoothed Probability Of Raining on Day One = "
077                                    + smoothedDayOne.getProbabilityOf(HmmConstants.RAINING));
078                    System.out.println("Smoothed Probability Of NOT Raining on Day One ="
079                                    + smoothedDayOne.getProbabilityOf(HmmConstants.NOT_RAINING));
080    
081                    RandomVariable smoothedDayTwo = results.get(2);
082                    System.out.println("Smoothed Probability Of Raining on Day Two = "
083                                    + smoothedDayTwo.getProbabilityOf(HmmConstants.RAINING));
084                    System.out.println("Smoothed Probability Of NOT Raining on Day Two = "
085                                    + smoothedDayTwo.getProbabilityOf(HmmConstants.NOT_RAINING));
086    
087            }
088    
089            private static void particleFilterinfDemo() {
090                    System.out.println("\nParticle Filtering Demo\n");
091                    HiddenMarkovModel rainman = HMMFactory.createRainmanHMM();
092                    Randomizer r = new JavaRandomizer();
093                    ParticleSet starting = rainman.prior().toParticleSet(rainman, r, 1000);
094                    System.out.println("at the beginning, "
095                                    + starting.numberOfParticlesWithState(HmmConstants.RAINING)
096                                    + " particles 0f 1000 indicate status == 'raining' ");
097                    System.out.println("at the beginning, "
098                                    + starting.numberOfParticlesWithState(HmmConstants.NOT_RAINING)
099                                    + " particles of 1000 indicate status == 'NOT raining' ");
100    
101                    System.out
102                                    .println("\n Filtering Particle Set.On perception == 'SEE_UMBRELLA' ..\n");
103                    ParticleSet afterSeeingUmbrella = starting.filter(
104                                    HmmConstants.SEE_UMBRELLA, r);
105                    System.out.println("after filtering "
106                                    + afterSeeingUmbrella
107                                                    .numberOfParticlesWithState(HmmConstants.RAINING)
108                                    + " particles of 1000 indicate status == 'raining' ");
109                    System.out.println("after filtering "
110                                    + afterSeeingUmbrella
111                                                    .numberOfParticlesWithState(HmmConstants.NOT_RAINING)
112                                    + " particles of 1000 indicate status == 'NOT raining' ");
113    
114            }
115    
116            private static void valueIterationDemo() {
117    
118                    System.out.println("\nValue Iteration Demo\n");
119                    System.out.println("creating an MDP to represent the 4 X 3 world");
120                    MDP<CellWorldPosition, String> fourByThreeMDP = MDPFactory
121                                    .createFourByThreeMDP();
122    
123                    System.out.println("Beginning Value Iteration");
124                    MDPUtilityFunction<CellWorldPosition> uf = fourByThreeMDP
125                                    .valueIterationTillMAximumUtilityGrowthFallsBelowErrorMargin(1,
126                                                    0.00001);
127                    for (int i = 1; i <= 3; i++) {
128                            for (int j = 1; j <= 4; j++) {
129                                    if (!((i == 2) && (j == 2))) {
130                                            printUtility(uf, i, j);
131                                    }
132    
133                            }
134                    }
135    
136            }
137    
138            private static void printUtility(MDPUtilityFunction<CellWorldPosition> uf,
139                            int i, int j) {
140                    System.out.println("Utility of (" + i + " , " + j + " ) "
141                                    + uf.getUtility(new CellWorldPosition(i, j)));
142    
143            }
144    
145            private static void policyIterationDemo() {
146    
147                    System.out.println("\nPolicy Iteration Demo\n");
148                    System.out.println("\nValue Iteration Demo\n");
149                    System.out.println("creating an MDP to represent the 4 X 3 world");
150                    MDP<CellWorldPosition, String> fourByThreeMDP = MDPFactory
151                                    .createFourByThreeMDP();
152                    MDPPolicy<CellWorldPosition, String> policy = fourByThreeMDP
153                                    .policyIteration(1);
154                    for (int i = 1; i <= 3; i++) {
155                            for (int j = 1; j <= 4; j++) {
156                                    if (!((i == 2) && (j == 2))) {
157                                            printPolicy(i, j, policy);
158                                    }
159                            }
160    
161                    }
162            }
163    
164            private static void printPolicy(int i, int j,
165                            MDPPolicy<CellWorldPosition, String> policy) {
166                    System.out.println("Reccomended Action for (" + i + " , " + j
167                                    + " )  =  " + policy.getAction(new CellWorldPosition(i, j)));
168    
169            }
170    
171            private static void passiveADPgentDemo() {
172                    System.out.println("\nPassive ADP Agent Demo\n");
173                    System.out.println("creating an MDP to represent the 4 X 3 world");
174                    MDP<CellWorldPosition, String> fourByThree = MDPFactory
175                                    .createFourByThreeMDP();
176                    ;
177    
178                    MDPPolicy<CellWorldPosition, String> policy = new MDPPolicy<CellWorldPosition, String>();
179                    System.out
180                                    .println("Creating a policy to reflect the policy in Fig 17.3");
181                    policy.setAction(new CellWorldPosition(1, 1), CellWorld.UP);
182                    policy.setAction(new CellWorldPosition(1, 2), CellWorld.LEFT);
183                    policy.setAction(new CellWorldPosition(1, 3), CellWorld.LEFT);
184                    policy.setAction(new CellWorldPosition(1, 4), CellWorld.LEFT);
185    
186                    policy.setAction(new CellWorldPosition(2, 1), CellWorld.UP);
187                    policy.setAction(new CellWorldPosition(2, 3), CellWorld.UP);
188    
189                    policy.setAction(new CellWorldPosition(3, 1), CellWorld.RIGHT);
190                    policy.setAction(new CellWorldPosition(3, 2), CellWorld.RIGHT);
191                    policy.setAction(new CellWorldPosition(3, 3), CellWorld.RIGHT);
192    
193                    PassiveADPAgent<CellWorldPosition, String> agent = new PassiveADPAgent<CellWorldPosition, String>(
194                                    fourByThree, policy);
195    
196                    Randomizer r = new JavaRandomizer();
197                    System.out
198                                    .println("Deriving Utility Function using the Passive ADP Agent  From 100 trials in the 4 by 3 world");
199                    MDPUtilityFunction<CellWorldPosition> uf = null;
200                    for (int i = 0; i < 100; i++) {
201                            agent.executeTrial(r);
202                            uf = agent.getUtilityFunction();
203    
204                    }
205    
206                    for (int i = 1; i <= 3; i++) {
207                            for (int j = 1; j <= 4; j++) {
208                                    if (!((i == 2) && (j == 2))) {
209                                            printUtility(uf, i, j);
210                                    }
211    
212                            }
213                    }
214    
215            }
216    
217            private static void passiveTDAgentDemo() {
218                    System.out.println("\nPassive TD Agent Demo\n");
219                    System.out.println("creating an MDP to represent the 4 X 3 world");
220                    MDP<CellWorldPosition, String> fourByThree = MDPFactory
221                                    .createFourByThreeMDP();
222                    ;
223    
224                    MDPPolicy<CellWorldPosition, String> policy = new MDPPolicy<CellWorldPosition, String>();
225                    System.out
226                                    .println("Creating a policy to reflect the policy in Fig 17.3");
227                    policy.setAction(new CellWorldPosition(1, 1), CellWorld.UP);
228                    policy.setAction(new CellWorldPosition(1, 2), CellWorld.LEFT);
229                    policy.setAction(new CellWorldPosition(1, 3), CellWorld.LEFT);
230                    policy.setAction(new CellWorldPosition(1, 4), CellWorld.LEFT);
231    
232                    policy.setAction(new CellWorldPosition(2, 1), CellWorld.UP);
233                    policy.setAction(new CellWorldPosition(2, 3), CellWorld.UP);
234    
235                    policy.setAction(new CellWorldPosition(3, 1), CellWorld.RIGHT);
236                    policy.setAction(new CellWorldPosition(3, 2), CellWorld.RIGHT);
237                    policy.setAction(new CellWorldPosition(3, 3), CellWorld.RIGHT);
238                    PassiveTDAgent<CellWorldPosition, String> agent = new PassiveTDAgent<CellWorldPosition, String>(
239                                    fourByThree, policy);
240                    Randomizer r = new JavaRandomizer();
241                    System.out
242                                    .println("Deriving Utility Function in the Passive ADP Agent  From 200 trials in the 4 by 3 world");
243                    MDPUtilityFunction<CellWorldPosition> uf = null;
244                    for (int i = 0; i < 200; i++) {
245                            agent.executeTrial(r);
246                            uf = agent.getUtilityFunction();
247                            // System.out.println(uf);
248    
249                    }
250                    for (int i = 1; i <= 3; i++) {
251                            for (int j = 1; j <= 4; j++) {
252                                    if (!((i == 2) && (j == 2))) {
253                                            printUtility(uf, i, j);
254                                    }
255    
256                            }
257                    }
258    
259            }
260    
261            private static void qLearningAgentDemo() {
262                    System.out.println("\nQ Learning Agent Demo Demo\n");
263                    System.out.println("creating an MDP to represent the 4 X 3 world");
264                    MDP<CellWorldPosition, String> fourByThree = MDPFactory
265                                    .createFourByThreeMDP();
266                    ;
267                    QLearningAgent<CellWorldPosition, String> qla = new QLearningAgent<CellWorldPosition, String>(
268                                    fourByThree);
269                    Randomizer r = new JavaRandomizer();
270    
271                    // Randomizer r = new JavaRandomizer();
272                    Hashtable<Pair<CellWorldPosition, String>, Double> q = null;
273                    QTable<CellWorldPosition, String> qTable = null;
274                    System.out.println("After 100 trials in the 4 by 3 world");
275                    for (int i = 0; i < 100; i++) {
276                            qla.executeTrial(r);
277                            q = qla.getQ();
278                            qTable = qla.getQTable();
279    
280                    }
281                    System.out.println("Final Q table" + qTable);
282    
283            }
284    
285            public static void enumerationJointAskDemo() {
286                    System.out.println("\nEnumerationJointAsk Demo\n");
287                    ProbabilityDistribution jp = new ProbabilityDistribution("ToothAche",
288                                    "Cavity", "Catch");
289                    jp.set(true, true, true, 0.108);
290                    jp.set(true, true, false, 0.012);
291                    jp.set(false, true, true, 0.072);
292                    jp.set(false, true, false, 0.008);
293                    jp.set(true, false, true, 0.016);
294                    jp.set(true, false, false, 0.064);
295                    jp.set(false, false, true, 0.144);
296                    jp.set(false, false, false, 0.008);
297    
298                    Query q = new Query("Cavity", new String[] { "ToothAche" },
299                                    new boolean[] { true });
300                    double[] probs = EnumerateJointAsk.ask(q, jp);
301                    System.out
302                                    .println("Using the full joint distribution of page 475 of Aima 2nd Edition");
303                    System.out
304                                    .println("Probability distribution of ToothAche using Enumeration joint ask is "
305                                                    + string(probs));
306            }
307    
308            private static void priorSampleDemo() {
309                    System.out.println("\nPriorSample Demo\n");
310                    BayesNet net = createWetGrassNetwork();
311                    System.out
312                                    .println("Using the Bayesian Network from page 510 of AIMA 2nd Edition generates");
313                    Hashtable table = net.getPriorSample();
314                    System.out.println(table.toString());
315            }
316    
317            private static void rejectionSamplingDemo() {
318                    BayesNet net = createWetGrassNetwork();
319                    Hashtable<String, Boolean> evidence = new Hashtable<String, Boolean>();
320                    evidence.put("Sprinkler", Boolean.TRUE);
321                    double[] results = net.rejectionSample("Rain", evidence, 100);
322                    System.out.println("\nRejectionSampling Demo\n");
323                    System.out
324                                    .println("Using the Bayesian Network from page 510 of AIMA 2nd Edition ");
325                    System.out
326                                    .println("and querying for P(Rain|Sprinkler=true) using 100 samples gives");
327                    System.out.println(string(results));
328    
329            }
330    
331            private static void likelihoodWeightingDemo() {
332                    BayesNet net = createWetGrassNetwork();
333                    Hashtable<String, Boolean> evidence = new Hashtable<String, Boolean>();
334                    evidence.put("Sprinkler", Boolean.TRUE);
335                    double[] results = net.likelihoodWeighting("Rain", evidence, 100);
336                    System.out.println("\nLikelihoodWeighting Demo\n");
337                    System.out
338                                    .println("Using the Bayesian Network from page 510 of AIMA 2nd Edition ");
339                    System.out
340                                    .println("and querying for P(Rain|Sprinkler=true) using 100 samples gives");
341                    System.out.println(string(results));
342    
343            }
344    
345            private static void mcmcAskDemo() {
346                    BayesNet net = createWetGrassNetwork();
347                    Hashtable<String, Boolean> evidence = new Hashtable<String, Boolean>();
348                    evidence.put("Sprinkler", Boolean.TRUE);
349                    double[] results = net.mcmcAsk("Rain", evidence, 100);
350                    System.out.println("\nMCMCAsk Demo\n");
351                    System.out
352                                    .println("Using the Bayesian Network from page 510 of AIMA 2nd Edition ");
353                    System.out
354                                    .println("and querying for P(Rain|Sprinkler=true) using 100 samples gives");
355                    System.out.println(string(results));
356    
357            }
358    
359            public static void enumerationAskDemo() {
360                    System.out.println("\nEnumerationAsk Demo\n");
361                    Query q = new Query("Burglary",
362                                    new String[] { "JohnCalls", "MaryCalls" }, new boolean[] {
363                                                    true, true });
364                    double[] probs = EnumerationAsk.ask(q, createBurglaryNetwork());
365                    System.out
366                                    .println("Using the Burglary BayesNet from page 494 of AIMA 2nd Edition");
367                    System.out
368                                    .println("Querying the probability of Burglary|JohnCalls=true, MaryCalls=true gives "
369                                                    + string(probs));
370    
371            }
372    
373            private static BayesNet createBurglaryNetwork() {
374                    BayesNetNode burglary = new BayesNetNode("Burglary");
375                    BayesNetNode earthquake = new BayesNetNode("EarthQuake");
376                    BayesNetNode alarm = new BayesNetNode("Alarm");
377                    BayesNetNode johnCalls = new BayesNetNode("JohnCalls");
378                    BayesNetNode maryCalls = new BayesNetNode("MaryCalls");
379    
380                    alarm.influencedBy(burglary, earthquake);
381                    johnCalls.influencedBy(alarm);
382                    maryCalls.influencedBy(alarm);
383    
384                    burglary.setProbability(true, 0.001);// TODO behaviour changes if
385                    // root node
386                    earthquake.setProbability(true, 0.002);
387    
388                    alarm.setProbability(true, true, 0.95);
389                    alarm.setProbability(true, false, 0.94);
390                    alarm.setProbability(false, true, 0.29);
391                    alarm.setProbability(false, false, 0.001);
392    
393                    johnCalls.setProbability(true, 0.90);
394                    johnCalls.setProbability(false, 0.05);
395    
396                    maryCalls.setProbability(true, 0.70);
397                    maryCalls.setProbability(false, 0.01);
398    
399                    BayesNet net = new BayesNet(burglary, earthquake);
400                    return net;
401            }
402    
403            private static BayesNet createWetGrassNetwork() {
404                    BayesNetNode cloudy = new BayesNetNode("Cloudy");
405                    BayesNetNode sprinkler = new BayesNetNode("Sprinkler");
406                    BayesNetNode rain = new BayesNetNode("Rain");
407                    BayesNetNode wetGrass = new BayesNetNode("WetGrass");
408    
409                    sprinkler.influencedBy(cloudy);
410                    rain.influencedBy(cloudy);
411                    wetGrass.influencedBy(rain, sprinkler);
412    
413                    cloudy.setProbability(true, 0.5);
414                    sprinkler.setProbability(true, 0.10);
415                    sprinkler.setProbability(false, 0.50);
416    
417                    rain.setProbability(true, 0.8);
418                    rain.setProbability(false, 0.2);
419    
420                    wetGrass.setProbability(true, true, 0.99);
421                    wetGrass.setProbability(true, false, 0.90);
422                    wetGrass.setProbability(false, true, 0.90);
423                    wetGrass.setProbability(false, false, 0.00);
424    
425                    BayesNet net = new BayesNet(cloudy);
426                    return net;
427            }
428    
429            private static String string(double[] probs) {
430                    return " [ " + probs[0] + " , " + probs[1] + " ] ";
431            }
432    
433    }