/** * @param targetTerm * @param sourceFile * @param trainingAlgo * @param outputFileClassifier * @param outputFileResults * @param termWindowSize * @param pipe * @return */ private static List<ClassificationResult> runTrainingAndClassification( String targetTerm, String sourceFile, String trainingAlgo, String outputFileClassifier, String outputFileResults, int termWindowSize, Pipe pipe, boolean useCollocationalVector) { // Read in concordance file and create list of Mallet training instances // TODO: Remove duplication of code (see execConvertToMalletFormat(...)) String vectorType = useCollocationalVector ? "coll" : "bow"; InstanceList instanceList = readConcordanceFileToInstanceList( targetTerm, sourceFile, termWindowSize, pipe, useCollocationalVector); // Creating splits for training and testing double[] proportions = {0.9, 0.1}; InstanceList[] splitLists = instanceList.split(proportions); InstanceList trainingList = splitLists[0]; InstanceList testList = splitLists[1]; // Train the classifier ClassifierTrainer classifierTrainer = getClassifierTrainerForAlgorithm(trainingAlgo); Classifier classifier = classifierTrainer.train(trainingList); if (classifier.getLabelAlphabet() != null) { // TODO: Make sure this is not null in RandomClassifier System.out.println("Labels:\n" + classifier.getLabelAlphabet()); System.out.println( "Size of data alphabet (= type count of training list): " + classifier.getAlphabet().size()); } // Run tests and get results Trial trial = new Trial(classifier, testList); List<ClassificationResult> results = new ArrayList<ClassificationResult>(); for (int i = 0; i < classifier.getLabelAlphabet().size(); i++) { Label label = classifier.getLabelAlphabet().lookupLabel(i); ClassificationResult result = new MalletClassificationResult( trainingAlgo, targetTerm, vectorType, label.toString(), termWindowSize, trial, sourceFile); results.add(result); System.out.println(result.toString()); } // Save classifier saveClassifierToFile(outputFileClassifier, classifier, trainingAlgo, termWindowSize); return results; }
/** * Train a classifier * * @param trainingInstances * @param trainingPortion The percentage to be used for training (<=1.0), the rest is used for * testing. * @return */ public Classifier train(InstanceList trainingInstances, double trainingPortion) { InstanceList[] instanceLists = trainingInstances.split( new Random(), new double[] {trainingPortion, (1 - trainingPortion)}); // InstanceList[] instanceLists = // trainingInstances.splitInOrder(new double[]{trainingPortion, (1-trainingPortion)}); return this.train(instanceLists[0], instanceLists[1]); }
/** * Command-line wrapper to train, test, or run a generic CRF-based tagger. * * @param args the command line arguments. Options (shell and Java quoting should be added as * needed): * <dl> * <dt><code>--help</code> <em>boolean</em> * <dd>Print this command line option usage information. Give <code>true</code> for longer * documentation. Default is <code>false</code>. * <dt><code>--prefix-code</code> <em>Java-code</em> * <dd>Java code you want run before any other interpreted code. Note that the text is * interpreted without modification, so unlike some other Java code options, you need to * include any necessary 'new's. Default is null. * <dt><code>--gaussian-variance</code> <em>positive-number</em> * <dd>The Gaussian prior variance used for training. Default is 10.0. * <dt><code>--train</code> <em>boolean</em> * <dd>Whether to train. Default is <code>false</code>. * <dt><code>--iterations</code> <em>positive-integer</em> * <dd>Number of training iterations. Default is 500. * <dt><code>--test</code> <code>lab</code> or <code>seg=</code><em>start-1</em><code>. * </code><em>continue-1</em><code>,</code>...<code>,</code><em>start-n</em><code>. * </code><em>continue-n</em> * <dd>Test measuring labeling or segmentation (<em>start-i</em>, <em>continue-i</em>) * accuracy. Default is no testing. * <dt><code>--training-proportion</code> <em>number-between-0-and-1</em> * <dd>Fraction of data to use for training in a random split. Default is 0.5. * <dt><code>--model-file</code> <em>filename</em> * <dd>The filename for reading (train/run) or saving (train) the model. Default is null. * <dt><code>--random-seed</code> <em>integer</em> * <dd>The random seed for randomly selecting a proportion of the instance list for training * Default is 0. * <dt><code>--orders</code> <em>comma-separated-integers</em> * <dd>List of label Markov orders (main and backoff) Default is 1. * <dt><code>--forbidden</code> <em>regular-expression</em> * <dd>If <em>label-1</em><code>,</code><em>label-2</em> matches the expression, the * corresponding transition is forbidden. Default is <code>\\s</code> (nothing * forbidden). * <dt><code>--allowed</code> <em>regular-expression</em> * <dd>If <em>label-1</em><code>,</code><em>label-2</em> does not match the expression, the * corresponding expression is forbidden. Default is <code>.*</code> (everything * allowed). * <dt><code>--default-label</code> <em>string</em> * <dd>Label for initial context and uninteresting tokens. Default is <code>O</code>. * <dt><code>--viterbi-output</code> <em>boolean</em> * <dd>Print Viterbi periodically during training. Default is <code>false</code>. * <dt><code>--fully-connected</code> <em>boolean</em> * <dd>Include all allowed transitions, even those not in training data. Default is <code> * true</code>. * <dt><code>--n-best</code> <em>positive-integer</em> * <dd>Number of answers to output when applying model. Default is 1. * <dt><code>--include-input</code> <em>boolean</em> * <dd>Whether to include input features when printing decoding output. Default is <code> * false</code>. * </dl> * Remaining arguments: * <ul> * <li><em>training-data-file</em> if training * <li><em>training-and-test-data-file</em>, if training and testing with random split * <li><em>training-data-file</em> <em>test-data-file</em> if training and testing from * separate files * <li><em>test-data-file</em> if testing * <li><em>input-data-file</em> if applying to new data (unlabeled) * </ul> * * @exception Exception if an error occurs */ public static void main(String[] args) throws Exception { Reader trainingFile = null, testFile = null; InstanceList trainingData = null, testData = null; int numEvaluations = 0; int iterationsBetweenEvals = 16; int restArgs = commandOptions.processOptions(args); if (restArgs == args.length) { commandOptions.printUsage(true); throw new IllegalArgumentException("Missing data file(s)"); } if (trainOption.value) { trainingFile = new FileReader(new File(args[restArgs])); if (testOption.value != null && restArgs < args.length - 1) testFile = new FileReader(new File(args[restArgs + 1])); } else testFile = new FileReader(new File(args[restArgs])); Pipe p = null; CRF crf = null; TransducerEvaluator eval = null; if (continueTrainingOption.value || !trainOption.value) { if (modelOption.value == null) { commandOptions.printUsage(true); throw new IllegalArgumentException("Missing model file option"); } ObjectInputStream s = new ObjectInputStream(new FileInputStream(modelOption.value)); crf = (CRF) s.readObject(); s.close(); p = crf.getInputPipe(); } else { p = new SimpleTaggerSentence2FeatureVectorSequence(); p.getTargetAlphabet().lookupIndex(defaultOption.value); } if (trainOption.value) { p.setTargetProcessing(true); trainingData = new InstanceList(p); trainingData.addThruPipe( new LineGroupIterator(trainingFile, Pattern.compile("^\\s*$"), true)); logger.info("Number of features in training data: " + p.getDataAlphabet().size()); if (testOption.value != null) { if (testFile != null) { testData = new InstanceList(p); testData.addThruPipe(new LineGroupIterator(testFile, Pattern.compile("^\\s*$"), true)); } else { Random r = new Random(randomSeedOption.value); InstanceList[] trainingLists = trainingData.split( r, new double[] {trainingFractionOption.value, 1 - trainingFractionOption.value}); trainingData = trainingLists[0]; testData = trainingLists[1]; } } } else if (testOption.value != null) { p.setTargetProcessing(true); testData = new InstanceList(p); testData.addThruPipe(new LineGroupIterator(testFile, Pattern.compile("^\\s*$"), true)); } else { p.setTargetProcessing(false); testData = new InstanceList(p); testData.addThruPipe(new LineGroupIterator(testFile, Pattern.compile("^\\s*$"), true)); } logger.info("Number of predicates: " + p.getDataAlphabet().size()); if (testOption.value != null) { if (testOption.value.startsWith("lab")) eval = new TokenAccuracyEvaluator( new InstanceList[] {trainingData, testData}, new String[] {"Training", "Testing"}); else if (testOption.value.startsWith("seg=")) { String[] pairs = testOption.value.substring(4).split(","); if (pairs.length < 1) { commandOptions.printUsage(true); throw new IllegalArgumentException( "Missing segment start/continue labels: " + testOption.value); } String startTags[] = new String[pairs.length]; String continueTags[] = new String[pairs.length]; for (int i = 0; i < pairs.length; i++) { String[] pair = pairs[i].split("\\."); if (pair.length != 2) { commandOptions.printUsage(true); throw new IllegalArgumentException( "Incorrectly-specified segment start and end labels: " + pairs[i]); } startTags[i] = pair[0]; continueTags[i] = pair[1]; } eval = new MultiSegmentationEvaluator( new InstanceList[] {trainingData, testData}, new String[] {"Training", "Testing"}, startTags, continueTags); } else { commandOptions.printUsage(true); throw new IllegalArgumentException("Invalid test option: " + testOption.value); } } if (p.isTargetProcessing()) { Alphabet targets = p.getTargetAlphabet(); StringBuffer buf = new StringBuffer("Labels:"); for (int i = 0; i < targets.size(); i++) buf.append(" ").append(targets.lookupObject(i).toString()); logger.info(buf.toString()); } if (trainOption.value) { crf = train( trainingData, testData, eval, ordersOption.value, defaultOption.value, forbiddenOption.value, allowedOption.value, connectedOption.value, iterationsOption.value, gaussianVarianceOption.value, crf); if (modelOption.value != null) { ObjectOutputStream s = new ObjectOutputStream(new FileOutputStream(modelOption.value)); s.writeObject(crf); s.close(); } } else { if (crf == null) { if (modelOption.value == null) { commandOptions.printUsage(true); throw new IllegalArgumentException("Missing model file option"); } ObjectInputStream s = new ObjectInputStream(new FileInputStream(modelOption.value)); crf = (CRF) s.readObject(); s.close(); } if (eval != null) test(new NoopTransducerTrainer(crf), eval, testData); else { boolean includeInput = includeInputOption.value(); for (int i = 0; i < testData.size(); i++) { Sequence input = (Sequence) testData.get(i).getData(); Sequence[] outputs = apply(crf, input, nBestOption.value); int k = outputs.length; boolean error = false; for (int a = 0; a < k; a++) { if (outputs[a].size() != input.size()) { System.err.println("Failed to decode input sequence " + i + ", answer " + a); error = true; } } if (!error) { for (int j = 0; j < input.size(); j++) { StringBuffer buf = new StringBuffer(); for (int a = 0; a < k; a++) buf.append(outputs[a].get(j).toString()).append(" "); if (includeInput) { FeatureVector fv = (FeatureVector) input.get(j); buf.append(fv.toString(true)); } System.out.println(buf.toString()); } System.out.println(); } } } } }