/** * Construct the backpropagation trainer. * * @param theNetwork The network to train. * @param theTraining The training data to use. * @param theLearningRate The learning rate. Can be changed as training runs. * @param theMomentum The momentum. Can be changed as training runs. */ public BackPropagation( BasicNetwork theNetwork, List<BasicData> theTraining, double theLearningRate, double theMomentum) { this.network = theNetwork; this.training = theTraining; this.learningRate = theLearningRate; this.momentum = theMomentum; this.gradients = new GradientCalc(this.network, new CrossEntropyErrorFunction(), this); this.lastDelta = new double[theNetwork.getWeights().length]; }
/** Run the example. */ public void process() { try { final InputStream istream = this.getClass().getResourceAsStream("/iris.csv"); if (istream == null) { System.out.println("Cannot access data set, make sure the resources are available."); System.exit(1); } final DataSet ds = DataSet.load(istream); // The following ranges are setup for the Iris data set. If you wish to normalize other files // you will // need to modify the below function calls other files. ds.normalizeRange(0, -1, 1); ds.normalizeRange(1, -1, 1); ds.normalizeRange(2, -1, 1); ds.normalizeRange(3, -1, 1); final Map<String, Integer> species = ds.encodeOneOfN(4); // species is column 4 istream.close(); final List<BasicData> trainingData = ds.extractSupervised(0, 4, 4, 3); BasicNetwork network = new BasicNetwork(); network.addLayer(new BasicLayer(null, true, 4)); network.addLayer(new BasicLayer(new ActivationReLU(), true, 20)); network.addLayer(new BasicLayer(new ActivationSoftMax(), false, 3)); network.finalizeStructure(); network.reset(); final BackPropagation train = new BackPropagation(network, trainingData, 0.001, 0.9); performIterations(train, 100000, 0.01, true); queryOneOfN(network, trainingData, species); } catch (Throwable t) { t.printStackTrace(); } }