001 /* 002 * Created on Feb 3, 2005 003 * 004 */ 005 package aima.probability; 006 007 import java.util.ArrayList; 008 import java.util.Hashtable; 009 import java.util.Iterator; 010 import java.util.List; 011 012 import aima.util.Util; 013 014 /** 015 * @author Ravi Mohan 016 * 017 */ 018 019 public class BayesNet { 020 private List<BayesNetNode> roots = new ArrayList<BayesNetNode>(); 021 022 private List<BayesNetNode> variableNodes; 023 024 public BayesNet(BayesNetNode root) { 025 roots.add(root); 026 } 027 028 public BayesNet(BayesNetNode root1, BayesNetNode root2) { 029 this(root1); 030 roots.add(root2); 031 } 032 033 public BayesNet(BayesNetNode root1, BayesNetNode root2, BayesNetNode root3) { 034 this(root1, root2); 035 roots.add(root3); 036 } 037 038 public BayesNet(List<BayesNetNode> rootNodes) { 039 roots = rootNodes; 040 } 041 042 public List<String> getVariables() { 043 variableNodes = getVariableNodes(); 044 List<String> variables = new ArrayList<String>(); 045 for (BayesNetNode variableNode : variableNodes) { 046 variables.add(variableNode.getVariable()); 047 } 048 return variables; 049 } 050 051 private List<BayesNetNode> getVariableNodes() { 052 // TODO dicey initalisation works fine but unclear . clarify 053 if (variableNodes == null) { 054 List<BayesNetNode> newVariableNodes = new ArrayList<BayesNetNode>(); 055 List<BayesNetNode> parents = roots; 056 List<BayesNetNode> traversedParents = new ArrayList<BayesNetNode>(); 057 058 while (parents.size() != 0) { 059 List<BayesNetNode> newParents = new ArrayList<BayesNetNode>(); 060 for (BayesNetNode parent : parents) { 061 // if parent unseen till now 062 if (!(traversedParents.contains(parent))) { 063 newVariableNodes.add(parent); 064 // add any unseen children to next generation of parents 065 List<BayesNetNode> children = parent.getChildren(); 066 for (BayesNetNode child : children) { 067 if (!newParents.contains(child)) { 068 newParents.add(child); 069 } 070 } 071 traversedParents.add(parent); 072 } 073 } 074 075 parents = newParents; 076 } 077 variableNodes = newVariableNodes; 078 } 079 080 return variableNodes; 081 } 082 083 private BayesNetNode getNodeOf(String y) { 084 List<BayesNetNode> variableNodes = getVariableNodes(); 085 for (BayesNetNode node : variableNodes) { 086 if (node.getVariable().equals(y)) { 087 return node; 088 } 089 } 090 return null; 091 } 092 093 public double probabilityOf(String Y, Boolean value, 094 Hashtable<String, Boolean> evidence) { 095 BayesNetNode y = getNodeOf(Y); 096 if (y == null) { 097 throw new RuntimeException("Unable to find a node with variable " 098 + Y); 099 } else { 100 List<BayesNetNode> parentNodes = y.getParents(); 101 if (parentNodes.size() == 0) {// root nodes 102 Hashtable<String, Boolean> YTable = new Hashtable<String, Boolean>(); 103 YTable.put(Y, value); 104 105 double prob = y.probabilityOf(YTable); 106 return prob; 107 108 } else {// non rootnodes 109 Hashtable<String, Boolean> parentValues = new Hashtable<String, Boolean>(); 110 for (BayesNetNode parent : parentNodes) { 111 parentValues.put(parent.getVariable(), evidence.get(parent 112 .getVariable())); 113 } 114 double prob = y.probabilityOf(parentValues); 115 if (value.equals(Boolean.TRUE)) { 116 return prob; 117 } else { 118 return (1.0 - prob); 119 } 120 121 } 122 123 } 124 } 125 126 public Hashtable getPriorSample(Randomizer r) { 127 Hashtable<String, Boolean> h = new Hashtable<String, Boolean>(); 128 List<BayesNetNode> variableNodes = getVariableNodes(); 129 for (BayesNetNode node : variableNodes) { 130 h.put(node.getVariable(), node.isTrueFor(r.nextDouble(), h)); 131 } 132 return h; 133 } 134 135 public Hashtable getPriorSample() { 136 return getPriorSample(new JavaRandomizer()); 137 } 138 139 public double[] rejectionSample(String X, Hashtable evidence, 140 int numberOfSamples, Randomizer r) { 141 double[] retval = new double[2]; 142 for (int i = 0; i < numberOfSamples; i++) { 143 Hashtable sample = getPriorSample(r); 144 if (consistent(sample, evidence)) { 145 boolean queryValue = ((Boolean) sample.get(X)).booleanValue(); 146 if (queryValue) { 147 retval[0] += 1; 148 } else { 149 retval[1] += 1; 150 } 151 } 152 } 153 return Util.normalize(retval); 154 } 155 156 private boolean consistent(Hashtable sample, Hashtable evidence) { 157 Iterator iter = evidence.keySet().iterator(); 158 while (iter.hasNext()) { 159 String key = (String) iter.next(); 160 Boolean value = (Boolean) evidence.get(key); 161 if (!(value.equals(sample.get(key)))) { 162 return false; 163 } 164 } 165 return true; 166 } 167 168 public double[] likelihoodWeighting(String X, 169 Hashtable<String, Boolean> evidence, int numberOfSamples, 170 Randomizer r) { 171 double[] retval = new double[2]; 172 for (int i = 0; i < numberOfSamples; i++) { 173 Hashtable<String, Boolean> x = new Hashtable<String, Boolean>(); 174 double w = 1.0; 175 List<BayesNetNode> variableNodes = getVariableNodes(); 176 for (BayesNetNode node : variableNodes) { 177 if (evidence.get(node.getVariable()) != null) { 178 w *= node.probabilityOf(x); 179 x.put(node.getVariable(), evidence.get(node.getVariable())); 180 } else { 181 x 182 .put(node.getVariable(), node.isTrueFor(r 183 .nextDouble(), x)); 184 } 185 } 186 boolean queryValue = (x.get(X)).booleanValue(); 187 if (queryValue) { 188 retval[0] += w; 189 } else { 190 retval[1] += w; 191 } 192 193 } 194 return Util.normalize(retval); 195 } 196 197 public double[] mcmcAsk(String X, Hashtable<String, Boolean> evidence, 198 int numberOfVariables, Randomizer r) { 199 double[] retval = new double[2]; 200 List nonEvidenceVariables = nonEvidenceVariables(evidence, X); 201 Hashtable<String, Boolean> event = createRandomEvent( 202 nonEvidenceVariables, evidence, r); 203 for (int j = 0; j < numberOfVariables; j++) { 204 Iterator iter = nonEvidenceVariables.iterator(); 205 while (iter.hasNext()) { 206 String variable = (String) iter.next(); 207 BayesNetNode node = getNodeOf(variable); 208 List<BayesNetNode> markovBlanket = markovBlanket(node); 209 Hashtable mb = createMBValues(markovBlanket, event); 210 // event.put(node.getVariable(), node.isTrueFor( 211 // r.getProbability(), mb)); 212 event.put(node.getVariable(), truthValue(rejectionSample(node 213 .getVariable(), mb, 100, r), r)); 214 boolean queryValue = (event.get(X)).booleanValue(); 215 if (queryValue) { 216 retval[0] += 1; 217 } else { 218 retval[1] += 1; 219 } 220 } 221 } 222 return Util.normalize(retval); 223 } 224 225 private Boolean truthValue(double[] ds, Randomizer r) { 226 double value = r.nextDouble(); 227 if (value < ds[0]) { 228 return Boolean.TRUE; 229 } else { 230 return Boolean.FALSE; 231 } 232 233 } 234 235 private Hashtable<String, Boolean> createRandomEvent( 236 List nonEvidenceVariables, Hashtable<String, Boolean> evidence, 237 Randomizer r) { 238 Hashtable<String, Boolean> table = new Hashtable<String, Boolean>(); 239 List<String> variables = getVariables(); 240 for (String variable : variables) { 241 242 if (nonEvidenceVariables.contains(variable)) { 243 Boolean value = r.nextDouble() <= 0.5 ? Boolean.TRUE 244 : Boolean.FALSE; 245 table.put(variable, value); 246 } else { 247 table.put(variable, evidence.get(variable)); 248 } 249 } 250 return table; 251 } 252 253 private List nonEvidenceVariables(Hashtable<String, Boolean> evidence, 254 String query) { 255 List<String> nonEvidenceVariables = new ArrayList<String>(); 256 List<String> variables = getVariables(); 257 for (String variable : variables) { 258 259 if (!(evidence.keySet().contains(variable))) { 260 nonEvidenceVariables.add(variable); 261 } 262 } 263 return nonEvidenceVariables; 264 } 265 266 private List<BayesNetNode> markovBlanket(BayesNetNode node) { 267 return markovBlanket(node, new ArrayList<BayesNetNode>()); 268 } 269 270 private List<BayesNetNode> markovBlanket(BayesNetNode node, 271 List<BayesNetNode> soFar) { 272 // parents 273 List<BayesNetNode> parents = node.getParents(); 274 for (BayesNetNode parent : parents) { 275 if (!soFar.contains(parent)) { 276 soFar.add(parent); 277 } 278 } 279 // children 280 List<BayesNetNode> children = node.getChildren(); 281 for (BayesNetNode child : children) { 282 if (!soFar.contains(child)) { 283 soFar.add(child); 284 List<BayesNetNode> childsParents = child.getParents(); 285 for (BayesNetNode childsParent : childsParents) { 286 ; 287 if ((!soFar.contains(childsParent)) 288 && (!(childsParent.equals(node)))) { 289 soFar.add(childsParent); 290 } 291 }// childsParents 292 }// end contains child 293 294 }// end child 295 296 return soFar; 297 } 298 299 private Hashtable createMBValues(List<BayesNetNode> markovBlanket, 300 Hashtable<String, Boolean> event) { 301 Hashtable<String, Boolean> table = new Hashtable<String, Boolean>(); 302 for (BayesNetNode node : markovBlanket) { 303 table.put(node.getVariable(), event.get(node.getVariable())); 304 } 305 return table; 306 } 307 308 public double[] mcmcAsk(String X, Hashtable<String, Boolean> evidence, 309 int numberOfVariables) { 310 return mcmcAsk(X, evidence, numberOfVariables, new JavaRandomizer()); 311 } 312 313 public double[] likelihoodWeighting(String X, 314 Hashtable<String, Boolean> evidence, int numberOfSamples) { 315 return likelihoodWeighting(X, evidence, numberOfSamples, 316 new JavaRandomizer()); 317 } 318 319 public double[] rejectionSample(String X, 320 Hashtable<String, Boolean> evidence, int numberOfSamples) { 321 return rejectionSample(X, evidence, numberOfSamples, 322 new JavaRandomizer()); 323 } 324 325 }