@Override public void iterationDone(int iterationDone) { int plotEpochs = network.conf().getRenderWeightsEveryNumEpochs(); if (plotEpochs <= 0) return; if (iterationDone % plotEpochs == 0) { plotter.plotNetworkGradient(network, network.getGradient(extraParams), 100); } }
@Override public void iterationDone(int epoch) { int plotEpochs = getRenderIterations(); if (plotEpochs <= 0) return; if (epoch % plotEpochs == 0 || epoch == 0) { NeuralNetPlotter plotter = new NeuralNetPlotter(); plotter.plotNetworkGradient( this, this.getGradient(new Object[] {1, 0.001, 1000}), getInput().rows); } }
/** Backprop with the output being the reconstruction */ @Override public void backProp(double lr, int iterations, Object[] extraParams) { double currRecon = squaredLoss(); boolean train = true; NeuralNetwork revert = clone(); while (train) { if (iterations > iterations) break; double newRecon = this.squaredLoss(); // prevent weights from exploding too far in either direction, we want this as close to zero // as possible if (newRecon > currRecon || currRecon < 0 && newRecon < currRecon) { update((BaseNeuralNetwork) revert); log.info("Converged for new recon; breaking..."); break; } else if (Double.isNaN(newRecon) || Double.isInfinite(newRecon)) { update((BaseNeuralNetwork) revert); log.info("Converged for new recon; breaking..."); break; } else if (newRecon == currRecon) break; else { currRecon = newRecon; revert = clone(); log.info("Recon went down " + currRecon); } iterations++; int plotIterations = getRenderIterations(); if (plotIterations > 0) { NeuralNetPlotter plotter = new NeuralNetPlotter(); if (iterations % plotIterations == 0) { plotter.plotNetworkGradient(this, getGradient(extraParams), getInput().rows); } } } }
public static void main(String... args) throws Exception { int numFeatures = 40; int iterations = 5; int seed = 123; int listenerFreq = iterations / 5; Nd4j.getRandom().setSeed(seed); log.info("Load dat...."); INDArray input = Nd4j.create( 2, numFeatures); // have to be at least two or else output layer gradient is a scalar and // cause exception INDArray labels = Nd4j.create(2, 2); INDArray row0 = Nd4j.create(1, numFeatures); row0.assign(0.1); input.putRow(0, row0); labels.put(0, 1, 1); // set the 4th column INDArray row1 = Nd4j.create(1, numFeatures); row1.assign(0.2); input.putRow(1, row1); labels.put(1, 0, 1); // set the 2nd column DataSet trainingSet = new DataSet(input, labels); log.info("Build model...."); NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder() .layer(new RBM()) .nIn(trainingSet.numInputs()) .nOut(trainingSet.numOutcomes()) .seed(seed) .weightInit(WeightInit.SIZE) .constrainGradientToUnitNorm(true) .iterations(iterations) .activationFunction("tanh") .visibleUnit(RBM.VisibleUnit.GAUSSIAN) .hiddenUnit(RBM.HiddenUnit.RECTIFIED) .lossFunction(LossFunctions.LossFunction.RMSE_XENT) .learningRate(1e-1f) .optimizationAlgo(OptimizationAlgorithm.ITERATION_GRADIENT_DESCENT) .build(); Layer model = LayerFactories.getFactory(conf).create(conf); model.setIterationListeners( Collections.singletonList((IterationListener) new ScoreIterationListener(listenerFreq))); log.info("Evaluate weights...."); INDArray w = model.getParam(DefaultParamInitializer.WEIGHT_KEY); log.info("Weights: " + w); log.info("Train model...."); model.fit(trainingSet.getFeatureMatrix()); log.info("Visualize training results...."); // Work in progress to get NeuralNetPlotter functioning NeuralNetPlotter plotter = new NeuralNetPlotter(); plotter.plotNetworkGradient(model, model.gradient(), 10); }