001    package aima.probability.reasoning;
002    
003    import java.util.Arrays;
004    import java.util.List;
005    
006    import aima.probability.RandomVariable;
007    import aima.util.Matrix;
008    
009    /**
010     * @author Ravi Mohan
011     * 
012     */
013    
014    public class HiddenMarkovModel {
015    
016            SensorModel sensorModel;
017    
018            TransitionModel transitionModel;
019    
020            private RandomVariable priorDistribution;
021    
022            public HiddenMarkovModel(RandomVariable priorDistribution,
023                            TransitionModel tm, SensorModel sm) {
024                    this.priorDistribution = priorDistribution;
025                    this.transitionModel = tm;
026                    this.sensorModel = sm;
027            }
028    
029            public RandomVariable prior() {
030                    return priorDistribution;
031            }
032    
033            public RandomVariable predict(RandomVariable aBelief, String action) {
034                    RandomVariable newBelief = aBelief.duplicate();
035    
036                    Matrix beliefMatrix = aBelief.asMatrix();
037                    Matrix transitionMatrix = transitionModel.asMatrix(action);
038                    Matrix predicted = transitionMatrix.transpose().times(beliefMatrix);
039                    newBelief.updateFrom(predicted);
040                    return newBelief;
041            }
042    
043            public RandomVariable perceptionUpdate(RandomVariable aBelief,
044                            String perception) {
045                    RandomVariable newBelief = aBelief.duplicate();
046    
047                    // one way - use matrices
048                    Matrix beliefMatrix = aBelief.asMatrix();
049                    Matrix o_matrix = sensorModel.asMatrix(perception);
050                    Matrix updated = o_matrix.times(beliefMatrix);
051                    newBelief.updateFrom(updated);
052                    newBelief.normalize();
053                    return newBelief;
054    
055                    // alternate way of doing this. clearer in intent.
056                    // for (String state : aBelief.states()){
057                    // double probabilityOfPerception= sensorModel.get(state,perception);
058                    // newBelief.setProbabilityOf(state,probabilityOfPerception *
059                    // aBelief.getProbabilityOf(state));
060                    // }
061            }
062    
063            public RandomVariable forward(RandomVariable aBelief, String action,
064                            String perception) {
065    
066                    return perceptionUpdate(predict(aBelief, action), perception);
067            }
068    
069            public RandomVariable forward(RandomVariable aBelief, String perception) {
070    
071                    return forward(aBelief, HmmConstants.DO_NOTHING, perception);
072            }
073    
074            public RandomVariable calculate_next_backward_message(
075                            RandomVariable forwardBelief,
076                            RandomVariable present_backward_message, String perception) {
077                    RandomVariable result = present_backward_message.duplicate();
078                    // System.out.println("fb :-calculating new backward message");
079                    // System.out.println("fb :-diagonal matrix from sens model = ");
080                    Matrix oMatrix = sensorModel.asMatrix(perception);
081                    // System.out.println(oMatrix);
082                    Matrix transitionMatrix = transitionModel.asMatrix();// action
083                    // should
084                    // be
085                    // passed
086                    // in
087                    // here?
088                    // System.out.println("fb :-present backward message = "
089                    // +present_backward_message);
090                    Matrix backwardMatrix = transitionMatrix.times(oMatrix
091                                    .times(present_backward_message.asMatrix()));
092                    Matrix resultMatrix = backwardMatrix.arrayTimes(forwardBelief
093                                    .asMatrix());
094                    result.updateFrom(resultMatrix);
095                    result.normalize();
096                    // System.out.println("fb :-normalized new backward message = "
097                    // +result);
098                    return result;
099            }
100    
101            public List<RandomVariable> forward_backward(List<String> perceptions) {
102                    RandomVariable forwardMessages[] = new RandomVariable[perceptions
103                                    .size() + 1];
104                    RandomVariable backwardMessage = priorDistribution.createUnitBelief();
105                    RandomVariable smoothedBeliefs[] = new RandomVariable[perceptions
106                                    .size() + 1];
107    
108                    forwardMessages[0] = priorDistribution;
109                    smoothedBeliefs[0] = null;
110    
111                    // populate forward messages
112                    for (int i = 0; i < perceptions.size(); i++) { // N.B i starts at 1,
113                            // not zero
114                            forwardMessages[i + 1] = forward(forwardMessages[i], perceptions
115                                            .get(i));
116                    }
117                    for (int i = perceptions.size(); i > 0; i--) {
118                            RandomVariable smoothed = priorDistribution.duplicate();
119                            smoothed.updateFrom(forwardMessages[i].asMatrix().arrayTimes(
120                                            backwardMessage.asMatrix()));
121                            smoothed.normalize();
122                            smoothedBeliefs[i] = smoothed;
123                            backwardMessage = calculate_next_backward_message(
124                                            forwardMessages[i], backwardMessage, perceptions.get(i - 1));
125                    }
126    
127                    return Arrays.asList(smoothedBeliefs);
128            }
129    
130            public SensorModel sensorModel() {
131    
132                    return sensorModel;
133            }
134    
135            public TransitionModel transitionModel() {
136    
137                    return transitionModel;
138            }
139    
140    }