コード例 #1
0
  public void printPerformance(InstanceList instanceList, String description) {
    System.out.println(description);
    System.out.println("Accuracy: " + classifier.getAccuracy(instanceList));
    LabelAlphabet labelAlphabet = classifier.getLabelAlphabet();
    Iterator iterator = labelAlphabet.iterator();
    while (iterator.hasNext()) {
      Object label = iterator.next();
      double p = classifier.getPrecision(instanceList, label);
      double r = classifier.getRecall(instanceList, label);
      double f1 = classifier.getF1(instanceList, label);

      System.out.println("Precision[" + label + "] = " + p);
      System.out.println("Recall[" + label + "] = " + r);
      System.out.println("F1[" + label + "] = " + f1);
      System.out.println("");
    }

    ExtendedTrial trial = new ExtendedTrial(classifier, instanceList);
    System.out.println("Overall performance\n=====");
    System.out.println("Precision: " + trial.getPrecision());
    System.out.println("Recall: " + trial.getRecall());
    System.out.println("F1: " + trial.getF1());
    System.out.println("Fall-out: " + trial.getFallOut());

    trial.showErrors();
  }
コード例 #2
0
  public static ArrayList<Double> getAverageCrossValidationScore(
      InstanceList ilist, int i, ClassifierTrainer trainer) {

    double crossValidAccSum = 0;
    double crossValidPrcSum = 0;
    double crossValidRecSum = 0;
    double crossValidF1Sum = 0;

    int count = 0;

    // get gross validation folds
    CrossValidationIterator cvIlists = ilist.crossValidationIterator(i);

    while (cvIlists.hasNext()) {

      System.out.println("#############Performing " + count + " iteration###########");

      InstanceList[] ilists = cvIlists.next();

      System.out.println("The train set size is " + ilists[0].size());
      System.out.println("The test set size is " + ilists[1].size());
      Classifier classifier = trainer.train(ilists[0]);
      System.out.println("The training accuracy is " + classifier.getAccuracy(ilists[0]));
      System.out.println("The testing accuracy is " + classifier.getAccuracy(ilists[1]));
      System.out.println("The testing precision is " + classifier.getPrecision(ilists[1], 1));
      System.out.println("The testing recall is " + classifier.getRecall(ilists[1], 1));
      System.out.println("The testing f1score is " + classifier.getF1(ilists[1], 1));

      crossValidAccSum += classifier.getAccuracy(ilists[1]);
      crossValidPrcSum += classifier.getPrecision(ilists[1], 1);
      crossValidRecSum += classifier.getRecall(ilists[1], 1);
      crossValidF1Sum += classifier.getF1(ilists[1], 1);
      count++;

      // additional calculations
      ArrayList<Classification> outClassifications = classifier.classify(ilists[1]);
      int p1l1 = 0;
      int p1l0 = 0;
      int p0l1 = 0;
      int p0l0 = 0;
      int countCorrect = 0;
      int countIncorrect = 0;

      System.out.println("Outclassification size " + outClassifications.size());
      for (int k = 0; k < outClassifications.size(); k++) {

        // System.out.println("Data "+outClassifications.get(k).getInstance().getName());
        // System.out.println("Labeling "+outClassifications.get(k).getLabeling()); uncomment to get
        // score
        double predictedLabel = outClassifications.get(k).getLabeling().getBestIndex();
        // System.out.println("Predicted label "+ predictedLabel);
        double targetLabel =
            Double.valueOf(outClassifications.get(k).getInstance().getTarget().toString());
        // System.out.println("Target "+ targetLabel);
        boolean bestlabelIsCorrect = outClassifications.get(k).bestLabelIsCorrect();
        // System.out.println("Prediction "+bestlabelIsCorrect);

        if (bestlabelIsCorrect) countCorrect++;
        else countIncorrect++;

        if ((predictedLabel == 1.0) && (targetLabel == 1.0)) p1l1++;
        else if ((predictedLabel == 1.0) && (targetLabel == 0.0)) p1l0++;
        else if ((predictedLabel == 0.0) && (targetLabel == 1.0)) p0l1++;
        else if ((predictedLabel == 0.0) && (targetLabel == 0.0)) p0l0++;
      }

      System.out.println("Count Correct " + countCorrect);
      System.out.println("Count Incorrect " + countIncorrect);
      System.out.println("p1l1 " + p1l1);
      System.out.println("p1l0 " + p1l0);
      System.out.println("p0l1 " + p0l1);
      System.out.println("p0l0 " + p0l0);
    }

    ArrayList<Double> results = new ArrayList<Double>();
    double crossValidAccAvg = crossValidAccSum / count;
    double crossValidPrcAvg = crossValidPrcSum / count;
    double crossValidRecAvg = crossValidRecSum / count;
    double crossValidF1Avg = crossValidF1Sum / count;

    results.add(crossValidAccAvg);
    results.add(crossValidPrcAvg);
    results.add(crossValidRecAvg);
    results.add(crossValidF1Avg);

    return results;
  }