Пример #1
0
  /**
   * @param predictions the predictions to use
   * @param classIndex the class index
   * @return the probabilities
   */
  private double[] getProbabilities(FastVector predictions, int classIndex) {

    // sort by predicted probability of the desired class.
    double[] probs = new double[predictions.size()];
    for (int i = 0; i < probs.length; i++) {
      NominalPrediction pred = (NominalPrediction) predictions.elementAt(i);
      probs[i] = pred.distribution()[classIndex];
    }
    return probs;
  }
Пример #2
0
  /**
   * Calculates the performance stats for the desired class and return results as a set of
   * Instances.
   *
   * @param predictions the predictions to base the curve on
   * @param classIndex index of the class of interest.
   * @return datapoints as a set of instances.
   */
  public Instances getCurve(FastVector predictions, int classIndex) {

    if ((predictions.size() == 0)
        || (((NominalPrediction) predictions.elementAt(0)).distribution().length <= classIndex)) {
      System.out.println(
          "Foooobared "
              + predictions.size()
              + " "
              + ((NominalPrediction) predictions.elementAt(0)).distribution().length
              + " "
              + classIndex);
      return null;
    }

    double totPos = 0, totNeg = 0;
    double[] probs = getProbabilities(predictions, classIndex);

    // Get distribution of positive/negatives
    for (int i = 0; i < probs.length; i++) {
      NominalPrediction pred = (NominalPrediction) predictions.elementAt(i);
      if (pred.actual() == Prediction.MISSING_VALUE) {
        System.err.println(getClass().getName() + " Skipping prediction with missing class value");
        continue;
      }
      if (pred.weight() < 0) {
        System.err.println(getClass().getName() + " Skipping prediction with negative weight");
        continue;
      }
      if (pred.actual() == classIndex) {
        totPos += pred.weight();
      } else {
        totNeg += pred.weight();
      }
    }

    Instances insts = makeHeader();
    int[] sorted = Utils.sort(probs);
    TwoClassStats tc = new TwoClassStats(totPos, totNeg, 0, 0);
    double threshold = 0;
    double cumulativePos = 0;
    double cumulativeNeg = 0;

    for (int i = 0; i < sorted.length; i++) {

      if ((i == 0) || (probs[sorted[i]] > threshold)) {
        tc.setTruePositive(tc.getTruePositive() - cumulativePos);
        tc.setFalseNegative(tc.getFalseNegative() + cumulativePos);
        tc.setFalsePositive(tc.getFalsePositive() - cumulativeNeg);
        tc.setTrueNegative(tc.getTrueNegative() + cumulativeNeg);
        threshold = probs[sorted[i]];
        insts.add(makeInstance(tc, threshold));
        cumulativePos = 0;
        cumulativeNeg = 0;
        if (i == sorted.length - 1) {
          break;
        }
      }

      NominalPrediction pred = (NominalPrediction) predictions.elementAt(sorted[i]);

      if (pred.actual() == Prediction.MISSING_VALUE) {
        System.err.println(getClass().getName() + " Skipping prediction with missing class value");
        continue;
      }
      if (pred.weight() < 0) {
        System.err.println(getClass().getName() + " Skipping prediction with negative weight");
        continue;
      }
      if (pred.actual() == classIndex) {
        cumulativePos += pred.weight();
      } else {
        cumulativeNeg += pred.weight();
      }

      /*
      System.out.println(tc + " " + probs[sorted[i]]
                         + " " + (pred.actual() == classIndex));
      */
      /*if ((i != (sorted.length - 1)) &&
               ((i == 0) ||
               (probs[sorted[i]] != probs[sorted[i - 1]]))) {
             insts.add(makeInstance(tc, probs[sorted[i]]));
      }*/
    }

    // make sure a zero point gets into the curve
    if (tc.getFalseNegative() != totPos || tc.getTrueNegative() != totNeg) {
      tc = new TwoClassStats(0, 0, totNeg, totPos);
      threshold = probs[sorted[sorted.length - 1]] + 10e-6;
      insts.add(makeInstance(tc, threshold));
    }

    return insts;
  }