/** * @param targetTerm * @param sourceFile * @param trainingAlgo * @param outputFileClassifier * @param outputFileResults * @param termWindowSize * @param pipe * @return */ private static List<ClassificationResult> runTrainingAndClassification( String targetTerm, String sourceFile, String trainingAlgo, String outputFileClassifier, String outputFileResults, int termWindowSize, Pipe pipe, boolean useCollocationalVector) { // Read in concordance file and create list of Mallet training instances // TODO: Remove duplication of code (see execConvertToMalletFormat(...)) String vectorType = useCollocationalVector ? "coll" : "bow"; InstanceList instanceList = readConcordanceFileToInstanceList( targetTerm, sourceFile, termWindowSize, pipe, useCollocationalVector); // Creating splits for training and testing double[] proportions = {0.9, 0.1}; InstanceList[] splitLists = instanceList.split(proportions); InstanceList trainingList = splitLists[0]; InstanceList testList = splitLists[1]; // Train the classifier ClassifierTrainer classifierTrainer = getClassifierTrainerForAlgorithm(trainingAlgo); Classifier classifier = classifierTrainer.train(trainingList); if (classifier.getLabelAlphabet() != null) { // TODO: Make sure this is not null in RandomClassifier System.out.println("Labels:\n" + classifier.getLabelAlphabet()); System.out.println( "Size of data alphabet (= type count of training list): " + classifier.getAlphabet().size()); } // Run tests and get results Trial trial = new Trial(classifier, testList); List<ClassificationResult> results = new ArrayList<ClassificationResult>(); for (int i = 0; i < classifier.getLabelAlphabet().size(); i++) { Label label = classifier.getLabelAlphabet().lookupLabel(i); ClassificationResult result = new MalletClassificationResult( trainingAlgo, targetTerm, vectorType, label.toString(), termWindowSize, trial, sourceFile); results.add(result); System.out.println(result.toString()); } // Save classifier saveClassifierToFile(outputFileClassifier, classifier, trainingAlgo, termWindowSize); return results; }
public void printPerformance(InstanceList instanceList, String description) { System.out.println(description); System.out.println("Accuracy: " + classifier.getAccuracy(instanceList)); LabelAlphabet labelAlphabet = classifier.getLabelAlphabet(); Iterator iterator = labelAlphabet.iterator(); while (iterator.hasNext()) { Object label = iterator.next(); double p = classifier.getPrecision(instanceList, label); double r = classifier.getRecall(instanceList, label); double f1 = classifier.getF1(instanceList, label); System.out.println("Precision[" + label + "] = " + p); System.out.println("Recall[" + label + "] = " + r); System.out.println("F1[" + label + "] = " + f1); System.out.println(""); } ExtendedTrial trial = new ExtendedTrial(classifier, instanceList); System.out.println("Overall performance\n====="); System.out.println("Precision: " + trial.getPrecision()); System.out.println("Recall: " + trial.getRecall()); System.out.println("F1: " + trial.getF1()); System.out.println("Fall-out: " + trial.getFallOut()); trial.showErrors(); }