コード例 #1
0
ファイル: DecisionTree.java プロジェクト: debmalyaroy/smile
    /**
     * Finds the best split cutoff for attribute j at the current node.
     *
     * @param n the number instances in this node.
     * @param count the sample count in each class.
     * @param falseCount an array to store sample count in each class for false child node.
     * @param impurity the impurity of this node.
     * @param j the attribute to split on.
     */
    public Node findBestSplit(int n, int[] count, int[] falseCount, double impurity, int j) {
      Node splitNode = new Node();

      if (attributes[j].getType() == Attribute.Type.NOMINAL) {
        int m = ((NominalAttribute) attributes[j]).size();
        int[][] trueCount = new int[m][k];

        for (int i = 0; i < x.length; i++) {
          if (samples[i] > 0) {
            trueCount[(int) x[i][j]][y[i]] += samples[i];
          }
        }

        for (int l = 0; l < m; l++) {
          int tc = Math.sum(trueCount[l]);
          int fc = n - tc;

          // If either side is empty, skip this feature.
          if (tc < nodeSize || fc < nodeSize) {
            continue;
          }

          for (int q = 0; q < k; q++) {
            falseCount[q] = count[q] - trueCount[l][q];
          }

          int trueLabel = Math.whichMax(trueCount[l]);
          int falseLabel = Math.whichMax(falseCount);
          double gain =
              impurity
                  - (double) tc / n * impurity(trueCount[l], tc)
                  - (double) fc / n * impurity(falseCount, fc);

          if (gain > splitNode.splitScore) {
            // new best split
            splitNode.splitFeature = j;
            splitNode.splitValue = l;
            splitNode.splitScore = gain;
            splitNode.trueChildOutput = trueLabel;
            splitNode.falseChildOutput = falseLabel;
          }
        }
      } else if (attributes[j].getType() == Attribute.Type.NUMERIC) {
        int[] trueCount = new int[k];
        double prevx = Double.NaN;
        int prevy = -1;

        for (int i : order[j]) {
          if (samples[i] > 0) {
            if (Double.isNaN(prevx) || x[i][j] == prevx || y[i] == prevy) {
              prevx = x[i][j];
              prevy = y[i];
              trueCount[y[i]] += samples[i];
              continue;
            }

            int tc = Math.sum(trueCount);
            int fc = n - tc;

            // If either side is empty, skip this feature.
            if (tc < nodeSize || fc < nodeSize) {
              prevx = x[i][j];
              prevy = y[i];
              trueCount[y[i]] += samples[i];
              continue;
            }

            for (int l = 0; l < k; l++) {
              falseCount[l] = count[l] - trueCount[l];
            }

            int trueLabel = Math.whichMax(trueCount);
            int falseLabel = Math.whichMax(falseCount);
            double gain =
                impurity
                    - (double) tc / n * impurity(trueCount, tc)
                    - (double) fc / n * impurity(falseCount, fc);

            if (gain > splitNode.splitScore) {
              // new best split
              splitNode.splitFeature = j;
              splitNode.splitValue = (x[i][j] + prevx) / 2;
              splitNode.splitScore = gain;
              splitNode.trueChildOutput = trueLabel;
              splitNode.falseChildOutput = falseLabel;
            }

            prevx = x[i][j];
            prevy = y[i];
            trueCount[y[i]] += samples[i];
          }
        }
      } else {
        throw new IllegalStateException("Unsupported attribute type: " + attributes[j].getType());
      }

      return splitNode;
    }
コード例 #2
0
ファイル: DecisionTree.java プロジェクト: debmalyaroy/smile
    /** Split the node into two children nodes. Returns true if split success. */
    public boolean split(PriorityQueue<TrainNode> nextSplits) {
      if (node.splitFeature < 0) {
        throw new IllegalStateException("Split a node with invalid feature.");
      }

      int n = x.length;
      int tc = 0;
      int fc = 0;
      int[] trueSamples = new int[n];
      int[] falseSamples = new int[n];

      if (attributes[node.splitFeature].getType() == Attribute.Type.NOMINAL) {
        for (int i = 0; i < n; i++) {
          if (samples[i] > 0) {
            if (x[i][node.splitFeature] == node.splitValue) {
              trueSamples[i] = samples[i];
              tc += samples[i];
            } else {
              falseSamples[i] = samples[i];
              fc += samples[i];
            }
          }
        }
      } else if (attributes[node.splitFeature].getType() == Attribute.Type.NUMERIC) {
        for (int i = 0; i < n; i++) {
          if (samples[i] > 0) {
            if (x[i][node.splitFeature] <= node.splitValue) {
              trueSamples[i] = samples[i];
              tc += samples[i];
            } else {
              falseSamples[i] = samples[i];
              fc += samples[i];
            }
          }
        }
      } else {
        throw new IllegalStateException(
            "Unsupported attribute type: " + attributes[node.splitFeature].getType());
      }

      if (tc < nodeSize || fc < nodeSize) {
        node.splitFeature = -1;
        node.splitValue = Double.NaN;
        node.splitScore = 0.0;
        return false;
      }

      double[] trueChildPosteriori = new double[k];
      double[] falseChildPosteriori = new double[k];
      for (int i = 0; i < n; i++) {
        int yi = y[i];
        trueChildPosteriori[yi] += trueSamples[i];
        falseChildPosteriori[yi] += falseSamples[i];
      }

      // add-k smoothing of posteriori probability
      for (int i = 0; i < k; i++) {
        trueChildPosteriori[i] = (trueChildPosteriori[i] + 1) / (tc + k);
        falseChildPosteriori[i] = (falseChildPosteriori[i] + 1) / (fc + k);
      }

      node.trueChild = new Node(node.trueChildOutput, trueChildPosteriori);
      node.falseChild = new Node(node.falseChildOutput, falseChildPosteriori);

      TrainNode trueChild = new TrainNode(node.trueChild, x, y, trueSamples);
      if (tc > nodeSize && trueChild.findBestSplit()) {
        if (nextSplits != null) {
          nextSplits.add(trueChild);
        } else {
          trueChild.split(null);
        }
      }

      TrainNode falseChild = new TrainNode(node.falseChild, x, y, falseSamples);
      if (fc > nodeSize && falseChild.findBestSplit()) {
        if (nextSplits != null) {
          nextSplits.add(falseChild);
        } else {
          falseChild.split(null);
        }
      }

      importance[node.splitFeature] += node.splitScore;

      return true;
    }
コード例 #3
0
ファイル: DecisionTree.java プロジェクト: debmalyaroy/smile
    /**
     * Finds the best attribute to split on at the current node. Returns true if a split exists to
     * reduce squared error, false otherwise.
     */
    public boolean findBestSplit() {
      int label = -1;
      boolean pure = true;
      for (int i = 0; i < x.length; i++) {
        if (samples[i] > 0) {
          if (label == -1) {
            label = y[i];
          } else if (y[i] != label) {
            pure = false;
            break;
          }
        }
      }

      // Since all instances have same label, stop splitting.
      if (pure) {
        return false;
      }

      int n = 0;
      for (int s : samples) {
        n += s;
      }

      if (n <= nodeSize) {
        return false;
      }

      // Sample count in each class.
      int[] count = new int[k];
      int[] falseCount = new int[k];
      for (int i = 0; i < x.length; i++) {
        if (samples[i] > 0) {
          count[y[i]] += samples[i];
        }
      }

      double impurity = impurity(count, n);

      int p = attributes.length;
      int[] variables = new int[p];
      for (int i = 0; i < p; i++) {
        variables[i] = i;
      }

      if (mtry < p) {
        Math.permutate(variables);

        // Random forest already runs on parallel.
        for (int j = 0; j < mtry; j++) {
          Node split = findBestSplit(n, count, falseCount, impurity, variables[j]);
          if (split.splitScore > node.splitScore) {
            node.splitFeature = split.splitFeature;
            node.splitValue = split.splitValue;
            node.splitScore = split.splitScore;
            node.trueChildOutput = split.trueChildOutput;
            node.falseChildOutput = split.falseChildOutput;
          }
        }
      } else {

        List<SplitTask> tasks = new ArrayList<>(mtry);
        for (int j = 0; j < mtry; j++) {
          tasks.add(new SplitTask(n, count, impurity, variables[j]));
        }

        try {
          for (Node split : MulticoreExecutor.run(tasks)) {
            if (split.splitScore > node.splitScore) {
              node.splitFeature = split.splitFeature;
              node.splitValue = split.splitValue;
              node.splitScore = split.splitScore;
              node.trueChildOutput = split.trueChildOutput;
              node.falseChildOutput = split.falseChildOutput;
            }
          }
        } catch (Exception ex) {
          for (int j = 0; j < mtry; j++) {
            Node split = findBestSplit(n, count, falseCount, impurity, variables[j]);
            if (split.splitScore > node.splitScore) {
              node.splitFeature = split.splitFeature;
              node.splitValue = split.splitValue;
              node.splitScore = split.splitScore;
              node.trueChildOutput = split.trueChildOutput;
              node.falseChildOutput = split.falseChildOutput;
            }
          }
        }
      }

      return (node.splitFeature != -1);
    }