/** This is (mostly) copied from CRF4.java */ public boolean[][] labelConnectionsIn( Alphabet outputAlphabet, InstanceList trainingSet, String start) { int numLabels = outputAlphabet.size(); boolean[][] connections = new boolean[numLabels][numLabels]; for (int i = 0; i < trainingSet.size(); i++) { Instance instance = trainingSet.getInstance(i); FeatureSequence output = (FeatureSequence) instance.getTarget(); for (int j = 1; j < output.size(); j++) { int sourceIndex = outputAlphabet.lookupIndex(output.get(j - 1)); int destIndex = outputAlphabet.lookupIndex(output.get(j)); assert (sourceIndex >= 0 && destIndex >= 0); connections[sourceIndex][destIndex] = true; } } // Handle start state if (start != null) { int startIndex = outputAlphabet.lookupIndex(start); for (int j = 0; j < outputAlphabet.size(); j++) { connections[startIndex][j] = true; } } return connections; }
public static void main(String[] args) throws bsh.EvalError, java.io.IOException { // Process the command-line options commandOptions.process(args); System.out.println("Trainer = " + trainerConstructorOption.value.toString()); ClassifierTrainer trainer = (ClassifierTrainer) trainerConstructorOption.value; InstanceList ilist = InstanceList.load(new File(instanceListFilenameOption.value)); Random r = randomSeedOption.wasInvoked() ? new Random(randomSeedOption.value) : new Random(); double t = trainingProportionOption.value; double v = validationProportionOption.value; InstanceList[] ilists = ilist.split(r, new double[] {t, v, 1 - t - v}); System.err.println("Training..."); Classifier c = trainer.train( ilists[0], ilists[1], null, (ClassifierEvaluating) classifierEvaluatorOption.value, null); if (printTrainAccuracyOption.value) System.out.print("Train accuracy = " + c.getAccuracy(ilists[0]) + " "); if (printTestAccuracyOption.value) System.out.print("Test accuracy = " + c.getAccuracy(ilists[2])); if (printTrainAccuracyOption.value || printTestAccuracyOption.value) System.out.println(""); if (outputFilenameOption.wasInvoked()) { try { ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(instanceListFilenameOption.value)); oos.writeObject(c); oos.close(); } catch (Exception e) { e.printStackTrace(); throw new IllegalArgumentException( "Couldn't write classifier to filename " + instanceListFilenameOption.value); } } }
public static CRF4 createCRF(File trainingFile, CRFInfo crfInfo) throws FileNotFoundException { Reader trainingFileReader = new FileReader(trainingFile); // Create a pipe that we can use to convert the training // file to a feature vector sequence. Pipe p = new SimpleTagger.SimpleTaggerSentence2FeatureVectorSequence(); // The training file does contain tags (aka targets) p.setTargetProcessing(true); // Register the default tag with the pipe, by looking it up // in the targetAlphabet before we look up any other tag. p.getTargetAlphabet().lookupIndex(crfInfo.defaultLabel); // Create a new instancelist to hold the training data. InstanceList trainingData = new InstanceList(p); // Read in the training data. trainingData.add(new LineGroupIterator(trainingFileReader, Pattern.compile("^\\s*$"), true)); // Create the CRF model. CRF4 crf = new CRF4(p, null); // Set various config options crf.setGaussianPriorVariance(crfInfo.gaussianVariance); crf.setTransductionType(crfInfo.transductionType); // Set up the model's states. if (crfInfo.stateInfoList != null) { Iterator stateIter = crfInfo.stateInfoList.iterator(); while (stateIter.hasNext()) { CRFInfo.StateInfo state = (CRFInfo.StateInfo) stateIter.next(); crf.addState( state.name, state.initialCost, state.finalCost, state.destinationNames, state.labelNames, state.weightNames); } } else if (crfInfo.stateStructure == CRFInfo.FULLY_CONNECTED_STRUCTURE) crf.addStatesForLabelsConnectedAsIn(trainingData); else if (crfInfo.stateStructure == CRFInfo.HALF_CONNECTED_STRUCTURE) crf.addStatesForHalfLabelsConnectedAsIn(trainingData); else if (crfInfo.stateStructure == CRFInfo.THREE_QUARTERS_CONNECTED_STRUCTURE) crf.addStatesForThreeQuarterLabelsConnectedAsIn(trainingData); else if (crfInfo.stateStructure == CRFInfo.BILABELS_STRUCTURE) crf.addStatesForBiLabelsConnectedAsIn(trainingData); else throw new RuntimeException("Unexpected state structure " + crfInfo.stateStructure); // Set up the weight groups. if (crfInfo.weightGroupInfoList != null) { Iterator wgIter = crfInfo.weightGroupInfoList.iterator(); while (wgIter.hasNext()) { CRFInfo.WeightGroupInfo wg = (CRFInfo.WeightGroupInfo) wgIter.next(); FeatureSelection fs = FeatureSelection.createFromRegex( crf.getInputAlphabet(), Pattern.compile(wg.featureSelectionRegex)); crf.setFeatureSelection(crf.getWeightsIndex(wg.name), fs); } } // Train the CRF. crf.train(trainingData, null, null, null, crfInfo.maxIterations); return crf; }
public void test( Transducer transducer, InstanceList data, String description, PrintStream viterbiOutputStream) { int[] ntrue = new int[segmentTags.length]; int[] npred = new int[segmentTags.length]; int[] ncorr = new int[segmentTags.length]; LabelAlphabet dict = (LabelAlphabet) transducer.getInputPipe().getTargetAlphabet(); for (int i = 0; i < data.size(); i++) { Instance instance = data.getInstance(i); Sequence input = (Sequence) instance.getData(); Sequence trueOutput = (Sequence) instance.getTarget(); assert (input.size() == trueOutput.size()); Sequence predOutput = transducer.viterbiPath(input).output(); assert (predOutput.size() == trueOutput.size()); List trueSegs = new ArrayList(); List predSegs = new ArrayList(); addSegs(trueSegs, trueOutput); addSegs(predSegs, predOutput); // System.out.println("FieldF1Evaluator instance "+instance.getName ()); // printSegs(dict, trueSegs, "True"); // printSegs(dict, predSegs, "Pred"); for (Iterator it = predSegs.iterator(); it.hasNext(); ) { Segment seg = (Segment) it.next(); npred[seg.tag]++; if (trueSegs.contains(seg)) { ncorr[seg.tag]++; } } for (Iterator it = trueSegs.iterator(); it.hasNext(); ) { Segment seg = (Segment) it.next(); ntrue[seg.tag]++; } } DecimalFormat f = new DecimalFormat("0.####"); logger.info(description + " per-field F1"); for (int tag = 0; tag < segmentTags.length; tag++) { double precision = ((double) ncorr[tag]) / npred[tag]; double recall = ((double) ncorr[tag]) / ntrue[tag]; double f1 = (2 * precision * recall) / (precision + recall); Label name = dict.lookupLabel(segmentTags[tag]); logger.info( " segments " + name + " true = " + ntrue[tag] + " pred = " + npred[tag] + " correct = " + ncorr[tag]); logger.info( " precision=" + f.format(precision) + " recall=" + f.format(recall) + " f1=" + f.format(f1)); } }