@Override
 public void iterationDone(int epoch) {
   int plotEpochs = getRenderIterations();
   if (plotEpochs <= 0) return;
   if (epoch % plotEpochs == 0 || epoch == 0) {
     NeuralNetPlotter plotter = new NeuralNetPlotter();
     plotter.plotNetworkGradient(
         this, this.getGradient(new Object[] {1, 0.001, 1000}), getInput().rows);
   }
 }
  public static void main(String[] args) throws Exception {
    // batches of 10, 60000 examples total
    DataSetIterator iter = new MnistDataSetIterator(10, 300);

    DBN dbn = SerializationUtils.readObject(new File(args[0]));
    for (int i = 0; i < dbn.getnLayers(); i++) dbn.getLayers()[i].setRenderEpochs(10);

    DeepAutoEncoder encoder = new DeepAutoEncoder(dbn);
    encoder.setRoundCodeLayerInput(true);
    encoder.setNormalizeCodeLayerOutput(false);
    encoder.setOutputLayerLossFunction(OutputLayer.LossFunction.RMSE_XENT);
    encoder.setOutputLayerActivation(Activations.sigmoid());
    encoder.setVisibleUnit(RBM.VisibleUnit.GAUSSIAN);
    encoder.setHiddenUnit(RBM.HiddenUnit.BINARY);

    while (iter.hasNext()) {
      DataSet next = iter.next();
      if (next == null) break;
      log.info("Training on " + next.numExamples());
      log.info("Coding layer is " + encoder.encode(next.getFirst()));
      encoder.finetune(next.getFirst(), 1e-1, 1000);
      NeuralNetPlotter plotter = new NeuralNetPlotter();
      encoder.getDecoder().feedForward(encoder.encodeWithScaling(next.getFirst()));
      String[] layers = new String[encoder.getDecoder().getnLayers()];
      DoubleMatrix[] weights = new DoubleMatrix[layers.length];
      for (int i = 0; i < encoder.getDecoder().getnLayers(); i++) {
        layers[i] = "" + i;
        weights[i] = encoder.getDecoder().getLayers()[i].getW();
      }

      plotter.histogram(layers, weights);

      FilterRenderer f = new FilterRenderer();
      f.renderFilters(
          encoder.getDecoder().getOutputLayer().getW(),
          "outputlayer.png",
          28,
          28,
          next.numExamples());
      DeepAutoEncoderDataSetReconstructionRender render =
          new DeepAutoEncoderDataSetReconstructionRender(
              next.iterator(next.numExamples()), encoder);
      render.draw();
    }

    SerializationUtils.saveObject(encoder, new File("deepautoencoder.ser"));

    iter.reset();

    DeepAutoEncoderDataSetReconstructionRender render =
        new DeepAutoEncoderDataSetReconstructionRender(iter, encoder);
    render.draw();
  }
 @Override
 public void iterationDone(int iterationDone) {
   int plotEpochs = network.conf().getRenderWeightsEveryNumEpochs();
   if (plotEpochs <= 0) return;
   if (iterationDone % plotEpochs == 0) {
     plotter.plotNetworkGradient(network, network.getGradient(extraParams), 100);
   }
 }
  /** Backprop with the output being the reconstruction */
  @Override
  public void backProp(double lr, int iterations, Object[] extraParams) {
    double currRecon = squaredLoss();
    boolean train = true;
    NeuralNetwork revert = clone();
    while (train) {
      if (iterations > iterations) break;

      double newRecon = this.squaredLoss();
      // prevent weights from exploding too far in either direction, we want this as close to zero
      // as possible
      if (newRecon > currRecon || currRecon < 0 && newRecon < currRecon) {
        update((BaseNeuralNetwork) revert);
        log.info("Converged for new recon; breaking...");
        break;
      } else if (Double.isNaN(newRecon) || Double.isInfinite(newRecon)) {
        update((BaseNeuralNetwork) revert);
        log.info("Converged for new recon; breaking...");
        break;
      } else if (newRecon == currRecon) break;
      else {
        currRecon = newRecon;
        revert = clone();
        log.info("Recon went down " + currRecon);
      }

      iterations++;

      int plotIterations = getRenderIterations();
      if (plotIterations > 0) {
        NeuralNetPlotter plotter = new NeuralNetPlotter();
        if (iterations % plotIterations == 0) {
          plotter.plotNetworkGradient(this, getGradient(extraParams), getInput().rows);
        }
      }
    }
  }
  public static void main(String... args) throws Exception {
    int numFeatures = 40;
    int iterations = 5;
    int seed = 123;
    int listenerFreq = iterations / 5;
    Nd4j.getRandom().setSeed(seed);

    log.info("Load dat....");
    INDArray input =
        Nd4j.create(
            2,
            numFeatures); // have to be at least two or else output layer gradient is a scalar and
    // cause exception
    INDArray labels = Nd4j.create(2, 2);

    INDArray row0 = Nd4j.create(1, numFeatures);
    row0.assign(0.1);
    input.putRow(0, row0);
    labels.put(0, 1, 1); // set the 4th column

    INDArray row1 = Nd4j.create(1, numFeatures);
    row1.assign(0.2);

    input.putRow(1, row1);
    labels.put(1, 0, 1); // set the 2nd column

    DataSet trainingSet = new DataSet(input, labels);

    log.info("Build model....");
    NeuralNetConfiguration conf =
        new NeuralNetConfiguration.Builder()
            .layer(new RBM())
            .nIn(trainingSet.numInputs())
            .nOut(trainingSet.numOutcomes())
            .seed(seed)
            .weightInit(WeightInit.SIZE)
            .constrainGradientToUnitNorm(true)
            .iterations(iterations)
            .activationFunction("tanh")
            .visibleUnit(RBM.VisibleUnit.GAUSSIAN)
            .hiddenUnit(RBM.HiddenUnit.RECTIFIED)
            .lossFunction(LossFunctions.LossFunction.RMSE_XENT)
            .learningRate(1e-1f)
            .optimizationAlgo(OptimizationAlgorithm.ITERATION_GRADIENT_DESCENT)
            .build();
    Layer model = LayerFactories.getFactory(conf).create(conf);
    model.setIterationListeners(
        Collections.singletonList((IterationListener) new ScoreIterationListener(listenerFreq)));

    log.info("Evaluate weights....");
    INDArray w = model.getParam(DefaultParamInitializer.WEIGHT_KEY);
    log.info("Weights: " + w);

    log.info("Train model....");
    model.fit(trainingSet.getFeatureMatrix());

    log.info("Visualize training results....");
    // Work in progress to get NeuralNetPlotter functioning
    NeuralNetPlotter plotter = new NeuralNetPlotter();
    plotter.plotNetworkGradient(model, model.gradient(), 10);
  }