示例#1
0
  @Test
  public void testClone() {
    System.out.println("clone");
    for (boolean useCatFeatures : new boolean[] {true, false}) {
      RandomForest instance = new RandomForest();

      ClassificationDataSet t1 = FixedProblems.getSimpleKClassLinear(100, 3);
      ClassificationDataSet t2 = FixedProblems.getSimpleKClassLinear(100, 6);
      if (useCatFeatures) {
        t1.applyTransform(new NumericalToHistogram(t1));
        t2.applyTransform(new NumericalToHistogram(t2));
      }

      instance = instance.clone();

      instance.trainC(t1);

      RandomForest result = instance.clone();
      for (int i = 0; i < t1.getSampleSize(); i++)
        assertEquals(t1.getDataPointCategory(i), result.classify(t1.getDataPoint(i)).mostLikely());
      result.trainC(t2);

      for (int i = 0; i < t1.getSampleSize(); i++)
        assertEquals(
            t1.getDataPointCategory(i), instance.classify(t1.getDataPoint(i)).mostLikely());

      for (int i = 0; i < t2.getSampleSize(); i++)
        assertEquals(t2.getDataPointCategory(i), result.classify(t2.getDataPoint(i)).mostLikely());
    }
  }
示例#2
0
  @Test
  public void testTrainC_ClassificationDataSetMissingFeat() {
    System.out.println("trainC");
    for (boolean useCatFeatures : new boolean[] {true, false}) {
      RandomForest instance = new RandomForest();

      ClassificationDataSet train = FixedProblems.getCircles(1000, 1.0, 10.0, 100.0);
      // RF may not get boundry perfect, so use noiseless for testing
      ClassificationDataSet test =
          FixedProblems.getCircles(1000, 0.0, new XORWOW(), 1.0, 10.0, 100.0);

      train.applyTransform(new InsertMissingValuesTransform(0.1));
      test.applyTransform(new InsertMissingValuesTransform(0.01));

      ClassificationModelEvaluation cme = new ClassificationModelEvaluation(instance, train);
      if (useCatFeatures)
        cme.setDataTransformProcess(
            new DataTransformProcess(
                new NumericalToHistogram.NumericalToHistogramTransformFactory()));
      cme.evaluateTestSet(test);

      if (useCatFeatures) // hard to get right with only 2 features like this
      assertTrue(cme.getErrorRate() <= 0.17);
      else assertTrue(cme.getErrorRate() <= 0.1);
    }
  }