示例#1
0
  /** Computes the best split for a NUMERICAL attribute */
  static Split numericalSplit(Data data, int attr) {
    double[] values = data.values(attr).clone();
    Arrays.sort(values);

    double[] splitPoints = chooseNumericSplitPoints(values);

    int numLabels = data.getDataset().nblabels();
    int[][] counts = new int[splitPoints.length][numLabels];
    int[] countAll = new int[numLabels];
    int[] countLess = new int[numLabels];

    computeFrequencies(data, attr, splitPoints, counts, countAll);

    int size = data.size();
    double hy = entropy(countAll, size);
    double invDataSize = 1.0 / size;

    int best = -1;
    double bestIg = -1.0;

    // try each possible split value
    for (int index = 0; index < splitPoints.length; index++) {
      double ig = hy;

      DataUtils.add(countLess, counts[index]);
      DataUtils.dec(countAll, counts[index]);

      // instance with attribute value < values[index]
      size = DataUtils.sum(countLess);
      ig -= size * invDataSize * entropy(countLess, size);
      // instance with attribute value >= values[index]
      size = DataUtils.sum(countAll);
      ig -= size * invDataSize * entropy(countAll, size);

      if (ig > bestIg) {
        bestIg = ig;
        best = index;
      }
    }

    if (best == -1) {
      throw new IllegalStateException("no best split found !");
    }
    return new Split(attr, bestIg, splitPoints[best]);
  }
示例#2
0
  /** Computes the split for a CATEGORICAL attribute */
  private static Split categoricalSplit(Data data, int attr) {
    double[] values = data.values(attr).clone();

    double[] splitPoints = chooseCategoricalSplitPoints(values);

    int numLabels = data.getDataset().nblabels();
    int[][] counts = new int[splitPoints.length][numLabels];
    int[] countAll = new int[numLabels];

    computeFrequencies(data, attr, splitPoints, counts, countAll);

    int size = data.size();
    double hy = entropy(countAll, size); // H(Y)
    double hyx = 0.0; // H(Y|X)
    double invDataSize = 1.0 / size;

    for (int index = 0; index < splitPoints.length; index++) {
      size = DataUtils.sum(counts[index]);
      hyx += size * invDataSize * entropy(counts[index], size);
    }

    double ig = hy - hyx;
    return new Split(attr, ig);
  }