/** * Give a prediction based on some input. * * @param input The input to the neural network which is equal in size to the number of input * neurons. * @return The output of the neural network. */ public Matrix predict(Matrix input) { if (input.getNumRows() != layers.get(0).getLayerSize().getInputSize()) { throw new InvalidParameterException( "Input size did not match the input size of the first layer"); } Matrix modInput = (Matrix) input.clone(); for (Layer l : layers) { modInput = l.activate(modInput); } return modInput; }
/** * Applies the activation function to the processed input. * * @param input The input to the activation function. * @return The output of the activation function. */ private Matrix applyFunction(Matrix input) { Matrix activated = (Matrix) input.clone(); for (int row = 0; row < input.getNumRows(); row++) for (int col = 0; col < input.getNumCols(); col++) activated.set(row, col, function.activate(input.get(row, col))); if (function instanceof Softmax) { double sum = activated.sum(); if (sum != 0) activated = activated.multiply(1 / sum); } return activated; }
private Matrix applyFunctionDerivative(Matrix input) { Matrix activated = (Matrix) input.clone(); if (function instanceof Softmax) activated = activated.map( new Matrix.Function() { @Override public double function(double x) { return Math.exp(x); } }); else activated = activated.map( new Matrix.Function() { @Override public double function(double x) { return function.derivative(x); } }); if (function instanceof Softmax) { double sum = activated.sum(); if (sum != 0) activated = activated.multiply(1 / sum); activated = activated.subtract(input); activated = activated.map( new Matrix.Function() { @Override public double function(double x) { return function.activate(x); } }); } return activated; }
/** * Set the weight matrix of a layer of the neural network. * * @param layer The layer number of the neural network. * @param weights The new weight matrix for the layer. */ public void setWeights(int layer, Matrix weights) { layers.get(layer).weightMatrix = (Matrix) weights.clone(); }