예제 #1
0
 public static DataSet mnist(int num) {
   try {
     MnistDataFetcher fetcher = new MnistDataFetcher();
     fetcher.fetch(num);
     return fetcher.next();
   } catch (IOException e) {
     throw new RuntimeException(e);
   }
 }
  @Test
  public void testWithMnist() throws Exception {
    MnistDataFetcher fetcher = new MnistDataFetcher(true);
    fetcher.fetch(200);
    DataSet data = fetcher.next();
    data.filterAndStrip(new int[] {0, 1});
    log.info("Training on " + data.numExamples());

    DBN dbn =
        new DBN.Builder()
            .hiddenLayerSizes(new int[] {1000, 500, 250, 10})
            .numberOfInputs(784)
            .numberOfOutPuts(2)
            .build();

    dbn.pretrain(data.getFirst(), new Object[] {1, 1e-1, 10000});

    DeepAutoEncoder encoder = new DeepAutoEncoder(dbn);
    encoder.finetune(data.getFirst(), 1e-3, 1000);

    DoubleMatrix reconstruct = encoder.reconstruct(data.getFirst());
    for (int j = 0; j < data.numExamples(); j++) {

      DoubleMatrix draw1 = data.get(j).getFirst().mul(255);
      DoubleMatrix reconstructed2 = reconstruct.getRow(j);
      DoubleMatrix draw2 = reconstructed2.mul(255);

      DrawMnistGreyScale d = new DrawMnistGreyScale(draw1);
      d.title = "REAL";
      d.draw();
      DrawMnistGreyScale d2 = new DrawMnistGreyScale(draw2);
      d2.title = "TEST";
      d2.draw();
      Thread.sleep(10000);
      d.frame.dispose();
      d2.frame.dispose();
    }
  }