예제 #1
0
 @Override
 protected void execImpl() {
   Vec va = null, vp = null, avp = null;
   try {
     if (classification) {
       // Create a new vectors - it is cheap since vector are only adaptation vectors
       va = vactual.toEnum(); // always returns TransfVec
       actual_domain = va._domain;
       vp = vpredict.toEnum(); // always returns TransfVec
       predicted_domain = vp._domain;
       if (!Arrays.equals(actual_domain, predicted_domain)) {
         domain = Utils.domainUnion(actual_domain, predicted_domain);
         int[][] vamap = Model.getDomainMapping(domain, actual_domain, true);
         va = TransfVec.compose((TransfVec) va, vamap, domain, false); // delete original va
         int[][] vpmap = Model.getDomainMapping(domain, predicted_domain, true);
         vp = TransfVec.compose((TransfVec) vp, vpmap, domain, false); // delete original vp
       } else domain = actual_domain;
       // The vectors are from different groups => align them, but properly delete it after
       // computation
       if (!va.group().equals(vp.group())) {
         avp = vp;
         vp = va.align(vp);
       }
       cm = new CM(domain.length).doAll(va, vp)._cm;
     } else {
       mse = new CM(1).doAll(vactual, vpredict).mse();
     }
     return;
   } finally { // Delete adaptation vectors
     if (va != null) UKV.remove(va._key);
     if (vp != null) UKV.remove(vp._key);
     if (avp != null) UKV.remove(avp._key);
   }
 }
예제 #2
0
파일: GBMTest.java 프로젝트: hihihippp/h2o
  // Adapt a trained model to a test dataset with different enums
  /*@Test*/ public void testModelAdapt() {
    File file1 = TestUtil.find_test_file("./smalldata/kaggle/KDDTrain.arff.gz");
    Key fkey1 = NFSFileVec.make(file1);
    Key dest1 = Key.make("KDDTrain.hex");
    File file2 = TestUtil.find_test_file("./smalldata/kaggle/KDDTest.arff.gz");
    Key fkey2 = NFSFileVec.make(file2);
    Key dest2 = Key.make("KDDTest.hex");
    GBM gbm = null;
    Frame fr = null;
    try {
      gbm = new GBM();
      gbm.source = ParseDataset2.parse(dest1, new Key[] {fkey1});
      UKV.remove(fkey1);
      gbm.response = gbm.source.remove(41); // Response is col 41
      gbm.ntrees = 2;
      gbm.max_depth = 8;
      gbm.learn_rate = 0.2f;
      gbm.min_rows = 10;
      gbm.nbins = 50;
      gbm.invoke();

      // The test data set has a few more enums than the train
      Frame ftest = ParseDataset2.parse(dest2, new Key[] {fkey2});
      Frame preds = gbm.score(ftest);

    } finally {
      UKV.remove(dest1); // Remove original hex frame key
      if (gbm != null) {
        UKV.remove(gbm.dest()); // Remove the model
        UKV.remove(gbm.response._key);
        gbm.remove(); // Remove GBM Job
        if (fr != null) fr.remove();
      }
    }
  }
예제 #3
0
  // Test kaggle/creditsample-test data
  @org.junit.Test
  public void kaggle_credit() {
    Key okey = loadAndParseFile("credit.hex", "smalldata/kaggle/creditsample-training.csv.gz");
    UKV.remove(Key.make("smalldata/kaggle/creditsample-training.csv.gz_UNZIPPED"));
    UKV.remove(Key.make("smalldata\\kaggle\\creditsample-training.csv.gz_UNZIPPED"));
    ValueArray val = DKV.get(okey).get();

    // Check parsed dataset
    final int n = new int[] {4, 2, 1}[ValueArray.LOG_CHK - 20];
    assertEquals("Number of chunks", n, val.chunks());
    assertEquals("Number of rows", 150000, val.numRows());
    assertEquals("Number of cols", 12, val.numCols());

    // setup default values for DRF
    int ntrees = 3;
    int depth = 30;
    int gini = StatType.GINI.ordinal();
    int seed = 42;
    StatType statType = StatType.values()[gini];
    final int cols[] =
        new int[] {0, 2, 3, 4, 5, 7, 8, 9, 10, 11, 1}; // ignore column 6, classify column 1

    // Start the distributed Random Forest
    final Key modelKey = Key.make("model");
    DRFJob result =
        hex.rf.DRF.execute(
            modelKey,
            cols,
            val,
            ntrees,
            depth,
            1024,
            statType,
            seed,
            true,
            null,
            -1,
            Sampling.Strategy.RANDOM,
            1.0f,
            null,
            0,
            0,
            false);
    // Wait for completion on all nodes
    RFModel model = result.get();

    assertEquals("Number of classes", 2, model.classes());
    assertEquals("Number of trees", ntrees, model.size());

    model.deleteKeys();
    UKV.remove(modelKey);
    UKV.remove(okey);
  }
예제 #4
0
 @Test
 public void testFullVectAssignment() {
   Key k = loadAndParseKey("cars.hex", "smalldata/cars.csv");
   Key k2 = executeExpression("cars.hex");
   testDataFrameStructure(k2, 406, 8);
   UKV.remove(k2);
   k2 = executeExpression("a5 = cars.hex[2]");
   testVectorExpression("a5", 8, 8, 8, 4, 6, 6);
   UKV.remove(k2);
   UKV.remove(k);
   UKV.remove(Key.make("a5"));
 }
예제 #5
0
 @Test
 public void testColumnSelectors() {
   Key k = loadAndParseKey("cars.hex", "smalldata/cars.csv");
   Key k2 = executeExpression("cars.hex[2]");
   testDataFrameStructure(k2, 406, 1);
   testKeyValues(k2, 8, 8, 8, 4, 6, 6);
   UKV.remove(k2);
   k2 = executeExpression("cars.hex$year");
   testDataFrameStructure(k2, 406, 1);
   testKeyValues(k2, 73, 70, 72, 76, 78, 81);
   UKV.remove(k2);
   UKV.remove(k);
 }
예제 #6
0
  /*@org.junit.Test*/ public void covtype() {
    // Key okey = loadAndParseFile("covtype.hex", "smalldata/covtype/covtype.20k.data");
    // Key okey = loadAndParseFile("covtype.hex", "../datasets/UCI/UCI-large/covtype/covtype.data");
    // Key okey = loadAndParseFile("covtype.hex", "/home/0xdiag/datasets/standard/covtype.data");
    Key okey = loadAndParseFile("mnist.hex", "smalldata/mnist/mnist8m.10k.csv.gz");
    // Key okey = loadAndParseFile("mnist.hex", "/home/0xdiag/datasets/mnist/mnist8m.csv");
    ValueArray val = UKV.get(okey);

    // setup default values for DRF
    int ntrees = 8;
    int depth = 999;
    int gini = StatType.ENTROPY.ordinal();
    int seed = 42;
    StatType statType = StatType.values()[gini];
    final int cols[] = new int[val.numCols()];
    for (int i = 1; i < cols.length; i++) cols[i] = i - 1;
    cols[cols.length - 1] = 0; // Class is in column 0 for mnist

    // Start the distributed Random Forest
    final Key modelKey = Key.make("model");
    DRFJob result =
        hex.rf.DRF.execute(
            modelKey,
            cols,
            val,
            ntrees,
            depth,
            1024,
            statType,
            seed,
            true,
            null,
            -1,
            Sampling.Strategy.RANDOM,
            1.0f,
            null,
            0,
            0,
            false);
    // Wait for completion on all nodes
    RFModel model = result.get();

    assertEquals("Number of classes", 10, model.classes());
    assertEquals("Number of trees", ntrees, model.size());

    model.deleteKeys();
    UKV.remove(modelKey);
    UKV.remove(okey);
  }
예제 #7
0
 // Write-lock & delete 'k'.  Will fail if 'k' is locked by anybody other than 'job_key'
 public static void delete(Key k, Key job_key) {
   if (k == null) return;
   Value val = DKV.get(k);
   if (val == null) return; // Or just nothing there to delete
   if (!val.isLockable()) UKV.remove(k); // Simple things being deleted
   else ((Lockable) val.get()).delete(job_key, 0.0f); // Lockable being deleted
 }
예제 #8
0
 @Test
 public void testDifferentSizeOps() {
   Key cars = loadAndParseKey("cars.hex", "smalldata/cars.csv");
   Key poker = loadAndParseKey("p.hex", "smalldata/poker/poker-hand-testing.data");
   testVectorExpression("cars.hex$year + p.hex[1]", 74, 82, 81, 84, 86, 81);
   testVectorExpression("cars.hex$year - p.hex[1]", 72, 58, 63, 62, 64, 71);
   testVectorExpression("cars.hex$year * p.hex[1]", 73, 840, 648, 803, 825, 380);
   // testVectorExpression("cars.hex$year / p.hex[1]", 73, 70/12, 8, 76/11, 78/11, 15.2); // hard
   // to get the numbers right + not needed no new coverage
   testVectorExpression("p.hex[1] + cars.hex$year", 74, 82, 81, 84, 86, 81);
   testVectorExpression("p.hex[1] - cars.hex$year", -72, -58, -63, -62, -64, -71);
   testVectorExpression("p.hex[1] * cars.hex$year", 73, 840, 648, 803, 825, 380);
   // testVectorExpression("p.hex[1] / cars.hex$year", 1/73, 12/70, 0.125, 11/76, 11/78, 5/81);
   UKV.remove(poker);
   UKV.remove(cars);
 }
예제 #9
0
 public static ValueArray loadAndParseKey(Key okey, String path) {
   FileIntegrityChecker c = FileIntegrityChecker.check(new File(path),false);
   Key k = c.syncDirectory(null,null,null,null);
   ParseDataset.forkParseDataset(okey, new Key[] { k }, null).get();
   UKV.remove(k);
   ValueArray res = DKV.get(okey).get();
   return res;
 }
예제 #10
0
파일: Frame.java 프로젝트: vmlaker/h2o
 /** Actually remove/delete all Vecs from memory, not just from the Frame. */
 public void remove(Futures fs) {
   if (vecs().length > 0) {
     for (Vec v : _vecs) UKV.remove(v._key, fs);
   }
   _names = new String[0];
   _vecs = new Vec[0];
   _keys = new Key[0];
 }
예제 #11
0
 protected void testScalarExpression(String expr, double result) {
   Key key = executeExpression(expr);
   ValueArray va = ValueArray.value(key);
   assertEquals(va.numRows(), 1);
   assertEquals(va.numCols(), 1);
   assertEquals(result, va.datad(0, 0), 0.0);
   UKV.remove(key);
 }
예제 #12
0
 @Test
 public void testLargeDataOps() {
   Key poker = loadAndParseKey("p.hex", "smalldata/poker/poker-hand-testing.data");
   testVectorExpression("p.hex[1] + p.hex[2]", 2, 15, 13, 15, 12, 7);
   testVectorExpression("p.hex[1] - p.hex[2]", 0, 9, 5, 7, 10, 3);
   testVectorExpression("p.hex[1] * p.hex[2]", 1, 36, 36, 44, 11, 10);
   testVectorExpression("p.hex[1] / p.hex[2]", 1.0, 4.0, 2.25, 2.75, 11.0, 2.5);
   UKV.remove(poker);
 }
예제 #13
0
 @Test
 public void testVectorOperators() {
   Key k = loadAndParseKey("cars.hex", "smalldata/cars.csv");
   testVectorExpression("cars.hex[2] + cars.hex$year", 81, 78, 80, 80, 84, 87);
   testVectorExpression("cars.hex[2] - cars.hex$year", -65, -62, -64, -72, -72, -75);
   testVectorExpression("cars.hex[2] * cars.hex$year", 584, 560, 576, 304, 468, 486);
   testVectorExpression("cars.hex$year / cars.hex[2]", 9.125, 8.75, 9.0, 19.0, 13.0, 13.5);
   UKV.remove(k);
 }
예제 #14
0
파일: Frame.java 프로젝트: NidhiMehta/h2o
 /** Actually remove/delete all Vecs from memory, not just from the Frame. */
 public void remove(Futures fs) {
   if (_vecs.length > 0) {
     VectorGroup vg = _vecs[0].group();
     for (Vec v : _vecs) UKV.remove(v._key, fs);
     DKV.remove(vg._key);
   }
   _names = new String[0];
   _vecs = new Vec[0];
 }
예제 #15
0
파일: GBMTest.java 프로젝트: hihihippp/h2o
  // ==========================================================================
  public void basicGBM(String fname, String hexname, PrepData prep) {
    File file = TestUtil.find_test_file(fname);
    if (file == null) return; // Silently abort test if the file is missing
    Key fkey = NFSFileVec.make(file);
    Key dest = Key.make(hexname);
    GBM gbm = null;
    Frame fr = null;
    try {
      gbm = new GBM();
      gbm.source = fr = ParseDataset2.parse(dest, new Key[] {fkey});
      UKV.remove(fkey);
      int idx = prep.prep(fr);
      if (idx < 0) {
        gbm.classification = false;
        idx = ~idx;
      }
      String rname = fr._names[idx];
      gbm.response = fr.vecs()[idx];
      fr.remove(idx); // Move response to the end
      fr.add(rname, gbm.response);
      gbm.ntrees = 4;
      gbm.max_depth = 4;
      gbm.min_rows = 1;
      gbm.nbins = 50;
      gbm.cols = new int[fr.numCols()];
      for (int i = 0; i < gbm.cols.length; i++) gbm.cols[i] = i;
      gbm.learn_rate = .2f;
      gbm.invoke();

      fr = gbm.score(gbm.source);

      GBM.GBMModel gbmmodel = UKV.get(gbm.dest());
      // System.out.println(gbmmodel.toJava());

    } finally {
      UKV.remove(dest); // Remove original hex frame key
      if (gbm != null) {
        UKV.remove(gbm.dest()); // Remove the model
        UKV.remove(gbm.response._key);
        gbm.remove(); // Remove GBM Job
        if (fr != null) fr.remove();
      }
    }
  }
예제 #16
0
파일: Job.java 프로젝트: shjgiser/h2o
 public void onException(Throwable ex) {
   UKV.remove(dest());
   Value v = DKV.get(progressKey());
   if( v != null ) {
     ChunkProgress p = v.get();
     p = p.error(ex.getMessage());
     DKV.put(progressKey(), p);
   }
   cancel(ex);
 }
예제 #17
0
파일: Job.java 프로젝트: shjgiser/h2o
 protected String[] getVectorDomain(final Vec v) {
   assert v==null || v.isInt() || v.isEnum() : "Cannot get vector domain!";
   if (v==null) return null;
   String[] r = null;
   if (v.isEnum()) {
     r = v.domain();
   } else {
     Vec tmp = v.toEnum();
     r = tmp.domain();
     UKV.remove(tmp._key);
   }
   return r;
 }
예제 #18
0
 @Test
 public void testBigLargeExpression() {
   Key poker = loadAndParseKey("p.hex", "smalldata/poker/poker-hand-testing.data");
   testVectorExpression(
       "p.hex[1] / p.hex[2] + p.hex[3] * p.hex[1] - p.hex[5] + (2* p.hex[1] - (p.hex[2]+3))",
       8,
       35,
       63.25,
       85.75,
       116.0,
       43.5);
   UKV.remove(poker);
 }
예제 #19
0
파일: AUC.java 프로젝트: Jrobinso09/h2o
  @Override
  protected void execImpl() {
    Vec va = null, vp;
    try {
      va = vactual.toEnum(); // always returns TransfVec
      vp = vpredict;
      // The vectors are from different groups => align them, but properly delete it after
      // computation
      if (!va.group().equals(vp.group())) {
        vp = va.align(vp);
      }
      // compute thresholds, if not user-given
      if (thresholds != null) {
        sort(thresholds);
        if (Utils.minValue(thresholds) < 0)
          throw new IllegalArgumentException("Minimum threshold cannot be negative.");
        if (Utils.maxValue(thresholds) > 1)
          throw new IllegalArgumentException("Maximum threshold cannot be greater than 1.");
      } else {
        HashSet hs = new HashSet();
        final int bins = (int) Math.min(vpredict.length(), 200l);
        final long stride = Math.max(vpredict.length() / bins, 1);
        for (int i = 0; i < bins; ++i)
          hs.add(
              new Float(
                  vpredict.at(i * stride))); // data-driven thresholds TODO: use percentiles (from
        // Summary2?)
        for (int i = 0; i < 51; ++i)
          hs.add(new Float(i / 50.)); // always add 0.02-spaced thresholds from 0 to 1

        // created sorted vector of unique thresholds
        thresholds = new float[hs.size()];
        int i = 0;
        for (Object h : hs) {
          thresholds[i++] = (Float) h;
        }
        sort(thresholds);
      }
      // compute CMs
      aucdata =
          new AUCData()
              .compute(
                  new AUCTask(thresholds, va.mean()).doAll(va, vp).getCMs(),
                  thresholds,
                  va._domain,
                  threshold_criterion);
    } finally { // Delete adaptation vectors
      if (va != null) UKV.remove(va._key);
    }
  }
예제 #20
0
  // ==========================================================================
  /*@Test*/ public void testBasicCRUD() {
    // Parse a file with many broken enum/string columns
    Key k = Key.make("zip.hex");
    try {
      Frame fr = TestUtil.parseFrame(k, "smalldata/zip_code/zip_code_database.csv.gz");
      System.out.println(fr);

      StringBuilder sb = new StringBuilder();
      String[] fs = fr.toStringHdr(sb);
      int lim = Math.min(40, (int) fr.numRows());
      for (int i = 0; i < lim; i++) fr.toString(sb, fs, i);
      System.out.println(sb.toString());
    } finally {
      UKV.remove(k);
    }
  }
예제 #21
0
  public static String store2Hdfs(Key srcKey) {
    assert srcKey._kb[0] != Key.ARRAYLET_CHUNK;
    assert PersistHdfs.getPathForKey(srcKey) != null; // Validate key name
    Value v = DKV.get(srcKey);
    if (v == null) return "Key " + srcKey + " not found";
    if (v._isArray == 0) { // Simple chunk?
      v.setHdfs(); // Set to HDFS and be done
      return null; // Success
    }

    // For ValueArrays, make the .hex header
    ValueArray ary = ValueArray.value(v);
    String err = PersistHdfs.freeze(srcKey, ary);
    if (err != null) return err;

    // The task managing which chunks to write next,
    // store in a known key
    TaskStore2HDFS ts = new TaskStore2HDFS(srcKey);
    Key selfKey = ts.selfKey();
    UKV.put(selfKey, ts);

    // Then start writing chunks in-order with the zero chunk
    H2ONode chk0_home = ValueArray.getChunkKey(0, srcKey).home_node();
    RPC.call(ts.chunkHome(), ts);

    // Watch the progress key until it gets removed or an error appears
    long idx = 0;
    while (UKV.get(selfKey, ts) != null) {
      if (ts._indexFrom != idx) {
        System.out.print(" " + idx + "/" + ary.chunks());
        idx = ts._indexFrom;
      }
      if (ts._err != null) { // Found an error?
        UKV.remove(selfKey); // Cleanup & report
        return ts._err;
      }
      try {
        Thread.sleep(100);
      } catch (InterruptedException e) {
      }
    }
    System.out.println(" " + ary.chunks() + "/" + ary.chunks());

    // PersistHdfs.refreshHDFSKeys();
    return null;
  }
예제 #22
0
 protected void testExecFail(String expr, int errorPos) {
   DKV.write_barrier();
   int keys = H2O.store_size();
   try {
     int i = UNIQUE.getAndIncrement();
     System.err.println("result" + (new Integer(i).toString()) + ": " + expr);
     Key key = Exec.exec(expr, "result" + (new Integer(i).toString()));
     UKV.remove(key);
     assertTrue("An exception should have been thrown.", false);
   } catch (ParserException e) {
     assertTrue(false);
   } catch (EvaluationException e) {
     if (errorPos != -1) assertEquals(errorPos, e._pos);
   }
   DKV.write_barrier();
   assertEquals("Keys were not properly deleted for expression " + expr, keys, H2O.store_size());
 }
예제 #23
0
  @Override
  public void compute() {
    String path = null; // getPathFromValue(val);
    ValueArray ary = ValueArray.value(_arykey);
    Key self = selfKey();

    while (_indexFrom < ary.chunks()) {
      Key ckey = ary.getChunkKey(_indexFrom++);
      if (!ckey.home()) { // Next chunk not At Home?
        RPC.call(chunkHome(), this); // Hand the baton off to the next node/chunk
        return;
      }
      Value val = DKV.get(ckey); // It IS home, so get the data
      _err = PersistHdfs.appendChunk(_arykey, val);
      if (_err != null) return;
      UKV.put(self, this); // Update the progress/self key
    }
    // We did the last chunk.  Removing the selfKey is the signal to the web
    // thread that All Done.
    UKV.remove(self);
  }
예제 #24
0
파일: Job.java 프로젝트: shjgiser/h2o
 @Override public void remove() {
   super.remove();
   UKV.remove(_progress);
 }
예제 #25
0
파일: Job.java 프로젝트: shjgiser/h2o
 /** Delete all vectors in given trash. */
 private void cleanupTrash(HashSet<Key> trash, Futures fs) {
   for (Key k : trash) UKV.remove(k, fs);
 }
예제 #26
0
파일: Frame.java 프로젝트: vmlaker/h2o
  public Frame deepSlice(Object orows, Object ocols) {
    // ocols is either a long[] or a Frame-of-1-Vec
    long[] cols;
    if (ocols == null) {
      cols = (long[]) ocols;
      assert cols == null;
    } else {
      if (ocols instanceof long[]) {
        cols = (long[]) ocols;
      } else if (ocols instanceof Frame) {
        Frame fr = (Frame) ocols;
        if (fr.numCols() != 1) {
          throw new IllegalArgumentException(
              "Columns Frame must have only one column (actually has "
                  + fr.numCols()
                  + " columns)");
        }

        long n = fr.anyVec().length();
        if (n > MAX_EQ2_COLS) {
          throw new IllegalArgumentException(
              "Too many requested columns (requested " + n + ", max " + MAX_EQ2_COLS + ")");
        }

        cols = new long[(int) n];
        Vec v = fr._vecs[0];
        for (long i = 0; i < v.length(); i++) {
          cols[(int) i] = v.at8(i);
        }
      } else {
        throw new IllegalArgumentException(
            "Columns is specified by an unsupported data type ("
                + ocols.getClass().getName()
                + ")");
      }
    }

    // Since cols is probably short convert to a positive list.
    int c2[] = null;
    if (cols == null) {
      c2 = new int[numCols()];
      for (int i = 0; i < c2.length; i++) c2[i] = i;
    } else if (cols.length == 0) {
      c2 = new int[0];
    } else if (cols[0] > 0) {
      c2 = new int[cols.length];
      for (int i = 0; i < cols.length; i++)
        c2[i] = (int) cols[i] - 1; // Convert 1-based cols to zero-based
    } else {
      c2 = new int[numCols() - cols.length];
      int j = 0;
      for (int i = 0; i < numCols(); i++) {
        if (j >= cols.length || i < (-cols[j] - 1)) c2[i - j] = i;
        else j++;
      }
    }
    for (int i = 0; i < c2.length; i++)
      if (c2[i] >= numCols())
        throw new IllegalArgumentException(
            "Trying to select column " + c2[i] + " but only " + numCols() + " present.");
    if (c2.length == 0)
      throw new IllegalArgumentException(
          "No columns selected (did you try to select column 0 instead of column 1?)");

    // Do Da Slice
    // orows is either a long[] or a Vec
    if (orows == null)
      return new DeepSlice((long[]) orows, c2)
          .doAll(c2.length, this)
          .outputFrame(names(c2), domains(c2));
    else if (orows instanceof long[]) {
      final long CHK_ROWS = 1000000;
      long[] rows = (long[]) orows;
      if (rows.length == 0)
        return new DeepSlice(rows, c2).doAll(c2.length, this).outputFrame(names(c2), domains(c2));
      if (rows[0] < 0)
        return new DeepSlice(rows, c2).doAll(c2.length, this).outputFrame(names(c2), domains(c2));
      // Vec'ize the index array
      AppendableVec av = new AppendableVec("rownames");
      int r = 0;
      int c = 0;
      while (r < rows.length) {
        NewChunk nc = new NewChunk(av, c);
        long end = Math.min(r + CHK_ROWS, rows.length);
        for (; r < end; r++) {
          nc.addNum(rows[r]);
        }
        nc.close(c++, null);
      }
      Vec c0 = av.close(null); // c0 is the row index vec
      Frame fr2 =
          new Slice(c2, this)
              .doAll(c2.length, new Frame(new String[] {"rownames"}, new Vec[] {c0}))
              .outputFrame(names(c2), domains(c2));
      UKV.remove(c0._key); // Remove hidden vector
      return fr2;
    }
    Frame frows = (Frame) orows;
    Vec vrows = frows.anyVec();
    // It's a compatible Vec; use it as boolean selector.
    // Build column names for the result.
    Vec[] vecs = new Vec[c2.length + 1];
    String[] names = new String[c2.length + 1];
    for (int i = 0; i < c2.length; ++i) {
      vecs[i] = _vecs[c2[i]];
      names[i] = _names[c2[i]];
    }
    vecs[c2.length] = vrows;
    names[c2.length] = "predicate";
    return new DeepSelect()
        .doAll(c2.length, new Frame(names, vecs))
        .outputFrame(names(c2), domains(c2));
  }
예제 #27
0
파일: Expr2Test.java 프로젝트: jmcclell/h2o
  @Test
  public void testBasicExpr1() {
    Key dest = Key.make("h.hex");
    try {
      File file = TestUtil.find_test_file("smalldata/tnc3_10.csv");
      // File file = TestUtil.find_test_file("smalldata/iris/iris_wheader.csv");
      // File file = TestUtil.find_test_file("smalldata/cars.csv");
      Key fkey = NFSFileVec.make(file);
      ParseDataset2.parse(dest, new Key[] {fkey});
      UKV.remove(fkey);

      checkStr("1.23"); // 1.23
      checkStr(" 1.23 + 2.34"); // 3.57
      checkStr(" 1.23 + 2.34 * 3"); // 10.71, L2R eval order
      checkStr(" 1.23 2.34"); // Syntax error
      checkStr("1.23 < 2.34"); // 1
      checkStr("1.23 <=2.34"); // 1
      checkStr("1.23 > 2.34"); // 0
      checkStr("1.23 >=2.34"); // 0
      checkStr("1.23 ==2.34"); // 0
      checkStr("1.23 !=2.34"); // 1
      checkStr("h.hex"); // Simple ref
      checkStr("+(1.23,2.34)"); // prefix 3.57
      checkStr("+(1.23)"); // Syntax error, not enuf args
      checkStr("+(1.23,2,3)"); // Syntax error, too many args
      checkStr("h.hex[2,3]"); // Scalar selection
      checkStr("h.hex[2,+]"); // Function not allowed
      checkStr("h.hex[2+4,-4]"); // Select row 6, all-cols but 4
      checkStr("h.hex[1,-1]; h.hex[2,-2]; h.hex[3,-3]"); // Partial results are freed
      checkStr("h.hex[2+3,h.hex]"); // Error: col selector has too many columns
      checkStr("h.hex[2,]"); // Row 2 all cols
      checkStr("h.hex[,3]"); // Col 3 all rows
      checkStr("h.hex+1"); // Broadcast scalar over ary
      checkStr("h.hex-h.hex");
      checkStr("1.23+(h.hex-h.hex)");
      checkStr("(1.23+h.hex)-h.hex");
      checkStr("min(h.hex,1+2)");
      checkStr("max(h.hex,1+2)");
      checkStr("is.na(h.hex)");
      checkStr("nrow(h.hex)*3");
      checkStr("h.hex[nrow(h.hex)-1,ncol(h.hex)-1]");
      checkStr("1=2");
      checkStr("x");
      checkStr("x+2");
      checkStr("2+x");
      checkStr("x=1");
      checkStr("x<-1"); // Alternative R assignment syntax
      checkStr("x=1;x=h.hex"); // Allowed to change types via shadowing at REPL level
      checkStr("a=h.hex"); // Top-level assignment back to H2O.STORE
      checkStr("x<-+");
      checkStr("(h.hex+1)<-2");
      checkStr("h.hex[nrow(h.hex=1),]");
      checkStr("h.hex[2,3]<-4;");
      checkStr("c(1,3,5)");
      checkStr("function(=){x+1}(2)");
      checkStr("function(x,=){x+1}(2)");
      checkStr("function(x,<-){x+1}(2)");
      checkStr("function(x,x){x+1}(2)");
      checkStr("function(x,y,z){x[]}(h.hex,1,2)");
      checkStr("function(x){x[]}(2)");
      checkStr("function(x){x+1}(2)");
      checkStr("function(x){y=x+y}(2)");
      checkStr("function(x){}(2)");
      checkStr("function(x){y=x*2; y+1}(2)");
      checkStr("function(x){y=1+2}(2)");
      checkStr("function(x){y=1+2;y=c(1,2)}"); // Not allowed to change types in inner scopes
      checkStr("sum(1,2,3)");
      checkStr("sum(c(1,3,5))");
      checkStr("sum(4,c(1,3,5),2,6)");
      checkStr("sum(1,h.hex,3)");
      checkStr("h.hex[,c(1,3,5)]");
      checkStr("h.hex[c(1,3,5),]");
      checkStr("a=c(11,22,33,44,55,66); a[c(2,6,1),]");
      checkStr("function(a){a[];a=1}");
      checkStr("a=1;a=2;function(x){x=a;a=3}");
      checkStr("a=h.hex;function(x){x=a;a=3;nrow(x)*a}(a)");
      checkStr("a=h.hex;a[,1]=(a[,1]==8)");
      // Higher-order function typing: fun is typed in the body of function(x)
      checkStr("function(funy){function(x){funy(x)*funy(x)}}(sgn)(-2)");
      // Filter/selection
      checkStr("h.hex[h.hex[,2]>4,]");
      checkStr("a=c(1,2,3);a[a[,1]>10,1]");
      checkStr("apply(h.hex,2,sum)");
      checkStr("y=5;apply(h.hex,2,function(x){x[]+y})");
      checkStr("apply(h.hex,2,function(x){x=1;h.hex})");
      checkStr("apply(h.hex,2,function(x){h.hex})");
      checkStr("mean=function(x){apply(x,2,sum)/nrow(x)};mean(h.hex)");

      // Conditional selection;
      checkStr("ifelse(0,1,2)");
      checkStr("ifelse(0,h.hex+1,h.hex+2)");
      checkStr("ifelse(h.hex>3,99,h.hex)"); // Broadcast selection
      checkStr("ifelse(0,+,*)(1,2)"); // Select functions
      checkStr("(0 ? + : *)(1,2)"); // Trinary select
      checkStr("(1? h.hex : (h.hex+1))[1,2]"); // True (vs false) test
      // Impute the mean
      checkStr(
          "apply(h.hex,2,function(x){total=sum(ifelse(is.na(x),0,x)); rcnt=nrow(x)-sum(is.na(x)); mean=total / rcnt; ifelse(is.na(x),mean,x)})");
      checkStr("factor(h.hex[,5])");

      // Slice assignment & map
      checkStr("h.hex[,2]");
      checkStr("h.hex[,2]+1");
      checkStr("h.hex[,3]=3.3;h.hex"); // Replace a col with a constant
      checkStr("h.hex[,3]=h.hex[,2]+1"); // Replace a col
      checkStr("h.hex[,ncol(h.hex)+1]=4"); // Extend a col
      checkStr("a=ncol(h.hex);h.hex[,c(a+1,a+2)]=5"); // Extend two cols
      checkStr("table(h.hex)");
      checkStr("table(h.hex[,3])");
      checkStr("h.hex[,4] != 29 || h.hex[,2] < 305 && h.hex[,2] < 81");
      checkStr("a=cbind(c(1,2,3), c(4,5,6))");
      // checkStr("h.hex[h.hex[,2]>4,]=-99");
      // checkStr("h.hex[2,]=h.hex[7,]");
      // checkStr("h.hex[c(1,3,5),1] = h.hex[c(2,4,6),2]");
      // checkStr("h.hex[c(1,3,5),1] = h.hex[c(2,4),2]");
      // checkStr("map()");
      // checkStr("map(1)");
      // checkStr("map(+,h.hex,1)");
      // checkStr("map(+,1,2)");
      // checkStr("map(function(x){x[];1},h.hex)");
      // checkStr("map(function(a,b,d){a+b+d},h.hex,h.hex,1)");
      // checkStr("map(function(a,b){a+ncol(b)},h.hex,h.hex)");

      checkStr("a=0;x=0"); // Delete keys from global scope

    } finally {
      UKV.remove(dest); // Remove original hex frame key
    }
  }
예제 #28
0
  void testModelAdaptation(String train, String test, PrepData dprep, boolean exactAdaptation) {
    DRFModel model = null;
    Frame frTest = null;
    Frame frTrain = null;
    Key trainKey = Key.make("train.hex");
    Key testKey = Key.make("test.hex");
    Frame[] frAdapted = null;
    try {
      // Prepare a simple model
      frTrain = parseFrame(trainKey, train);
      model = runDRF(frTrain, dprep);
      // Load test dataset - test data contains input columns matching train data,
      // BUT each input requires adaptation. Moreover, test data contains additional columns
      // containing correct value mapping.
      frTest = parseFrame(testKey, test);
      Assert.assertEquals(
          "TEST CONF ERROR: The test dataset should contain 2*<number of input columns>+1!",
          2 * (frTrain.numCols() - 1) + 1,
          frTest.numCols());
      // Adapt test dataset
      frAdapted = model.adapt(frTest, exactAdaptation); // do/do not perform translation to enums
      Assert.assertEquals("Adapt method should return two frames", 2, frAdapted.length);
      Assert.assertEquals(
          "Test expects that all columns in  test dataset has to be adapted",
          dprep.needAdaptation(frTrain),
          frAdapted[1].numCols());

      // Compare vectors
      Frame adaptedFrame = frAdapted[0];
      // System.err.println(frTest.toStringAll());
      // System.err.println(adaptedFrame.toStringAll());

      for (int av = 0; av < frTrain.numCols() - 1; av++) {
        int ev = av + frTrain.numCols();
        Vec actV = adaptedFrame.vecs()[av];
        Vec expV = frTest.vecs()[ev];
        Assert.assertEquals(
            "Different number of rows in test vectors", expV.length(), actV.length());
        for (long r = 0; r < expV.length(); r++) {
          if (expV.isNA(r))
            Assert.assertTrue(
                "Badly adapted vector - expected NA! Col: " + av + ", row: " + r, actV.isNA(r));
          else {
            Assert.assertTrue(
                "Badly adapted vector - expected value but get NA! Col: " + av + ", row: " + r,
                !actV.isNA(r));
            Assert.assertEquals(
                "Badly adapted vector - wrong values! Col: " + av + ", row: " + r,
                expV.at8(r),
                actV.at8(r));
          }
        }
      }

    } finally {
      // Test cleanup
      if (model != null) UKV.remove(model._selfKey);
      if (frTrain != null) frTrain.remove();
      UKV.remove(trainKey);
      if (frTest != null) frTest.remove();
      UKV.remove(testKey);
      // Remove adapted vectors which were saved into KV-store, rest of vectors are remove by
      // frTest.remove()
      if (frAdapted != null) frAdapted[1].remove();
    }
  }
예제 #29
0
 // ---
 // Test some basic expressions on "cars.csv"
 @Test
 public void testBasicCrud() {
   Key k = loadAndParseKey("cars.hex", "smalldata/cars.csv");
   testVectorExpression("cars.hex[1] + cars.hex$cylinders", 21, 23, 25, 24, 23, 36.7);
   UKV.remove(k);
 }
예제 #30
0
 public void testVectorExpression(
     String expr, double n1, double n2, double n3, double nx3, double nx2, double nx1) {
   Key key = executeExpression(expr);
   testKeyValues(key, n1, n2, n3, nx3, nx2, nx1);
   UKV.remove(key);
 }