示例#1
0
  public static void train(File dataDir) {
    final File networkFile = new File(dataDir, Config.NETWORK_FILE);
    final File trainingFile = new File(dataDir, Config.TRAINING_FILE);

    // network file
    if (!networkFile.exists()) {
      System.out.println("Can't read file: " + networkFile.getAbsolutePath());
      return;
    }

    BasicNetwork network = (BasicNetwork) EncogDirectoryPersistence.loadObject(networkFile);

    // training file
    if (!trainingFile.exists()) {
      System.out.println("Can't read file: " + trainingFile.getAbsolutePath());
      return;
    }

    final MLDataSet trainingSet = EncogUtility.loadEGB2Memory(trainingFile);

    // train the neural network
    EncogUtility.trainConsole(network, trainingSet, Config.TRAINING_MINUTES);
    System.out.println("Final Error: " + (float) network.calculateError(trainingSet));
    System.out.println("Training complete, saving network.");
    EncogDirectoryPersistence.saveObject(networkFile, network);
    System.out.println("Network saved.");
    Encog.getInstance().shutdown();
  }
  private void processTrain() throws IOException {
    final String strMode = getArg("mode");
    final String strMinutes = getArg("minutes");
    final String strStrategyError = getArg("strategyerror");
    final String strStrategyCycles = getArg("strategycycles");

    System.out.println("Training Beginning... Output patterns=" + this.outputCount);

    final double strategyError = Double.parseDouble(strStrategyError);
    final int strategyCycles = Integer.parseInt(strStrategyCycles);

    final ResilientPropagation train = new ResilientPropagation(this.network, this.training);
    train.addStrategy(new ResetStrategy(strategyError, strategyCycles));

    if (strMode.equalsIgnoreCase("gui")) {
      EncogUtility.trainDialog(train, this.network, this.training);
    } else {
      final int minutes = Integer.parseInt(strMinutes);
      EncogUtility.trainConsole(train, this.network, this.training, minutes);
    }
    System.out.println("Training Stopped...");
  }
 /**
  * Train the neural network, using SCG training, and output status to the console.
  *
  * @param network The network to train.
  * @param trainingSet The training set.
  * @param minutes The number of minutes to train for.
  */
 public static void trainConsole(
     final BasicNetwork network, final MLDataSet trainingSet, final int minutes) {
   final Propagation train = new ResilientPropagation(network, trainingSet);
   train.setThreadCount(0);
   EncogUtility.trainConsole(train, network, trainingSet, minutes);
 }