public static void main(String[] args) throws Exception { // batches of 10, 60000 examples total DataSetIterator iter = new MnistDataSetIterator(10, 300); DBN dbn = SerializationUtils.readObject(new File(args[0])); for (int i = 0; i < dbn.getnLayers(); i++) dbn.getLayers()[i].setRenderEpochs(10); DeepAutoEncoder encoder = new DeepAutoEncoder(dbn); encoder.setRoundCodeLayerInput(true); encoder.setNormalizeCodeLayerOutput(false); encoder.setOutputLayerLossFunction(OutputLayer.LossFunction.RMSE_XENT); encoder.setOutputLayerActivation(Activations.sigmoid()); encoder.setVisibleUnit(RBM.VisibleUnit.GAUSSIAN); encoder.setHiddenUnit(RBM.HiddenUnit.BINARY); while (iter.hasNext()) { DataSet next = iter.next(); if (next == null) break; log.info("Training on " + next.numExamples()); log.info("Coding layer is " + encoder.encode(next.getFirst())); encoder.finetune(next.getFirst(), 1e-1, 1000); NeuralNetPlotter plotter = new NeuralNetPlotter(); encoder.getDecoder().feedForward(encoder.encodeWithScaling(next.getFirst())); String[] layers = new String[encoder.getDecoder().getnLayers()]; DoubleMatrix[] weights = new DoubleMatrix[layers.length]; for (int i = 0; i < encoder.getDecoder().getnLayers(); i++) { layers[i] = "" + i; weights[i] = encoder.getDecoder().getLayers()[i].getW(); } plotter.histogram(layers, weights); FilterRenderer f = new FilterRenderer(); f.renderFilters( encoder.getDecoder().getOutputLayer().getW(), "outputlayer.png", 28, 28, next.numExamples()); DeepAutoEncoderDataSetReconstructionRender render = new DeepAutoEncoderDataSetReconstructionRender( next.iterator(next.numExamples()), encoder); render.draw(); } SerializationUtils.saveObject(encoder, new File("deepautoencoder.ser")); iter.reset(); DeepAutoEncoderDataSetReconstructionRender render = new DeepAutoEncoderDataSetReconstructionRender(iter, encoder); render.draw(); }
@Test public void testWithMnist() throws Exception { MnistDataFetcher fetcher = new MnistDataFetcher(true); fetcher.fetch(200); DataSet data = fetcher.next(); data.filterAndStrip(new int[] {0, 1}); log.info("Training on " + data.numExamples()); DBN dbn = new DBN.Builder() .hiddenLayerSizes(new int[] {1000, 500, 250, 10}) .numberOfInputs(784) .numberOfOutPuts(2) .build(); dbn.pretrain(data.getFirst(), new Object[] {1, 1e-1, 10000}); DeepAutoEncoder encoder = new DeepAutoEncoder(dbn); encoder.finetune(data.getFirst(), 1e-3, 1000); DoubleMatrix reconstruct = encoder.reconstruct(data.getFirst()); for (int j = 0; j < data.numExamples(); j++) { DoubleMatrix draw1 = data.get(j).getFirst().mul(255); DoubleMatrix reconstructed2 = reconstruct.getRow(j); DoubleMatrix draw2 = reconstructed2.mul(255); DrawMnistGreyScale d = new DrawMnistGreyScale(draw1); d.title = "REAL"; d.draw(); DrawMnistGreyScale d2 = new DrawMnistGreyScale(draw2); d2.title = "TEST"; d2.draw(); Thread.sleep(10000); d.frame.dispose(); d2.frame.dispose(); } }