示例#1
0
文件: DRF.java 项目: rohit2412/h2o
 @Override
 protected float[] score0(double data[], float preds[]) {
   float[] p = super.score0(data, preds);
   int ntrees = ntrees();
   if (p.length == 1) {
     if (ntrees > 0) div(p, ntrees);
   } // regression - compute avg over all trees
   else { // classification
     float s = sum(p);
     if (s > 0) div(p, s); // unify over all classes
     p[0] = ModelUtils.getPrediction(p, data);
   }
   return p;
 }
示例#2
0
文件: DRF.java 项目: rohit2412/h2o
 @Override
 public void map(Chunk[] chks) {
   final Chunk y = importance ? chk_resp(chks) : null; // Response
   final float[] rpred = importance ? new float[1 + _nclass] : null; // Row prediction
   final double[] rowdata = importance ? new double[_ncols] : null; // Pre-allocated row data
   final Chunk oobt = chk_oobt(chks); // Out-of-bag rows counter over all trees
   // Iterate over all rows
   for (int row = 0; row < oobt._len; row++) {
     boolean wasOOBRow = false;
     // For all tree (i.e., k-classes)
     for (int k = 0; k < _nclass; k++) {
       final DTree tree = _trees[k];
       if (tree == null) continue; // Empty class is ignored
       // If we have all constant responses, then we do not split even the
       // root and the residuals should be zero.
       if (tree.root() instanceof LeafNode) continue;
       final Chunk nids = chk_nids(chks, k); // Node-ids  for this tree/class
       final Chunk ct = chk_tree(chks, k); // k-tree working column holding votes for given row
       int nid = (int) nids.at80(row); // Get Node to decide from
       // Update only out-of-bag rows
       // This is out-of-bag row - but we would like to track on-the-fly prediction for the row
       if (isOOBRow(nid)) { // The row should be OOB for all k-trees !!!
         assert k == 0 || wasOOBRow
             : "Something is wrong: k-class trees oob row computing is broken! All k-trees should agree on oob row!";
         wasOOBRow = true;
         nid = oob2Nid(nid);
         if (tree.node(nid) instanceof UndecidedNode) // If we bottomed out the tree
         nid = tree.node(nid).pid(); // Then take parent's decision
         DecidedNode dn = tree.decided(nid); // Must have a decision point
         if (dn._split.col() == -1) // Unable to decide?
         dn = tree.decided(tree.node(nid).pid()); // Then take parent's decision
         int leafnid = dn.ns(chks, row); // Decide down to a leafnode
         // Setup Tree(i) - on the fly prediction of i-tree for row-th row
         //   - for classification: cumulative number of votes for this row
         //   - for regression: cumulative sum of prediction of each tree - has to be normalized
         // by number of trees
         double prediction =
             ((LeafNode) tree.node(leafnid)).pred(); // Prediction for this k-class and this row
         if (importance)
           rpred[1 + k] = (float) prediction; // for both regression and classification
         ct.set0(row, (float) (ct.at0(row) + prediction));
         // For this tree this row is out-of-bag - i.e., a tree voted for this row
         oobt.set0(
             row,
             _nclass > 1
                 ? 1
                 : oobt.at0(row)
                     + 1); // for regression track number of trees, for classification boolean
                           // flag is enough
       }
       // reset help column for this row and this k-class
       nids.set0(row, 0);
     } /* end of k-trees iteration */
     if (importance) {
       if (wasOOBRow && !y.isNA0(row)) {
         if (classification) {
           int treePred = ModelUtils.getPrediction(rpred, data_row(chks, row, rowdata));
           int actuPred = (int) y.at80(row);
           if (treePred == actuPred) rightVotes++; // No miss !
         } else { // regression
           float treePred = rpred[1];
           float actuPred = (float) y.at0(row);
           sse += (actuPred - treePred) * (actuPred - treePred);
         }
         allRows++;
       }
     }
   }
 }