001    package aima.learning.neural;
002    
003    import aima.learning.framework.DataSet;
004    import aima.util.Matrix;
005    
006    public class FeedForwardNeuralNetwork implements FunctionApproximator {
007    
008            public static final String UPPER_LIMIT_WEIGHTS = "upper_limit_weights";
009            public static final String LOWER_LIMIT_WEIGHTS = "lower_limit_weights";
010            public static final String NUMBER_OF_OUTPUTS = "number_of_outputs";
011            public static final String NUMBER_OF_HIDDEN_NEURONS = "number_of_hidden_neurons";
012            public static final String NUMBER_OF_INPUTS = "number_of_inputs";
013            private final Layer hiddenLayer;
014            private final Layer outputLayer;
015    
016            private NNTrainingScheme trainingScheme;
017    
018            /*
019             * constructor to be used for non testing code.
020             */
021            public FeedForwardNeuralNetwork(NNConfig config) {
022    
023                    int numberOfInputNeurons = config
024                                    .getParameterAsInteger(NUMBER_OF_INPUTS);
025                    int numberOfHiddenNeurons = config
026                                    .getParameterAsInteger(NUMBER_OF_HIDDEN_NEURONS);
027                    int numberOfOutputNeurons = config
028                                    .getParameterAsInteger(NUMBER_OF_OUTPUTS);
029    
030                    double lowerLimitForWeights = config
031                                    .getParameterAsDouble(LOWER_LIMIT_WEIGHTS);
032                    double upperLimitForWeights = config
033                                    .getParameterAsDouble(UPPER_LIMIT_WEIGHTS);
034    
035                    hiddenLayer = new Layer(numberOfHiddenNeurons, numberOfInputNeurons,
036                                    lowerLimitForWeights, upperLimitForWeights,
037                                    new LogSigActivationFunction());
038    
039                    outputLayer = new Layer(numberOfOutputNeurons, numberOfHiddenNeurons,
040                                    lowerLimitForWeights, upperLimitForWeights,
041                                    new PureLinearActivationFunction());
042    
043            }
044    
045            /*
046             * ONLY for testing to set up a network with known weights in future use to
047             * deserialize networks after adding variables for pen weightupdate,
048             * lastnput etc
049             */
050            public FeedForwardNeuralNetwork(Matrix hiddenLayerWeights,
051                            Vector hiddenLayerBias, Matrix outputLayerWeights,
052                            Vector outputLayerBias) {
053    
054                    hiddenLayer = new Layer(hiddenLayerWeights, hiddenLayerBias,
055                                    new LogSigActivationFunction());
056                    outputLayer = new Layer(outputLayerWeights, outputLayerBias,
057                                    new PureLinearActivationFunction());
058    
059            }
060    
061            public void processError(Vector error) {
062    
063                    trainingScheme.processError(this, error);
064    
065            }
066    
067            public Vector processInput(Vector input) {
068                    return trainingScheme.processInput(this, input);
069            }
070    
071            public void trainOn(NNDataSet innds, int numberofEpochs) {
072                    for (int i = 0; i < numberofEpochs; i++) {
073                            innds.refreshDataset();
074                            while (innds.hasMoreExamples()) {
075                                    NNExample nne = innds.getExampleAtRandom();
076                                    processInput(nne.getInput());
077                                    Vector error = getOutputLayer()
078                                                    .errorVectorFrom(nne.getTarget());
079                                    processError(error);
080                            }
081                    }
082    
083            }
084    
085            public Vector predict(NNExample nne) {
086                    return processInput(nne.getInput());
087            }
088    
089            public int[] testOnDataSet(NNDataSet nnds) {
090                    int[] result = new int[] { 0, 0 };
091                    nnds.refreshDataset();
092                    while (nnds.hasMoreExamples()) {
093                            NNExample nne = nnds.getExampleAtRandom();
094                            Vector prediction = predict(nne);
095                            if (nne.isCorrect(prediction)) {
096                                    result[0] = result[0] + 1;
097                            } else {
098                                    result[1] = result[1] + 1;
099                            }
100                    }
101                    return result;
102            }
103    
104            public void testOn(DataSet ds) {
105                    // TODO Auto-generated method stub
106            }
107    
108            public Matrix getHiddenLayerWeights() {
109    
110                    return hiddenLayer.getWeightMatrix();
111            }
112    
113            public Vector getHiddenLayerBias() {
114    
115                    return hiddenLayer.getBiasVector();
116            }
117    
118            public Matrix getOutputLayerWeights() {
119    
120                    return outputLayer.getWeightMatrix();
121            }
122    
123            public Vector getOutputLayerBias() {
124    
125                    return outputLayer.getBiasVector();
126            }
127    
128            public Layer getHiddenLayer() {
129                    return hiddenLayer;
130            }
131    
132            public Layer getOutputLayer() {
133                    return outputLayer;
134            }
135    
136            public void setTrainingScheme(NNTrainingScheme trainingScheme) {
137                    this.trainingScheme = trainingScheme;
138                    trainingScheme.setNeuralNetwork(this);
139            }
140    
141    }