/** * On-the-fly version for varimp. After generation a new tree, its tree votes are collected on * shuffled OOB rows and variable importance is recomputed. * * <p>The <a * href="http://www.stat.berkeley.edu/~breiman/RandomForests/cc_home.htm#varimp">page</a> says: * <cite> "In every tree grown in the forest, put down the oob cases and count the number of votes * cast for the correct class. Now randomly permute the values of variable m in the oob cases and * put these cases down the tree. Subtract the number of votes for the correct class in the * variable-m-permuted oob data from the number of votes for the correct class in the untouched * oob data. The average of this number over all trees in the forest is the raw importance score * for variable m." </cite> */ @Override protected VarImp doVarImpCalc( final DRFModel model, DTree[] ktrees, final int tid, final Frame fTrain, boolean scale) { // Check if we have already serialized 'ktrees'-trees in the model assert model.ntrees() - 1 == tid : "Cannot compute DRF varimp since 'ktrees' are not serialized in the model! tid=" + tid; assert _treeMeasuresOnOOB.npredictors() - 1 == tid : "Tree votes over OOB rows for this tree (var ktrees) were not found!"; // Compute tree votes over shuffled data final CompressedTree[ /*nclass*/] theTree = model.ctree(tid); // get the last tree FIXME we should pass only keys final int nclasses = model.nclasses(); Futures fs = new Futures(); for (int var = 0; var < _ncols; var++) { final int variable = var; H2OCountedCompleter task4var = classification ? new H2OCountedCompleter() { @Override public void compute2() { // Compute this tree votes over all data over given variable TreeVotes cd = TreeMeasuresCollector.collectVotes( theTree, nclasses, fTrain, _ncols, sample_rate, variable); assert cd.npredictors() == 1; asVotes(_treeMeasuresOnSOOB[variable]).append(cd); tryComplete(); } } : /* regression */ new H2OCountedCompleter() { @Override public void compute2() { // Compute this tree votes over all data over given variable TreeSSE cd = TreeMeasuresCollector.collectSSE( theTree, nclasses, fTrain, _ncols, sample_rate, variable); assert cd.npredictors() == 1; asSSE(_treeMeasuresOnSOOB[variable]).append(cd); tryComplete(); } }; H2O.submitTask(task4var); // Fork computation fs.add(task4var); } fs.blockForPending(); // Wait for results // Compute varimp for individual features (_ncols) final float[] varimp = new float[_ncols]; // output variable importance final float[] varimpSD = new float[_ncols]; // output variable importance sd for (int var = 0; var < _ncols; var++) { double[ /*2*/] imp = classification ? asVotes(_treeMeasuresOnSOOB[var]).imp(asVotes(_treeMeasuresOnOOB)) : asSSE(_treeMeasuresOnSOOB[var]).imp(asSSE(_treeMeasuresOnOOB)); varimp[var] = (float) imp[0]; varimpSD[var] = (float) imp[1]; } return new VarImp.VarImpMDA(varimp, varimpSD, model.ntrees()); }
@Override protected DRFModel buildModel( DRFModel model, final Frame fr, String names[], String domains[][], final Timer t_build) { // Append number of trees participating in on-the-fly scoring fr.add("OUT_BAG_TREES", response.makeZero()); // The RNG used to pick split columns Random rand = createRNG(_seed); // Prepare working columns new SetWrkTask().doAll(fr); int tid; DTree[] ktrees = null; // Prepare tree statistics TreeStats tstats = new TreeStats(); // Build trees until we hit the limit for (tid = 0; tid < ntrees; tid++) { // Building tid-tree model = doScoring( model, fr, ktrees, tid, tstats, tid == 0, !hasValidation(), build_tree_one_node); // At each iteration build K trees (K = nclass = response column domain size) // TODO: parallelize more? build more than k trees at each time, we need to care about // temporary data // Idea: launch more DRF at once. Timer kb_timer = new Timer(); ktrees = buildNextKTrees(fr, _mtry, sample_rate, rand, tid); Log.info(Sys.DRF__, (tid + 1) + ". tree was built " + kb_timer.toString()); if (!Job.isRunning(self())) break; // If canceled during building, do not bulkscore // Check latest predictions tstats.updateBy(ktrees); } model = doScoring(model, fr, ktrees, tid, tstats, true, !hasValidation(), build_tree_one_node); // Make sure that we did not miss any votes assert !importance || _treeMeasuresOnOOB.npredictors() == _treeMeasuresOnSOOB[0 /*variable*/].npredictors() : "Missing some tree votes in variable importance voting?!"; return model; }