public static void train(File dataDir) { final File networkFile = new File(dataDir, Config.NETWORK_FILE); final File trainingFile = new File(dataDir, Config.TRAINING_FILE); // network file if (!networkFile.exists()) { System.out.println("Can't read file: " + networkFile.getAbsolutePath()); return; } BasicNetwork network = (BasicNetwork) EncogDirectoryPersistence.loadObject(networkFile); // training file if (!trainingFile.exists()) { System.out.println("Can't read file: " + trainingFile.getAbsolutePath()); return; } final MLDataSet trainingSet = EncogUtility.loadEGB2Memory(trainingFile); // train the neural network EncogUtility.trainConsole(network, trainingSet, Config.TRAINING_MINUTES); System.out.println("Final Error: " + (float) network.calculateError(trainingSet)); System.out.println("Training complete, saving network."); EncogDirectoryPersistence.saveObject(networkFile, network); System.out.println("Network saved."); Encog.getInstance().shutdown(); }
public static void main(String[] args) { Logging.stopConsoleLogging(); NeuralDataSet trainingSet = new BasicNeuralDataSet(XOR_INPUT, XOR_IDEAL); BasicNetwork network = EncogUtility.simpleFeedForward(2, 4, 0, 1, false); ResilientPropagation train = new ResilientPropagation(network, trainingSet); train.addStrategy(new RequiredImprovementStrategy(5)); System.out.println("Perform initial train."); EncogUtility.trainToError(train, network, trainingSet, 0.01); TrainingContinuation cont = train.pause(); System.out.println( Arrays.toString((double[]) cont.getContents().get(ResilientPropagation.LAST_GRADIENTS))); System.out.println( Arrays.toString((double[]) cont.getContents().get(ResilientPropagation.UPDATE_VALUES))); try { SerializeObject.save("resume.ser", cont); cont = (TrainingContinuation) SerializeObject.load("resume.ser"); } catch (Exception ex) { ex.printStackTrace(); } System.out.println( "Now trying a second train, with continue from the first. Should stop after one iteration"); ResilientPropagation train2 = new ResilientPropagation(network, trainingSet); train2.resume(cont); EncogUtility.trainToError(train2, network, trainingSet, 0.01); }
/** * Calculate the error for the given method and dataset. * * @param method The method to use. * @param data The data to use. * @return The error. */ public double calculateError(MLMethod method, MLDataSet data) { if (this.dataset.getNormHelper().getOutputColumns().size() == 1) { ColumnDefinition cd = this.dataset.getNormHelper().getOutputColumns().get(0); if (cd.getDataType() == ColumnType.nominal) { return EncogUtility.calculateClassificationError((MLClassification) method, data); } } return EncogUtility.calculateRegressionError((MLRegression) method, data); }
/** * Evaluate the network and display (to the console) the output for every value in the training * set. Displays ideal and actual. * * @param network The network to evaluate. * @param training The training set to evaluate. */ public static void evaluate(final MLRegression network, final MLDataSet training) { for (final MLDataPair pair : training) { final MLData output = network.compute(pair.getInput()); System.out.println( "Input=" + EncogUtility.formatNeuralData(pair.getInput()) + ", Actual=" + EncogUtility.formatNeuralData(output) + ", Ideal=" + EncogUtility.formatNeuralData(pair.getIdeal())); } }
/** {@inheritDoc} */ @Override public boolean executeCommand(final String args) { // get filenames final String sourceID = getProp().getPropertyString(ScriptProperties.GENERATE_CONFIG_SOURCE_FILE); final String targetID = getProp().getPropertyString(ScriptProperties.GENERATE_CONFIG_TARGET_FILE); final CSVFormat format = getAnalyst().getScript().determineFormat(); EncogLogging.log(EncogLogging.LEVEL_DEBUG, "Beginning generate"); EncogLogging.log(EncogLogging.LEVEL_DEBUG, "source file:" + sourceID); EncogLogging.log(EncogLogging.LEVEL_DEBUG, "target file:" + targetID); final File sourceFile = getScript().resolveFilename(sourceID); final File targetFile = getScript().resolveFilename(targetID); // mark generated getScript().markGenerated(targetID); // read file final boolean headers = getScript().expectInputHeaders(sourceID); final CSVHeaders headerList = new CSVHeaders(sourceFile, headers, format); final int[] input = determineInputFields(headerList); final int[] ideal = determineIdealFields(headerList); EncogUtility.convertCSV2Binary(sourceFile, format, targetFile, input, ideal, headers); return false; }
private void processNetwork() throws IOException { System.out.println("Downsampling images..."); for (final ImagePair pair : this.imageList) { final NeuralData ideal = new BasicNeuralData(this.outputCount); final int idx = pair.getIdentity(); for (int i = 0; i < this.outputCount; i++) { if (i == idx) { ideal.setData(i, 1); } else { ideal.setData(i, -1); } } final Image img = ImageIO.read(pair.getFile()); final ImageNeuralData data = new ImageNeuralData(img); this.training.add(data, ideal); } final String strHidden1 = getArg("hidden1"); final String strHidden2 = getArg("hidden2"); this.training.downsample(this.downsampleHeight, this.downsampleWidth); final int hidden1 = Integer.parseInt(strHidden1); final int hidden2 = Integer.parseInt(strHidden2); this.network = EncogUtility.simpleFeedForward( this.training.getInputSize(), hidden1, hidden2, this.training.getIdealSize(), true); System.out.println("Created network: " + this.network.toString()); }
public void testAnalyze() { BasicNetwork network = EncogUtility.simpleFeedForward(2, 2, 0, 1, false); double[] weights = new double[network.encodedArrayLength()]; EngineArray.fill(weights, 1.0); network.decodeFromArray(weights); AnalyzeNetwork analyze = new AnalyzeNetwork(network); Assert.assertEquals(weights.length, analyze.getWeightsAndBias().getSamples()); Assert.assertEquals(3, analyze.getBias().getSamples()); Assert.assertEquals(6, analyze.getWeights().getSamples()); }
/** * Train the method, to a specific error, send the output to the console. * * @param method The method to train. * @param dataSet The training set to use. * @param error The error level to train to. */ public static void trainToError( final MLMethod method, final MLDataSet dataSet, final double error) { MLTrain train; if (method instanceof SVM) { train = new SVMTrain((SVM) method, dataSet); } else { train = new ResilientPropagation((ContainsFlat) method, dataSet); } EncogUtility.trainToError(train, error); }
public void testCompleteTrain() { MLDataSet trainingData = new BasicMLDataSet(XOR.XOR_INPUT, XOR.XOR_IDEAL); BasicNetwork network = EncogUtility.simpleFeedForward(2, 5, 7, 1, true); (new ConsistentRandomizer(-1, 1)).randomize(network); MLTrain rprop = new ResilientPropagation(network, trainingData); int iteration = 0; do { rprop.iteration(); iteration++; } while (iteration < 5000 && rprop.getError() > 0.01); Assert.assertTrue(iteration < 40); }
private void processTrain() throws IOException { final String strMode = getArg("mode"); final String strMinutes = getArg("minutes"); final String strStrategyError = getArg("strategyerror"); final String strStrategyCycles = getArg("strategycycles"); System.out.println("Training Beginning... Output patterns=" + this.outputCount); final double strategyError = Double.parseDouble(strStrategyError); final int strategyCycles = Integer.parseInt(strStrategyCycles); final ResilientPropagation train = new ResilientPropagation(this.network, this.training); train.addStrategy(new ResetStrategy(strategyError, strategyCycles)); if (strMode.equalsIgnoreCase("gui")) { EncogUtility.trainDialog(train, this.network, this.training); } else { final int minutes = Integer.parseInt(strMinutes); EncogUtility.trainConsole(train, this.network, this.training, minutes); } System.out.println("Training Stopped..."); }
public PredictSIN() { this.setTitle("SIN Wave Predict"); this.setSize(640, 480); Container content = this.getContentPane(); content.setLayout(new BorderLayout()); content.add(graph = new GraphPanel(), BorderLayout.CENTER); network = EncogUtility.simpleFeedForward(INPUT_WINDOW, PREDICT_WINDOW * 2, 0, 1, true); network.reset(); graph.setNetwork(network); this.trainingData = generateTraining(); this.train = new ResilientPropagation(this.network, this.trainingData); btnTrain = new JButton("Train"); this.btnTrain.addActionListener(this); content.add(btnTrain, BorderLayout.SOUTH); graph.setError(network.calculateError(this.trainingData)); }
/** * Calculate the error for this neural network. * * @param data The training set. * @return The error percentage. */ @Override public double calculateError(final MLDataSet data) { return EncogUtility.calculateRegressionError(this, data); }
/** * Train the neural network, using SCG training, and output status to the console. * * @param network The network to train. * @param trainingSet The training set. * @param minutes The number of minutes to train for. */ public static void trainConsole( final BasicNetwork network, final MLDataSet trainingSet, final int minutes) { final Propagation train = new ResilientPropagation(network, trainingSet); train.setThreadCount(0); EncogUtility.trainConsole(train, network, trainingSet, minutes); }