Esempio n. 1
0
  @Test
  public void testClone() {
    System.out.println("clone");
    for (boolean useCatFeatures : new boolean[] {true, false}) {
      RandomForest instance = new RandomForest();

      ClassificationDataSet t1 = FixedProblems.getSimpleKClassLinear(100, 3);
      ClassificationDataSet t2 = FixedProblems.getSimpleKClassLinear(100, 6);
      if (useCatFeatures) {
        t1.applyTransform(new NumericalToHistogram(t1));
        t2.applyTransform(new NumericalToHistogram(t2));
      }

      instance = instance.clone();

      instance.trainC(t1);

      RandomForest result = instance.clone();
      for (int i = 0; i < t1.getSampleSize(); i++)
        assertEquals(t1.getDataPointCategory(i), result.classify(t1.getDataPoint(i)).mostLikely());
      result.trainC(t2);

      for (int i = 0; i < t1.getSampleSize(); i++)
        assertEquals(
            t1.getDataPointCategory(i), instance.classify(t1.getDataPoint(i)).mostLikely());

      for (int i = 0; i < t2.getSampleSize(); i++)
        assertEquals(t2.getDataPointCategory(i), result.classify(t2.getDataPoint(i)).mostLikely());
    }
  }
Esempio n. 2
0
  @Override
  public void trainC(ClassificationDataSet dataSet, ExecutorService threadPool) {
    final int models = baseClassifiers.size();
    final int C = dataSet.getClassSize();
    weightsPerModel = C == 2 ? 1 : C;
    ClassificationDataSet metaSet =
        new ClassificationDataSet(
            weightsPerModel * models, new CategoricalData[0], dataSet.getPredicting());

    List<ClassificationDataSet> dataFolds = dataSet.cvSet(folds);
    // iterate in the order of the folds so we get the right dataum weights
    for (ClassificationDataSet cds : dataFolds)
      for (int i = 0; i < cds.getSampleSize(); i++)
        metaSet.addDataPoint(
            new DenseVector(weightsPerModel * models),
            cds.getDataPointCategory(i),
            cds.getDataPoint(i).getWeight());

    // create the meta training set
    for (int c = 0; c < baseClassifiers.size(); c++) {
      Classifier cl = baseClassifiers.get(c);
      int pos = 0;
      for (int f = 0; f < dataFolds.size(); f++) {
        ClassificationDataSet train = ClassificationDataSet.comineAllBut(dataFolds, f);
        ClassificationDataSet test = dataFolds.get(f);
        if (threadPool == null) cl.trainC(train);
        else cl.trainC(train, threadPool);
        for (int i = 0;
            i < test.getSampleSize();
            i++) // evaluate and mark each point in the held out fold.
        {
          CategoricalResults pred = cl.classify(test.getDataPoint(i));
          if (C == 2)
            metaSet.getDataPoint(pos).getNumericalValues().set(c, pred.getProb(0) * 2 - 1);
          else {
            Vec toSet = metaSet.getDataPoint(pos).getNumericalValues();
            for (int j = weightsPerModel * c; j < weightsPerModel * (c + 1); j++)
              toSet.set(j, pred.getProb(j - weightsPerModel * c));
          }

          pos++;
        }
      }
    }

    // train the meta model
    if (threadPool == null) aggregatingClassifier.trainC(metaSet);
    else aggregatingClassifier.trainC(metaSet, threadPool);

    // train the final classifiers, unless folds=1. In that case they are already trained
    if (folds != 1) {
      for (Classifier cl : baseClassifiers)
        if (threadPool == null) cl.trainC(dataSet);
        else cl.trainC(dataSet, threadPool);
    }
  }
Esempio n. 3
0
  public void trainC(ClassificationDataSet dataSet, ExecutorService threadPool) {
    if (dataSet.getClassSize() != 2)
      throw new FailedToFitException(
          "Logistic Regression works only in the case of two classes, and can not handle "
              + dataSet.getClassSize()
              + " classes");
    RegressionDataSet rds =
        new RegressionDataSet(dataSet.getNumNumericalVars(), dataSet.getCategories());
    for (int i = 0; i < dataSet.getSampleSize(); i++) {
      // getDataPointCategory will return either 0 or 1, so it works perfectly
      rds.addDataPoint(dataSet.getDataPoint(i), (double) dataSet.getDataPointCategory(i));
    }

    train(rds, threadPool);
  }