@Test public void testExpandCatsIris() throws InterruptedException, ExecutionException { double[][] iris = ard( ard(6.3, 2.5, 4.9, 1.5, 1), ard(5.7, 2.8, 4.5, 1.3, 1), ard(5.6, 2.8, 4.9, 2.0, 2), ard(5.0, 3.4, 1.6, 0.4, 0), ard(6.0, 2.2, 5.0, 1.5, 2)); double[][] iris_expandR = ard( ard(0, 1, 0, 6.3, 2.5, 4.9, 1.5), ard(0, 1, 0, 5.7, 2.8, 4.5, 1.3), ard(0, 0, 1, 5.6, 2.8, 4.9, 2.0), ard(1, 0, 0, 5.0, 3.4, 1.6, 0.4), ard(0, 0, 1, 6.0, 2.2, 5.0, 1.5)); String[] iris_cols = new String[] {"sepal_len", "sepal_wid", "petal_len", "petal_wid", "class"}; String[][] iris_domains = new String[][] {null, null, null, null, new String[] {"setosa", "versicolor", "virginica"}}; Frame fr = null; try { fr = parse_test_file(Key.make("iris.hex"), "smalldata/iris/iris_wheader.csv"); DataInfo dinfo = new DataInfo( Key.make(), fr, null, 0, true, DataInfo.TransformType.NONE, DataInfo.TransformType.NONE, false, false, false, /* weights */ false, /* offset */ false, /* fold */ false); Log.info("Original matrix:\n" + colFormat(iris_cols, "%8.7s") + ArrayUtils.pprint(iris)); double[][] iris_perm = ArrayUtils.permuteCols(iris, dinfo._permutation); Log.info( "Permuted matrix:\n" + colFormat(iris_cols, "%8.7s", dinfo._permutation) + ArrayUtils.pprint(iris_perm)); double[][] iris_exp = GLRM.expandCats(iris_perm, dinfo); Log.info( "Expanded matrix:\n" + colExpFormat(iris_cols, iris_domains, "%8.7s", dinfo._permutation) + ArrayUtils.pprint(iris_exp)); Assert.assertArrayEquals(iris_expandR, iris_exp); } catch (Throwable t) { t.printStackTrace(); throw new RuntimeException(t); } finally { if (fr != null) fr.delete(); } }
@Test public void testExpandCatsProstate() throws InterruptedException, ExecutionException { double[][] prostate = ard( ard(0, 71, 1, 0, 0, 4.8, 14.0, 7), ard(1, 70, 1, 1, 0, 8.4, 21.8, 5), ard(0, 73, 1, 3, 0, 10.0, 27.4, 6), ard(1, 68, 1, 0, 0, 6.7, 16.7, 6)); double[][] pros_expandR = ard( ard(1, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 71, 4.8, 14.0, 7), ard(0, 1, 0, 0, 0, 1, 0, 0, 1, 1, 0, 70, 8.4, 21.8, 5), ard(0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 73, 10.0, 27.4, 6), ard(1, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 68, 6.7, 16.7, 6)); String[] pros_cols = new String[] {"Capsule", "Age", "Race", "Dpros", "Dcaps", "PSA", "Vol", "Gleason"}; String[][] pros_domains = new String[][] { new String[] {"No", "Yes"}, null, new String[] {"Other", "White", "Black"}, new String[] {"None", "UniLeft", "UniRight", "Bilobar"}, new String[] {"No", "Yes"}, null, null, null }; final int[] cats = new int[] {1, 3, 4, 5}; // Categoricals: CAPSULE, RACE, DPROS, DCAPS Frame fr = null; try { Scope.enter(); fr = parse_test_file(Key.make("prostate.hex"), "smalldata/logreg/prostate.csv"); for (int i = 0; i < cats.length; i++) Scope.track(fr.replace(cats[i], fr.vec(cats[i]).toCategoricalVec())._key); fr.remove("ID").remove(); DKV.put(fr._key, fr); DataInfo dinfo = new DataInfo( Key.make(), fr, null, 0, true, DataInfo.TransformType.NONE, DataInfo.TransformType.NONE, false, false, false, /* weights */ false, /* offset */ false, /* fold */ false); Log.info("Original matrix:\n" + colFormat(pros_cols, "%8.7s") + ArrayUtils.pprint(prostate)); double[][] pros_perm = ArrayUtils.permuteCols(prostate, dinfo._permutation); Log.info( "Permuted matrix:\n" + colFormat(pros_cols, "%8.7s", dinfo._permutation) + ArrayUtils.pprint(pros_perm)); double[][] pros_exp = GLRM.expandCats(pros_perm, dinfo); Log.info( "Expanded matrix:\n" + colExpFormat(pros_cols, pros_domains, "%8.7s", dinfo._permutation) + ArrayUtils.pprint(pros_exp)); Assert.assertArrayEquals(pros_expandR, pros_exp); } catch (Throwable t) { t.printStackTrace(); throw new RuntimeException(t); } finally { if (fr != null) fr.delete(); Scope.exit(); } }