public void train(String[] trainSections) throws IOException { pipe = defaultPipe(); InstanceList trainingInstanceList = prepareInstanceList(trainSections); /*NFoldEvaluator evaluator = new NFoldEvaluator(); evaluator.evaluate(trainer, trainingInstanceList, 10); */ /*InstanceList[] instanceLists = trainingInstanceList.splitInOrder(new double[]{0.9, 0.1}); Classifier classifier = trainer.train(instanceLists[0]); showAccuracy(classifier, instanceLists[1]);*/ // showNFoldAccuracy(trainingInstanceList, 10, 2633); trainer.train(trainingInstanceList); // showNFoldTypeSpecificAccuracy(trainingInstanceList, 10); }
public void train(String[] trainSections, String[] testSections) throws IOException { pipe = defaultPipe(); InstanceList trainingInstanceList = prepareInstanceList(trainSections); InstanceList testingInstanceList = prepareInstanceList(testSections); // Classifier classifier = trainer.train(trainingInstanceList, testingInstanceList); Classifier classifier = trainer.train(trainingInstanceList); System.out.println("training size: " + trainingInstanceList.size()); System.out.println("testing size: " + testingInstanceList.size()); // showAccuracy(classifier, testingInstanceList); // getTypeSpecificAccuracy(trainingInstanceList, testingInstanceList, true); // showInterpolatedTCAccuracy(trainingInstanceList, testingInstanceList); }
private void showNFoldAccuracy(InstanceList instanceList, int n, int count) { InstanceList.CrossValidationIterator cvIt = instanceList.crossValidationIterator(n); double accuracies[] = new double[n]; double accuracy = 0; int run = 0; double totalTP = 0; while (cvIt.hasNext()) { InstanceList[] nextSplit = cvIt.nextSplit(); InstanceList trainingInstances = nextSplit[0]; InstanceList testingInstances = nextSplit[1]; trainer = new MyClassifierTrainer(new RankMaxEntTrainer()); Classifier classifier = trainer.train(trainingInstances); accuracies[run] = getAccuracy(classifier, testingInstances); accuracy += accuracies[run]; totalTP += accuracies[run] * testingInstances.size(); run++; } System.out.println(n + "-Fold accuracy(avg): " + accuracy / n); System.out.println("Total tp:" + totalTP); System.out.println("Total count:" + count); System.out.println(n + "-Fold accuracy: " + totalTP / count); }
void showInterpolatedTCAccuracy( InstanceList trainingInstanceList, InstanceList testingInstanceList) { trainer = new MyClassifierTrainer(new RankMaxEntTrainer()); RankMaxEnt generalClassifier = (RankMaxEnt) trainer.train(trainingInstanceList); InstanceList[] trainingInstanceLists = new InstanceList[3]; InstanceList[] testingInstanceLists = new InstanceList[3]; for (int i = 0; i < 3; i++) { trainingInstanceLists[i] = new InstanceList(trainingInstanceList.getPipe()); testingInstanceLists[i] = new InstanceList(testingInstanceList.getPipe()); } for (Instance instance : trainingInstanceList) { Arg1RankInstance rankInstance = (Arg1RankInstance) instance; Sentence sentence = rankInstance.document.getSentence(rankInstance.getArg2Line()); String conn = sentence.toString(rankInstance.connStart, rankInstance.connEnd).toLowerCase(); String category = connAnalyzer.getCategory(conn); if (category == null) category = "Conj-adverbial"; if (category.startsWith("Coord")) { trainingInstanceLists[0].add(instance); } else if (category.startsWith("Sub")) { trainingInstanceLists[1].add(instance); } else { trainingInstanceLists[2].add(instance); } } for (Instance instance : testingInstanceList) { Arg1RankInstance rankInstance = (Arg1RankInstance) instance; Sentence sentence = rankInstance.document.getSentence(rankInstance.getArg2Line()); String conn = sentence.toString(rankInstance.connStart, rankInstance.connEnd).toLowerCase(); String category = connAnalyzer.getCategory(conn); if (category == null) category = "Conj-adverbial"; if (category.startsWith("Coord")) { testingInstanceLists[0].add(instance); } else if (category.startsWith("Sub")) { testingInstanceLists[1].add(instance); } else { testingInstanceLists[2].add(instance); } } MyClassifierTrainer trainers[] = new MyClassifierTrainer[3]; RankMaxEnt classifiers[] = new RankMaxEnt[3]; double total = 0; double correct = 0; for (int i = 0; i < 3; i++) { trainers[i] = new MyClassifierTrainer(new RankMaxEntTrainer()); classifiers[i] = (RankMaxEnt) trainers[i].train(trainingInstanceLists[i]); total += testingInstanceLists[i].size(); // correct += getAccuracy(classifiers[i], testingInstanceLists[i]) * // testingInstanceLists[i].size(); //accuracy * total for (Instance instance : testingInstanceLists[i]) { Arg1RankInstance rankInstance = (Arg1RankInstance) instance; int trueIndex = rankInstance.trueArg1Candidate; double genScores[] = new double[((FeatureVectorSequence) instance.getData()).size()]; generalClassifier.getClassificationScores(instance, genScores); double tcScores[] = new double[((FeatureVectorSequence) instance.getData()).size()]; classifiers[i].getClassificationScores(instance, tcScores); double max = 0; int maxIndex = -1; for (int j = 0; j < genScores.length; j++) { double score = genScores[j] * 0.4 + tcScores[j] * 0.6; if (score > max) { max = score; maxIndex = j; } } if (maxIndex == trueIndex) { correct++; } } } System.out.println("Using interpolated model:"); System.out.println("Total: " + total); System.out.println("Correct: " + correct); System.out.println("Accuracy: " + correct / total); }