001    /*
002     * Created on Aug 6, 2005
003     *
004     */
005    package aima.learning.demos;
006    
007    import java.util.ArrayList;
008    import java.util.List;
009    
010    import aima.learning.framework.DataSet;
011    import aima.learning.framework.DataSetFactory;
012    import aima.learning.framework.Learner;
013    import aima.learning.inductive.DLTestFactory;
014    import aima.learning.inductive.DecisionTree;
015    import aima.learning.learners.AdaBoostLearner;
016    import aima.learning.learners.DecisionListLearner;
017    import aima.learning.learners.DecisionTreeLearner;
018    import aima.learning.learners.StumpLearner;
019    import aima.learning.neural.BackPropLearning;
020    import aima.learning.neural.FeedForwardNeuralNetwork;
021    import aima.learning.neural.IrisDataSetNumerizer;
022    import aima.learning.neural.IrisNNDataSet;
023    import aima.learning.neural.NNConfig;
024    import aima.learning.neural.NNDataSet;
025    import aima.learning.neural.Numerizer;
026    import aima.learning.neural.Perceptron;
027    import aima.util.Util;
028    
029    public class LearningDemo {
030            public static void main(String[] args) {
031    
032                    // For Reinforcement Learning Demos see Probability Demo
033                    decisionTreeDemo();
034                    decisionListDemo();
035                    ensembleLearningDemo();
036                    perceptronDemo();
037                    backPropogationDemo();
038            }
039    
040            private static void decisionTreeDemo() {
041                    System.out.println(Util.ntimes("*", 100));
042                    System.out
043                                    .println("\nDecisionTree Demo - Inducing a DecisionList from the Restaurant DataSet\n ");
044                    System.out.println(Util.ntimes("*", 100));
045                    try {
046                            DataSet ds = DataSetFactory.getRestaurantDataSet();
047                            DecisionTreeLearner learner = new DecisionTreeLearner();
048                            learner.train(ds);
049                            System.out.println("The Induced Decision Tree is ");
050                            System.out.println(learner.getDecisionTree());
051                            int[] result = learner.test(ds);
052    
053                            System.out
054                                            .println("\nThis Decision Tree classifies the data set with "
055                                                            + result[0]
056                                                            + " successes"
057                                                            + " and "
058                                                            + result[1]
059                                                            + " failures");
060                            System.out.println("\n");
061                    } catch (Exception e) {
062                            System.out.println("Decision Tree Demo Failed  ");
063                            e.printStackTrace();
064                    }
065    
066            }
067    
068            private static void decisionListDemo() {
069                    try {
070                            System.out.println(Util.ntimes("*", 100));
071                            System.out
072                                            .println("DecisionList Demo - Inducing a DecisionList from the Restaurant DataSet\n ");
073                            System.out.println(Util.ntimes("*", 100));
074                            DataSet ds = DataSetFactory.getRestaurantDataSet();
075                            DecisionListLearner learner = new DecisionListLearner("Yes", "No",
076                                            new DLTestFactory());
077                            learner.train(ds);
078                            System.out.println("The Induced DecisionList is");
079                            System.out.println(learner.getDecisionList());
080                            int[] result = learner.test(ds);
081    
082                            System.out
083                                            .println("\nThis Decision List classifies the data set with "
084                                                            + result[0]
085                                                            + " successes"
086                                                            + " and "
087                                                            + result[1]
088                                                            + " failures");
089                            System.out.println("\n");
090    
091                    } catch (Exception e) {
092                            System.out.println("Decision ListDemo Failed");
093                    }
094            }
095    
096            private static void ensembleLearningDemo() {
097                    System.out.println(Util.ntimes("*", 100));
098                    System.out
099                                    .println("\n Ensemble Decision Demo - Weak Learners co operating to give Superior decisions ");
100                    System.out.println(Util.ntimes("*", 100));
101                    try {
102                            DataSet ds = DataSetFactory.getRestaurantDataSet();
103                            List stumps = DecisionTree.getStumpsFor(ds, "Yes", "No");
104                            List<Learner> learners = new ArrayList<Learner>();
105    
106                            System.out
107                                            .println("\nStump Learners vote to decide in this algorithm");
108                            for (Object stump : stumps) {
109                                    DecisionTree sl = (DecisionTree) stump;
110                                    StumpLearner stumpLearner = new StumpLearner(sl, "No");
111                                    learners.add(stumpLearner);
112                            }
113                            AdaBoostLearner learner = new AdaBoostLearner(learners, ds);
114                            learner.train(ds);
115                            int[] result = learner.test(ds);
116                            System.out
117                                            .println("\nThis Ensemble Learner  classifies the data set with "
118                                                            + result[0]
119                                                            + " successes"
120                                                            + " and "
121                                                            + result[1]
122                                                            + " failures");
123                            System.out.println("\n");
124    
125                    } catch (Exception e) {
126    
127                    }
128    
129            }
130    
131            private static void perceptronDemo() {
132                    try {
133                            System.out.println(Util.ntimes("*", 100));
134                            System.out
135                                            .println("\n Perceptron Demo - Running Perceptron on Iris data Set with 10 epochs of learning ");
136                            System.out.println(Util.ntimes("*", 100));
137                            DataSet irisDataSet = DataSetFactory.getIrisDataSet();
138                            Numerizer numerizer = new IrisDataSetNumerizer();
139                            NNDataSet innds = new IrisNNDataSet();
140    
141                            innds.createExamplesFromDataSet(irisDataSet, numerizer);
142    
143                            Perceptron perc = new Perceptron(3, 4);
144    
145                            perc.trainOn(innds, 10);
146    
147                            innds.refreshDataset();
148                            int[] result = perc.testOnDataSet(innds);
149                            System.out.println(result[0] + " right, " + result[1] + " wrong");
150                    } catch (Exception e) {
151                            // TODO Auto-generated catch block
152                            e.printStackTrace();
153                    }
154    
155            }
156    
157            private static void backPropogationDemo() {
158                    try {
159                            System.out.println(Util.ntimes("*", 100));
160                            System.out
161                                            .println("\n BackpropagationDemo  - Running BackProp on Iris data Set with 10 epochs of learning ");
162                            System.out.println(Util.ntimes("*", 100));
163    
164                            DataSet irisDataSet = DataSetFactory.getIrisDataSet();
165                            Numerizer numerizer = new IrisDataSetNumerizer();
166                            NNDataSet innds = new IrisNNDataSet();
167    
168                            innds.createExamplesFromDataSet(irisDataSet, numerizer);
169    
170                            NNConfig config = new NNConfig();
171                            config.setConfig(FeedForwardNeuralNetwork.NUMBER_OF_INPUTS, 4);
172                            config.setConfig(FeedForwardNeuralNetwork.NUMBER_OF_OUTPUTS, 3);
173                            config.setConfig(FeedForwardNeuralNetwork.NUMBER_OF_HIDDEN_NEURONS,
174                                            6);
175                            config
176                                            .setConfig(FeedForwardNeuralNetwork.LOWER_LIMIT_WEIGHTS,
177                                                            -2.0);
178                            config.setConfig(FeedForwardNeuralNetwork.UPPER_LIMIT_WEIGHTS, 2.0);
179    
180                            FeedForwardNeuralNetwork ffnn = new FeedForwardNeuralNetwork(config);
181                            ffnn.setTrainingScheme(new BackPropLearning(0.1, 0.9));
182    
183                            ffnn.trainOn(innds, 10);
184    
185                            innds.refreshDataset();
186                            int[] result = ffnn.testOnDataSet(innds);
187                            System.out.println(result[0] + " right, " + result[1] + " wrong");
188                    } catch (Exception e) {
189                            // TODO Auto-generated catch block
190                            e.printStackTrace();
191                    }
192            }
193    
194    }