001    package aima.probability;
002    
003    import java.util.ArrayList;
004    import java.util.Hashtable;
005    import java.util.List;
006    
007    import aima.probability.reasoning.HiddenMarkovModel;
008    import aima.probability.reasoning.Particle;
009    import aima.probability.reasoning.ParticleSet;
010    import aima.util.Matrix;
011    import aima.util.Util;
012    
013    /**
014     * @author Ravi Mohan
015     * 
016     */
017    
018    public class RandomVariable {
019            private String name;
020    
021            private Hashtable<String, Double> distribution;
022    
023            private List<String> states;
024    
025            public RandomVariable(List<String> states) {
026                    this("HiddenState", states);
027            }
028    
029            public RandomVariable(String name, List<String> states) {
030                    this.name = name;
031                    this.states = states;
032                    this.distribution = new Hashtable<String, Double>();
033                    int numberOfStates = states.size();
034                    double initialProbability = 1.0 / numberOfStates;
035                    for (String s : states) {
036                            distribution.put(s, initialProbability);
037                    }
038            }
039    
040            private RandomVariable(String name, List<String> states,
041                            Hashtable<String, Double> distribution) {
042                    this.name = name;
043                    this.states = states;
044                    this.distribution = distribution;
045            }
046    
047            public void setProbabilityOf(String state, Double probability) {
048                    if (states.contains(state)) {
049                            distribution.put(state, probability);
050                    } else {
051                            throw new RuntimeException(state + "  is an invalid state");
052                    }
053            }
054    
055            public double getProbabilityOf(String state) {
056                    if (states.contains(state)) {
057                            return distribution.get(state);
058                    } else {
059                            throw new RuntimeException(state + "  is an invalid state");
060                    }
061            }
062    
063            public List<String> states() {
064                    return states;
065            }
066    
067            public RandomVariable duplicate() {
068                    Hashtable<String, Double> probs = new Hashtable<String, Double>();
069                    for (String key : distribution.keySet()) {
070                            probs.put(key, distribution.get(key));
071                    }
072                    return new RandomVariable(name, states, probs);
073    
074            }
075    
076            public void normalize() {
077                    List<Double> probs = new ArrayList<Double>();
078                    for (String s : states) {
079                            probs.add(distribution.get(s));
080                    }
081                    List<Double> newProbs = Util.normalize(probs);
082                    for (int i = 0; i < states.size(); i++) {
083                            distribution.put(states.get(i), newProbs.get(i));
084                    }
085            }
086    
087            public Matrix asMatrix() {
088                    Matrix m = new Matrix(states.size(), 1);
089                    for (int i = 0; i < states.size(); i++) {
090                            m.set(i, 0, distribution.get(states.get(i)));
091                    }
092                    return m;
093    
094            }
095    
096            public void updateFrom(Matrix aMatrix) {
097                    for (int i = 0; i < states.size(); i++) {
098                            distribution.put(states.get(i), aMatrix.get(i, 0));
099                    }
100    
101            }
102    
103            public RandomVariable createUnitBelief() {
104                    RandomVariable result = duplicate();
105                    for (String s : states()) {
106                            result.setProbabilityOf(s, 1.0);
107                    }
108                    return result;
109            }
110    
111            @Override
112            public String toString() {
113                    return asMatrix().toString();
114            }
115    
116            public ParticleSet toParticleSet(HiddenMarkovModel hmm,
117                            Randomizer randomizer, int numberOfParticles) {
118                    ParticleSet result = new ParticleSet(hmm);
119                    for (int i = 0; i < numberOfParticles; i++) {
120                            double rvalue = randomizer.nextDouble();
121                            String state = getStateForRandomNumber(rvalue);
122                            result.add(new Particle(state, 0));
123                    }
124                    return result;
125            }
126    
127            private String getStateForRandomNumber(double rvalue) {
128                    double total = 0.0;
129                    for (String s : states) {
130                            total = total + distribution.get(s);
131                            if (total >= rvalue) {
132                                    return s;
133                            }
134                    }
135                    throw new RuntimeException("cannot handle " + rvalue);
136            }
137    
138    }