private void initCrfs(File referenceCrfFile, File headerCrfFile) {
    try {
      if (referenceCrfFile != null) {
        if (referenceCrfFile.getName().endsWith(".partial")) {
          _referencesExtractor = loadPartiallyTrainedModel(referenceCrfFile);
          log.info("loaded partial reference crf '" + referenceCrfFile.getPath() + "'");
        } else {
          _referencesExtractor = loadCrfExtor(referenceCrfFile);
          Pipe pipe = _referencesExtractor.getFeaturePipe();
          pipe.setTargetProcessing(false);
          log.info("loaded reference crf '" + referenceCrfFile.getPath() + "'");
        }
      }
    } catch (Exception e) {
      log.error("couldn't init crf '" + referenceCrfFile.getPath() + "', continuing..." + e);
      _referencesExtractor = null;
    }

    try {
      if (headerCrfFile != null) {
        if (headerCrfFile.getName().endsWith(".partial")) {
          _headersExtractor = loadPartiallyTrainedModel(headerCrfFile);
          log.info("loaded partial header crf '" + headerCrfFile.getPath() + "'");
        } else {
          _headersExtractor = loadCrfExtor(headerCrfFile);
          Pipe pipe = _headersExtractor.getFeaturePipe();
          pipe.setTargetProcessing(false);
          log.info("loaded header crf '" + headerCrfFile.getPath() + "'");
        }
      }
    } catch (Exception e) {
      log.error("couldn't init crf '" + headerCrfFile.getPath() + "', continuing...");
      _headersExtractor = null;
    }
  }
 private CRFExtractor loadPartiallyTrainedModel(File crfFile)
     throws IOException, ClassNotFoundException {
   CRF4 crf =
       (CRF4)
           new ObjectInputStream(new BufferedInputStream(new FileInputStream(crfFile)))
               .readObject();
   // Create the trivial tokenization pipe
   Pipe tokPipe =
       new SerialPipes(
           new Pipe[] {
             new Noop(),
           });
   tokPipe.setTargetProcessing(false);
   return new CRFExtractor(crf, tokPipe);
 }
예제 #3
0
  public static CRF4 createCRF(File trainingFile, CRFInfo crfInfo) throws FileNotFoundException {
    Reader trainingFileReader = new FileReader(trainingFile);

    // Create a pipe that we can use to convert the training
    // file to a feature vector sequence.
    Pipe p = new SimpleTagger.SimpleTaggerSentence2FeatureVectorSequence();

    // The training file does contain tags (aka targets)
    p.setTargetProcessing(true);

    // Register the default tag with the pipe, by looking it up
    // in the targetAlphabet before we look up any other tag.
    p.getTargetAlphabet().lookupIndex(crfInfo.defaultLabel);

    // Create a new instancelist to hold the training data.
    InstanceList trainingData = new InstanceList(p);

    // Read in the training data.
    trainingData.add(new LineGroupIterator(trainingFileReader, Pattern.compile("^\\s*$"), true));

    // Create the CRF model.
    CRF4 crf = new CRF4(p, null);

    // Set various config options
    crf.setGaussianPriorVariance(crfInfo.gaussianVariance);
    crf.setTransductionType(crfInfo.transductionType);

    // Set up the model's states.
    if (crfInfo.stateInfoList != null) {
      Iterator stateIter = crfInfo.stateInfoList.iterator();
      while (stateIter.hasNext()) {
        CRFInfo.StateInfo state = (CRFInfo.StateInfo) stateIter.next();
        crf.addState(
            state.name,
            state.initialCost,
            state.finalCost,
            state.destinationNames,
            state.labelNames,
            state.weightNames);
      }
    } else if (crfInfo.stateStructure == CRFInfo.FULLY_CONNECTED_STRUCTURE)
      crf.addStatesForLabelsConnectedAsIn(trainingData);
    else if (crfInfo.stateStructure == CRFInfo.HALF_CONNECTED_STRUCTURE)
      crf.addStatesForHalfLabelsConnectedAsIn(trainingData);
    else if (crfInfo.stateStructure == CRFInfo.THREE_QUARTERS_CONNECTED_STRUCTURE)
      crf.addStatesForThreeQuarterLabelsConnectedAsIn(trainingData);
    else if (crfInfo.stateStructure == CRFInfo.BILABELS_STRUCTURE)
      crf.addStatesForBiLabelsConnectedAsIn(trainingData);
    else throw new RuntimeException("Unexpected state structure " + crfInfo.stateStructure);

    // Set up the weight groups.
    if (crfInfo.weightGroupInfoList != null) {
      Iterator wgIter = crfInfo.weightGroupInfoList.iterator();
      while (wgIter.hasNext()) {
        CRFInfo.WeightGroupInfo wg = (CRFInfo.WeightGroupInfo) wgIter.next();
        FeatureSelection fs =
            FeatureSelection.createFromRegex(
                crf.getInputAlphabet(), Pattern.compile(wg.featureSelectionRegex));
        crf.setFeatureSelection(crf.getWeightsIndex(wg.name), fs);
      }
    }

    // Train the CRF.
    crf.train(trainingData, null, null, null, crfInfo.maxIterations);

    return crf;
  }