001 package aima.probability; 002 003 import java.util.ArrayList; 004 import java.util.Hashtable; 005 import java.util.List; 006 007 import aima.probability.reasoning.HiddenMarkovModel; 008 import aima.probability.reasoning.Particle; 009 import aima.probability.reasoning.ParticleSet; 010 import aima.util.Matrix; 011 import aima.util.Util; 012 013 /** 014 * @author Ravi Mohan 015 * 016 */ 017 018 public class RandomVariable { 019 private String name; 020 021 private Hashtable<String, Double> distribution; 022 023 private List<String> states; 024 025 public RandomVariable(List<String> states) { 026 this("HiddenState", states); 027 } 028 029 public RandomVariable(String name, List<String> states) { 030 this.name = name; 031 this.states = states; 032 this.distribution = new Hashtable<String, Double>(); 033 int numberOfStates = states.size(); 034 double initialProbability = 1.0 / numberOfStates; 035 for (String s : states) { 036 distribution.put(s, initialProbability); 037 } 038 } 039 040 private RandomVariable(String name, List<String> states, 041 Hashtable<String, Double> distribution) { 042 this.name = name; 043 this.states = states; 044 this.distribution = distribution; 045 } 046 047 public void setProbabilityOf(String state, Double probability) { 048 if (states.contains(state)) { 049 distribution.put(state, probability); 050 } else { 051 throw new RuntimeException(state + " is an invalid state"); 052 } 053 } 054 055 public double getProbabilityOf(String state) { 056 if (states.contains(state)) { 057 return distribution.get(state); 058 } else { 059 throw new RuntimeException(state + " is an invalid state"); 060 } 061 } 062 063 public List<String> states() { 064 return states; 065 } 066 067 public RandomVariable duplicate() { 068 Hashtable<String, Double> probs = new Hashtable<String, Double>(); 069 for (String key : distribution.keySet()) { 070 probs.put(key, distribution.get(key)); 071 } 072 return new RandomVariable(name, states, probs); 073 074 } 075 076 public void normalize() { 077 List<Double> probs = new ArrayList<Double>(); 078 for (String s : states) { 079 probs.add(distribution.get(s)); 080 } 081 List<Double> newProbs = Util.normalize(probs); 082 for (int i = 0; i < states.size(); i++) { 083 distribution.put(states.get(i), newProbs.get(i)); 084 } 085 } 086 087 public Matrix asMatrix() { 088 Matrix m = new Matrix(states.size(), 1); 089 for (int i = 0; i < states.size(); i++) { 090 m.set(i, 0, distribution.get(states.get(i))); 091 } 092 return m; 093 094 } 095 096 public void updateFrom(Matrix aMatrix) { 097 for (int i = 0; i < states.size(); i++) { 098 distribution.put(states.get(i), aMatrix.get(i, 0)); 099 } 100 101 } 102 103 public RandomVariable createUnitBelief() { 104 RandomVariable result = duplicate(); 105 for (String s : states()) { 106 result.setProbabilityOf(s, 1.0); 107 } 108 return result; 109 } 110 111 @Override 112 public String toString() { 113 return asMatrix().toString(); 114 } 115 116 public ParticleSet toParticleSet(HiddenMarkovModel hmm, 117 Randomizer randomizer, int numberOfParticles) { 118 ParticleSet result = new ParticleSet(hmm); 119 for (int i = 0; i < numberOfParticles; i++) { 120 double rvalue = randomizer.nextDouble(); 121 String state = getStateForRandomNumber(rvalue); 122 result.add(new Particle(state, 0)); 123 } 124 return result; 125 } 126 127 private String getStateForRandomNumber(double rvalue) { 128 double total = 0.0; 129 for (String s : states) { 130 total = total + distribution.get(s); 131 if (total >= rvalue) { 132 return s; 133 } 134 } 135 throw new RuntimeException("cannot handle " + rvalue); 136 } 137 138 }