001    /*
002     * Created on Jul 31, 2005
003     *
004     */
005    package aima.test.learningtest;
006    
007    import java.util.ArrayList;
008    import java.util.List;
009    
010    import junit.framework.TestCase;
011    import aima.learning.framework.DataSet;
012    import aima.learning.framework.DataSetFactory;
013    import aima.learning.framework.Learner;
014    import aima.learning.inductive.DecisionTree;
015    import aima.learning.learners.AdaBoostLearner;
016    import aima.learning.learners.StumpLearner;
017    
018    /**
019     * @author Ravi Mohan
020     * 
021     */
022    
023    public class EnsembleLearningTest extends TestCase {
024            private static final String UNABLE_TO_CLASSIFY = "Unable to Classify";
025    
026            private static final String YES = "Yes";
027    
028            public void testAdaBoostEnablesCollectionOfStumpsToClassifyDataSetAccurately()
029                            throws Exception {
030                    DataSet ds = DataSetFactory.getRestaurantDataSet();
031                    List stumps = DecisionTree.getStumpsFor(ds, YES, "No");
032                    List<Learner> learners = new ArrayList<Learner>();
033                    for (Object stump : stumps) {
034                            DecisionTree sl = (DecisionTree) stump;
035                            StumpLearner stumpLearner = new StumpLearner(sl, "No");
036                            learners.add(stumpLearner);
037                    }
038                    AdaBoostLearner learner = new AdaBoostLearner(learners, ds);
039                    learner.train(ds);
040                    int[] result = learner.test(ds);
041                    assertEquals(12, result[0]);
042                    assertEquals(0, result[1]);
043            }
044    
045    }