예제 #1
0
  //	@Test
  public void test() throws Exception {
    InputStream ftrn =
        IOUtils.createFileInputStream("/Users/jdchoi/Documents/Data/mnist/mnist_trn.txt");
    InputStream ftst =
        IOUtils.createFileInputStream("/Users/jdchoi/Documents/Data/mnist/mnist_tst.txt");
    boolean sparse = true;
    List<Instance> trn = read(ftrn, sparse);
    List<Instance> tst = read(ftst, sparse);
    WeightVector w = new WeightVector();
    OnlineOptimizer op;

    // best: 89.87 at 46 epoch
    //		op = new Perceptron(w, 0.01f, 0);
    //		develop(trn, tst, op, 1, sparse);

    // best: 90.64 at 41 epoch
    //		op = new AdaGrad(w, 0.01f, 0);
    //		develop(trn, tst, op, 1, sparse);

    // best: 90.60 at 26 epoch
    //		op = new AdaGrad(w, 0.01f, 0, new RegularizedDualAveraging(w, 0.001f));
    //		develop(trn, tst, op, 1, sparse);

    // best: 90.49 at 24 epoch
    //		op = new AdaGradMiniBatch(w, 0.01f, 0);
    //		develop(trn, tst, op, 5, sparse);

    // best: 90.63 at 45 epoch
    //		op = new AdaGradMiniBatch(w, 0.01f, 0, new RegularizedDualAveraging(w, 0.001f));
    //		develop(trn, tst, op, 5, sparse);

    // best: 89.12 at 24 epoch
    //		op = new AdaDeltaMiniBatch(w, 0.01f, 0.4f, 0);
    //		develop(trn, tst, op, 5, sparse);

    // best: 89.36 at 7 epoch
    //		op = new AdaDeltaMiniBatch(w, 0.01f, 0.4f, 0, new RegularizedDualAveraging(w, 0.001f));
    //		develop(trn, tst, op, 5, sparse);

    // best: 92.38 at 47 epoch
    //		op = new SoftmaxRegression(w, 0.00000001f, 0);
    //		develop(trn, tst, op, 1, sparse);

    // best: 92.66 at 46 epoch
    op = new AdaGradRegression(w, 0.0001f, 0);
    develop(trn, tst, op, 1, sparse);

    //		ActivationFunction sigmoid = new SigmoidFunction();
    //		op = new FeedForwardNeuralNetworkSoftmax(new int[]{300}, new ActivationFunction[]{sigmoid},
    // 0.0001f, 0, new RandomWeightGenerator(new XORShiftRandom(9), -0.2f, 0.2f));
    //		develop(trn, tst, op, 1, sparse);
  }
예제 #2
0
  void iterate(TSVReader<POSNode> reader, List<String> filenames, Consumer<POSNode[]> f)
      throws Exception {
    for (String filename : filenames) {
      reader.open(IOUtils.createFileInputStream(filename));
      POSNode[] nodes;

      while ((nodes = reader.next()) != null) f.accept(nodes);

      reader.close();
    }
  }
예제 #3
0
  List<Instance> read(InputStream in, boolean sparse) throws Exception {
    BufferedReader reader = IOUtils.createBufferedReader(in);
    List<Instance> instances = new ArrayList<>();
    String line;
    String[] t;

    while ((line = reader.readLine()) != null) {
      t = Splitter.splitSpace(line);
      instances.add(sparse ? toSparseInstance(t) : toDenseInstance(t));
    }

    return instances;
  }