/** * Create and train a CRF model from the given training data, optionally testing it on the given * test data. * * @param training training data * @param testing test data (possibly <code>null</code>) * @param eval accuracy evaluator (possibly <code>null</code>) * @param orders label Markov orders (main and backoff) * @param defaultLabel default label * @param forbidden regular expression specifying impossible label transitions <em>current</em> * <code>,</code><em>next</em> (<code>null</code> indicates no forbidden transitions) * @param allowed regular expression specifying allowed label transitions (<code>null</code> * indicates everything is allowed that is not forbidden) * @param connected whether to include even transitions not occurring in the training data. * @param iterations number of training iterations * @param var Gaussian prior variance * @return the trained model */ public static CRF train( InstanceList training, InstanceList testing, TransducerEvaluator eval, int[] orders, String defaultLabel, String forbidden, String allowed, boolean connected, int iterations, double var, CRF crf) { Pattern forbiddenPat = Pattern.compile(forbidden); Pattern allowedPat = Pattern.compile(allowed); if (crf == null) { crf = new CRF(training.getPipe(), (Pipe) null); String startName = crf.addOrderNStates( training, orders, null, defaultLabel, forbiddenPat, allowedPat, connected); CRFTrainerByLabelLikelihood crft = new CRFTrainerByLabelLikelihood(crf); crft.setGaussianPriorVariance(var); for (int i = 0; i < crf.numStates(); i++) crf.getState(i).setInitialWeight(Transducer.IMPOSSIBLE_WEIGHT); crf.getState(startName).setInitialWeight(0.0); } logger.info("Training on " + training.size() + " instances"); if (testing != null) logger.info("Testing on " + testing.size() + " instances"); CRFTrainerByLabelLikelihood crft = new CRFTrainerByLabelLikelihood(crf); if (featureInductionOption.value) { crft.trainWithFeatureInduction( training, null, testing, eval, iterations, 10, 20, 500, 0.5, false, null); } else { boolean converged; for (int i = 1; i <= iterations; i++) { converged = crft.train(training, 1); if (i % 1 == 0 && eval != null) // Change the 1 to higher integer to evaluate less often eval.evaluate(crft); if (viterbiOutputOption.value && i % 10 == 0) new ViterbiWriter( "", new InstanceList[] {training, testing}, new String[] {"training", "testing"}) .evaluate(crft); if (converged) break; } } return crf; }
public void train() throws Exception { /* * Read the training dataset into an object which implements DataIter * interface(trainData). Each of the training instance is encapsulated * in the object which provides DataSequence interface. The DataIter * interface returns object of DataSequence (training instance) in * next() routine. */ DataIterImpl trainData = new DataIterImpl(); /* * Once you have loaded the training dataset, you need to allocate * objects for the model to be learned. allocmodel() method does that * allocation. */ allocModel(); /* * You may need to train some of the feature types class. This training * is needed for features which need to learn from the training data for * instance dictionary features build generated from the training set. */ featureGen.train(trainData); /* * Call train routine of the CRF model to train the model using the * train data. This routine returns the learned weight for the features. */ double featureWts[] = crfModel.train(trainData); /* * You can store the learned model for later use into disk. For this you * will have to store features as well as their corresponding weights. */ crfModel.write(baseDir + "/learntModels/" + outDir + "/crf"); featureGen.write(baseDir + "/learntModels/" + outDir + "/features"); }
public void test() throws Exception { /* * Read the test dataset. Each of the test instance is encapsulated in the * object which provides DataSequence interface. */ /* * Once you have loaded the test dataset, you need to allocate objects * for the model to be learned. allocmodel() method does that allocation. * Also, you need to read learned parameters from the disk stored after * training. If the model is already available in the memory, then you do * not need to reallocate the model i.e. you can skip the next step in that * case. */ allocModel(); featureGen.read(baseDir+"/learntModels/"+outDir+"/features"); crfModel.read(baseDir+"/learntModels/"+outDir+"/crf"); /* * Iterate over test data set and apply the crf model to each test instance. */ while(...) { /* * Now apply CRF model to each test instance. */ crfModel.apply(testRecord); /* * The labeled instance have value of the states as labels. * These state values are not labels as supplied during training. * To map this state to one of the labels you need to call following * method on the labled testRecord. */ featureGen.mapStatesToLabels(testRecord); } }
/** * Command-line wrapper to train, test, or run a generic CRF-based tagger. * * @param args the command line arguments. Options (shell and Java quoting should be added as * needed): * <dl> * <dt><code>--help</code> <em>boolean</em> * <dd>Print this command line option usage information. Give <code>true</code> for longer * documentation. Default is <code>false</code>. * <dt><code>--prefix-code</code> <em>Java-code</em> * <dd>Java code you want run before any other interpreted code. Note that the text is * interpreted without modification, so unlike some other Java code options, you need to * include any necessary 'new's. Default is null. * <dt><code>--gaussian-variance</code> <em>positive-number</em> * <dd>The Gaussian prior variance used for training. Default is 10.0. * <dt><code>--train</code> <em>boolean</em> * <dd>Whether to train. Default is <code>false</code>. * <dt><code>--iterations</code> <em>positive-integer</em> * <dd>Number of training iterations. Default is 500. * <dt><code>--test</code> <code>lab</code> or <code>seg=</code><em>start-1</em><code>. * </code><em>continue-1</em><code>,</code>...<code>,</code><em>start-n</em><code>. * </code><em>continue-n</em> * <dd>Test measuring labeling or segmentation (<em>start-i</em>, <em>continue-i</em>) * accuracy. Default is no testing. * <dt><code>--training-proportion</code> <em>number-between-0-and-1</em> * <dd>Fraction of data to use for training in a random split. Default is 0.5. * <dt><code>--model-file</code> <em>filename</em> * <dd>The filename for reading (train/run) or saving (train) the model. Default is null. * <dt><code>--random-seed</code> <em>integer</em> * <dd>The random seed for randomly selecting a proportion of the instance list for training * Default is 0. * <dt><code>--orders</code> <em>comma-separated-integers</em> * <dd>List of label Markov orders (main and backoff) Default is 1. * <dt><code>--forbidden</code> <em>regular-expression</em> * <dd>If <em>label-1</em><code>,</code><em>label-2</em> matches the expression, the * corresponding transition is forbidden. Default is <code>\\s</code> (nothing * forbidden). * <dt><code>--allowed</code> <em>regular-expression</em> * <dd>If <em>label-1</em><code>,</code><em>label-2</em> does not match the expression, the * corresponding expression is forbidden. Default is <code>.*</code> (everything * allowed). * <dt><code>--default-label</code> <em>string</em> * <dd>Label for initial context and uninteresting tokens. Default is <code>O</code>. * <dt><code>--viterbi-output</code> <em>boolean</em> * <dd>Print Viterbi periodically during training. Default is <code>false</code>. * <dt><code>--fully-connected</code> <em>boolean</em> * <dd>Include all allowed transitions, even those not in training data. Default is <code> * true</code>. * <dt><code>--n-best</code> <em>positive-integer</em> * <dd>Number of answers to output when applying model. Default is 1. * <dt><code>--include-input</code> <em>boolean</em> * <dd>Whether to include input features when printing decoding output. Default is <code> * false</code>. * </dl> * Remaining arguments: * <ul> * <li><em>training-data-file</em> if training * <li><em>training-and-test-data-file</em>, if training and testing with random split * <li><em>training-data-file</em> <em>test-data-file</em> if training and testing from * separate files * <li><em>test-data-file</em> if testing * <li><em>input-data-file</em> if applying to new data (unlabeled) * </ul> * * @exception Exception if an error occurs */ public static void main(String[] args) throws Exception { Reader trainingFile = null, testFile = null; InstanceList trainingData = null, testData = null; int numEvaluations = 0; int iterationsBetweenEvals = 16; int restArgs = commandOptions.processOptions(args); if (restArgs == args.length) { commandOptions.printUsage(true); throw new IllegalArgumentException("Missing data file(s)"); } if (trainOption.value) { trainingFile = new FileReader(new File(args[restArgs])); if (testOption.value != null && restArgs < args.length - 1) testFile = new FileReader(new File(args[restArgs + 1])); } else testFile = new FileReader(new File(args[restArgs])); Pipe p = null; CRF crf = null; TransducerEvaluator eval = null; if (continueTrainingOption.value || !trainOption.value) { if (modelOption.value == null) { commandOptions.printUsage(true); throw new IllegalArgumentException("Missing model file option"); } ObjectInputStream s = new ObjectInputStream(new FileInputStream(modelOption.value)); crf = (CRF) s.readObject(); s.close(); p = crf.getInputPipe(); } else { p = new SimpleTaggerSentence2FeatureVectorSequence(); p.getTargetAlphabet().lookupIndex(defaultOption.value); } if (trainOption.value) { p.setTargetProcessing(true); trainingData = new InstanceList(p); trainingData.addThruPipe( new LineGroupIterator(trainingFile, Pattern.compile("^\\s*$"), true)); logger.info("Number of features in training data: " + p.getDataAlphabet().size()); if (testOption.value != null) { if (testFile != null) { testData = new InstanceList(p); testData.addThruPipe(new LineGroupIterator(testFile, Pattern.compile("^\\s*$"), true)); } else { Random r = new Random(randomSeedOption.value); InstanceList[] trainingLists = trainingData.split( r, new double[] {trainingFractionOption.value, 1 - trainingFractionOption.value}); trainingData = trainingLists[0]; testData = trainingLists[1]; } } } else if (testOption.value != null) { p.setTargetProcessing(true); testData = new InstanceList(p); testData.addThruPipe(new LineGroupIterator(testFile, Pattern.compile("^\\s*$"), true)); } else { p.setTargetProcessing(false); testData = new InstanceList(p); testData.addThruPipe(new LineGroupIterator(testFile, Pattern.compile("^\\s*$"), true)); } logger.info("Number of predicates: " + p.getDataAlphabet().size()); if (testOption.value != null) { if (testOption.value.startsWith("lab")) eval = new TokenAccuracyEvaluator( new InstanceList[] {trainingData, testData}, new String[] {"Training", "Testing"}); else if (testOption.value.startsWith("seg=")) { String[] pairs = testOption.value.substring(4).split(","); if (pairs.length < 1) { commandOptions.printUsage(true); throw new IllegalArgumentException( "Missing segment start/continue labels: " + testOption.value); } String startTags[] = new String[pairs.length]; String continueTags[] = new String[pairs.length]; for (int i = 0; i < pairs.length; i++) { String[] pair = pairs[i].split("\\."); if (pair.length != 2) { commandOptions.printUsage(true); throw new IllegalArgumentException( "Incorrectly-specified segment start and end labels: " + pairs[i]); } startTags[i] = pair[0]; continueTags[i] = pair[1]; } eval = new MultiSegmentationEvaluator( new InstanceList[] {trainingData, testData}, new String[] {"Training", "Testing"}, startTags, continueTags); } else { commandOptions.printUsage(true); throw new IllegalArgumentException("Invalid test option: " + testOption.value); } } if (p.isTargetProcessing()) { Alphabet targets = p.getTargetAlphabet(); StringBuffer buf = new StringBuffer("Labels:"); for (int i = 0; i < targets.size(); i++) buf.append(" ").append(targets.lookupObject(i).toString()); logger.info(buf.toString()); } if (trainOption.value) { crf = train( trainingData, testData, eval, ordersOption.value, defaultOption.value, forbiddenOption.value, allowedOption.value, connectedOption.value, iterationsOption.value, gaussianVarianceOption.value, crf); if (modelOption.value != null) { ObjectOutputStream s = new ObjectOutputStream(new FileOutputStream(modelOption.value)); s.writeObject(crf); s.close(); } } else { if (crf == null) { if (modelOption.value == null) { commandOptions.printUsage(true); throw new IllegalArgumentException("Missing model file option"); } ObjectInputStream s = new ObjectInputStream(new FileInputStream(modelOption.value)); crf = (CRF) s.readObject(); s.close(); } if (eval != null) test(new NoopTransducerTrainer(crf), eval, testData); else { boolean includeInput = includeInputOption.value(); for (int i = 0; i < testData.size(); i++) { Sequence input = (Sequence) testData.get(i).getData(); Sequence[] outputs = apply(crf, input, nBestOption.value); int k = outputs.length; boolean error = false; for (int a = 0; a < k; a++) { if (outputs[a].size() != input.size()) { System.err.println("Failed to decode input sequence " + i + ", answer " + a); error = true; } } if (!error) { for (int j = 0; j < input.size(); j++) { StringBuffer buf = new StringBuffer(); for (int a = 0; a < k; a++) buf.append(outputs[a].get(j).toString()).append(" "); if (includeInput) { FeatureVector fv = (FeatureVector) input.get(j); buf.append(fv.toString(true)); } System.out.println(buf.toString()); } System.out.println(); } } } } }