public static DataSet mnist(int num) { try { MnistDataFetcher fetcher = new MnistDataFetcher(); fetcher.fetch(num); return fetcher.next(); } catch (IOException e) { throw new RuntimeException(e); } }
@Test public void testWithMnist() throws Exception { MnistDataFetcher fetcher = new MnistDataFetcher(true); fetcher.fetch(200); DataSet data = fetcher.next(); data.filterAndStrip(new int[] {0, 1}); log.info("Training on " + data.numExamples()); DBN dbn = new DBN.Builder() .hiddenLayerSizes(new int[] {1000, 500, 250, 10}) .numberOfInputs(784) .numberOfOutPuts(2) .build(); dbn.pretrain(data.getFirst(), new Object[] {1, 1e-1, 10000}); DeepAutoEncoder encoder = new DeepAutoEncoder(dbn); encoder.finetune(data.getFirst(), 1e-3, 1000); DoubleMatrix reconstruct = encoder.reconstruct(data.getFirst()); for (int j = 0; j < data.numExamples(); j++) { DoubleMatrix draw1 = data.get(j).getFirst().mul(255); DoubleMatrix reconstructed2 = reconstruct.getRow(j); DoubleMatrix draw2 = reconstructed2.mul(255); DrawMnistGreyScale d = new DrawMnistGreyScale(draw1); d.title = "REAL"; d.draw(); DrawMnistGreyScale d2 = new DrawMnistGreyScale(draw2); d2.title = "TEST"; d2.draw(); Thread.sleep(10000); d.frame.dispose(); d2.frame.dispose(); } }