@Override public Split computeSplit(Data data, int attr) { if (data.getDataset().isNumerical(attr)) { return numericalSplit(data, attr); } else { return categoricalSplit(data, attr); } }
/** Computes the best split for a NUMERICAL attribute */ static Split numericalSplit(Data data, int attr) { double[] values = data.values(attr).clone(); Arrays.sort(values); double[] splitPoints = chooseNumericSplitPoints(values); int numLabels = data.getDataset().nblabels(); int[][] counts = new int[splitPoints.length][numLabels]; int[] countAll = new int[numLabels]; int[] countLess = new int[numLabels]; computeFrequencies(data, attr, splitPoints, counts, countAll); int size = data.size(); double hy = entropy(countAll, size); double invDataSize = 1.0 / size; int best = -1; double bestIg = -1.0; // try each possible split value for (int index = 0; index < splitPoints.length; index++) { double ig = hy; DataUtils.add(countLess, counts[index]); DataUtils.dec(countAll, counts[index]); // instance with attribute value < values[index] size = DataUtils.sum(countLess); ig -= size * invDataSize * entropy(countLess, size); // instance with attribute value >= values[index] size = DataUtils.sum(countAll); ig -= size * invDataSize * entropy(countAll, size); if (ig > bestIg) { bestIg = ig; best = index; } } if (best == -1) { throw new IllegalStateException("no best split found !"); } return new Split(attr, bestIg, splitPoints[best]); }
static void computeFrequencies( Data data, int attr, double[] splitPoints, int[][] counts, int[] countAll) { Dataset dataset = data.getDataset(); for (int index = 0; index < data.size(); index++) { Instance instance = data.get(index); int label = (int) dataset.getLabel(instance); double value = instance.get(attr); int split = 0; while (split < splitPoints.length && value > splitPoints[split]) { split++; } if (split < splitPoints.length) { counts[split][label]++; } // Otherwise it's in the last split, which we don't need to count countAll[label]++; } }
/** * checks if all the vectors have identical attribute values. Ignore selected attributes. * * @return true is all the vectors are identical or the data is empty<br> * false otherwise */ private boolean isIdentical(Data data) { if (data.isEmpty()) { return true; } Instance instance = data.get(0); for (int attr = 0; attr < selected.length; attr++) { if (selected[attr]) { continue; } for (int index = 1; index < data.size(); index++) { if (data.get(index).get(attr) != instance.get(attr)) { return false; } } } return true; }
/** Computes the split for a CATEGORICAL attribute */ private static Split categoricalSplit(Data data, int attr) { double[] values = data.values(attr).clone(); double[] splitPoints = chooseCategoricalSplitPoints(values); int numLabels = data.getDataset().nblabels(); int[][] counts = new int[splitPoints.length][numLabels]; int[] countAll = new int[numLabels]; computeFrequencies(data, attr, splitPoints, counts, countAll); int size = data.size(); double hy = entropy(countAll, size); // H(Y) double hyx = 0.0; // H(Y|X) double invDataSize = 1.0 / size; for (int index = 0; index < splitPoints.length; index++) { size = DataUtils.sum(counts[index]); hyx += size * invDataSize * entropy(counts[index], size); } double ig = hy - hyx; return new Split(attr, ig); }
@Override public Node build(Random rng, Data data) { if (selected == null) { selected = new boolean[data.getDataset().nbAttributes()]; selected[data.getDataset().getLabelId()] = true; // never select the label } if (m == 0) { // set default m double e = data.getDataset().nbAttributes() - 1; if (data.getDataset().isNumerical(data.getDataset().getLabelId())) { // regression m = (int) Math.ceil(e / 3.0); } else { // classification m = (int) Math.ceil(Math.sqrt(e)); } } if (data.isEmpty()) { return new Leaf(Double.NaN); } double sum = 0.0; if (data.getDataset().isNumerical(data.getDataset().getLabelId())) { // regression // sum and sum squared of a label is computed double sumSquared = 0.0; for (int i = 0; i < data.size(); i++) { double label = data.getDataset().getLabel(data.get(i)); sum += label; sumSquared += label * label; } // computes the variance double var = sumSquared - (sum * sum) / data.size(); // computes the minimum variance if (Double.compare(minVariance, Double.NaN) == 0) { minVariance = var / data.size() * minVarianceProportion; log.debug("minVariance:{}", minVariance); } // variance is compared with minimum variance if ((var / data.size()) < minVariance) { log.debug( "variance({}) < minVariance({}) Leaf({})", var / data.size(), minVariance, sum / data.size()); return new Leaf(sum / data.size()); } } else { // classification if (isIdentical(data)) { return new Leaf(data.majorityLabel(rng)); } if (data.identicalLabel()) { return new Leaf(data.getDataset().getLabel(data.get(0))); } } // store full set data if (fullSet == null) { fullSet = data; } int[] attributes = randomAttributes(rng, selected, m); if (attributes == null || attributes.length == 0) { // we tried all the attributes and could not split the data anymore double label; if (data.getDataset().isNumerical(data.getDataset().getLabelId())) { // regression label = sum / data.size(); } else { // classification label = data.majorityLabel(rng); } log.warn("attribute which can be selected is not found Leaf({})", label); return new Leaf(label); } if (igSplit == null) { if (data.getDataset().isNumerical(data.getDataset().getLabelId())) { // regression igSplit = new RegressionSplit(); } else { // classification igSplit = new OptIgSplit(); } } // find the best split Split best = null; for (int attr : attributes) { Split split = igSplit.computeSplit(data, attr); if (best == null || best.getIg() < split.getIg()) { best = split; } } // information gain is near to zero. if (best.getIg() < EPSILON) { double label; if (data.getDataset().isNumerical(data.getDataset().getLabelId())) { label = sum / data.size(); } else { label = data.majorityLabel(rng); } log.debug("ig is near to zero Leaf({})", label); return new Leaf(label); } log.debug("best split attr:{}, split:{}, ig:{}", best.getAttr(), best.getSplit(), best.getIg()); boolean alreadySelected = selected[best.getAttr()]; if (alreadySelected) { // attribute already selected log.warn("attribute {} already selected in a parent node", best.getAttr()); } Node childNode; if (data.getDataset().isNumerical(best.getAttr())) { boolean[] temp = null; Data loSubset = data.subset(Condition.lesser(best.getAttr(), best.getSplit())); Data hiSubset = data.subset(Condition.greaterOrEquals(best.getAttr(), best.getSplit())); if (loSubset.isEmpty() || hiSubset.isEmpty()) { // the selected attribute did not change the data, avoid using it in the child notes selected[best.getAttr()] = true; } else { // the data changed, so we can unselect all previousely selected NUMERICAL attributes temp = selected; selected = cloneCategoricalAttributes(data.getDataset(), selected); } // size of the subset is less than the minSpitNum if (loSubset.size() < minSplitNum || hiSubset.size() < minSplitNum) { // branch is not split double label; if (data.getDataset().isNumerical(data.getDataset().getLabelId())) { label = sum / data.size(); } else { label = data.majorityLabel(rng); } log.debug("branch is not split Leaf({})", label); return new Leaf(label); } Node loChild = build(rng, loSubset); Node hiChild = build(rng, hiSubset); // restore the selection state of the attributes if (temp != null) { selected = temp; } else { selected[best.getAttr()] = alreadySelected; } childNode = new NumericalNode(best.getAttr(), best.getSplit(), loChild, hiChild); } else { // CATEGORICAL attribute double[] values = data.values(best.getAttr()); // tree is complemented Collection<Double> subsetValues = null; if (complemented) { subsetValues = new HashSet<>(); for (double value : values) { subsetValues.add(value); } values = fullSet.values(best.getAttr()); } int cnt = 0; Data[] subsets = new Data[values.length]; for (int index = 0; index < values.length; index++) { if (complemented && !subsetValues.contains(values[index])) { continue; } subsets[index] = data.subset(Condition.equals(best.getAttr(), values[index])); if (subsets[index].size() >= minSplitNum) { cnt++; } } // size of the subset is less than the minSpitNum if (cnt < 2) { // branch is not split double label; if (data.getDataset().isNumerical(data.getDataset().getLabelId())) { label = sum / data.size(); } else { label = data.majorityLabel(rng); } log.debug("branch is not split Leaf({})", label); return new Leaf(label); } selected[best.getAttr()] = true; Node[] children = new Node[values.length]; for (int index = 0; index < values.length; index++) { if (complemented && (subsetValues == null || !subsetValues.contains(values[index]))) { // tree is complemented double label; if (data.getDataset().isNumerical(data.getDataset().getLabelId())) { label = sum / data.size(); } else { label = data.majorityLabel(rng); } log.debug("complemented Leaf({})", label); children[index] = new Leaf(label); continue; } children[index] = build(rng, subsets[index]); } selected[best.getAttr()] = alreadySelected; childNode = new CategoricalNode(best.getAttr(), values, children); } return childNode; }