/** * Global redistribution of a Frame (balancing of chunks), done by calling process (all-to-one + * one-to-all) * * @param fr Input frame * @param seed RNG seed * @param shuffle whether to shuffle the data globally * @return Shuffled frame */ public static Frame shuffleAndBalance( final Frame fr, int splits, long seed, final boolean local, final boolean shuffle) { if ((fr.vecs()[0].nChunks() < splits || shuffle) && fr.numRows() > splits) { Vec[] vecs = fr.vecs().clone(); Log.info("Load balancing dataset, splitting it into up to " + splits + " chunks."); long[] idx = null; if (shuffle) { idx = new long[splits]; for (int r = 0; r < idx.length; ++r) idx[r] = r; Utils.shuffleArray(idx, seed); } Key keys[] = new Vec.VectorGroup().addVecs(vecs.length); final long rows_per_new_chunk = (long) (Math.ceil((double) fr.numRows() / splits)); // loop over cols (same indexing for each column) Futures fs = new Futures(); for (int col = 0; col < vecs.length; col++) { AppendableVec vec = new AppendableVec(keys[col]); // create outgoing chunks for this col NewChunk[] outCkg = new NewChunk[splits]; for (int i = 0; i < splits; ++i) outCkg[i] = new NewChunk(vec, i); // loop over all incoming chunks for (int ckg = 0; ckg < vecs[col].nChunks(); ckg++) { final Chunk inCkg = vecs[col].chunkForChunkIdx(ckg); // loop over local rows of incoming chunks (fast path) for (int row = 0; row < inCkg._len; ++row) { int outCkgIdx = (int) ((inCkg._start + row) / rows_per_new_chunk); // destination chunk idx if (shuffle) outCkgIdx = (int) (idx[outCkgIdx]); // shuffle: choose a different output chunk assert (outCkgIdx >= 0 && outCkgIdx < splits); outCkg[outCkgIdx].addNum(inCkg.at0(row)); } } for (int i = 0; i < outCkg.length; ++i) outCkg[i].close(i, fs); Vec t = vec.close(fs); t._domain = vecs[col]._domain; vecs[col] = t; } fs.blockForPending(); Log.info("Load balancing done."); return new Frame(fr.names(), vecs); } return fr; }
// internal version with repeat counter // currently hardcoded to do up to 10 tries to get a row from each class, which can be impossible // for certain wrong sampling ratios private static Frame sampleFrameStratified( final Frame fr, Vec label, final float[] sampling_ratios, final long seed, final boolean debug, int count) { if (fr == null) return null; assert (label.isEnum()); assert (sampling_ratios != null && sampling_ratios.length == label.domain().length); final int labelidx = fr.find(label); // which column is the label? assert (labelidx >= 0); final boolean poisson = false; // beta feature Frame r = new MRTask2() { @Override public void map(Chunk[] cs, NewChunk[] ncs) { final Random rng = getDeterRNG(seed + cs[0].cidx()); for (int r = 0; r < cs[0]._len; r++) { if (cs[labelidx].isNA0(r)) continue; // skip missing labels final int label = (int) cs[labelidx].at80(r); assert (sampling_ratios.length > label && label >= 0); int sampling_reps; if (poisson) { sampling_reps = Utils.getPoisson(sampling_ratios[label], rng); } else { final float remainder = sampling_ratios[label] - (int) sampling_ratios[label]; sampling_reps = (int) sampling_ratios[label] + (rng.nextFloat() < remainder ? 1 : 0); } for (int i = 0; i < ncs.length; i++) { for (int j = 0; j < sampling_reps; ++j) { ncs[i].addNum(cs[i].at0(r)); } } } } }.doAll(fr.numCols(), fr).outputFrame(fr.names(), fr.domains()); // Confirm the validity of the distribution long[] dist = new ClassDist(r.vecs()[labelidx]).doAll(r.vecs()[labelidx]).dist(); // if there are no training labels in the test set, then there is no point in sampling the test // set if (dist == null) return fr; if (debug) { long sumdist = Utils.sum(dist); Log.info("After stratified sampling: " + sumdist + " rows."); for (int i = 0; i < dist.length; ++i) { Log.info( "Class " + r.vecs()[labelidx].domain(i) + ": count: " + dist[i] + " sampling ratio: " + sampling_ratios[i] + " actual relative frequency: " + (float) dist[i] / sumdist * dist.length); } } // Re-try if we didn't get at least one example from each class if (Utils.minValue(dist) == 0 && count < 10) { Log.info( "Re-doing stratified sampling because not all classes were represented (unlucky draw)."); r.delete(); return sampleFrameStratified(fr, label, sampling_ratios, seed + 1, debug, ++count); } // shuffle intra-chunk Frame shuffled = shuffleFramePerChunk(r, seed + 0x580FF13); r.delete(); return shuffled; }
/** * Stratified sampling for classifiers * * @param fr Input frame * @param label Label vector (must be enum) * @param sampling_ratios Optional: array containing the requested sampling ratios per class (in * order of domains), will be overwritten if it contains all 0s * @param maxrows Maximum number of rows in the returned frame * @param seed RNG seed for sampling * @param allowOversampling Allow oversampling of minority classes * @param verbose Whether to print verbose info * @return Sampled frame, with approximately the same number of samples from each class (or given * by the requested sampling ratios) */ public static Frame sampleFrameStratified( final Frame fr, Vec label, float[] sampling_ratios, long maxrows, final long seed, final boolean allowOversampling, final boolean verbose) { if (fr == null) return null; assert (label.isEnum()); assert (maxrows >= label.domain().length); long[] dist = new ClassDist(label).doAll(label).dist(); assert (dist.length > 0); Log.info( "Doing stratified sampling for data set containing " + fr.numRows() + " rows from " + dist.length + " classes. Oversampling: " + (allowOversampling ? "on" : "off")); if (verbose) { for (int i = 0; i < dist.length; ++i) { Log.info( "Class " + label.domain(i) + ": count: " + dist[i] + " prior: " + (float) dist[i] / fr.numRows()); } } // create sampling_ratios for class balance with max. maxrows rows (fill existing array if not // null) if (sampling_ratios == null || (Utils.minValue(sampling_ratios) == 0 && Utils.maxValue(sampling_ratios) == 0)) { // compute sampling ratios to achieve class balance if (sampling_ratios == null) { sampling_ratios = new float[dist.length]; } assert (sampling_ratios.length == dist.length); for (int i = 0; i < dist.length; ++i) { sampling_ratios[i] = ((float) fr.numRows() / label.domain().length) / dist[i]; // prior^-1 / num_classes } final float inv_scale = Utils.minValue( sampling_ratios); // majority class has lowest required oversampling factor to achieve // balance if (!Float.isNaN(inv_scale) && !Float.isInfinite(inv_scale)) Utils.div( sampling_ratios, inv_scale); // want sampling_ratio 1.0 for majority class (no downsampling) } if (!allowOversampling) { for (int i = 0; i < sampling_ratios.length; ++i) { sampling_ratios[i] = Math.min(1.0f, sampling_ratios[i]); } } // given these sampling ratios, and the original class distribution, this is the expected number // of resulting rows float numrows = 0; for (int i = 0; i < sampling_ratios.length; ++i) { numrows += sampling_ratios[i] * dist[i]; } final long actualnumrows = Math.min(maxrows, Math.round(numrows)); // cap #rows at maxrows assert (actualnumrows >= 0); // can have no matching rows in case of sparse data where we had to fill in a // makeZero() vector Log.info("Stratified sampling to a total of " + String.format("%,d", actualnumrows) + " rows."); if (actualnumrows != numrows) { Utils.mult( sampling_ratios, (float) actualnumrows / numrows); // adjust the sampling_ratios by the global rescaling factor if (verbose) Log.info( "Downsampling majority class by " + (float) actualnumrows / numrows + " to limit number of rows to " + String.format("%,d", maxrows)); } Log.info( "Majority class (" + label.domain()[Utils.minIndex(sampling_ratios)].toString() + ") sampling ratio: " + Utils.minValue(sampling_ratios)); Log.info( "Minority class (" + label.domain()[Utils.maxIndex(sampling_ratios)].toString() + ") sampling ratio: " + Utils.maxValue(sampling_ratios)); return sampleFrameStratified(fr, label, sampling_ratios, seed, verbose); }