Example #1
0
  /** This is (mostly) copied from CRF4.java */
  public boolean[][] labelConnectionsIn(
      Alphabet outputAlphabet, InstanceList trainingSet, String start) {
    int numLabels = outputAlphabet.size();
    boolean[][] connections = new boolean[numLabels][numLabels];
    for (int i = 0; i < trainingSet.size(); i++) {
      Instance instance = trainingSet.getInstance(i);
      FeatureSequence output = (FeatureSequence) instance.getTarget();
      for (int j = 1; j < output.size(); j++) {
        int sourceIndex = outputAlphabet.lookupIndex(output.get(j - 1));
        int destIndex = outputAlphabet.lookupIndex(output.get(j));
        assert (sourceIndex >= 0 && destIndex >= 0);
        connections[sourceIndex][destIndex] = true;
      }
    }

    // Handle start state
    if (start != null) {
      int startIndex = outputAlphabet.lookupIndex(start);
      for (int j = 0; j < outputAlphabet.size(); j++) {
        connections[startIndex][j] = true;
      }
    }

    return connections;
  }
  public static void main(String[] args) throws bsh.EvalError, java.io.IOException {
    // Process the command-line options
    commandOptions.process(args);

    System.out.println("Trainer = " + trainerConstructorOption.value.toString());
    ClassifierTrainer trainer = (ClassifierTrainer) trainerConstructorOption.value;
    InstanceList ilist = InstanceList.load(new File(instanceListFilenameOption.value));

    Random r = randomSeedOption.wasInvoked() ? new Random(randomSeedOption.value) : new Random();
    double t = trainingProportionOption.value;
    double v = validationProportionOption.value;
    InstanceList[] ilists = ilist.split(r, new double[] {t, v, 1 - t - v});
    System.err.println("Training...");
    Classifier c =
        trainer.train(
            ilists[0],
            ilists[1],
            null,
            (ClassifierEvaluating) classifierEvaluatorOption.value,
            null);
    if (printTrainAccuracyOption.value)
      System.out.print("Train accuracy = " + c.getAccuracy(ilists[0]) + "  ");
    if (printTestAccuracyOption.value)
      System.out.print("Test accuracy  = " + c.getAccuracy(ilists[2]));
    if (printTrainAccuracyOption.value || printTestAccuracyOption.value) System.out.println("");
    if (outputFilenameOption.wasInvoked()) {
      try {
        ObjectOutputStream oos =
            new ObjectOutputStream(new FileOutputStream(instanceListFilenameOption.value));
        oos.writeObject(c);
        oos.close();
      } catch (Exception e) {
        e.printStackTrace();
        throw new IllegalArgumentException(
            "Couldn't write classifier to filename " + instanceListFilenameOption.value);
      }
    }
  }
Example #3
0
  public static CRF4 createCRF(File trainingFile, CRFInfo crfInfo) throws FileNotFoundException {
    Reader trainingFileReader = new FileReader(trainingFile);

    // Create a pipe that we can use to convert the training
    // file to a feature vector sequence.
    Pipe p = new SimpleTagger.SimpleTaggerSentence2FeatureVectorSequence();

    // The training file does contain tags (aka targets)
    p.setTargetProcessing(true);

    // Register the default tag with the pipe, by looking it up
    // in the targetAlphabet before we look up any other tag.
    p.getTargetAlphabet().lookupIndex(crfInfo.defaultLabel);

    // Create a new instancelist to hold the training data.
    InstanceList trainingData = new InstanceList(p);

    // Read in the training data.
    trainingData.add(new LineGroupIterator(trainingFileReader, Pattern.compile("^\\s*$"), true));

    // Create the CRF model.
    CRF4 crf = new CRF4(p, null);

    // Set various config options
    crf.setGaussianPriorVariance(crfInfo.gaussianVariance);
    crf.setTransductionType(crfInfo.transductionType);

    // Set up the model's states.
    if (crfInfo.stateInfoList != null) {
      Iterator stateIter = crfInfo.stateInfoList.iterator();
      while (stateIter.hasNext()) {
        CRFInfo.StateInfo state = (CRFInfo.StateInfo) stateIter.next();
        crf.addState(
            state.name,
            state.initialCost,
            state.finalCost,
            state.destinationNames,
            state.labelNames,
            state.weightNames);
      }
    } else if (crfInfo.stateStructure == CRFInfo.FULLY_CONNECTED_STRUCTURE)
      crf.addStatesForLabelsConnectedAsIn(trainingData);
    else if (crfInfo.stateStructure == CRFInfo.HALF_CONNECTED_STRUCTURE)
      crf.addStatesForHalfLabelsConnectedAsIn(trainingData);
    else if (crfInfo.stateStructure == CRFInfo.THREE_QUARTERS_CONNECTED_STRUCTURE)
      crf.addStatesForThreeQuarterLabelsConnectedAsIn(trainingData);
    else if (crfInfo.stateStructure == CRFInfo.BILABELS_STRUCTURE)
      crf.addStatesForBiLabelsConnectedAsIn(trainingData);
    else throw new RuntimeException("Unexpected state structure " + crfInfo.stateStructure);

    // Set up the weight groups.
    if (crfInfo.weightGroupInfoList != null) {
      Iterator wgIter = crfInfo.weightGroupInfoList.iterator();
      while (wgIter.hasNext()) {
        CRFInfo.WeightGroupInfo wg = (CRFInfo.WeightGroupInfo) wgIter.next();
        FeatureSelection fs =
            FeatureSelection.createFromRegex(
                crf.getInputAlphabet(), Pattern.compile(wg.featureSelectionRegex));
        crf.setFeatureSelection(crf.getWeightsIndex(wg.name), fs);
      }
    }

    // Train the CRF.
    crf.train(trainingData, null, null, null, crfInfo.maxIterations);

    return crf;
  }
  public void test(
      Transducer transducer,
      InstanceList data,
      String description,
      PrintStream viterbiOutputStream) {
    int[] ntrue = new int[segmentTags.length];
    int[] npred = new int[segmentTags.length];
    int[] ncorr = new int[segmentTags.length];

    LabelAlphabet dict = (LabelAlphabet) transducer.getInputPipe().getTargetAlphabet();

    for (int i = 0; i < data.size(); i++) {
      Instance instance = data.getInstance(i);
      Sequence input = (Sequence) instance.getData();
      Sequence trueOutput = (Sequence) instance.getTarget();
      assert (input.size() == trueOutput.size());
      Sequence predOutput = transducer.viterbiPath(input).output();
      assert (predOutput.size() == trueOutput.size());

      List trueSegs = new ArrayList();
      List predSegs = new ArrayList();

      addSegs(trueSegs, trueOutput);
      addSegs(predSegs, predOutput);

      //      System.out.println("FieldF1Evaluator instance "+instance.getName ());
      //      printSegs(dict, trueSegs, "True");
      //      printSegs(dict, predSegs, "Pred");

      for (Iterator it = predSegs.iterator(); it.hasNext(); ) {
        Segment seg = (Segment) it.next();
        npred[seg.tag]++;
        if (trueSegs.contains(seg)) {
          ncorr[seg.tag]++;
        }
      }

      for (Iterator it = trueSegs.iterator(); it.hasNext(); ) {
        Segment seg = (Segment) it.next();
        ntrue[seg.tag]++;
      }
    }

    DecimalFormat f = new DecimalFormat("0.####");
    logger.info(description + " per-field F1");
    for (int tag = 0; tag < segmentTags.length; tag++) {
      double precision = ((double) ncorr[tag]) / npred[tag];
      double recall = ((double) ncorr[tag]) / ntrue[tag];
      double f1 = (2 * precision * recall) / (precision + recall);
      Label name = dict.lookupLabel(segmentTags[tag]);
      logger.info(
          " segments "
              + name
              + "  true = "
              + ntrue[tag]
              + "  pred = "
              + npred[tag]
              + "  correct = "
              + ncorr[tag]);
      logger.info(
          " precision="
              + f.format(precision)
              + " recall="
              + f.format(recall)
              + " f1="
              + f.format(f1));
    }
  }