Beispiel #1
0
  public static void main(String[] args) {
    Logging.stopConsoleLogging();
    NeuralDataSet trainingSet = new BasicNeuralDataSet(XOR_INPUT, XOR_IDEAL);
    BasicNetwork network = EncogUtility.simpleFeedForward(2, 4, 0, 1, false);
    ResilientPropagation train = new ResilientPropagation(network, trainingSet);
    train.addStrategy(new RequiredImprovementStrategy(5));

    System.out.println("Perform initial train.");
    EncogUtility.trainToError(train, network, trainingSet, 0.01);
    TrainingContinuation cont = train.pause();
    System.out.println(
        Arrays.toString((double[]) cont.getContents().get(ResilientPropagation.LAST_GRADIENTS)));
    System.out.println(
        Arrays.toString((double[]) cont.getContents().get(ResilientPropagation.UPDATE_VALUES)));

    try {
      SerializeObject.save("resume.ser", cont);
      cont = (TrainingContinuation) SerializeObject.load("resume.ser");
    } catch (Exception ex) {
      ex.printStackTrace();
    }

    System.out.println(
        "Now trying a second train, with continue from the first.  Should stop after one iteration");
    ResilientPropagation train2 = new ResilientPropagation(network, trainingSet);
    train2.resume(cont);
    EncogUtility.trainToError(train2, network, trainingSet, 0.01);
  }
  /**
   * Train the method, to a specific error, send the output to the console.
   *
   * @param method The method to train.
   * @param dataSet The training set to use.
   * @param error The error level to train to.
   */
  public static void trainToError(
      final MLMethod method, final MLDataSet dataSet, final double error) {

    MLTrain train;

    if (method instanceof SVM) {
      train = new SVMTrain((SVM) method, dataSet);
    } else {
      train = new ResilientPropagation((ContainsFlat) method, dataSet);
    }
    EncogUtility.trainToError(train, error);
  }