コード例 #1
0
  public void doit() throws Exception {
    PredictionModelModule pmm = (PredictionModelModule) pullInput(0);
    Table t = (Table) pullInput(1);

    PredictionTable pt = pmm.predict(t);
    int[] outputs = pt.getOutputFeatures();
    int[] preds = pt.getPredictionSet();

    int numRows = pt.getNumRows();

    String[] names = new String[preds.length];
    double[] errors = new double[preds.length];
    for (int i = 0; i < preds.length; i++) {
      int numCorrect = 0;
      for (int j = 0; j < numRows; j++) {
        String orig = pt.getString(j, outputs[i]);
        String pred = pt.getString(j, preds[i]);
        if (orig.equals(pred)) numCorrect++;
      }

      names[i] = pt.getColumnLabel(preds[i]);
      errors[i] = 1 - ((double) numCorrect) / ((double) numRows);
    }

    ParameterPoint pp = ParameterPointImpl.getParameterPoint(names, errors);
    pushOutput(pp, 0);
  }
コード例 #2
0
  /**
   * calculates one over residual score. works pretty much the same as a stand root-mean-squared
   * error metric for a single output, but weights those points that are closer to the origin (in
   * the input space) higher. Assumes only one input.
   */
  public double calcRSquared(PredictionTable pt) {
    // the diff between the first output and prediction for
    // each example
    double[] errors;
    // the distance from the original in the input space for
    // each example
    double[] distances;

    int i, j;
    double pred, actual;
    int input = pt.getInputFeatures()[0];
    int output = pt.getOutputFeatures()[0];

    int numRows = pt.getNumRows();
    errors = new double[numRows];
    distances = new double[numRows];
    // System.out.println("\n\n");
    for (i = 0; i < numRows; i++) {
      actual = pt.getDouble(i, output);
      pred = pt.getDoublePrediction(i, 0);
      errors[i] = actual - pred;

      // System.out.println("bin:" + i +
      //		"actual:"+actual+" pred:"+pred+"error:"+errors[i]);

      distances[i] = pt.getDouble(i, input);
    }
    // System.out.println("\n");

    double weight;
    double r2 = 0;
    ;
    for (i = 0; i < numRows; i++) {
      weight = calcWeight(distances[i]);
      //	System.out.println("dist:"+distances[i] + " weight:" + weight);
      r2 += weight * errors[i] * errors[i];
      //	System.out.println("errors:"+errors[i] + " r2 tally:" + r2);
      //	System.out.println();
    }
    return r2;
  }