示例#1
0
 public HMM(Pipe inputPipe, Pipe outputPipe) {
   this.inputPipe = inputPipe;
   this.outputPipe = outputPipe;
   this.inputAlphabet = inputPipe.getDataAlphabet();
   this.outputAlphabet = inputPipe.getTargetAlphabet();
 }
示例#2
0
  /**
   * Command-line wrapper to train, test, or run a generic CRF-based tagger.
   *
   * @param args the command line arguments. Options (shell and Java quoting should be added as
   *     needed):
   *     <dl>
   *       <dt><code>--help</code> <em>boolean</em>
   *       <dd>Print this command line option usage information. Give <code>true</code> for longer
   *           documentation. Default is <code>false</code>.
   *       <dt><code>--prefix-code</code> <em>Java-code</em>
   *       <dd>Java code you want run before any other interpreted code. Note that the text is
   *           interpreted without modification, so unlike some other Java code options, you need to
   *           include any necessary 'new's. Default is null.
   *       <dt><code>--gaussian-variance</code> <em>positive-number</em>
   *       <dd>The Gaussian prior variance used for training. Default is 10.0.
   *       <dt><code>--train</code> <em>boolean</em>
   *       <dd>Whether to train. Default is <code>false</code>.
   *       <dt><code>--iterations</code> <em>positive-integer</em>
   *       <dd>Number of training iterations. Default is 500.
   *       <dt><code>--test</code> <code>lab</code> or <code>seg=</code><em>start-1</em><code>.
   *           </code><em>continue-1</em><code>,</code>...<code>,</code><em>start-n</em><code>.
   *           </code><em>continue-n</em>
   *       <dd>Test measuring labeling or segmentation (<em>start-i</em>, <em>continue-i</em>)
   *           accuracy. Default is no testing.
   *       <dt><code>--training-proportion</code> <em>number-between-0-and-1</em>
   *       <dd>Fraction of data to use for training in a random split. Default is 0.5.
   *       <dt><code>--model-file</code> <em>filename</em>
   *       <dd>The filename for reading (train/run) or saving (train) the model. Default is null.
   *       <dt><code>--random-seed</code> <em>integer</em>
   *       <dd>The random seed for randomly selecting a proportion of the instance list for training
   *           Default is 0.
   *       <dt><code>--orders</code> <em>comma-separated-integers</em>
   *       <dd>List of label Markov orders (main and backoff) Default is 1.
   *       <dt><code>--forbidden</code> <em>regular-expression</em>
   *       <dd>If <em>label-1</em><code>,</code><em>label-2</em> matches the expression, the
   *           corresponding transition is forbidden. Default is <code>\\s</code> (nothing
   *           forbidden).
   *       <dt><code>--allowed</code> <em>regular-expression</em>
   *       <dd>If <em>label-1</em><code>,</code><em>label-2</em> does not match the expression, the
   *           corresponding expression is forbidden. Default is <code>.*</code> (everything
   *           allowed).
   *       <dt><code>--default-label</code> <em>string</em>
   *       <dd>Label for initial context and uninteresting tokens. Default is <code>O</code>.
   *       <dt><code>--viterbi-output</code> <em>boolean</em>
   *       <dd>Print Viterbi periodically during training. Default is <code>false</code>.
   *       <dt><code>--fully-connected</code> <em>boolean</em>
   *       <dd>Include all allowed transitions, even those not in training data. Default is <code>
   *           true</code>.
   *       <dt><code>--n-best</code> <em>positive-integer</em>
   *       <dd>Number of answers to output when applying model. Default is 1.
   *       <dt><code>--include-input</code> <em>boolean</em>
   *       <dd>Whether to include input features when printing decoding output. Default is <code>
   *           false</code>.
   *     </dl>
   *     Remaining arguments:
   *     <ul>
   *       <li><em>training-data-file</em> if training
   *       <li><em>training-and-test-data-file</em>, if training and testing with random split
   *       <li><em>training-data-file</em> <em>test-data-file</em> if training and testing from
   *           separate files
   *       <li><em>test-data-file</em> if testing
   *       <li><em>input-data-file</em> if applying to new data (unlabeled)
   *     </ul>
   *
   * @exception Exception if an error occurs
   */
  public static void main(String[] args) throws Exception {
    Reader trainingFile = null, testFile = null;
    InstanceList trainingData = null, testData = null;
    int numEvaluations = 0;
    int iterationsBetweenEvals = 16;
    int restArgs = commandOptions.processOptions(args);
    if (restArgs == args.length) {
      commandOptions.printUsage(true);
      throw new IllegalArgumentException("Missing data file(s)");
    }
    if (trainOption.value) {
      trainingFile = new FileReader(new File(args[restArgs]));
      if (testOption.value != null && restArgs < args.length - 1)
        testFile = new FileReader(new File(args[restArgs + 1]));
    } else testFile = new FileReader(new File(args[restArgs]));

    Pipe p = null;
    CRF crf = null;
    TransducerEvaluator eval = null;
    if (continueTrainingOption.value || !trainOption.value) {
      if (modelOption.value == null) {
        commandOptions.printUsage(true);
        throw new IllegalArgumentException("Missing model file option");
      }
      ObjectInputStream s = new ObjectInputStream(new FileInputStream(modelOption.value));
      crf = (CRF) s.readObject();
      s.close();
      p = crf.getInputPipe();
    } else {
      p = new SimpleTaggerSentence2FeatureVectorSequence();
      p.getTargetAlphabet().lookupIndex(defaultOption.value);
    }

    if (trainOption.value) {
      p.setTargetProcessing(true);
      trainingData = new InstanceList(p);
      trainingData.addThruPipe(
          new LineGroupIterator(trainingFile, Pattern.compile("^\\s*$"), true));
      logger.info("Number of features in training data: " + p.getDataAlphabet().size());
      if (testOption.value != null) {
        if (testFile != null) {
          testData = new InstanceList(p);
          testData.addThruPipe(new LineGroupIterator(testFile, Pattern.compile("^\\s*$"), true));
        } else {
          Random r = new Random(randomSeedOption.value);
          InstanceList[] trainingLists =
              trainingData.split(
                  r, new double[] {trainingFractionOption.value, 1 - trainingFractionOption.value});
          trainingData = trainingLists[0];
          testData = trainingLists[1];
        }
      }
    } else if (testOption.value != null) {
      p.setTargetProcessing(true);
      testData = new InstanceList(p);
      testData.addThruPipe(new LineGroupIterator(testFile, Pattern.compile("^\\s*$"), true));
    } else {
      p.setTargetProcessing(false);
      testData = new InstanceList(p);
      testData.addThruPipe(new LineGroupIterator(testFile, Pattern.compile("^\\s*$"), true));
    }
    logger.info("Number of predicates: " + p.getDataAlphabet().size());

    if (testOption.value != null) {
      if (testOption.value.startsWith("lab"))
        eval =
            new TokenAccuracyEvaluator(
                new InstanceList[] {trainingData, testData}, new String[] {"Training", "Testing"});
      else if (testOption.value.startsWith("seg=")) {
        String[] pairs = testOption.value.substring(4).split(",");
        if (pairs.length < 1) {
          commandOptions.printUsage(true);
          throw new IllegalArgumentException(
              "Missing segment start/continue labels: " + testOption.value);
        }
        String startTags[] = new String[pairs.length];
        String continueTags[] = new String[pairs.length];
        for (int i = 0; i < pairs.length; i++) {
          String[] pair = pairs[i].split("\\.");
          if (pair.length != 2) {
            commandOptions.printUsage(true);
            throw new IllegalArgumentException(
                "Incorrectly-specified segment start and end labels: " + pairs[i]);
          }
          startTags[i] = pair[0];
          continueTags[i] = pair[1];
        }
        eval =
            new MultiSegmentationEvaluator(
                new InstanceList[] {trainingData, testData},
                new String[] {"Training", "Testing"},
                startTags,
                continueTags);
      } else {
        commandOptions.printUsage(true);
        throw new IllegalArgumentException("Invalid test option: " + testOption.value);
      }
    }

    if (p.isTargetProcessing()) {
      Alphabet targets = p.getTargetAlphabet();
      StringBuffer buf = new StringBuffer("Labels:");
      for (int i = 0; i < targets.size(); i++)
        buf.append(" ").append(targets.lookupObject(i).toString());
      logger.info(buf.toString());
    }
    if (trainOption.value) {
      crf =
          train(
              trainingData,
              testData,
              eval,
              ordersOption.value,
              defaultOption.value,
              forbiddenOption.value,
              allowedOption.value,
              connectedOption.value,
              iterationsOption.value,
              gaussianVarianceOption.value,
              crf);
      if (modelOption.value != null) {
        ObjectOutputStream s = new ObjectOutputStream(new FileOutputStream(modelOption.value));
        s.writeObject(crf);
        s.close();
      }
    } else {
      if (crf == null) {
        if (modelOption.value == null) {
          commandOptions.printUsage(true);
          throw new IllegalArgumentException("Missing model file option");
        }
        ObjectInputStream s = new ObjectInputStream(new FileInputStream(modelOption.value));
        crf = (CRF) s.readObject();
        s.close();
      }
      if (eval != null) test(new NoopTransducerTrainer(crf), eval, testData);
      else {
        boolean includeInput = includeInputOption.value();
        for (int i = 0; i < testData.size(); i++) {
          Sequence input = (Sequence) testData.get(i).getData();
          Sequence[] outputs = apply(crf, input, nBestOption.value);
          int k = outputs.length;
          boolean error = false;
          for (int a = 0; a < k; a++) {
            if (outputs[a].size() != input.size()) {
              System.err.println("Failed to decode input sequence " + i + ", answer " + a);
              error = true;
            }
          }
          if (!error) {
            for (int j = 0; j < input.size(); j++) {
              StringBuffer buf = new StringBuffer();
              for (int a = 0; a < k; a++) buf.append(outputs[a].get(j).toString()).append(" ");
              if (includeInput) {
                FeatureVector fv = (FeatureVector) input.get(j);
                buf.append(fv.toString(true));
              }
              System.out.println(buf.toString());
            }
            System.out.println();
          }
        }
      }
    }
  }