Exemplo n.º 1
0
 @Override
 public void map(Chunk[] chks) {
   final Chunk y = importance ? chk_resp(chks) : null; // Response
   final float[] rpred = importance ? new float[1 + _nclass] : null; // Row prediction
   final double[] rowdata = importance ? new double[_ncols] : null; // Pre-allocated row data
   final Chunk oobt = chk_oobt(chks); // Out-of-bag rows counter over all trees
   // Iterate over all rows
   for (int row = 0; row < oobt._len; row++) {
     boolean wasOOBRow = false;
     // For all tree (i.e., k-classes)
     for (int k = 0; k < _nclass; k++) {
       final DTree tree = _trees[k];
       if (tree == null) continue; // Empty class is ignored
       // If we have all constant responses, then we do not split even the
       // root and the residuals should be zero.
       if (tree.root() instanceof LeafNode) continue;
       final Chunk nids = chk_nids(chks, k); // Node-ids  for this tree/class
       final Chunk ct = chk_tree(chks, k); // k-tree working column holding votes for given row
       int nid = (int) nids.at80(row); // Get Node to decide from
       // Update only out-of-bag rows
       // This is out-of-bag row - but we would like to track on-the-fly prediction for the row
       if (isOOBRow(nid)) { // The row should be OOB for all k-trees !!!
         assert k == 0 || wasOOBRow
             : "Something is wrong: k-class trees oob row computing is broken! All k-trees should agree on oob row!";
         wasOOBRow = true;
         nid = oob2Nid(nid);
         if (tree.node(nid) instanceof UndecidedNode) // If we bottomed out the tree
         nid = tree.node(nid).pid(); // Then take parent's decision
         DecidedNode dn = tree.decided(nid); // Must have a decision point
         if (dn._split.col() == -1) // Unable to decide?
         dn = tree.decided(tree.node(nid).pid()); // Then take parent's decision
         int leafnid = dn.ns(chks, row); // Decide down to a leafnode
         // Setup Tree(i) - on the fly prediction of i-tree for row-th row
         //   - for classification: cumulative number of votes for this row
         //   - for regression: cumulative sum of prediction of each tree - has to be normalized
         // by number of trees
         double prediction =
             ((LeafNode) tree.node(leafnid)).pred(); // Prediction for this k-class and this row
         if (importance)
           rpred[1 + k] = (float) prediction; // for both regression and classification
         ct.set0(row, (float) (ct.at0(row) + prediction));
         // For this tree this row is out-of-bag - i.e., a tree voted for this row
         oobt.set0(
             row,
             _nclass > 1
                 ? 1
                 : oobt.at0(row)
                     + 1); // for regression track number of trees, for classification boolean
                           // flag is enough
       }
       // reset help column for this row and this k-class
       nids.set0(row, 0);
     } /* end of k-trees iteration */
     if (importance) {
       if (wasOOBRow && !y.isNA0(row)) {
         if (classification) {
           int treePred = ModelUtils.getPrediction(rpred, data_row(chks, row, rowdata));
           int actuPred = (int) y.at80(row);
           if (treePred == actuPred) rightVotes++; // No miss !
         } else { // regression
           float treePred = rpred[1];
           float actuPred = (float) y.at0(row);
           sse += (actuPred - treePred) * (actuPred - treePred);
         }
         allRows++;
       }
     }
   }
 }
Exemplo n.º 2
0
  // --------------------------------------------------------------------------
  // Build the next random k-trees represeint tid-th tree
  private DTree[] buildNextKTrees(Frame fr, int mtrys, float sample_rate, Random rand, int tid) {
    // We're going to build K (nclass) trees - each focused on correcting
    // errors for a single class.
    final DTree[] ktrees = new DTree[_nclass];

    // Initial set of histograms.  All trees; one leaf per tree (the root
    // leaf); all columns
    DHistogram hcs[][][] = new DHistogram[_nclass][1 /*just root leaf*/][_ncols];

    // Use for all k-trees the same seed. NOTE: this is only to make a fair
    // view for all k-trees
    long rseed = rand.nextLong();
    // Initially setup as-if an empty-split had just happened
    for (int k = 0; k < _nclass; k++) {
      assert (_distribution != null && classification)
          || (_distribution == null && !classification);
      if (_distribution == null || _distribution[k] != 0) { // Ignore missing classes
        // The Boolean Optimization
        // This optimization assumes the 2nd tree of a 2-class system is the
        // inverse of the first.  This is false for DRF (and true for GBM) -
        // DRF picks a random different set of columns for the 2nd tree.
        // if( k==1 && _nclass==2 ) continue;
        ktrees[k] = new DRFTree(fr, _ncols, (char) nbins, (char) _nclass, min_rows, mtrys, rseed);
        boolean isBinom = classification;
        new DRFUndecidedNode(
            ktrees[k],
            -1,
            DHistogram.initialHist(fr, _ncols, nbins, hcs[k][0], isBinom)); // The "root" node
      }
    }

    // Sample - mark the lines by putting 'OUT_OF_BAG' into nid(<klass>) vector
    Timer t_1 = new Timer();
    Sample ss[] = new Sample[_nclass];
    for (int k = 0; k < _nclass; k++)
      if (ktrees[k] != null)
        ss[k] =
            new Sample((DRFTree) ktrees[k], sample_rate)
                .dfork(0, new Frame(vec_nids(fr, k), vec_resp(fr, k)), build_tree_one_node);
    for (int k = 0; k < _nclass; k++) if (ss[k] != null) ss[k].getResult();
    Log.debug(Sys.DRF__, "Sampling took: + " + t_1);

    int[] leafs =
        new int
            [_nclass]; // Define a "working set" of leaf splits, from leafs[i] to tree._len for each
                       // tree i

    // ----
    // One Big Loop till the ktrees are of proper depth.
    // Adds a layer to the trees each pass.
    Timer t_2 = new Timer();
    int depth = 0;
    for (; depth < max_depth; depth++) {
      if (!Job.isRunning(self())) return null;

      hcs = buildLayer(fr, ktrees, leafs, hcs, true, build_tree_one_node);

      // If we did not make any new splits, then the tree is split-to-death
      if (hcs == null) break;
    }
    Log.debug(Sys.DRF__, "Tree build took: " + t_2);

    // Each tree bottomed-out in a DecidedNode; go 1 more level and insert
    // LeafNodes to hold predictions.
    Timer t_3 = new Timer();
    for (int k = 0; k < _nclass; k++) {
      DTree tree = ktrees[k];
      if (tree == null) continue;
      int leaf = leafs[k] = tree.len();
      for (int nid = 0; nid < leaf; nid++) {
        if (tree.node(nid) instanceof DecidedNode) {
          DecidedNode dn = tree.decided(nid);
          for (int i = 0; i < dn._nids.length; i++) {
            int cnid = dn._nids[i];
            if (cnid == -1
                || // Bottomed out (predictors or responses known constant)
                tree.node(cnid) instanceof UndecidedNode
                || // Or chopped off for depth
                (tree.node(cnid) instanceof DecidedNode
                    && // Or not possible to split
                    ((DecidedNode) tree.node(cnid))._split.col() == -1)) {
              LeafNode ln = new DRFLeafNode(tree, nid);
              ln._pred = dn.pred(i); // Set prediction into the leaf
              dn._nids[i] = ln.nid(); // Mark a leaf here
            }
          }
          // Handle the trivial non-splitting tree
          if (nid == 0 && dn._split.col() == -1) new DRFLeafNode(tree, -1, 0);
        }
      }
    } // -- k-trees are done
    Log.debug(Sys.DRF__, "Nodes propagation: " + t_3);

    // ----
    // Move rows into the final leaf rows
    Timer t_4 = new Timer();
    CollectPreds cp = new CollectPreds(ktrees, leafs).doAll(fr, build_tree_one_node);
    if (importance) {
      if (classification)
        asVotes(_treeMeasuresOnOOB)
            .append(cp.rightVotes, cp.allRows); // Track right votes over OOB rows for this tree
      else /* regression */ asSSE(_treeMeasuresOnOOB).append(cp.sse, cp.allRows);
    }
    Log.debug(Sys.DRF__, "CollectPreds done: " + t_4);

    // Collect leaves stats
    for (int i = 0; i < ktrees.length; i++)
      if (ktrees[i] != null) ktrees[i].leaves = ktrees[i].len() - leafs[i];
    // DEBUG: Print the generated K trees
    // printGenerateTrees(ktrees);

    return ktrees;
  }