예제 #1
0
  public void initializeExamples(ClusAttrType[] attrs, RowData data) {
    // first remove all examplelists from attributes (-> ensembles!)
    for (int i = 0; i < attrs.length; i++) {
      ClusAttrType at = attrs[i];
      if (at.isSparse()) {
        ((SparseNumericAttrType) at).resetExamples();
      }
    }

    for (int i = 0; i < data.getNbRows(); i++) {
      SparseDataTuple tuple = (SparseDataTuple) data.getTuple(i);
      tuple.addExampleToAttributes();
    }
  }
예제 #2
0
  public void induceRandomForestRecursive2(ClusNode node, RowData data, Object[] attrs) {
    // System.out.println("INDUCE SPARSE with " + attrs.length + " attributes and " +
    // data.getNbRows() + " examples");
    // Initialize selector and perform various stopping criteria
    if (initSelectorAndStopCrit(node, data)) {
      makeLeaf(node);
      return;
    }
    // Find best test
    for (int i = 0; i < attrs.length; i++) {
      ClusAttrType at = (ClusAttrType) attrs[i];
      if (at instanceof NominalAttrType) m_FindBestTest.findNominal((NominalAttrType) at, data);
      else m_FindBestTest.findNumeric((NumericAttrType) at, data);
    }

    // Partition data + recursive calls
    CurrentBestTestAndHeuristic best = m_FindBestTest.getBestTest();
    if (best.hasBestTest()) {
      node.testToNode(best);
      // Output best test
      if (Settings.VERBOSE > 0)
        System.out.println("Test: " + node.getTestString() + " -> " + best.getHeuristicValue());
      // Create children
      int arity = node.updateArity();
      NodeTest test = node.getTest();
      RowData[] subsets = new RowData[arity];
      for (int j = 0; j < arity; j++) {
        subsets[j] = data.applyWeighted(test, j);
      }
      if (getSettings().showAlternativeSplits()) {
        filterAlternativeSplits(node, data, subsets);
      }
      if (node != m_Root && getSettings().hasTreeOptimize(Settings.TREE_OPTIMIZE_NO_INODE_STATS)) {
        // Don't remove statistics of root node; code below depends on them
        node.setClusteringStat(null);
        node.setTargetStat(null);
      }

      for (int j = 0; j < arity; j++) {
        ClusNode child = new ClusNode();
        node.setChild(child, j);
        child.initClusteringStat(m_StatManager, m_Root.getClusteringStat(), subsets[j]);
        child.initTargetStat(m_StatManager, m_Root.getTargetStat(), subsets[j]);

        induceRandomForestRecursive(child, subsets[j]);
      }
    } else {
      makeLeaf(node);
    }
  }
예제 #3
0
  public void induce(ClusNode node, RowData data) {
    if (getSettings().isEnsembleMode()
        && ((getSettings().getEnsembleMethod() == Settings.ENSEMBLE_RFOREST)
            || (getSettings().getEnsembleMethod() == Settings.ENSEMBLE_NOBAGRFOREST))) {
      induceRandomForest(node, data);
    } else {
      ClusAttrType[] attrs = getDescriptiveAttributes();
      initializeExamples(attrs, data);
      ArrayList<ClusAttrType> attrList = new ArrayList<ClusAttrType>();
      ArrayList<ArrayList> examplelistList = new ArrayList<ArrayList>();
      for (int i = 0; i < attrs.length; i++) {
        ClusAttrType at = attrs[i];
        if (at.isSparse()) {
          if (((SparseNumericAttrType) at).getExampleWeight() >= getSettings().getMinimalWeight()) {
            attrList.add(at);

            Object[] exampleArray = ((SparseNumericAttrType) at).getExamples().toArray();
            RowData exampleData = new RowData(exampleArray, exampleArray.length);
            exampleData.sortSparse((SparseNumericAttrType) at, m_FindBestTest.getSortHelper());
            ArrayList<SparseDataTuple> exampleList = new ArrayList<SparseDataTuple>();
            for (int j = 0; j < exampleData.getNbRows(); j++) {
              exampleList.add((SparseDataTuple) exampleData.getTuple(j));
            }
            ((SparseNumericAttrType) at).setExamples(exampleList);
            examplelistList.add(exampleList);
          }
        } else {
          attrList.add(at);
          examplelistList.add(null);
        }
      }
      Object[] attrArray = attrList.toArray();
      Object[] examplelistArray = examplelistList.toArray();
      induce(node, data, attrArray, examplelistArray);
    }
  }
 public static RowData predict(ClusModel model, RowData test) throws ClusException {
   ClusSchema schema = test.getSchema();
   schema.attachModel(model);
   RowData predictions = new RowData(schema, test.getNbRows());
   for (int i = 0; i < test.getNbRows(); i++) {
     DataTuple prediction = new DataTuple(schema);
     ClusStatistic stat = model.predictWeighted(test.getTuple(i));
     stat.predictTuple(prediction);
     predictions.setTuple(prediction, i);
   }
   return predictions;
 }