@Test public void testFactorNACalc() { FactorExtraTrees et = getFET(10, 5, false); FactorExtraTrees etw = getFET(10, 5, true); double gini; gini = 1 - (0.4 * 0.4 + 0.6 * 0.6); assertEquals(gini, et.get1NaNScore(AbstractTrees.seq(10)), 1e-6); gini = 1 - (Math.pow(3.0 / 9.0, 2) + Math.pow(6.0 / 9.0, 2)); assertEquals(gini, et.get1NaNScore(AbstractTrees.seq(9)), 1e-6); // testing NaN counts: CutResult cr; cr = new CutResult(); et.calculateCutScore(AbstractTrees.seq(9), 2, 0.5, cr); assertEquals(3.0, cr.nanWeigth, 1e-6); cr = new CutResult(); et.calculateCutScore(AbstractTrees.seq(9), 1, 0.5, cr); assertEquals(0.0, cr.nanWeigth, 1e-6); // testing weights: cr = new CutResult(); etw.calculateCutScore(AbstractTrees.seq(9), 2, 0.5, cr); assertEquals(1.5, cr.nanWeigth, 1e-6); }
@Test public void testRegressionNACalc() { ExtraTrees et = getET(10, 5, false); ExtraTrees etw = getET(10, 5, true); double var, mean; mean = 4.0 / 10.0; var = 0.6 * Math.pow(mean, 2) + 0.4 * Math.pow(1 - mean, 2); assertEquals(var, et.get1NaNScore(AbstractTrees.seq(10)), 1e-6); mean = 3.0 / 9.0; var = 6 / 9.0 * Math.pow(mean, 2) + 3 / 9.0 * Math.pow(1 - mean, 2); assertEquals(var, et.get1NaNScore(AbstractTrees.seq(9)), 1e-6); // testing NaN counts: CutResult cr; cr = new CutResult(); et.calculateCutScore(AbstractTrees.seq(9), 2, 0.5, cr); assertEquals(3.0, cr.nanWeigth, 1e-6); cr = new CutResult(); et.calculateCutScore(AbstractTrees.seq(9), 1, 0.5, cr); assertEquals(0.0, cr.nanWeigth, 1e-6); // testing weights: cr = new CutResult(); etw.calculateCutScore(AbstractTrees.seq(9), 2, 0.5, cr); assertEquals(1.5, cr.nanWeigth, 1e-6); }
@Test public void testRegressionNALearn() { int ndim = 5; ExtraTrees et = getET(100, ndim, false); ExtraTrees etw = getET(100, ndim, true); et.learnTrees(3, 3, 5); etw.learnTrees(3, 3, 5); double[] x = new double[ndim]; for (int i = 0; i < x.length; i++) { x[i] = Double.NaN; } double[] val; val = et.getValues(new Matrix(x, 1, ndim)); assertTrue(Double.isNaN(val[0])); val = etw.getValues(new Matrix(x, 1, ndim)); assertTrue(Double.isNaN(val[0])); // checking if getRange works with NaN double[] col2 = ((Matrix) et.input).getCol(2); double[] range2 = AbstractTrees.getRange(col2); assertEquals(0.5, range2[0], 1e-6); assertEquals(0.5, range2[1], 1e-6); }