/** Computes average class values for each attribute and value */ private void computeAverageClassValues() { double totalCounts, sum; Instance instance; double[] counts; double[][] avgClassValues = new double[getInputFormat().numAttributes()][0]; m_Indices = new int[getInputFormat().numAttributes()][0]; for (int j = 0; j < getInputFormat().numAttributes(); j++) { Attribute att = getInputFormat().attribute(j); if (att.isNominal()) { avgClassValues[j] = new double[att.numValues()]; counts = new double[att.numValues()]; for (int i = 0; i < getInputFormat().numInstances(); i++) { instance = getInputFormat().instance(i); if (!instance.classIsMissing() && (!instance.isMissing(j))) { counts[(int) instance.value(j)] += instance.weight(); avgClassValues[j][(int) instance.value(j)] += instance.weight() * instance.classValue(); } } sum = Utils.sum(avgClassValues[j]); totalCounts = Utils.sum(counts); if (Utils.gr(totalCounts, 0)) { for (int k = 0; k < att.numValues(); k++) { if (Utils.gr(counts[k], 0)) { avgClassValues[j][k] /= counts[k]; } else { avgClassValues[j][k] = sum / totalCounts; } } } m_Indices[j] = Utils.sort(avgClassValues[j]); } } }
/** * Normalizes branch sizes so they contain frequencies (stored in "props") instead of counts * (stored in "dist"). * * <p>Overwrites the supplied "props"! * * <p>props.length must be == dist.length. */ protected static void countsToFreqs(double[][] dist, double[] props) { for (int k = 0; k < props.length; k++) { props[k] = Utils.sum(dist[k]); } if (Utils.eq(Utils.sum(props), 0)) { for (int k = 0; k < props.length; k++) { props[k] = 1.0 / (double) props.length; } } else { FastRfUtils.normalize(props); } }
/** * Normalizes branch sizes so they contain frequencies (stored in "props") instead of counts * (stored in "dist"). Creates a new double[] which it returns. */ protected static double[] countsToFreqs(double[][] dist) { double[] props = new double[dist.length]; for (int k = 0; k < props.length; k++) { props[k] = Utils.sum(dist[k]); } if (Utils.eq(Utils.sum(props), 0)) { for (int k = 0; k < props.length; k++) { props[k] = 1.0 / (double) props.length; } } else { FastRfUtils.normalize(props); } return props; }
/** * Normalize the instance * * @param inst instance to be normalized * @return a new Instance with normalized values */ private Instance normalizeInstance(Instance inst) { double[] vals = inst.toDoubleArray(); double sum = Utils.sum(vals); for (int i = 0; i < vals.length; i++) { vals[i] /= sum; } return new DenseInstance(inst.weight(), vals); }
/** * Calculates the class membership probabilities for the given test instance. * * @param instance the instance to be classified * @return predicted class probability distribution * @throws Exception if an error occurred during the prediction */ public double[] distributionForInstance(Instance instance) throws Exception { String debug = "(KStar.distributionForInstance) "; double transProb = 0.0, temp = 0.0; double[] classProbability = new double[m_NumClasses]; double[] predictedValue = new double[1]; // initialization ... for (int i = 0; i < classProbability.length; i++) { classProbability[i] = 0.0; } predictedValue[0] = 0.0; if (m_InitFlag == ON) { // need to compute them only once and will be used for all instances. // We are doing this because the evaluation module controls the calls. if (m_BlendMethod == B_ENTROPY) { generateRandomClassColomns(); } m_Cache = new KStarCache[m_NumAttributes]; for (int i = 0; i < m_NumAttributes; i++) { m_Cache[i] = new KStarCache(); } m_InitFlag = OFF; // System.out.println("Computing..."); } // init done. Instance trainInstance; Enumeration enu = m_Train.enumerateInstances(); while (enu.hasMoreElements()) { trainInstance = (Instance) enu.nextElement(); transProb = instanceTransformationProbability(instance, trainInstance); switch (m_ClassType) { case Attribute.NOMINAL: classProbability[(int) trainInstance.classValue()] += transProb; break; case Attribute.NUMERIC: predictedValue[0] += transProb * trainInstance.classValue(); temp += transProb; break; } } if (m_ClassType == Attribute.NOMINAL) { double sum = Utils.sum(classProbability); if (sum <= 0.0) for (int i = 0; i < classProbability.length; i++) classProbability[i] = (double) 1 / (double) m_NumClasses; else Utils.normalize(classProbability, sum); return classProbability; } else { predictedValue[0] = (temp != 0) ? predictedValue[0] / temp : 0.0; return predictedValue; } }
private Matrix getTransposedNormedMatrix(Instances data) { Matrix matrix = new Matrix(data.numAttributes(), data.numInstances()); for (int i = 0; i < data.numInstances(); i++) { double[] vals = data.instance(i).toDoubleArray(); double sum = Utils.sum(vals); for (int v = 0; v < vals.length; v++) { vals[v] /= sum; matrix.set(v, i, vals[v]); } } return matrix; }
/** * Classifies a given instance. Either this or distributionForInstance() needs to be implemented * by subclasses. * * @param instance the instance to be assigned to a cluster * @return the number of the assigned cluster as an integer * @exception Exception if instance could not be clustered successfully */ @Override public int clusterInstance(Instance instance) throws Exception { double[] dist = distributionForInstance(instance); if (dist == null) { throw new Exception("Null distribution predicted"); } if (Utils.sum(dist) <= 0) { throw new Exception("Unable to cluster instance"); } return Utils.maxIndex(dist); }
/** * Calculates the class membership probabilities for the given test instance * * @param instance the instance to be classified * @return predicted class probability distribution * @throws Exception if there is a problem generating the prediction */ public double[] distributionForInstance(Instance instance) throws Exception { // default model? if (m_ZeroR != null) { return m_ZeroR.distributionForInstance(instance); } // Definition of local variables double[] probs = new double[m_NumClasses]; double prob; double mutualInfoSum; // store instance's att values in an int array int[] attIndex = new int[m_NumAttributes]; for (int att = 0; att < m_NumAttributes; att++) { if (att == m_ClassIndex) attIndex[att] = -1; else attIndex[att] = m_StartAttIndex[att] + (int) instance.value(att); } // calculate probabilities for each possible class value for (int classVal = 0; classVal < m_NumClasses; classVal++) { probs[classVal] = 0; prob = 1; mutualInfoSum = 0.0; for (int parent = 0; parent < m_NumAttributes; parent++) { if (attIndex[parent] == -1) continue; prob = (m_ClassAttAttCounts[classVal][attIndex[parent]][attIndex[parent]] + 1.0 / (m_NumClasses * m_NumAttValues[parent])) / (m_NumInstances + 1.0); for (int son = 0; son < m_NumAttributes; son++) { if (attIndex[son] == -1 || son == parent) continue; prob *= (m_ClassAttAttCounts[classVal][attIndex[parent]][attIndex[son]] + 1.0 / m_NumAttValues[son]) / (m_ClassAttAttCounts[classVal][attIndex[parent]][attIndex[parent]] + 1.0); } mutualInfoSum += m_mutualInformation[parent]; probs[classVal] += m_mutualInformation[parent] * prob; } probs[classVal] /= mutualInfoSum; } if (!Double.isNaN(Utils.sum(probs))) Utils.normalize(probs); return probs; }
/** * Calculates the class membership probabilities for the given test instance. * * @param instance the instance to be classified * @return predicted class probability distribution * @exception Exception if distribution can't be computed successfully */ public double[] distributionForInstance(Instance instance) throws Exception { if (instance.classAttribute().isNumeric()) { throw new UnsupportedClassTypeException("Decorate can't handle a numeric class!"); } double[] sums = new double[instance.numClasses()], newProbs; Classifier curr; for (int i = 0; i < m_Committee.size(); i++) { curr = (Classifier) m_Committee.get(i); newProbs = curr.distributionForInstance(instance); for (int j = 0; j < newProbs.length; j++) sums[j] += newProbs[j]; } if (Utils.eq(Utils.sum(sums), 0)) { return sums; } else { Utils.normalize(sums); return sums; } }
/** * Compute the JS divergence between an instance and a cluster, used for test data * * @param inst instance to be clustered * @param t index of the cluster * @param pi1 * @param pi2 * @return the JS divergence */ private double JS(Instance inst, int t, double pi1, double pi2) { if (Math.min(pi1, pi2) <= 0) { System.out.format( "Warning: zero or negative weights in JS calculation! (pi1 %s, pi2 %s)\n", pi1, pi2); return 0; } double sum = Utils.sum(inst.toDoubleArray()); double kl1 = 0.0, kl2 = 0.0, tmp = 0.0; for (int i = 0; i < inst.numValues(); i++) { tmp = inst.valueSparse(i) / sum; if (tmp != 0) { kl1 += tmp * Math.log(tmp / (tmp * pi1 + pi2 * bestT.Py_t.get(inst.index(i), t))); } } for (int i = 0; i < m_numAttributes; i++) { if ((tmp = bestT.Py_t.get(i, t)) != 0) { kl2 += tmp * Math.log(tmp / (inst.value(i) * pi1 / sum + pi2 * tmp)); } } return pi1 * kl1 + pi2 * kl2; }
/** * Test using Kononenko's MDL criterion. * * @param priorCounts * @param bestCounts * @param numInstances * @param numCutPoints * @return true if the split is acceptable */ private boolean KononenkosMDL( double[] priorCounts, double[][] bestCounts, double numInstances, int numCutPoints) { double distPrior, instPrior, distAfter = 0, sum, instAfter = 0; double before, after; int numClassesTotal; // Number of classes occuring in the set numClassesTotal = 0; for (double priorCount : priorCounts) { if (priorCount > 0) { numClassesTotal++; } } // Encode distribution prior to split distPrior = SpecialFunctions.log2Binomial(numInstances + numClassesTotal - 1, numClassesTotal - 1); // Encode instances prior to split. instPrior = SpecialFunctions.log2Multinomial(numInstances, priorCounts); before = instPrior + distPrior; // Encode distributions and instances after split. for (double[] bestCount : bestCounts) { sum = Utils.sum(bestCount); distAfter += SpecialFunctions.log2Binomial(sum + numClassesTotal - 1, numClassesTotal - 1); instAfter += SpecialFunctions.log2Multinomial(sum, bestCount); } // Coding cost after split after = Utils.log2(numCutPoints) + distAfter + instAfter; // Check if split is to be accepted return (before > after); }
/** * Compute the entropy score based on an array of probabilities * * @param probs array of non-negative and normalized probabilities * @return the entropy value */ private double Entropy(double[] probs) { for (double prob : probs) { if (prob <= 0) { if (m_verbose) { System.out.println("Warning: Negative probability."); } return Double.NaN; } } // could be unormalized, when normalization is not specified if (Math.abs(Utils.sum(probs) - 1) >= 1e-6) { if (m_verbose) { System.out.println("Warning: Not normalized."); } return Double.NaN; } double mi = 0.0; for (double prob : probs) { mi += prob * Math.log(prob); } mi = -mi; return mi; }
/** * Generates the classifier. * * @param instances set of instances serving as training data * @exception Exception if the classifier has not been generated successfully */ @Override public void buildClassifier(Instances instances) throws Exception { int attIndex = 0; double sum; // can classifier handle the data? getCapabilities().testWithFail(instances); // remove instances with missing class instances = new Instances(instances); instances.deleteWithMissingClass(); m_Instances = new Instances(instances, 0); // Reserve space m_Counts = new double[instances.numClasses()][instances.numAttributes() - 1][0]; m_Means = new double[instances.numClasses()][instances.numAttributes() - 1]; m_Devs = new double[instances.numClasses()][instances.numAttributes() - 1]; m_Priors = new double[instances.numClasses()]; Enumeration<Attribute> enu = instances.enumerateAttributes(); while (enu.hasMoreElements()) { Attribute attribute = enu.nextElement(); if (attribute.isNominal()) { for (int j = 0; j < instances.numClasses(); j++) { m_Counts[j][attIndex] = new double[attribute.numValues()]; } } else { for (int j = 0; j < instances.numClasses(); j++) { m_Counts[j][attIndex] = new double[1]; } } attIndex++; } // Compute counts and sums Enumeration<Instance> enumInsts = instances.enumerateInstances(); while (enumInsts.hasMoreElements()) { Instance instance = enumInsts.nextElement(); if (!instance.classIsMissing()) { Enumeration<Attribute> enumAtts = instances.enumerateAttributes(); attIndex = 0; while (enumAtts.hasMoreElements()) { Attribute attribute = enumAtts.nextElement(); if (!instance.isMissing(attribute)) { if (attribute.isNominal()) { m_Counts[(int) instance.classValue()][attIndex][(int) instance.value(attribute)]++; } else { m_Means[(int) instance.classValue()][attIndex] += instance.value(attribute); m_Counts[(int) instance.classValue()][attIndex][0]++; } } attIndex++; } m_Priors[(int) instance.classValue()]++; } } // Compute means Enumeration<Attribute> enumAtts = instances.enumerateAttributes(); attIndex = 0; while (enumAtts.hasMoreElements()) { Attribute attribute = enumAtts.nextElement(); if (attribute.isNumeric()) { for (int j = 0; j < instances.numClasses(); j++) { if (m_Counts[j][attIndex][0] < 2) { throw new Exception( "attribute " + attribute.name() + ": less than two values for class " + instances.classAttribute().value(j)); } m_Means[j][attIndex] /= m_Counts[j][attIndex][0]; } } attIndex++; } // Compute standard deviations enumInsts = instances.enumerateInstances(); while (enumInsts.hasMoreElements()) { Instance instance = enumInsts.nextElement(); if (!instance.classIsMissing()) { enumAtts = instances.enumerateAttributes(); attIndex = 0; while (enumAtts.hasMoreElements()) { Attribute attribute = enumAtts.nextElement(); if (!instance.isMissing(attribute)) { if (attribute.isNumeric()) { m_Devs[(int) instance.classValue()][attIndex] += (m_Means[(int) instance.classValue()][attIndex] - instance.value(attribute)) * (m_Means[(int) instance.classValue()][attIndex] - instance.value(attribute)); } } attIndex++; } } } enumAtts = instances.enumerateAttributes(); attIndex = 0; while (enumAtts.hasMoreElements()) { Attribute attribute = enumAtts.nextElement(); if (attribute.isNumeric()) { for (int j = 0; j < instances.numClasses(); j++) { if (m_Devs[j][attIndex] <= 0) { throw new Exception( "attribute " + attribute.name() + ": standard deviation is 0 for class " + instances.classAttribute().value(j)); } else { m_Devs[j][attIndex] /= m_Counts[j][attIndex][0] - 1; m_Devs[j][attIndex] = Math.sqrt(m_Devs[j][attIndex]); } } } attIndex++; } // Normalize counts enumAtts = instances.enumerateAttributes(); attIndex = 0; while (enumAtts.hasMoreElements()) { Attribute attribute = enumAtts.nextElement(); if (attribute.isNominal()) { for (int j = 0; j < instances.numClasses(); j++) { sum = Utils.sum(m_Counts[j][attIndex]); for (int i = 0; i < attribute.numValues(); i++) { m_Counts[j][attIndex][i] = (m_Counts[j][attIndex][i] + 1) / (sum + attribute.numValues()); } } } attIndex++; } // Normalize priors sum = Utils.sum(m_Priors); for (int j = 0; j < instances.numClasses(); j++) { m_Priors[j] = (m_Priors[j] + 1) / (sum + instances.numClasses()); } }
/** * Classifies the given test instance. * * @param instance the instance to be classified * @return the predicted class for the instance * @throws Exception if the instance can't be classified */ public double[] distributionForInstance(Instance instance) throws Exception { double[] dist = new double[m_NumClasses]; double[] temp = new double[m_NumClasses]; double weight = 1.0; for (int i = 0; i < instance.numAttributes(); i++) { if (i != m_ClassIndex && !instance.isMissing(i)) { double val = instance.value(i); boolean ok = false; if (instance.attribute(i).isNumeric()) { int k; for (k = m_intervalBounds[i].length - 1; k >= 0; k--) { if (val > m_intervalBounds[i][k]) { for (int j = 0; j < m_NumClasses; j++) { if (m_globalCounts[j] > 0) { temp[j] = ((m_counts[i][k][j] + TINY) / (m_globalCounts[j] + TINY)); } } ok = true; break; } else if (val == m_intervalBounds[i][k]) { for (int j = 0; j < m_NumClasses; j++) { if (m_globalCounts[j] > 0) { temp[j] = ((m_counts[i][k][j] + m_counts[i][k - 1][j]) / 2.0) + TINY; temp[j] /= (m_globalCounts[j] + TINY); } } ok = true; break; } } if (!ok) { throw new Exception("This shouldn't happen"); } } else { // nominal attribute ok = true; for (int j = 0; j < m_NumClasses; j++) { if (m_globalCounts[j] > 0) { temp[j] = ((m_counts[i][(int) val][j] + TINY) / (m_globalCounts[j] + TINY)); } } } double sum = Utils.sum(temp); if (sum <= 0) { for (int j = 0; j < temp.length; j++) { temp[j] = 1.0 / (double) temp.length; } } else { Utils.normalize(temp, sum); } if (m_weightByConfidence) { weight = weka.core.ContingencyTables.entropy(temp); weight = Math.pow(weight, m_bias); if (weight < 1.0) { weight = 1.0; } } for (int j = 0; j < m_NumClasses; j++) { dist[j] += (temp[j] * weight); } } } double sum = Utils.sum(dist); if (sum <= 0) { for (int j = 0; j < dist.length; j++) { dist[j] = 1.0 / (double) dist.length; } return dist; } else { Utils.normalize(dist, sum); return dist; } }
private double[] calculateRegionProbs(int j, int i) throws Exception { double[] sumOfProbsForRegion = new double[m_trainingData.classAttribute().numValues()]; for (int u = 0; u < m_numOfSamplesPerRegion; u++) { double[] sumOfProbsForLocation = new double[m_trainingData.classAttribute().numValues()]; m_weightingAttsValues[m_xAttribute] = getRandomX(j); m_weightingAttsValues[m_yAttribute] = getRandomY(m_panelHeight - i - 1); m_dataGenerator.setWeightingValues(m_weightingAttsValues); double[] weights = m_dataGenerator.getWeights(); double sumOfWeights = Utils.sum(weights); int[] indices = Utils.sort(weights); // Prune 1% of weight mass int[] newIndices = new int[indices.length]; double sumSoFar = 0; double criticalMass = 0.99 * sumOfWeights; int index = weights.length - 1; int counter = 0; for (int z = weights.length - 1; z >= 0; z--) { newIndices[index--] = indices[z]; sumSoFar += weights[indices[z]]; counter++; if (sumSoFar > criticalMass) { break; } } indices = new int[counter]; System.arraycopy(newIndices, index + 1, indices, 0, counter); for (int z = 0; z < m_numOfSamplesPerGenerator; z++) { m_dataGenerator.setWeightingValues(m_weightingAttsValues); double[][] values = m_dataGenerator.generateInstances(indices); for (int q = 0; q < values.length; q++) { if (values[q] != null) { System.arraycopy(values[q], 0, m_vals, 0, m_vals.length); m_vals[m_xAttribute] = m_weightingAttsValues[m_xAttribute]; m_vals[m_yAttribute] = m_weightingAttsValues[m_yAttribute]; // classify the instance m_dist = m_classifier.distributionForInstance(m_predInst); for (int k = 0; k < sumOfProbsForLocation.length; k++) { sumOfProbsForLocation[k] += (m_dist[k] * weights[q]); } } } } for (int k = 0; k < sumOfProbsForRegion.length; k++) { sumOfProbsForRegion[k] += (sumOfProbsForLocation[k] * sumOfWeights); } } // average Utils.normalize(sumOfProbsForRegion); // cache double[] tempDist = new double[sumOfProbsForRegion.length]; System.arraycopy(sumOfProbsForRegion, 0, tempDist, 0, sumOfProbsForRegion.length); return tempDist; }
/** * Recursively generates a tree. A derivative of the buildTree function from the * "weka.classifiers.trees.RandomTree" class, with the following changes made: * * <ul> * <li>m_ClassProbs are now remembered only in leaves, not in every node of the tree * <li>m_Distribution has been removed * <li>members of dists, splits, props and vals arrays which are not used are dereferenced prior * to recursion to reduce memory requirements * <li>a check for "branch with no training instances" is now (FastRF 0.98) made before * recursion; with the current implementation of splitData(), empty branches can appear only * with nominal attributes with more than two categories * <li>each new 'tree' (i.e. node or leaf) is passed a reference to its 'mother forest', * necessary to look up parameters such as maxDepth and K * <li>pre-split entropy is not recalculated unnecessarily * <li>uses DataCache instead of weka.core.Instances, the reference to the DataCache is stored * as a field in FastRandomTree class and not passed recursively down new buildTree() calls * <li>similarly, a reference to the random number generator is stored in a field of the * DataCache * <li>m_ClassProbs are now normalized by dividing with number of instances in leaf, instead of * forcing the sum of class probabilities to 1.0; this has a large effect when * class/instance weights are set by user * <li>a little imprecision is allowed in checking whether there was a decrease in entropy after * splitting * <li>0.99: the temporary arrays splits, props, vals now are not wide as the full number of * attributes in the dataset (of which only "k" columns of randomly chosen attributes get * filled). Now, it's just a single array which gets replaced as the k features are * evaluated sequentially, but it gets replaced only if a next feature is better than a * previous one. * <li>0.99: the SortedIndices are now not cut up into smaller arrays on every split, but rather * re-sorted within the same array in the splitDataNew(), and passed down to buildTree() as * the original large matrix, but with start and end points explicitly specified * </ul> * * @param sortedIndices the indices of the instances of the whole bootstrap replicate * @param startAt First index of the instance to consider in this split; inclusive. * @param endAt Last index of the instance to consider; inclusive. * @param classProbs the class distribution * @param debug whether debugging is on * @param attIndicesWindow the attribute window to choose attributes from * @param depth the current depth */ protected void buildTree( int[][] sortedIndices, int startAt, int endAt, double[] classProbs, boolean debug, int[] attIndicesWindow, int depth) { m_Debug = debug; int sortedIndicesLength = endAt - startAt + 1; // Check if node doesn't contain enough instances or is pure // or maximum depth reached, make leaf. if ((sortedIndicesLength < Math.max(2, getMinNum())) // small || Utils.eq(classProbs[Utils.maxIndex(classProbs)], Utils.sum(classProbs)) // pure || ((getMaxDepth() > 0) && (depth >= getMaxDepth())) // deep ) { m_Attribute = -1; // indicates leaf (no useful attribute to split on) // normalize by dividing with the number of instances (as of ver. 0.97) // unless leaf is empty - this can happen with splits on nominal // attributes with more than two categories if (sortedIndicesLength != 0) for (int c = 0; c < classProbs.length; c++) { classProbs[c] /= sortedIndicesLength; } m_ClassProbs = classProbs; this.data = null; return; } // (leaf making) // new 0.99: all the following are for the best attribute only! they're updated while // sequentially through the attributes double val = Double.NaN; // value of splitting criterion double[][] dist = new double[2] [data.numClasses]; // class distributions (contingency table), indexed first by branch, // then by class double[] prop = new double[2]; // the branch sizes (as fraction) double split = Double.NaN; // split point // Investigate K random attributes int attIndex = 0; int windowSize = attIndicesWindow.length; int k = getKValue(); boolean sensibleSplitFound = false; double prior = Double.NaN; double bestNegPosterior = -Double.MAX_VALUE; int bestAttIdx = -1; while ((windowSize > 0) && (k-- > 0 || !sensibleSplitFound)) { int chosenIndex = data.reusableRandomGenerator.nextInt(windowSize); attIndex = attIndicesWindow[chosenIndex]; // shift chosen attIndex out of window attIndicesWindow[chosenIndex] = attIndicesWindow[windowSize - 1]; attIndicesWindow[windowSize - 1] = attIndex; windowSize--; // new: 0.99 double candidateSplit = distributionSequentialAtt( prop, dist, bestNegPosterior, attIndex, sortedIndices[attIndex], startAt, endAt); if (Double.isNaN(candidateSplit)) { continue; // we did not improve over a previous attribute! "dist" is unchanged from before } // by this point we know we have an improvement, so we keep the new split point split = candidateSplit; bestAttIdx = attIndex; if (Double.isNaN( prior)) { // needs to be computed only once per branch - is same for all attributes (even // regardless of missing values) prior = SplitCriteria.entropyOverColumns(dist); } double negPosterior = -SplitCriteria.entropyConditionedOnRows(dist); // this is an updated dist if (negPosterior > bestNegPosterior) { bestNegPosterior = negPosterior; } else { throw new IllegalArgumentException("Very strange!"); } val = prior - (-negPosterior); // we want the greatest reduction in entropy if (val > 1e-2) { // we allow some leeway here to compensate sensibleSplitFound = true; // for imprecision in entropy computation } } // feature by feature in window if (sensibleSplitFound) { m_Attribute = bestAttIdx; // find best attribute m_SplitPoint = split; m_Prop = prop; prop = null; // can be GC'ed // int[][][] subsetIndices = // new int[dist.length][data.numAttributes][]; // splitData( subsetIndices, m_Attribute, // m_SplitPoint, sortedIndices ); // int numInstancesBeforeSplit = sortedIndices[0].length; int belowTheSplitStartsAt = splitDataNew(m_Attribute, m_SplitPoint, sortedIndices, startAt, endAt); m_Successors = new FastRandomTree[dist.length]; // dist.length now always == 2 for (int i = 0; i < dist.length; i++) { m_Successors[i] = new FastRandomTree(); m_Successors[i].m_MotherForest = this.m_MotherForest; m_Successors[i].data = this.data; // new in 0.99 - used in distributionSequentialAtt() m_Successors[i].tempDists = this.tempDists; m_Successors[i].tempDistsOther = this.tempDistsOther; m_Successors[i].tempProps = this.tempProps; // check if we're about to make an empty branch - this can happen with // nominal attributes with more than two categories (as of ver. 0.98) if (belowTheSplitStartsAt - startAt == 0) { // in this case, modify the chosenAttDists[i] so that it contains // the current, before-split class probabilities, properly normalized // by the number of instances (as we won't be able to normalize // after the split) for (int j = 0; j < dist[i].length; j++) dist[i][j] = classProbs[j] / sortedIndicesLength; } if (i == 0) { // before split m_Successors[i].buildTree( sortedIndices, startAt, belowTheSplitStartsAt - 1, dist[i], m_Debug, attIndicesWindow, depth + 1); } else { // after split m_Successors[i].buildTree( sortedIndices, belowTheSplitStartsAt, endAt, dist[i], m_Debug, attIndicesWindow, depth + 1); } dist[i] = null; } sortedIndices = null; } else { // ------ make leaf -------- m_Attribute = -1; // normalize by dividing with the number of instances (as of ver. 0.97) // unless leaf is empty - this can happen with splits on nominal attributes if (sortedIndicesLength != 0) for (int c = 0; c < classProbs.length; c++) { classProbs[c] /= sortedIndicesLength; } m_ClassProbs = classProbs; } this.data = null; // dereference all pointers so data can be GC'd after tree is built }