コード例 #1
0
ファイル: TrainResume.java プロジェクト: arjuncomar/VAFusion2
  public static void main(String[] args) {
    Logging.stopConsoleLogging();
    NeuralDataSet trainingSet = new BasicNeuralDataSet(XOR_INPUT, XOR_IDEAL);
    BasicNetwork network = EncogUtility.simpleFeedForward(2, 4, 0, 1, false);
    ResilientPropagation train = new ResilientPropagation(network, trainingSet);
    train.addStrategy(new RequiredImprovementStrategy(5));

    System.out.println("Perform initial train.");
    EncogUtility.trainToError(train, network, trainingSet, 0.01);
    TrainingContinuation cont = train.pause();
    System.out.println(
        Arrays.toString((double[]) cont.getContents().get(ResilientPropagation.LAST_GRADIENTS)));
    System.out.println(
        Arrays.toString((double[]) cont.getContents().get(ResilientPropagation.UPDATE_VALUES)));

    try {
      SerializeObject.save("resume.ser", cont);
      cont = (TrainingContinuation) SerializeObject.load("resume.ser");
    } catch (Exception ex) {
      ex.printStackTrace();
    }

    System.out.println(
        "Now trying a second train, with continue from the first.  Should stop after one iteration");
    ResilientPropagation train2 = new ResilientPropagation(network, trainingSet);
    train2.resume(cont);
    EncogUtility.trainToError(train2, network, trainingSet, 0.01);
  }