public static void main(String[] args) throws Exception {
    // batches of 10, 60000 examples total
    DataSetIterator iter = new MnistDataSetIterator(10, 300);

    DBN dbn = SerializationUtils.readObject(new File(args[0]));
    for (int i = 0; i < dbn.getnLayers(); i++) dbn.getLayers()[i].setRenderEpochs(10);

    DeepAutoEncoder encoder = new DeepAutoEncoder(dbn);
    encoder.setRoundCodeLayerInput(true);
    encoder.setNormalizeCodeLayerOutput(false);
    encoder.setOutputLayerLossFunction(OutputLayer.LossFunction.RMSE_XENT);
    encoder.setOutputLayerActivation(Activations.sigmoid());
    encoder.setVisibleUnit(RBM.VisibleUnit.GAUSSIAN);
    encoder.setHiddenUnit(RBM.HiddenUnit.BINARY);

    while (iter.hasNext()) {
      DataSet next = iter.next();
      if (next == null) break;
      log.info("Training on " + next.numExamples());
      log.info("Coding layer is " + encoder.encode(next.getFirst()));
      encoder.finetune(next.getFirst(), 1e-1, 1000);
      NeuralNetPlotter plotter = new NeuralNetPlotter();
      encoder.getDecoder().feedForward(encoder.encodeWithScaling(next.getFirst()));
      String[] layers = new String[encoder.getDecoder().getnLayers()];
      DoubleMatrix[] weights = new DoubleMatrix[layers.length];
      for (int i = 0; i < encoder.getDecoder().getnLayers(); i++) {
        layers[i] = "" + i;
        weights[i] = encoder.getDecoder().getLayers()[i].getW();
      }

      plotter.histogram(layers, weights);

      FilterRenderer f = new FilterRenderer();
      f.renderFilters(
          encoder.getDecoder().getOutputLayer().getW(),
          "outputlayer.png",
          28,
          28,
          next.numExamples());
      DeepAutoEncoderDataSetReconstructionRender render =
          new DeepAutoEncoderDataSetReconstructionRender(
              next.iterator(next.numExamples()), encoder);
      render.draw();
    }

    SerializationUtils.saveObject(encoder, new File("deepautoencoder.ser"));

    iter.reset();

    DeepAutoEncoderDataSetReconstructionRender render =
        new DeepAutoEncoderDataSetReconstructionRender(iter, encoder);
    render.draw();
  }
  @Test
  public void testWithMnist() throws Exception {
    MnistDataFetcher fetcher = new MnistDataFetcher(true);
    fetcher.fetch(200);
    DataSet data = fetcher.next();
    data.filterAndStrip(new int[] {0, 1});
    log.info("Training on " + data.numExamples());

    DBN dbn =
        new DBN.Builder()
            .hiddenLayerSizes(new int[] {1000, 500, 250, 10})
            .numberOfInputs(784)
            .numberOfOutPuts(2)
            .build();

    dbn.pretrain(data.getFirst(), new Object[] {1, 1e-1, 10000});

    DeepAutoEncoder encoder = new DeepAutoEncoder(dbn);
    encoder.finetune(data.getFirst(), 1e-3, 1000);

    DoubleMatrix reconstruct = encoder.reconstruct(data.getFirst());
    for (int j = 0; j < data.numExamples(); j++) {

      DoubleMatrix draw1 = data.get(j).getFirst().mul(255);
      DoubleMatrix reconstructed2 = reconstruct.getRow(j);
      DoubleMatrix draw2 = reconstructed2.mul(255);

      DrawMnistGreyScale d = new DrawMnistGreyScale(draw1);
      d.title = "REAL";
      d.draw();
      DrawMnistGreyScale d2 = new DrawMnistGreyScale(draw2);
      d2.title = "TEST";
      d2.draw();
      Thread.sleep(10000);
      d.frame.dispose();
      d2.frame.dispose();
    }
  }