001    package aima.learning.neural;
002    
003    import aima.util.Matrix;
004    
005    public class BackPropLearning implements NNTrainingScheme {
006            private final double learningRate;
007            private final double momentum;
008    
009            private Layer hiddenLayer;
010            private Layer outputLayer;
011            private LayerSensitivity hiddenSensitivity;
012            private LayerSensitivity outputSensitivity;
013    
014            public BackPropLearning(double learningRate, double momentum) {
015    
016                    this.learningRate = learningRate;
017                    this.momentum = momentum;
018    
019            }
020    
021            public void setNeuralNetwork(FunctionApproximator fapp) {
022                    FeedForwardNeuralNetwork ffnn = (FeedForwardNeuralNetwork) fapp;
023                    this.hiddenLayer = ffnn.getHiddenLayer();
024                    this.outputLayer = ffnn.getOutputLayer();
025                    this.hiddenSensitivity = new LayerSensitivity(hiddenLayer);
026                    this.outputSensitivity = new LayerSensitivity(outputLayer);
027            }
028    
029            public Vector processInput(FeedForwardNeuralNetwork network, Vector input) {
030    
031                    hiddenLayer.feedForward(input);
032                    outputLayer.feedForward(hiddenLayer.getLastActivationValues());
033                    return outputLayer.getLastActivationValues();
034            }
035    
036            public void processError(FeedForwardNeuralNetwork network, Vector error) {
037                    // TODO calculate total error somewhere
038                    // create Sensitivity Matrices
039                    outputSensitivity.sensitivityMatrixFromErrorMatrix(error);
040    
041                    hiddenSensitivity
042                                    .sensitivityMatrixFromSucceedingLayer(outputSensitivity);
043    
044                    // calculate weight Updates
045                    calculateWeightUpdates(outputSensitivity, hiddenLayer
046                                    .getLastActivationValues(), learningRate, momentum);
047                    calculateWeightUpdates(hiddenSensitivity, hiddenLayer
048                                    .getLastInputValues(), learningRate, momentum);
049    
050                    // calculate Bias Updates
051                    calculateBiasUpdates(outputSensitivity, learningRate, momentum);
052                    calculateBiasUpdates(hiddenSensitivity, learningRate, momentum);
053    
054                    // update weightsAndBiases
055                    outputLayer.updateWeights();
056                    outputLayer.updateBiases();
057    
058                    hiddenLayer.updateWeights();
059                    hiddenLayer.updateBiases();
060    
061            }
062    
063            public Matrix calculateWeightUpdates(LayerSensitivity layerSensitivity,
064                            Vector previousLayerActivationOrInput, double alpha, double momentum) {
065                    Layer layer = layerSensitivity.getLayer();
066                    Matrix activationTranspose = previousLayerActivationOrInput.transpose();
067                    Matrix momentumLessUpdate = layerSensitivity.getSensitivityMatrix()
068                                    .times(activationTranspose).times(alpha).times(-1.0);
069                    Matrix updateWithMomentum = layer.getLastWeightUpdateMatrix().times(
070                                    momentum).plus(momentumLessUpdate.times(1.0 - momentum));
071                    layer.acceptNewWeightUpdate(updateWithMomentum.copy());
072                    return updateWithMomentum;
073            }
074    
075            public static Matrix calculateWeightUpdates(
076                            LayerSensitivity layerSensitivity,
077                            Vector previousLayerActivationOrInput, double alpha) {
078                    Layer layer = layerSensitivity.getLayer();
079                    Matrix activationTranspose = previousLayerActivationOrInput.transpose();
080                    Matrix weightUpdateMatrix = layerSensitivity.getSensitivityMatrix()
081                                    .times(activationTranspose).times(alpha).times(-1.0);
082                    layer.acceptNewWeightUpdate(weightUpdateMatrix.copy());
083                    return weightUpdateMatrix;
084            }
085    
086            public Vector calculateBiasUpdates(LayerSensitivity layerSensitivity,
087                            double alpha, double momentum) {
088                    Layer layer = layerSensitivity.getLayer();
089                    Matrix biasUpdateMatrixWithoutMomentum = layerSensitivity
090                                    .getSensitivityMatrix().times(alpha).times(-1.0);
091    
092                    Matrix biasUpdateMatrixWithMomentum = layer.getLastBiasUpdateVector()
093                                    .times(momentum).plus(
094                                                    biasUpdateMatrixWithoutMomentum.times(1.0 - momentum));
095                    Vector result = new Vector(biasUpdateMatrixWithMomentum
096                                    .getRowDimension());
097                    for (int i = 0; i < biasUpdateMatrixWithMomentum.getRowDimension(); i++) {
098                            result.setValue(i, biasUpdateMatrixWithMomentum.get(i, 0));
099                    }
100                    layer.acceptNewBiasUpdate(result.copyVector());
101                    return result;
102            }
103    
104            public static Vector calculateBiasUpdates(
105                            LayerSensitivity layerSensitivity, double alpha) {
106                    Layer layer = layerSensitivity.getLayer();
107                    Matrix biasUpdateMatrix = layerSensitivity.getSensitivityMatrix()
108                                    .times(alpha).times(-1.0);
109    
110                    Vector result = new Vector(biasUpdateMatrix.getRowDimension());
111                    for (int i = 0; i < biasUpdateMatrix.getRowDimension(); i++) {
112                            result.setValue(i, biasUpdateMatrix.get(i, 0));
113                    }
114                    layer.acceptNewBiasUpdate(result.copyVector());
115                    return result;
116            }
117    }