001 /* 002 * Created on Aug 6, 2005 003 * 004 */ 005 package aima.learning.demos; 006 007 import java.util.ArrayList; 008 import java.util.List; 009 010 import aima.learning.framework.DataSet; 011 import aima.learning.framework.DataSetFactory; 012 import aima.learning.framework.Learner; 013 import aima.learning.inductive.DLTestFactory; 014 import aima.learning.inductive.DecisionTree; 015 import aima.learning.learners.AdaBoostLearner; 016 import aima.learning.learners.DecisionListLearner; 017 import aima.learning.learners.DecisionTreeLearner; 018 import aima.learning.learners.StumpLearner; 019 import aima.learning.neural.BackPropLearning; 020 import aima.learning.neural.FeedForwardNeuralNetwork; 021 import aima.learning.neural.IrisDataSetNumerizer; 022 import aima.learning.neural.IrisNNDataSet; 023 import aima.learning.neural.NNConfig; 024 import aima.learning.neural.NNDataSet; 025 import aima.learning.neural.Numerizer; 026 import aima.learning.neural.Perceptron; 027 import aima.util.Util; 028 029 public class LearningDemo { 030 public static void main(String[] args) { 031 032 // For Reinforcement Learning Demos see Probability Demo 033 decisionTreeDemo(); 034 decisionListDemo(); 035 ensembleLearningDemo(); 036 perceptronDemo(); 037 backPropogationDemo(); 038 } 039 040 private static void decisionTreeDemo() { 041 System.out.println(Util.ntimes("*", 100)); 042 System.out 043 .println("\nDecisionTree Demo - Inducing a DecisionList from the Restaurant DataSet\n "); 044 System.out.println(Util.ntimes("*", 100)); 045 try { 046 DataSet ds = DataSetFactory.getRestaurantDataSet(); 047 DecisionTreeLearner learner = new DecisionTreeLearner(); 048 learner.train(ds); 049 System.out.println("The Induced Decision Tree is "); 050 System.out.println(learner.getDecisionTree()); 051 int[] result = learner.test(ds); 052 053 System.out 054 .println("\nThis Decision Tree classifies the data set with " 055 + result[0] 056 + " successes" 057 + " and " 058 + result[1] 059 + " failures"); 060 System.out.println("\n"); 061 } catch (Exception e) { 062 System.out.println("Decision Tree Demo Failed "); 063 e.printStackTrace(); 064 } 065 066 } 067 068 private static void decisionListDemo() { 069 try { 070 System.out.println(Util.ntimes("*", 100)); 071 System.out 072 .println("DecisionList Demo - Inducing a DecisionList from the Restaurant DataSet\n "); 073 System.out.println(Util.ntimes("*", 100)); 074 DataSet ds = DataSetFactory.getRestaurantDataSet(); 075 DecisionListLearner learner = new DecisionListLearner("Yes", "No", 076 new DLTestFactory()); 077 learner.train(ds); 078 System.out.println("The Induced DecisionList is"); 079 System.out.println(learner.getDecisionList()); 080 int[] result = learner.test(ds); 081 082 System.out 083 .println("\nThis Decision List classifies the data set with " 084 + result[0] 085 + " successes" 086 + " and " 087 + result[1] 088 + " failures"); 089 System.out.println("\n"); 090 091 } catch (Exception e) { 092 System.out.println("Decision ListDemo Failed"); 093 } 094 } 095 096 private static void ensembleLearningDemo() { 097 System.out.println(Util.ntimes("*", 100)); 098 System.out 099 .println("\n Ensemble Decision Demo - Weak Learners co operating to give Superior decisions "); 100 System.out.println(Util.ntimes("*", 100)); 101 try { 102 DataSet ds = DataSetFactory.getRestaurantDataSet(); 103 List stumps = DecisionTree.getStumpsFor(ds, "Yes", "No"); 104 List<Learner> learners = new ArrayList<Learner>(); 105 106 System.out 107 .println("\nStump Learners vote to decide in this algorithm"); 108 for (Object stump : stumps) { 109 DecisionTree sl = (DecisionTree) stump; 110 StumpLearner stumpLearner = new StumpLearner(sl, "No"); 111 learners.add(stumpLearner); 112 } 113 AdaBoostLearner learner = new AdaBoostLearner(learners, ds); 114 learner.train(ds); 115 int[] result = learner.test(ds); 116 System.out 117 .println("\nThis Ensemble Learner classifies the data set with " 118 + result[0] 119 + " successes" 120 + " and " 121 + result[1] 122 + " failures"); 123 System.out.println("\n"); 124 125 } catch (Exception e) { 126 127 } 128 129 } 130 131 private static void perceptronDemo() { 132 try { 133 System.out.println(Util.ntimes("*", 100)); 134 System.out 135 .println("\n Perceptron Demo - Running Perceptron on Iris data Set with 10 epochs of learning "); 136 System.out.println(Util.ntimes("*", 100)); 137 DataSet irisDataSet = DataSetFactory.getIrisDataSet(); 138 Numerizer numerizer = new IrisDataSetNumerizer(); 139 NNDataSet innds = new IrisNNDataSet(); 140 141 innds.createExamplesFromDataSet(irisDataSet, numerizer); 142 143 Perceptron perc = new Perceptron(3, 4); 144 145 perc.trainOn(innds, 10); 146 147 innds.refreshDataset(); 148 int[] result = perc.testOnDataSet(innds); 149 System.out.println(result[0] + " right, " + result[1] + " wrong"); 150 } catch (Exception e) { 151 // TODO Auto-generated catch block 152 e.printStackTrace(); 153 } 154 155 } 156 157 private static void backPropogationDemo() { 158 try { 159 System.out.println(Util.ntimes("*", 100)); 160 System.out 161 .println("\n BackpropagationDemo - Running BackProp on Iris data Set with 10 epochs of learning "); 162 System.out.println(Util.ntimes("*", 100)); 163 164 DataSet irisDataSet = DataSetFactory.getIrisDataSet(); 165 Numerizer numerizer = new IrisDataSetNumerizer(); 166 NNDataSet innds = new IrisNNDataSet(); 167 168 innds.createExamplesFromDataSet(irisDataSet, numerizer); 169 170 NNConfig config = new NNConfig(); 171 config.setConfig(FeedForwardNeuralNetwork.NUMBER_OF_INPUTS, 4); 172 config.setConfig(FeedForwardNeuralNetwork.NUMBER_OF_OUTPUTS, 3); 173 config.setConfig(FeedForwardNeuralNetwork.NUMBER_OF_HIDDEN_NEURONS, 174 6); 175 config 176 .setConfig(FeedForwardNeuralNetwork.LOWER_LIMIT_WEIGHTS, 177 -2.0); 178 config.setConfig(FeedForwardNeuralNetwork.UPPER_LIMIT_WEIGHTS, 2.0); 179 180 FeedForwardNeuralNetwork ffnn = new FeedForwardNeuralNetwork(config); 181 ffnn.setTrainingScheme(new BackPropLearning(0.1, 0.9)); 182 183 ffnn.trainOn(innds, 10); 184 185 innds.refreshDataset(); 186 int[] result = ffnn.testOnDataSet(innds); 187 System.out.println(result[0] + " right, " + result[1] + " wrong"); 188 } catch (Exception e) { 189 // TODO Auto-generated catch block 190 e.printStackTrace(); 191 } 192 } 193 194 }