001    package aima.test.learningtest.neural;
002    
003    import junit.framework.TestCase;
004    import aima.learning.framework.DataSet;
005    import aima.learning.framework.DataSetFactory;
006    import aima.learning.neural.BackPropLearning;
007    import aima.learning.neural.FeedForwardNeuralNetwork;
008    import aima.learning.neural.IrisDataSetNumerizer;
009    import aima.learning.neural.IrisNNDataSet;
010    import aima.learning.neural.NNConfig;
011    import aima.learning.neural.NNDataSet;
012    import aima.learning.neural.Numerizer;
013    import aima.learning.neural.Perceptron;
014    import aima.learning.neural.Vector;
015    import aima.util.Matrix;
016    
017    public class BackPropagationTests extends TestCase {
018    
019            public void testFeedForwardAndBAckLoopWorks() {
020                    // example 11.14 of Neural Network Design by Hagan, Demuth and Beale
021                    Matrix hiddenLayerWeightMatrix = new Matrix(2, 1);
022                    hiddenLayerWeightMatrix.set(0, 0, -0.27);
023                    hiddenLayerWeightMatrix.set(1, 0, -0.41);
024    
025                    Vector hiddenLayerBiasVector = new Vector(2);
026                    hiddenLayerBiasVector.setValue(0, -0.48);
027                    hiddenLayerBiasVector.setValue(1, -0.13);
028    
029                    Vector input = new Vector(1);
030                    input.setValue(0, 1);
031    
032                    Matrix outputLayerWeightMatrix = new Matrix(1, 2);
033                    outputLayerWeightMatrix.set(0, 0, 0.09);
034                    outputLayerWeightMatrix.set(0, 1, -0.17);
035    
036                    Vector outputLayerBiasVector = new Vector(1);
037                    outputLayerBiasVector.setValue(0, 0.48);
038    
039                    Vector error = new Vector(1);
040                    error.setValue(0, 1.261);
041    
042                    double learningRate = 0.1;
043                    double momentumFactor = 0.0;
044                    FeedForwardNeuralNetwork ffnn = new FeedForwardNeuralNetwork(
045                                    hiddenLayerWeightMatrix, hiddenLayerBiasVector,
046                                    outputLayerWeightMatrix, outputLayerBiasVector);
047                    ffnn.setTrainingScheme(new BackPropLearning(learningRate,
048                                    momentumFactor));
049                    ffnn.processInput(input);
050                    ffnn.processError(error);
051    
052                    Matrix finalHiddenLayerWeights = ffnn.getHiddenLayerWeights();
053                    assertEquals(-0.265, finalHiddenLayerWeights.get(0, 0), 0.001);
054                    assertEquals(-0.419, finalHiddenLayerWeights.get(1, 0), 0.001);
055    
056                    Vector hiddenLayerBias = ffnn.getHiddenLayerBias();
057                    assertEquals(-0.475, hiddenLayerBias.getValue(0), 0.001);
058                    assertEquals(-0.1399, hiddenLayerBias.getValue(1), 0.001);
059    
060                    Matrix finalOutputLayerWeights = ffnn.getOutputLayerWeights();
061                    assertEquals(0.171, finalOutputLayerWeights.get(0, 0), 0.001);
062                    assertEquals(-0.0772, finalOutputLayerWeights.get(0, 1), 0.001);
063    
064                    Vector outputLayerBias = ffnn.getOutputLayerBias();
065                    assertEquals(0.7322, outputLayerBias.getValue(0), 0.001);
066    
067            }
068    
069            public void xtestFeedForwardAndBAckLoopWorksWithMomentum() {
070                    // example 11.14 of Neural Network Design by Hagan, Demuth and Beale
071                    Matrix hiddenLayerWeightMatrix = new Matrix(2, 1);
072                    hiddenLayerWeightMatrix.set(0, 0, -0.27);
073                    hiddenLayerWeightMatrix.set(1, 0, -0.41);
074    
075                    Vector hiddenLayerBiasVector = new Vector(2);
076                    hiddenLayerBiasVector.setValue(0, -0.48);
077                    hiddenLayerBiasVector.setValue(1, -0.13);
078    
079                    Vector input = new Vector(1);
080                    input.setValue(0, 1);
081    
082                    Matrix outputLayerWeightMatrix = new Matrix(1, 2);
083                    outputLayerWeightMatrix.set(0, 0, 0.09);
084                    outputLayerWeightMatrix.set(0, 1, -0.17);
085    
086                    Vector outputLayerBiasVector = new Vector(1);
087                    outputLayerBiasVector.setValue(0, 0.48);
088    
089                    Vector error = new Vector(1);
090                    error.setValue(0, 1.261);
091    
092                    double learningRate = 0.1;
093                    double momentumFactor = 0.5;
094                    FeedForwardNeuralNetwork ffnn = new FeedForwardNeuralNetwork(
095                                    hiddenLayerWeightMatrix, hiddenLayerBiasVector,
096                                    outputLayerWeightMatrix, outputLayerBiasVector);
097    
098                    ffnn.setTrainingScheme(new BackPropLearning(learningRate,
099                                    momentumFactor));
100                    ffnn.processInput(input);
101                    ffnn.processError(error);
102    
103                    Matrix finalHiddenLayerWeights = ffnn.getHiddenLayerWeights();
104                    assertEquals(-0.2675, finalHiddenLayerWeights.get(0, 0), 0.001);
105                    assertEquals(-0.4149, finalHiddenLayerWeights.get(1, 0), 0.001);
106    
107                    Vector hiddenLayerBias = ffnn.getHiddenLayerBias();
108                    assertEquals(-0.4775, hiddenLayerBias.getValue(0), 0.001);
109                    assertEquals(-0.1349, hiddenLayerBias.getValue(1), 0.001);
110    
111                    Matrix finalOutputLayerWeights = ffnn.getOutputLayerWeights();
112                    assertEquals(0.1304, finalOutputLayerWeights.get(0, 0), 0.001);
113                    assertEquals(-0.1235, finalOutputLayerWeights.get(0, 1), 0.001);
114    
115                    Vector outputLayerBias = ffnn.getOutputLayerBias();
116                    assertEquals(0.6061, outputLayerBias.getValue(0), 0.001);
117    
118            }
119    
120            public void xtestDataSetPopulation() throws Exception {
121                    DataSet irisDataSet = DataSetFactory.getIrisDataSet();
122                    Numerizer numerizer = new IrisDataSetNumerizer();
123                    NNDataSet innds = new IrisNNDataSet();
124    
125                    innds.createExamplesFromDataSet(irisDataSet, numerizer);
126    
127                    NNConfig config = new NNConfig();
128                    config.setConfig(FeedForwardNeuralNetwork.NUMBER_OF_INPUTS, 4);
129                    config.setConfig(FeedForwardNeuralNetwork.NUMBER_OF_OUTPUTS, 3);
130                    config.setConfig(FeedForwardNeuralNetwork.NUMBER_OF_HIDDEN_NEURONS, 6);
131                    config.setConfig(FeedForwardNeuralNetwork.LOWER_LIMIT_WEIGHTS, -2.0);
132                    config.setConfig(FeedForwardNeuralNetwork.UPPER_LIMIT_WEIGHTS, 2.0);
133    
134                    FeedForwardNeuralNetwork ffnn = new FeedForwardNeuralNetwork(config);
135                    ffnn.setTrainingScheme(new BackPropLearning(0.1, 0.9));
136    
137                    ffnn.trainOn(innds, 10);
138    
139                    innds.refreshDataset();
140                    int[] result = ffnn.testOnDataSet(innds);
141                    // System.out.println(result[0] + " right, " + result[1] + " wrong");
142    
143            }
144    
145            public void testPerceptron() throws Exception {
146                    DataSet irisDataSet = DataSetFactory.getIrisDataSet();
147                    Numerizer numerizer = new IrisDataSetNumerizer();
148                    NNDataSet innds = new IrisNNDataSet();
149    
150                    innds.createExamplesFromDataSet(irisDataSet, numerizer);
151    
152                    Perceptron perc = new Perceptron(3, 4);
153    
154                    perc.trainOn(innds, 10);
155    
156                    innds.refreshDataset();
157                    int[] result = perc.testOnDataSet(innds);
158                    // System.out.println(result[0] + " right, " + result[1] + " wrong");
159    
160            }
161    }