// -------------------------------------------------------------------------- // Build the next random k-trees representing tid-th tree private void buildNextKTrees(Frame fr, int mtrys, float sample_rate, Random rand, int tid) { // We're going to build K (nclass) trees - each focused on correcting // errors for a single class. final DTree[] ktrees = new DTree[_nclass]; // Initial set of histograms. All trees; one leaf per tree (the root // leaf); all columns DHistogram hcs[][][] = new DHistogram[_nclass][1 /*just root leaf*/][_ncols]; // Adjust real bins for the top-levels int adj_nbins = Math.max(_parms._nbins_top_level, _parms._nbins); // Use for all k-trees the same seed. NOTE: this is only to make a fair // view for all k-trees final double[] _distribution = _model._output._distribution; long rseed = rand.nextLong(); // Initially setup as-if an empty-split had just happened for (int k = 0; k < _nclass; k++) { if (_distribution[k] != 0) { // Ignore missing classes // The Boolean Optimization // This optimization assumes the 2nd tree of a 2-class system is the // inverse of the first (and that the same columns were picked) if (k == 1 && _nclass == 2 && _model.binomialOpt()) continue; ktrees[k] = new DRFTree( fr, _ncols, (char) _parms._nbins, (char) _parms._nbins_cats, (char) _nclass, _parms._min_rows, mtrys, rseed); new DRFUndecidedNode( ktrees[k], -1, DHistogram.initialHist( fr, _ncols, adj_nbins, _parms._nbins_cats, hcs[k][0])); // The "root" node } } // Sample - mark the lines by putting 'OUT_OF_BAG' into nid(<klass>) vector Timer t_1 = new Timer(); Sample ss[] = new Sample[_nclass]; for (int k = 0; k < _nclass; k++) if (ktrees[k] != null) ss[k] = new Sample((DRFTree) ktrees[k], sample_rate) .dfork(0, new Frame(vec_nids(fr, k), vec_resp(fr)), _parms._build_tree_one_node); for (int k = 0; k < _nclass; k++) if (ss[k] != null) ss[k].getResult(); Log.debug("Sampling took: + " + t_1); int[] leafs = new int [_nclass]; // Define a "working set" of leaf splits, from leafs[i] to tree._len for // each tree i // ---- // One Big Loop till the ktrees are of proper depth. // Adds a layer to the trees each pass. Timer t_2 = new Timer(); int depth = 0; for (; depth < _parms._max_depth; depth++) { if (!isRunning()) return; hcs = buildLayer( fr, _parms._nbins, _parms._nbins_cats, ktrees, leafs, hcs, true, _parms._build_tree_one_node); // If we did not make any new splits, then the tree is split-to-death if (hcs == null) break; } Log.debug("Tree build took: " + t_2); // Each tree bottomed-out in a DecidedNode; go 1 more level and insert // LeafNodes to hold predictions. Timer t_3 = new Timer(); for (int k = 0; k < _nclass; k++) { DTree tree = ktrees[k]; if (tree == null) continue; int leaf = leafs[k] = tree.len(); for (int nid = 0; nid < leaf; nid++) { if (tree.node(nid) instanceof DecidedNode) { DecidedNode dn = tree.decided(nid); if (dn._split._col == -1) { // No decision here, no row should have this NID now if (nid == 0) { // Handle the trivial non-splitting tree LeafNode ln = new DRFLeafNode(tree, -1, 0); ln._pred = (float) (isClassifier() ? _model._output._priorClassDist[k] : responseMean()); } continue; } for (int i = 0; i < dn._nids.length; i++) { int cnid = dn._nids[i]; if (cnid == -1 || // Bottomed out (predictors or responses known constant) tree.node(cnid) instanceof UndecidedNode || // Or chopped off for depth (tree.node(cnid) instanceof DecidedNode && // Or not possible to split ((DecidedNode) tree.node(cnid))._split.col() == -1)) { LeafNode ln = new DRFLeafNode(tree, nid); ln._pred = (float) dn.pred(i); // Set prediction into the leaf dn._nids[i] = ln.nid(); // Mark a leaf here } } } } } // -- k-trees are done Log.debug("Nodes propagation: " + t_3); // ---- // Move rows into the final leaf rows Timer t_4 = new Timer(); CollectPreds cp = new CollectPreds(ktrees, leafs, _model.defaultThreshold()) .doAll(fr, _parms._build_tree_one_node); if (isClassifier()) asVotes(_treeMeasuresOnOOB) .append(cp.rightVotes, cp.allRows); // Track right votes over OOB rows for this tree else /* regression */ asSSE(_treeMeasuresOnOOB).append(cp.sse, cp.allRows); Log.debug("CollectPreds done: " + t_4); // Grow the model by K-trees _model._output.addKTrees(ktrees); }