001 package aima.learning.neural; 002 003 import aima.util.Matrix; 004 005 public class BackPropLearning implements NNTrainingScheme { 006 private final double learningRate; 007 private final double momentum; 008 009 private Layer hiddenLayer; 010 private Layer outputLayer; 011 private LayerSensitivity hiddenSensitivity; 012 private LayerSensitivity outputSensitivity; 013 014 public BackPropLearning(double learningRate, double momentum) { 015 016 this.learningRate = learningRate; 017 this.momentum = momentum; 018 019 } 020 021 public void setNeuralNetwork(FunctionApproximator fapp) { 022 FeedForwardNeuralNetwork ffnn = (FeedForwardNeuralNetwork) fapp; 023 this.hiddenLayer = ffnn.getHiddenLayer(); 024 this.outputLayer = ffnn.getOutputLayer(); 025 this.hiddenSensitivity = new LayerSensitivity(hiddenLayer); 026 this.outputSensitivity = new LayerSensitivity(outputLayer); 027 } 028 029 public Vector processInput(FeedForwardNeuralNetwork network, Vector input) { 030 031 hiddenLayer.feedForward(input); 032 outputLayer.feedForward(hiddenLayer.getLastActivationValues()); 033 return outputLayer.getLastActivationValues(); 034 } 035 036 public void processError(FeedForwardNeuralNetwork network, Vector error) { 037 // TODO calculate total error somewhere 038 // create Sensitivity Matrices 039 outputSensitivity.sensitivityMatrixFromErrorMatrix(error); 040 041 hiddenSensitivity 042 .sensitivityMatrixFromSucceedingLayer(outputSensitivity); 043 044 // calculate weight Updates 045 calculateWeightUpdates(outputSensitivity, hiddenLayer 046 .getLastActivationValues(), learningRate, momentum); 047 calculateWeightUpdates(hiddenSensitivity, hiddenLayer 048 .getLastInputValues(), learningRate, momentum); 049 050 // calculate Bias Updates 051 calculateBiasUpdates(outputSensitivity, learningRate, momentum); 052 calculateBiasUpdates(hiddenSensitivity, learningRate, momentum); 053 054 // update weightsAndBiases 055 outputLayer.updateWeights(); 056 outputLayer.updateBiases(); 057 058 hiddenLayer.updateWeights(); 059 hiddenLayer.updateBiases(); 060 061 } 062 063 public Matrix calculateWeightUpdates(LayerSensitivity layerSensitivity, 064 Vector previousLayerActivationOrInput, double alpha, double momentum) { 065 Layer layer = layerSensitivity.getLayer(); 066 Matrix activationTranspose = previousLayerActivationOrInput.transpose(); 067 Matrix momentumLessUpdate = layerSensitivity.getSensitivityMatrix() 068 .times(activationTranspose).times(alpha).times(-1.0); 069 Matrix updateWithMomentum = layer.getLastWeightUpdateMatrix().times( 070 momentum).plus(momentumLessUpdate.times(1.0 - momentum)); 071 layer.acceptNewWeightUpdate(updateWithMomentum.copy()); 072 return updateWithMomentum; 073 } 074 075 public static Matrix calculateWeightUpdates( 076 LayerSensitivity layerSensitivity, 077 Vector previousLayerActivationOrInput, double alpha) { 078 Layer layer = layerSensitivity.getLayer(); 079 Matrix activationTranspose = previousLayerActivationOrInput.transpose(); 080 Matrix weightUpdateMatrix = layerSensitivity.getSensitivityMatrix() 081 .times(activationTranspose).times(alpha).times(-1.0); 082 layer.acceptNewWeightUpdate(weightUpdateMatrix.copy()); 083 return weightUpdateMatrix; 084 } 085 086 public Vector calculateBiasUpdates(LayerSensitivity layerSensitivity, 087 double alpha, double momentum) { 088 Layer layer = layerSensitivity.getLayer(); 089 Matrix biasUpdateMatrixWithoutMomentum = layerSensitivity 090 .getSensitivityMatrix().times(alpha).times(-1.0); 091 092 Matrix biasUpdateMatrixWithMomentum = layer.getLastBiasUpdateVector() 093 .times(momentum).plus( 094 biasUpdateMatrixWithoutMomentum.times(1.0 - momentum)); 095 Vector result = new Vector(biasUpdateMatrixWithMomentum 096 .getRowDimension()); 097 for (int i = 0; i < biasUpdateMatrixWithMomentum.getRowDimension(); i++) { 098 result.setValue(i, biasUpdateMatrixWithMomentum.get(i, 0)); 099 } 100 layer.acceptNewBiasUpdate(result.copyVector()); 101 return result; 102 } 103 104 public static Vector calculateBiasUpdates( 105 LayerSensitivity layerSensitivity, double alpha) { 106 Layer layer = layerSensitivity.getLayer(); 107 Matrix biasUpdateMatrix = layerSensitivity.getSensitivityMatrix() 108 .times(alpha).times(-1.0); 109 110 Vector result = new Vector(biasUpdateMatrix.getRowDimension()); 111 for (int i = 0; i < biasUpdateMatrix.getRowDimension(); i++) { 112 result.setValue(i, biasUpdateMatrix.get(i, 0)); 113 } 114 layer.acceptNewBiasUpdate(result.copyVector()); 115 return result; 116 } 117 }