/** * Stratified sampling for classifiers * * @param fr Input frame * @param label Label vector (must be enum) * @param sampling_ratios Optional: array containing the requested sampling ratios per class (in * order of domains), will be overwritten if it contains all 0s * @param maxrows Maximum number of rows in the returned frame * @param seed RNG seed for sampling * @param allowOversampling Allow oversampling of minority classes * @param verbose Whether to print verbose info * @return Sampled frame, with approximately the same number of samples from each class (or given * by the requested sampling ratios) */ public static Frame sampleFrameStratified( final Frame fr, Vec label, float[] sampling_ratios, long maxrows, final long seed, final boolean allowOversampling, final boolean verbose) { if (fr == null) return null; assert (label.isEnum()); assert (maxrows >= label.domain().length); long[] dist = new ClassDist(label).doAll(label).dist(); assert (dist.length > 0); Log.info( "Doing stratified sampling for data set containing " + fr.numRows() + " rows from " + dist.length + " classes. Oversampling: " + (allowOversampling ? "on" : "off")); if (verbose) { for (int i = 0; i < dist.length; ++i) { Log.info( "Class " + label.domain(i) + ": count: " + dist[i] + " prior: " + (float) dist[i] / fr.numRows()); } } // create sampling_ratios for class balance with max. maxrows rows (fill existing array if not // null) if (sampling_ratios == null || (Utils.minValue(sampling_ratios) == 0 && Utils.maxValue(sampling_ratios) == 0)) { // compute sampling ratios to achieve class balance if (sampling_ratios == null) { sampling_ratios = new float[dist.length]; } assert (sampling_ratios.length == dist.length); for (int i = 0; i < dist.length; ++i) { sampling_ratios[i] = ((float) fr.numRows() / label.domain().length) / dist[i]; // prior^-1 / num_classes } final float inv_scale = Utils.minValue( sampling_ratios); // majority class has lowest required oversampling factor to achieve // balance if (!Float.isNaN(inv_scale) && !Float.isInfinite(inv_scale)) Utils.div( sampling_ratios, inv_scale); // want sampling_ratio 1.0 for majority class (no downsampling) } if (!allowOversampling) { for (int i = 0; i < sampling_ratios.length; ++i) { sampling_ratios[i] = Math.min(1.0f, sampling_ratios[i]); } } // given these sampling ratios, and the original class distribution, this is the expected number // of resulting rows float numrows = 0; for (int i = 0; i < sampling_ratios.length; ++i) { numrows += sampling_ratios[i] * dist[i]; } final long actualnumrows = Math.min(maxrows, Math.round(numrows)); // cap #rows at maxrows assert (actualnumrows >= 0); // can have no matching rows in case of sparse data where we had to fill in a // makeZero() vector Log.info("Stratified sampling to a total of " + String.format("%,d", actualnumrows) + " rows."); if (actualnumrows != numrows) { Utils.mult( sampling_ratios, (float) actualnumrows / numrows); // adjust the sampling_ratios by the global rescaling factor if (verbose) Log.info( "Downsampling majority class by " + (float) actualnumrows / numrows + " to limit number of rows to " + String.format("%,d", maxrows)); } Log.info( "Majority class (" + label.domain()[Utils.minIndex(sampling_ratios)].toString() + ") sampling ratio: " + Utils.minValue(sampling_ratios)); Log.info( "Minority class (" + label.domain()[Utils.maxIndex(sampling_ratios)].toString() + ") sampling ratio: " + Utils.maxValue(sampling_ratios)); return sampleFrameStratified(fr, label, sampling_ratios, seed, verbose); }
public static int imputeCat(Vec v) { if (v.isCategorical()) return v.mode(); return (int) Math.round(v.mean()); }