001 /* 002 * Created on Jul 31, 2005 003 * 004 */ 005 package aima.test.learningtest; 006 007 import java.util.ArrayList; 008 import java.util.List; 009 010 import junit.framework.TestCase; 011 import aima.learning.framework.DataSet; 012 import aima.learning.framework.DataSetFactory; 013 import aima.learning.framework.Learner; 014 import aima.learning.inductive.DecisionTree; 015 import aima.learning.learners.AdaBoostLearner; 016 import aima.learning.learners.StumpLearner; 017 018 /** 019 * @author Ravi Mohan 020 * 021 */ 022 023 public class EnsembleLearningTest extends TestCase { 024 private static final String UNABLE_TO_CLASSIFY = "Unable to Classify"; 025 026 private static final String YES = "Yes"; 027 028 public void testAdaBoostEnablesCollectionOfStumpsToClassifyDataSetAccurately() 029 throws Exception { 030 DataSet ds = DataSetFactory.getRestaurantDataSet(); 031 List stumps = DecisionTree.getStumpsFor(ds, YES, "No"); 032 List<Learner> learners = new ArrayList<Learner>(); 033 for (Object stump : stumps) { 034 DecisionTree sl = (DecisionTree) stump; 035 StumpLearner stumpLearner = new StumpLearner(sl, "No"); 036 learners.add(stumpLearner); 037 } 038 AdaBoostLearner learner = new AdaBoostLearner(learners, ds); 039 learner.train(ds); 040 int[] result = learner.test(ds); 041 assertEquals(12, result[0]); 042 assertEquals(0, result[1]); 043 } 044 045 }