/** * A helper function for dumping the accuracy of the trained classifier. * * @param classifier The classifier to evaluate. * @param dataset The dataset to evaluate the classifier on. */ public static void dumpAccuracy( Classifier<ClauseSplitter.ClauseClassifierLabel, String> classifier, GeneralDataset<ClauseSplitter.ClauseClassifierLabel, String> dataset) { DecimalFormat df = new DecimalFormat("0.00%"); log("size: " + dataset.size()); log( "split count: " + StreamSupport.stream(dataset.spliterator(), false) .filter(x -> x.label() == ClauseSplitter.ClauseClassifierLabel.CLAUSE_SPLIT) .collect(Collectors.toList()) .size()); log( "interm count: " + StreamSupport.stream(dataset.spliterator(), false) .filter(x -> x.label() == ClauseSplitter.ClauseClassifierLabel.CLAUSE_INTERM) .collect(Collectors.toList()) .size()); Pair<Double, Double> pr = classifier.evaluatePrecisionAndRecall( dataset, ClauseSplitter.ClauseClassifierLabel.CLAUSE_SPLIT); log("p (split): " + df.format(pr.first)); log("r (split): " + df.format(pr.second)); log("f1 (split): " + df.format(2 * pr.first * pr.second / (pr.first + pr.second))); pr = classifier.evaluatePrecisionAndRecall( dataset, ClauseSplitter.ClauseClassifierLabel.CLAUSE_INTERM); log("p (interm): " + df.format(pr.first)); log("r (interm): " + df.format(pr.second)); log("f1 (interm): " + df.format(2 * pr.first * pr.second / (pr.first + pr.second))); }
public <F> double score(Classifier<L, F> classifier, GeneralDataset<L, F> data) { List<L> guesses = new ArrayList<L>(); List<L> labels = new ArrayList<L>(); for (int i = 0; i < data.size(); i++) { Datum<L, F> d = data.getRVFDatum(i); L guess = classifier.classOf(d); guesses.add(guess); } int[] labelsArr = data.getLabelsArray(); labelIndex = data.labelIndex; for (int i = 0; i < data.size(); i++) { labels.add(labelIndex.get(labelsArr[i])); } labelIndex = new HashIndex<L>(); labelIndex.addAll(data.labelIndex().objectsList()); labelIndex.addAll(classifier.labels()); int numClasses = labelIndex.size(); tpCount = new int[numClasses]; fpCount = new int[numClasses]; fnCount = new int[numClasses]; negIndex = labelIndex.indexOf(negLabel); for (int i = 0; i < guesses.size(); ++i) { L guess = guesses.get(i); int guessIndex = labelIndex.indexOf(guess); L label = labels.get(i); int trueIndex = labelIndex.indexOf(label); if (guessIndex == trueIndex) { if (guessIndex != negIndex) { tpCount[guessIndex]++; } } else { if (guessIndex != negIndex) { fpCount[guessIndex]++; } if (trueIndex != negIndex) { fnCount[trueIndex]++; } } } return getFMeasure(); }
public <F> double score(Classifier<L, F> classifier, GeneralDataset<L, F> data) { labelIndex = new HashIndex<L>(); labelIndex.addAll(classifier.labels()); labelIndex.addAll(data.labelIndex.objectsList()); clearCounts(); int[] labelsArr = data.getLabelsArray(); for (int i = 0; i < data.size(); i++) { Datum<L, F> d = data.getRVFDatum(i); L guess = classifier.classOf(d); addGuess(guess, labelIndex.get(labelsArr[i])); } finalizeCounts(); return getFMeasure(); }
/** * Generate the training features from the CoNLL input file. * * @return Dataset of feature vectors * @throws Exception */ public GeneralDataset<String, String> generateFeatureVectors(Properties props) throws Exception { GeneralDataset<String, String> dataset = new Dataset<>(); Dictionaries dict = new Dictionaries(props); MentionExtractor mentionExtractor = new CoNLLMentionExtractor(dict, props, new Semantics(dict)); Document document; while ((document = mentionExtractor.nextDoc()) != null) { setTokenIndices(document); document.extractGoldCorefClusters(); Map<Integer, CorefCluster> entities = document.goldCorefClusters; // Generate features for coreferent mentions with class label 1 for (CorefCluster entity : entities.values()) { for (Mention mention : entity.getCorefMentions()) { // Ignore verbal mentions if (mention.headWord.tag().startsWith("V")) continue; IndexedWord head = mention.dependency.getNodeByIndexSafe(mention.headWord.index()); if (head == null) continue; ArrayList<String> feats = mention.getSingletonFeatures(dict); dataset.add(new BasicDatum<>(feats, "1")); } } // Generate features for singletons with class label 0 ArrayList<CoreLabel> gold_heads = new ArrayList<>(); for (Mention gold_men : document.allGoldMentions.values()) { gold_heads.add(gold_men.headWord); } for (Mention predicted_men : document.allPredictedMentions.values()) { SemanticGraph dep = predicted_men.dependency; IndexedWord head = dep.getNodeByIndexSafe(predicted_men.headWord.index()); if (head == null) continue; // Ignore verbal mentions if (predicted_men.headWord.tag().startsWith("V")) continue; // If the mention is in the gold set, it is not a singleton and thus ignore if (gold_heads.contains(predicted_men.headWord)) continue; dataset.add(new BasicDatum<>(predicted_men.getSingletonFeatures(dict), "0")); } } dataset.summaryStatistics(); return dataset; }