001 package aima.learning.inductive; 002 003 import java.util.ArrayList; 004 import java.util.Hashtable; 005 import java.util.List; 006 007 import aima.learning.framework.DataSet; 008 import aima.learning.framework.Example; 009 import aima.util.Util; 010 011 /** 012 * @author Ravi Mohan 013 * 014 */ 015 public class DecisionTree { 016 private String attributeName; 017 018 // each node modelled as a hash of attribute_value/decisiontree 019 private Hashtable<String, DecisionTree> nodes; 020 021 protected DecisionTree() { 022 023 } 024 025 public DecisionTree(String attributeName) { 026 this.attributeName = attributeName; 027 nodes = new Hashtable<String, DecisionTree>(); 028 029 } 030 031 public void addLeaf(String attributeValue, String decision) { 032 nodes.put(attributeValue, new ConstantDecisonTree(decision)); 033 } 034 035 public void addNode(String attributeValue, DecisionTree tree) { 036 nodes.put(attributeValue, tree); 037 } 038 039 public Object predict(Example e) { 040 String attrValue = e.getAttributeValueAsString(attributeName); 041 if (nodes.containsKey(attrValue)) { 042 return nodes.get(attrValue).predict(e); 043 } else { 044 throw new RuntimeException("no node exists for attribute value " 045 + attrValue); 046 } 047 } 048 049 public static DecisionTree getStumpFor(DataSet ds, String attributeName, 050 String attributeValue, String returnValueIfMatched, 051 List<String> unmatchedValues, String returnValueIfUnmatched) { 052 DecisionTree dt = new DecisionTree(attributeName); 053 dt.addLeaf(attributeValue, returnValueIfMatched); 054 for (String unmatchedValue : unmatchedValues) { 055 dt.addLeaf(unmatchedValue, returnValueIfUnmatched); 056 } 057 return dt; 058 } 059 060 public static List<DecisionTree> getStumpsFor(DataSet ds, 061 String returnValueIfMatched, String returnValueIfUnmatched) { 062 List<String> attributes = ds.getNonTargetAttributes(); 063 List<DecisionTree> trees = new ArrayList<DecisionTree>(); 064 for (String attribute : attributes) { 065 List<String> values = ds.getPossibleAttributeValues(attribute); 066 for (String value : values) { 067 List<String> unmatchedValues = Util.removeFrom(ds 068 .getPossibleAttributeValues(attribute), value); 069 070 DecisionTree tree = getStumpFor(ds, attribute, value, 071 returnValueIfMatched, unmatchedValues, 072 returnValueIfUnmatched); 073 trees.add(tree); 074 075 } 076 } 077 return trees; 078 } 079 080 /** 081 * @return Returns the attributeName. 082 */ 083 public String getAttributeName() { 084 return attributeName; 085 } 086 087 @Override 088 public String toString() { 089 return toString(1, new StringBuffer()); 090 } 091 092 public String toString(int depth, StringBuffer buf) { 093 094 if (attributeName != null) { 095 buf.append(Util.ntimes("\t", depth)); 096 buf.append(Util.ntimes("***", 1)); 097 buf.append(attributeName + " \n"); 098 for (String attributeValue : nodes.keySet()) { 099 buf.append(Util.ntimes("\t", depth + 1)); 100 buf.append("+" + attributeValue); 101 buf.append("\n"); 102 DecisionTree child = nodes.get(attributeValue); 103 buf.append(child.toString(depth + 1, new StringBuffer())); 104 } 105 } 106 107 return buf.toString(); 108 } 109 110 }