/** Implementation of the FindBestSplit algorithm from E.Ikonomovska et al. */ protected AttributeSplitSuggestion searchForBestSplitOption( Node currentNode, AttributeSplitSuggestion currentBestOption, SplitCriterion criterion, int attIndex) { // Return null if the current node is null or we have finished looking through all the possible // splits if (currentNode == null || rightTotal == 0.0) { return currentBestOption; } if (currentNode.left != null) { currentBestOption = searchForBestSplitOption(currentNode.left, currentBestOption, criterion, attIndex); } sumTotalLeft += currentNode.leftStatistics.getValue(1); sumTotalRight -= currentNode.leftStatistics.getValue(1); sumSqTotalLeft += currentNode.leftStatistics.getValue(2); sumSqTotalRight -= currentNode.leftStatistics.getValue(2); rightTotal -= currentNode.leftStatistics.getValue(0); double[][] postSplitDists = new double[][] { {currentNode.leftStatistics.getValue(0), sumTotalLeft, sumSqTotalLeft}, {rightTotal, sumTotalRight, sumSqTotalRight} }; double[] preSplitDist = new double[] { (currentNode.leftStatistics.getValue(0) + rightTotal), (sumTotalLeft + sumTotalRight), (sumSqTotalLeft + sumSqTotalRight) }; double merit = criterion.getMeritOfSplit(preSplitDist, postSplitDists); if ((currentBestOption == null) || (merit > currentBestOption.merit)) { currentBestOption = new AttributeSplitSuggestion( new NumericAttributeBinaryTest(attIndex, currentNode.cut_point, true), postSplitDists, merit); } if (currentNode.right != null) { currentBestOption = searchForBestSplitOption(currentNode.right, currentBestOption, criterion, attIndex); } sumTotalLeft -= currentNode.leftStatistics.getValue(1); sumTotalRight += currentNode.leftStatistics.getValue(1); sumSqTotalLeft -= currentNode.leftStatistics.getValue(2); sumSqTotalRight += currentNode.leftStatistics.getValue(2); rightTotal += currentNode.leftStatistics.getValue(0); return currentBestOption; }
/** * Recursive method that first checks all of a node's children before deciding if it is 'bad' and * may be removed */ private boolean removeBadSplitNodes( SplitCriterion criterion, Node currentNode, double lastCheckRatio, double lastCheckSDR, double lastCheckE) { boolean isBad = false; if (currentNode == null) { return true; } if (currentNode.left != null) { isBad = removeBadSplitNodes( criterion, currentNode.left, lastCheckRatio, lastCheckSDR, lastCheckE); } if (currentNode.right != null && isBad) { isBad = removeBadSplitNodes( criterion, currentNode.left, lastCheckRatio, lastCheckSDR, lastCheckE); } if (isBad) { double[][] postSplitDists = new double[][] { { currentNode.leftStatistics.getValue(0), currentNode.leftStatistics.getValue(1), currentNode.leftStatistics.getValue(2) }, { currentNode.rightStatistics.getValue(0), currentNode.rightStatistics.getValue(1), currentNode.rightStatistics.getValue(2) } }; double[] preSplitDist = new double[] { (currentNode.leftStatistics.getValue(0) + currentNode.rightStatistics.getValue(0)), (currentNode.leftStatistics.getValue(1) + currentNode.rightStatistics.getValue(1)), (currentNode.leftStatistics.getValue(2) + currentNode.rightStatistics.getValue(2)) }; double merit = criterion.getMeritOfSplit(preSplitDist, postSplitDists); if ((merit / lastCheckSDR) < (lastCheckRatio - (2 * lastCheckE))) { currentNode = null; return true; } } return false; }