001    package aima.learning.neural;
002    
003    import aima.util.Matrix;
004    import aima.util.Util;
005    
006    public class Layer {
007            // vectors are represented by n * 1 matrices;
008            private final Matrix weightMatrix;
009    
010            Vector biasVector, lastBiasUpdateVector;
011    
012            private final ActivationFunction activationFunction;
013    
014            private Vector lastActivationValues, lastInducedField;
015    
016            private Matrix lastWeightUpdateMatrix;
017    
018            private Matrix penultimateWeightUpdateMatrix;
019    
020            private Vector penultimateBiasUpdateVector;
021    
022            private Vector lastInput;
023    
024            public Layer(Matrix weightMatrix, Vector biasVector, ActivationFunction af) {
025    
026                    activationFunction = af;
027                    this.weightMatrix = weightMatrix;
028                    lastWeightUpdateMatrix = new Matrix(weightMatrix.getRowDimension(),
029                                    weightMatrix.getColumnDimension());
030                    penultimateWeightUpdateMatrix = new Matrix(weightMatrix
031                                    .getRowDimension(), weightMatrix.getColumnDimension());
032    
033                    this.biasVector = biasVector;
034                    lastBiasUpdateVector = new Vector(biasVector.getRowDimension());
035                    penultimateBiasUpdateVector = new Vector(biasVector.getRowDimension());
036            }
037    
038            public Layer(int numberOfNeurons, int numberOfInputs,
039                            double lowerLimitForWeights, double upperLimitForWeights,
040                            ActivationFunction af) {
041    
042                    activationFunction = af;
043                    this.weightMatrix = new Matrix(numberOfNeurons, numberOfInputs);
044                    lastWeightUpdateMatrix = new Matrix(weightMatrix.getRowDimension(),
045                                    weightMatrix.getColumnDimension());
046                    penultimateWeightUpdateMatrix = new Matrix(weightMatrix
047                                    .getRowDimension(), weightMatrix.getColumnDimension());
048    
049                    this.biasVector = new Vector(numberOfNeurons);
050                    lastBiasUpdateVector = new Vector(biasVector.getRowDimension());
051                    penultimateBiasUpdateVector = new Vector(biasVector.getRowDimension());
052    
053                    initializeMatrix(weightMatrix, lowerLimitForWeights,
054                                    upperLimitForWeights);
055                    initializeVector(biasVector, lowerLimitForWeights, upperLimitForWeights);
056    
057            }
058    
059            public Vector feedForward(Vector inputVector) {
060                    lastInput = inputVector;
061                    Matrix inducedField = weightMatrix.times(inputVector).plus(biasVector);
062    
063                    Vector inducedFieldVector = new Vector(numberOfNeurons());
064                    for (int i = 0; i < numberOfNeurons(); i++) {
065                            inducedFieldVector.setValue(i, inducedField.get(i, 0));
066                    }
067    
068                    lastInducedField = inducedFieldVector.copyVector();
069                    Vector resultVector = new Vector(numberOfNeurons());
070                    for (int i = 0; i < numberOfNeurons(); i++) {
071                            resultVector.setValue(i, activationFunction
072                                            .activation(inducedFieldVector.getValue(i)));
073                    }
074                    // set the result as the last activation value
075                    lastActivationValues = resultVector.copyVector();
076                    return resultVector;
077            }
078    
079            public Matrix getWeightMatrix() {
080                    return weightMatrix;
081            }
082    
083            public Vector getBiasVector() {
084                    return biasVector;
085            }
086    
087            public int numberOfNeurons() {
088                    return weightMatrix.getRowDimension();
089            }
090    
091            public int numberOfInputs() {
092                    return weightMatrix.getColumnDimension();
093            }
094    
095            public Vector getLastActivationValues() {
096                    return lastActivationValues;
097            }
098    
099            public Vector getLastInducedField() {
100                    return lastInducedField;
101            }
102    
103            private static void initializeMatrix(Matrix aMatrix, double lowerLimit,
104                            double upperLimit) {
105                    for (int i = 0; i < aMatrix.getRowDimension(); i++) {
106                            for (int j = 0; j < aMatrix.getColumnDimension(); j++) {
107                                    double random = Util.generateRandomDoubleBetween(lowerLimit,
108                                                    upperLimit);
109                                    aMatrix.set(i, j, random);
110                            }
111                    }
112    
113            }
114    
115            private static void initializeVector(Vector aVector, double lowerLimit,
116                            double upperLimit) {
117                    for (int i = 0; i < aVector.size(); i++) {
118    
119                            double random = Util.generateRandomDoubleBetween(lowerLimit,
120                                            upperLimit);
121                            aVector.setValue(i, random);
122                    }
123            }
124    
125            public Matrix getLastWeightUpdateMatrix() {
126                    return lastWeightUpdateMatrix;
127            }
128    
129            public void setLastWeightUpdateMatrix(Matrix m) {
130                    lastWeightUpdateMatrix = m;
131            }
132    
133            public Matrix getPenultimateWeightUpdateMatrix() {
134                    return penultimateWeightUpdateMatrix;
135            }
136    
137            public void setPenultimateWeightUpdateMatrix(Matrix m) {
138                    penultimateWeightUpdateMatrix = m;
139            }
140    
141            public Vector getLastBiasUpdateVector() {
142                    return lastBiasUpdateVector;
143            }
144    
145            public void setLastBiasUpdateVector(Vector v) {
146                    lastBiasUpdateVector = v;
147            }
148    
149            public Vector getPenultimateBiasUpdateVector() {
150                    return penultimateBiasUpdateVector;
151            }
152    
153            public void setPenultimateBiasUpdateVector(Vector v) {
154                    penultimateBiasUpdateVector = v;
155            }
156    
157            public void updateWeights() {
158                    weightMatrix.plusEquals(lastWeightUpdateMatrix);
159            }
160    
161            public void updateBiases() {
162                    Matrix biasMatrix = biasVector.plusEquals(lastBiasUpdateVector);
163                    Vector result = new Vector(biasMatrix.getRowDimension());
164                    for (int i = 0; i < biasMatrix.getRowDimension(); i++) {
165                            result.setValue(i, biasMatrix.get(i, 0));
166                    }
167                    biasVector = result;
168            }
169    
170            public Vector getLastInputValues() {
171    
172                    return lastInput;
173    
174            }
175    
176            public ActivationFunction getActivationFunction() {
177    
178                    return activationFunction;
179            }
180    
181            public void acceptNewWeightUpdate(Matrix weightUpdate) {
182                    /*
183                     * penultimate weightupdates maintained only to implement VLBP later
184                     */
185                    setPenultimateWeightUpdateMatrix(getLastWeightUpdateMatrix());
186                    setLastWeightUpdateMatrix(weightUpdate);
187            }
188    
189            public void acceptNewBiasUpdate(Vector biasUpdate) {
190                    setPenultimateBiasUpdateVector(getLastBiasUpdateVector());
191                    setLastBiasUpdateVector(biasUpdate);
192            }
193    
194            public Vector errorVectorFrom(Vector target) {
195                    return target.minus(getLastActivationValues());
196    
197            }
198    
199    }