コード例 #1
0
ファイル: TrainResume.java プロジェクト: arjuncomar/VAFusion2
  public static void main(String[] args) {
    Logging.stopConsoleLogging();
    NeuralDataSet trainingSet = new BasicNeuralDataSet(XOR_INPUT, XOR_IDEAL);
    BasicNetwork network = EncogUtility.simpleFeedForward(2, 4, 0, 1, false);
    ResilientPropagation train = new ResilientPropagation(network, trainingSet);
    train.addStrategy(new RequiredImprovementStrategy(5));

    System.out.println("Perform initial train.");
    EncogUtility.trainToError(train, network, trainingSet, 0.01);
    TrainingContinuation cont = train.pause();
    System.out.println(
        Arrays.toString((double[]) cont.getContents().get(ResilientPropagation.LAST_GRADIENTS)));
    System.out.println(
        Arrays.toString((double[]) cont.getContents().get(ResilientPropagation.UPDATE_VALUES)));

    try {
      SerializeObject.save("resume.ser", cont);
      cont = (TrainingContinuation) SerializeObject.load("resume.ser");
    } catch (Exception ex) {
      ex.printStackTrace();
    }

    System.out.println(
        "Now trying a second train, with continue from the first.  Should stop after one iteration");
    ResilientPropagation train2 = new ResilientPropagation(network, trainingSet);
    train2.resume(cont);
    EncogUtility.trainToError(train2, network, trainingSet, 0.01);
  }
コード例 #2
0
  private void processNetwork() throws IOException {
    System.out.println("Downsampling images...");

    for (final ImagePair pair : this.imageList) {
      final NeuralData ideal = new BasicNeuralData(this.outputCount);
      final int idx = pair.getIdentity();
      for (int i = 0; i < this.outputCount; i++) {
        if (i == idx) {
          ideal.setData(i, 1);
        } else {
          ideal.setData(i, -1);
        }
      }

      final Image img = ImageIO.read(pair.getFile());
      final ImageNeuralData data = new ImageNeuralData(img);
      this.training.add(data, ideal);
    }

    final String strHidden1 = getArg("hidden1");
    final String strHidden2 = getArg("hidden2");

    this.training.downsample(this.downsampleHeight, this.downsampleWidth);

    final int hidden1 = Integer.parseInt(strHidden1);
    final int hidden2 = Integer.parseInt(strHidden2);

    this.network =
        EncogUtility.simpleFeedForward(
            this.training.getInputSize(), hidden1, hidden2, this.training.getIdealSize(), true);
    System.out.println("Created network: " + this.network.toString());
  }
コード例 #3
0
 public void testAnalyze() {
   BasicNetwork network = EncogUtility.simpleFeedForward(2, 2, 0, 1, false);
   double[] weights = new double[network.encodedArrayLength()];
   EngineArray.fill(weights, 1.0);
   network.decodeFromArray(weights);
   AnalyzeNetwork analyze = new AnalyzeNetwork(network);
   Assert.assertEquals(weights.length, analyze.getWeightsAndBias().getSamples());
   Assert.assertEquals(3, analyze.getBias().getSamples());
   Assert.assertEquals(6, analyze.getWeights().getSamples());
 }
コード例 #4
0
  public void testCompleteTrain() {
    MLDataSet trainingData = new BasicMLDataSet(XOR.XOR_INPUT, XOR.XOR_IDEAL);

    BasicNetwork network = EncogUtility.simpleFeedForward(2, 5, 7, 1, true);
    (new ConsistentRandomizer(-1, 1)).randomize(network);
    MLTrain rprop = new ResilientPropagation(network, trainingData);
    int iteration = 0;
    do {
      rprop.iteration();
      iteration++;
    } while (iteration < 5000 && rprop.getError() > 0.01);
    Assert.assertTrue(iteration < 40);
  }
コード例 #5
0
  public PredictSIN() {
    this.setTitle("SIN Wave Predict");
    this.setSize(640, 480);
    Container content = this.getContentPane();
    content.setLayout(new BorderLayout());
    content.add(graph = new GraphPanel(), BorderLayout.CENTER);

    network = EncogUtility.simpleFeedForward(INPUT_WINDOW, PREDICT_WINDOW * 2, 0, 1, true);
    network.reset();
    graph.setNetwork(network);

    this.trainingData = generateTraining();
    this.train = new ResilientPropagation(this.network, this.trainingData);
    btnTrain = new JButton("Train");
    this.btnTrain.addActionListener(this);
    content.add(btnTrain, BorderLayout.SOUTH);
    graph.setError(network.calculateError(this.trainingData));
  }