001    package aima.learning.inductive;
002    
003    import java.util.ArrayList;
004    import java.util.Hashtable;
005    import java.util.List;
006    
007    import aima.learning.framework.DataSet;
008    import aima.learning.framework.Example;
009    import aima.util.Util;
010    
011    /**
012     * @author Ravi Mohan
013     * 
014     */
015    public class DecisionTree {
016            private String attributeName;
017    
018            // each node modelled as a hash of attribute_value/decisiontree
019            private Hashtable<String, DecisionTree> nodes;
020    
021            protected DecisionTree() {
022    
023            }
024    
025            public DecisionTree(String attributeName) {
026                    this.attributeName = attributeName;
027                    nodes = new Hashtable<String, DecisionTree>();
028    
029            }
030    
031            public void addLeaf(String attributeValue, String decision) {
032                    nodes.put(attributeValue, new ConstantDecisonTree(decision));
033            }
034    
035            public void addNode(String attributeValue, DecisionTree tree) {
036                    nodes.put(attributeValue, tree);
037            }
038    
039            public Object predict(Example e) {
040                    String attrValue = e.getAttributeValueAsString(attributeName);
041                    if (nodes.containsKey(attrValue)) {
042                            return nodes.get(attrValue).predict(e);
043                    } else {
044                            throw new RuntimeException("no node exists for attribute value "
045                                            + attrValue);
046                    }
047            }
048    
049            public static DecisionTree getStumpFor(DataSet ds, String attributeName,
050                            String attributeValue, String returnValueIfMatched,
051                            List<String> unmatchedValues, String returnValueIfUnmatched) {
052                    DecisionTree dt = new DecisionTree(attributeName);
053                    dt.addLeaf(attributeValue, returnValueIfMatched);
054                    for (String unmatchedValue : unmatchedValues) {
055                            dt.addLeaf(unmatchedValue, returnValueIfUnmatched);
056                    }
057                    return dt;
058            }
059    
060            public static List<DecisionTree> getStumpsFor(DataSet ds,
061                            String returnValueIfMatched, String returnValueIfUnmatched) {
062                    List<String> attributes = ds.getNonTargetAttributes();
063                    List<DecisionTree> trees = new ArrayList<DecisionTree>();
064                    for (String attribute : attributes) {
065                            List<String> values = ds.getPossibleAttributeValues(attribute);
066                            for (String value : values) {
067                                    List<String> unmatchedValues = Util.removeFrom(ds
068                                                    .getPossibleAttributeValues(attribute), value);
069    
070                                    DecisionTree tree = getStumpFor(ds, attribute, value,
071                                                    returnValueIfMatched, unmatchedValues,
072                                                    returnValueIfUnmatched);
073                                    trees.add(tree);
074    
075                            }
076                    }
077                    return trees;
078            }
079    
080            /**
081             * @return Returns the attributeName.
082             */
083            public String getAttributeName() {
084                    return attributeName;
085            }
086    
087            @Override
088            public String toString() {
089                    return toString(1, new StringBuffer());
090            }
091    
092            public String toString(int depth, StringBuffer buf) {
093    
094                    if (attributeName != null) {
095                            buf.append(Util.ntimes("\t", depth));
096                            buf.append(Util.ntimes("***", 1));
097                            buf.append(attributeName + " \n");
098                            for (String attributeValue : nodes.keySet()) {
099                                    buf.append(Util.ntimes("\t", depth + 1));
100                                    buf.append("+" + attributeValue);
101                                    buf.append("\n");
102                                    DecisionTree child = nodes.get(attributeValue);
103                                    buf.append(child.toString(depth + 1, new StringBuffer()));
104                            }
105                    }
106    
107                    return buf.toString();
108            }
109    
110    }