001    package aima.test.probreasoningtest;
002    
003    import java.util.ArrayList;
004    import java.util.List;
005    
006    import junit.framework.TestCase;
007    import aima.probability.RandomVariable;
008    import aima.probability.reasoning.FixedLagSmoothing;
009    import aima.probability.reasoning.HMMFactory;
010    import aima.probability.reasoning.HiddenMarkovModel;
011    import aima.probability.reasoning.HmmConstants;
012    
013    /**
014     * @author Ravi Mohan
015     * 
016     */
017    public class HMMTest extends TestCase {
018            private HiddenMarkovModel robotHmm, rainmanHmm;
019    
020            private static final double TOLERANCE = 0.001;
021    
022            @Override
023            public void setUp() {
024                    robotHmm = HMMFactory.createRobotHMM();
025                    rainmanHmm = HMMFactory.createRainmanHMM();
026            }
027    
028            public void testRobotHMMInitialization() {
029    
030                    assertEquals(0.5, robotHmm.prior().getProbabilityOf(
031                                    HmmConstants.DOOR_OPEN));
032                    assertEquals(0.5, robotHmm.prior().getProbabilityOf(
033                                    HmmConstants.DOOR_CLOSED));
034            }
035    
036            public void testRainmanHmmInitialization() {
037    
038                    assertEquals(0.5, rainmanHmm.prior().getProbabilityOf(
039                                    HmmConstants.RAINING));
040                    assertEquals(0.5, rainmanHmm.prior().getProbabilityOf(
041                                    HmmConstants.NOT_RAINING));
042            }
043    
044            public void testForwardMessagingWorksForFiltering() {
045                    RandomVariable afterOneStep = robotHmm.forward(robotHmm.prior(),
046                                    HmmConstants.DO_NOTHING, HmmConstants.SEE_DOOR_OPEN);
047                    assertEquals(0.75, afterOneStep
048                                    .getProbabilityOf(HmmConstants.DOOR_OPEN), TOLERANCE);
049                    assertEquals(0.25, afterOneStep
050                                    .getProbabilityOf(HmmConstants.DOOR_CLOSED), TOLERANCE);
051    
052                    RandomVariable afterTwoSteps = robotHmm.forward(afterOneStep,
053                                    HmmConstants.PUSH_DOOR, HmmConstants.SEE_DOOR_OPEN);
054                    assertEquals(0.983, afterTwoSteps
055                                    .getProbabilityOf(HmmConstants.DOOR_OPEN), TOLERANCE);
056                    assertEquals(0.017, afterTwoSteps
057                                    .getProbabilityOf(HmmConstants.DOOR_CLOSED), TOLERANCE);
058            }
059    
060            public void testRecursiveBackwardMessageCalculationIsCorrect() {
061                    RandomVariable afterOneStep = rainmanHmm.forward(rainmanHmm.prior(),
062                                    HmmConstants.DO_NOTHING, HmmConstants.SEE_UMBRELLA);
063                    RandomVariable afterTwoSteps = rainmanHmm.forward(afterOneStep,
064                                    HmmConstants.DO_NOTHING, HmmConstants.SEE_UMBRELLA);
065    
066                    RandomVariable postSequence = afterTwoSteps.duplicate()
067                                    .createUnitBelief();
068    
069                    RandomVariable smoothed = rainmanHmm.calculate_next_backward_message(
070                                    afterOneStep, postSequence, HmmConstants.SEE_UMBRELLA);
071                    assertEquals(0.883, smoothed.getProbabilityOf(HmmConstants.RAINING),
072                                    TOLERANCE);
073                    assertEquals(0.117,
074                                    smoothed.getProbabilityOf(HmmConstants.NOT_RAINING), TOLERANCE);
075    
076            }
077    
078            public void testForwardBackwardOnRainmanHmm() {
079                    List<String> perceptions = new ArrayList<String>();
080                    perceptions.add(HmmConstants.SEE_UMBRELLA);
081                    perceptions.add(HmmConstants.SEE_UMBRELLA);
082    
083                    List<RandomVariable> results = rainmanHmm.forward_backward(perceptions);
084                    assertEquals(3, results.size());
085    
086                    assertNull(results.get(0));
087                    RandomVariable smoothedDayOne = results.get(1);
088                    assertEquals(0.982, smoothedDayOne
089                                    .getProbabilityOf(HmmConstants.RAINING), TOLERANCE);
090                    assertEquals(0.018, smoothedDayOne
091                                    .getProbabilityOf(HmmConstants.NOT_RAINING), TOLERANCE);
092    
093                    RandomVariable smoothedDayTwo = results.get(2);
094                    assertEquals(0.883, smoothedDayTwo
095                                    .getProbabilityOf(HmmConstants.RAINING), TOLERANCE);
096                    assertEquals(0.117, smoothedDayTwo
097                                    .getProbabilityOf(HmmConstants.NOT_RAINING), TOLERANCE);
098    
099            }
100    
101            public void testForwardBackwardOnRainmanHmmFor3daysData() {
102                    List<String> perceptions = new ArrayList<String>();
103                    perceptions.add(HmmConstants.SEE_UMBRELLA);
104                    perceptions.add(HmmConstants.SEE_UMBRELLA);
105                    perceptions.add(HmmConstants.SEE_NO_UMBRELLA);
106    
107                    List<RandomVariable> results = rainmanHmm.forward_backward(perceptions);
108                    assertEquals(4, results.size());
109                    assertNull(results.get(0));
110    
111                    RandomVariable smoothedDayOne = results.get(1);
112                    assertEquals(0.964, smoothedDayOne
113                                    .getProbabilityOf(HmmConstants.RAINING), TOLERANCE);
114                    assertEquals(0.036, smoothedDayOne
115                                    .getProbabilityOf(HmmConstants.NOT_RAINING), TOLERANCE);
116    
117                    RandomVariable smoothedDayTwo = results.get(2);
118                    assertEquals(0.484, smoothedDayTwo
119                                    .getProbabilityOf(HmmConstants.RAINING), TOLERANCE);
120                    assertEquals(0.516, smoothedDayTwo
121                                    .getProbabilityOf(HmmConstants.NOT_RAINING), TOLERANCE);
122    
123                    RandomVariable smoothedDayThree = results.get(3);
124                    assertEquals(0.190, smoothedDayThree
125                                    .getProbabilityOf(HmmConstants.RAINING), TOLERANCE);
126                    assertEquals(0.810, smoothedDayThree
127                                    .getProbabilityOf(HmmConstants.NOT_RAINING), TOLERANCE);
128            }
129    
130            public void xtestForwardBackwardAndFixedLagSmoothingGiveSameResults() {
131    
132                    // test disabled pending algorithm clarification
133                    List<String> perceptions = new ArrayList<String>();
134    
135                    String dayOnePerception = HmmConstants.SEE_UMBRELLA;
136                    String dayTwoPerception = HmmConstants.SEE_UMBRELLA;
137                    String dayThreePerception = HmmConstants.SEE_NO_UMBRELLA;
138    
139                    perceptions.add(dayOnePerception);
140                    perceptions.add(dayTwoPerception);
141                    perceptions.add(dayThreePerception);
142    
143                    List<RandomVariable> fbResults = rainmanHmm
144                                    .forward_backward(perceptions);
145                    assertEquals(4, fbResults.size());
146    
147                    RandomVariable fbDayOneResult = fbResults.get(1);
148                    System.out.println(fbDayOneResult);
149    
150                    FixedLagSmoothing fls = new FixedLagSmoothing(rainmanHmm, 2);
151    
152                    assertNull(fls.smooth(dayOnePerception));
153                    System.out.println(fls.smooth(dayTwoPerception));
154                    RandomVariable flsDayoneResult = fls.smooth(dayThreePerception);
155                    System.out.println(flsDayoneResult);
156    
157            }
158    
159            public void testOneStepFixedLagSmoothingOnRainManHmm() {
160                    FixedLagSmoothing fls = new FixedLagSmoothing(rainmanHmm, 1);
161    
162                    RandomVariable smoothedDayZero = fls.smooth(HmmConstants.SEE_UMBRELLA); // see
163                    // umbrella
164                    // on
165                    // day
166                    // one
167                    assertEquals(0.627, smoothedDayZero
168                                    .getProbabilityOf(HmmConstants.RAINING), TOLERANCE);
169    
170                    RandomVariable smoothedDayOne = fls.smooth(HmmConstants.SEE_UMBRELLA); // see
171                    // umbrella
172                    // on
173                    // day
174                    // two
175                    assertEquals(0.883, smoothedDayOne
176                                    .getProbabilityOf(HmmConstants.RAINING), TOLERANCE);
177                    assertEquals(0.117, smoothedDayOne
178                                    .getProbabilityOf(HmmConstants.NOT_RAINING), TOLERANCE);
179    
180                    RandomVariable smoothedDayTwo = fls
181                                    .smooth(HmmConstants.SEE_NO_UMBRELLA); // see no umbrella on
182                    // day three
183                    assertEquals(0.799, smoothedDayTwo
184                                    .getProbabilityOf(HmmConstants.RAINING), TOLERANCE);
185                    assertEquals(0.201, smoothedDayTwo
186                                    .getProbabilityOf(HmmConstants.NOT_RAINING), TOLERANCE);
187            }
188    
189            public void testOneStepFixedLagSmoothingOnRainManHmmWithDifferingEvidence() {
190                    FixedLagSmoothing fls = new FixedLagSmoothing(rainmanHmm, 1);
191    
192                    RandomVariable smoothedDayZero = fls.smooth(HmmConstants.SEE_UMBRELLA);// see
193                    // umbrella
194                    // on
195                    // day
196                    // one
197                    assertEquals(0.627, smoothedDayZero
198                                    .getProbabilityOf(HmmConstants.RAINING), TOLERANCE);
199    
200                    RandomVariable smoothedDayOne = fls
201                                    .smooth(HmmConstants.SEE_NO_UMBRELLA);// no umbrella on day
202                    // two
203                    assertEquals(0.702, smoothedDayOne
204                                    .getProbabilityOf(HmmConstants.RAINING), TOLERANCE);
205                    assertEquals(0.297, smoothedDayOne
206                                    .getProbabilityOf(HmmConstants.NOT_RAINING), TOLERANCE);
207            }
208    
209            public void testTwoStepFixedLagSmoothingOnRainManHmm() {
210                    FixedLagSmoothing fls = new FixedLagSmoothing(rainmanHmm, 2);
211    
212                    RandomVariable smoothedOne = fls.smooth(HmmConstants.SEE_UMBRELLA); // see
213                    // umbrella
214                    // on
215                    // day
216                    // one
217                    assertNull(smoothedOne);
218    
219                    smoothedOne = fls.smooth(HmmConstants.SEE_UMBRELLA); // see
220                    // umbrella
221                    // on
222                    // day
223                    // two
224                    assertEquals(0.653, smoothedOne.getProbabilityOf(HmmConstants.RAINING),
225                                    TOLERANCE);
226                    assertEquals(0.346, smoothedOne
227                                    .getProbabilityOf(HmmConstants.NOT_RAINING), TOLERANCE);
228    
229                    RandomVariable smoothedTwo = fls.smooth(HmmConstants.SEE_UMBRELLA);// see
230                    // umbrella
231                    // on
232                    // day
233                    // 3
234                    assertEquals(0.894, smoothedTwo.getProbabilityOf(HmmConstants.RAINING),
235                                    TOLERANCE);
236                    assertEquals(0.105, smoothedTwo
237                                    .getProbabilityOf(HmmConstants.NOT_RAINING), TOLERANCE);
238    
239            }
240    }