コード例 #1
0
  @Override
  public Instance pipe(Instance carrier) {
    Arg1RankInstance instance = (Arg1RankInstance) carrier;

    Document document = (Document) instance.getData();
    List<Pair<Integer, Integer>> candidates = instance.getCandidates();
    int connStart = instance.getConnStart();
    int connEnd = instance.getConnEnd();
    int arg2Line = instance.getArg2Line();
    int arg2HeadPos = instance.getArg2HeadPos();

    FeatureVector fvs[] = new FeatureVector[candidates.size()];

    for (int i = 0; i < candidates.size(); i++) {
      Pair<Integer, Integer> candidate = candidates.get(i);
      PropertyList pl = null;
      pl = addBaselineFeatures(pl, document, candidate, arg2Line, arg2HeadPos, connStart, connEnd);
      pl =
          addConstituentFeatures(
              pl, document, candidate, arg2Line, arg2HeadPos, connStart, connEnd);
      pl =
          addDependencyFeatures(pl, document, candidate, arg2Line, arg2HeadPos, connStart, connEnd);
      // pl = addLexicoSyntacticFeatures(pl, document, candidate, arg2Line, arg2HeadPos, connStart,
      // connEnd);

      fvs[i] = new FeatureVector(getDataAlphabet(), pl, true, true);
    }

    // set target label
    LabelAlphabet ldict = (LabelAlphabet) getTargetAlphabet();
    carrier.setTarget(ldict.lookupLabel(String.valueOf(instance.getTrueArg1Candidate())));
    carrier.setData(new FeatureVectorSequence(fvs));

    return carrier;
  }
コード例 #2
0
  double[] getTypeSpecificAccuracy(
      InstanceList trainingInstanceList, InstanceList testingInstanceList, boolean show) {
    InstanceList[] trainingInstanceLists = new InstanceList[3];
    InstanceList[] testingInstanceLists = new InstanceList[3];

    for (int i = 0; i < 3; i++) {
      trainingInstanceLists[i] = new InstanceList(trainingInstanceList.getPipe());
      testingInstanceLists[i] = new InstanceList(testingInstanceList.getPipe());
    }

    for (Instance instance : trainingInstanceList) {
      Arg1RankInstance rankInstance = (Arg1RankInstance) instance;
      Sentence sentence = rankInstance.document.getSentence(rankInstance.getArg2Line());
      String conn = sentence.toString(rankInstance.connStart, rankInstance.connEnd).toLowerCase();
      String category = connAnalyzer.getCategory(conn);
      if (category == null) category = "Conj-adverbial";

      if (category.startsWith("Coord")) {
        trainingInstanceLists[0].add(instance);
      } else if (category.startsWith("Sub")) {
        trainingInstanceLists[1].add(instance);
      } else {
        trainingInstanceLists[2].add(instance);
      }
    }
    for (Instance instance : testingInstanceList) {
      Arg1RankInstance rankInstance = (Arg1RankInstance) instance;
      Sentence sentence = rankInstance.document.getSentence(rankInstance.getArg2Line());
      String conn = sentence.toString(rankInstance.connStart, rankInstance.connEnd).toLowerCase();
      String category = connAnalyzer.getCategory(conn);
      if (category == null) category = "Conj-adverbial";

      if (category.startsWith("Coord")) {
        testingInstanceLists[0].add(instance);
      } else if (category.startsWith("Sub")) {
        testingInstanceLists[1].add(instance);
      } else {
        testingInstanceLists[2].add(instance);
      }
    }

    MyClassifierTrainer trainers[] = new MyClassifierTrainer[3];
    Classifier classifiers[] = new Classifier[3];

    double total = 0;
    double correct = 0;
    for (int i = 0; i < 3; i++) {
      trainers[i] = new MyClassifierTrainer(new RankMaxEntTrainer());
      classifiers[i] = trainers[i].train(trainingInstanceLists[i]);
      total += testingInstanceLists[i].size();
      correct +=
          getAccuracy(classifiers[i], testingInstanceLists[i])
              * testingInstanceLists[i].size(); // accuracy * total
    }
    if (show) {
      System.out.println("Using type specific models:");
      System.out.println("Total: " + total);
      System.out.println("Correct: " + correct);
      System.out.println("Accuracy: " + correct / total);
    }
    return new double[] {total, correct, 1.0 * correct / total};
  }
コード例 #3
0
  void showInterpolatedTCAccuracy(
      InstanceList trainingInstanceList, InstanceList testingInstanceList) {
    trainer = new MyClassifierTrainer(new RankMaxEntTrainer());
    RankMaxEnt generalClassifier = (RankMaxEnt) trainer.train(trainingInstanceList);

    InstanceList[] trainingInstanceLists = new InstanceList[3];
    InstanceList[] testingInstanceLists = new InstanceList[3];

    for (int i = 0; i < 3; i++) {
      trainingInstanceLists[i] = new InstanceList(trainingInstanceList.getPipe());
      testingInstanceLists[i] = new InstanceList(testingInstanceList.getPipe());
    }

    for (Instance instance : trainingInstanceList) {
      Arg1RankInstance rankInstance = (Arg1RankInstance) instance;
      Sentence sentence = rankInstance.document.getSentence(rankInstance.getArg2Line());
      String conn = sentence.toString(rankInstance.connStart, rankInstance.connEnd).toLowerCase();
      String category = connAnalyzer.getCategory(conn);
      if (category == null) category = "Conj-adverbial";

      if (category.startsWith("Coord")) {
        trainingInstanceLists[0].add(instance);
      } else if (category.startsWith("Sub")) {
        trainingInstanceLists[1].add(instance);
      } else {
        trainingInstanceLists[2].add(instance);
      }
    }
    for (Instance instance : testingInstanceList) {
      Arg1RankInstance rankInstance = (Arg1RankInstance) instance;
      Sentence sentence = rankInstance.document.getSentence(rankInstance.getArg2Line());
      String conn = sentence.toString(rankInstance.connStart, rankInstance.connEnd).toLowerCase();
      String category = connAnalyzer.getCategory(conn);
      if (category == null) category = "Conj-adverbial";

      if (category.startsWith("Coord")) {
        testingInstanceLists[0].add(instance);
      } else if (category.startsWith("Sub")) {
        testingInstanceLists[1].add(instance);
      } else {
        testingInstanceLists[2].add(instance);
      }
    }

    MyClassifierTrainer trainers[] = new MyClassifierTrainer[3];
    RankMaxEnt classifiers[] = new RankMaxEnt[3];

    double total = 0;
    double correct = 0;
    for (int i = 0; i < 3; i++) {
      trainers[i] = new MyClassifierTrainer(new RankMaxEntTrainer());
      classifiers[i] = (RankMaxEnt) trainers[i].train(trainingInstanceLists[i]);
      total += testingInstanceLists[i].size();
      // correct += getAccuracy(classifiers[i], testingInstanceLists[i]) *
      // testingInstanceLists[i].size(); //accuracy * total
      for (Instance instance : testingInstanceLists[i]) {
        Arg1RankInstance rankInstance = (Arg1RankInstance) instance;
        int trueIndex = rankInstance.trueArg1Candidate;
        double genScores[] = new double[((FeatureVectorSequence) instance.getData()).size()];
        generalClassifier.getClassificationScores(instance, genScores);
        double tcScores[] = new double[((FeatureVectorSequence) instance.getData()).size()];
        classifiers[i].getClassificationScores(instance, tcScores);
        double max = 0;
        int maxIndex = -1;

        for (int j = 0; j < genScores.length; j++) {
          double score = genScores[j] * 0.4 + tcScores[j] * 0.6;
          if (score > max) {
            max = score;
            maxIndex = j;
          }
        }
        if (maxIndex == trueIndex) {
          correct++;
        }
      }
    }

    System.out.println("Using interpolated model:");
    System.out.println("Total: " + total);
    System.out.println("Correct: " + correct);
    System.out.println("Accuracy: " + correct / total);
  }
コード例 #4
0
  /**
   * Shows accuracy according to Ben Wellner's definition of accuracy
   *
   * @param classifier
   * @param instanceList
   */
  private void showAccuracy(Classifier classifier, InstanceList instanceList) throws IOException {
    int total = instanceList.size();
    int correct = 0;
    HashMap<String, Integer> errorMap = new HashMap<String, Integer>();
    FileWriter errorWriter = new FileWriter("arg1Error.log");

    for (Instance instance : instanceList) {
      Classification classification = classifier.classify(instance);
      if (classification.bestLabelIsCorrect()) {
        correct++;
      } else {
        Arg1RankInstance rankInstance = (Arg1RankInstance) instance;
        Document doc = rankInstance.getDocument();
        Sentence s = doc.getSentence(rankInstance.getArg2Line());
        String conn =
            s.toString(rankInstance.getConnStart(), rankInstance.getConnEnd()).toLowerCase();
        // String category = connAnalyzer.getCategory(conn);
        if (errorMap.containsKey(conn)) {
          errorMap.put(conn, errorMap.get(conn) + 1);
        } else {
          errorMap.put(conn, 1);
        }
        int arg2Line = rankInstance.getArg2Line();
        int arg1Line =
            rankInstance.getCandidates().get(rankInstance.getTrueArg1Candidate()).first();
        int arg1HeadPos =
            rankInstance.getCandidates().get(rankInstance.getTrueArg1Candidate()).second();
        int predictedCandidateIndex =
            Integer.parseInt(classification.getLabeling().getBestLabel().toString());

        if (arg1Line == arg2Line) {
          errorWriter.write("FileName: " + doc.getFileName() + "\n");
          errorWriter.write("Sentential\n");
          errorWriter.write("Conn: " + conn + "\n");
          errorWriter.write("Arg1Head: " + s.get(arg1HeadPos).word() + "\n");
          errorWriter.write(s.toString() + "\n\n");
        } else {
          errorWriter.write("FileName: " + doc.getFileName() + "\n");
          errorWriter.write("Inter-Sentential\n");
          errorWriter.write("Arg1 in : " + arg1Line + "\n");
          errorWriter.write("Arg2 in : " + arg2Line + "\n");
          errorWriter.write("Conn: " + conn + "\n");
          errorWriter.write(s.toString() + "\n");
          Sentence s1 = doc.getSentence(arg1Line);
          errorWriter.write("Arg1Head: " + s1.get(arg1HeadPos) + "\n");
          errorWriter.write(s1.toString() + "\n\n");
        }
        int predictedArg1Line = rankInstance.getCandidates().get(predictedCandidateIndex).first();
        int predictedArg1HeadPos =
            rankInstance.getCandidates().get(predictedCandidateIndex).second();
        Sentence pSentence = doc.getSentence(predictedArg1Line);
        errorWriter.write(
            "Predicted arg1 sentence: "
                + pSentence.toString()
                + " [Correct: "
                + (predictedArg1Line == arg1Line)
                + "]\n");
        errorWriter.write("Predicted head: " + pSentence.get(predictedArg1HeadPos).word() + "\n\n");
      }
    }
    errorWriter.close();

    Set<Entry<String, Integer>> entrySet = errorMap.entrySet();
    List<Entry<String, Integer>> list = new ArrayList<Entry<String, Integer>>(entrySet);
    Collections.sort(
        list,
        new Comparator<Entry<String, Integer>>() {

          @Override
          public int compare(Entry<String, Integer> o1, Entry<String, Integer> o2) {
            if (o1.getValue() > o2.getValue()) return -1;
            else if (o1.getValue() < o2.getValue()) return 1;
            return 0;
          }
        });

    for (Entry<String, Integer> item : list) {
      System.out.println(item.getKey() + "-" + item.getValue());
    }

    System.out.println("Total: " + total);
    System.out.println("Correct: " + correct);
    System.out.println("Accuracy: " + (1.0 * correct) / total);
  }