public static void main(String[] args) { Logging.stopConsoleLogging(); NeuralDataSet trainingSet = new BasicNeuralDataSet(XOR_INPUT, XOR_IDEAL); BasicNetwork network = EncogUtility.simpleFeedForward(2, 4, 0, 1, false); ResilientPropagation train = new ResilientPropagation(network, trainingSet); train.addStrategy(new RequiredImprovementStrategy(5)); System.out.println("Perform initial train."); EncogUtility.trainToError(train, network, trainingSet, 0.01); TrainingContinuation cont = train.pause(); System.out.println( Arrays.toString((double[]) cont.getContents().get(ResilientPropagation.LAST_GRADIENTS))); System.out.println( Arrays.toString((double[]) cont.getContents().get(ResilientPropagation.UPDATE_VALUES))); try { SerializeObject.save("resume.ser", cont); cont = (TrainingContinuation) SerializeObject.load("resume.ser"); } catch (Exception ex) { ex.printStackTrace(); } System.out.println( "Now trying a second train, with continue from the first. Should stop after one iteration"); ResilientPropagation train2 = new ResilientPropagation(network, trainingSet); train2.resume(cont); EncogUtility.trainToError(train2, network, trainingSet, 0.01); }
private void processNetwork() throws IOException { System.out.println("Downsampling images..."); for (final ImagePair pair : this.imageList) { final NeuralData ideal = new BasicNeuralData(this.outputCount); final int idx = pair.getIdentity(); for (int i = 0; i < this.outputCount; i++) { if (i == idx) { ideal.setData(i, 1); } else { ideal.setData(i, -1); } } final Image img = ImageIO.read(pair.getFile()); final ImageNeuralData data = new ImageNeuralData(img); this.training.add(data, ideal); } final String strHidden1 = getArg("hidden1"); final String strHidden2 = getArg("hidden2"); this.training.downsample(this.downsampleHeight, this.downsampleWidth); final int hidden1 = Integer.parseInt(strHidden1); final int hidden2 = Integer.parseInt(strHidden2); this.network = EncogUtility.simpleFeedForward( this.training.getInputSize(), hidden1, hidden2, this.training.getIdealSize(), true); System.out.println("Created network: " + this.network.toString()); }
public void testAnalyze() { BasicNetwork network = EncogUtility.simpleFeedForward(2, 2, 0, 1, false); double[] weights = new double[network.encodedArrayLength()]; EngineArray.fill(weights, 1.0); network.decodeFromArray(weights); AnalyzeNetwork analyze = new AnalyzeNetwork(network); Assert.assertEquals(weights.length, analyze.getWeightsAndBias().getSamples()); Assert.assertEquals(3, analyze.getBias().getSamples()); Assert.assertEquals(6, analyze.getWeights().getSamples()); }
public void testCompleteTrain() { MLDataSet trainingData = new BasicMLDataSet(XOR.XOR_INPUT, XOR.XOR_IDEAL); BasicNetwork network = EncogUtility.simpleFeedForward(2, 5, 7, 1, true); (new ConsistentRandomizer(-1, 1)).randomize(network); MLTrain rprop = new ResilientPropagation(network, trainingData); int iteration = 0; do { rprop.iteration(); iteration++; } while (iteration < 5000 && rprop.getError() > 0.01); Assert.assertTrue(iteration < 40); }
public PredictSIN() { this.setTitle("SIN Wave Predict"); this.setSize(640, 480); Container content = this.getContentPane(); content.setLayout(new BorderLayout()); content.add(graph = new GraphPanel(), BorderLayout.CENTER); network = EncogUtility.simpleFeedForward(INPUT_WINDOW, PREDICT_WINDOW * 2, 0, 1, true); network.reset(); graph.setNetwork(network); this.trainingData = generateTraining(); this.train = new ResilientPropagation(this.network, this.trainingData); btnTrain = new JButton("Train"); this.btnTrain.addActionListener(this); content.add(btnTrain, BorderLayout.SOUTH); graph.setError(network.calculateError(this.trainingData)); }