001    package aima.test.probabilitytest;
002    
003    import java.util.Hashtable;
004    
005    import junit.framework.TestCase;
006    import aima.probability.BayesNet;
007    import aima.probability.BayesNetNode;
008    import aima.probability.EnumerationAsk;
009    import aima.probability.Query;
010    
011    /**
012     * @author Ravi Mohan
013     * 
014     */
015    
016    public class ProbabilitySamplingTest extends TestCase {
017    
018            public void testPriorSample() {
019                    BayesNet net = createWetGrassNetwork();
020                    MockRandomizer r = new MockRandomizer(
021                                    new double[] { 0.5, 0.5, 0.5, 0.5 });
022                    Hashtable table = net.getPriorSample(r);
023                    assertEquals(4, table.keySet().size());
024                    assertEquals(Boolean.TRUE, table.get("Cloudy"));
025                    assertEquals(Boolean.FALSE, table.get("Sprinkler"));
026                    assertEquals(Boolean.TRUE, table.get("Rain"));
027                    assertEquals(Boolean.TRUE, table.get("WetGrass"));
028            }
029    
030            public void testRejectionSample() {
031                    BayesNet net = createWetGrassNetwork();
032                    MockRandomizer r = new MockRandomizer(new double[] { 0.1 });
033                    Hashtable<String, Boolean> evidence = new Hashtable<String, Boolean>();
034                    evidence.put("Sprinkler", Boolean.TRUE);
035                    double[] results = net.rejectionSample("Rain", evidence, 100, r);
036                    assertEquals(1.0, results[0], 0.001);
037                    assertEquals(0.0, results[1], 0.001);
038    
039            }
040    
041            public void testLikelihoodWeighting() {
042                    MockRandomizer r = new MockRandomizer(
043                                    new double[] { 0.5, 0.5, 0.5, 0.5 });
044                    BayesNet net = createWetGrassNetwork();
045                    Hashtable<String, Boolean> evidence = new Hashtable<String, Boolean>();
046                    evidence.put("Sprinkler", Boolean.TRUE);
047                    double[] results = net.likelihoodWeighting("Rain", evidence, 1000, r);
048                    // System.out.println(results[0] + " " + results[1]);
049                    assertEquals(1.0, results[0], 0.001);
050                    assertEquals(0.0, results[1], 0.001);
051            }
052    
053            public void testMCMCask() {
054                    BayesNet net = createWetGrassNetwork();
055                    MockRandomizer r = new MockRandomizer(
056                                    new double[] { 0.5, 0.5, 0.5, 0.5 });
057    
058                    Hashtable<String, Boolean> evidence = new Hashtable<String, Boolean>();
059                    evidence.put("Sprinkler", Boolean.TRUE);
060                    double[] results = net.mcmcAsk("Rain", evidence, 1, r);
061                    // System.out.println(results[0] + " " + results[1]);
062                    assertEquals(0.333, results[0], 0.001);
063                    assertEquals(0.666, results[1], 0.001);
064    
065            }
066    
067            public void testMCMCask2() {
068                    BayesNet net = createWetGrassNetwork();
069                    MockRandomizer r = new MockRandomizer(
070                                    new double[] { 0.5, 0.5, 0.5, 0.5 });
071    
072                    Hashtable<String, Boolean> evidence = new Hashtable<String, Boolean>();
073                    evidence.put("Sprinkler", Boolean.TRUE);
074                    double[] results = net.mcmcAsk("Rain", evidence, 1, r);
075                    // System.out.println(results[0] + " " + results[1]);
076                    assertEquals(0.333, results[0], 0.001);
077                    assertEquals(0.666, results[1], 0.001);
078    
079            }
080    
081            public void testEnumerationAskinMCMC() {
082                    BayesNet net = createWetGrassNetwork();
083                    MockRandomizer r = new MockRandomizer(
084                                    new double[] { 0.5, 0.5, 0.5, 0.5 });
085                    Hashtable<String, Boolean> evidence = new Hashtable<String, Boolean>();
086                    evidence.put("Rain", Boolean.TRUE);
087                    evidence.put("Sprinkler", Boolean.TRUE);
088                    Query q = new Query("Cloudy", new String[] { "Sprinkler", "Rain" },
089                                    new boolean[] { true, true });
090                    double[] results = EnumerationAsk.ask(q, net);
091                    double[] results2 = net.mcmcAsk("Cloudy", evidence, 1000);
092                    // System.out.println(results[0] + " " + results[1]);
093                    // System.out.println(results2[0] + " " + results2[1]);
094    
095            }
096    
097            private BayesNet createWetGrassNetwork() {
098                    BayesNetNode cloudy = new BayesNetNode("Cloudy");
099                    BayesNetNode sprinkler = new BayesNetNode("Sprinkler");
100                    BayesNetNode rain = new BayesNetNode("Rain");
101                    BayesNetNode wetGrass = new BayesNetNode("WetGrass");
102    
103                    sprinkler.influencedBy(cloudy);
104                    rain.influencedBy(cloudy);
105                    wetGrass.influencedBy(rain, sprinkler);
106    
107                    cloudy.setProbability(true, 0.5);
108                    sprinkler.setProbability(true, 0.10);
109                    sprinkler.setProbability(false, 0.50);
110    
111                    rain.setProbability(true, 0.8);
112                    rain.setProbability(false, 0.2);
113    
114                    wetGrass.setProbability(true, true, 0.99);
115                    wetGrass.setProbability(true, false, 0.90);
116                    wetGrass.setProbability(false, true, 0.90);
117                    wetGrass.setProbability(false, false, 0.00);
118    
119                    BayesNet net = new BayesNet(cloudy);
120                    return net;
121            }
122    }