示例#1
0
 /**
  * A helper function for dumping the accuracy of the trained classifier.
  *
  * @param classifier The classifier to evaluate.
  * @param dataset The dataset to evaluate the classifier on.
  */
 public static void dumpAccuracy(
     Classifier<ClauseSplitter.ClauseClassifierLabel, String> classifier,
     GeneralDataset<ClauseSplitter.ClauseClassifierLabel, String> dataset) {
   DecimalFormat df = new DecimalFormat("0.00%");
   log("size:         " + dataset.size());
   log(
       "split count:  "
           + StreamSupport.stream(dataset.spliterator(), false)
               .filter(x -> x.label() == ClauseSplitter.ClauseClassifierLabel.CLAUSE_SPLIT)
               .collect(Collectors.toList())
               .size());
   log(
       "interm count: "
           + StreamSupport.stream(dataset.spliterator(), false)
               .filter(x -> x.label() == ClauseSplitter.ClauseClassifierLabel.CLAUSE_INTERM)
               .collect(Collectors.toList())
               .size());
   Pair<Double, Double> pr =
       classifier.evaluatePrecisionAndRecall(
           dataset, ClauseSplitter.ClauseClassifierLabel.CLAUSE_SPLIT);
   log("p  (split):   " + df.format(pr.first));
   log("r  (split):   " + df.format(pr.second));
   log("f1 (split):   " + df.format(2 * pr.first * pr.second / (pr.first + pr.second)));
   pr =
       classifier.evaluatePrecisionAndRecall(
           dataset, ClauseSplitter.ClauseClassifierLabel.CLAUSE_INTERM);
   log("p  (interm):  " + df.format(pr.first));
   log("r  (interm):  " + df.format(pr.second));
   log("f1 (interm):  " + df.format(2 * pr.first * pr.second / (pr.first + pr.second)));
 }
  public <F> double score(Classifier<L, F> classifier, GeneralDataset<L, F> data) {

    List<L> guesses = new ArrayList<L>();
    List<L> labels = new ArrayList<L>();

    for (int i = 0; i < data.size(); i++) {
      Datum<L, F> d = data.getRVFDatum(i);
      L guess = classifier.classOf(d);
      guesses.add(guess);
    }

    int[] labelsArr = data.getLabelsArray();
    labelIndex = data.labelIndex;
    for (int i = 0; i < data.size(); i++) {
      labels.add(labelIndex.get(labelsArr[i]));
    }

    labelIndex = new HashIndex<L>();
    labelIndex.addAll(data.labelIndex().objectsList());
    labelIndex.addAll(classifier.labels());

    int numClasses = labelIndex.size();
    tpCount = new int[numClasses];
    fpCount = new int[numClasses];
    fnCount = new int[numClasses];

    negIndex = labelIndex.indexOf(negLabel);

    for (int i = 0; i < guesses.size(); ++i) {
      L guess = guesses.get(i);
      int guessIndex = labelIndex.indexOf(guess);
      L label = labels.get(i);
      int trueIndex = labelIndex.indexOf(label);

      if (guessIndex == trueIndex) {
        if (guessIndex != negIndex) {
          tpCount[guessIndex]++;
        }
      } else {
        if (guessIndex != negIndex) {
          fpCount[guessIndex]++;
        }
        if (trueIndex != negIndex) {
          fnCount[trueIndex]++;
        }
      }
    }

    return getFMeasure();
  }
  public <F> double score(Classifier<L, F> classifier, GeneralDataset<L, F> data) {
    labelIndex = new HashIndex<L>();
    labelIndex.addAll(classifier.labels());
    labelIndex.addAll(data.labelIndex.objectsList());
    clearCounts();
    int[] labelsArr = data.getLabelsArray();
    for (int i = 0; i < data.size(); i++) {
      Datum<L, F> d = data.getRVFDatum(i);
      L guess = classifier.classOf(d);
      addGuess(guess, labelIndex.get(labelsArr[i]));
    }
    finalizeCounts();

    return getFMeasure();
  }
示例#4
0
  /**
   * Generate the training features from the CoNLL input file.
   *
   * @return Dataset of feature vectors
   * @throws Exception
   */
  public GeneralDataset<String, String> generateFeatureVectors(Properties props) throws Exception {

    GeneralDataset<String, String> dataset = new Dataset<>();

    Dictionaries dict = new Dictionaries(props);
    MentionExtractor mentionExtractor = new CoNLLMentionExtractor(dict, props, new Semantics(dict));

    Document document;
    while ((document = mentionExtractor.nextDoc()) != null) {
      setTokenIndices(document);
      document.extractGoldCorefClusters();
      Map<Integer, CorefCluster> entities = document.goldCorefClusters;

      // Generate features for coreferent mentions with class label 1
      for (CorefCluster entity : entities.values()) {
        for (Mention mention : entity.getCorefMentions()) {
          // Ignore verbal mentions
          if (mention.headWord.tag().startsWith("V")) continue;

          IndexedWord head = mention.dependency.getNodeByIndexSafe(mention.headWord.index());
          if (head == null) continue;
          ArrayList<String> feats = mention.getSingletonFeatures(dict);
          dataset.add(new BasicDatum<>(feats, "1"));
        }
      }

      // Generate features for singletons with class label 0
      ArrayList<CoreLabel> gold_heads = new ArrayList<>();
      for (Mention gold_men : document.allGoldMentions.values()) {
        gold_heads.add(gold_men.headWord);
      }
      for (Mention predicted_men : document.allPredictedMentions.values()) {
        SemanticGraph dep = predicted_men.dependency;
        IndexedWord head = dep.getNodeByIndexSafe(predicted_men.headWord.index());
        if (head == null) continue;

        // Ignore verbal mentions
        if (predicted_men.headWord.tag().startsWith("V")) continue;
        // If the mention is in the gold set, it is not a singleton and thus ignore
        if (gold_heads.contains(predicted_men.headWord)) continue;

        dataset.add(new BasicDatum<>(predicted_men.getSingletonFeatures(dict), "0"));
      }
    }

    dataset.summaryStatistics();
    return dataset;
  }