/** * Calculates the difference * * @return difference */ public double prototypeDifference(CombStat stat) { double sumdiff = 0; double weight; // Numeric atts: abs difference for (int i = 0; i < m_RegStat.getNbNumericAttributes(); i++) { weight = m_StatManager.getClusteringWeights().getWeight(m_RegStat.getAttribute(i)); sumdiff += Math.abs(prototypeNum(i) - stat.prototypeNum(i)) * weight; // System.err.println("sumdiff: " + Math.abs(prototypeNum(i) - stat.prototypeNum(i)) * // weight); } // Nominal atts: Manhattan distance for (int i = 0; i < m_ClassStat.getNbNominalAttributes(); i++) { weight = m_StatManager.getClusteringWeights().getWeight(m_ClassStat.getAttribute(i)); double sum = 0; double[] proto1 = prototypeNom(i); double[] proto2 = stat.prototypeNom(i); for (int j = 0; j < proto1.length; j++) { sum += Math.abs(proto1[j] - proto2[j]); } sumdiff += sum * weight; // System.err.println("sumdiff: " + (sum * weight)); } // System.err.println("sumdiff-total: " + sumdiff); return sumdiff != 0 ? sumdiff : 0.0; }
// TODO: Move all heuristic stuff to ClusRuleHeuristic* public double rDispersionMltHeur() { // Original /* double train_sum_w = m_StatManager.getTrainSetStat().getTotalWeight(); double comp = dispersion(IN_HEURISTIC); double def_comp = ((CombStat)m_StatManager.getTrainSetStat()).dispersion(IN_HEURISTIC); return -m_SumWeight/train_sum_w*(def_comp-comp); */ double offset = getSettings().getHeurDispOffset(); double disp = dispersion(IN_HEURISTIC) + offset; double dis1 = disp; double def_disp = ((CombStat) m_StatManager.getTrainSetStat()).dispersion(IN_HEURISTIC); disp = disp - def_disp; // This should be < 0 most of the time double dis2 = disp; // Coverage part double train_sum_w = m_StatManager.getTrainSetStat().getTotalWeight(); double cov_par = getSettings().getHeurCoveragePar(); // comp *= (1.0 + cov_par*train_sum_w/m_SumWeight); // How about this??? // comp *= cov_par*train_sum_w/m_SumWeight; // comp *= cov_par*m_SumWeight/train_sum_w; disp *= Math.pow(m_SumWeight / train_sum_w, cov_par); double dis3 = disp; // Prototype distance part // Prefers rules that predict different class than the default rule if (getSettings().isHeurPrototypeDistPar()) { double proto_par = getSettings().getHeurPrototypeDistPar(); double proto_val = prototypeDifference((CombStat) m_StatManager.getTrainSetStat()); // disp *= (1.0 + proto_par*m_SumWeight/train_sum_w*proto_val); disp = proto_val > 0 ? disp / Math.pow(proto_val, proto_par) : 0.0; } // Significance testing part - TODO: Complete or remove altogether if (Settings.IS_RULE_SIG_TESTING) { int sign_diff; int thresh = getSettings().getRuleNbSigAtt(); if (thresh > 0) { sign_diff = signDifferent(); if (sign_diff < thresh) { disp *= 1000; // Some big number ??? - What if comp < 0??? } } else if (thresh < 0) { // Testing just one target attribute - TODO: change! if (!targetSignDifferent()) { disp *= 1000; // Some big number ??? } } } // System.err.println("Disp: " + dis1 + " DDisp: " + def_disp + " RDisp: " + dis2 + " FDisp: " + // dis3); return disp; }
public void prune(ClusNode node) { RegressionStat stat = (RegressionStat) node.getClusteringStat(); m_GlobalDeviation = Math.sqrt(stat.getSVarS(m_ClusteringWeights) / stat.getTotalWeight()); pruneRecursive(node); // System.out.println("Performing test of M5 pruning"); // TestM5PruningRuleNode.performTest(orig, node, m_GlobalDeviation, m_TargetWeights, // m_TrainingData); }
public void pruneRecursive(ClusNode node) { if (node.atBottomLevel()) { return; } for (int i = 0; i < node.getNbChildren(); i++) { ClusNode child = (ClusNode) node.getChild(i); pruneRecursive(child); } RegressionStat stat = (RegressionStat) node.getClusteringStat(); double rmsLeaf = stat.getRMSE(m_ClusteringWeights); double adjustedErrorLeaf = rmsLeaf * pruningFactor(stat.getTotalWeight(), 1); double rmsSubTree = Math.sqrt(node.estimateClusteringSS(m_ClusteringWeights) / stat.getTotalWeight()); double adjustedErrorTree = rmsSubTree * pruningFactor(stat.getTotalWeight(), node.getModelSize()); // System.out.println("C leaf: "+rmsLeaf+" tree: "+rmsSubTree); // System.out.println("C leafadj: "+adjustedErrorLeaf +" treeadj: "+rmsSubTree); if ((adjustedErrorLeaf <= adjustedErrorTree) || (adjustedErrorLeaf < (m_GlobalDeviation * 0.00001))) { node.makeLeaf(); } }