public void prune(ClusNode node) { RegressionStat stat = (RegressionStat) node.getClusteringStat(); m_GlobalDeviation = Math.sqrt(stat.getSVarS(m_ClusteringWeights) / stat.getTotalWeight()); pruneRecursive(node); // System.out.println("Performing test of M5 pruning"); // TestM5PruningRuleNode.performTest(orig, node, m_GlobalDeviation, m_TargetWeights, // m_TrainingData); }
public void pruneRecursive(ClusNode node) { if (node.atBottomLevel()) { return; } for (int i = 0; i < node.getNbChildren(); i++) { ClusNode child = (ClusNode) node.getChild(i); pruneRecursive(child); } RegressionStat stat = (RegressionStat) node.getClusteringStat(); double rmsLeaf = stat.getRMSE(m_ClusteringWeights); double adjustedErrorLeaf = rmsLeaf * pruningFactor(stat.getTotalWeight(), 1); double rmsSubTree = Math.sqrt(node.estimateClusteringSS(m_ClusteringWeights) / stat.getTotalWeight()); double adjustedErrorTree = rmsSubTree * pruningFactor(stat.getTotalWeight(), node.getModelSize()); // System.out.println("C leaf: "+rmsLeaf+" tree: "+rmsSubTree); // System.out.println("C leafadj: "+adjustedErrorLeaf +" treeadj: "+rmsSubTree); if ((adjustedErrorLeaf <= adjustedErrorTree) || (adjustedErrorLeaf < (m_GlobalDeviation * 0.00001))) { node.makeLeaf(); } }