コード例 #1
0
 /**
  * Collect the correct segments for this example. These are defined as all segments with
  * non-NEGATIVE labels, and all unit-length negative labels not inside a positives label.
  */
 private Segmentation correctSegments(
     CandidateSegmentGroup g, ExampleSchema schema, int maxSegmentSize) {
   Segmentation result = new Segmentation(schema);
   int pos, len;
   for (pos = 0; pos < g.getSequenceLength(); ) {
     boolean addedASegmentStartingAtPos = false;
     for (len = 1; !addedASegmentStartingAtPos && len <= maxSegmentSize; len++) {
       Instance inst = g.getSubsequenceInstance(pos, pos + len);
       ClassLabel label = g.getSubsequenceLabel(pos, pos + len);
       if (inst != null && !label.isNegative()) {
         result.add(
             new Segmentation.Segment(
                 pos, pos + len, schema.getClassIndex(label.bestClassName())));
         addedASegmentStartingAtPos = true;
         pos += len;
       }
     }
     if (!addedASegmentStartingAtPos) {
       //				Instance inst = g.getSubsequenceInstance(pos,pos+1);
       //				ClassLabel label = g.getSubsequenceLabel(pos,pos+1);
       result.add(
           new Segmentation.Segment(
               pos, pos + 1, schema.getClassIndex(ExampleSchema.NEG_CLASS_NAME)));
       pos += 1;
     }
   }
   return result;
 }
コード例 #2
0
  @Override
  public Segmenter batchTrain(SegmentDataset dataset) {
    ExampleSchema schema = dataset.getSchema();
    innerLearner =
        SequenceUtils.duplicatePrototypeLearner(innerLearnerPrototype, schema.getNumberOfClasses());

    ProgressCounter pc =
        new ProgressCounter(
            "training segments " + innerLearnerPrototype.toString(),
            "sequence",
            numberOfEpochs * dataset.getNumberOfSegmentGroups());

    for (int epoch = 0; epoch < numberOfEpochs; epoch++) {
      // dataset.shuffle();

      // statistics for curious researchers
      int sequenceErrors = 0;
      int transitionErrors = 0;
      int transitions = 0;

      for (Iterator<CandidateSegmentGroup> i = dataset.candidateSegmentGroupIterator();
          i.hasNext(); ) {
        Classifier c = new SequenceUtils.MultiClassClassifier(schema, innerLearner);
        if (DEBUG) log.debug("classifier is: " + c);

        CandidateSegmentGroup g = i.next();
        Segmentation viterbi =
            new SegmentCollinsPerceptronLearner.ViterbiSearcher(c, schema, maxSegmentSize)
                .bestSegments(g);
        if (DEBUG) log.debug("viterbi " + maxSegmentSize + "\n" + viterbi);
        Segmentation correct = correctSegments(g, schema, maxSegmentSize);
        if (DEBUG) log.debug("correct segments:\n" + correct);

        boolean errorOnThisSequence = false;

        // accumulate weights for transitions associated with each class k
        Hyperplane[] accumPos = new Hyperplane[schema.getNumberOfClasses()];
        Hyperplane[] accumNeg = new Hyperplane[schema.getNumberOfClasses()];
        for (int k = 0; k < schema.getNumberOfClasses(); k++) {
          accumPos[k] = new Hyperplane();
          accumNeg[k] = new Hyperplane();
        }

        int fp = compareSegmentsAndIncrement(schema, viterbi, correct, accumNeg, +1, g);
        if (fp > 0) errorOnThisSequence = true;
        int fn = compareSegmentsAndIncrement(schema, correct, viterbi, accumPos, +1, g);
        if (fn > 0) errorOnThisSequence = true;
        if (errorOnThisSequence) sequenceErrors++;
        transitionErrors += fp + fn;

        if (errorOnThisSequence) {
          sequenceErrors++;
          String subPopId = g.getSubpopulationId();
          Object source = "no source";
          for (int k = 0; k < schema.getNumberOfClasses(); k++) {
            // System.out.println("adding class="+k+" example: "+accumPos[k]);
            innerLearner[k].addExample(
                new Example(
                    new HyperplaneInstance(accumPos[k], subPopId, source),
                    ClassLabel.positiveLabel(+1.0)));
            innerLearner[k].addExample(
                new Example(
                    new HyperplaneInstance(accumNeg[k], subPopId, source),
                    ClassLabel.negativeLabel(-1.0)));
          }
        }

        transitions += correct.size();
        pc.progress();
      } // sequence i

      System.out.println(
          "Epoch "
              + epoch
              + ": sequenceErr="
              + sequenceErrors
              + " transitionErrors="
              + transitionErrors
              + "/"
              + transitions);

      if (transitionErrors == 0) break;
    } // epoch
    pc.finished();

    for (int k = 0; k < schema.getNumberOfClasses(); k++) {
      innerLearner[k].completeTraining();
    }

    Classifier c = new SequenceUtils.MultiClassClassifier(schema, innerLearner);
    return new SegmentCollinsPerceptronLearner.ViterbiSegmenter(c, schema, maxSegmentSize);
  }