Ejemplo n.º 1
0
  @Test
  public void baseline() throws Exception {
    run(new POSFeatureTemplate1(), 0, true);

    DEPNode[] nodes = null;
    AmbiguityClassMap map = new AmbiguityClassMap();
    map.add(nodes);
  }
Ejemplo n.º 2
0
  public void run(FeatureTemplate<POSNode, POSState<POSNode>> template, int type, boolean average)
      throws Exception {
    final String root = "/Users/jdchoi/Documents/Data/experiments/wsj/wsj-pos/";
    TSVReader<POSNode> reader = new TSVReader<>(new POSIndex(0, 1));
    List<String> trnFiles = FileUtils.getFileList(root + "trn/", "pos");
    List<String> devFiles = FileUtils.getFileList(root + "dev/", "pos");
    Collections.sort(trnFiles);
    Collections.sort(devFiles);

    // collect ambiguity classes
    AmbiguityClassMap map = new AmbiguityClassMap();
    iterate(reader, trnFiles, nodes -> map.add(nodes));
    map.expand(0.4);

    // collect training instances
    StringModel model = new StringModel(new MultinomialWeightVector(), false);
    POSTagger<POSNode> tagger = new POSTagger<>(model);
    tagger.setFlag(NLPFlag.TRAIN);
    tagger.setAmbiguityClassMap(map);
    tagger.setFeatureTemplate(template);
    iterate(reader, trnFiles, nodes -> tagger.process(nodes));

    // map string into a vector space
    final int label_cutoff = 0;
    final int feature_cutoff = 0;
    model.vectorize(label_cutoff, feature_cutoff, false);

    // train the statistical model
    final double learning_rate = 0.02;
    final int epochs = 1;

    WeightVector weight = model.getWeightVector();
    OnlineOptimizer sgd = null;

    switch (type) {
      case 0:
        sgd = new Perceptron(weight, average, learning_rate);
        break;
      case 1:
        sgd = new AdaGrad(weight, average, learning_rate);
        break;
      case 2:
        sgd = new LogisticRegression(weight, average, learning_rate);
        break;
    }

    Eval eval = new AccuracyEval();
    tagger.setFlag(NLPFlag.EVALUATE);
    tagger.setEval(eval);

    DoubleIntPair best = new DoubleIntPair(-1, -1);
    double currScore;

    for (int i = 0; i < epochs; i++) {
      sgd.train(model.getInstanceList());
      eval.clear();
      iterate(reader, devFiles, nodes -> tagger.process(nodes));
      currScore = eval.score();
      System.out.printf("%4d: %5.2f\n", i, currScore);
      if (best.d < currScore) best.set(currScore, i);
    }

    System.out.printf("Best: %d - %5.2f\n", best.i, best.d);

    //		System.out.println("Saving");
    //		ObjectOutputStream out = IOUtils.createObjectXZBufferedOutputStream("tmp");
    //		out.writeObject(tagger);
    //		out.close();
    //
    //		System.out.println("Loading");
    //		ObjectInputStream in = IOUtils.createObjectXZBufferedInputStream("tmp");
    //		NLPComponent<POSNode,POSState<POSNode>> component =
    // (NLPComponent<POSNode,POSState<POSNode>>)in.readObject();
    //		eval.clear();
    //		component.setEval(eval);
    //		component.setFlag(NLPFlag.EVALUATE);
    //
    //		System.out.println("Decoding");
    //		iterate(reader, devFiles, nodes -> component.process(nodes));
    //		System.out.println(eval.score());
  }