コード例 #1
0
ファイル: RandomForestTest.java プロジェクト: NidhiMehta/h2o
  // Test kaggle/creditsample-test data
  @org.junit.Test
  public void kaggle_credit() {
    Key okey = loadAndParseFile("credit.hex", "smalldata/kaggle/creditsample-training.csv.gz");
    UKV.remove(Key.make("smalldata/kaggle/creditsample-training.csv.gz_UNZIPPED"));
    UKV.remove(Key.make("smalldata\\kaggle\\creditsample-training.csv.gz_UNZIPPED"));
    ValueArray val = DKV.get(okey).get();

    // Check parsed dataset
    final int n = new int[] {4, 2, 1}[ValueArray.LOG_CHK - 20];
    assertEquals("Number of chunks", n, val.chunks());
    assertEquals("Number of rows", 150000, val.numRows());
    assertEquals("Number of cols", 12, val.numCols());

    // setup default values for DRF
    int ntrees = 3;
    int depth = 30;
    int gini = StatType.GINI.ordinal();
    int seed = 42;
    StatType statType = StatType.values()[gini];
    final int cols[] =
        new int[] {0, 2, 3, 4, 5, 7, 8, 9, 10, 11, 1}; // ignore column 6, classify column 1

    // Start the distributed Random Forest
    final Key modelKey = Key.make("model");
    DRFJob result =
        hex.rf.DRF.execute(
            modelKey,
            cols,
            val,
            ntrees,
            depth,
            1024,
            statType,
            seed,
            true,
            null,
            -1,
            Sampling.Strategy.RANDOM,
            1.0f,
            null,
            0,
            0,
            false);
    // Wait for completion on all nodes
    RFModel model = result.get();

    assertEquals("Number of classes", 2, model.classes());
    assertEquals("Number of trees", ntrees, model.size());

    model.deleteKeys();
    UKV.remove(modelKey);
    UKV.remove(okey);
  }
コード例 #2
0
ファイル: RandomForestTest.java プロジェクト: NidhiMehta/h2o
  /*@org.junit.Test*/ public void covtype() {
    // Key okey = loadAndParseFile("covtype.hex", "smalldata/covtype/covtype.20k.data");
    // Key okey = loadAndParseFile("covtype.hex", "../datasets/UCI/UCI-large/covtype/covtype.data");
    // Key okey = loadAndParseFile("covtype.hex", "/home/0xdiag/datasets/standard/covtype.data");
    Key okey = loadAndParseFile("mnist.hex", "smalldata/mnist/mnist8m.10k.csv.gz");
    // Key okey = loadAndParseFile("mnist.hex", "/home/0xdiag/datasets/mnist/mnist8m.csv");
    ValueArray val = UKV.get(okey);

    // setup default values for DRF
    int ntrees = 8;
    int depth = 999;
    int gini = StatType.ENTROPY.ordinal();
    int seed = 42;
    StatType statType = StatType.values()[gini];
    final int cols[] = new int[val.numCols()];
    for (int i = 1; i < cols.length; i++) cols[i] = i - 1;
    cols[cols.length - 1] = 0; // Class is in column 0 for mnist

    // Start the distributed Random Forest
    final Key modelKey = Key.make("model");
    DRFJob result =
        hex.rf.DRF.execute(
            modelKey,
            cols,
            val,
            ntrees,
            depth,
            1024,
            statType,
            seed,
            true,
            null,
            -1,
            Sampling.Strategy.RANDOM,
            1.0f,
            null,
            0,
            0,
            false);
    // Wait for completion on all nodes
    RFModel model = result.get();

    assertEquals("Number of classes", 10, model.classes());
    assertEquals("Number of trees", ntrees, model.size());

    model.deleteKeys();
    UKV.remove(modelKey);
    UKV.remove(okey);
  }
コード例 #3
0
ファイル: RFView.java プロジェクト: jayfans3/h2o
 static final int findResponseIdx(RFModel model) {
   String nresponse = model.responseName();
   ValueArray ary = UKV.get(model._dataKey);
   int idx = 0;
   for (ValueArray.Column cols : ary._cols)
     if (nresponse.equals(cols._name)) return idx;
     else idx++;
   return -1;
 }
コード例 #4
0
ファイル: RFView.java プロジェクト: jayfans3/h2o
  @Override
  protected Response serve() {
    int tasks = 0;
    int finished = 0;
    RFModel model = _modelKey.value();
    double[] weights = _weights.value();
    // Finish refresh after rf model is done and confusion matrix for all trees is computed
    boolean done = false;
    int classCol = _classCol.specified() ? _classCol.value() : findResponseIdx(model);

    tasks = model._totalTrees;
    finished = model.size();

    // Handle cancelled/aborted jobs
    if (_job.value() != null) {
      Job jjob = Job.findJob(_job.value());
      if (jjob != null && jjob.isCancelled())
        return Response.error(
            jjob.exception == null ? "Job was cancelled by user!" : jjob.exception);
    }

    JsonObject response = defaultJsonResponse();
    // CM return and possible computation is requested
    if (!_noCM.value() && (finished == tasks || _iterativeCM.value()) && finished > 0) {
      // Compute the highest number of trees which is less then a threshold
      int modelSize = tasks * _refreshThresholdCM.value() / 100;
      modelSize =
          modelSize == 0 || finished == tasks ? finished : modelSize * (finished / modelSize);

      // Get the computing the matrix - if no job is computing, then start a new job
      Job cmJob =
          ConfusionTask.make(
              model, modelSize, _dataKey.value()._key, classCol, weights, _oobee.value());
      // Here the the job is running - it saved a CM which can be already finished or in invalid
      // state.
      CMFinal confusion = UKV.get(cmJob.dest());
      // if the matrix is valid, report it in the JSON
      if (confusion != null && confusion.valid() && modelSize > 0) {
        // finished += 1;
        JsonObject cm = new JsonObject();
        JsonArray cmHeader = new JsonArray();
        JsonArray matrix = new JsonArray();
        cm.addProperty(JSON_CM_TYPE, _oobee.value() ? "OOB error estimate" : "full scoring");
        cm.addProperty(JSON_CM_CLASS_ERR, confusion.classError());
        cm.addProperty(JSON_CM_ROWS_SKIPPED, confusion.skippedRows());
        cm.addProperty(JSON_CM_ROWS, confusion.rows());
        // create the header
        for (String s : cfDomain(confusion, 1024)) cmHeader.add(new JsonPrimitive(s));
        cm.add(JSON_CM_HEADER, cmHeader);
        // add the matrix
        final int nclasses = confusion.dimension();
        JsonArray classErrors = new JsonArray();
        for (int crow = 0; crow < nclasses; ++crow) {
          JsonArray row = new JsonArray();
          int classHitScore = 0;
          for (int ccol = 0; ccol < nclasses; ++ccol) {
            row.add(new JsonPrimitive(confusion.matrix(crow, ccol)));
            if (crow != ccol) classHitScore += confusion.matrix(crow, ccol);
          }
          // produce infinity members in case of 0.f/0
          classErrors.add(
              new JsonPrimitive(
                  (float) classHitScore / (classHitScore + confusion.matrix(crow, crow))));
          matrix.add(row);
        }
        cm.add(JSON_CM_CLASSES_ERRORS, classErrors);
        cm.add(JSON_CM_MATRIX, matrix);
        cm.addProperty(JSON_CM_TREES, modelSize);
        response.add(JSON_CM, cm);
        // Signal end only and only if all trees were generated and confusion matrix is valid
        done = finished == tasks;
      }
    } else if (_noCM.value() && finished == tasks) done = true;

    // Trees
    JsonObject trees = new JsonObject();
    trees.addProperty(Constants.TREE_COUNT, model.size());
    if (model.size() > 0) {
      trees.add(Constants.TREE_DEPTH, model.depth().toJson());
      trees.add(Constants.TREE_LEAVES, model.leaves().toJson());
    }
    response.add(Constants.TREES, trees);

    // Build a response
    Response r;
    if (done) {
      r = jobDone(response);
      r.addHeader(
          "<div class='alert'>"
              + /*RFScore.link(MODEL_KEY, model._key, "Use this model for scoring.") */ GeneratePredictionsPage
                  .link(model._key, "Predict!")
              + " </div>");
    } else {
      r = Response.poll(response, finished, tasks);
    }
    r.setBuilder(JSON_CM, new ConfusionMatrixBuilder());
    r.setBuilder(TREES, new TreeListBuilder());
    return r;
  }