001 /* 002 * Created on Jul 25, 2005 003 * 004 */ 005 package aima.test.learningtest; 006 007 import java.util.ArrayList; 008 import java.util.List; 009 010 import junit.framework.TestCase; 011 import aima.learning.framework.DataSet; 012 import aima.learning.framework.DataSetFactory; 013 import aima.learning.inductive.DecisionTree; 014 import aima.learning.learners.DecisionTreeLearner; 015 import aima.util.Util; 016 017 /** 018 * @author Ravi Mohan 019 * 020 */ 021 public class DecisionTreeTest extends TestCase { 022 private static final String YES = "Yes"; 023 024 private static final String NO = "No"; 025 026 public void testActualDecisionTreeClassifiesRestaurantDataSetCorrectly() 027 throws Exception { 028 DecisionTreeLearner learner = new DecisionTreeLearner( 029 createActualRestaurantDecisionTree(), "Unable to clasify"); 030 int[] results = learner.test(DataSetFactory.getRestaurantDataSet()); 031 assertEquals(12, results[0]); 032 assertEquals(0, results[1]); 033 } 034 035 public void testInducedDecisionTreeClassifiesRestaurantDataSetCorrectly() 036 throws Exception { 037 DecisionTreeLearner learner = new DecisionTreeLearner( 038 createInducedRestaurantDecisionTree(), "Unable to clasify"); 039 int[] results = learner.test(DataSetFactory.getRestaurantDataSet()); 040 assertEquals(12, results[0]); 041 assertEquals(0, results[1]); 042 } 043 044 public void testStumpCreationForSpecifiedAttributeValuePair() 045 throws Exception { 046 DataSet ds = DataSetFactory.getRestaurantDataSet(); 047 List<String> unmatchedValues = new ArrayList<String>(); 048 unmatchedValues.add(NO); 049 DecisionTree dt = DecisionTree.getStumpFor(ds, "alternate", YES, YES, 050 unmatchedValues, NO); 051 assertNotNull(dt); 052 } 053 054 public void testStumpCreationForDataSet() throws Exception { 055 DataSet ds = DataSetFactory.getRestaurantDataSet(); 056 List<DecisionTree> dt = DecisionTree.getStumpsFor(ds, YES, 057 "Unable to classify"); 058 assertEquals(26, dt.size()); 059 } 060 061 public void testStumpPredictionForDataSet() throws Exception { 062 DataSet ds = DataSetFactory.getRestaurantDataSet(); 063 List<DecisionTree> trees = DecisionTree.getStumpsFor(ds, YES, 064 "Unable to classify"); 065 // for (DecisionTree tree : trees){ 066 // DecisionTreeLearner learner = new DecisionTreeLearner(tree,"Unable to 067 // Classify"); 068 // int[] result = learner.test(ds); 069 // System.out.println("On stump " +tree.getAttributeName()+ " " + 070 // result[0]+ " successes "+ result[1]+ " failures"); 071 // } 072 List<String> unmatchedValues = new ArrayList<String>(); 073 unmatchedValues.add(NO); 074 DecisionTree tree = DecisionTree.getStumpFor(ds, "hungry", YES, YES, 075 unmatchedValues, "Unable to Classify"); 076 DecisionTreeLearner learner = new DecisionTreeLearner(tree, 077 "Unable to Classify"); 078 int[] result = learner.test(ds); 079 assertEquals(5, result[0]); 080 assertEquals(7, result[1]); 081 // System.out.println("On stump " +tree.getAttributeName()+ " " + 082 // result[0]+ " successes "+ result[1]+ " failures"); 083 084 } 085 086 private static DecisionTree createInducedRestaurantDecisionTree() {// from 087 // AIMA 088 // 2nd 089 // ED 090 // Fig 091 // 18.6 092 // friday saturday node 093 DecisionTree frisat = new DecisionTree("fri/sat"); 094 frisat.addLeaf(Util.YES, Util.YES); 095 frisat.addLeaf(Util.NO, Util.NO); 096 097 // type node 098 DecisionTree type = new DecisionTree("type"); 099 type.addLeaf("French", Util.YES); 100 type.addLeaf("Italian", Util.NO); 101 type.addNode("Thai", frisat); 102 type.addLeaf("Burger", Util.YES); 103 104 // hungry node 105 DecisionTree hungry = new DecisionTree("hungry"); 106 hungry.addLeaf(Util.NO, Util.NO); 107 hungry.addNode(Util.YES, type); 108 109 // patrons node 110 DecisionTree patrons = new DecisionTree("patrons"); 111 patrons.addLeaf("None", Util.NO); 112 patrons.addLeaf("Some", Util.YES); 113 patrons.addNode("Full", hungry); 114 115 return patrons; 116 117 } 118 119 private static DecisionTree createActualRestaurantDecisionTree() {// from 120 // AIMA 121 // 2nd 122 // ED 123 // Fig 124 // 18.2 125 126 // raining node 127 DecisionTree raining = new DecisionTree("raining"); 128 raining.addLeaf(Util.YES, Util.YES); 129 raining.addLeaf(Util.NO, Util.NO); 130 131 // bar node 132 DecisionTree bar = new DecisionTree("bar"); 133 bar.addLeaf(Util.YES, Util.YES); 134 bar.addLeaf(Util.NO, Util.NO); 135 136 // friday saturday node 137 DecisionTree frisat = new DecisionTree("fri/sat"); 138 frisat.addLeaf(Util.YES, Util.YES); 139 frisat.addLeaf(Util.NO, Util.NO); 140 141 // second alternate node to the right of the diagram below hungry 142 DecisionTree alternate2 = new DecisionTree("alternate"); 143 alternate2.addNode(Util.YES, raining); 144 alternate2.addLeaf(Util.NO, Util.YES); 145 146 // reservation node 147 DecisionTree reservation = new DecisionTree("reservation"); 148 frisat.addNode(Util.NO, bar); 149 frisat.addLeaf(Util.YES, Util.YES); 150 151 // first alternate node to the left of the diagram below waitestimate 152 DecisionTree alternate1 = new DecisionTree("alternate"); 153 alternate1.addNode(Util.NO, reservation); 154 alternate1.addNode(Util.YES, frisat); 155 156 // hungry node 157 DecisionTree hungry = new DecisionTree("hungry"); 158 hungry.addLeaf(Util.NO, Util.YES); 159 hungry.addNode(Util.YES, alternate2); 160 161 // wait estimate node 162 DecisionTree waitEstimate = new DecisionTree("wait_estimate"); 163 waitEstimate.addLeaf(">60", Util.NO); 164 waitEstimate.addNode("30-60", alternate1); 165 waitEstimate.addNode("10-30", hungry); 166 waitEstimate.addLeaf("0-10", Util.YES); 167 168 // patrons node 169 DecisionTree patrons = new DecisionTree("patrons"); 170 patrons.addLeaf("None", Util.NO); 171 patrons.addLeaf("Some", Util.YES); 172 patrons.addNode("Full", waitEstimate); 173 174 return patrons; 175 176 } 177 }