Пример #1
0
  public void trainAndSave() {
    System.out.println("Training XOR network to under 1% error rate.");
    BasicNetwork network = new BasicNetwork();
    network.addLayer(new BasicLayer(2));
    network.addLayer(new BasicLayer(2));
    network.addLayer(new BasicLayer(1));
    network.getStructure().finalizeStructure();
    network.reset();

    NeuralDataSet trainingSet = new BasicNeuralDataSet(XOR_INPUT, XOR_IDEAL);

    // train the neural network
    final Train train = new ResilientPropagation(network, trainingSet);

    do {
      train.iteration();
    } while (train.getError() > 0.009);

    double e = network.calculateError(trainingSet);
    System.out.println("Network traiined to error: " + e);

    System.out.println("Saving network");
    final EncogPersistedCollection encog = new EncogPersistedCollection(FILENAME);
    encog.create();
    encog.add("network", network);
  }
Пример #2
0
  public static void main(String args[]) {
    int inputNeurons = CHAR_WIDTH * CHAR_HEIGHT;
    int outputNeurons = DIGITS.length;

    BasicNetwork network = new BasicNetwork();

    Layer inputLayer = new BasicLayer(new ActivationLinear(), false, inputNeurons);
    Layer outputLayer = new BasicLayer(new ActivationLinear(), true, outputNeurons);

    network.addLayer(inputLayer);
    network.addLayer(outputLayer);
    network.getStructure().finalizeStructure();

    (new RangeRandomizer(-0.5, 0.5)).randomize(network);

    // train it
    NeuralDataSet training = generateTraining();
    Train train = new TrainAdaline(network, training, 0.01);

    int epoch = 1;
    do {
      train.iteration();
      System.out.println("Epoch #" + epoch + " Error:" + train.getError());
      epoch++;
    } while (train.getError() > 0.01);

    //
    System.out.println("Error:" + network.calculateError(training));

    // test it
    for (int i = 0; i < DIGITS.length; i++) {
      int output = network.winner(image2data(DIGITS[i]));

      for (int j = 0; j < CHAR_HEIGHT; j++) {
        if (j == CHAR_HEIGHT - 1) System.out.println(DIGITS[i][j] + " -> " + output);
        else System.out.println(DIGITS[i][j]);
      }

      System.out.println();
    }
  }
Пример #3
0
  /**
   * * <br>
   * Updates the neural networks</br> <br>
   * Assumes that the *
   */
  private void trainNeuralNets(int id) {

    Train train =
        new ResilientPropagation(array_of_neural_nets[id], input_datas[id].getTrainingSet());

    int epoch = 0;
    while (train.getError() > error_allowed && epoch < max_epoch) {

      train.iteration();
      epoch++;
    }
  }
  public static double trainNetwork(
      final String what, final BasicNetwork network, final NeuralDataSet trainingSet) {
    // train the neural network
    CalculateScore score = new TrainingSetScore(trainingSet);
    final Train trainAlt = new NeuralSimulatedAnnealing(network, score, 10, 2, 100);

    final Train trainMain = new Backpropagation(network, trainingSet, 0.00001, 0.0);

    final StopTrainingStrategy stop = new StopTrainingStrategy();
    trainMain.addStrategy(new Greedy());
    trainMain.addStrategy(new HybridStrategy(trainAlt));
    trainMain.addStrategy(stop);

    int epoch = 0;
    while (!stop.shouldStop()) {
      trainMain.iteration();
      System.out.println(
          "Training " + what + ", Epoch #" + epoch + " Error:" + trainMain.getError());
      epoch++;
    }
    return trainMain.getError();
  }