@Override public Instance pipe(Instance carrier) { Arg1RankInstance instance = (Arg1RankInstance) carrier; Document document = (Document) instance.getData(); List<Pair<Integer, Integer>> candidates = instance.getCandidates(); int connStart = instance.getConnStart(); int connEnd = instance.getConnEnd(); int arg2Line = instance.getArg2Line(); int arg2HeadPos = instance.getArg2HeadPos(); FeatureVector fvs[] = new FeatureVector[candidates.size()]; for (int i = 0; i < candidates.size(); i++) { Pair<Integer, Integer> candidate = candidates.get(i); PropertyList pl = null; pl = addBaselineFeatures(pl, document, candidate, arg2Line, arg2HeadPos, connStart, connEnd); pl = addConstituentFeatures( pl, document, candidate, arg2Line, arg2HeadPos, connStart, connEnd); pl = addDependencyFeatures(pl, document, candidate, arg2Line, arg2HeadPos, connStart, connEnd); // pl = addLexicoSyntacticFeatures(pl, document, candidate, arg2Line, arg2HeadPos, connStart, // connEnd); fvs[i] = new FeatureVector(getDataAlphabet(), pl, true, true); } // set target label LabelAlphabet ldict = (LabelAlphabet) getTargetAlphabet(); carrier.setTarget(ldict.lookupLabel(String.valueOf(instance.getTrueArg1Candidate()))); carrier.setData(new FeatureVectorSequence(fvs)); return carrier; }
double[] getTypeSpecificAccuracy( InstanceList trainingInstanceList, InstanceList testingInstanceList, boolean show) { InstanceList[] trainingInstanceLists = new InstanceList[3]; InstanceList[] testingInstanceLists = new InstanceList[3]; for (int i = 0; i < 3; i++) { trainingInstanceLists[i] = new InstanceList(trainingInstanceList.getPipe()); testingInstanceLists[i] = new InstanceList(testingInstanceList.getPipe()); } for (Instance instance : trainingInstanceList) { Arg1RankInstance rankInstance = (Arg1RankInstance) instance; Sentence sentence = rankInstance.document.getSentence(rankInstance.getArg2Line()); String conn = sentence.toString(rankInstance.connStart, rankInstance.connEnd).toLowerCase(); String category = connAnalyzer.getCategory(conn); if (category == null) category = "Conj-adverbial"; if (category.startsWith("Coord")) { trainingInstanceLists[0].add(instance); } else if (category.startsWith("Sub")) { trainingInstanceLists[1].add(instance); } else { trainingInstanceLists[2].add(instance); } } for (Instance instance : testingInstanceList) { Arg1RankInstance rankInstance = (Arg1RankInstance) instance; Sentence sentence = rankInstance.document.getSentence(rankInstance.getArg2Line()); String conn = sentence.toString(rankInstance.connStart, rankInstance.connEnd).toLowerCase(); String category = connAnalyzer.getCategory(conn); if (category == null) category = "Conj-adverbial"; if (category.startsWith("Coord")) { testingInstanceLists[0].add(instance); } else if (category.startsWith("Sub")) { testingInstanceLists[1].add(instance); } else { testingInstanceLists[2].add(instance); } } MyClassifierTrainer trainers[] = new MyClassifierTrainer[3]; Classifier classifiers[] = new Classifier[3]; double total = 0; double correct = 0; for (int i = 0; i < 3; i++) { trainers[i] = new MyClassifierTrainer(new RankMaxEntTrainer()); classifiers[i] = trainers[i].train(trainingInstanceLists[i]); total += testingInstanceLists[i].size(); correct += getAccuracy(classifiers[i], testingInstanceLists[i]) * testingInstanceLists[i].size(); // accuracy * total } if (show) { System.out.println("Using type specific models:"); System.out.println("Total: " + total); System.out.println("Correct: " + correct); System.out.println("Accuracy: " + correct / total); } return new double[] {total, correct, 1.0 * correct / total}; }
void showInterpolatedTCAccuracy( InstanceList trainingInstanceList, InstanceList testingInstanceList) { trainer = new MyClassifierTrainer(new RankMaxEntTrainer()); RankMaxEnt generalClassifier = (RankMaxEnt) trainer.train(trainingInstanceList); InstanceList[] trainingInstanceLists = new InstanceList[3]; InstanceList[] testingInstanceLists = new InstanceList[3]; for (int i = 0; i < 3; i++) { trainingInstanceLists[i] = new InstanceList(trainingInstanceList.getPipe()); testingInstanceLists[i] = new InstanceList(testingInstanceList.getPipe()); } for (Instance instance : trainingInstanceList) { Arg1RankInstance rankInstance = (Arg1RankInstance) instance; Sentence sentence = rankInstance.document.getSentence(rankInstance.getArg2Line()); String conn = sentence.toString(rankInstance.connStart, rankInstance.connEnd).toLowerCase(); String category = connAnalyzer.getCategory(conn); if (category == null) category = "Conj-adverbial"; if (category.startsWith("Coord")) { trainingInstanceLists[0].add(instance); } else if (category.startsWith("Sub")) { trainingInstanceLists[1].add(instance); } else { trainingInstanceLists[2].add(instance); } } for (Instance instance : testingInstanceList) { Arg1RankInstance rankInstance = (Arg1RankInstance) instance; Sentence sentence = rankInstance.document.getSentence(rankInstance.getArg2Line()); String conn = sentence.toString(rankInstance.connStart, rankInstance.connEnd).toLowerCase(); String category = connAnalyzer.getCategory(conn); if (category == null) category = "Conj-adverbial"; if (category.startsWith("Coord")) { testingInstanceLists[0].add(instance); } else if (category.startsWith("Sub")) { testingInstanceLists[1].add(instance); } else { testingInstanceLists[2].add(instance); } } MyClassifierTrainer trainers[] = new MyClassifierTrainer[3]; RankMaxEnt classifiers[] = new RankMaxEnt[3]; double total = 0; double correct = 0; for (int i = 0; i < 3; i++) { trainers[i] = new MyClassifierTrainer(new RankMaxEntTrainer()); classifiers[i] = (RankMaxEnt) trainers[i].train(trainingInstanceLists[i]); total += testingInstanceLists[i].size(); // correct += getAccuracy(classifiers[i], testingInstanceLists[i]) * // testingInstanceLists[i].size(); //accuracy * total for (Instance instance : testingInstanceLists[i]) { Arg1RankInstance rankInstance = (Arg1RankInstance) instance; int trueIndex = rankInstance.trueArg1Candidate; double genScores[] = new double[((FeatureVectorSequence) instance.getData()).size()]; generalClassifier.getClassificationScores(instance, genScores); double tcScores[] = new double[((FeatureVectorSequence) instance.getData()).size()]; classifiers[i].getClassificationScores(instance, tcScores); double max = 0; int maxIndex = -1; for (int j = 0; j < genScores.length; j++) { double score = genScores[j] * 0.4 + tcScores[j] * 0.6; if (score > max) { max = score; maxIndex = j; } } if (maxIndex == trueIndex) { correct++; } } } System.out.println("Using interpolated model:"); System.out.println("Total: " + total); System.out.println("Correct: " + correct); System.out.println("Accuracy: " + correct / total); }
/** * Shows accuracy according to Ben Wellner's definition of accuracy * * @param classifier * @param instanceList */ private void showAccuracy(Classifier classifier, InstanceList instanceList) throws IOException { int total = instanceList.size(); int correct = 0; HashMap<String, Integer> errorMap = new HashMap<String, Integer>(); FileWriter errorWriter = new FileWriter("arg1Error.log"); for (Instance instance : instanceList) { Classification classification = classifier.classify(instance); if (classification.bestLabelIsCorrect()) { correct++; } else { Arg1RankInstance rankInstance = (Arg1RankInstance) instance; Document doc = rankInstance.getDocument(); Sentence s = doc.getSentence(rankInstance.getArg2Line()); String conn = s.toString(rankInstance.getConnStart(), rankInstance.getConnEnd()).toLowerCase(); // String category = connAnalyzer.getCategory(conn); if (errorMap.containsKey(conn)) { errorMap.put(conn, errorMap.get(conn) + 1); } else { errorMap.put(conn, 1); } int arg2Line = rankInstance.getArg2Line(); int arg1Line = rankInstance.getCandidates().get(rankInstance.getTrueArg1Candidate()).first(); int arg1HeadPos = rankInstance.getCandidates().get(rankInstance.getTrueArg1Candidate()).second(); int predictedCandidateIndex = Integer.parseInt(classification.getLabeling().getBestLabel().toString()); if (arg1Line == arg2Line) { errorWriter.write("FileName: " + doc.getFileName() + "\n"); errorWriter.write("Sentential\n"); errorWriter.write("Conn: " + conn + "\n"); errorWriter.write("Arg1Head: " + s.get(arg1HeadPos).word() + "\n"); errorWriter.write(s.toString() + "\n\n"); } else { errorWriter.write("FileName: " + doc.getFileName() + "\n"); errorWriter.write("Inter-Sentential\n"); errorWriter.write("Arg1 in : " + arg1Line + "\n"); errorWriter.write("Arg2 in : " + arg2Line + "\n"); errorWriter.write("Conn: " + conn + "\n"); errorWriter.write(s.toString() + "\n"); Sentence s1 = doc.getSentence(arg1Line); errorWriter.write("Arg1Head: " + s1.get(arg1HeadPos) + "\n"); errorWriter.write(s1.toString() + "\n\n"); } int predictedArg1Line = rankInstance.getCandidates().get(predictedCandidateIndex).first(); int predictedArg1HeadPos = rankInstance.getCandidates().get(predictedCandidateIndex).second(); Sentence pSentence = doc.getSentence(predictedArg1Line); errorWriter.write( "Predicted arg1 sentence: " + pSentence.toString() + " [Correct: " + (predictedArg1Line == arg1Line) + "]\n"); errorWriter.write("Predicted head: " + pSentence.get(predictedArg1HeadPos).word() + "\n\n"); } } errorWriter.close(); Set<Entry<String, Integer>> entrySet = errorMap.entrySet(); List<Entry<String, Integer>> list = new ArrayList<Entry<String, Integer>>(entrySet); Collections.sort( list, new Comparator<Entry<String, Integer>>() { @Override public int compare(Entry<String, Integer> o1, Entry<String, Integer> o2) { if (o1.getValue() > o2.getValue()) return -1; else if (o1.getValue() < o2.getValue()) return 1; return 0; } }); for (Entry<String, Integer> item : list) { System.out.println(item.getKey() + "-" + item.getValue()); } System.out.println("Total: " + total); System.out.println("Correct: " + correct); System.out.println("Accuracy: " + (1.0 * correct) / total); }