/** * A wrapper for sampling one tree during the Gibbs sampler * * @param sample_num The current sample number of the Gibbs sampler * @param t The current tree to be sampled * @param trees The trees in this Gibbs sampler * @param tree_array_illustration The tree array (for debugging purposes only) * @return The responses minus the sum of the trees' contribution up to this point */ protected double[] SampleTree( int sample_num, int t, CGMBARTTreeNode[] trees, TreeArrayIllustration tree_array_illustration) { // first copy the tree from the previous gibbs position final CGMBARTTreeNode copy_of_old_jth_tree = gibbs_samples_of_cgm_trees[sample_num - 1][t].clone(); // okay so first we need to get "y" that this tree sees. This is defined as R_j in formula 12 on // p274 // just go to sum_residual_vec and subtract it from y_trans double[] R_j = Tools.add_arrays( Tools.subtract_arrays(y_trans, sum_resids_vec), copy_of_old_jth_tree.yhats); // now, (important!) set the R_j's as this tree's data. copy_of_old_jth_tree.updateWithNewResponsesRecursively(R_j); // sample from T_j | R_j, \sigma // now we will run one M-H step on this tree with the y as the R_j CGMBARTTreeNode new_jth_tree = metroHastingsPosteriorTreeSpaceIteration( copy_of_old_jth_tree, t, accept_reject_mh, accept_reject_mh_steps); // add it to the vector of current sample's trees trees[t] = new_jth_tree; // now set the new trees in the gibbs sample pantheon gibbs_samples_of_cgm_trees[sample_num] = trees; tree_array_illustration.AddTree(new_jth_tree); // return the updated residuals return R_j; }
/** * A wrapper for sampling the mus (mean predictions at terminal nodes). This function implements * part of the "residual diffing" explained in the paper. * * @param sample_num The current sample number of the Gibbs sampler * @param t The tree index number in 1...<code>num_trees</code> * @see Section 3.1 of Kapelner, A and Bleich, J. bartMachine: A Powerful Tool for Machine * Learning in R. ArXiv e-prints, 2013 */ protected void SampleMusWrapper(int sample_num, int t) { CGMBARTTreeNode previous_tree = gibbs_samples_of_cgm_trees[sample_num - 1][t]; // subtract out previous tree's yhats sum_resids_vec = Tools.subtract_arrays(sum_resids_vec, previous_tree.yhats); CGMBARTTreeNode tree = gibbs_samples_of_cgm_trees[sample_num][t]; double current_sigsq = gibbs_samples_of_sigsq[sample_num - 1]; assignLeafValsBySamplingFromPosteriorMeanAndSigsqAndUpdateYhats(tree, current_sigsq); // after mus are sampled, we need to update the sum_resids_vec // add in current tree's yhats sum_resids_vec = Tools.add_arrays(sum_resids_vec, tree.yhats); }