示例#1
0
 public int rank(List<List<String>> candidateFeatures) {
   FeatureVector[] featureVectorArray =
       TrainMalletMaxEntRank.candidateFeatures2FV(candidateFeatures, model.getAlphabet());
   Instance instance = new Instance(featureVectorArray, null, null, null);
   Labeling lab = model.classify(instance).getLabeling();
   return Integer.parseInt(lab.getBestLabel().toString());
 }
  void showInterpolatedTCAccuracy(
      InstanceList trainingInstanceList, InstanceList testingInstanceList) {
    trainer = new MyClassifierTrainer(new RankMaxEntTrainer());
    RankMaxEnt generalClassifier = (RankMaxEnt) trainer.train(trainingInstanceList);

    InstanceList[] trainingInstanceLists = new InstanceList[3];
    InstanceList[] testingInstanceLists = new InstanceList[3];

    for (int i = 0; i < 3; i++) {
      trainingInstanceLists[i] = new InstanceList(trainingInstanceList.getPipe());
      testingInstanceLists[i] = new InstanceList(testingInstanceList.getPipe());
    }

    for (Instance instance : trainingInstanceList) {
      Arg1RankInstance rankInstance = (Arg1RankInstance) instance;
      Sentence sentence = rankInstance.document.getSentence(rankInstance.getArg2Line());
      String conn = sentence.toString(rankInstance.connStart, rankInstance.connEnd).toLowerCase();
      String category = connAnalyzer.getCategory(conn);
      if (category == null) category = "Conj-adverbial";

      if (category.startsWith("Coord")) {
        trainingInstanceLists[0].add(instance);
      } else if (category.startsWith("Sub")) {
        trainingInstanceLists[1].add(instance);
      } else {
        trainingInstanceLists[2].add(instance);
      }
    }
    for (Instance instance : testingInstanceList) {
      Arg1RankInstance rankInstance = (Arg1RankInstance) instance;
      Sentence sentence = rankInstance.document.getSentence(rankInstance.getArg2Line());
      String conn = sentence.toString(rankInstance.connStart, rankInstance.connEnd).toLowerCase();
      String category = connAnalyzer.getCategory(conn);
      if (category == null) category = "Conj-adverbial";

      if (category.startsWith("Coord")) {
        testingInstanceLists[0].add(instance);
      } else if (category.startsWith("Sub")) {
        testingInstanceLists[1].add(instance);
      } else {
        testingInstanceLists[2].add(instance);
      }
    }

    MyClassifierTrainer trainers[] = new MyClassifierTrainer[3];
    RankMaxEnt classifiers[] = new RankMaxEnt[3];

    double total = 0;
    double correct = 0;
    for (int i = 0; i < 3; i++) {
      trainers[i] = new MyClassifierTrainer(new RankMaxEntTrainer());
      classifiers[i] = (RankMaxEnt) trainers[i].train(trainingInstanceLists[i]);
      total += testingInstanceLists[i].size();
      // correct += getAccuracy(classifiers[i], testingInstanceLists[i]) *
      // testingInstanceLists[i].size(); //accuracy * total
      for (Instance instance : testingInstanceLists[i]) {
        Arg1RankInstance rankInstance = (Arg1RankInstance) instance;
        int trueIndex = rankInstance.trueArg1Candidate;
        double genScores[] = new double[((FeatureVectorSequence) instance.getData()).size()];
        generalClassifier.getClassificationScores(instance, genScores);
        double tcScores[] = new double[((FeatureVectorSequence) instance.getData()).size()];
        classifiers[i].getClassificationScores(instance, tcScores);
        double max = 0;
        int maxIndex = -1;

        for (int j = 0; j < genScores.length; j++) {
          double score = genScores[j] * 0.4 + tcScores[j] * 0.6;
          if (score > max) {
            max = score;
            maxIndex = j;
          }
        }
        if (maxIndex == trueIndex) {
          correct++;
        }
      }
    }

    System.out.println("Using interpolated model:");
    System.out.println("Total: " + total);
    System.out.println("Correct: " + correct);
    System.out.println("Accuracy: " + correct / total);
  }