public static void main(String[] args) throws Exception { // batches of 10, 60000 examples total DataSetIterator iter = new MnistDataSetIterator(10, 300); DBN dbn = SerializationUtils.readObject(new File(args[0])); for (int i = 0; i < dbn.getnLayers(); i++) dbn.getLayers()[i].setRenderEpochs(10); DeepAutoEncoder encoder = new DeepAutoEncoder(dbn); encoder.setRoundCodeLayerInput(true); encoder.setNormalizeCodeLayerOutput(false); encoder.setOutputLayerLossFunction(OutputLayer.LossFunction.RMSE_XENT); encoder.setOutputLayerActivation(Activations.sigmoid()); encoder.setVisibleUnit(RBM.VisibleUnit.GAUSSIAN); encoder.setHiddenUnit(RBM.HiddenUnit.BINARY); while (iter.hasNext()) { DataSet next = iter.next(); if (next == null) break; log.info("Training on " + next.numExamples()); log.info("Coding layer is " + encoder.encode(next.getFirst())); encoder.finetune(next.getFirst(), 1e-1, 1000); NeuralNetPlotter plotter = new NeuralNetPlotter(); encoder.getDecoder().feedForward(encoder.encodeWithScaling(next.getFirst())); String[] layers = new String[encoder.getDecoder().getnLayers()]; DoubleMatrix[] weights = new DoubleMatrix[layers.length]; for (int i = 0; i < encoder.getDecoder().getnLayers(); i++) { layers[i] = "" + i; weights[i] = encoder.getDecoder().getLayers()[i].getW(); } plotter.histogram(layers, weights); FilterRenderer f = new FilterRenderer(); f.renderFilters( encoder.getDecoder().getOutputLayer().getW(), "outputlayer.png", 28, 28, next.numExamples()); DeepAutoEncoderDataSetReconstructionRender render = new DeepAutoEncoderDataSetReconstructionRender( next.iterator(next.numExamples()), encoder); render.draw(); } SerializationUtils.saveObject(encoder, new File("deepautoencoder.ser")); iter.reset(); DeepAutoEncoderDataSetReconstructionRender render = new DeepAutoEncoderDataSetReconstructionRender(iter, encoder); render.draw(); }