001 /* 002 * Created on Jul 25, 2005 003 * 004 */ 005 package aima.test.learningtest; 006 007 import java.util.ArrayList; 008 009 import junit.framework.TestCase; 010 import aima.learning.framework.DataSet; 011 import aima.learning.framework.DataSetFactory; 012 import aima.learning.inductive.DLTest; 013 import aima.learning.inductive.DLTestFactory; 014 import aima.learning.learners.CurrentBestLearner; 015 import aima.learning.learners.DecisionListLearner; 016 import aima.learning.learners.DecisionTreeLearner; 017 import aima.learning.learners.MajorityLearner; 018 019 /** 020 * @author Ravi Mohan 021 * 022 */ 023 024 public class LearnerTests extends TestCase { 025 public void testMajorityLearner() throws Exception { 026 MajorityLearner learner = new MajorityLearner(); 027 DataSet ds = DataSetFactory.getRestaurantDataSet(); 028 learner.train(ds); 029 int[] result = learner.test(ds); 030 assertEquals(6, result[0]); 031 assertEquals(6, result[1]); 032 } 033 034 public void testDefaultUsedWhenTrainingDataSetHasNoExamples() 035 throws Exception { 036 // tests RecursionBaseCase#1 037 DataSet ds = DataSetFactory.getRestaurantDataSet(); 038 DecisionTreeLearner learner = new DecisionTreeLearner(); 039 040 DataSet ds2 = ds.emptyDataSet(); 041 assertEquals(0, ds2.size()); 042 043 learner.train(ds2); 044 assertEquals("Unable To Classify", learner.predict(ds.getExample(0))); 045 046 } 047 048 public void testClassificationReturnedWhenAllExamplesHaveTheSameClassification() 049 throws Exception { 050 // tests RecursionBaseCase#2 051 DataSet ds = DataSetFactory.getRestaurantDataSet(); 052 DecisionTreeLearner learner = new DecisionTreeLearner(); 053 054 DataSet ds2 = ds.emptyDataSet(); 055 056 // all 3 examples have the same classification (willWait = yes) 057 ds2.add(ds.getExample(0)); 058 ds2.add(ds.getExample(2)); 059 ds2.add(ds.getExample(3)); 060 061 learner.train(ds2); 062 assertEquals("Yes", learner.predict(ds.getExample(0))); 063 064 } 065 066 public void testMajorityReturnedWhenAttributesToExamineIsEmpty() 067 throws Exception { 068 // tests RecursionBaseCase#2 069 DataSet ds = DataSetFactory.getRestaurantDataSet(); 070 DecisionTreeLearner learner = new DecisionTreeLearner(); 071 072 DataSet ds2 = ds.emptyDataSet(); 073 074 // 3 examples have classification = "yes" and one ,"no" 075 ds2.add(ds.getExample(0)); 076 ds2.add(ds.getExample(1));// "no" 077 ds2.add(ds.getExample(2)); 078 ds2.add(ds.getExample(3)); 079 ds2.setSpecification(new MockDataSetSpecification("will_wait")); 080 081 learner.train(ds2); 082 assertEquals("Yes", learner.predict(ds.getExample(1))); 083 084 } 085 086 public void testInducedTreeClassifiesDataSetCorrectly() throws Exception { 087 DataSet ds = DataSetFactory.getRestaurantDataSet(); 088 DecisionTreeLearner learner = new DecisionTreeLearner(); 089 learner.train(ds); 090 int[] result = learner.test(ds); 091 assertEquals(12, result[0]); 092 assertEquals(0, result[1]); 093 } 094 095 public void testDecisionListLearnerReturnsNegativeDLWhenDataSetEmpty() 096 throws Exception { 097 // tests first base case of DL Learner 098 DecisionListLearner learner = new DecisionListLearner("Yes", "No", 099 new MockDLTestFactory(null)); 100 DataSet ds = DataSetFactory.getRestaurantDataSet(); 101 DataSet empty = ds.emptyDataSet(); 102 learner.train(empty); 103 assertEquals("No", learner.predict(ds.getExample(0))); 104 assertEquals("No", learner.predict(ds.getExample(1))); 105 assertEquals("No", learner.predict(ds.getExample(2))); 106 } 107 108 public void testDecisionListLearnerReturnsFailureWhenTestsEmpty() 109 throws Exception { 110 // tests second base case of DL Learner 111 DecisionListLearner learner = new DecisionListLearner("Yes", "No", 112 new MockDLTestFactory(new ArrayList<DLTest>())); 113 DataSet ds = DataSetFactory.getRestaurantDataSet(); 114 learner.train(ds); 115 assertEquals(DecisionListLearner.FAILURE, learner.predict(ds 116 .getExample(0))); 117 } 118 119 public void testDecisionListTestRunOnRestaurantDataSet() throws Exception { 120 DataSet ds = DataSetFactory.getRestaurantDataSet(); 121 DecisionListLearner learner = new DecisionListLearner("Yes", "No", 122 new DLTestFactory()); 123 learner.train(ds); 124 // System.out.println(learner.getDecisionList()); 125 int[] result = learner.test(ds); 126 assertEquals(12, result[0]); 127 assertEquals(0, result[1]); 128 } 129 130 public void testCurrentBestLearnerOnRestaurantDataSet() throws Exception { 131 DataSet ds = DataSetFactory.getRestaurantDataSet(); 132 CurrentBestLearner learner = new CurrentBestLearner("Yes"); 133 learner.train(ds); 134 135 int[] result = learner.test(ds); 136 assertEquals(12, result[0]); 137 assertEquals(0, result[1]); 138 } 139 }