コード例 #1
0
 double getAccuracy(Classifier classifier, InstanceList instanceList) {
   int total = instanceList.size();
   int correct = 0;
   for (Instance instance : instanceList) {
     Classification classification = classifier.classify(instance);
     if (classification.bestLabelIsCorrect()) correct++;
   }
   return (1.0 * correct) / total;
 }
コード例 #2
0
  /**
   * Shows accuracy according to Ben Wellner's definition of accuracy
   *
   * @param classifier
   * @param instanceList
   */
  private void showAccuracy(Classifier classifier, InstanceList instanceList) throws IOException {
    int total = instanceList.size();
    int correct = 0;
    HashMap<String, Integer> errorMap = new HashMap<String, Integer>();
    FileWriter errorWriter = new FileWriter("arg1Error.log");

    for (Instance instance : instanceList) {
      Classification classification = classifier.classify(instance);
      if (classification.bestLabelIsCorrect()) {
        correct++;
      } else {
        Arg1RankInstance rankInstance = (Arg1RankInstance) instance;
        Document doc = rankInstance.getDocument();
        Sentence s = doc.getSentence(rankInstance.getArg2Line());
        String conn =
            s.toString(rankInstance.getConnStart(), rankInstance.getConnEnd()).toLowerCase();
        // String category = connAnalyzer.getCategory(conn);
        if (errorMap.containsKey(conn)) {
          errorMap.put(conn, errorMap.get(conn) + 1);
        } else {
          errorMap.put(conn, 1);
        }
        int arg2Line = rankInstance.getArg2Line();
        int arg1Line =
            rankInstance.getCandidates().get(rankInstance.getTrueArg1Candidate()).first();
        int arg1HeadPos =
            rankInstance.getCandidates().get(rankInstance.getTrueArg1Candidate()).second();
        int predictedCandidateIndex =
            Integer.parseInt(classification.getLabeling().getBestLabel().toString());

        if (arg1Line == arg2Line) {
          errorWriter.write("FileName: " + doc.getFileName() + "\n");
          errorWriter.write("Sentential\n");
          errorWriter.write("Conn: " + conn + "\n");
          errorWriter.write("Arg1Head: " + s.get(arg1HeadPos).word() + "\n");
          errorWriter.write(s.toString() + "\n\n");
        } else {
          errorWriter.write("FileName: " + doc.getFileName() + "\n");
          errorWriter.write("Inter-Sentential\n");
          errorWriter.write("Arg1 in : " + arg1Line + "\n");
          errorWriter.write("Arg2 in : " + arg2Line + "\n");
          errorWriter.write("Conn: " + conn + "\n");
          errorWriter.write(s.toString() + "\n");
          Sentence s1 = doc.getSentence(arg1Line);
          errorWriter.write("Arg1Head: " + s1.get(arg1HeadPos) + "\n");
          errorWriter.write(s1.toString() + "\n\n");
        }
        int predictedArg1Line = rankInstance.getCandidates().get(predictedCandidateIndex).first();
        int predictedArg1HeadPos =
            rankInstance.getCandidates().get(predictedCandidateIndex).second();
        Sentence pSentence = doc.getSentence(predictedArg1Line);
        errorWriter.write(
            "Predicted arg1 sentence: "
                + pSentence.toString()
                + " [Correct: "
                + (predictedArg1Line == arg1Line)
                + "]\n");
        errorWriter.write("Predicted head: " + pSentence.get(predictedArg1HeadPos).word() + "\n\n");
      }
    }
    errorWriter.close();

    Set<Entry<String, Integer>> entrySet = errorMap.entrySet();
    List<Entry<String, Integer>> list = new ArrayList<Entry<String, Integer>>(entrySet);
    Collections.sort(
        list,
        new Comparator<Entry<String, Integer>>() {

          @Override
          public int compare(Entry<String, Integer> o1, Entry<String, Integer> o2) {
            if (o1.getValue() > o2.getValue()) return -1;
            else if (o1.getValue() < o2.getValue()) return 1;
            return 0;
          }
        });

    for (Entry<String, Integer> item : list) {
      System.out.println(item.getKey() + "-" + item.getValue());
    }

    System.out.println("Total: " + total);
    System.out.println("Correct: " + correct);
    System.out.println("Accuracy: " + (1.0 * correct) / total);
  }