/** * Constructor. Learns a classification tree for AdaBoost and Random Forest. * * @param attributes the attribute properties. * @param x the training instances. * @param y the response variable. * @param nodeSize the minimum size of leaf nodes. * @param maxNodes the maximum number of leaf nodes in the tree. * @param mtry the number of input variables to pick to split on at each node. It seems that * sqrt(p) give generally good performance, where p is the number of variables. * @param rule the splitting rule. * @param order the index of training values in ascending order. Note that only numeric attributes * need be sorted. * @param samples the sample set of instances for stochastic learning. samples[i] is the number of * sampling for instance i. */ public DecisionTree( Attribute[] attributes, double[][] x, int[] y, int maxNodes, int nodeSize, int mtry, SplitRule rule, int[] samples, int[][] order) { if (x.length != y.length) { throw new IllegalArgumentException( String.format("The sizes of X and Y don't match: %d != %d", x.length, y.length)); } if (mtry < 1 || mtry > x[0].length) { throw new IllegalArgumentException( "Invalid number of variables to split on at a node of the tree: " + mtry); } if (maxNodes < 2) { throw new IllegalArgumentException("Invalid maximum leaves: " + maxNodes); } if (nodeSize < 1) { throw new IllegalArgumentException("Invalid minimum size of leaf nodes: " + nodeSize); } // class label set. int[] labels = Math.unique(y); Arrays.sort(labels); for (int i = 0; i < labels.length; i++) { if (labels[i] < 0) { throw new IllegalArgumentException("Negative class label: " + labels[i]); } if (i > 0 && labels[i] - labels[i - 1] > 1) { throw new IllegalArgumentException("Missing class: " + labels[i] + 1); } } k = labels.length; if (k < 2) { throw new IllegalArgumentException("Only one class."); } if (attributes == null) { int p = x[0].length; attributes = new Attribute[p]; for (int i = 0; i < p; i++) { attributes[i] = new NumericAttribute("V" + (i + 1)); } } this.attributes = attributes; this.mtry = mtry; this.nodeSize = nodeSize; this.maxNodes = maxNodes; this.rule = rule; importance = new double[attributes.length]; if (order != null) { this.order = order; } else { int n = x.length; int p = x[0].length; double[] a = new double[n]; this.order = new int[p][]; for (int j = 0; j < p; j++) { if (attributes[j] instanceof NumericAttribute) { for (int i = 0; i < n; i++) { a[i] = x[i][j]; } this.order[j] = QuickSort.sort(a); } } } // Priority queue for best-first tree growing. PriorityQueue<TrainNode> nextSplits = new PriorityQueue<>(); int n = y.length; int[] count = new int[k]; if (samples == null) { samples = new int[n]; for (int i = 0; i < n; i++) { samples[i] = 1; count[y[i]]++; } } else { for (int i = 0; i < n; i++) { count[y[i]] += samples[i]; } } double[] posteriori = new double[k]; for (int i = 0; i < k; i++) { posteriori[i] = (double) count[i] / n; } root = new Node(Math.whichMax(count), posteriori); TrainNode trainRoot = new TrainNode(root, x, y, samples); // Now add splits to the tree until max tree size is reached if (trainRoot.findBestSplit()) { nextSplits.add(trainRoot); } // Pop best leaf from priority queue, split it, and push // children nodes into the queue if possible. for (int leaves = 1; leaves < this.maxNodes; leaves++) { // parent is the leaf to split TrainNode node = nextSplits.poll(); if (node == null) { break; } node.split(nextSplits); // Split the parent node into two children nodes } }
/** Split the node into two children nodes. Returns true if split success. */ public boolean split(PriorityQueue<TrainNode> nextSplits) { if (node.splitFeature < 0) { throw new IllegalStateException("Split a node with invalid feature."); } int n = x.length; int tc = 0; int fc = 0; int[] trueSamples = new int[n]; int[] falseSamples = new int[n]; if (attributes[node.splitFeature].getType() == Attribute.Type.NOMINAL) { for (int i = 0; i < n; i++) { if (samples[i] > 0) { if (x[i][node.splitFeature] == node.splitValue) { trueSamples[i] = samples[i]; tc += samples[i]; } else { falseSamples[i] = samples[i]; fc += samples[i]; } } } } else if (attributes[node.splitFeature].getType() == Attribute.Type.NUMERIC) { for (int i = 0; i < n; i++) { if (samples[i] > 0) { if (x[i][node.splitFeature] <= node.splitValue) { trueSamples[i] = samples[i]; tc += samples[i]; } else { falseSamples[i] = samples[i]; fc += samples[i]; } } } } else { throw new IllegalStateException( "Unsupported attribute type: " + attributes[node.splitFeature].getType()); } if (tc < nodeSize || fc < nodeSize) { node.splitFeature = -1; node.splitValue = Double.NaN; node.splitScore = 0.0; return false; } double[] trueChildPosteriori = new double[k]; double[] falseChildPosteriori = new double[k]; for (int i = 0; i < n; i++) { int yi = y[i]; trueChildPosteriori[yi] += trueSamples[i]; falseChildPosteriori[yi] += falseSamples[i]; } // add-k smoothing of posteriori probability for (int i = 0; i < k; i++) { trueChildPosteriori[i] = (trueChildPosteriori[i] + 1) / (tc + k); falseChildPosteriori[i] = (falseChildPosteriori[i] + 1) / (fc + k); } node.trueChild = new Node(node.trueChildOutput, trueChildPosteriori); node.falseChild = new Node(node.falseChildOutput, falseChildPosteriori); TrainNode trueChild = new TrainNode(node.trueChild, x, y, trueSamples); if (tc > nodeSize && trueChild.findBestSplit()) { if (nextSplits != null) { nextSplits.add(trueChild); } else { trueChild.split(null); } } TrainNode falseChild = new TrainNode(node.falseChild, x, y, falseSamples); if (fc > nodeSize && falseChild.findBestSplit()) { if (nextSplits != null) { nextSplits.add(falseChild); } else { falseChild.split(null); } } importance[node.splitFeature] += node.splitScore; return true; }