@Override
 public void iterationDone(int iterationDone) {
   int plotEpochs = network.conf().getRenderWeightsEveryNumEpochs();
   if (plotEpochs <= 0) return;
   if (iterationDone % plotEpochs == 0) {
     plotter.plotNetworkGradient(network, network.getGradient(extraParams), 100);
   }
 }
 @Override
 public void iterationDone(int epoch) {
   int plotEpochs = getRenderIterations();
   if (plotEpochs <= 0) return;
   if (epoch % plotEpochs == 0 || epoch == 0) {
     NeuralNetPlotter plotter = new NeuralNetPlotter();
     plotter.plotNetworkGradient(
         this, this.getGradient(new Object[] {1, 0.001, 1000}), getInput().rows);
   }
 }
  /** Backprop with the output being the reconstruction */
  @Override
  public void backProp(double lr, int iterations, Object[] extraParams) {
    double currRecon = squaredLoss();
    boolean train = true;
    NeuralNetwork revert = clone();
    while (train) {
      if (iterations > iterations) break;

      double newRecon = this.squaredLoss();
      // prevent weights from exploding too far in either direction, we want this as close to zero
      // as possible
      if (newRecon > currRecon || currRecon < 0 && newRecon < currRecon) {
        update((BaseNeuralNetwork) revert);
        log.info("Converged for new recon; breaking...");
        break;
      } else if (Double.isNaN(newRecon) || Double.isInfinite(newRecon)) {
        update((BaseNeuralNetwork) revert);
        log.info("Converged for new recon; breaking...");
        break;
      } else if (newRecon == currRecon) break;
      else {
        currRecon = newRecon;
        revert = clone();
        log.info("Recon went down " + currRecon);
      }

      iterations++;

      int plotIterations = getRenderIterations();
      if (plotIterations > 0) {
        NeuralNetPlotter plotter = new NeuralNetPlotter();
        if (iterations % plotIterations == 0) {
          plotter.plotNetworkGradient(this, getGradient(extraParams), getInput().rows);
        }
      }
    }
  }
  public static void main(String... args) throws Exception {
    int numFeatures = 40;
    int iterations = 5;
    int seed = 123;
    int listenerFreq = iterations / 5;
    Nd4j.getRandom().setSeed(seed);

    log.info("Load dat....");
    INDArray input =
        Nd4j.create(
            2,
            numFeatures); // have to be at least two or else output layer gradient is a scalar and
    // cause exception
    INDArray labels = Nd4j.create(2, 2);

    INDArray row0 = Nd4j.create(1, numFeatures);
    row0.assign(0.1);
    input.putRow(0, row0);
    labels.put(0, 1, 1); // set the 4th column

    INDArray row1 = Nd4j.create(1, numFeatures);
    row1.assign(0.2);

    input.putRow(1, row1);
    labels.put(1, 0, 1); // set the 2nd column

    DataSet trainingSet = new DataSet(input, labels);

    log.info("Build model....");
    NeuralNetConfiguration conf =
        new NeuralNetConfiguration.Builder()
            .layer(new RBM())
            .nIn(trainingSet.numInputs())
            .nOut(trainingSet.numOutcomes())
            .seed(seed)
            .weightInit(WeightInit.SIZE)
            .constrainGradientToUnitNorm(true)
            .iterations(iterations)
            .activationFunction("tanh")
            .visibleUnit(RBM.VisibleUnit.GAUSSIAN)
            .hiddenUnit(RBM.HiddenUnit.RECTIFIED)
            .lossFunction(LossFunctions.LossFunction.RMSE_XENT)
            .learningRate(1e-1f)
            .optimizationAlgo(OptimizationAlgorithm.ITERATION_GRADIENT_DESCENT)
            .build();
    Layer model = LayerFactories.getFactory(conf).create(conf);
    model.setIterationListeners(
        Collections.singletonList((IterationListener) new ScoreIterationListener(listenerFreq)));

    log.info("Evaluate weights....");
    INDArray w = model.getParam(DefaultParamInitializer.WEIGHT_KEY);
    log.info("Weights: " + w);

    log.info("Train model....");
    model.fit(trainingSet.getFeatureMatrix());

    log.info("Visualize training results....");
    // Work in progress to get NeuralNetPlotter functioning
    NeuralNetPlotter plotter = new NeuralNetPlotter();
    plotter.plotNetworkGradient(model, model.gradient(), 10);
  }