001    /*
002     * Created on Jul 25, 2005
003     *
004     */
005    package aima.test.learningtest;
006    
007    import java.util.ArrayList;
008    import java.util.List;
009    
010    import junit.framework.TestCase;
011    import aima.learning.framework.DataSet;
012    import aima.learning.framework.DataSetFactory;
013    import aima.learning.inductive.DecisionTree;
014    import aima.learning.learners.DecisionTreeLearner;
015    import aima.util.Util;
016    
017    /**
018     * @author Ravi Mohan
019     * 
020     */
021    public class DecisionTreeTest extends TestCase {
022            private static final String YES = "Yes";
023    
024            private static final String NO = "No";
025    
026            public void testActualDecisionTreeClassifiesRestaurantDataSetCorrectly()
027                            throws Exception {
028                    DecisionTreeLearner learner = new DecisionTreeLearner(
029                                    createActualRestaurantDecisionTree(), "Unable to clasify");
030                    int[] results = learner.test(DataSetFactory.getRestaurantDataSet());
031                    assertEquals(12, results[0]);
032                    assertEquals(0, results[1]);
033            }
034    
035            public void testInducedDecisionTreeClassifiesRestaurantDataSetCorrectly()
036                            throws Exception {
037                    DecisionTreeLearner learner = new DecisionTreeLearner(
038                                    createInducedRestaurantDecisionTree(), "Unable to clasify");
039                    int[] results = learner.test(DataSetFactory.getRestaurantDataSet());
040                    assertEquals(12, results[0]);
041                    assertEquals(0, results[1]);
042            }
043    
044            public void testStumpCreationForSpecifiedAttributeValuePair()
045                            throws Exception {
046                    DataSet ds = DataSetFactory.getRestaurantDataSet();
047                    List<String> unmatchedValues = new ArrayList<String>();
048                    unmatchedValues.add(NO);
049                    DecisionTree dt = DecisionTree.getStumpFor(ds, "alternate", YES, YES,
050                                    unmatchedValues, NO);
051                    assertNotNull(dt);
052            }
053    
054            public void testStumpCreationForDataSet() throws Exception {
055                    DataSet ds = DataSetFactory.getRestaurantDataSet();
056                    List<DecisionTree> dt = DecisionTree.getStumpsFor(ds, YES,
057                                    "Unable to classify");
058                    assertEquals(26, dt.size());
059            }
060    
061            public void testStumpPredictionForDataSet() throws Exception {
062                    DataSet ds = DataSetFactory.getRestaurantDataSet();
063                    List<DecisionTree> trees = DecisionTree.getStumpsFor(ds, YES,
064                                    "Unable to classify");
065                    // for (DecisionTree tree : trees){
066                    // DecisionTreeLearner learner = new DecisionTreeLearner(tree,"Unable to
067                    // Classify");
068                    // int[] result = learner.test(ds);
069                    // System.out.println("On stump " +tree.getAttributeName()+ " " +
070                    // result[0]+ " successes "+ result[1]+ " failures");
071                    // }
072                    List<String> unmatchedValues = new ArrayList<String>();
073                    unmatchedValues.add(NO);
074                    DecisionTree tree = DecisionTree.getStumpFor(ds, "hungry", YES, YES,
075                                    unmatchedValues, "Unable to Classify");
076                    DecisionTreeLearner learner = new DecisionTreeLearner(tree,
077                                    "Unable to Classify");
078                    int[] result = learner.test(ds);
079                    assertEquals(5, result[0]);
080                    assertEquals(7, result[1]);
081                    // System.out.println("On stump " +tree.getAttributeName()+ " " +
082                    // result[0]+ " successes "+ result[1]+ " failures");
083    
084            }
085    
086            private static DecisionTree createInducedRestaurantDecisionTree() {// from
087                    // AIMA
088                    // 2nd
089                    // ED
090                    // Fig
091                    // 18.6
092                    // friday saturday node
093                    DecisionTree frisat = new DecisionTree("fri/sat");
094                    frisat.addLeaf(Util.YES, Util.YES);
095                    frisat.addLeaf(Util.NO, Util.NO);
096    
097                    // type node
098                    DecisionTree type = new DecisionTree("type");
099                    type.addLeaf("French", Util.YES);
100                    type.addLeaf("Italian", Util.NO);
101                    type.addNode("Thai", frisat);
102                    type.addLeaf("Burger", Util.YES);
103    
104                    // hungry node
105                    DecisionTree hungry = new DecisionTree("hungry");
106                    hungry.addLeaf(Util.NO, Util.NO);
107                    hungry.addNode(Util.YES, type);
108    
109                    // patrons node
110                    DecisionTree patrons = new DecisionTree("patrons");
111                    patrons.addLeaf("None", Util.NO);
112                    patrons.addLeaf("Some", Util.YES);
113                    patrons.addNode("Full", hungry);
114    
115                    return patrons;
116    
117            }
118    
119            private static DecisionTree createActualRestaurantDecisionTree() {// from
120                    // AIMA
121                    // 2nd
122                    // ED
123                    // Fig
124                    // 18.2
125    
126                    // raining node
127                    DecisionTree raining = new DecisionTree("raining");
128                    raining.addLeaf(Util.YES, Util.YES);
129                    raining.addLeaf(Util.NO, Util.NO);
130    
131                    // bar node
132                    DecisionTree bar = new DecisionTree("bar");
133                    bar.addLeaf(Util.YES, Util.YES);
134                    bar.addLeaf(Util.NO, Util.NO);
135    
136                    // friday saturday node
137                    DecisionTree frisat = new DecisionTree("fri/sat");
138                    frisat.addLeaf(Util.YES, Util.YES);
139                    frisat.addLeaf(Util.NO, Util.NO);
140    
141                    // second alternate node to the right of the diagram below hungry
142                    DecisionTree alternate2 = new DecisionTree("alternate");
143                    alternate2.addNode(Util.YES, raining);
144                    alternate2.addLeaf(Util.NO, Util.YES);
145    
146                    // reservation node
147                    DecisionTree reservation = new DecisionTree("reservation");
148                    frisat.addNode(Util.NO, bar);
149                    frisat.addLeaf(Util.YES, Util.YES);
150    
151                    // first alternate node to the left of the diagram below waitestimate
152                    DecisionTree alternate1 = new DecisionTree("alternate");
153                    alternate1.addNode(Util.NO, reservation);
154                    alternate1.addNode(Util.YES, frisat);
155    
156                    // hungry node
157                    DecisionTree hungry = new DecisionTree("hungry");
158                    hungry.addLeaf(Util.NO, Util.YES);
159                    hungry.addNode(Util.YES, alternate2);
160    
161                    // wait estimate node
162                    DecisionTree waitEstimate = new DecisionTree("wait_estimate");
163                    waitEstimate.addLeaf(">60", Util.NO);
164                    waitEstimate.addNode("30-60", alternate1);
165                    waitEstimate.addNode("10-30", hungry);
166                    waitEstimate.addLeaf("0-10", Util.YES);
167    
168                    // patrons node
169                    DecisionTree patrons = new DecisionTree("patrons");
170                    patrons.addLeaf("None", Util.NO);
171                    patrons.addLeaf("Some", Util.YES);
172                    patrons.addNode("Full", waitEstimate);
173    
174                    return patrons;
175    
176            }
177    }