/** * Collect the correct segments for this example. These are defined as all segments with * non-NEGATIVE labels, and all unit-length negative labels not inside a positives label. */ private Segmentation correctSegments( CandidateSegmentGroup g, ExampleSchema schema, int maxSegmentSize) { Segmentation result = new Segmentation(schema); int pos, len; for (pos = 0; pos < g.getSequenceLength(); ) { boolean addedASegmentStartingAtPos = false; for (len = 1; !addedASegmentStartingAtPos && len <= maxSegmentSize; len++) { Instance inst = g.getSubsequenceInstance(pos, pos + len); ClassLabel label = g.getSubsequenceLabel(pos, pos + len); if (inst != null && !label.isNegative()) { result.add( new Segmentation.Segment( pos, pos + len, schema.getClassIndex(label.bestClassName()))); addedASegmentStartingAtPos = true; pos += len; } } if (!addedASegmentStartingAtPos) { // Instance inst = g.getSubsequenceInstance(pos,pos+1); // ClassLabel label = g.getSubsequenceLabel(pos,pos+1); result.add( new Segmentation.Segment( pos, pos + 1, schema.getClassIndex(ExampleSchema.NEG_CLASS_NAME))); pos += 1; } } return result; }
@Override public Segmenter batchTrain(SegmentDataset dataset) { ExampleSchema schema = dataset.getSchema(); innerLearner = SequenceUtils.duplicatePrototypeLearner(innerLearnerPrototype, schema.getNumberOfClasses()); ProgressCounter pc = new ProgressCounter( "training segments " + innerLearnerPrototype.toString(), "sequence", numberOfEpochs * dataset.getNumberOfSegmentGroups()); for (int epoch = 0; epoch < numberOfEpochs; epoch++) { // dataset.shuffle(); // statistics for curious researchers int sequenceErrors = 0; int transitionErrors = 0; int transitions = 0; for (Iterator<CandidateSegmentGroup> i = dataset.candidateSegmentGroupIterator(); i.hasNext(); ) { Classifier c = new SequenceUtils.MultiClassClassifier(schema, innerLearner); if (DEBUG) log.debug("classifier is: " + c); CandidateSegmentGroup g = i.next(); Segmentation viterbi = new SegmentCollinsPerceptronLearner.ViterbiSearcher(c, schema, maxSegmentSize) .bestSegments(g); if (DEBUG) log.debug("viterbi " + maxSegmentSize + "\n" + viterbi); Segmentation correct = correctSegments(g, schema, maxSegmentSize); if (DEBUG) log.debug("correct segments:\n" + correct); boolean errorOnThisSequence = false; // accumulate weights for transitions associated with each class k Hyperplane[] accumPos = new Hyperplane[schema.getNumberOfClasses()]; Hyperplane[] accumNeg = new Hyperplane[schema.getNumberOfClasses()]; for (int k = 0; k < schema.getNumberOfClasses(); k++) { accumPos[k] = new Hyperplane(); accumNeg[k] = new Hyperplane(); } int fp = compareSegmentsAndIncrement(schema, viterbi, correct, accumNeg, +1, g); if (fp > 0) errorOnThisSequence = true; int fn = compareSegmentsAndIncrement(schema, correct, viterbi, accumPos, +1, g); if (fn > 0) errorOnThisSequence = true; if (errorOnThisSequence) sequenceErrors++; transitionErrors += fp + fn; if (errorOnThisSequence) { sequenceErrors++; String subPopId = g.getSubpopulationId(); Object source = "no source"; for (int k = 0; k < schema.getNumberOfClasses(); k++) { // System.out.println("adding class="+k+" example: "+accumPos[k]); innerLearner[k].addExample( new Example( new HyperplaneInstance(accumPos[k], subPopId, source), ClassLabel.positiveLabel(+1.0))); innerLearner[k].addExample( new Example( new HyperplaneInstance(accumNeg[k], subPopId, source), ClassLabel.negativeLabel(-1.0))); } } transitions += correct.size(); pc.progress(); } // sequence i System.out.println( "Epoch " + epoch + ": sequenceErr=" + sequenceErrors + " transitionErrors=" + transitionErrors + "/" + transitions); if (transitionErrors == 0) break; } // epoch pc.finished(); for (int k = 0; k < schema.getNumberOfClasses(); k++) { innerLearner[k].completeTraining(); } Classifier c = new SequenceUtils.MultiClassClassifier(schema, innerLearner); return new SegmentCollinsPerceptronLearner.ViterbiSegmenter(c, schema, maxSegmentSize); }