001 /* 002 * Created on Jul 25, 2005 003 * 004 */ 005 package aima.learning.learners; 006 007 import java.util.Iterator; 008 import java.util.List; 009 010 import aima.learning.framework.DataSet; 011 import aima.learning.framework.Example; 012 import aima.learning.framework.Learner; 013 import aima.learning.inductive.ConstantDecisonTree; 014 import aima.learning.inductive.DecisionTree; 015 import aima.util.Util; 016 017 /** 018 * @author Ravi Mohan 019 * 020 */ 021 022 public class DecisionTreeLearner implements Learner { 023 private DecisionTree tree; 024 025 private String defaultValue; 026 027 public DecisionTreeLearner() { 028 this.defaultValue = "Unable To Classify"; 029 030 } 031 032 // used when you have to test a non induced tree (eg: for testing) 033 public DecisionTreeLearner(DecisionTree tree, String defaultValue) { 034 this.tree = tree; 035 this.defaultValue = defaultValue; 036 } 037 038 public void train(DataSet ds) { 039 List<String> attributes = ds.getNonTargetAttributes(); 040 this.tree = decisionTreeLearning(ds, attributes, 041 new ConstantDecisonTree(defaultValue)); 042 } 043 044 public String predict(Example e) { 045 return (String) tree.predict(e); 046 } 047 048 public int[] test(DataSet ds) { 049 int[] results = new int[] { 0, 0 }; 050 051 for (Example e : ds.examples) { 052 if (e.targetValue().equals(tree.predict(e))) { 053 results[0] = results[0] + 1; 054 } else { 055 results[1] = results[1] + 1; 056 } 057 } 058 return results; 059 } 060 061 private DecisionTree decisionTreeLearning(DataSet ds, 062 List<String> attributeNames, ConstantDecisonTree defaultTree) { 063 if (ds.size() == 0) { 064 return defaultTree; 065 } 066 if (allExamplesHaveSameClassification(ds)) { 067 return new ConstantDecisonTree(ds.getExample(0).targetValue()); 068 } 069 if (attributeNames.size() == 0) { 070 return majorityValue(ds); 071 } 072 String chosenAttribute = chooseAttribute(ds, attributeNames); 073 074 DecisionTree tree = new DecisionTree(chosenAttribute); 075 ConstantDecisonTree m = majorityValue(ds); 076 077 List<String> values = ds.getPossibleAttributeValues(chosenAttribute); 078 for (String v : values) { 079 DataSet filtered = ds.matchingDataSet(chosenAttribute, v); 080 List<String> newAttribs = Util.removeFrom(attributeNames, 081 chosenAttribute); 082 DecisionTree subTree = decisionTreeLearning(filtered, newAttribs, m); 083 tree.addNode(v, subTree); 084 085 } 086 087 return tree; 088 } 089 090 private ConstantDecisonTree majorityValue(DataSet ds) { 091 Learner learner = new MajorityLearner(); 092 learner.train(ds); 093 return new ConstantDecisonTree(learner.predict(ds.getExample(0))); 094 } 095 096 private String chooseAttribute(DataSet ds, List<String> attributeNames) { 097 double greatestGain = 0.0; 098 String attributeWithGreatestGain = attributeNames.get(0); 099 for (String attr : attributeNames) { 100 double gain = ds.calculateGainFor(attr); 101 if (gain > greatestGain) { 102 greatestGain = gain; 103 attributeWithGreatestGain = attr; 104 } 105 } 106 107 return attributeWithGreatestGain; 108 } 109 110 private boolean allExamplesHaveSameClassification(DataSet ds) { 111 String classification = ds.getExample(0).targetValue(); 112 Iterator<Example> iter = ds.iterator(); 113 while (iter.hasNext()) { 114 Example element = iter.next(); 115 if (!(element.targetValue().equals(classification))) { 116 return false; 117 } 118 119 } 120 return true; 121 } 122 123 public DecisionTree getDecisionTree() { 124 return tree; 125 } 126 127 }