public static void main(String[] args) throws Exception {
    // batches of 10, 60000 examples total
    DataSetIterator iter = new MnistDataSetIterator(10, 300);

    DBN dbn = SerializationUtils.readObject(new File(args[0]));
    for (int i = 0; i < dbn.getnLayers(); i++) dbn.getLayers()[i].setRenderEpochs(10);

    DeepAutoEncoder encoder = new DeepAutoEncoder(dbn);
    encoder.setRoundCodeLayerInput(true);
    encoder.setNormalizeCodeLayerOutput(false);
    encoder.setOutputLayerLossFunction(OutputLayer.LossFunction.RMSE_XENT);
    encoder.setOutputLayerActivation(Activations.sigmoid());
    encoder.setVisibleUnit(RBM.VisibleUnit.GAUSSIAN);
    encoder.setHiddenUnit(RBM.HiddenUnit.BINARY);

    while (iter.hasNext()) {
      DataSet next = iter.next();
      if (next == null) break;
      log.info("Training on " + next.numExamples());
      log.info("Coding layer is " + encoder.encode(next.getFirst()));
      encoder.finetune(next.getFirst(), 1e-1, 1000);
      NeuralNetPlotter plotter = new NeuralNetPlotter();
      encoder.getDecoder().feedForward(encoder.encodeWithScaling(next.getFirst()));
      String[] layers = new String[encoder.getDecoder().getnLayers()];
      DoubleMatrix[] weights = new DoubleMatrix[layers.length];
      for (int i = 0; i < encoder.getDecoder().getnLayers(); i++) {
        layers[i] = "" + i;
        weights[i] = encoder.getDecoder().getLayers()[i].getW();
      }

      plotter.histogram(layers, weights);

      FilterRenderer f = new FilterRenderer();
      f.renderFilters(
          encoder.getDecoder().getOutputLayer().getW(),
          "outputlayer.png",
          28,
          28,
          next.numExamples());
      DeepAutoEncoderDataSetReconstructionRender render =
          new DeepAutoEncoderDataSetReconstructionRender(
              next.iterator(next.numExamples()), encoder);
      render.draw();
    }

    SerializationUtils.saveObject(encoder, new File("deepautoencoder.ser"));

    iter.reset();

    DeepAutoEncoderDataSetReconstructionRender render =
        new DeepAutoEncoderDataSetReconstructionRender(iter, encoder);
    render.draw();
  }