001    package aima.test.learningtest.neural;
002    
003    import junit.framework.TestCase;
004    import aima.learning.neural.BackPropLearning;
005    import aima.learning.neural.Layer;
006    import aima.learning.neural.LayerSensitivity;
007    import aima.learning.neural.LogSigActivationFunction;
008    import aima.learning.neural.PureLinearActivationFunction;
009    import aima.learning.neural.Vector;
010    import aima.util.Matrix;
011    
012    public class LayerTests extends TestCase {
013            public void testFeedForward() {
014                    // example 11.14 of Neural Network Design by Hagan, Demuth and Beale
015                    // lots of tedious tests necessary to ensure nn is fundamentally correct
016                    Matrix weightMatrix1 = new Matrix(2, 1);
017                    weightMatrix1.set(0, 0, -0.27);
018                    weightMatrix1.set(1, 0, -0.41);
019    
020                    Vector biasVector1 = new Vector(2);
021                    biasVector1.setValue(0, -0.48);
022                    biasVector1.setValue(1, -0.13);
023    
024                    Layer layer1 = new Layer(weightMatrix1, biasVector1,
025                                    new LogSigActivationFunction());
026    
027                    Vector inputVector1 = new Vector(1);
028                    inputVector1.setValue(0, 1);
029    
030                    Vector expected = new Vector(2);
031                    expected.setValue(0, 0.321);
032                    expected.setValue(1, 0.368);
033    
034                    Vector result1 = layer1.feedForward(inputVector1);
035                    assertEquals(expected.getValue(0), result1.getValue(0), 0.001);
036                    assertEquals(expected.getValue(1), result1.getValue(1), 0.001);
037    
038                    Matrix weightMatrix2 = new Matrix(1, 2);
039                    weightMatrix2.set(0, 0, 0.09);
040                    weightMatrix2.set(0, 1, -0.17);
041    
042                    Vector biasVector2 = new Vector(1);
043                    biasVector2.setValue(0, 0.48);
044    
045                    Layer layer2 = new Layer(weightMatrix2, biasVector2,
046                                    new PureLinearActivationFunction());
047                    Vector inputVector2 = layer1.getLastActivationValues();
048                    Vector result2 = layer2.feedForward(inputVector2);
049                    assertEquals(0.446, result2.getValue(0), 0.001);
050    
051            }
052    
053            public void testSensitivityMatrixCalculationFromErrorVector() {
054                    Matrix weightMatrix1 = new Matrix(2, 1);
055                    weightMatrix1.set(0, 0, -0.27);
056                    weightMatrix1.set(1, 0, -0.41);
057    
058                    Vector biasVector1 = new Vector(2);
059                    biasVector1.setValue(0, -0.48);
060                    biasVector1.setValue(1, -0.13);
061    
062                    Layer layer1 = new Layer(weightMatrix1, biasVector1,
063                                    new LogSigActivationFunction());
064    
065                    Vector inputVector1 = new Vector(1);
066                    inputVector1.setValue(0, 1);
067    
068                    layer1.feedForward(inputVector1);
069    
070                    Matrix weightMatrix2 = new Matrix(1, 2);
071                    weightMatrix2.set(0, 0, 0.09);
072                    weightMatrix2.set(0, 1, -0.17);
073    
074                    Vector biasVector2 = new Vector(1);
075                    biasVector2.setValue(0, 0.48);
076    
077                    Layer layer2 = new Layer(weightMatrix2, biasVector2,
078                                    new PureLinearActivationFunction());
079                    Vector inputVector2 = layer1.getLastActivationValues();
080                    layer2.feedForward(inputVector2);
081    
082                    Vector errorVector = new Vector(1);
083                    errorVector.setValue(0, 1.261);
084                    LayerSensitivity layer2Sensitivity = new LayerSensitivity(layer2);
085                    layer2Sensitivity.sensitivityMatrixFromErrorMatrix(errorVector);
086    
087                    Matrix sensitivityMatrix = layer2Sensitivity.getSensitivityMatrix();
088                    assertEquals(-2.522, sensitivityMatrix.get(0, 0));
089    
090            }
091    
092            public void testSensitivityMatrixCalculationFromSucceedingLayer() {
093                    Matrix weightMatrix1 = new Matrix(2, 1);
094                    weightMatrix1.set(0, 0, -0.27);
095                    weightMatrix1.set(1, 0, -0.41);
096    
097                    Vector biasVector1 = new Vector(2);
098                    biasVector1.setValue(0, -0.48);
099                    biasVector1.setValue(1, -0.13);
100    
101                    Layer layer1 = new Layer(weightMatrix1, biasVector1,
102                                    new LogSigActivationFunction());
103                    LayerSensitivity layer1Sensitivity = new LayerSensitivity(layer1);
104    
105                    Vector inputVector1 = new Vector(1);
106                    inputVector1.setValue(0, 1);
107    
108                    layer1.feedForward(inputVector1);
109    
110                    Matrix weightMatrix2 = new Matrix(1, 2);
111                    weightMatrix2.set(0, 0, 0.09);
112                    weightMatrix2.set(0, 1, -0.17);
113    
114                    Vector biasVector2 = new Vector(1);
115                    biasVector2.setValue(0, 0.48);
116    
117                    Layer layer2 = new Layer(weightMatrix2, biasVector2,
118                                    new PureLinearActivationFunction());
119                    Vector inputVector2 = layer1.getLastActivationValues();
120                    layer2.feedForward(inputVector2);
121    
122                    Vector errorVector = new Vector(1);
123                    errorVector.setValue(0, 1.261);
124                    LayerSensitivity layer2Sensitivity = new LayerSensitivity(layer2);
125                    layer2Sensitivity.sensitivityMatrixFromErrorMatrix(errorVector);
126    
127                    layer1Sensitivity
128                                    .sensitivityMatrixFromSucceedingLayer(layer2Sensitivity);
129                    Matrix sensitivityMatrix = layer1Sensitivity.getSensitivityMatrix();
130    
131                    assertEquals(2, sensitivityMatrix.getRowDimension());
132                    assertEquals(1, sensitivityMatrix.getColumnDimension());
133                    assertEquals(-0.0495, sensitivityMatrix.get(0, 0), 0.001);
134                    assertEquals(0.0997, sensitivityMatrix.get(1, 0), 0.001);
135    
136            }
137    
138            public void testWeightUpdateMatrixesFormedCorrectly() {
139                    Matrix weightMatrix1 = new Matrix(2, 1);
140                    weightMatrix1.set(0, 0, -0.27);
141                    weightMatrix1.set(1, 0, -0.41);
142    
143                    Vector biasVector1 = new Vector(2);
144                    biasVector1.setValue(0, -0.48);
145                    biasVector1.setValue(1, -0.13);
146    
147                    Layer layer1 = new Layer(weightMatrix1, biasVector1,
148                                    new LogSigActivationFunction());
149                    LayerSensitivity layer1Sensitivity = new LayerSensitivity(layer1);
150    
151                    Vector inputVector1 = new Vector(1);
152                    inputVector1.setValue(0, 1);
153    
154                    layer1.feedForward(inputVector1);
155    
156                    Matrix weightMatrix2 = new Matrix(1, 2);
157                    weightMatrix2.set(0, 0, 0.09);
158                    weightMatrix2.set(0, 1, -0.17);
159    
160                    Vector biasVector2 = new Vector(1);
161                    biasVector2.setValue(0, 0.48);
162    
163                    Layer layer2 = new Layer(weightMatrix2, biasVector2,
164                                    new PureLinearActivationFunction());
165                    Vector inputVector2 = layer1.getLastActivationValues();
166                    layer2.feedForward(inputVector2);
167    
168                    Vector errorVector = new Vector(1);
169                    errorVector.setValue(0, 1.261);
170                    LayerSensitivity layer2Sensitivity = new LayerSensitivity(layer2);
171                    layer2Sensitivity.sensitivityMatrixFromErrorMatrix(errorVector);
172    
173                    layer1Sensitivity
174                                    .sensitivityMatrixFromSucceedingLayer(layer2Sensitivity);
175    
176                    Matrix weightUpdateMatrix2 = BackPropLearning.calculateWeightUpdates(
177                                    layer2Sensitivity, layer1.getLastActivationValues(), 0.1);
178                    assertEquals(0.0809, weightUpdateMatrix2.get(0, 0), 0.001);
179                    assertEquals(0.0928, weightUpdateMatrix2.get(0, 1), 0.001);
180    
181                    Matrix lastWeightUpdateMatrix2 = layer2.getLastWeightUpdateMatrix();
182                    assertEquals(0.0809, lastWeightUpdateMatrix2.get(0, 0), 0.001);
183                    assertEquals(0.0928, lastWeightUpdateMatrix2.get(0, 1), 0.001);
184    
185                    Matrix penultimateWeightUpdatematrix2 = layer2
186                                    .getPenultimateWeightUpdateMatrix();
187                    assertEquals(0.0, penultimateWeightUpdatematrix2.get(0, 0), 0.001);
188                    assertEquals(0.0, penultimateWeightUpdatematrix2.get(0, 1), 0.001);
189    
190                    Matrix weightUpdateMatrix1 = BackPropLearning.calculateWeightUpdates(
191                                    layer1Sensitivity, inputVector1, 0.1);
192                    assertEquals(0.0049, weightUpdateMatrix1.get(0, 0), 0.001);
193                    assertEquals(-0.00997, weightUpdateMatrix1.get(1, 0), 0.001);
194    
195                    Matrix lastWeightUpdateMatrix1 = layer1.getLastWeightUpdateMatrix();
196                    assertEquals(0.0049, lastWeightUpdateMatrix1.get(0, 0), 0.001);
197                    assertEquals(-0.00997, lastWeightUpdateMatrix1.get(1, 0), 0.001);
198                    Matrix penultimateWeightUpdatematrix1 = layer1
199                                    .getPenultimateWeightUpdateMatrix();
200                    assertEquals(0.0, penultimateWeightUpdatematrix1.get(0, 0), 0.001);
201                    assertEquals(0.0, penultimateWeightUpdatematrix1.get(1, 0), 0.001);
202    
203            }
204    
205            public void testBiasUpdateMatrixesFormedCorrectly() {
206                    Matrix weightMatrix1 = new Matrix(2, 1);
207                    weightMatrix1.set(0, 0, -0.27);
208                    weightMatrix1.set(1, 0, -0.41);
209    
210                    Vector biasVector1 = new Vector(2);
211                    biasVector1.setValue(0, -0.48);
212                    biasVector1.setValue(1, -0.13);
213    
214                    Layer layer1 = new Layer(weightMatrix1, biasVector1,
215                                    new LogSigActivationFunction());
216                    LayerSensitivity layer1Sensitivity = new LayerSensitivity(layer1);
217    
218                    Vector inputVector1 = new Vector(1);
219                    inputVector1.setValue(0, 1);
220    
221                    layer1.feedForward(inputVector1);
222    
223                    Matrix weightMatrix2 = new Matrix(1, 2);
224                    weightMatrix2.set(0, 0, 0.09);
225                    weightMatrix2.set(0, 1, -0.17);
226    
227                    Vector biasVector2 = new Vector(1);
228                    biasVector2.setValue(0, 0.48);
229    
230                    Layer layer2 = new Layer(weightMatrix2, biasVector2,
231                                    new PureLinearActivationFunction());
232                    LayerSensitivity layer2Sensitivity = new LayerSensitivity(layer2);
233                    Vector inputVector2 = layer1.getLastActivationValues();
234                    layer2.feedForward(inputVector2);
235    
236                    Vector errorVector = new Vector(1);
237                    errorVector.setValue(0, 1.261);
238                    layer2Sensitivity.sensitivityMatrixFromErrorMatrix(errorVector);
239    
240                    layer1Sensitivity
241                                    .sensitivityMatrixFromSucceedingLayer(layer2Sensitivity);
242    
243                    Vector biasUpdateVector2 = BackPropLearning.calculateBiasUpdates(
244                                    layer2Sensitivity, 0.1);
245                    assertEquals(0.2522, biasUpdateVector2.getValue(0), 0.001);
246    
247                    Vector lastBiasUpdateVector2 = layer2.getLastBiasUpdateVector();
248                    assertEquals(0.2522, lastBiasUpdateVector2.getValue(0), 0.001);
249    
250                    Vector penultimateBiasUpdateVector2 = layer2
251                                    .getPenultimateBiasUpdateVector();
252                    assertEquals(0.0, penultimateBiasUpdateVector2.getValue(0), 0.001);
253    
254                    Vector biasUpdateVector1 = BackPropLearning.calculateBiasUpdates(
255                                    layer1Sensitivity, 0.1);
256                    assertEquals(0.00495, biasUpdateVector1.getValue(0), 0.001);
257                    assertEquals(-0.00997, biasUpdateVector1.getValue(1), 0.001);
258    
259                    Vector lastBiasUpdateVector1 = layer1.getLastBiasUpdateVector();
260    
261                    assertEquals(0.00495, lastBiasUpdateVector1.getValue(0), 0.001);
262                    assertEquals(-0.00997, lastBiasUpdateVector1.getValue(1), 0.001);
263    
264                    Vector penultimateBiasUpdateVector1 = layer1
265                                    .getPenultimateBiasUpdateVector();
266                    assertEquals(0.0, penultimateBiasUpdateVector1.getValue(0), 0.001);
267                    assertEquals(0.0, penultimateBiasUpdateVector1.getValue(1), 0.001);
268    
269            }
270    
271            public void testWeightsAndBiasesUpdatedCorrectly() {
272                    Matrix weightMatrix1 = new Matrix(2, 1);
273                    weightMatrix1.set(0, 0, -0.27);
274                    weightMatrix1.set(1, 0, -0.41);
275    
276                    Vector biasVector1 = new Vector(2);
277                    biasVector1.setValue(0, -0.48);
278                    biasVector1.setValue(1, -0.13);
279    
280                    Layer layer1 = new Layer(weightMatrix1, biasVector1,
281                                    new LogSigActivationFunction());
282                    LayerSensitivity layer1Sensitivity = new LayerSensitivity(layer1);
283    
284                    Vector inputVector1 = new Vector(1);
285                    inputVector1.setValue(0, 1);
286    
287                    layer1.feedForward(inputVector1);
288    
289                    Matrix weightMatrix2 = new Matrix(1, 2);
290                    weightMatrix2.set(0, 0, 0.09);
291                    weightMatrix2.set(0, 1, -0.17);
292    
293                    Vector biasVector2 = new Vector(1);
294                    biasVector2.setValue(0, 0.48);
295    
296                    Layer layer2 = new Layer(weightMatrix2, biasVector2,
297                                    new PureLinearActivationFunction());
298                    Vector inputVector2 = layer1.getLastActivationValues();
299                    layer2.feedForward(inputVector2);
300    
301                    Vector errorVector = new Vector(1);
302                    errorVector.setValue(0, 1.261);
303                    LayerSensitivity layer2Sensitivity = new LayerSensitivity(layer2);
304                    layer2Sensitivity.sensitivityMatrixFromErrorMatrix(errorVector);
305    
306                    layer1Sensitivity
307                                    .sensitivityMatrixFromSucceedingLayer(layer2Sensitivity);
308    
309                    BackPropLearning.calculateWeightUpdates(layer2Sensitivity, layer1
310                                    .getLastActivationValues(), 0.1);
311    
312                    BackPropLearning.calculateBiasUpdates(layer2Sensitivity, 0.1);
313    
314                    BackPropLearning.calculateWeightUpdates(layer1Sensitivity,
315                                    inputVector1, 0.1);
316    
317                    BackPropLearning.calculateBiasUpdates(layer1Sensitivity, 0.1);
318    
319                    layer2.updateWeights();
320                    Matrix newWeightMatrix2 = layer2.getWeightMatrix();
321                    assertEquals(0.171, newWeightMatrix2.get(0, 0), 0.001);
322                    assertEquals(-0.0772, newWeightMatrix2.get(0, 1), 0.001);
323    
324                    layer2.updateBiases();
325                    Vector newBiasVector2 = layer2.getBiasVector();
326                    assertEquals(0.7322, newBiasVector2.getValue(0));
327    
328                    layer1.updateWeights();
329                    Matrix newWeightMatrix1 = layer1.getWeightMatrix();
330    
331                    assertEquals(-0.265, newWeightMatrix1.get(0, 0), 0.001);
332                    assertEquals(-0.419, newWeightMatrix1.get(1, 0), 0.001);
333    
334                    layer1.updateBiases();
335                    Vector newBiasVector1 = layer1.getBiasVector();
336    
337                    assertEquals(-0.475, newBiasVector1.getValue(0), 0.001);
338                    assertEquals(-0.139, newBiasVector1.getValue(1), 0.001);
339            }
340    }