001    /*
002     * Created on Feb 3, 2005
003     *
004     */
005    package aima.probability;
006    
007    import java.util.ArrayList;
008    import java.util.Hashtable;
009    import java.util.Iterator;
010    import java.util.List;
011    
012    import aima.util.Util;
013    
014    /**
015     * @author Ravi Mohan
016     * 
017     */
018    
019    public class BayesNet {
020            private List<BayesNetNode> roots = new ArrayList<BayesNetNode>();
021    
022            private List<BayesNetNode> variableNodes;
023    
024            public BayesNet(BayesNetNode root) {
025                    roots.add(root);
026            }
027    
028            public BayesNet(BayesNetNode root1, BayesNetNode root2) {
029                    this(root1);
030                    roots.add(root2);
031            }
032    
033            public BayesNet(BayesNetNode root1, BayesNetNode root2, BayesNetNode root3) {
034                    this(root1, root2);
035                    roots.add(root3);
036            }
037    
038            public BayesNet(List<BayesNetNode> rootNodes) {
039                    roots = rootNodes;
040            }
041    
042            public List<String> getVariables() {
043                    variableNodes = getVariableNodes();
044                    List<String> variables = new ArrayList<String>();
045                    for (BayesNetNode variableNode : variableNodes) {
046                            variables.add(variableNode.getVariable());
047                    }
048                    return variables;
049            }
050    
051            private List<BayesNetNode> getVariableNodes() {
052                    // TODO dicey initalisation works fine but unclear . clarify
053                    if (variableNodes == null) {
054                            List<BayesNetNode> newVariableNodes = new ArrayList<BayesNetNode>();
055                            List<BayesNetNode> parents = roots;
056                            List<BayesNetNode> traversedParents = new ArrayList<BayesNetNode>();
057    
058                            while (parents.size() != 0) {
059                                    List<BayesNetNode> newParents = new ArrayList<BayesNetNode>();
060                                    for (BayesNetNode parent : parents) {
061                                            // if parent unseen till now
062                                            if (!(traversedParents.contains(parent))) {
063                                                    newVariableNodes.add(parent);
064                                                    // add any unseen children to next generation of parents
065                                                    List<BayesNetNode> children = parent.getChildren();
066                                                    for (BayesNetNode child : children) {
067                                                            if (!newParents.contains(child)) {
068                                                                    newParents.add(child);
069                                                            }
070                                                    }
071                                                    traversedParents.add(parent);
072                                            }
073                                    }
074    
075                                    parents = newParents;
076                            }
077                            variableNodes = newVariableNodes;
078                    }
079    
080                    return variableNodes;
081            }
082    
083            private BayesNetNode getNodeOf(String y) {
084                    List<BayesNetNode> variableNodes = getVariableNodes();
085                    for (BayesNetNode node : variableNodes) {
086                            if (node.getVariable().equals(y)) {
087                                    return node;
088                            }
089                    }
090                    return null;
091            }
092    
093            public double probabilityOf(String Y, Boolean value,
094                            Hashtable<String, Boolean> evidence) {
095                    BayesNetNode y = getNodeOf(Y);
096                    if (y == null) {
097                            throw new RuntimeException("Unable to find a node with variable "
098                                            + Y);
099                    } else {
100                            List<BayesNetNode> parentNodes = y.getParents();
101                            if (parentNodes.size() == 0) {// root nodes
102                                    Hashtable<String, Boolean> YTable = new Hashtable<String, Boolean>();
103                                    YTable.put(Y, value);
104    
105                                    double prob = y.probabilityOf(YTable);
106                                    return prob;
107    
108                            } else {// non rootnodes
109                                    Hashtable<String, Boolean> parentValues = new Hashtable<String, Boolean>();
110                                    for (BayesNetNode parent : parentNodes) {
111                                            parentValues.put(parent.getVariable(), evidence.get(parent
112                                                            .getVariable()));
113                                    }
114                                    double prob = y.probabilityOf(parentValues);
115                                    if (value.equals(Boolean.TRUE)) {
116                                            return prob;
117                                    } else {
118                                            return (1.0 - prob);
119                                    }
120    
121                            }
122    
123                    }
124            }
125    
126            public Hashtable getPriorSample(Randomizer r) {
127                    Hashtable<String, Boolean> h = new Hashtable<String, Boolean>();
128                    List<BayesNetNode> variableNodes = getVariableNodes();
129                    for (BayesNetNode node : variableNodes) {
130                            h.put(node.getVariable(), node.isTrueFor(r.nextDouble(), h));
131                    }
132                    return h;
133            }
134    
135            public Hashtable getPriorSample() {
136                    return getPriorSample(new JavaRandomizer());
137            }
138    
139            public double[] rejectionSample(String X, Hashtable evidence,
140                            int numberOfSamples, Randomizer r) {
141                    double[] retval = new double[2];
142                    for (int i = 0; i < numberOfSamples; i++) {
143                            Hashtable sample = getPriorSample(r);
144                            if (consistent(sample, evidence)) {
145                                    boolean queryValue = ((Boolean) sample.get(X)).booleanValue();
146                                    if (queryValue) {
147                                            retval[0] += 1;
148                                    } else {
149                                            retval[1] += 1;
150                                    }
151                            }
152                    }
153                    return Util.normalize(retval);
154            }
155    
156            private boolean consistent(Hashtable sample, Hashtable evidence) {
157                    Iterator iter = evidence.keySet().iterator();
158                    while (iter.hasNext()) {
159                            String key = (String) iter.next();
160                            Boolean value = (Boolean) evidence.get(key);
161                            if (!(value.equals(sample.get(key)))) {
162                                    return false;
163                            }
164                    }
165                    return true;
166            }
167    
168            public double[] likelihoodWeighting(String X,
169                            Hashtable<String, Boolean> evidence, int numberOfSamples,
170                            Randomizer r) {
171                    double[] retval = new double[2];
172                    for (int i = 0; i < numberOfSamples; i++) {
173                            Hashtable<String, Boolean> x = new Hashtable<String, Boolean>();
174                            double w = 1.0;
175                            List<BayesNetNode> variableNodes = getVariableNodes();
176                            for (BayesNetNode node : variableNodes) {
177                                    if (evidence.get(node.getVariable()) != null) {
178                                            w *= node.probabilityOf(x);
179                                            x.put(node.getVariable(), evidence.get(node.getVariable()));
180                                    } else {
181                                            x
182                                                            .put(node.getVariable(), node.isTrueFor(r
183                                                                            .nextDouble(), x));
184                                    }
185                            }
186                            boolean queryValue = (x.get(X)).booleanValue();
187                            if (queryValue) {
188                                    retval[0] += w;
189                            } else {
190                                    retval[1] += w;
191                            }
192    
193                    }
194                    return Util.normalize(retval);
195            }
196    
197            public double[] mcmcAsk(String X, Hashtable<String, Boolean> evidence,
198                            int numberOfVariables, Randomizer r) {
199                    double[] retval = new double[2];
200                    List nonEvidenceVariables = nonEvidenceVariables(evidence, X);
201                    Hashtable<String, Boolean> event = createRandomEvent(
202                                    nonEvidenceVariables, evidence, r);
203                    for (int j = 0; j < numberOfVariables; j++) {
204                            Iterator iter = nonEvidenceVariables.iterator();
205                            while (iter.hasNext()) {
206                                    String variable = (String) iter.next();
207                                    BayesNetNode node = getNodeOf(variable);
208                                    List<BayesNetNode> markovBlanket = markovBlanket(node);
209                                    Hashtable mb = createMBValues(markovBlanket, event);
210                                    // event.put(node.getVariable(), node.isTrueFor(
211                                    // r.getProbability(), mb));
212                                    event.put(node.getVariable(), truthValue(rejectionSample(node
213                                                    .getVariable(), mb, 100, r), r));
214                                    boolean queryValue = (event.get(X)).booleanValue();
215                                    if (queryValue) {
216                                            retval[0] += 1;
217                                    } else {
218                                            retval[1] += 1;
219                                    }
220                            }
221                    }
222                    return Util.normalize(retval);
223            }
224    
225            private Boolean truthValue(double[] ds, Randomizer r) {
226                    double value = r.nextDouble();
227                    if (value < ds[0]) {
228                            return Boolean.TRUE;
229                    } else {
230                            return Boolean.FALSE;
231                    }
232    
233            }
234    
235            private Hashtable<String, Boolean> createRandomEvent(
236                            List nonEvidenceVariables, Hashtable<String, Boolean> evidence,
237                            Randomizer r) {
238                    Hashtable<String, Boolean> table = new Hashtable<String, Boolean>();
239                    List<String> variables = getVariables();
240                    for (String variable : variables) {
241    
242                            if (nonEvidenceVariables.contains(variable)) {
243                                    Boolean value = r.nextDouble() <= 0.5 ? Boolean.TRUE
244                                                    : Boolean.FALSE;
245                                    table.put(variable, value);
246                            } else {
247                                    table.put(variable, evidence.get(variable));
248                            }
249                    }
250                    return table;
251            }
252    
253            private List nonEvidenceVariables(Hashtable<String, Boolean> evidence,
254                            String query) {
255                    List<String> nonEvidenceVariables = new ArrayList<String>();
256                    List<String> variables = getVariables();
257                    for (String variable : variables) {
258    
259                            if (!(evidence.keySet().contains(variable))) {
260                                    nonEvidenceVariables.add(variable);
261                            }
262                    }
263                    return nonEvidenceVariables;
264            }
265    
266            private List<BayesNetNode> markovBlanket(BayesNetNode node) {
267                    return markovBlanket(node, new ArrayList<BayesNetNode>());
268            }
269    
270            private List<BayesNetNode> markovBlanket(BayesNetNode node,
271                            List<BayesNetNode> soFar) {
272                    // parents
273                    List<BayesNetNode> parents = node.getParents();
274                    for (BayesNetNode parent : parents) {
275                            if (!soFar.contains(parent)) {
276                                    soFar.add(parent);
277                            }
278                    }
279                    // children
280                    List<BayesNetNode> children = node.getChildren();
281                    for (BayesNetNode child : children) {
282                            if (!soFar.contains(child)) {
283                                    soFar.add(child);
284                                    List<BayesNetNode> childsParents = child.getParents();
285                                    for (BayesNetNode childsParent : childsParents) {
286                                            ;
287                                            if ((!soFar.contains(childsParent))
288                                                            && (!(childsParent.equals(node)))) {
289                                                    soFar.add(childsParent);
290                                            }
291                                    }// childsParents
292                            }// end contains child
293    
294                    }// end child
295    
296                    return soFar;
297            }
298    
299            private Hashtable createMBValues(List<BayesNetNode> markovBlanket,
300                            Hashtable<String, Boolean> event) {
301                    Hashtable<String, Boolean> table = new Hashtable<String, Boolean>();
302                    for (BayesNetNode node : markovBlanket) {
303                            table.put(node.getVariable(), event.get(node.getVariable()));
304                    }
305                    return table;
306            }
307    
308            public double[] mcmcAsk(String X, Hashtable<String, Boolean> evidence,
309                            int numberOfVariables) {
310                    return mcmcAsk(X, evidence, numberOfVariables, new JavaRandomizer());
311            }
312    
313            public double[] likelihoodWeighting(String X,
314                            Hashtable<String, Boolean> evidence, int numberOfSamples) {
315                    return likelihoodWeighting(X, evidence, numberOfSamples,
316                                    new JavaRandomizer());
317            }
318    
319            public double[] rejectionSample(String X,
320                            Hashtable<String, Boolean> evidence, int numberOfSamples) {
321                    return rejectionSample(X, evidence, numberOfSamples,
322                                    new JavaRandomizer());
323            }
324    
325    }