public int rank(List<List<String>> candidateFeatures) { FeatureVector[] featureVectorArray = TrainMalletMaxEntRank.candidateFeatures2FV(candidateFeatures, model.getAlphabet()); Instance instance = new Instance(featureVectorArray, null, null, null); Labeling lab = model.classify(instance).getLabeling(); return Integer.parseInt(lab.getBestLabel().toString()); }
void showInterpolatedTCAccuracy( InstanceList trainingInstanceList, InstanceList testingInstanceList) { trainer = new MyClassifierTrainer(new RankMaxEntTrainer()); RankMaxEnt generalClassifier = (RankMaxEnt) trainer.train(trainingInstanceList); InstanceList[] trainingInstanceLists = new InstanceList[3]; InstanceList[] testingInstanceLists = new InstanceList[3]; for (int i = 0; i < 3; i++) { trainingInstanceLists[i] = new InstanceList(trainingInstanceList.getPipe()); testingInstanceLists[i] = new InstanceList(testingInstanceList.getPipe()); } for (Instance instance : trainingInstanceList) { Arg1RankInstance rankInstance = (Arg1RankInstance) instance; Sentence sentence = rankInstance.document.getSentence(rankInstance.getArg2Line()); String conn = sentence.toString(rankInstance.connStart, rankInstance.connEnd).toLowerCase(); String category = connAnalyzer.getCategory(conn); if (category == null) category = "Conj-adverbial"; if (category.startsWith("Coord")) { trainingInstanceLists[0].add(instance); } else if (category.startsWith("Sub")) { trainingInstanceLists[1].add(instance); } else { trainingInstanceLists[2].add(instance); } } for (Instance instance : testingInstanceList) { Arg1RankInstance rankInstance = (Arg1RankInstance) instance; Sentence sentence = rankInstance.document.getSentence(rankInstance.getArg2Line()); String conn = sentence.toString(rankInstance.connStart, rankInstance.connEnd).toLowerCase(); String category = connAnalyzer.getCategory(conn); if (category == null) category = "Conj-adverbial"; if (category.startsWith("Coord")) { testingInstanceLists[0].add(instance); } else if (category.startsWith("Sub")) { testingInstanceLists[1].add(instance); } else { testingInstanceLists[2].add(instance); } } MyClassifierTrainer trainers[] = new MyClassifierTrainer[3]; RankMaxEnt classifiers[] = new RankMaxEnt[3]; double total = 0; double correct = 0; for (int i = 0; i < 3; i++) { trainers[i] = new MyClassifierTrainer(new RankMaxEntTrainer()); classifiers[i] = (RankMaxEnt) trainers[i].train(trainingInstanceLists[i]); total += testingInstanceLists[i].size(); // correct += getAccuracy(classifiers[i], testingInstanceLists[i]) * // testingInstanceLists[i].size(); //accuracy * total for (Instance instance : testingInstanceLists[i]) { Arg1RankInstance rankInstance = (Arg1RankInstance) instance; int trueIndex = rankInstance.trueArg1Candidate; double genScores[] = new double[((FeatureVectorSequence) instance.getData()).size()]; generalClassifier.getClassificationScores(instance, genScores); double tcScores[] = new double[((FeatureVectorSequence) instance.getData()).size()]; classifiers[i].getClassificationScores(instance, tcScores); double max = 0; int maxIndex = -1; for (int j = 0; j < genScores.length; j++) { double score = genScores[j] * 0.4 + tcScores[j] * 0.6; if (score > max) { max = score; maxIndex = j; } } if (maxIndex == trueIndex) { correct++; } } } System.out.println("Using interpolated model:"); System.out.println("Total: " + total); System.out.println("Correct: " + correct); System.out.println("Accuracy: " + correct / total); }