示例#1
0
 private DRFModel runDRF(Frame data, PrepData dprep) {
   DRF drf = new DRF();
   drf.source = data;
   drf.response = dprep.prep(data);
   drf.ntrees = 1;
   drf.invoke();
   return UKV.get(drf.dest());
 }
示例#2
0
  // ==========================================================================
  public void basicGBM(String fname, String hexname, PrepData prep) {
    File file = TestUtil.find_test_file(fname);
    if (file == null) return; // Silently abort test if the file is missing
    Key fkey = NFSFileVec.make(file);
    Key dest = Key.make(hexname);
    GBM gbm = null;
    Frame fr = null;
    try {
      gbm = new GBM();
      gbm.source = fr = ParseDataset2.parse(dest, new Key[] {fkey});
      UKV.remove(fkey);
      int idx = prep.prep(fr);
      if (idx < 0) {
        gbm.classification = false;
        idx = ~idx;
      }
      String rname = fr._names[idx];
      gbm.response = fr.vecs()[idx];
      fr.remove(idx); // Move response to the end
      fr.add(rname, gbm.response);
      gbm.ntrees = 4;
      gbm.max_depth = 4;
      gbm.min_rows = 1;
      gbm.nbins = 50;
      gbm.cols = new int[fr.numCols()];
      for (int i = 0; i < gbm.cols.length; i++) gbm.cols[i] = i;
      gbm.learn_rate = .2f;
      gbm.invoke();

      fr = gbm.score(gbm.source);

      GBM.GBMModel gbmmodel = UKV.get(gbm.dest());
      // System.out.println(gbmmodel.toJava());

    } finally {
      UKV.remove(dest); // Remove original hex frame key
      if (gbm != null) {
        UKV.remove(gbm.dest()); // Remove the model
        UKV.remove(gbm.response._key);
        gbm.remove(); // Remove GBM Job
        if (fr != null) fr.remove();
      }
    }
  }
示例#3
0
  void testModelAdaptation(String train, String test, PrepData dprep, boolean exactAdaptation) {
    DRFModel model = null;
    Frame frTest = null;
    Frame frTrain = null;
    Key trainKey = Key.make("train.hex");
    Key testKey = Key.make("test.hex");
    Frame[] frAdapted = null;
    try {
      // Prepare a simple model
      frTrain = parseFrame(trainKey, train);
      model = runDRF(frTrain, dprep);
      // Load test dataset - test data contains input columns matching train data,
      // BUT each input requires adaptation. Moreover, test data contains additional columns
      // containing correct value mapping.
      frTest = parseFrame(testKey, test);
      Assert.assertEquals(
          "TEST CONF ERROR: The test dataset should contain 2*<number of input columns>+1!",
          2 * (frTrain.numCols() - 1) + 1,
          frTest.numCols());
      // Adapt test dataset
      frAdapted = model.adapt(frTest, exactAdaptation); // do/do not perform translation to enums
      Assert.assertEquals("Adapt method should return two frames", 2, frAdapted.length);
      Assert.assertEquals(
          "Test expects that all columns in  test dataset has to be adapted",
          dprep.needAdaptation(frTrain),
          frAdapted[1].numCols());

      // Compare vectors
      Frame adaptedFrame = frAdapted[0];
      // System.err.println(frTest.toStringAll());
      // System.err.println(adaptedFrame.toStringAll());

      for (int av = 0; av < frTrain.numCols() - 1; av++) {
        int ev = av + frTrain.numCols();
        Vec actV = adaptedFrame.vecs()[av];
        Vec expV = frTest.vecs()[ev];
        Assert.assertEquals(
            "Different number of rows in test vectors", expV.length(), actV.length());
        for (long r = 0; r < expV.length(); r++) {
          if (expV.isNA(r))
            Assert.assertTrue(
                "Badly adapted vector - expected NA! Col: " + av + ", row: " + r, actV.isNA(r));
          else {
            Assert.assertTrue(
                "Badly adapted vector - expected value but get NA! Col: " + av + ", row: " + r,
                !actV.isNA(r));
            Assert.assertEquals(
                "Badly adapted vector - wrong values! Col: " + av + ", row: " + r,
                expV.at8(r),
                actV.at8(r));
          }
        }
      }

    } finally {
      // Test cleanup
      if (model != null) UKV.remove(model._selfKey);
      if (frTrain != null) frTrain.remove();
      UKV.remove(trainKey);
      if (frTest != null) frTest.remove();
      UKV.remove(testKey);
      // Remove adapted vectors which were saved into KV-store, rest of vectors are remove by
      // frTest.remove()
      if (frAdapted != null) frAdapted[1].remove();
    }
  }