001    /*
002     * Created on Jul 25, 2005
003     *
004     */
005    package aima.learning.learners;
006    
007    import java.util.Iterator;
008    import java.util.List;
009    
010    import aima.learning.framework.DataSet;
011    import aima.learning.framework.Example;
012    import aima.learning.framework.Learner;
013    import aima.learning.inductive.ConstantDecisonTree;
014    import aima.learning.inductive.DecisionTree;
015    import aima.util.Util;
016    
017    /**
018     * @author Ravi Mohan
019     * 
020     */
021    
022    public class DecisionTreeLearner implements Learner {
023            private DecisionTree tree;
024    
025            private String defaultValue;
026    
027            public DecisionTreeLearner() {
028                    this.defaultValue = "Unable To Classify";
029    
030            }
031    
032            // used when you have to test a non induced tree (eg: for testing)
033            public DecisionTreeLearner(DecisionTree tree, String defaultValue) {
034                    this.tree = tree;
035                    this.defaultValue = defaultValue;
036            }
037    
038            public void train(DataSet ds) {
039                    List<String> attributes = ds.getNonTargetAttributes();
040                    this.tree = decisionTreeLearning(ds, attributes,
041                                    new ConstantDecisonTree(defaultValue));
042            }
043    
044            public String predict(Example e) {
045                    return (String) tree.predict(e);
046            }
047    
048            public int[] test(DataSet ds) {
049                    int[] results = new int[] { 0, 0 };
050    
051                    for (Example e : ds.examples) {
052                            if (e.targetValue().equals(tree.predict(e))) {
053                                    results[0] = results[0] + 1;
054                            } else {
055                                    results[1] = results[1] + 1;
056                            }
057                    }
058                    return results;
059            }
060    
061            private DecisionTree decisionTreeLearning(DataSet ds,
062                            List<String> attributeNames, ConstantDecisonTree defaultTree) {
063                    if (ds.size() == 0) {
064                            return defaultTree;
065                    }
066                    if (allExamplesHaveSameClassification(ds)) {
067                            return new ConstantDecisonTree(ds.getExample(0).targetValue());
068                    }
069                    if (attributeNames.size() == 0) {
070                            return majorityValue(ds);
071                    }
072                    String chosenAttribute = chooseAttribute(ds, attributeNames);
073    
074                    DecisionTree tree = new DecisionTree(chosenAttribute);
075                    ConstantDecisonTree m = majorityValue(ds);
076    
077                    List<String> values = ds.getPossibleAttributeValues(chosenAttribute);
078                    for (String v : values) {
079                            DataSet filtered = ds.matchingDataSet(chosenAttribute, v);
080                            List<String> newAttribs = Util.removeFrom(attributeNames,
081                                            chosenAttribute);
082                            DecisionTree subTree = decisionTreeLearning(filtered, newAttribs, m);
083                            tree.addNode(v, subTree);
084    
085                    }
086    
087                    return tree;
088            }
089    
090            private ConstantDecisonTree majorityValue(DataSet ds) {
091                    Learner learner = new MajorityLearner();
092                    learner.train(ds);
093                    return new ConstantDecisonTree(learner.predict(ds.getExample(0)));
094            }
095    
096            private String chooseAttribute(DataSet ds, List<String> attributeNames) {
097                    double greatestGain = 0.0;
098                    String attributeWithGreatestGain = attributeNames.get(0);
099                    for (String attr : attributeNames) {
100                            double gain = ds.calculateGainFor(attr);
101                            if (gain > greatestGain) {
102                                    greatestGain = gain;
103                                    attributeWithGreatestGain = attr;
104                            }
105                    }
106    
107                    return attributeWithGreatestGain;
108            }
109    
110            private boolean allExamplesHaveSameClassification(DataSet ds) {
111                    String classification = ds.getExample(0).targetValue();
112                    Iterator<Example> iter = ds.iterator();
113                    while (iter.hasNext()) {
114                            Example element = iter.next();
115                            if (!(element.targetValue().equals(classification))) {
116                                    return false;
117                            }
118    
119                    }
120                    return true;
121            }
122    
123            public DecisionTree getDecisionTree() {
124                    return tree;
125            }
126    
127    }