// @Test public void test() throws Exception { InputStream ftrn = IOUtils.createFileInputStream("/Users/jdchoi/Documents/Data/mnist/mnist_trn.txt"); InputStream ftst = IOUtils.createFileInputStream("/Users/jdchoi/Documents/Data/mnist/mnist_tst.txt"); boolean sparse = true; List<Instance> trn = read(ftrn, sparse); List<Instance> tst = read(ftst, sparse); WeightVector w = new WeightVector(); OnlineOptimizer op; // best: 89.87 at 46 epoch // op = new Perceptron(w, 0.01f, 0); // develop(trn, tst, op, 1, sparse); // best: 90.64 at 41 epoch // op = new AdaGrad(w, 0.01f, 0); // develop(trn, tst, op, 1, sparse); // best: 90.60 at 26 epoch // op = new AdaGrad(w, 0.01f, 0, new RegularizedDualAveraging(w, 0.001f)); // develop(trn, tst, op, 1, sparse); // best: 90.49 at 24 epoch // op = new AdaGradMiniBatch(w, 0.01f, 0); // develop(trn, tst, op, 5, sparse); // best: 90.63 at 45 epoch // op = new AdaGradMiniBatch(w, 0.01f, 0, new RegularizedDualAveraging(w, 0.001f)); // develop(trn, tst, op, 5, sparse); // best: 89.12 at 24 epoch // op = new AdaDeltaMiniBatch(w, 0.01f, 0.4f, 0); // develop(trn, tst, op, 5, sparse); // best: 89.36 at 7 epoch // op = new AdaDeltaMiniBatch(w, 0.01f, 0.4f, 0, new RegularizedDualAveraging(w, 0.001f)); // develop(trn, tst, op, 5, sparse); // best: 92.38 at 47 epoch // op = new SoftmaxRegression(w, 0.00000001f, 0); // develop(trn, tst, op, 1, sparse); // best: 92.66 at 46 epoch op = new AdaGradRegression(w, 0.0001f, 0); develop(trn, tst, op, 1, sparse); // ActivationFunction sigmoid = new SigmoidFunction(); // op = new FeedForwardNeuralNetworkSoftmax(new int[]{300}, new ActivationFunction[]{sigmoid}, // 0.0001f, 0, new RandomWeightGenerator(new XORShiftRandom(9), -0.2f, 0.2f)); // develop(trn, tst, op, 1, sparse); }
void iterate(TSVReader<POSNode> reader, List<String> filenames, Consumer<POSNode[]> f) throws Exception { for (String filename : filenames) { reader.open(IOUtils.createFileInputStream(filename)); POSNode[] nodes; while ((nodes = reader.next()) != null) f.accept(nodes); reader.close(); } }
List<Instance> read(InputStream in, boolean sparse) throws Exception { BufferedReader reader = IOUtils.createBufferedReader(in); List<Instance> instances = new ArrayList<>(); String line; String[] t; while ((line = reader.readLine()) != null) { t = Splitter.splitSpace(line); instances.add(sparse ? toSparseInstance(t) : toDenseInstance(t)); } return instances; }