コード例 #1
0
ファイル: Request2.java プロジェクト: jmcclell/h2o
 private static String[] addSplit(String[] values, String value) {
   if (value.contains(":")) {
     double[] gen = NumberSequence.parseGenerator(value, false, 1);
     for (double d : gen) values = Utils.append(values, "" + d);
   } else if (value.length() > 0) values = Utils.append(values, value);
   return values;
 }
コード例 #2
0
ファイル: DRFModelAdaptTest.java プロジェクト: jamesliu/h2o
 @Test
 public void testBasics_1() {
   // Simple domain mapping
   Assert.assertArrayEquals(a(0, 1, 2, 3), Utils.mapping(a(0, 1, 2, 3)));
   Assert.assertArrayEquals(a(0, 1, 2, -1, 3), Utils.mapping(a(0, 1, 2, 4)));
   Assert.assertArrayEquals(a(0, -1, 1), Utils.mapping(a(-1, 1)));
   Assert.assertArrayEquals(a(0, -1, 1, -1, 2), Utils.mapping(a(-1, 1, 3)));
 }
コード例 #3
0
ファイル: MRUtils.java プロジェクト: Jrobinso09/h2o
 public final float[] rel_dist() {
   float[] rel = new float[_ys.length];
   for (int i = 0; i < _ys.length; ++i) rel[i] = (float) _ys[i];
   final float sum = Utils.sum(rel);
   assert (sum != 0.);
   Utils.div(rel, sum);
   return rel;
 }
コード例 #4
0
ファイル: KMeansModel.java プロジェクト: raghavendrabhat/h2o
 @Override
 public void reduce(DRemoteTask rt) {
   KMeansScore kms = (KMeansScore) rt;
   if (_rows == null) {
     _rows = kms._rows;
     _dist = kms._dist;
   } else {
     Utils.add(_rows, kms._rows);
     Utils.add(_dist, kms._dist);
   }
 }
コード例 #5
0
ファイル: JUnitRunner.java プロジェクト: hihihippp/h2o
  public static void main(String[] args) throws Exception {
    // Can be necessary to run in parallel to other clouds, so find open ports
    int[] ports = new int[3];
    int port = 54321;
    for( int i = 0; i < ports.length; i++ ) {
      for( ;; ) {
        if( isOpen(port) && isOpen(port + 1) ) {
          ports[i] = port;
          port += 2;
          break;
        }
        port++;
      }
    }
    String flat = "";
    for( int i = 0; i < ports.length; i++ )
      flat += "127.0.0.1:" + ports[i] + "\n";
    // Force all IPs to local so that users can run with a firewall
    String[] a = new String[] { "-ip", "127.0.0.1", "-flatfile", Utils.writeFile(flat).getAbsolutePath() };
    H2O.OPT_ARGS.ip = "127.0.0.1";
    args = (String[]) ArrayUtils.addAll(a, args);

    ArrayList<Node> nodes = new ArrayList<Node>();
    for( int i = 1; i < ports.length; i++ )
      nodes.add(new NodeVM(Utils.append(args, "-port", "" + ports[i])));

    args = Utils.append(new String[] { "-mainClass", Master.class.getName() }, args);
    Node master = new NodeVM(Utils.append(args, "-port", "" + ports[0]));
    nodes.add(master);

    File out = null, err = null, sandbox = new File("sandbox");
    sandbox.mkdirs();
    Utils.clearFolder(sandbox);
    for( int i = 0; i < nodes.size(); i++ ) {
      out = File.createTempFile("junit-" + i + "-out-", null, sandbox);
      err = File.createTempFile("junit-" + i + "-err-", null, sandbox);
      nodes.get(i).persistIO(out.getAbsolutePath(), err.getAbsolutePath());
      nodes.get(i).start();
    }

    int exit = master.waitFor();
    if( exit != 0 ) {
      Log.log(out, System.out);
      Thread.sleep(100); // Or mixed (?)
      Log.log(err, System.err);
    }
    for( Node node : nodes )
      node.kill();
    if( exit == 0 )
      System.out.println("OK");
    System.exit(exit);
  }
コード例 #6
0
ファイル: AUC.java プロジェクト: Jrobinso09/h2o
  @Override
  protected void execImpl() {
    Vec va = null, vp;
    try {
      va = vactual.toEnum(); // always returns TransfVec
      vp = vpredict;
      // The vectors are from different groups => align them, but properly delete it after
      // computation
      if (!va.group().equals(vp.group())) {
        vp = va.align(vp);
      }
      // compute thresholds, if not user-given
      if (thresholds != null) {
        sort(thresholds);
        if (Utils.minValue(thresholds) < 0)
          throw new IllegalArgumentException("Minimum threshold cannot be negative.");
        if (Utils.maxValue(thresholds) > 1)
          throw new IllegalArgumentException("Maximum threshold cannot be greater than 1.");
      } else {
        HashSet hs = new HashSet();
        final int bins = (int) Math.min(vpredict.length(), 200l);
        final long stride = Math.max(vpredict.length() / bins, 1);
        for (int i = 0; i < bins; ++i)
          hs.add(
              new Float(
                  vpredict.at(i * stride))); // data-driven thresholds TODO: use percentiles (from
        // Summary2?)
        for (int i = 0; i < 51; ++i)
          hs.add(new Float(i / 50.)); // always add 0.02-spaced thresholds from 0 to 1

        // created sorted vector of unique thresholds
        thresholds = new float[hs.size()];
        int i = 0;
        for (Object h : hs) {
          thresholds[i++] = (Float) h;
        }
        sort(thresholds);
      }
      // compute CMs
      aucdata =
          new AUCData()
              .compute(
                  new AUCTask(thresholds, va.mean()).doAll(va, vp).getCMs(),
                  thresholds,
                  va._domain,
                  threshold_criterion);
    } finally { // Delete adaptation vectors
      if (va != null) UKV.remove(va._key);
    }
  }
コード例 #7
0
ファイル: TypeMap.java プロジェクト: plumcube/h2o
 public static void main(String[] args) throws Exception {
   Log._dontDie = true; // Ignore fatal class load error
   ArrayList<String> list = new ArrayList<String>();
   for (String name : Boot.getClasses()) {
     if (!name.equals("water.api.RequestServer")
         && !name.equals("water.External")
         && !name.startsWith("water.r.")) {
       Class c = Class.forName(name);
       if (Freezable.class.isAssignableFrom(c)) list.add(c.getName());
     }
   }
   Collections.sort(list);
   String s =
       ""
           + //
           "package water;\n"
           + //
           "\n"
           + //
           "// Do not edit - generated\n"
           + //
           "public class TypeMapGen {\n"
           + //
           "  static final String[] CLAZZES = {\n"
           + //
           "    \" BAD\",                     // 0: BAD\n"
           + //
           "    \"[B\",                       // 1: Array of Bytes\n";
   for (String c : list) s += "    \"" + c + "\",\n";
   s += "  };\n";
   s += "}";
   Utils.writeFile(new File("src/main/java/water/TypeMapGen.java"), s);
   Log.info("Generated TypeMap");
 }
コード例 #8
0
ファイル: ConfusionMatrix.java プロジェクト: BhaskarPros/h2o
 @Override
 protected void execImpl() {
   Vec va = null, vp = null, avp = null;
   try {
     if (classification) {
       // Create a new vectors - it is cheap since vector are only adaptation vectors
       va = vactual.toEnum(); // always returns TransfVec
       actual_domain = va._domain;
       vp = vpredict.toEnum(); // always returns TransfVec
       predicted_domain = vp._domain;
       if (!Arrays.equals(actual_domain, predicted_domain)) {
         domain = Utils.domainUnion(actual_domain, predicted_domain);
         int[][] vamap = Model.getDomainMapping(domain, actual_domain, true);
         va = TransfVec.compose((TransfVec) va, vamap, domain, false); // delete original va
         int[][] vpmap = Model.getDomainMapping(domain, predicted_domain, true);
         vp = TransfVec.compose((TransfVec) vp, vpmap, domain, false); // delete original vp
       } else domain = actual_domain;
       // The vectors are from different groups => align them, but properly delete it after
       // computation
       if (!va.group().equals(vp.group())) {
         avp = vp;
         vp = va.align(vp);
       }
       cm = new CM(domain.length).doAll(va, vp)._cm;
     } else {
       mse = new CM(1).doAll(vactual, vpredict).mse();
     }
     return;
   } finally { // Delete adaptation vectors
     if (va != null) UKV.remove(va._key);
     if (vp != null) UKV.remove(vp._key);
     if (avp != null) UKV.remove(avp._key);
   }
 }
コード例 #9
0
ファイル: Job.java プロジェクト: chouclee/h2o
    @Override
    protected void init() {
      if (validation != null && n_folds != 0)
        throw new UnsupportedOperationException(
            "Cannot specify a validation dataset and non-zero number of cross-validation folds.");
      if (n_folds < 0)
        throw new UnsupportedOperationException(
            "The number of cross-validation folds must be >= 0.");
      super.init();
      xval_models = new Key[n_folds];
      for (int i = 0; i < xval_models.length; ++i)
        xval_models[i] = Key.make(dest().toString() + "_xval" + i);

      int rIndex = 0;
      for (int i = 0; i < source.vecs().length; i++)
        if (source.vecs()[i] == response) {
          rIndex = i;
          break;
        }
      _responseName = source._names != null && rIndex >= 0 ? source._names[rIndex] : "response";

      _train = selectVecs(source);
      _names = new String[cols.length];
      for (int i = 0; i < cols.length; i++) _names[i] = source._names[cols[i]];

      // Compute source response domain
      if (classification) _sourceResponseDomain = getVectorDomain(response);
      // Is validation specified?
      if (validation != null) {
        // Extract a validation response
        int idx = validation.find(source.names()[rIndex]);
        if (idx == -1)
          throw new IllegalArgumentException(
              "Validation set does not have a response column called " + _responseName);
        _validResponse = validation.vecs()[idx];
        // Compute output confusion matrix domain for classification:
        // - if validation dataset is specified then CM domain is union of train and validation
        // response domains
        //   else it is only domain of response column.
        if (classification) {
          _validResponseDomain = getVectorDomain(_validResponse);
          if (_validResponseDomain != null) {
            _cmDomain = Utils.domainUnion(_sourceResponseDomain, _validResponseDomain);
            if (!Arrays.deepEquals(_sourceResponseDomain, _validResponseDomain)) {
              _fromModel2CM =
                  Model.getDomainMapping(
                      _cmDomain,
                      _sourceResponseDomain,
                      false); // transformation from model produced response ~> cmDomain
              _fromValid2CM =
                  Model.getDomainMapping(
                      _cmDomain,
                      _validResponseDomain,
                      false); // transformation from validation response domain ~> cmDomain
            }
          } else _cmDomain = _sourceResponseDomain;
        } /* end of if classification */
      } else if (classification) _cmDomain = _sourceResponseDomain;
    }
コード例 #10
0
ファイル: Job.java プロジェクト: chouclee/h2o
 @Override
 public void validateRaw(String value) {
   if (Utils.contains(value, Key.ILLEGAL_USER_KEY_CHARS))
     throw new IllegalArgumentException(
         "Key '"
             + value
             + "' contains illegal character! Please avoid these characters: "
             + Key.ILLEGAL_USER_KEY_CHARS);
 }
コード例 #11
0
ファイル: DRFModelAdaptTest.java プロジェクト: jamesliu/h2o
 @Test
 public void testBasics_2() {
   Assert.assertArrayEquals(
       a(2, 30, 400, 5000), Utils.compose(Utils.mapping(a(0, 1, 2, 3)), a(2, 30, 400, 5000)));
   Assert.assertArrayEquals(
       a(2, 30, 400, -1, 5000), Utils.compose(Utils.mapping(a(0, 1, 2, 4)), a(2, 30, 400, 5000)));
   Assert.assertArrayEquals(
       a(2, -1, 30), Utils.compose(Utils.mapping(a(-1, 1)), a(2, 30, 400, 5000)));
   Assert.assertArrayEquals(
       a(2, -1, 30, -1, 400), Utils.compose(Utils.mapping(a(-1, 1, 3)), a(2, 30, 400, 5000)));
 }
コード例 #12
0
ファイル: ConfusionMatrix.java プロジェクト: BhaskarPros/h2o
 @Override
 public void reduce(CM cm) {
   if (_cm != null && cm._cm != null) {
     Utils.add(_cm, cm._cm);
   } else {
     assert (_mse != Double.NaN && cm._mse != Double.NaN);
     assert (_cm == null && cm._cm == null);
     _mse += cm._mse;
     _count += cm._count;
   }
 }
コード例 #13
0
ファイル: FrameSplitter.java プロジェクト: Jfeng3/h2o
 public FrameSplitter(Frame dataset, float[] ratios, Key[] destKeys, Key jobKey) {
   assert ratios.length > 0 : "No ratio specified!";
   assert ratios.length < 100 : "Too many frame splits demanded!";
   this.dataset = dataset;
   this.ratios = ratios;
   this.destKeys =
       destKeys != null ? destKeys : Utils.generateNumKeys(dataset._key, ratios.length + 1);
   assert this.destKeys.length == this.ratios.length + 1
       : "Unexpected number of destination keys.";
   this.jobKey = jobKey;
 }
コード例 #14
0
ファイル: FrameSplitter.java プロジェクト: Jfeng3/h2o
 @Override
 protected void setupLocal() {
   // Precompute the first input chunk index and start row inside that chunk for this partition
   Vec anyInVec = _srcVecs[0];
   long[] partSizes = Utils.partitione(anyInVec.length(), _ratios);
   long pnrows = 0;
   for (int p = 0; p < _partIdx; p++) pnrows += partSizes[p];
   long[] espc = anyInVec._espc;
   while (_pcidx < espc.length - 1 && (pnrows -= (espc[_pcidx + 1] - espc[_pcidx])) > 0)
     _pcidx++;
   assert pnrows <= 0;
   _psrow = (int) (pnrows + espc[_pcidx + 1] - espc[_pcidx]);
 }
コード例 #15
0
ファイル: Utils.java プロジェクト: raghavendrabhat/h2o
 public static byte [] unzipBytes(byte [] bs, Compression cmp) {
   InputStream is = null;
   int off = 0;
   try {
     switch(cmp) {
     case NONE: // No compression
       return bs;
     case ZIP: {
       ZipInputStream zis = new ZipInputStream(new ByteArrayInputStream(bs));
       ZipEntry ze = zis.getNextEntry(); // Get the *FIRST* entry
       // There is at least one entry in zip file and it is not a directory.
       if( ze != null && !ze.isDirectory() ) {
         is = zis;
         break;
       }
       zis.close();
       return bs; // Don't crash, ignore file if cannot unzip
     }
     case GZIP:
       is = new GZIPInputStream(new ByteArrayInputStream(bs));
       break;
     default:
       assert false:"cmp = " + cmp;
     }
     // If reading from a compressed stream, estimate we can read 2x uncompressed
     assert( is != null ):"is is NULL, cmp = " + cmp;
     bs = new byte[bs.length * 2];
     // Now read from the (possibly compressed) stream
     while( off < bs.length ) {
       int len = is.read(bs, off, bs.length - off);
       if( len < 0 )
         break;
       off += len;
       if( off == bs.length ) { // Dataset is uncompressing alot! Need more space...
         if( bs.length >= ValueArray.CHUNK_SZ )
           break; // Already got enough
         bs = Arrays.copyOf(bs, bs.length * 2);
       }
     }
   } catch( IOException ioe ) { // Stop at any io error
     Log.err(ioe);
   } finally {
     Utils.close(is);
   }
   return bs;
 }
コード例 #16
0
ファイル: Job.java プロジェクト: shjgiser/h2o
    @Override protected void init() {
      super.init();
      // Check if it make sense to build a model
      if (source.numRows()==0)
        throw new IllegalArgumentException("Cannot build a model on empty dataset!");
      // Does not alter the Response to an Enum column if Classification is
      // asked for: instead use the classification flag to decide between
      // classification or regression.
      Vec[] vecs = source.vecs();
      for( int i = cols.length - 1; i >= 0; i-- )
        if( vecs[cols[i]] == response )
          cols = Utils.remove(cols,i);

      final boolean has_constant_response = response.isEnum() ?
              response.domain().length <= 1 : response.min() == response.max();
      if (has_constant_response)
        throw new IllegalArgumentException("Constant response column!");
    }
コード例 #17
0
ファイル: JUnitRunnerDebug.java プロジェクト: BersaKAIN/h2o
  public static void main(String[] args) throws Exception {
    int[] ports = new int[NODES];
    for (int i = 0; i < ports.length; i++) ports[i] = 64321 + i * 2;

    String flat = "";
    for (int i = 0; i < ports.length; i++) flat += "127.0.0.1:" + ports[i] + "\n";
    flat = Utils.writeFile(flat).getAbsolutePath();

    for (int i = 0; i < ports.length; i++) {
      Class c = i == 0 ? UserCode.class : H2O.class;
      // single precision
      //      new NodeCL(c, ("-ip 127.0.0.1 -single_precision -port " + ports[i] + " -flatfile " +
      // flat).split(" ")).start();

      // double precision
      new NodeCL(c, ("-ip 127.0.0.1 -port " + ports[i] + " -flatfile " + flat).split(" ")).start();
    }
  }
コード例 #18
0
ファイル: MRUtils.java プロジェクト: Jrobinso09/h2o
 /**
  * Global redistribution of a Frame (balancing of chunks), done by calling process (all-to-one +
  * one-to-all)
  *
  * @param fr Input frame
  * @param seed RNG seed
  * @param shuffle whether to shuffle the data globally
  * @return Shuffled frame
  */
 public static Frame shuffleAndBalance(
     final Frame fr, int splits, long seed, final boolean local, final boolean shuffle) {
   if ((fr.vecs()[0].nChunks() < splits || shuffle) && fr.numRows() > splits) {
     Vec[] vecs = fr.vecs().clone();
     Log.info("Load balancing dataset, splitting it into up to " + splits + " chunks.");
     long[] idx = null;
     if (shuffle) {
       idx = new long[splits];
       for (int r = 0; r < idx.length; ++r) idx[r] = r;
       Utils.shuffleArray(idx, seed);
     }
     Key keys[] = new Vec.VectorGroup().addVecs(vecs.length);
     final long rows_per_new_chunk = (long) (Math.ceil((double) fr.numRows() / splits));
     // loop over cols (same indexing for each column)
     Futures fs = new Futures();
     for (int col = 0; col < vecs.length; col++) {
       AppendableVec vec = new AppendableVec(keys[col]);
       // create outgoing chunks for this col
       NewChunk[] outCkg = new NewChunk[splits];
       for (int i = 0; i < splits; ++i) outCkg[i] = new NewChunk(vec, i);
       // loop over all incoming chunks
       for (int ckg = 0; ckg < vecs[col].nChunks(); ckg++) {
         final Chunk inCkg = vecs[col].chunkForChunkIdx(ckg);
         // loop over local rows of incoming chunks (fast path)
         for (int row = 0; row < inCkg._len; ++row) {
           int outCkgIdx =
               (int) ((inCkg._start + row) / rows_per_new_chunk); // destination chunk idx
           if (shuffle)
             outCkgIdx = (int) (idx[outCkgIdx]); // shuffle: choose a different output chunk
           assert (outCkgIdx >= 0 && outCkgIdx < splits);
           outCkg[outCkgIdx].addNum(inCkg.at0(row));
         }
       }
       for (int i = 0; i < outCkg.length; ++i) outCkg[i].close(i, fs);
       Vec t = vec.close(fs);
       t._domain = vecs[col]._domain;
       vecs[col] = t;
     }
     fs.blockForPending();
     Log.info("Load balancing done.");
     return new Frame(fr.names(), vecs);
   }
   return fr;
 }
コード例 #19
0
ファイル: FrameSplitter.java プロジェクト: Jfeng3/h2o
 // The task computes ESPC per split
 static long[ /*nsplits*/][ /*nchunks*/] computeEspcPerSplit(
     long[] espc, long len, float[] ratios) {
   assert espc.length > 0 && espc[0] == 0;
   assert espc[espc.length - 1] == len;
   long[] partSizes = Utils.partitione(len, ratios); // Split of whole vector
   int nparts = ratios.length + 1;
   long[][] r = new long[nparts][espc.length]; // espc for each partition
   long nrows = 0;
   long start = 0;
   for (int p = 0, c = 0; p < nparts; p++) {
     int nc = 0; // number of chunks for this partition
     for (; c < espc.length - 1 && (espc[c + 1] - start) <= partSizes[p]; c++)
       r[p][++nc] = espc[c + 1] - start;
     if (r[p][nc] < partSizes[p])
       r[p][++nc] = partSizes[p]; // last item in espc contains number of rows
     r[p] = Arrays.copyOf(r[p], nc + 1);
     // Transfer rest of lines to the next part
     nrows = nrows - partSizes[p];
     start += partSizes[p];
   }
   return r;
 }
コード例 #20
0
ファイル: Job.java プロジェクト: shjgiser/h2o
    @Override protected void init() {
      super.init();

      int rIndex = 0;
      for( int i = 0; i < source.vecs().length; i++ )
        if( source.vecs()[i] == response )
          rIndex = i;
      _responseName = source._names != null && rIndex >= 0 ? source._names[rIndex] : "response";

      _train = selectVecs(source);
      _names = new String[cols.length];
      for( int i = 0; i < cols.length; i++ )
        _names[i] = source._names[cols[i]];

      // Compute source response domain
      if (classification) _sourceResponseDomain = getVectorDomain(response);
      // Is validation specified?
      if( validation != null ) {
        // Extract a validation response
        int idx = validation.find(source.names()[rIndex]);
        if( idx == -1 ) throw new IllegalArgumentException("Validation set does not have a response column called "+_responseName);
        _validResponse = validation.vecs()[idx];
        // Compute output confusion matrix domain for classification:
        // - if validation dataset is specified then CM domain is union of train and validation response domains
        //   else it is only domain of response column.
        if (classification) {
          _validResponseDomain  = getVectorDomain(_validResponse);
          if (_validResponseDomain!=null) {
            _cmDomain = Utils.domainUnion(_sourceResponseDomain, _validResponseDomain);
            if (!Arrays.deepEquals(_sourceResponseDomain, _validResponseDomain)) {
              _fromModel2CM = Model.getDomainMapping(_cmDomain, _sourceResponseDomain, false); // transformation from model produced response ~> cmDomain
              _fromValid2CM = Model.getDomainMapping(_cmDomain, _validResponseDomain , false); // transformation from validation response domain ~> cmDomain
            }
          } else _cmDomain = _sourceResponseDomain;
        } /* end of if classification */
      } else if (classification) _cmDomain = _sourceResponseDomain;
    }
コード例 #21
0
ファイル: GBM.java プロジェクト: shjgiser/h2o
 @Override
 public void reduce(GammaPass gp) {
   Utils.add(_gss, gp._gss);
   Utils.add(_rss, gp._rss);
 }
コード例 #22
0
ファイル: GridSearch.java プロジェクト: hihihippp/h2o
    @Override
    public boolean toHTML(StringBuilder sb) {
      if (jobs != null) {
        DocGen.HTML.arrayHead(sb);
        sb.append("<tr class='warning'>");
        ArrayList<Argument> args = jobs[0].arguments();
        // Filter some keys to simplify UI
        args = (ArrayList<Argument>) args.clone();
        filter(
            args,
            "destination_key",
            "source",
            "cols",
            "ignored_cols_by_name",
            "response",
            "classification",
            "validation");
        for (int i = 0; i < args.size(); i++)
          sb.append("<td><b>").append(args.get(i)._name).append("</b></td>");
        sb.append("<td><b>").append("run time").append("</b></td>");
        String perf = jobs[0].speedDescription();
        if (perf != null) sb.append("<td><b>").append(perf).append("</b></td>");
        sb.append("<td><b>").append("model key").append("</b></td>");
        sb.append("<td><b>").append("prediction error").append("</b></td>");
        sb.append("<td><b>").append("F1 score").append("</b></td>");
        sb.append("</tr>");

        ArrayList<JobInfo> infos = new ArrayList<JobInfo>();
        for (Job job : jobs) {
          JobInfo info = new JobInfo();
          info._job = job;
          Object value = UKV.get(job.destination_key);
          info._model = value instanceof Model ? (Model) value : null;
          if (info._model != null) info._cm = info._model.cm();
          if (info._cm != null) info._error = info._cm.err();
          infos.add(info);
        }
        Collections.sort(
            infos,
            new Comparator<JobInfo>() {
              @Override
              public int compare(JobInfo a, JobInfo b) {
                return Double.compare(a._error, b._error);
              }
            });

        for (JobInfo info : infos) {
          sb.append("<tr>");
          for (Argument a : args) {
            try {
              Object value = a._field.get(info._job);
              String s;
              if (value instanceof int[]) s = Utils.sampleToString((int[]) value, 20);
              else s = "" + value;
              sb.append("<td>").append(s).append("</td>");
            } catch (Exception e) {
              throw new RuntimeException(e);
            }
          }
          String runTime = "Pending", speed = "";
          if (info._job.start_time != 0) {
            runTime = PrettyPrint.msecs(info._job.runTimeMs(), true);
            speed = perf != null ? PrettyPrint.msecs(info._job.speedValue(), true) : "";
          }
          sb.append("<td>").append(runTime).append("</td>");
          if (perf != null) sb.append("<td>").append(speed).append("</td>");

          String link = info._job.destination_key.toString();
          if (info._job.start_time != 0 && DKV.get(info._job.destination_key) != null) {
            if (info._model instanceof GBMModel)
              link = GBMModelView.link(link, info._job.destination_key);
            else if (info._model instanceof NeuralNetModel)
              link = NeuralNetProgress.link(info._job.self(), info._job.destination_key, link);
            if (info._model instanceof KMeans2Model)
              link = KMeans2ModelView.link(link, info._job.destination_key);
            else link = Inspect.link(link, info._job.destination_key);
          }
          sb.append("<td>").append(link).append("</td>");

          String pct = "", f1 = "";
          if (info._cm != null) {
            pct = String.format("%.2f", 100 * info._error) + "%";
            if (info._cm._arr.length == 2)
              f1 = String.format("%.2f", info._cm.precisionAndRecall());
          }
          sb.append("<td><b>").append(pct).append("</b></td>");
          sb.append("<td><b>").append(f1).append("</b></td>");
          sb.append("</tr>");
        }
        DocGen.HTML.arrayTail(sb);
      }
      return true;
    }
コード例 #23
0
ファイル: GLM2.java プロジェクト: jayfans3/h2o
 @Override
 public void callback(final GLMIterationTask glmt) {
   if (!isRunning(self())) throw new JobCancelledException();
   boolean converged = false;
   if (glmt._beta != null && glmt._val != null && _glm.family != Family.tweedie) {
     glmt._val.finalize_AIC_AUC();
     _model.setAndTestValidation(_lambdaIdx, glmt._val); // .store();
     _model.clone().update(self());
     converged = true;
     double l1pen = alpha[0] * lambda[_lambdaIdx] * glmt._n;
     double l2pen = (1 - alpha[0]) * lambda[_lambdaIdx] * glmt._n;
     final double eps = 1e-2;
     for (int i = 0; i < glmt._grad.length - 1; ++i) { // add l2 reg. term to the gradient
       glmt._grad[i] += l2pen * glmt._beta[i];
       if (glmt._beta[i] < 0) converged &= Math.abs(glmt._grad[i] - l1pen) < eps;
       else if (glmt._beta[i] > 0) converged &= Math.abs(glmt._grad[i] + l1pen) < eps;
       else converged &= LSMSolver.shrinkage(glmt._grad[i], l1pen + eps) == 0;
     }
     if (converged) Log.info("GLM converged by reaching 0 gradient/subgradient.");
     double objval = glmt._val.residual_deviance + 0.5 * l2pen * l2norm(glmt._beta);
     if (!converged && _lastResult != null && needLineSearch(glmt._beta, objval, 1)) {
       new GLMTask.GLMLineSearchTask(
               GLM2.this,
               _dinfo,
               _glm,
               _lastResult._glmt._beta,
               glmt._beta,
               1e-8,
               new LineSearchIteration())
           .asyncExec(_dinfo._adaptedFrame);
       return;
     }
     _lastResult = new IterationInfo(GLM2.this._iter - 1, objval, glmt);
   }
   double[] newBeta =
       glmt._beta != null ? glmt._beta.clone() : MemoryManager.malloc8d(glmt._xy.length);
   double[] newBetaDeNorm = null;
   ADMMSolver slvr = new ADMMSolver(lambda[_lambdaIdx], alpha[0], _addedL2);
   slvr.solve(glmt._gram, glmt._xy, glmt._yy, newBeta);
   _addedL2 = slvr._addedL2;
   if (Utils.hasNaNsOrInfs(newBeta)) {
     Log.info("GLM forcibly converged by getting NaNs and/or Infs in beta");
   } else {
     if (_dinfo._standardize) {
       newBetaDeNorm = newBeta.clone();
       double norm = 0.0; // Reverse any normalization on the intercept
       // denormalize only the numeric coefs (categoricals are not normalized)
       final int numoff = newBeta.length - _dinfo._nums - 1;
       for (int i = numoff; i < newBeta.length - 1; i++) {
         double b = newBetaDeNorm[i] * _dinfo._normMul[i - numoff];
         norm += b * _dinfo._normSub[i - numoff]; // Also accumulate the intercept adjustment
         newBetaDeNorm[i] = b;
       }
       newBetaDeNorm[newBetaDeNorm.length - 1] -= norm;
     }
     _model.setLambdaSubmodel(
         _lambdaIdx,
         newBetaDeNorm == null ? newBeta : newBetaDeNorm,
         newBetaDeNorm == null ? null : newBeta,
         _iter);
     if (beta_diff(glmt._beta, newBeta) < beta_epsilon) {
       Log.info("GLM converged by reaching fixed-point.");
       converged = true;
     }
     if (!converged && _glm.family != Family.gaussian && _iter < max_iter) {
       ++_iter;
       new GLMIterationTask(GLM2.this, _dinfo, glmt._glm, newBeta, _ymu, _reg, new Iteration())
           .asyncExec(_dinfo._adaptedFrame);
       return;
     }
   }
   // done with this lambda
   nextLambda(glmt);
 }
コード例 #24
0
ファイル: GLMRunner.java プロジェクト: NidhiMehta/h2o
 /**
  * Simple GLM wrapper to enable launching GLM from command line.
  *
  * <p>Example input: java -jar target/h2o.jar -name=test -runMethod water.util.GLMRunner
  * -file=smalldata/logreg/prostate.csv -y=CAPSULE -family=binomial
  *
  * @param args
  * @throws InterruptedException
  */
 public static void main(String[] args) throws InterruptedException {
   try {
     GLMArgs ARGS = new GLMArgs();
     new Arguments(args).extract(ARGS);
     System.out.println("==================<GLMRunner START>===================");
     ValueArray ary = Utils.loadAndParseKey(ARGS.file);
     int ycol;
     try {
       ycol = Integer.parseInt(ARGS.y);
     } catch (NumberFormatException e) {
       ycol = ary.getColumnIds(new String[] {ARGS.y})[0];
     }
     int ncols = ary.numCols();
     if (ycol < 0 || ycol >= ary.numCols()) {
       System.err.println("invalid y column: " + ycol);
       H2O.exit(-1);
     }
     int[] xcols;
     if (ARGS.xs.equalsIgnoreCase("all")) {
       xcols = new int[ncols - 1];
       for (int i = 0; i < ycol; ++i) xcols[i] = i;
       for (int i = ycol; i < ncols - 1; ++i) xcols[i] = i + 1;
     } else {
       System.out.println("xs = " + ARGS.xs);
       String[] names = ARGS.xs.split(",");
       xcols = new int[names.length];
       try {
         for (int i = 0; i < names.length; ++i) xcols[i] = Integer.valueOf(names[i]);
       } catch (NumberFormatException e) {
         xcols = ary.getColumnIds(ARGS.xs.split(","));
       }
     }
     for (int x : xcols)
       if (x < 0) {
         System.err.println("Invalid predictor specification " + ARGS.xs);
         H2O.exit(-1);
       }
     GLMJob j =
         DGLM.startGLMJob(
             DGLM.getData(ary, xcols, ycol, null, true),
             new ADMMSolver(ARGS.lambda, ARGS._alpha),
             new GLMParams(Family.valueOf(ARGS.family)),
             null,
             ARGS.xval,
             true);
     System.out.print("[GLM] computing model...");
     int progress = 0;
     while (!j.isDone()) {
       int p = (int) (100 * j.progress());
       int dots = p - progress;
       progress = p;
       for (int i = 0; i < dots; ++i) System.out.print('.');
       Thread.sleep(250);
     }
     Log.debug(Sys.GENLM, "DONE.");
     GLMModel m = j.get();
     String[] colnames = ary.colNames();
     System.out.println("Intercept" + " = " + m._beta[ncols - 1]);
     for (int i = 0; i < xcols.length; ++i) {
       System.out.println(colnames[i] + " = " + m._beta[i]);
     }
   } catch (Throwable t) {
     Log.err(t);
   } finally { // we're done. shutdown the cloud
     Log.debug(Sys.GENLM, "==================<GLMRunner DONE>===================");
     UDPRebooted.suicide(UDPRebooted.T.shutdown, H2O.SELF);
   }
 }
コード例 #25
0
ファイル: GLM2.java プロジェクト: rohit2412/h2o
    @Override
    public void callback(final GLMIterationTask glmt) {
      _model.stop_training();
      Log.info(
          "GLM2 iteration("
              + _iter
              + ") done in "
              + (System.currentTimeMillis() - _iterationStartTime)
              + "ms");
      if (!isRunning(self())) throw new JobCancelledException();
      currentLambdaIter++;
      if (glmt._val != null) {
        if (!(glmt._val.residual_deviance
            < glmt._val
                .null_deviance)) { // complete fail, look if we can restart with higher_accuracy on
          if (!highAccuracy()) {
            Log.info(
                "GLM2 reached negative explained deviance without line-search, rerunning with high accuracy settings.");
            setHighAccuracy();
            if (_lastResult != null)
              new GLMIterationTask(
                      GLM2.this,
                      _activeData,
                      glmt._glm,
                      true,
                      true,
                      true,
                      _lastResult._glmt._beta,
                      _ymu,
                      _reg,
                      new Iteration())
                  .asyncExec(_activeData._adaptedFrame);
            else if (_lambdaIdx > 2) // > 2 because 0 is null model, we don't wan to run with that
            new GLMIterationTask(
                      GLM2.this,
                      _activeData,
                      glmt._glm,
                      true,
                      true,
                      true,
                      _model.submodels[_lambdaIdx - 1].norm_beta,
                      _ymu,
                      _reg,
                      new Iteration())
                  .asyncExec(_activeData._adaptedFrame);
            else // no sane solution to go back to, start from scratch!
            new GLMIterationTask(
                      GLM2.this,
                      _activeData,
                      glmt._glm,
                      true,
                      false,
                      false,
                      null,
                      _ymu,
                      _reg,
                      new Iteration())
                  .asyncExec(_activeData._adaptedFrame);
            _lastResult = null;
            return;
          }
        }
        _model.setAndTestValidation(_lambdaIdx, glmt._val);
        _model.clone().update(self());
      }

      if (glmt._val != null && glmt._computeGradient) { // check gradient
        final double[] grad = glmt.gradient(l2pen());
        ADMMSolver.subgrad(alpha[0], lambda[_lambdaIdx], glmt._beta, grad);
        double err = 0;
        for (double d : grad)
          if (d > err) err = d;
          else if (d < -err) err = -d;
        Log.info("GLM2 gradient after " + _iter + " iterations = " + err);
        if (err <= GLM_GRAD_EPS) {
          Log.info(
              "GLM2 converged by reaching small enough gradient, with max |subgradient| = " + err);
          setNewBeta(glmt._beta);
          nextLambda(glmt, glmt._beta);
          return;
        }
      }
      if (glmt._beta != null
          && glmt._val != null
          && glmt._computeGradient
          && _glm.family != Family.tweedie) {
        if (_lastResult != null && needLineSearch(glmt._beta, objval(glmt), 1)) {
          if (!highAccuracy()) {
            setHighAccuracy();
            if (_lastResult._iter
                < (_iter - 2)) { // there is a gap form last result...return to it and start again
              final double[] prevBeta =
                  _lastResult._activeCols != _activeCols
                      ? resizeVec(_lastResult._glmt._beta, _activeCols, _lastResult._activeCols)
                      : _lastResult._glmt._beta;
              new GLMIterationTask(
                      GLM2.this,
                      _activeData,
                      glmt._glm,
                      true,
                      true,
                      true,
                      prevBeta,
                      _ymu,
                      _reg,
                      new Iteration())
                  .asyncExec(_activeData._adaptedFrame);
              return;
            }
          }
          final double[] b =
              resizeVec(_lastResult._glmt._beta, _activeCols, _lastResult._activeCols);
          assert (b.length == glmt._beta.length)
              : b.length + " != " + glmt._beta.length + ", activeCols = " + _activeCols.length;
          new GLMTask.GLMLineSearchTask(
                  GLM2.this,
                  _activeData,
                  _glm,
                  resizeVec(_lastResult._glmt._beta, _activeCols, _lastResult._activeCols),
                  glmt._beta,
                  1e-4,
                  glmt._nobs,
                  alpha[0],
                  lambda[_lambdaIdx],
                  new LineSearchIteration())
              .asyncExec(_activeData._adaptedFrame);
          return;
        }
        _lastResult = new IterationInfo(GLM2.this._iter - 1, glmt, _activeCols);
      }
      final double[] newBeta = MemoryManager.malloc8d(glmt._xy.length);
      ADMMSolver slvr = new ADMMSolver(lambda[_lambdaIdx], alpha[0], ADMM_GRAD_EPS, _addedL2);
      slvr.solve(glmt._gram, glmt._xy, glmt._yy, newBeta);
      _addedL2 = slvr._addedL2;
      if (Utils.hasNaNsOrInfs(newBeta)) {
        Log.info("GLM2 forcibly converged by getting NaNs and/or Infs in beta");
        nextLambda(glmt, glmt._beta);
      } else {
        setNewBeta(newBeta);
        final double bdiff = beta_diff(glmt._beta, newBeta);
        if (_glm.family == Family.gaussian
            || bdiff < beta_epsilon
            || _iter
                == max_iter) { // Gaussian is non-iterative and gradient is ADMMSolver's gradient =>
          // just validate and move on to the next lambda
          int diff = (int) Math.log10(bdiff);
          int nzs = 0;
          for (int i = 0; i < newBeta.length; ++i) if (newBeta[i] != 0) ++nzs;
          if (newBeta.length < 20) System.out.println("beta = " + Arrays.toString(newBeta));
          Log.info(
              "GLM2 (lambda_"
                  + _lambdaIdx
                  + "="
                  + lambda[_lambdaIdx]
                  + ") converged (reached a fixed point with ~ 1e"
                  + diff
                  + " precision) after "
                  + _iter
                  + "iterations, got "
                  + nzs
                  + " nzs");
          nextLambda(glmt, newBeta);
        } else { // not done yet, launch next iteration
          final boolean validate = higher_accuracy || (currentLambdaIter % 5) == 0;
          ++_iter;
          System.out.println("Iter = " + _iter);
          new GLMIterationTask(
                  GLM2.this,
                  _activeData,
                  glmt._glm,
                  true,
                  validate,
                  validate,
                  newBeta,
                  _ymu,
                  _reg,
                  new Iteration())
              .asyncExec(_activeData._adaptedFrame);
        }
      }
    }
コード例 #26
0
ファイル: MRUtils.java プロジェクト: Jrobinso09/h2o
  // internal version with repeat counter
  // currently hardcoded to do up to 10 tries to get a row from each class, which can be impossible
  // for certain wrong sampling ratios
  private static Frame sampleFrameStratified(
      final Frame fr,
      Vec label,
      final float[] sampling_ratios,
      final long seed,
      final boolean debug,
      int count) {
    if (fr == null) return null;
    assert (label.isEnum());
    assert (sampling_ratios != null && sampling_ratios.length == label.domain().length);
    final int labelidx = fr.find(label); // which column is the label?
    assert (labelidx >= 0);

    final boolean poisson = false; // beta feature

    Frame r =
        new MRTask2() {
          @Override
          public void map(Chunk[] cs, NewChunk[] ncs) {
            final Random rng = getDeterRNG(seed + cs[0].cidx());
            for (int r = 0; r < cs[0]._len; r++) {
              if (cs[labelidx].isNA0(r)) continue; // skip missing labels
              final int label = (int) cs[labelidx].at80(r);
              assert (sampling_ratios.length > label && label >= 0);
              int sampling_reps;
              if (poisson) {
                sampling_reps = Utils.getPoisson(sampling_ratios[label], rng);
              } else {
                final float remainder = sampling_ratios[label] - (int) sampling_ratios[label];
                sampling_reps =
                    (int) sampling_ratios[label] + (rng.nextFloat() < remainder ? 1 : 0);
              }
              for (int i = 0; i < ncs.length; i++) {
                for (int j = 0; j < sampling_reps; ++j) {
                  ncs[i].addNum(cs[i].at0(r));
                }
              }
            }
          }
        }.doAll(fr.numCols(), fr).outputFrame(fr.names(), fr.domains());

    // Confirm the validity of the distribution
    long[] dist = new ClassDist(r.vecs()[labelidx]).doAll(r.vecs()[labelidx]).dist();

    // if there are no training labels in the test set, then there is no point in sampling the test
    // set
    if (dist == null) return fr;

    if (debug) {
      long sumdist = Utils.sum(dist);
      Log.info("After stratified sampling: " + sumdist + " rows.");
      for (int i = 0; i < dist.length; ++i) {
        Log.info(
            "Class "
                + r.vecs()[labelidx].domain(i)
                + ": count: "
                + dist[i]
                + " sampling ratio: "
                + sampling_ratios[i]
                + " actual relative frequency: "
                + (float) dist[i] / sumdist * dist.length);
      }
    }

    // Re-try if we didn't get at least one example from each class
    if (Utils.minValue(dist) == 0 && count < 10) {
      Log.info(
          "Re-doing stratified sampling because not all classes were represented (unlucky draw).");
      r.delete();
      return sampleFrameStratified(fr, label, sampling_ratios, seed + 1, debug, ++count);
    }

    // shuffle intra-chunk
    Frame shuffled = shuffleFramePerChunk(r, seed + 0x580FF13);
    r.delete();

    return shuffled;
  }
コード例 #27
0
ファイル: MRUtils.java プロジェクト: Jrobinso09/h2o
  /**
   * Stratified sampling for classifiers
   *
   * @param fr Input frame
   * @param label Label vector (must be enum)
   * @param sampling_ratios Optional: array containing the requested sampling ratios per class (in
   *     order of domains), will be overwritten if it contains all 0s
   * @param maxrows Maximum number of rows in the returned frame
   * @param seed RNG seed for sampling
   * @param allowOversampling Allow oversampling of minority classes
   * @param verbose Whether to print verbose info
   * @return Sampled frame, with approximately the same number of samples from each class (or given
   *     by the requested sampling ratios)
   */
  public static Frame sampleFrameStratified(
      final Frame fr,
      Vec label,
      float[] sampling_ratios,
      long maxrows,
      final long seed,
      final boolean allowOversampling,
      final boolean verbose) {
    if (fr == null) return null;
    assert (label.isEnum());
    assert (maxrows >= label.domain().length);

    long[] dist = new ClassDist(label).doAll(label).dist();
    assert (dist.length > 0);
    Log.info(
        "Doing stratified sampling for data set containing "
            + fr.numRows()
            + " rows from "
            + dist.length
            + " classes. Oversampling: "
            + (allowOversampling ? "on" : "off"));
    if (verbose) {
      for (int i = 0; i < dist.length; ++i) {
        Log.info(
            "Class "
                + label.domain(i)
                + ": count: "
                + dist[i]
                + " prior: "
                + (float) dist[i] / fr.numRows());
      }
    }

    // create sampling_ratios for class balance with max. maxrows rows (fill existing array if not
    // null)
    if (sampling_ratios == null
        || (Utils.minValue(sampling_ratios) == 0 && Utils.maxValue(sampling_ratios) == 0)) {
      // compute sampling ratios to achieve class balance
      if (sampling_ratios == null) {
        sampling_ratios = new float[dist.length];
      }
      assert (sampling_ratios.length == dist.length);
      for (int i = 0; i < dist.length; ++i) {
        sampling_ratios[i] =
            ((float) fr.numRows() / label.domain().length) / dist[i]; // prior^-1 / num_classes
      }
      final float inv_scale =
          Utils.minValue(
              sampling_ratios); // majority class has lowest required oversampling factor to achieve
                                // balance
      if (!Float.isNaN(inv_scale) && !Float.isInfinite(inv_scale))
        Utils.div(
            sampling_ratios,
            inv_scale); // want sampling_ratio 1.0 for majority class (no downsampling)
    }

    if (!allowOversampling) {
      for (int i = 0; i < sampling_ratios.length; ++i) {
        sampling_ratios[i] = Math.min(1.0f, sampling_ratios[i]);
      }
    }

    // given these sampling ratios, and the original class distribution, this is the expected number
    // of resulting rows
    float numrows = 0;
    for (int i = 0; i < sampling_ratios.length; ++i) {
      numrows += sampling_ratios[i] * dist[i];
    }
    final long actualnumrows = Math.min(maxrows, Math.round(numrows)); // cap #rows at maxrows
    assert (actualnumrows
        >= 0); // can have no matching rows in case of sparse data where we had to fill in a
               // makeZero() vector
    Log.info("Stratified sampling to a total of " + String.format("%,d", actualnumrows) + " rows.");

    if (actualnumrows != numrows) {
      Utils.mult(
          sampling_ratios,
          (float) actualnumrows
              / numrows); // adjust the sampling_ratios by the global rescaling factor
      if (verbose)
        Log.info(
            "Downsampling majority class by "
                + (float) actualnumrows / numrows
                + " to limit number of rows to "
                + String.format("%,d", maxrows));
    }
    Log.info(
        "Majority class ("
            + label.domain()[Utils.minIndex(sampling_ratios)].toString()
            + ") sampling ratio: "
            + Utils.minValue(sampling_ratios));
    Log.info(
        "Minority class ("
            + label.domain()[Utils.maxIndex(sampling_ratios)].toString()
            + ") sampling ratio: "
            + Utils.maxValue(sampling_ratios));

    return sampleFrameStratified(fr, label, sampling_ratios, seed, verbose);
  }
コード例 #28
0
ファイル: MRUtils.java プロジェクト: Jrobinso09/h2o
 @Override
 public void reduce(ClassDist that) {
   Utils.add(_ys, that._ys);
 }
コード例 #29
0
ファイル: FrameTask.java プロジェクト: BhaskarPros/h2o
  /**
   * Extracts the values, applies regularization to numerics, adds appropriate offsets to
   * categoricals, and adapts response according to the CaseMode/CaseValue if set.
   */
  @Override
  public final void map(Chunk[] chunks, NewChunk[] outputs) {
    if (_job != null && _job.self() != null && !Job.isRunning(_job.self()))
      throw new JobCancelledException();
    final int nrows = chunks[0]._len;
    final long offset = chunks[0]._start;
    chunkInit();
    double[] nums = MemoryManager.malloc8d(_dinfo._nums);
    int[] cats = MemoryManager.malloc4(_dinfo._cats);
    double[] response = MemoryManager.malloc8d(_dinfo._responses);
    int start = 0;
    int end = nrows;

    boolean contiguous = false;
    Random skip_rng = null; // random generator for skipping rows
    if (_useFraction < 1.0) {
      skip_rng = water.util.Utils.getDeterRNG(new Random().nextLong());
      if (contiguous) {
        final int howmany = (int) Math.ceil(_useFraction * nrows);
        if (howmany > 0) {
          start = skip_rng.nextInt(nrows - howmany);
          end = start + howmany;
        }
        assert (start < nrows);
        assert (end <= nrows);
      }
    }

    long[] shuf_map = null;
    if (_shuffle) {
      shuf_map = new long[end - start];
      for (int i = 0; i < shuf_map.length; ++i) shuf_map[i] = start + i;
      Utils.shuffleArray(shuf_map, new Random().nextLong());
    }

    OUTER:
    for (int rr = start; rr < end; ++rr) {
      final int r = shuf_map != null ? (int) shuf_map[rr - start] : rr;
      if ((_dinfo._nfolds > 0 && (r % _dinfo._nfolds) == _dinfo._foldId)
          || (skip_rng != null && skip_rng.nextFloat() > _useFraction)) continue;
      for (Chunk c : chunks) if (c.isNA0(r)) continue OUTER; // skip rows with NAs!
      int i = 0, ncats = 0;
      for (; i < _dinfo._cats; ++i) {
        int c = (int) chunks[i].at80(r);
        if (c != 0) cats[ncats++] = c + _dinfo._catOffsets[i] - 1;
      }
      final int n = chunks.length - _dinfo._responses;
      for (; i < n; ++i) {
        double d = chunks[i].at0(r);
        if (_dinfo._normMul != null)
          d = (d - _dinfo._normSub[i - _dinfo._cats]) * _dinfo._normMul[i - _dinfo._cats];
        nums[i - _dinfo._cats] = d;
      }
      for (i = 0; i < _dinfo._responses; ++i) {
        response[i] = chunks[chunks.length - _dinfo._responses + i].at0(r);
        if (_dinfo._normRespMul != null)
          response[i] = (response[i] - _dinfo._normRespSub[i]) * _dinfo._normRespMul[i];
      }
      if (outputs != null && outputs.length > 0)
        processRow(offset + r, nums, ncats, cats, response, outputs);
      else processRow(offset + r, nums, ncats, cats, response);
    }
    chunkDone();
  }
コード例 #30
0
ファイル: DRF.java プロジェクト: rohit2412/h2o
// Random Forest Trees
public class DRF extends SharedTreeModelBuilder<DRF.DRFModel> {
  static final int API_WEAVER = 1; // This file has auto-gen'd doc & json fields
  public static DocGen.FieldDoc[] DOC_FIELDS; // Initialized from Auto-Gen code.

  static final boolean DEBUG_DETERMINISTIC =
      false; // enable this for deterministic version of DRF. It will use same seed for each
             // execution. I would prefere here to read this property from system properties.

  @API(
      help = "Columns to randomly select at each level, or -1 for sqrt(#cols)",
      filter = Default.class,
      lmin = -1,
      lmax = 100000)
  int mtries = -1;

  @API(
      help = "Sample rate, from 0. to 1.0",
      filter = Default.class,
      dmin = 0,
      dmax = 1,
      importance = ParamImportance.SECONDARY)
  float sample_rate = 0.6666667f;

  @API(help = "Seed for the random number generator (autogenerated)", filter = Default.class)
  long seed =
      -1; // To follow R-semantics, each call of RF should provide different seed. -1 means seed
          // autogeneration

  @API(
      help =
          "Run on one node only; no network overhead but fewer cpus used.  Suitable for small datasets.",
      filter = myClassFilter.class,
      importance = ParamImportance.SECONDARY)
  public boolean build_tree_one_node = false;

  class myClassFilter extends DRFCopyDataBoolean {
    myClassFilter() {
      super("source");
    }
  }

  @API(help = "Computed number of split features", importance = ParamImportance.EXPERT)
  protected int _mtry; // FIXME remove and replace by mtries

  @API(help = "Autogenerated seed", importance = ParamImportance.EXPERT)
  protected long _seed; // FIXME remove and replace by seed

  // Fixed seed generator for DRF
  private static final Random _seedGenerator = Utils.getDeterRNG(0xd280524ad7fe0602L);

  // --- Private data handled only on master node
  // Classification or Regression:
  // Tree votes/SSE of individual trees on OOB rows
  private transient TreeMeasures _treeMeasuresOnOOB;
  // Tree votes/SSE per individual features on permutated OOB rows
  private transient TreeMeasures[ /*features*/] _treeMeasuresOnSOOB;

  /** DRF model holding serialized tree and implementing logic for scoring a row */
  public static class DRFModel extends DTree.TreeModel {
    static final int API_WEAVER = 1; // This file has auto-gen'd doc & json fields
    public static DocGen.FieldDoc[] DOC_FIELDS; // Initialized from Auto-Gen code.

    @API(help = "Model parameters", json = true)
    private final DRF parameters; // This is used purely for printing values out.

    @Override
    public final DRF get_params() {
      return parameters;
    }

    @Override
    public final Request2 job() {
      return get_params();
    }

    @API(help = "Number of columns picked at each split")
    final int mtries;

    @API(help = "Sample rate")
    final float sample_rate;

    @API(help = "Seed")
    final long seed;

    // Params that do not affect model quality:
    //
    public DRFModel(
        DRF params,
        Key key,
        Key dataKey,
        Key testKey,
        String names[],
        String domains[][],
        String[] cmDomain,
        int ntrees,
        int max_depth,
        int min_rows,
        int nbins,
        int mtries,
        float sample_rate,
        long seed) {
      super(key, dataKey, testKey, names, domains, cmDomain, ntrees, max_depth, min_rows, nbins);
      this.parameters = params;
      this.mtries = mtries;
      this.sample_rate = sample_rate;
      this.seed = seed;
    }

    private DRFModel(DRFModel prior, DTree[] trees, TreeStats tstats) {
      super(prior, trees, tstats);
      this.parameters = prior.parameters;
      this.mtries = prior.mtries;
      this.sample_rate = prior.sample_rate;
      this.seed = prior.seed;
    }

    private DRFModel(
        DRFModel prior, double err, ConfusionMatrix cm, VarImp varimp, water.api.AUC validAUC) {
      super(prior, err, cm, varimp, validAUC);
      this.parameters = prior.parameters;
      this.mtries = prior.mtries;
      this.sample_rate = prior.sample_rate;
      this.seed = prior.seed;
    }

    @Override
    protected TreeModelType getTreeModelType() {
      return TreeModelType.DRF;
    }

    @Override
    protected float[] score0(double data[], float preds[]) {
      float[] p = super.score0(data, preds);
      int ntrees = ntrees();
      if (p.length == 1) {
        if (ntrees > 0) div(p, ntrees);
      } // regression - compute avg over all trees
      else { // classification
        float s = sum(p);
        if (s > 0) div(p, s); // unify over all classes
        p[0] = ModelUtils.getPrediction(p, data);
      }
      return p;
    }

    @Override
    protected void generateModelDescription(StringBuilder sb) {
      DocGen.HTML.paragraph(
          sb, "mtries: " + mtries + ", Sample rate: " + sample_rate + ", Seed: " + seed);
      if (testKey == null && sample_rate == 1f) {
        sb.append(
            "<div class=\"alert alert-danger\">There are now OOB data to report out-of-bag error, since sampling rate is 100%!</div>");
      }
    }

    @Override
    protected void toJavaUnifyPreds(SB bodySb) {
      if (isClassifier()) {
        bodySb.i().p("float sum = 0;").nl();
        bodySb.i().p("for(int i=1; i<preds.length; i++) sum += preds[i];").nl();
        bodySb.i().p("if (sum>0) for(int i=1; i<preds.length; i++) preds[i] /= sum;").nl();
      } else bodySb.i().p("preds[1] = preds[1]/NTREES;").nl();
    }
  }

  public Frame score(Frame fr) {
    return ((DRFModel) UKV.get(dest())).score(fr);
  }

  @Override
  protected Log.Tag.Sys logTag() {
    return Sys.DRF__;
  }

  @Override
  protected DRFModel makeModel(
      Key outputKey,
      Key dataKey,
      Key testKey,
      String[] names,
      String[][] domains,
      String[] cmDomain) {
    return new DRFModel(
        this,
        outputKey,
        dataKey,
        validation == null ? null : testKey,
        names,
        domains,
        cmDomain,
        ntrees,
        max_depth,
        min_rows,
        nbins,
        mtries,
        sample_rate,
        _seed);
  }

  @Override
  protected DRFModel makeModel(
      DRFModel model, double err, ConfusionMatrix cm, VarImp varimp, water.api.AUC validAUC) {
    return new DRFModel(model, err, cm, varimp, validAUC);
  }

  @Override
  protected DRFModel makeModel(DRFModel model, DTree ktrees[], TreeStats tstats) {
    return new DRFModel(model, ktrees, tstats);
  }

  public DRF() {
    description = "Distributed RF";
    ntrees = 50;
    max_depth = 20;
    min_rows = 1;
  }

  /** Return the query link to this page */
  public static String link(Key k, String content) {
    RString rs = new RString("<a href='/2/DRF.query?source=%$key'>%content</a>");
    rs.replace("key", k.toString());
    rs.replace("content", content);
    return rs.toString();
  }

  // ==========================================================================

  /**
   * Compute a DRF tree.
   *
   * <p>Start by splitting all the data according to some criteria (minimize variance at the
   * leaves). Record on each row which split it goes to, and assign a split number to it (for next
   * pass). On *this* pass, use the split-number to build a per-split histogram, with a
   * per-histogram-bucket variance.
   */
  @Override
  protected void execImpl() {
    logStart();
    buildModel(seed);
  }

  @Override
  protected Response redirect() {
    return DRFProgressPage.redirect(this, self(), dest());
  }

  @SuppressWarnings("unused")
  @Override
  protected void init() {
    super.init();
    // Initialize local variables
    _mtry =
        (mtries == -1)
            ? // classification: mtry=sqrt(_ncols), regression: mtry=_ncols/3
            (classification ? Math.max((int) Math.sqrt(_ncols), 1) : Math.max(_ncols / 3, 1))
            : mtries;
    if (!(1 <= _mtry && _mtry <= _ncols))
      throw new IllegalArgumentException(
          "Computed mtry should be in interval <1,#cols> but it is " + _mtry);
    if (!(0.0 < sample_rate && sample_rate <= 1.0))
      throw new IllegalArgumentException(
          "Sample rate should be interval (0,1> but it is " + sample_rate);
    if (DEBUG_DETERMINISTIC && seed == -1) _seed = 0x1321e74a0192470cL; // fixed version of seed
    else if (seed == -1) _seed = _seedGenerator.nextLong();
    else _seed = seed;
    if (sample_rate == 1f && validation != null)
      Log.warn(
          Sys.DRF__,
          "Sample rate is 100% and no validation dataset is required. There are no OOB data to perform validation!");
  }

  // Out-of-bag trees counter - only one since it is shared via k-trees
  protected Chunk chk_oobt(Chunk chks[]) {
    return chks[_ncols + 1 + _nclass + _nclass + _nclass];
  }

  @Override
  protected void initAlgo(DRFModel initialModel) {
    // Initialize TreeVotes for classification, MSE arrays for regression
    if (importance) initTreeMeasurements();
  }

  @Override
  protected void initWorkFrame(DRFModel initialModel, Frame fr) {
    if (classification)
      initialModel.setModelClassDistribution(
          new MRUtils.ClassDist(response).doAll(response).rel_dist());
  }

  @Override
  protected DRFModel buildModel(
      DRFModel model, final Frame fr, String names[], String domains[][], final Timer t_build) {
    // Append number of trees participating in on-the-fly scoring
    fr.add("OUT_BAG_TREES", response.makeZero());

    // The RNG used to pick split columns
    Random rand = createRNG(_seed);

    // Prepare working columns
    new SetWrkTask().doAll(fr);

    int tid;
    DTree[] ktrees = null;
    // Prepare tree statistics
    TreeStats tstats = new TreeStats();
    // Build trees until we hit the limit
    for (tid = 0; tid < ntrees; tid++) { // Building tid-tree
      model =
          doScoring(
              model, fr, ktrees, tid, tstats, tid == 0, !hasValidation(), build_tree_one_node);
      // At each iteration build K trees (K = nclass = response column domain size)

      // TODO: parallelize more? build more than k trees at each time, we need to care about
      // temporary data
      // Idea: launch more DRF at once.
      Timer kb_timer = new Timer();
      ktrees = buildNextKTrees(fr, _mtry, sample_rate, rand, tid);
      Log.info(Sys.DRF__, (tid + 1) + ". tree was built " + kb_timer.toString());
      if (!Job.isRunning(self())) break; // If canceled during building, do not bulkscore

      // Check latest predictions
      tstats.updateBy(ktrees);
    }

    model = doScoring(model, fr, ktrees, tid, tstats, true, !hasValidation(), build_tree_one_node);
    // Make sure that we did not miss any votes
    assert !importance
            || _treeMeasuresOnOOB.npredictors() == _treeMeasuresOnSOOB[0 /*variable*/].npredictors()
        : "Missing some tree votes in variable importance voting?!";

    return model;
  }

  private void initTreeMeasurements() {
    assert importance
        : "Tree votes should be initialized only if variable importance is requested!";
    // Preallocate tree votes
    if (classification) {
      _treeMeasuresOnOOB = new TreeVotes(ntrees);
      _treeMeasuresOnSOOB = new TreeVotes[_ncols];
      for (int i = 0; i < _ncols; i++) _treeMeasuresOnSOOB[i] = new TreeVotes(ntrees);
    } else {
      _treeMeasuresOnOOB = new TreeSSE(ntrees);
      _treeMeasuresOnSOOB = new TreeSSE[_ncols];
      for (int i = 0; i < _ncols; i++) _treeMeasuresOnSOOB[i] = new TreeSSE(ntrees);
    }
  }

  /**
   * On-the-fly version for varimp. After generation a new tree, its tree votes are collected on
   * shuffled OOB rows and variable importance is recomputed.
   *
   * <p>The <a
   * href="http://www.stat.berkeley.edu/~breiman/RandomForests/cc_home.htm#varimp">page</a> says:
   * <cite> "In every tree grown in the forest, put down the oob cases and count the number of votes
   * cast for the correct class. Now randomly permute the values of variable m in the oob cases and
   * put these cases down the tree. Subtract the number of votes for the correct class in the
   * variable-m-permuted oob data from the number of votes for the correct class in the untouched
   * oob data. The average of this number over all trees in the forest is the raw importance score
   * for variable m." </cite>
   */
  @Override
  protected VarImp doVarImpCalc(
      final DRFModel model, DTree[] ktrees, final int tid, final Frame fTrain, boolean scale) {
    // Check if we have already serialized 'ktrees'-trees in the model
    assert model.ntrees() - 1 == tid
        : "Cannot compute DRF varimp since 'ktrees' are not serialized in the model! tid=" + tid;
    assert _treeMeasuresOnOOB.npredictors() - 1 == tid
        : "Tree votes over OOB rows for this tree (var ktrees) were not found!";
    // Compute tree votes over shuffled data
    final CompressedTree[ /*nclass*/] theTree =
        model.ctree(tid); // get the last tree FIXME we should pass only keys
    final int nclasses = model.nclasses();
    Futures fs = new Futures();
    for (int var = 0; var < _ncols; var++) {
      final int variable = var;
      H2OCountedCompleter task4var =
          classification
              ? new H2OCountedCompleter() {
                @Override
                public void compute2() {
                  // Compute this tree votes over all data over given variable
                  TreeVotes cd =
                      TreeMeasuresCollector.collectVotes(
                          theTree, nclasses, fTrain, _ncols, sample_rate, variable);
                  assert cd.npredictors() == 1;
                  asVotes(_treeMeasuresOnSOOB[variable]).append(cd);
                  tryComplete();
                }
              }
              : /* regression */ new H2OCountedCompleter() {
                @Override
                public void compute2() {
                  // Compute this tree votes over all data over given variable
                  TreeSSE cd =
                      TreeMeasuresCollector.collectSSE(
                          theTree, nclasses, fTrain, _ncols, sample_rate, variable);
                  assert cd.npredictors() == 1;
                  asSSE(_treeMeasuresOnSOOB[variable]).append(cd);
                  tryComplete();
                }
              };
      H2O.submitTask(task4var); // Fork computation
      fs.add(task4var);
    }
    fs.blockForPending(); // Wait for results
    // Compute varimp for individual features (_ncols)
    final float[] varimp = new float[_ncols]; // output variable importance
    final float[] varimpSD = new float[_ncols]; // output variable importance sd
    for (int var = 0; var < _ncols; var++) {
      double[ /*2*/] imp =
          classification
              ? asVotes(_treeMeasuresOnSOOB[var]).imp(asVotes(_treeMeasuresOnOOB))
              : asSSE(_treeMeasuresOnSOOB[var]).imp(asSSE(_treeMeasuresOnOOB));
      varimp[var] = (float) imp[0];
      varimpSD[var] = (float) imp[1];
    }
    return new VarImp.VarImpMDA(varimp, varimpSD, model.ntrees());
  }

  /**
   * Fill work columns: - classification: set 1 in the corresponding wrk col according to row
   * response - regression: copy response into work column (there is only 1 work column)
   */
  private class SetWrkTask extends MRTask2<SetWrkTask> {
    @Override
    public void map(Chunk chks[]) {
      Chunk cy = chk_resp(chks);
      for (int i = 0; i < cy._len; i++) {
        if (cy.isNA0(i)) continue;
        if (classification) {
          int cls = (int) cy.at80(i);
          chk_work(chks, cls).set0(i, 1L);
        } else {
          float pred = (float) cy.at0(i);
          chk_work(chks, 0).set0(i, pred);
        }
      }
    }
  }

  // --------------------------------------------------------------------------
  // Build the next random k-trees represeint tid-th tree
  private DTree[] buildNextKTrees(Frame fr, int mtrys, float sample_rate, Random rand, int tid) {
    // We're going to build K (nclass) trees - each focused on correcting
    // errors for a single class.
    final DTree[] ktrees = new DTree[_nclass];

    // Initial set of histograms.  All trees; one leaf per tree (the root
    // leaf); all columns
    DHistogram hcs[][][] = new DHistogram[_nclass][1 /*just root leaf*/][_ncols];

    // Use for all k-trees the same seed. NOTE: this is only to make a fair
    // view for all k-trees
    long rseed = rand.nextLong();
    // Initially setup as-if an empty-split had just happened
    for (int k = 0; k < _nclass; k++) {
      assert (_distribution != null && classification)
          || (_distribution == null && !classification);
      if (_distribution == null || _distribution[k] != 0) { // Ignore missing classes
        // The Boolean Optimization
        // This optimization assumes the 2nd tree of a 2-class system is the
        // inverse of the first.  This is false for DRF (and true for GBM) -
        // DRF picks a random different set of columns for the 2nd tree.
        // if( k==1 && _nclass==2 ) continue;
        ktrees[k] = new DRFTree(fr, _ncols, (char) nbins, (char) _nclass, min_rows, mtrys, rseed);
        boolean isBinom = classification;
        new DRFUndecidedNode(
            ktrees[k],
            -1,
            DHistogram.initialHist(fr, _ncols, nbins, hcs[k][0], isBinom)); // The "root" node
      }
    }

    // Sample - mark the lines by putting 'OUT_OF_BAG' into nid(<klass>) vector
    Timer t_1 = new Timer();
    Sample ss[] = new Sample[_nclass];
    for (int k = 0; k < _nclass; k++)
      if (ktrees[k] != null)
        ss[k] =
            new Sample((DRFTree) ktrees[k], sample_rate)
                .dfork(0, new Frame(vec_nids(fr, k), vec_resp(fr, k)), build_tree_one_node);
    for (int k = 0; k < _nclass; k++) if (ss[k] != null) ss[k].getResult();
    Log.debug(Sys.DRF__, "Sampling took: + " + t_1);

    int[] leafs =
        new int
            [_nclass]; // Define a "working set" of leaf splits, from leafs[i] to tree._len for each
                       // tree i

    // ----
    // One Big Loop till the ktrees are of proper depth.
    // Adds a layer to the trees each pass.
    Timer t_2 = new Timer();
    int depth = 0;
    for (; depth < max_depth; depth++) {
      if (!Job.isRunning(self())) return null;

      hcs = buildLayer(fr, ktrees, leafs, hcs, true, build_tree_one_node);

      // If we did not make any new splits, then the tree is split-to-death
      if (hcs == null) break;
    }
    Log.debug(Sys.DRF__, "Tree build took: " + t_2);

    // Each tree bottomed-out in a DecidedNode; go 1 more level and insert
    // LeafNodes to hold predictions.
    Timer t_3 = new Timer();
    for (int k = 0; k < _nclass; k++) {
      DTree tree = ktrees[k];
      if (tree == null) continue;
      int leaf = leafs[k] = tree.len();
      for (int nid = 0; nid < leaf; nid++) {
        if (tree.node(nid) instanceof DecidedNode) {
          DecidedNode dn = tree.decided(nid);
          for (int i = 0; i < dn._nids.length; i++) {
            int cnid = dn._nids[i];
            if (cnid == -1
                || // Bottomed out (predictors or responses known constant)
                tree.node(cnid) instanceof UndecidedNode
                || // Or chopped off for depth
                (tree.node(cnid) instanceof DecidedNode
                    && // Or not possible to split
                    ((DecidedNode) tree.node(cnid))._split.col() == -1)) {
              LeafNode ln = new DRFLeafNode(tree, nid);
              ln._pred = dn.pred(i); // Set prediction into the leaf
              dn._nids[i] = ln.nid(); // Mark a leaf here
            }
          }
          // Handle the trivial non-splitting tree
          if (nid == 0 && dn._split.col() == -1) new DRFLeafNode(tree, -1, 0);
        }
      }
    } // -- k-trees are done
    Log.debug(Sys.DRF__, "Nodes propagation: " + t_3);

    // ----
    // Move rows into the final leaf rows
    Timer t_4 = new Timer();
    CollectPreds cp = new CollectPreds(ktrees, leafs).doAll(fr, build_tree_one_node);
    if (importance) {
      if (classification)
        asVotes(_treeMeasuresOnOOB)
            .append(cp.rightVotes, cp.allRows); // Track right votes over OOB rows for this tree
      else /* regression */ asSSE(_treeMeasuresOnOOB).append(cp.sse, cp.allRows);
    }
    Log.debug(Sys.DRF__, "CollectPreds done: " + t_4);

    // Collect leaves stats
    for (int i = 0; i < ktrees.length; i++)
      if (ktrees[i] != null) ktrees[i].leaves = ktrees[i].len() - leafs[i];
    // DEBUG: Print the generated K trees
    // printGenerateTrees(ktrees);

    return ktrees;
  }

  // Read the 'tree' columns, do model-specific math and put the results in the
  // fs[] array, and return the sum.  Dividing any fs[] element by the sum
  // turns the results into a probability distribution.
  @Override
  protected float score1(Chunk chks[], float fs[ /*nclass*/], int row) {
    float sum = 0;
    for (int k = 0; k < _nclass; k++) // Sum across of likelyhoods
    sum += (fs[k + 1] = (float) chk_tree(chks, k).at0(row));
    if (_nclass == 1)
      sum /=
          (float)
              chk_oobt(chks)
                  .at0(
                      row); // for regression average per trees voted for this row (only trees which
                            // have row in "out-of-bag"
    return sum;
  }

  @Override
  protected boolean inBagRow(Chunk[] chks, int row) {
    return chk_oobt(chks).at80(row) == 0;
  }

  // Collect and write predictions into leafs.
  private class CollectPreds extends MRTask2<CollectPreds> {
    /* @IN  */ final DTree _trees[]; // Read-only, shared (except at the histograms in the Nodes)
    /* @OUT */ long
        rightVotes; // number of right votes over OOB rows (performed by this tree) represented by
                    // DTree[] _trees
    /* @OUT */ long allRows; // number of all OOB rows (sampled by this tree)
    /* @OUT */ float sse; // Sum of squares for this tree only

    CollectPreds(DTree trees[], int leafs[]) {
      _trees = trees;
    }

    @Override
    public void map(Chunk[] chks) {
      final Chunk y = importance ? chk_resp(chks) : null; // Response
      final float[] rpred = importance ? new float[1 + _nclass] : null; // Row prediction
      final double[] rowdata = importance ? new double[_ncols] : null; // Pre-allocated row data
      final Chunk oobt = chk_oobt(chks); // Out-of-bag rows counter over all trees
      // Iterate over all rows
      for (int row = 0; row < oobt._len; row++) {
        boolean wasOOBRow = false;
        // For all tree (i.e., k-classes)
        for (int k = 0; k < _nclass; k++) {
          final DTree tree = _trees[k];
          if (tree == null) continue; // Empty class is ignored
          // If we have all constant responses, then we do not split even the
          // root and the residuals should be zero.
          if (tree.root() instanceof LeafNode) continue;
          final Chunk nids = chk_nids(chks, k); // Node-ids  for this tree/class
          final Chunk ct = chk_tree(chks, k); // k-tree working column holding votes for given row
          int nid = (int) nids.at80(row); // Get Node to decide from
          // Update only out-of-bag rows
          // This is out-of-bag row - but we would like to track on-the-fly prediction for the row
          if (isOOBRow(nid)) { // The row should be OOB for all k-trees !!!
            assert k == 0 || wasOOBRow
                : "Something is wrong: k-class trees oob row computing is broken! All k-trees should agree on oob row!";
            wasOOBRow = true;
            nid = oob2Nid(nid);
            if (tree.node(nid) instanceof UndecidedNode) // If we bottomed out the tree
            nid = tree.node(nid).pid(); // Then take parent's decision
            DecidedNode dn = tree.decided(nid); // Must have a decision point
            if (dn._split.col() == -1) // Unable to decide?
            dn = tree.decided(tree.node(nid).pid()); // Then take parent's decision
            int leafnid = dn.ns(chks, row); // Decide down to a leafnode
            // Setup Tree(i) - on the fly prediction of i-tree for row-th row
            //   - for classification: cumulative number of votes for this row
            //   - for regression: cumulative sum of prediction of each tree - has to be normalized
            // by number of trees
            double prediction =
                ((LeafNode) tree.node(leafnid)).pred(); // Prediction for this k-class and this row
            if (importance)
              rpred[1 + k] = (float) prediction; // for both regression and classification
            ct.set0(row, (float) (ct.at0(row) + prediction));
            // For this tree this row is out-of-bag - i.e., a tree voted for this row
            oobt.set0(
                row,
                _nclass > 1
                    ? 1
                    : oobt.at0(row)
                        + 1); // for regression track number of trees, for classification boolean
                              // flag is enough
          }
          // reset help column for this row and this k-class
          nids.set0(row, 0);
        } /* end of k-trees iteration */
        if (importance) {
          if (wasOOBRow && !y.isNA0(row)) {
            if (classification) {
              int treePred = ModelUtils.getPrediction(rpred, data_row(chks, row, rowdata));
              int actuPred = (int) y.at80(row);
              if (treePred == actuPred) rightVotes++; // No miss !
            } else { // regression
              float treePred = rpred[1];
              float actuPred = (float) y.at0(row);
              sse += (actuPred - treePred) * (actuPred - treePred);
            }
            allRows++;
          }
        }
      }
    }

    @Override
    public void reduce(CollectPreds mrt) {
      rightVotes += mrt.rightVotes;
      allRows += mrt.allRows;
      sse += mrt.sse;
    }
  }

  // A standard DTree with a few more bits.  Support for sampling during
  // training, and replaying the sample later on the identical dataset to
  // e.g. compute OOBEE.
  static class DRFTree extends DTree {
    final int _mtrys; // Number of columns to choose amongst in splits
    final long _seeds[]; // One seed for each chunk, for sampling
    final transient Random _rand; // RNG for split decisions & sampling

    DRFTree(Frame fr, int ncols, char nbins, char nclass, int min_rows, int mtrys, long seed) {
      super(fr._names, ncols, nbins, nclass, min_rows, seed);
      _mtrys = mtrys;
      _rand = createRNG(seed);
      _seeds = new long[fr.vecs()[0].nChunks()];
      for (int i = 0; i < _seeds.length; i++) _seeds[i] = _rand.nextLong();
    }
    // Return a deterministic chunk-local RNG.  Can be kinda expensive.
    @Override
    public Random rngForChunk(int cidx) {
      long seed = _seeds[cidx];
      return createRNG(seed);
    }
  }

  @Override
  protected DecidedNode makeDecided(UndecidedNode udn, DHistogram hs[]) {
    return new DRFDecidedNode(udn, hs);
  }

  // DRF DTree decision node: same as the normal DecidedNode, but specifies a
  // decision algorithm given complete histograms on all columns.
  // DRF algo: find the lowest error amongst a random mtry columns.
  static class DRFDecidedNode extends DecidedNode {
    DRFDecidedNode(UndecidedNode n, DHistogram hs[]) {
      super(n, hs);
    }

    @Override
    public DRFUndecidedNode makeUndecidedNode(DHistogram hs[]) {
      return new DRFUndecidedNode(_tree, _nid, hs);
    }

    // Find the column with the best split (lowest score).
    @Override
    public DTree.Split bestCol(UndecidedNode u, DHistogram hs[]) {
      DTree.Split best =
          new DTree.Split(-1, -1, false, Double.MAX_VALUE, Double.MAX_VALUE, 0L, 0L, 0, 0);
      if (hs == null) return best;
      for (int i = 0; i < u._scoreCols.length; i++) {
        int col = u._scoreCols[i];
        DTree.Split s = hs[col].scoreMSE(col);
        if (s == null) continue;
        if (s.se() < best.se()) best = s;
        if (s.se() <= 0) break; // No point in looking further!
      }
      return best;
    }
  }

  // DRF DTree undecided node: same as the normal UndecidedNode, but specifies
  // a list of columns to score on now, and then decide over later.
  // DRF algo: pick a random mtry columns
  static class DRFUndecidedNode extends UndecidedNode {
    DRFUndecidedNode(DTree tree, int pid, DHistogram[] hs) {
      super(tree, pid, hs);
    }

    // Randomly select mtry columns to 'score' in following pass over the data.
    @Override
    public int[] scoreCols(DHistogram[] hs) {
      DRFTree tree = (DRFTree) _tree;
      int[] cols = new int[hs.length];
      int len = 0;
      // Gather all active columns to choose from.
      for (int i = 0; i < hs.length; i++) {
        if (hs[i] == null) continue; // Ignore not-tracked cols
        assert hs[i]._min < hs[i]._maxEx && hs[i].nbins() > 1 : "broken histo range " + hs[i];
        cols[len++] = i; // Gather active column
      }
      int choices = len; // Number of columns I can choose from
      assert choices > 0;

      // Draw up to mtry columns at random without replacement.
      for (int i = 0; i < tree._mtrys; i++) {
        if (len == 0) break; // Out of choices!
        int idx2 = tree._rand.nextInt(len);
        int col = cols[idx2]; // The chosen column
        cols[idx2] = cols[--len]; // Compress out of array; do not choose again
        cols[len] = col; // Swap chosen in just after 'len'
      }
      assert choices - len > 0;
      return Arrays.copyOfRange(cols, len, choices);
    }
  }

  static class DRFLeafNode extends LeafNode {
    DRFLeafNode(DTree tree, int pid) {
      super(tree, pid);
    }

    DRFLeafNode(DTree tree, int pid, int nid) {
      super(tree, pid, nid);
    }
    // Insert just the predictions: a single byte/short if we are predicting a
    // single class, or else the full distribution.
    @Override
    protected AutoBuffer compress(AutoBuffer ab) {
      assert !Double.isNaN(pred());
      return ab.put4f((float) pred());
    }

    @Override
    protected int size() {
      return 4;
    }
  }

  // Deterministic sampling
  static class Sample extends MRTask2<Sample> {
    final DRFTree _tree;
    final float _rate;

    Sample(DRFTree tree, float rate) {
      _tree = tree;
      _rate = rate;
    }

    @Override
    public void map(Chunk nids, Chunk ys) {
      Random rand = _tree.rngForChunk(nids.cidx());
      for (int row = 0; row < nids._len; row++)
        if (rand.nextFloat() >= _rate || Double.isNaN(ys.at0(row))) {
          nids.set0(row, OUT_OF_BAG); // Flag row as being ignored by sampling
        }
    }
  }
}