Exemplo n.º 1
0
  /**
   * Given confusion matrix, it writes it in CSV and LaTeX form to the tasks output directory, and
   * also prints evaluations (F-measure, Precision, Recall)
   *
   * @param context task context
   * @param confusionMatrix confusion matrix
   * @param filePrefix prefix of output files
   * @throws java.io.IOException
   */
  public static void writeOutputResults(
      TaskContext context, ConfusionMatrix confusionMatrix, String filePrefix) throws IOException {
    // storing the results as latex confusion matrix
    String confMatrixFileTex = (filePrefix != null ? filePrefix : "") + "confusionMatrix.tex";
    File matrixFolderTex =
        context.getFolder(Constants.TEST_TASK_OUTPUT_KEY, StorageService.AccessMode.READWRITE);
    File evaluationFileLaTeX = new File(matrixFolderTex, confMatrixFileTex);
    FileUtils.writeStringToFile(evaluationFileLaTeX, confusionMatrix.toStringLatex());

    // as CSV confusion matrix

    String confMatrixFileCsv = (filePrefix != null ? filePrefix : "") + "confusionMatrix.csv";
    File matrixFolder =
        context.getFolder(Constants.TEST_TASK_OUTPUT_KEY, StorageService.AccessMode.READWRITE);
    File evaluationFileCSV = new File(matrixFolder, confMatrixFileCsv);

    CSVPrinter csvPrinter = new CSVPrinter(new FileWriter(evaluationFileCSV), CSVFormat.DEFAULT);
    csvPrinter.printRecords(confusionMatrix.toStringMatrix());
    IOUtils.closeQuietly(csvPrinter);

    // and results
    File evalFolder =
        context.getFolder(Constants.TEST_TASK_OUTPUT_KEY, StorageService.AccessMode.READWRITE);
    String evalFileName =
        new SVMHMMAdapter()
            .getFrameworkFilename(TCMachineLearningAdapter.AdapterNameEntries.evaluationFile);
    File evaluationFile = new File(evalFolder, evalFileName);

    PrintWriter pw = new PrintWriter(evaluationFile);
    pw.println(confusionMatrix.printNiceResults());
    pw.println(confusionMatrix.printLabelPrecRecFm());
    pw.println(confusionMatrix.printClassDistributionGold());
    IOUtils.closeQuietly(pw);
  }
Exemplo n.º 2
0
 private void updateConfusionMatrix(
     ConfusionMatrix m, Example ex, double compatDecisionThreshold, double probDecisionThreshold) {
   List<Derivation> derivations = ex.getPredDerivations();
   double[] probs = Derivation.getProbs(derivations, 1.0d);
   for (int i = 0; i < derivations.size(); i++) {
     Derivation deriv = derivations.get(i);
     double gold, pred;
     if (compatDecisionThreshold == -1.0d) gold = deriv.getCompatibility();
     else gold = (deriv.getCompatibility() > compatDecisionThreshold) ? 1.0d : 0.0d;
     if (probDecisionThreshold == -1.0d) pred = probs[i];
     else pred = (probs[i] > probDecisionThreshold) ? 1.0d : 0.0d;
     m.tp += gold * pred;
     m.fn += gold * (1.0d - pred);
     m.fp += (1.0d - gold) * pred;
     m.tn += (1.0d - gold) * (1.0d - pred);
   }
 }
Exemplo n.º 3
0
  public double compute(Rule rule) {
    double result;
    ConfusionMatrix confusionMatrix;

    confusionMatrix = rule.getConfusionMatrix();

    result =
        (confusionMatrix.getTruePositives()
                    / (confusionMatrix.getTruePositives() + confusionMatrix.getFalsePositives())
                - (confusionMatrix.getTruePositives() + confusionMatrix.getFalseNegatives())
                    / confusionMatrix.getNumberOfSamples())
            / (1
                - ((confusionMatrix.getTruePositives() + confusionMatrix.getFalseNegatives())
                    / confusionMatrix.getNumberOfSamples()));

    return result;
  }
Exemplo n.º 4
0
 /**
  * @param correctLabel The correct label
  * @param classifiedResult The classified result
  * @return whether the instance was correct or not
  */
 public boolean addInstance(String correctLabel, ClassifierResult classifiedResult) {
   boolean result = correctLabel.equals(classifiedResult.getLabel());
   if (result) {
     correctlyClassified++;
   } else {
     incorrectlyClassified++;
   }
   confusionMatrix.addInstance(correctLabel, classifiedResult);
   if (classifiedResult.getLogLikelihood() != Double.MAX_VALUE) {
     summarizer.add(classifiedResult.getLogLikelihood());
     hasLL = true;
   }
   return result;
 }
Exemplo n.º 5
0
  @Override
  public String toString() {
    StringBuilder returnString = new StringBuilder();

    returnString.append('\n');
    returnString.append("=======================================================\n");
    returnString.append("Summary\n");
    returnString.append("-------------------------------------------------------\n");
    int totalClassified = correctlyClassified + incorrectlyClassified;
    double percentageCorrect = (double) 100 * correctlyClassified / totalClassified;
    double percentageIncorrect = (double) 100 * incorrectlyClassified / totalClassified;
    NumberFormat decimalFormatter = new DecimalFormat("0.####");

    returnString
        .append(StringUtils.rightPad("Correctly Classified Instances", 40))
        .append(": ")
        .append(StringUtils.leftPad(Integer.toString(correctlyClassified), 10))
        .append('\t')
        .append(StringUtils.leftPad(decimalFormatter.format(percentageCorrect), 10))
        .append("%\n");
    returnString
        .append(StringUtils.rightPad("Incorrectly Classified Instances", 40))
        .append(": ")
        .append(StringUtils.leftPad(Integer.toString(incorrectlyClassified), 10))
        .append('\t')
        .append(StringUtils.leftPad(decimalFormatter.format(percentageIncorrect), 10))
        .append("%\n");
    returnString
        .append(StringUtils.rightPad("Total Classified Instances", 40))
        .append(": ")
        .append(StringUtils.leftPad(Integer.toString(totalClassified), 10))
        .append('\n');
    returnString.append('\n');

    returnString.append(confusionMatrix);
    returnString.append("=======================================================\n");
    returnString.append("Statistics\n");
    returnString.append("-------------------------------------------------------\n");

    RunningAverageAndStdDev normStats = confusionMatrix.getNormalizedStats();
    returnString
        .append(StringUtils.rightPad("Kappa", 40))
        .append(StringUtils.leftPad(decimalFormatter.format(confusionMatrix.getKappa()), 10))
        .append('\n');
    returnString
        .append(StringUtils.rightPad("Accuracy", 40))
        .append(StringUtils.leftPad(decimalFormatter.format(confusionMatrix.getAccuracy()), 10))
        .append("%\n");
    returnString
        .append(StringUtils.rightPad("Reliability", 40))
        .append(
            StringUtils.leftPad(decimalFormatter.format(normStats.getAverage() * 100.00000001), 10))
        .append("%\n");
    returnString
        .append(StringUtils.rightPad("Reliability (standard deviation)", 40))
        .append(StringUtils.leftPad(decimalFormatter.format(normStats.getStandardDeviation()), 10))
        .append('\n');
    returnString
        .append(StringUtils.rightPad("Weighted precision", 40))
        .append(
            StringUtils.leftPad(
                decimalFormatter.format(confusionMatrix.getWeightedPrecision()), 10))
        .append('\n');
    returnString
        .append(StringUtils.rightPad("Weighted recall", 40))
        .append(
            StringUtils.leftPad(decimalFormatter.format(confusionMatrix.getWeightedRecall()), 10))
        .append('\n');
    returnString
        .append(StringUtils.rightPad("Weighted F1 score", 40))
        .append(
            StringUtils.leftPad(decimalFormatter.format(confusionMatrix.getWeightedF1score()), 10))
        .append('\n');

    if (hasLL) {
      returnString
          .append(StringUtils.rightPad("Log-likelihood", 30))
          .append("mean      : ")
          .append(StringUtils.leftPad(decimalFormatter.format(summarizer.getMean()), 10))
          .append('\n');
      returnString
          .append(StringUtils.rightPad("", 30))
          .append(StringUtils.rightPad("25%-ile   : ", 10))
          .append(StringUtils.leftPad(decimalFormatter.format(summarizer.getQuartile(1)), 10))
          .append('\n');
      returnString
          .append(StringUtils.rightPad("", 30))
          .append(StringUtils.rightPad("75%-ile   : ", 10))
          .append(StringUtils.leftPad(decimalFormatter.format(summarizer.getQuartile(3)), 10))
          .append('\n');
    }

    return returnString.toString();
  }
  /** Test of evaluatePerformance method, of class ConfusionMatrixPerformanceEvaluator. */
  @Test
  public void testEvaluatePerformance() {
    double epsilon = 1e-10;
    ConfusionMatrixPerformanceEvaluator<?, String> instance =
        new ConfusionMatrixPerformanceEvaluator<String, String>();

    Collection<TargetEstimatePair<String, String>> data =
        new ArrayList<TargetEstimatePair<String, String>>();

    ConfusionMatrix<String> confusion = instance.summarize(data);
    assertEquals(0.0, confusion.getTotalCount(), 0.0);

    data.add(new DefaultTargetEstimatePair<String, String>("yes", "no"));
    confusion = instance.summarize(data);
    assertEquals(1.0, confusion.getTotalCount(), 0.0);
    assertEquals(1.0, confusion.getCount("yes", "no"), 0.0);
    assertEquals(1.0 / 1.0, confusion.getErrorRate(), 0.0);

    data.add(new DefaultTargetEstimatePair<String, String>("yes", "yes"));
    confusion = instance.summarize(data);
    assertEquals(2.0, confusion.getTotalCount(), 0.0);
    assertEquals(1.0, confusion.getCount("yes", "no"), 0.0);
    assertEquals(1.0, confusion.getCount("yes", "yes"), 0.0);
    assertEquals(1.0 / 2.0, confusion.getErrorRate(), epsilon);

    data.add(new DefaultTargetEstimatePair<String, String>("no", "no"));
    confusion = instance.summarize(data);
    assertEquals(3.0, confusion.getTotalCount(), 0.0);
    assertEquals(1.0, confusion.getCount("yes", "no"), 0.0);
    assertEquals(1.0, confusion.getCount("yes", "yes"), 0.0);
    assertEquals(1.0, confusion.getCount("no", "no"), 0.0);
    assertEquals(1.0 / 3.0, confusion.getErrorRate(), epsilon);

    data.add(new DefaultTargetEstimatePair<String, String>("something", "else"));
    confusion = instance.summarize(data);
    assertEquals(4.0, confusion.getTotalCount(), 0.0);
    assertEquals(1.0, confusion.getCount("yes", "no"), 0.0);
    assertEquals(1.0, confusion.getCount("yes", "yes"), 0.0);
    assertEquals(1.0, confusion.getCount("no", "no"), 0.0);
    assertEquals(1.0, confusion.getCount("something", "else"), 0.0);
    assertEquals(2.0 / 4.0, confusion.getErrorRate(), epsilon);

    data.add(new DefaultTargetEstimatePair<String, String>("same", "same"));
    confusion = instance.summarize(data);
    assertEquals(5.0, confusion.getTotalCount(), 0.0);
    assertEquals(1.0, confusion.getCount("yes", "no"), 0.0);
    assertEquals(1.0, confusion.getCount("yes", "yes"), 0.0);
    assertEquals(1.0, confusion.getCount("no", "no"), 0.0);
    assertEquals(1.0, confusion.getCount("something", "else"), 0.0);
    assertEquals(1.0, confusion.getCount("same", "same"), 0.0);
    assertEquals(2.0 / 5.0, confusion.getErrorRate(), epsilon);

    data.add(new DefaultTargetEstimatePair<String, String>("oh", "no"));
    confusion = instance.summarize(data);
    assertEquals(6.0, confusion.getTotalCount(), 0.0);
    assertEquals(1.0, confusion.getCount("yes", "no"), 0.0);
    assertEquals(1.0, confusion.getCount("yes", "yes"), 0.0);
    assertEquals(1.0, confusion.getCount("no", "no"), 0.0);
    assertEquals(1.0, confusion.getCount("something", "else"), 0.0);
    assertEquals(1.0, confusion.getCount("same", "same"), 0.0);
    assertEquals(1.0, confusion.getCount("oh", "no"), 0.0);
    assertEquals(3.0 / 6.0, confusion.getErrorRate(), epsilon);

    data.add(new DefaultTargetEstimatePair<String, String>("this", "bad"));
    confusion = instance.summarize(data);
    assertEquals(7.0, confusion.getTotalCount(), 0.0);
    assertEquals(1.0, confusion.getCount("yes", "no"), 0.0);
    assertEquals(1.0, confusion.getCount("yes", "yes"), 0.0);
    assertEquals(1.0, confusion.getCount("no", "no"), 0.0);
    assertEquals(1.0, confusion.getCount("something", "else"), 0.0);
    assertEquals(1.0, confusion.getCount("same", "same"), 0.0);
    assertEquals(1.0, confusion.getCount("oh", "no"), 0.0);
    assertEquals(1.0, confusion.getCount("this", "bad"), 0.0);
    assertEquals(4.0 / 7.0, confusion.getErrorRate(), epsilon);

    data.add(new DefaultTargetEstimatePair<String, String>("not null", null));
    confusion = instance.summarize(data);
    assertEquals(8.0, confusion.getTotalCount(), 0.0);
    assertEquals(1.0, confusion.getCount("yes", "no"), 0.0);
    assertEquals(1.0, confusion.getCount("yes", "yes"), 0.0);
    assertEquals(1.0, confusion.getCount("no", "no"), 0.0);
    assertEquals(1.0, confusion.getCount("something", "else"), 0.0);
    assertEquals(1.0, confusion.getCount("same", "same"), 0.0);
    assertEquals(1.0, confusion.getCount("oh", "no"), 0.0);
    assertEquals(1.0, confusion.getCount("this", "bad"), 0.0);
    assertEquals(1.0, confusion.getCount("not null", null), 0.0);
    assertEquals(5.0 / 8.0, confusion.getErrorRate(), epsilon);

    data.add(new DefaultTargetEstimatePair<String, String>(null, "not null"));
    confusion = instance.summarize(data);
    assertEquals(9.0, confusion.getTotalCount(), 0.0);
    assertEquals(1.0, confusion.getCount("yes", "no"), 0.0);
    assertEquals(1.0, confusion.getCount("yes", "yes"), 0.0);
    assertEquals(1.0, confusion.getCount("no", "no"), 0.0);
    assertEquals(1.0, confusion.getCount("something", "else"), 0.0);
    assertEquals(1.0, confusion.getCount("same", "same"), 0.0);
    assertEquals(1.0, confusion.getCount("oh", "no"), 0.0);
    assertEquals(1.0, confusion.getCount("this", "bad"), 0.0);
    assertEquals(1.0, confusion.getCount("not null", null), 0.0);
    assertEquals(1.0, confusion.getCount(null, "not null"), 0.0);
    assertEquals(6.0 / 9.0, confusion.getErrorRate(), epsilon);

    data.add(new DefaultTargetEstimatePair<String, String>(null, null));
    confusion = instance.summarize(data);
    assertEquals(10.0, confusion.getTotalCount(), 0.0);
    assertEquals(1.0, confusion.getCount("yes", "no"), 0.0);
    assertEquals(1.0, confusion.getCount("yes", "yes"), 0.0);
    assertEquals(1.0, confusion.getCount("no", "no"), 0.0);
    assertEquals(1.0, confusion.getCount("something", "else"), 0.0);
    assertEquals(1.0, confusion.getCount("same", "same"), 0.0);
    assertEquals(1.0, confusion.getCount("oh", "no"), 0.0);
    assertEquals(1.0, confusion.getCount("this", "bad"), 0.0);
    assertEquals(1.0, confusion.getCount("not null", null), 0.0);
    assertEquals(1.0, confusion.getCount(null, "not null"), 0.0);
    assertEquals(1.0, confusion.getCount(null, null), 0.0);
    assertEquals(6.0 / 10.0, confusion.getErrorRate(), epsilon);

    data.add(new DefaultTargetEstimatePair<String, String>("yes", "no"));
    data.add(new DefaultTargetEstimatePair<String, String>("yes", "no"));
    data.add(new DefaultTargetEstimatePair<String, String>("yes", "yes"));
    data.add(new DefaultTargetEstimatePair<String, String>("no", "no"));
    confusion = instance.summarize(data);
    assertEquals(14.0, confusion.getTotalCount(), 0.0);
    assertEquals(3.0, confusion.getCount("yes", "no"), 0.0);
    assertEquals(2.0, confusion.getCount("yes", "yes"), 0.0);
    assertEquals(2.0, confusion.getCount("no", "no"), 0.0);
    assertEquals(1.0, confusion.getCount("something", "else"), 0.0);
    assertEquals(1.0, confusion.getCount("same", "same"), 0.0);
    assertEquals(1.0, confusion.getCount("oh", "no"), 0.0);
    assertEquals(1.0, confusion.getCount("this", "bad"), 0.0);
    assertEquals(1.0, confusion.getCount("not null", null), 0.0);
    assertEquals(1.0, confusion.getCount(null, "not null"), 0.0);
    assertEquals(1.0, confusion.getCount(null, null), 0.0);
    assertEquals(8.0 / 14.0, confusion.getErrorRate(), epsilon);
  }
Exemplo n.º 7
0
  protected double doScoringAndSaveModel(
      boolean finalScoring, boolean oob, boolean build_tree_one_node) {
    double training_r2 = Double.NaN; // Training R^2 value, if computed
    long now = System.currentTimeMillis();
    if (_firstScore == 0) _firstScore = now;
    long sinceLastScore = now - _timeLastScoreStart;
    boolean updated = false;
    new ProgressUpdate(
            "Built " + _model._output._ntrees + " trees so far (out of " + _parms._ntrees + ").")
        .fork(_progressKey);
    // Now model already contains tid-trees in serialized form
    if (_parms._score_each_iteration
        || finalScoring
        || (now - _firstScore < 4000)
        || // Score every time for 4 secs
        // Throttle scoring to keep the cost sane; limit to a 10% duty cycle & every 4 secs
        (sinceLastScore > 4000
            && // Limit scoring updates to every 4sec
            (double) (_timeLastScoreEnd - _timeLastScoreStart) / sinceLastScore
                < 0.1)) { // 10% duty cycle

      checkMemoryFootPrint();

      // If validation is specified we use a model for scoring, so we need to
      // update it!  First we save model with trees (i.e., make them available
      // for scoring) and then update it with resulting error
      _model.update(_key);
      updated = true;

      Log.info("============================================================== ");
      SharedTreeModel.SharedTreeOutput out = _model._output;
      _timeLastScoreStart = now;
      // Score on training data
      new ProgressUpdate("Scoring the model.").fork(_progressKey);
      Score sc =
          new Score(this, true, oob, _model._output.getModelCategory())
              .doAll(train(), build_tree_one_node);
      ModelMetrics mm = sc.makeModelMetrics(_model, _parms.train());
      out._training_metrics = mm;
      if (oob)
        out._training_metrics._description = "Metrics reported on Out-Of-Bag training samples";
      out._scored_train[out._ntrees].fillFrom(mm);
      if (out._ntrees > 0) Log.info("Training " + out._scored_train[out._ntrees].toString());

      // Score again on validation data
      if (_parms._valid != null) {
        Score scv =
            new Score(this, false, false, _model._output.getModelCategory())
                .doAll(valid(), build_tree_one_node);
        ModelMetrics mmv = scv.makeModelMetrics(_model, _parms.valid());
        out._validation_metrics = mmv;
        out._scored_valid[out._ntrees].fillFrom(mmv);
        if (out._ntrees > 0) Log.info("Validation " + out._scored_valid[out._ntrees].toString());
      }

      if (out._ntrees > 0) { // Compute variable importances
        out._model_summary = createModelSummaryTable(out);
        out._scoring_history = createScoringHistoryTable(out);
        out._varimp = new hex.VarImp(_improvPerVar, out._names);
        out._variable_importances = hex.ModelMetrics.calcVarImp(out._varimp);
        Log.info(out._model_summary.toString());
        // For Debugging:
        //        Log.info(out._scoring_history.toString());
        //        Log.info(out._variable_importances.toString());
      }

      ConfusionMatrix cm = mm.cm();
      if (cm != null) {
        if (cm._cm.length <= _parms._max_confusion_matrix_size) {
          Log.info(cm.toASCII());
        } else {
          Log.info(
              "Confusion Matrix is too large (max_confusion_matrix_size="
                  + _parms._max_confusion_matrix_size
                  + "): "
                  + _nclass
                  + " classes.");
        }
      }
      _timeLastScoreEnd = System.currentTimeMillis();
    }

    // Double update - after either scoring or variable importance
    if (updated) _model.update(_key);
    return training_r2;
  }