コード例 #1
0
ファイル: TrainCRF.java プロジェクト: carriercomm/PrologMUD
  /** This is (mostly) copied from CRF4.java */
  public boolean[][] labelConnectionsIn(
      Alphabet outputAlphabet, InstanceList trainingSet, String start) {
    int numLabels = outputAlphabet.size();
    boolean[][] connections = new boolean[numLabels][numLabels];
    for (int i = 0; i < trainingSet.size(); i++) {
      Instance instance = trainingSet.getInstance(i);
      FeatureSequence output = (FeatureSequence) instance.getTarget();
      for (int j = 1; j < output.size(); j++) {
        int sourceIndex = outputAlphabet.lookupIndex(output.get(j - 1));
        int destIndex = outputAlphabet.lookupIndex(output.get(j));
        assert (sourceIndex >= 0 && destIndex >= 0);
        connections[sourceIndex][destIndex] = true;
      }
    }

    // Handle start state
    if (start != null) {
      int startIndex = outputAlphabet.lookupIndex(start);
      for (int j = 0; j < outputAlphabet.size(); j++) {
        connections[startIndex][j] = true;
      }
    }

    return connections;
  }
コード例 #2
0
  public Instance pipe(Instance carrier) {
    TokenSequence ts = (TokenSequence) carrier.getData();
    for (int i = 0; i < ts.size(); i++) {

      for (int j = 0; j < lexicon.size(); j++) {
        Pattern[] pats = (Pattern[]) lexicon.elementAt(j);
        boolean matched = true;

        for (int k = 0; k < pats.length && i + k < ts.size(); k++) {
          Token tok = ts.getToken(i + k);
          String t = tok.getText().intern();
          if (!(pats[k].matcher(ignoreCase ? t.toLowerCase() : t)).matches()) {
            matched = false;
            break;
          }
        }

        if (matched) {
          for (int k = 0; k < pats.length && i + k < ts.size(); k++)
            ts.getToken(i + k).setFeatureValue(name + (indvMatch ? "" + j : ""), 1.0);
          i = i + pats.length - 1;
          break;
        }
      }
    }
    return carrier;
  }
コード例 #3
0
  public Instance pipe(Instance carrier) {
    LineInfo[] lineInfos = (LineInfo[]) carrier.getData();

    for (int i = 0; i < lineInfos.length; i++) {
      if (containsLexicon(lineInfos[i])) lineInfos[i].presentFeatures.add(m_featureName);
    }

    return carrier;
  }
コード例 #4
0
ファイル: Winnow.java プロジェクト: VivianLuwenHuangfu/banner
  /**
   * Classifies an instance using Winnow's weights
   *
   * @param instance an instance to be classified
   * @return an object containing the classifier's guess
   */
  public Classification classify(Instance instance) {
    int numClasses = getLabelAlphabet().size();
    double[] scores = new double[numClasses];
    FeatureVector fv = (FeatureVector) instance.getData(this.instancePipe);
    // Make sure the feature vector's feature dictionary matches
    // what we are expecting from our data pipe (and thus our notion
    // of feature probabilities.
    assert (instancePipe == null || fv.getAlphabet() == this.instancePipe.getDataAlphabet());
    int fvisize = fv.numLocations();

    // Set the scores by summing wi*xi
    for (int fvi = 0; fvi < fvisize; fvi++) {
      int fi = fv.indexAtLocation(fvi);
      for (int ci = 0; ci < numClasses; ci++) scores[ci] += this.weights[ci][fi];
    }

    // Create and return a Classification object
    return new Classification(instance, this, new LabelVector(getLabelAlphabet(), scores));
  }
コード例 #5
0
 public Instance pipe(Instance carrier) {
   carrier.setData(new TokenIterator((TokenSequence) carrier.getData()));
   return carrier;
 }
コード例 #6
0
 public String toString() {
   return index + " : " + inst.getName() + "\t// " + DF.format(score);
 }
コード例 #7
0
  public void test(
      Transducer transducer,
      InstanceList data,
      String description,
      PrintStream viterbiOutputStream) {
    int[] ntrue = new int[segmentTags.length];
    int[] npred = new int[segmentTags.length];
    int[] ncorr = new int[segmentTags.length];

    LabelAlphabet dict = (LabelAlphabet) transducer.getInputPipe().getTargetAlphabet();

    for (int i = 0; i < data.size(); i++) {
      Instance instance = data.getInstance(i);
      Sequence input = (Sequence) instance.getData();
      Sequence trueOutput = (Sequence) instance.getTarget();
      assert (input.size() == trueOutput.size());
      Sequence predOutput = transducer.viterbiPath(input).output();
      assert (predOutput.size() == trueOutput.size());

      List trueSegs = new ArrayList();
      List predSegs = new ArrayList();

      addSegs(trueSegs, trueOutput);
      addSegs(predSegs, predOutput);

      //      System.out.println("FieldF1Evaluator instance "+instance.getName ());
      //      printSegs(dict, trueSegs, "True");
      //      printSegs(dict, predSegs, "Pred");

      for (Iterator it = predSegs.iterator(); it.hasNext(); ) {
        Segment seg = (Segment) it.next();
        npred[seg.tag]++;
        if (trueSegs.contains(seg)) {
          ncorr[seg.tag]++;
        }
      }

      for (Iterator it = trueSegs.iterator(); it.hasNext(); ) {
        Segment seg = (Segment) it.next();
        ntrue[seg.tag]++;
      }
    }

    DecimalFormat f = new DecimalFormat("0.####");
    logger.info(description + " per-field F1");
    for (int tag = 0; tag < segmentTags.length; tag++) {
      double precision = ((double) ncorr[tag]) / npred[tag];
      double recall = ((double) ncorr[tag]) / ntrue[tag];
      double f1 = (2 * precision * recall) / (precision + recall);
      Label name = dict.lookupLabel(segmentTags[tag]);
      logger.info(
          " segments "
              + name
              + "  true = "
              + ntrue[tag]
              + "  pred = "
              + npred[tag]
              + "  correct = "
              + ncorr[tag]);
      logger.info(
          " precision="
              + f.format(precision)
              + " recall="
              + f.format(recall)
              + " f1="
              + f.format(f1));
    }
  }