/** * Performs the network's training iterations. The training continues until it reach the max error * or the max number of iterations set on the network. * * @param network AbstractNeuralNetwork * @return List of iterations */ public static List<String> train(AbstractNeuralNetwork network) { MLTrain training = network.getTrainStrategy(); List<String> output = new ArrayList<String>(); double error = 0; int epoch = 1; do { training.iteration(); error = training.getError(); output.add(epoch + "\t" + String.format(Locale.US, "%.20f", error)); // System.out.println("Iteration #" + epoch + " Error = " + training.getError()); epoch++; } while (continueIterations(network, training.getError(), epoch)); network.updateTrainError(); System.out.println("Ended training: Iteration #" + --epoch + " Error = " + training.getError()); training.finishTraining(); return output; }
/** * Determines if the training may continue. Checks the network permited max training error and max * number of iteraitions. * * @param network AbstractNeuralNetwork * @param error Current training error * @param iteration Current iteration * @return May continue? */ private static boolean continueIterations( AbstractNeuralNetwork network, double error, int iteration) { if (network.getMaxIterations() > 0 && iteration > 0) return network.getMaxIterations() >= iteration && network.getMaxError() < error; else return network.getMaxError() < error; }