Exemple #1
0
  /** Computes the best split for a NUMERICAL attribute */
  static Split numericalSplit(Data data, int attr) {
    double[] values = data.values(attr).clone();
    Arrays.sort(values);

    double[] splitPoints = chooseNumericSplitPoints(values);

    int numLabels = data.getDataset().nblabels();
    int[][] counts = new int[splitPoints.length][numLabels];
    int[] countAll = new int[numLabels];
    int[] countLess = new int[numLabels];

    computeFrequencies(data, attr, splitPoints, counts, countAll);

    int size = data.size();
    double hy = entropy(countAll, size);
    double invDataSize = 1.0 / size;

    int best = -1;
    double bestIg = -1.0;

    // try each possible split value
    for (int index = 0; index < splitPoints.length; index++) {
      double ig = hy;

      DataUtils.add(countLess, counts[index]);
      DataUtils.dec(countAll, counts[index]);

      // instance with attribute value < values[index]
      size = DataUtils.sum(countLess);
      ig -= size * invDataSize * entropy(countLess, size);
      // instance with attribute value >= values[index]
      size = DataUtils.sum(countAll);
      ig -= size * invDataSize * entropy(countAll, size);

      if (ig > bestIg) {
        bestIg = ig;
        best = index;
      }
    }

    if (best == -1) {
      throw new IllegalStateException("no best split found !");
    }
    return new Split(attr, bestIg, splitPoints[best]);
  }
Exemple #2
0
  /** Computes the split for a CATEGORICAL attribute */
  private static Split categoricalSplit(Data data, int attr) {
    double[] values = data.values(attr).clone();

    double[] splitPoints = chooseCategoricalSplitPoints(values);

    int numLabels = data.getDataset().nblabels();
    int[][] counts = new int[splitPoints.length][numLabels];
    int[] countAll = new int[numLabels];

    computeFrequencies(data, attr, splitPoints, counts, countAll);

    int size = data.size();
    double hy = entropy(countAll, size); // H(Y)
    double hyx = 0.0; // H(Y|X)
    double invDataSize = 1.0 / size;

    for (int index = 0; index < splitPoints.length; index++) {
      size = DataUtils.sum(counts[index]);
      hyx += size * invDataSize * entropy(counts[index], size);
    }

    double ig = hy - hyx;
    return new Split(attr, ig);
  }
  @Override
  public Node build(Random rng, Data data) {
    if (selected == null) {
      selected = new boolean[data.getDataset().nbAttributes()];
      selected[data.getDataset().getLabelId()] = true; // never select the label
    }
    if (m == 0) {
      // set default m
      double e = data.getDataset().nbAttributes() - 1;
      if (data.getDataset().isNumerical(data.getDataset().getLabelId())) {
        // regression
        m = (int) Math.ceil(e / 3.0);
      } else {
        // classification
        m = (int) Math.ceil(Math.sqrt(e));
      }
    }

    if (data.isEmpty()) {
      return new Leaf(Double.NaN);
    }

    double sum = 0.0;
    if (data.getDataset().isNumerical(data.getDataset().getLabelId())) {
      // regression
      // sum and sum squared of a label is computed
      double sumSquared = 0.0;
      for (int i = 0; i < data.size(); i++) {
        double label = data.getDataset().getLabel(data.get(i));
        sum += label;
        sumSquared += label * label;
      }

      // computes the variance
      double var = sumSquared - (sum * sum) / data.size();

      // computes the minimum variance
      if (Double.compare(minVariance, Double.NaN) == 0) {
        minVariance = var / data.size() * minVarianceProportion;
        log.debug("minVariance:{}", minVariance);
      }

      // variance is compared with minimum variance
      if ((var / data.size()) < minVariance) {
        log.debug(
            "variance({}) < minVariance({}) Leaf({})",
            var / data.size(),
            minVariance,
            sum / data.size());
        return new Leaf(sum / data.size());
      }
    } else {
      // classification
      if (isIdentical(data)) {
        return new Leaf(data.majorityLabel(rng));
      }
      if (data.identicalLabel()) {
        return new Leaf(data.getDataset().getLabel(data.get(0)));
      }
    }

    // store full set data
    if (fullSet == null) {
      fullSet = data;
    }

    int[] attributes = randomAttributes(rng, selected, m);
    if (attributes == null || attributes.length == 0) {
      // we tried all the attributes and could not split the data anymore
      double label;
      if (data.getDataset().isNumerical(data.getDataset().getLabelId())) {
        // regression
        label = sum / data.size();
      } else {
        // classification
        label = data.majorityLabel(rng);
      }
      log.warn("attribute which can be selected is not found Leaf({})", label);
      return new Leaf(label);
    }

    if (igSplit == null) {
      if (data.getDataset().isNumerical(data.getDataset().getLabelId())) {
        // regression
        igSplit = new RegressionSplit();
      } else {
        // classification
        igSplit = new OptIgSplit();
      }
    }

    // find the best split
    Split best = null;
    for (int attr : attributes) {
      Split split = igSplit.computeSplit(data, attr);
      if (best == null || best.getIg() < split.getIg()) {
        best = split;
      }
    }

    // information gain is near to zero.
    if (best.getIg() < EPSILON) {
      double label;
      if (data.getDataset().isNumerical(data.getDataset().getLabelId())) {
        label = sum / data.size();
      } else {
        label = data.majorityLabel(rng);
      }
      log.debug("ig is near to zero Leaf({})", label);
      return new Leaf(label);
    }

    log.debug("best split attr:{}, split:{}, ig:{}", best.getAttr(), best.getSplit(), best.getIg());

    boolean alreadySelected = selected[best.getAttr()];
    if (alreadySelected) {
      // attribute already selected
      log.warn("attribute {} already selected in a parent node", best.getAttr());
    }

    Node childNode;
    if (data.getDataset().isNumerical(best.getAttr())) {
      boolean[] temp = null;

      Data loSubset = data.subset(Condition.lesser(best.getAttr(), best.getSplit()));
      Data hiSubset = data.subset(Condition.greaterOrEquals(best.getAttr(), best.getSplit()));

      if (loSubset.isEmpty() || hiSubset.isEmpty()) {
        // the selected attribute did not change the data, avoid using it in the child notes
        selected[best.getAttr()] = true;
      } else {
        // the data changed, so we can unselect all previousely selected NUMERICAL attributes
        temp = selected;
        selected = cloneCategoricalAttributes(data.getDataset(), selected);
      }

      // size of the subset is less than the minSpitNum
      if (loSubset.size() < minSplitNum || hiSubset.size() < minSplitNum) {
        // branch is not split
        double label;
        if (data.getDataset().isNumerical(data.getDataset().getLabelId())) {
          label = sum / data.size();
        } else {
          label = data.majorityLabel(rng);
        }
        log.debug("branch is not split Leaf({})", label);
        return new Leaf(label);
      }

      Node loChild = build(rng, loSubset);
      Node hiChild = build(rng, hiSubset);

      // restore the selection state of the attributes
      if (temp != null) {
        selected = temp;
      } else {
        selected[best.getAttr()] = alreadySelected;
      }

      childNode = new NumericalNode(best.getAttr(), best.getSplit(), loChild, hiChild);
    } else { // CATEGORICAL attribute
      double[] values = data.values(best.getAttr());

      // tree is complemented
      Collection<Double> subsetValues = null;
      if (complemented) {
        subsetValues = new HashSet<>();
        for (double value : values) {
          subsetValues.add(value);
        }
        values = fullSet.values(best.getAttr());
      }

      int cnt = 0;
      Data[] subsets = new Data[values.length];
      for (int index = 0; index < values.length; index++) {
        if (complemented && !subsetValues.contains(values[index])) {
          continue;
        }
        subsets[index] = data.subset(Condition.equals(best.getAttr(), values[index]));
        if (subsets[index].size() >= minSplitNum) {
          cnt++;
        }
      }

      // size of the subset is less than the minSpitNum
      if (cnt < 2) {
        // branch is not split
        double label;
        if (data.getDataset().isNumerical(data.getDataset().getLabelId())) {
          label = sum / data.size();
        } else {
          label = data.majorityLabel(rng);
        }
        log.debug("branch is not split Leaf({})", label);
        return new Leaf(label);
      }

      selected[best.getAttr()] = true;

      Node[] children = new Node[values.length];
      for (int index = 0; index < values.length; index++) {
        if (complemented && (subsetValues == null || !subsetValues.contains(values[index]))) {
          // tree is complemented
          double label;
          if (data.getDataset().isNumerical(data.getDataset().getLabelId())) {
            label = sum / data.size();
          } else {
            label = data.majorityLabel(rng);
          }
          log.debug("complemented Leaf({})", label);
          children[index] = new Leaf(label);
          continue;
        }
        children[index] = build(rng, subsets[index]);
      }

      selected[best.getAttr()] = alreadySelected;

      childNode = new CategoricalNode(best.getAttr(), values, children);
    }

    return childNode;
  }