001    package aima.learning.neural;
002    
003    import java.util.ArrayList;
004    import java.util.List;
005    
006    import aima.util.Matrix;
007    
008    public class LayerSensitivity {
009            /*
010             * contains sensitivity matrices and related calculations for each layer.
011             * Used for backprop learning
012             */
013    
014            private Matrix sensitivityMatrix;
015            private final Layer layer;
016    
017            public LayerSensitivity(Layer layer) {
018                    Matrix weightMatrix = layer.getWeightMatrix();
019                    this.sensitivityMatrix = new Matrix(weightMatrix.getRowDimension(),
020                                    weightMatrix.getColumnDimension());
021                    this.layer = layer;
022    
023            }
024    
025            public Matrix getSensitivityMatrix() {
026                    return sensitivityMatrix;
027            }
028    
029            public Matrix sensitivityMatrixFromErrorMatrix(Vector errorVector) {
030                    Matrix derivativeMatrix = createDerivativeMatrix(layer
031                                    .getLastInducedField());
032                    Matrix calculatedSensitivityMatrix = derivativeMatrix
033                                    .times(errorVector).times(-2.0);
034                    sensitivityMatrix = calculatedSensitivityMatrix.copy();
035                    return calculatedSensitivityMatrix;
036            }
037    
038            public Matrix sensitivityMatrixFromSucceedingLayer(
039                            LayerSensitivity nextLayerSensitivity) {
040                    Layer nextLayer = nextLayerSensitivity.getLayer();
041                    Matrix derivativeMatrix = createDerivativeMatrix(layer
042                                    .getLastInducedField());
043                    Matrix weightTranspose = nextLayer.getWeightMatrix().transpose();
044                    Matrix calculatedSensitivityMatrix = derivativeMatrix.times(
045                                    weightTranspose).times(
046                                    nextLayerSensitivity.getSensitivityMatrix());
047                    sensitivityMatrix = calculatedSensitivityMatrix.copy();
048                    return sensitivityMatrix;
049            }
050    
051            private Matrix createDerivativeMatrix(Vector lastInducedField) {
052                    List<Double> lst = new ArrayList<Double>();
053                    for (int i = 0; i < lastInducedField.size(); i++) {
054                            lst.add(new Double(layer.getActivationFunction().deriv(
055                                            lastInducedField.getValue(i))));
056                    }
057                    return Matrix.createDiagonalMatrix(lst);
058            }
059    
060            public Layer getLayer() {
061                    return layer;
062            }
063    
064    }