public void printPerformance(InstanceList instanceList, String description) { System.out.println(description); System.out.println("Accuracy: " + classifier.getAccuracy(instanceList)); LabelAlphabet labelAlphabet = classifier.getLabelAlphabet(); Iterator iterator = labelAlphabet.iterator(); while (iterator.hasNext()) { Object label = iterator.next(); double p = classifier.getPrecision(instanceList, label); double r = classifier.getRecall(instanceList, label); double f1 = classifier.getF1(instanceList, label); System.out.println("Precision[" + label + "] = " + p); System.out.println("Recall[" + label + "] = " + r); System.out.println("F1[" + label + "] = " + f1); System.out.println(""); } ExtendedTrial trial = new ExtendedTrial(classifier, instanceList); System.out.println("Overall performance\n====="); System.out.println("Precision: " + trial.getPrecision()); System.out.println("Recall: " + trial.getRecall()); System.out.println("F1: " + trial.getF1()); System.out.println("Fall-out: " + trial.getFallOut()); trial.showErrors(); }
public static void trainAndSaveOnWholeData(InstanceList ilist, ClassifierTrainer trainer) throws IOException { Classifier classifier = trainer.train(ilist); String outputFilename = "Models/" + classifier.toString(); saveClassifier(classifier, new File(outputFilename)); }
public static void printLabelings(File modelFile, String inputFile) throws IOException, ClassNotFoundException { // load the classifier from model file, form the output file Classifier classifier = loadClassifier(modelFile); String modelFilename = modelFile.getName(); String outputFile = "Outputs/" + modelFilename.substring(19, modelFilename.length() - 9) + "Output"; System.out.println( "######################Getting outputs for model " + outputFile + " classifier###############"); FileWriter wr = new FileWriter(outputFile); // read the input from input file and create instance variable ArrayList<Instance> instances = readInputFromTSVFile(inputFile); Iterator<?> instancesItr = classifier.getInstancePipe().newIteratorFrom(instances.iterator()); // label and write to the output file while (instancesItr.hasNext()) { Labeling labeling = classifier.classify(instancesItr.next()).getLabeling(); // print the labels with their weights in descending order (ie best first) for (int rank = 0; rank < labeling.numLocations(); rank++) { // System.out.print(labeling.getLabelAtRank(rank) + ":" + // labeling.getValueAtRank(rank) + " "); wr.write(labeling.getLabelAtRank(rank) + ":" + labeling.getValueAtRank(rank) + " "); } // System.out.println(); wr.write("\n"); } wr.close(); }
public static boolean hasPopulation(File modelFile, String inputSentence) throws FileNotFoundException, ClassNotFoundException, IOException { Classifier classifier = loadClassifier(modelFile); String templabel = "1"; Instance input = new Instance(inputSentence, templabel, 1, inputSentence); Instance instanceInput = classifier.getInstancePipe().instanceFrom(input); // since 0th index represent 0th class and 1st index represent 1st class so index==label double index = classifier.classify(instanceInput).getLabeling().getBestIndex(); if (index == 1.0) return true; else return false; }
double getAccuracy(Classifier classifier, InstanceList instanceList) { int total = instanceList.size(); int correct = 0; for (Instance instance : instanceList) { Classification classification = classifier.classify(instance); if (classification.bestLabelIsCorrect()) correct++; } return (1.0 * correct) / total; }
/** * @param targetTerm * @param sourceFile * @param trainingAlgo * @param outputFileClassifier * @param outputFileResults * @param termWindowSize * @param pipe * @return */ private static List<ClassificationResult> runTrainingAndClassification( String targetTerm, String sourceFile, String trainingAlgo, String outputFileClassifier, String outputFileResults, int termWindowSize, Pipe pipe, boolean useCollocationalVector) { // Read in concordance file and create list of Mallet training instances // TODO: Remove duplication of code (see execConvertToMalletFormat(...)) String vectorType = useCollocationalVector ? "coll" : "bow"; InstanceList instanceList = readConcordanceFileToInstanceList( targetTerm, sourceFile, termWindowSize, pipe, useCollocationalVector); // Creating splits for training and testing double[] proportions = {0.9, 0.1}; InstanceList[] splitLists = instanceList.split(proportions); InstanceList trainingList = splitLists[0]; InstanceList testList = splitLists[1]; // Train the classifier ClassifierTrainer classifierTrainer = getClassifierTrainerForAlgorithm(trainingAlgo); Classifier classifier = classifierTrainer.train(trainingList); if (classifier.getLabelAlphabet() != null) { // TODO: Make sure this is not null in RandomClassifier System.out.println("Labels:\n" + classifier.getLabelAlphabet()); System.out.println( "Size of data alphabet (= type count of training list): " + classifier.getAlphabet().size()); } // Run tests and get results Trial trial = new Trial(classifier, testList); List<ClassificationResult> results = new ArrayList<ClassificationResult>(); for (int i = 0; i < classifier.getLabelAlphabet().size(); i++) { Label label = classifier.getLabelAlphabet().lookupLabel(i); ClassificationResult result = new MalletClassificationResult( trainingAlgo, targetTerm, vectorType, label.toString(), termWindowSize, trial, sourceFile); results.add(result); System.out.println(result.toString()); } // Save classifier saveClassifierToFile(outputFileClassifier, classifier, trainingAlgo, termWindowSize); return results; }
/** * Shows accuracy according to Ben Wellner's definition of accuracy * * @param classifier * @param instanceList */ private void showAccuracy(Classifier classifier, InstanceList instanceList) throws IOException { int total = instanceList.size(); int correct = 0; HashMap<String, Integer> errorMap = new HashMap<String, Integer>(); FileWriter errorWriter = new FileWriter("arg1Error.log"); for (Instance instance : instanceList) { Classification classification = classifier.classify(instance); if (classification.bestLabelIsCorrect()) { correct++; } else { Arg1RankInstance rankInstance = (Arg1RankInstance) instance; Document doc = rankInstance.getDocument(); Sentence s = doc.getSentence(rankInstance.getArg2Line()); String conn = s.toString(rankInstance.getConnStart(), rankInstance.getConnEnd()).toLowerCase(); // String category = connAnalyzer.getCategory(conn); if (errorMap.containsKey(conn)) { errorMap.put(conn, errorMap.get(conn) + 1); } else { errorMap.put(conn, 1); } int arg2Line = rankInstance.getArg2Line(); int arg1Line = rankInstance.getCandidates().get(rankInstance.getTrueArg1Candidate()).first(); int arg1HeadPos = rankInstance.getCandidates().get(rankInstance.getTrueArg1Candidate()).second(); int predictedCandidateIndex = Integer.parseInt(classification.getLabeling().getBestLabel().toString()); if (arg1Line == arg2Line) { errorWriter.write("FileName: " + doc.getFileName() + "\n"); errorWriter.write("Sentential\n"); errorWriter.write("Conn: " + conn + "\n"); errorWriter.write("Arg1Head: " + s.get(arg1HeadPos).word() + "\n"); errorWriter.write(s.toString() + "\n\n"); } else { errorWriter.write("FileName: " + doc.getFileName() + "\n"); errorWriter.write("Inter-Sentential\n"); errorWriter.write("Arg1 in : " + arg1Line + "\n"); errorWriter.write("Arg2 in : " + arg2Line + "\n"); errorWriter.write("Conn: " + conn + "\n"); errorWriter.write(s.toString() + "\n"); Sentence s1 = doc.getSentence(arg1Line); errorWriter.write("Arg1Head: " + s1.get(arg1HeadPos) + "\n"); errorWriter.write(s1.toString() + "\n\n"); } int predictedArg1Line = rankInstance.getCandidates().get(predictedCandidateIndex).first(); int predictedArg1HeadPos = rankInstance.getCandidates().get(predictedCandidateIndex).second(); Sentence pSentence = doc.getSentence(predictedArg1Line); errorWriter.write( "Predicted arg1 sentence: " + pSentence.toString() + " [Correct: " + (predictedArg1Line == arg1Line) + "]\n"); errorWriter.write("Predicted head: " + pSentence.get(predictedArg1HeadPos).word() + "\n\n"); } } errorWriter.close(); Set<Entry<String, Integer>> entrySet = errorMap.entrySet(); List<Entry<String, Integer>> list = new ArrayList<Entry<String, Integer>>(entrySet); Collections.sort( list, new Comparator<Entry<String, Integer>>() { @Override public int compare(Entry<String, Integer> o1, Entry<String, Integer> o2) { if (o1.getValue() > o2.getValue()) return -1; else if (o1.getValue() < o2.getValue()) return 1; return 0; } }); for (Entry<String, Integer> item : list) { System.out.println(item.getKey() + "-" + item.getValue()); } System.out.println("Total: " + total); System.out.println("Correct: " + correct); System.out.println("Accuracy: " + (1.0 * correct) / total); }
public static ArrayList<Double> getAverageCrossValidationScore( InstanceList ilist, int i, ClassifierTrainer trainer) { double crossValidAccSum = 0; double crossValidPrcSum = 0; double crossValidRecSum = 0; double crossValidF1Sum = 0; int count = 0; // get gross validation folds CrossValidationIterator cvIlists = ilist.crossValidationIterator(i); while (cvIlists.hasNext()) { System.out.println("#############Performing " + count + " iteration###########"); InstanceList[] ilists = cvIlists.next(); System.out.println("The train set size is " + ilists[0].size()); System.out.println("The test set size is " + ilists[1].size()); Classifier classifier = trainer.train(ilists[0]); System.out.println("The training accuracy is " + classifier.getAccuracy(ilists[0])); System.out.println("The testing accuracy is " + classifier.getAccuracy(ilists[1])); System.out.println("The testing precision is " + classifier.getPrecision(ilists[1], 1)); System.out.println("The testing recall is " + classifier.getRecall(ilists[1], 1)); System.out.println("The testing f1score is " + classifier.getF1(ilists[1], 1)); crossValidAccSum += classifier.getAccuracy(ilists[1]); crossValidPrcSum += classifier.getPrecision(ilists[1], 1); crossValidRecSum += classifier.getRecall(ilists[1], 1); crossValidF1Sum += classifier.getF1(ilists[1], 1); count++; // additional calculations ArrayList<Classification> outClassifications = classifier.classify(ilists[1]); int p1l1 = 0; int p1l0 = 0; int p0l1 = 0; int p0l0 = 0; int countCorrect = 0; int countIncorrect = 0; System.out.println("Outclassification size " + outClassifications.size()); for (int k = 0; k < outClassifications.size(); k++) { // System.out.println("Data "+outClassifications.get(k).getInstance().getName()); // System.out.println("Labeling "+outClassifications.get(k).getLabeling()); uncomment to get // score double predictedLabel = outClassifications.get(k).getLabeling().getBestIndex(); // System.out.println("Predicted label "+ predictedLabel); double targetLabel = Double.valueOf(outClassifications.get(k).getInstance().getTarget().toString()); // System.out.println("Target "+ targetLabel); boolean bestlabelIsCorrect = outClassifications.get(k).bestLabelIsCorrect(); // System.out.println("Prediction "+bestlabelIsCorrect); if (bestlabelIsCorrect) countCorrect++; else countIncorrect++; if ((predictedLabel == 1.0) && (targetLabel == 1.0)) p1l1++; else if ((predictedLabel == 1.0) && (targetLabel == 0.0)) p1l0++; else if ((predictedLabel == 0.0) && (targetLabel == 1.0)) p0l1++; else if ((predictedLabel == 0.0) && (targetLabel == 0.0)) p0l0++; } System.out.println("Count Correct " + countCorrect); System.out.println("Count Incorrect " + countIncorrect); System.out.println("p1l1 " + p1l1); System.out.println("p1l0 " + p1l0); System.out.println("p0l1 " + p0l1); System.out.println("p0l0 " + p0l0); } ArrayList<Double> results = new ArrayList<Double>(); double crossValidAccAvg = crossValidAccSum / count; double crossValidPrcAvg = crossValidPrcSum / count; double crossValidRecAvg = crossValidRecSum / count; double crossValidF1Avg = crossValidF1Sum / count; results.add(crossValidAccAvg); results.add(crossValidPrcAvg); results.add(crossValidRecAvg); results.add(crossValidF1Avg); return results; }