// internal version with repeat counter // currently hardcoded to do up to 10 tries to get a row from each class, which can be impossible // for certain wrong sampling ratios private static Frame sampleFrameStratified( final Frame fr, Vec label, final float[] sampling_ratios, final long seed, final boolean debug, int count) { if (fr == null) return null; assert (label.isEnum()); assert (sampling_ratios != null && sampling_ratios.length == label.domain().length); final int labelidx = fr.find(label); // which column is the label? assert (labelidx >= 0); final boolean poisson = false; // beta feature Frame r = new MRTask2() { @Override public void map(Chunk[] cs, NewChunk[] ncs) { final Random rng = getDeterRNG(seed + cs[0].cidx()); for (int r = 0; r < cs[0]._len; r++) { if (cs[labelidx].isNA0(r)) continue; // skip missing labels final int label = (int) cs[labelidx].at80(r); assert (sampling_ratios.length > label && label >= 0); int sampling_reps; if (poisson) { sampling_reps = Utils.getPoisson(sampling_ratios[label], rng); } else { final float remainder = sampling_ratios[label] - (int) sampling_ratios[label]; sampling_reps = (int) sampling_ratios[label] + (rng.nextFloat() < remainder ? 1 : 0); } for (int i = 0; i < ncs.length; i++) { for (int j = 0; j < sampling_reps; ++j) { ncs[i].addNum(cs[i].at0(r)); } } } } }.doAll(fr.numCols(), fr).outputFrame(fr.names(), fr.domains()); // Confirm the validity of the distribution long[] dist = new ClassDist(r.vecs()[labelidx]).doAll(r.vecs()[labelidx]).dist(); // if there are no training labels in the test set, then there is no point in sampling the test // set if (dist == null) return fr; if (debug) { long sumdist = Utils.sum(dist); Log.info("After stratified sampling: " + sumdist + " rows."); for (int i = 0; i < dist.length; ++i) { Log.info( "Class " + r.vecs()[labelidx].domain(i) + ": count: " + dist[i] + " sampling ratio: " + sampling_ratios[i] + " actual relative frequency: " + (float) dist[i] / sumdist * dist.length); } } // Re-try if we didn't get at least one example from each class if (Utils.minValue(dist) == 0 && count < 10) { Log.info( "Re-doing stratified sampling because not all classes were represented (unlucky draw)."); r.delete(); return sampleFrameStratified(fr, label, sampling_ratios, seed + 1, debug, ++count); } // shuffle intra-chunk Frame shuffled = shuffleFramePerChunk(r, seed + 0x580FF13); r.delete(); return shuffled; }
public ClassDist(final Vec label) { super(label.domain().length); }
/** * Stratified sampling for classifiers * * @param fr Input frame * @param label Label vector (must be enum) * @param sampling_ratios Optional: array containing the requested sampling ratios per class (in * order of domains), will be overwritten if it contains all 0s * @param maxrows Maximum number of rows in the returned frame * @param seed RNG seed for sampling * @param allowOversampling Allow oversampling of minority classes * @param verbose Whether to print verbose info * @return Sampled frame, with approximately the same number of samples from each class (or given * by the requested sampling ratios) */ public static Frame sampleFrameStratified( final Frame fr, Vec label, float[] sampling_ratios, long maxrows, final long seed, final boolean allowOversampling, final boolean verbose) { if (fr == null) return null; assert (label.isEnum()); assert (maxrows >= label.domain().length); long[] dist = new ClassDist(label).doAll(label).dist(); assert (dist.length > 0); Log.info( "Doing stratified sampling for data set containing " + fr.numRows() + " rows from " + dist.length + " classes. Oversampling: " + (allowOversampling ? "on" : "off")); if (verbose) { for (int i = 0; i < dist.length; ++i) { Log.info( "Class " + label.domain(i) + ": count: " + dist[i] + " prior: " + (float) dist[i] / fr.numRows()); } } // create sampling_ratios for class balance with max. maxrows rows (fill existing array if not // null) if (sampling_ratios == null || (Utils.minValue(sampling_ratios) == 0 && Utils.maxValue(sampling_ratios) == 0)) { // compute sampling ratios to achieve class balance if (sampling_ratios == null) { sampling_ratios = new float[dist.length]; } assert (sampling_ratios.length == dist.length); for (int i = 0; i < dist.length; ++i) { sampling_ratios[i] = ((float) fr.numRows() / label.domain().length) / dist[i]; // prior^-1 / num_classes } final float inv_scale = Utils.minValue( sampling_ratios); // majority class has lowest required oversampling factor to achieve // balance if (!Float.isNaN(inv_scale) && !Float.isInfinite(inv_scale)) Utils.div( sampling_ratios, inv_scale); // want sampling_ratio 1.0 for majority class (no downsampling) } if (!allowOversampling) { for (int i = 0; i < sampling_ratios.length; ++i) { sampling_ratios[i] = Math.min(1.0f, sampling_ratios[i]); } } // given these sampling ratios, and the original class distribution, this is the expected number // of resulting rows float numrows = 0; for (int i = 0; i < sampling_ratios.length; ++i) { numrows += sampling_ratios[i] * dist[i]; } final long actualnumrows = Math.min(maxrows, Math.round(numrows)); // cap #rows at maxrows assert (actualnumrows >= 0); // can have no matching rows in case of sparse data where we had to fill in a // makeZero() vector Log.info("Stratified sampling to a total of " + String.format("%,d", actualnumrows) + " rows."); if (actualnumrows != numrows) { Utils.mult( sampling_ratios, (float) actualnumrows / numrows); // adjust the sampling_ratios by the global rescaling factor if (verbose) Log.info( "Downsampling majority class by " + (float) actualnumrows / numrows + " to limit number of rows to " + String.format("%,d", maxrows)); } Log.info( "Majority class (" + label.domain()[Utils.minIndex(sampling_ratios)].toString() + ") sampling ratio: " + Utils.minValue(sampling_ratios)); Log.info( "Minority class (" + label.domain()[Utils.maxIndex(sampling_ratios)].toString() + ") sampling ratio: " + Utils.maxValue(sampling_ratios)); return sampleFrameStratified(fr, label, sampling_ratios, seed, verbose); }
/** * The train/valid Frame instances are sorted by categorical (themselves sorted by cardinality * greatest to least) with all numerical columns following. The response column(s) are placed at * the end. * * <p>Interactions: 1. Num-Num (Note: N(0,1) * N(0,1) ~ N(0,1) ) 2. Num-Enum 3. Enum-Enum * * <p>Interactions are produced on the fly and are dense (in all 3 cases). Consumers of DataInfo * should not have to care how these interactions are generated. Any heuristic using the fullN * value should continue functioning the same. * * <p>Interactions are specified in two ways: A. As a list of pairs of column indices. B. As a * list of pairs of column indices with limited enums. */ public DataInfo( Frame train, Frame valid, int nResponses, boolean useAllFactorLevels, TransformType predictor_transform, TransformType response_transform, boolean skipMissing, boolean imputeMissing, boolean missingBucket, boolean weight, boolean offset, boolean fold, Model.InteractionPair[] interactions) { super(Key.<DataInfo>make()); _valid = valid != null; assert predictor_transform != null; assert response_transform != null; _offset = offset; _weights = weight; _fold = fold; assert !(skipMissing && imputeMissing) : "skipMissing and imputeMissing cannot both be true"; _skipMissing = skipMissing; _imputeMissing = imputeMissing; _predictor_transform = predictor_transform; _response_transform = response_transform; _responses = nResponses; _useAllFactorLevels = useAllFactorLevels; _interactions = interactions; // create dummy InteractionWrappedVecs and shove them onto the front if (_interactions != null) { _interactionVecs = new int[_interactions.length]; train = Model.makeInteractions( train, false, _interactions, _useAllFactorLevels, _skipMissing, predictor_transform == TransformType.STANDARDIZE) .add(train); if (valid != null) valid = Model.makeInteractions( valid, true, _interactions, _useAllFactorLevels, _skipMissing, predictor_transform == TransformType.STANDARDIZE) .add(valid); // FIXME: should be using the training subs/muls! } _permutation = new int[train.numCols()]; final Vec[] tvecs = train.vecs(); // Count categorical-vs-numerical final int n = tvecs.length - _responses - (offset ? 1 : 0) - (weight ? 1 : 0) - (fold ? 1 : 0); int[] nums = MemoryManager.malloc4(n); int[] cats = MemoryManager.malloc4(n); int nnums = 0, ncats = 0; for (int i = 0; i < n; ++i) if (tvecs[i].isCategorical()) cats[ncats++] = i; else nums[nnums++] = i; _nums = nnums; _cats = ncats; _catLvls = new int[ncats][]; // sort the cats in the decreasing order according to their size for (int i = 0; i < ncats; ++i) for (int j = i + 1; j < ncats; ++j) if (tvecs[cats[i]].domain().length < tvecs[cats[j]].domain().length) { int x = cats[i]; cats[i] = cats[j]; cats[j] = x; } String[] names = new String[train.numCols()]; Vec[] tvecs2 = new Vec[train.numCols()]; // Compute the cardinality of each cat _catModes = new int[ncats]; _catOffsets = MemoryManager.malloc4(ncats + 1); _catMissing = new boolean[ncats]; int len = _catOffsets[0] = 0; int interactionIdx = 0; // simple index into the _interactionVecs array ArrayList<Integer> interactionIds; if (_interactions == null) { interactionIds = new ArrayList<>(); for (int i = 0; i < tvecs.length; ++i) if (tvecs[i] instanceof InteractionWrappedVec) { interactionIds.add(i); } _interactionVecs = new int[interactionIds.size()]; for (int i = 0; i < _interactionVecs.length; ++i) _interactionVecs[i] = interactionIds.get(i); } for (int i = 0; i < ncats; ++i) { names[i] = train._names[cats[i]]; Vec v = (tvecs2[i] = tvecs[cats[i]]); _catMissing[i] = missingBucket; // needed for test time if (v instanceof InteractionWrappedVec) { if (_interactions != null) _interactions[interactionIdx].vecIdx = i; _interactionVecs[interactionIdx++] = i; // i (and not cats[i]) because this is the index in _adaptedFrame _catOffsets[i + 1] = (len += v.domain().length + (missingBucket ? 1 : 0)); } else _catOffsets[i + 1] = (len += v.domain().length - (useAllFactorLevels ? 0 : 1) + (missingBucket ? 1 : 0)); // missing values turn into a new factor level _catModes[i] = imputeMissing ? imputeCat(train.vec(cats[i])) : _catMissing[i] ? v.domain().length : -100; _permutation[i] = cats[i]; } _numMeans = new double[nnums]; _numOffsets = MemoryManager.malloc4(nnums + 1); _numOffsets[0] = len; boolean isIWV; // is InteractionWrappedVec? for (int i = 0; i < nnums; ++i) { names[i + ncats] = train._names[nums[i]]; Vec v = train.vec(nums[i]); tvecs2[i + ncats] = v; isIWV = v instanceof InteractionWrappedVec; if (isIWV) { if (null != _interactions) _interactions[interactionIdx].vecIdx = i + ncats; _interactionVecs[interactionIdx++] = i + ncats; } _numOffsets[i + 1] = (len += (isIWV ? ((InteractionWrappedVec) v).expandedLength() : 1)); _numMeans[i] = train.vec(nums[i]).mean(); _permutation[i + ncats] = nums[i]; } for (int i = names.length - nResponses - (weight ? 1 : 0) - (offset ? 1 : 0) - (fold ? 1 : 0); i < names.length; ++i) { names[i] = train._names[i]; tvecs2[i] = train.vec(i); } _adaptedFrame = new Frame(names, tvecs2); train.restructure(names, tvecs2); if (valid != null) valid.restructure(names, valid.vecs(names)); // _adaptedFrame = train; setPredictorTransform(predictor_transform); if (_responses > 0) setResponseTransform(response_transform); }