/** * @param predictions the predictions to use * @param classIndex the class index * @return the probabilities */ private double[] getProbabilities(FastVector predictions, int classIndex) { // sort by predicted probability of the desired class. double[] probs = new double[predictions.size()]; for (int i = 0; i < probs.length; i++) { NominalPrediction pred = (NominalPrediction) predictions.elementAt(i); probs[i] = pred.distribution()[classIndex]; } return probs; }
/** * Calculates the performance stats for the desired class and return results as a set of * Instances. * * @param predictions the predictions to base the curve on * @param classIndex index of the class of interest. * @return datapoints as a set of instances. */ public Instances getCurve(FastVector predictions, int classIndex) { if ((predictions.size() == 0) || (((NominalPrediction) predictions.elementAt(0)).distribution().length <= classIndex)) { System.out.println( "Foooobared " + predictions.size() + " " + ((NominalPrediction) predictions.elementAt(0)).distribution().length + " " + classIndex); return null; } double totPos = 0, totNeg = 0; double[] probs = getProbabilities(predictions, classIndex); // Get distribution of positive/negatives for (int i = 0; i < probs.length; i++) { NominalPrediction pred = (NominalPrediction) predictions.elementAt(i); if (pred.actual() == Prediction.MISSING_VALUE) { System.err.println(getClass().getName() + " Skipping prediction with missing class value"); continue; } if (pred.weight() < 0) { System.err.println(getClass().getName() + " Skipping prediction with negative weight"); continue; } if (pred.actual() == classIndex) { totPos += pred.weight(); } else { totNeg += pred.weight(); } } Instances insts = makeHeader(); int[] sorted = Utils.sort(probs); TwoClassStats tc = new TwoClassStats(totPos, totNeg, 0, 0); double threshold = 0; double cumulativePos = 0; double cumulativeNeg = 0; for (int i = 0; i < sorted.length; i++) { if ((i == 0) || (probs[sorted[i]] > threshold)) { tc.setTruePositive(tc.getTruePositive() - cumulativePos); tc.setFalseNegative(tc.getFalseNegative() + cumulativePos); tc.setFalsePositive(tc.getFalsePositive() - cumulativeNeg); tc.setTrueNegative(tc.getTrueNegative() + cumulativeNeg); threshold = probs[sorted[i]]; insts.add(makeInstance(tc, threshold)); cumulativePos = 0; cumulativeNeg = 0; if (i == sorted.length - 1) { break; } } NominalPrediction pred = (NominalPrediction) predictions.elementAt(sorted[i]); if (pred.actual() == Prediction.MISSING_VALUE) { System.err.println(getClass().getName() + " Skipping prediction with missing class value"); continue; } if (pred.weight() < 0) { System.err.println(getClass().getName() + " Skipping prediction with negative weight"); continue; } if (pred.actual() == classIndex) { cumulativePos += pred.weight(); } else { cumulativeNeg += pred.weight(); } /* System.out.println(tc + " " + probs[sorted[i]] + " " + (pred.actual() == classIndex)); */ /*if ((i != (sorted.length - 1)) && ((i == 0) || (probs[sorted[i]] != probs[sorted[i - 1]]))) { insts.add(makeInstance(tc, probs[sorted[i]])); }*/ } // make sure a zero point gets into the curve if (tc.getFalseNegative() != totPos || tc.getTrueNegative() != totNeg) { tc = new TwoClassStats(0, 0, totNeg, totPos); threshold = probs[sorted[sorted.length - 1]] + 10e-6; insts.add(makeInstance(tc, threshold)); } return insts; }