public void trainAndSave() {
    System.out.println("Training XOR network to under 1% error rate.");
    BasicNetwork network = new BasicNetwork();
    network.addLayer(new BasicLayer(2));
    network.addLayer(new BasicLayer(2));
    network.addLayer(new BasicLayer(1));
    network.getStructure().finalizeStructure();
    network.reset();

    NeuralDataSet trainingSet = new BasicNeuralDataSet(XOR_INPUT, XOR_IDEAL);

    // train the neural network
    final Train train = new ResilientPropagation(network, trainingSet);

    do {
      train.iteration();
    } while (train.getError() > 0.009);

    double e = network.calculateError(trainingSet);
    System.out.println("Network traiined to error: " + e);

    System.out.println("Saving network");
    final EncogPersistedCollection encog = new EncogPersistedCollection(FILENAME);
    encog.create();
    encog.add("network", network);
  }
Beispiel #2
0
  public void testFeedforwardPersist() throws Throwable {
    NeuralDataSet trainingData = new BasicNeuralDataSet(XOR.XOR_INPUT, XOR.XOR_IDEAL);

    BasicNetwork network = createNetwork();
    Train train = new Backpropagation(network, trainingData, 0.7, 0.9);

    for (int i = 0; i < 5000; i++) {
      train.iteration();
      network = (BasicNetwork) train.getNetwork();
    }

    TestCase.assertTrue("Error too high for backpropagation", train.getError() < 0.1);
    TestCase.assertTrue("XOR outputs not correct", XOR.verifyXOR(network, 0.1));

    EncogPersistedCollection encog = new EncogPersistedCollection();
    encog.add(network);
    encog.save("encogtest.xml");

    EncogPersistedCollection encog2 = new EncogPersistedCollection();
    encog2.load("encogtest.xml");
    new File("encogtest.xml").delete();

    BasicNetwork n = (BasicNetwork) encog2.getList().get(0);
    TestCase.assertTrue("Error too high for load", n.calculateError(trainingData) < 0.1);
  }
  public static void train(File dataDir) {
    final File networkFile = new File(dataDir, Config.NETWORK_FILE);
    final File trainingFile = new File(dataDir, Config.TRAINING_FILE);

    // network file
    if (!networkFile.exists()) {
      System.out.println("Can't read file: " + networkFile.getAbsolutePath());
      return;
    }

    BasicNetwork network = (BasicNetwork) EncogDirectoryPersistence.loadObject(networkFile);

    // training file
    if (!trainingFile.exists()) {
      System.out.println("Can't read file: " + trainingFile.getAbsolutePath());
      return;
    }

    final MLDataSet trainingSet = EncogUtility.loadEGB2Memory(trainingFile);

    // train the neural network
    EncogUtility.trainConsole(network, trainingSet, Config.TRAINING_MINUTES);
    System.out.println("Final Error: " + (float) network.calculateError(trainingSet));
    System.out.println("Training complete, saving network.");
    EncogDirectoryPersistence.saveObject(networkFile, network);
    System.out.println("Network saved.");
    Encog.getInstance().shutdown();
  }
  public void loadAndEvaluate() {
    System.out.println("Loading network");

    final EncogPersistedCollection encog = new EncogPersistedCollection(FILENAME);
    BasicNetwork network = (BasicNetwork) encog.find("network");

    NeuralDataSet trainingSet = new BasicNeuralDataSet(XOR_INPUT, XOR_IDEAL);
    double e = network.calculateError(trainingSet);
    System.out.println("Loaded network's error is(should be same as above): " + e);
  }
  public void performEvaluate() {
    try {
      EvaluateDialog dialog = new EvaluateDialog(EncogWorkBench.getInstance().getMainWindow());
      if (dialog.process()) {
        BasicNetwork network = dialog.getNetwork();
        NeuralDataSet training = dialog.getTrainingSet();

        double error = network.calculateError(training);
        EncogWorkBench.displayMessage("Error For this Network", "" + Format.formatPercent(error));
      }
    } catch (Throwable t) {
      EncogWorkBench.displayError("Error Evaluating Network", t);
    }
  }
Beispiel #6
0
  public static double evaluateMPROP(BasicNetwork network, NeuralDataSet data) {

    ResilientPropagation train = new ResilientPropagation(network, data);
    train.setNumThreads(0);
    long start = System.currentTimeMillis();
    System.out.println("Training 20 Iterations with MPROP");
    for (int i = 1; i <= 20; i++) {
      train.iteration();
      System.out.println("Iteration #" + i + " Error:" + train.getError());
    }
    train.finishTraining();
    long stop = System.currentTimeMillis();
    double diff = ((double) (stop - start)) / 1000.0;
    System.out.println("MPROP Result:" + diff + " seconds.");
    System.out.println("Final MPROP error: " + network.calculateError(data));
    return diff;
  }
  public PredictSIN() {
    this.setTitle("SIN Wave Predict");
    this.setSize(640, 480);
    Container content = this.getContentPane();
    content.setLayout(new BorderLayout());
    content.add(graph = new GraphPanel(), BorderLayout.CENTER);

    network = EncogUtility.simpleFeedForward(INPUT_WINDOW, PREDICT_WINDOW * 2, 0, 1, true);
    network.reset();
    graph.setNetwork(network);

    this.trainingData = generateTraining();
    this.train = new ResilientPropagation(this.network, this.trainingData);
    btnTrain = new JButton("Train");
    this.btnTrain.addActionListener(this);
    content.add(btnTrain, BorderLayout.SOUTH);
    graph.setError(network.calculateError(this.trainingData));
  }
Beispiel #8
0
  public static void main(String args[]) {
    int inputNeurons = CHAR_WIDTH * CHAR_HEIGHT;
    int outputNeurons = DIGITS.length;

    BasicNetwork network = new BasicNetwork();

    Layer inputLayer = new BasicLayer(new ActivationLinear(), false, inputNeurons);
    Layer outputLayer = new BasicLayer(new ActivationLinear(), true, outputNeurons);

    network.addLayer(inputLayer);
    network.addLayer(outputLayer);
    network.getStructure().finalizeStructure();

    (new RangeRandomizer(-0.5, 0.5)).randomize(network);

    // train it
    NeuralDataSet training = generateTraining();
    Train train = new TrainAdaline(network, training, 0.01);

    int epoch = 1;
    do {
      train.iteration();
      System.out.println("Epoch #" + epoch + " Error:" + train.getError());
      epoch++;
    } while (train.getError() > 0.01);

    //
    System.out.println("Error:" + network.calculateError(training));

    // test it
    for (int i = 0; i < DIGITS.length; i++) {
      int output = network.winner(image2data(DIGITS[i]));

      for (int j = 0; j < CHAR_HEIGHT; j++) {
        if (j == CHAR_HEIGHT - 1) System.out.println(DIGITS[i][j] + " -> " + output);
        else System.out.println(DIGITS[i][j]);
      }

      System.out.println();
    }
  }
 @Override
 public double getError(EnsembleDataSet testset) {
   return ml.calculateError(testset);
 }