コード例 #1
0
  public void induceRandomForestRecursive2(ClusNode node, RowData data, Object[] attrs) {
    // System.out.println("INDUCE SPARSE with " + attrs.length + " attributes and " +
    // data.getNbRows() + " examples");
    // Initialize selector and perform various stopping criteria
    if (initSelectorAndStopCrit(node, data)) {
      makeLeaf(node);
      return;
    }
    // Find best test
    for (int i = 0; i < attrs.length; i++) {
      ClusAttrType at = (ClusAttrType) attrs[i];
      if (at instanceof NominalAttrType) m_FindBestTest.findNominal((NominalAttrType) at, data);
      else m_FindBestTest.findNumeric((NumericAttrType) at, data);
    }

    // Partition data + recursive calls
    CurrentBestTestAndHeuristic best = m_FindBestTest.getBestTest();
    if (best.hasBestTest()) {
      node.testToNode(best);
      // Output best test
      if (Settings.VERBOSE > 0)
        System.out.println("Test: " + node.getTestString() + " -> " + best.getHeuristicValue());
      // Create children
      int arity = node.updateArity();
      NodeTest test = node.getTest();
      RowData[] subsets = new RowData[arity];
      for (int j = 0; j < arity; j++) {
        subsets[j] = data.applyWeighted(test, j);
      }
      if (getSettings().showAlternativeSplits()) {
        filterAlternativeSplits(node, data, subsets);
      }
      if (node != m_Root && getSettings().hasTreeOptimize(Settings.TREE_OPTIMIZE_NO_INODE_STATS)) {
        // Don't remove statistics of root node; code below depends on them
        node.setClusteringStat(null);
        node.setTargetStat(null);
      }

      for (int j = 0; j < arity; j++) {
        ClusNode child = new ClusNode();
        node.setChild(child, j);
        child.initClusteringStat(m_StatManager, m_Root.getClusteringStat(), subsets[j]);
        child.initTargetStat(m_StatManager, m_Root.getTargetStat(), subsets[j]);

        induceRandomForestRecursive(child, subsets[j]);
      }
    } else {
      makeLeaf(node);
    }
  }
コード例 #2
0
ファイル: M5Pruner.java プロジェクト: rcerri/Clus-Hyper-Code
 public void prune(ClusNode node) {
   RegressionStat stat = (RegressionStat) node.getClusteringStat();
   m_GlobalDeviation = Math.sqrt(stat.getSVarS(m_ClusteringWeights) / stat.getTotalWeight());
   pruneRecursive(node);
   // System.out.println("Performing test of M5 pruning");
   // TestM5PruningRuleNode.performTest(orig, node, m_GlobalDeviation, m_TargetWeights,
   // m_TrainingData);
 }
コード例 #3
0
ファイル: M5Pruner.java プロジェクト: rcerri/Clus-Hyper-Code
 public void pruneRecursive(ClusNode node) {
   if (node.atBottomLevel()) {
     return;
   }
   for (int i = 0; i < node.getNbChildren(); i++) {
     ClusNode child = (ClusNode) node.getChild(i);
     pruneRecursive(child);
   }
   RegressionStat stat = (RegressionStat) node.getClusteringStat();
   double rmsLeaf = stat.getRMSE(m_ClusteringWeights);
   double adjustedErrorLeaf = rmsLeaf * pruningFactor(stat.getTotalWeight(), 1);
   double rmsSubTree =
       Math.sqrt(node.estimateClusteringSS(m_ClusteringWeights) / stat.getTotalWeight());
   double adjustedErrorTree =
       rmsSubTree * pruningFactor(stat.getTotalWeight(), node.getModelSize());
   // System.out.println("C leaf: "+rmsLeaf+" tree: "+rmsSubTree);
   // System.out.println("C leafadj: "+adjustedErrorLeaf +" treeadj: "+rmsSubTree);
   if ((adjustedErrorLeaf <= adjustedErrorTree)
       || (adjustedErrorLeaf < (m_GlobalDeviation * 0.00001))) {
     node.makeLeaf();
   }
 }