Ejemplo n.º 1
0
  public static void train(File dataDir) {
    final File networkFile = new File(dataDir, Config.NETWORK_FILE);
    final File trainingFile = new File(dataDir, Config.TRAINING_FILE);

    // network file
    if (!networkFile.exists()) {
      System.out.println("Can't read file: " + networkFile.getAbsolutePath());
      return;
    }

    BasicNetwork network = (BasicNetwork) EncogDirectoryPersistence.loadObject(networkFile);

    // training file
    if (!trainingFile.exists()) {
      System.out.println("Can't read file: " + trainingFile.getAbsolutePath());
      return;
    }

    final MLDataSet trainingSet = EncogUtility.loadEGB2Memory(trainingFile);

    // train the neural network
    EncogUtility.trainConsole(network, trainingSet, Config.TRAINING_MINUTES);
    System.out.println("Final Error: " + (float) network.calculateError(trainingSet));
    System.out.println("Training complete, saving network.");
    EncogDirectoryPersistence.saveObject(networkFile, network);
    System.out.println("Network saved.");
    Encog.getInstance().shutdown();
  }
Ejemplo n.º 2
0
  public static void main(String[] args) {
    Logging.stopConsoleLogging();
    NeuralDataSet trainingSet = new BasicNeuralDataSet(XOR_INPUT, XOR_IDEAL);
    BasicNetwork network = EncogUtility.simpleFeedForward(2, 4, 0, 1, false);
    ResilientPropagation train = new ResilientPropagation(network, trainingSet);
    train.addStrategy(new RequiredImprovementStrategy(5));

    System.out.println("Perform initial train.");
    EncogUtility.trainToError(train, network, trainingSet, 0.01);
    TrainingContinuation cont = train.pause();
    System.out.println(
        Arrays.toString((double[]) cont.getContents().get(ResilientPropagation.LAST_GRADIENTS)));
    System.out.println(
        Arrays.toString((double[]) cont.getContents().get(ResilientPropagation.UPDATE_VALUES)));

    try {
      SerializeObject.save("resume.ser", cont);
      cont = (TrainingContinuation) SerializeObject.load("resume.ser");
    } catch (Exception ex) {
      ex.printStackTrace();
    }

    System.out.println(
        "Now trying a second train, with continue from the first.  Should stop after one iteration");
    ResilientPropagation train2 = new ResilientPropagation(network, trainingSet);
    train2.resume(cont);
    EncogUtility.trainToError(train2, network, trainingSet, 0.01);
  }
Ejemplo n.º 3
0
  /**
   * Calculate the error for the given method and dataset.
   *
   * @param method The method to use.
   * @param data The data to use.
   * @return The error.
   */
  public double calculateError(MLMethod method, MLDataSet data) {
    if (this.dataset.getNormHelper().getOutputColumns().size() == 1) {
      ColumnDefinition cd = this.dataset.getNormHelper().getOutputColumns().get(0);
      if (cd.getDataType() == ColumnType.nominal) {
        return EncogUtility.calculateClassificationError((MLClassification) method, data);
      }
    }

    return EncogUtility.calculateRegressionError((MLRegression) method, data);
  }
Ejemplo n.º 4
0
 /**
  * Evaluate the network and display (to the console) the output for every value in the training
  * set. Displays ideal and actual.
  *
  * @param network The network to evaluate.
  * @param training The training set to evaluate.
  */
 public static void evaluate(final MLRegression network, final MLDataSet training) {
   for (final MLDataPair pair : training) {
     final MLData output = network.compute(pair.getInput());
     System.out.println(
         "Input="
             + EncogUtility.formatNeuralData(pair.getInput())
             + ", Actual="
             + EncogUtility.formatNeuralData(output)
             + ", Ideal="
             + EncogUtility.formatNeuralData(pair.getIdeal()));
   }
 }
Ejemplo n.º 5
0
  /** {@inheritDoc} */
  @Override
  public boolean executeCommand(final String args) {
    // get filenames
    final String sourceID =
        getProp().getPropertyString(ScriptProperties.GENERATE_CONFIG_SOURCE_FILE);
    final String targetID =
        getProp().getPropertyString(ScriptProperties.GENERATE_CONFIG_TARGET_FILE);
    final CSVFormat format = getAnalyst().getScript().determineFormat();

    EncogLogging.log(EncogLogging.LEVEL_DEBUG, "Beginning generate");
    EncogLogging.log(EncogLogging.LEVEL_DEBUG, "source file:" + sourceID);
    EncogLogging.log(EncogLogging.LEVEL_DEBUG, "target file:" + targetID);

    final File sourceFile = getScript().resolveFilename(sourceID);
    final File targetFile = getScript().resolveFilename(targetID);

    // mark generated
    getScript().markGenerated(targetID);

    // read file
    final boolean headers = getScript().expectInputHeaders(sourceID);
    final CSVHeaders headerList = new CSVHeaders(sourceFile, headers, format);

    final int[] input = determineInputFields(headerList);
    final int[] ideal = determineIdealFields(headerList);

    EncogUtility.convertCSV2Binary(sourceFile, format, targetFile, input, ideal, headers);
    return false;
  }
  private void processNetwork() throws IOException {
    System.out.println("Downsampling images...");

    for (final ImagePair pair : this.imageList) {
      final NeuralData ideal = new BasicNeuralData(this.outputCount);
      final int idx = pair.getIdentity();
      for (int i = 0; i < this.outputCount; i++) {
        if (i == idx) {
          ideal.setData(i, 1);
        } else {
          ideal.setData(i, -1);
        }
      }

      final Image img = ImageIO.read(pair.getFile());
      final ImageNeuralData data = new ImageNeuralData(img);
      this.training.add(data, ideal);
    }

    final String strHidden1 = getArg("hidden1");
    final String strHidden2 = getArg("hidden2");

    this.training.downsample(this.downsampleHeight, this.downsampleWidth);

    final int hidden1 = Integer.parseInt(strHidden1);
    final int hidden2 = Integer.parseInt(strHidden2);

    this.network =
        EncogUtility.simpleFeedForward(
            this.training.getInputSize(), hidden1, hidden2, this.training.getIdealSize(), true);
    System.out.println("Created network: " + this.network.toString());
  }
 public void testAnalyze() {
   BasicNetwork network = EncogUtility.simpleFeedForward(2, 2, 0, 1, false);
   double[] weights = new double[network.encodedArrayLength()];
   EngineArray.fill(weights, 1.0);
   network.decodeFromArray(weights);
   AnalyzeNetwork analyze = new AnalyzeNetwork(network);
   Assert.assertEquals(weights.length, analyze.getWeightsAndBias().getSamples());
   Assert.assertEquals(3, analyze.getBias().getSamples());
   Assert.assertEquals(6, analyze.getWeights().getSamples());
 }
Ejemplo n.º 8
0
  /**
   * Train the method, to a specific error, send the output to the console.
   *
   * @param method The method to train.
   * @param dataSet The training set to use.
   * @param error The error level to train to.
   */
  public static void trainToError(
      final MLMethod method, final MLDataSet dataSet, final double error) {

    MLTrain train;

    if (method instanceof SVM) {
      train = new SVMTrain((SVM) method, dataSet);
    } else {
      train = new ResilientPropagation((ContainsFlat) method, dataSet);
    }
    EncogUtility.trainToError(train, error);
  }
Ejemplo n.º 9
0
  public void testCompleteTrain() {
    MLDataSet trainingData = new BasicMLDataSet(XOR.XOR_INPUT, XOR.XOR_IDEAL);

    BasicNetwork network = EncogUtility.simpleFeedForward(2, 5, 7, 1, true);
    (new ConsistentRandomizer(-1, 1)).randomize(network);
    MLTrain rprop = new ResilientPropagation(network, trainingData);
    int iteration = 0;
    do {
      rprop.iteration();
      iteration++;
    } while (iteration < 5000 && rprop.getError() > 0.01);
    Assert.assertTrue(iteration < 40);
  }
  private void processTrain() throws IOException {
    final String strMode = getArg("mode");
    final String strMinutes = getArg("minutes");
    final String strStrategyError = getArg("strategyerror");
    final String strStrategyCycles = getArg("strategycycles");

    System.out.println("Training Beginning... Output patterns=" + this.outputCount);

    final double strategyError = Double.parseDouble(strStrategyError);
    final int strategyCycles = Integer.parseInt(strStrategyCycles);

    final ResilientPropagation train = new ResilientPropagation(this.network, this.training);
    train.addStrategy(new ResetStrategy(strategyError, strategyCycles));

    if (strMode.equalsIgnoreCase("gui")) {
      EncogUtility.trainDialog(train, this.network, this.training);
    } else {
      final int minutes = Integer.parseInt(strMinutes);
      EncogUtility.trainConsole(train, this.network, this.training, minutes);
    }
    System.out.println("Training Stopped...");
  }
Ejemplo n.º 11
0
  public PredictSIN() {
    this.setTitle("SIN Wave Predict");
    this.setSize(640, 480);
    Container content = this.getContentPane();
    content.setLayout(new BorderLayout());
    content.add(graph = new GraphPanel(), BorderLayout.CENTER);

    network = EncogUtility.simpleFeedForward(INPUT_WINDOW, PREDICT_WINDOW * 2, 0, 1, true);
    network.reset();
    graph.setNetwork(network);

    this.trainingData = generateTraining();
    this.train = new ResilientPropagation(this.network, this.trainingData);
    btnTrain = new JButton("Train");
    this.btnTrain.addActionListener(this);
    content.add(btnTrain, BorderLayout.SOUTH);
    graph.setError(network.calculateError(this.trainingData));
  }
Ejemplo n.º 12
0
 /**
  * Calculate the error for this neural network.
  *
  * @param data The training set.
  * @return The error percentage.
  */
 @Override
 public double calculateError(final MLDataSet data) {
   return EncogUtility.calculateRegressionError(this, data);
 }
Ejemplo n.º 13
0
 /**
  * Train the neural network, using SCG training, and output status to the console.
  *
  * @param network The network to train.
  * @param trainingSet The training set.
  * @param minutes The number of minutes to train for.
  */
 public static void trainConsole(
     final BasicNetwork network, final MLDataSet trainingSet, final int minutes) {
   final Propagation train = new ResilientPropagation(network, trainingSet);
   train.setThreadCount(0);
   EncogUtility.trainConsole(train, network, trainingSet, minutes);
 }