Ejemplo n.º 1
0
  @Test
  public void testExpandCatsIris() throws InterruptedException, ExecutionException {
    double[][] iris =
        ard(
            ard(6.3, 2.5, 4.9, 1.5, 1),
            ard(5.7, 2.8, 4.5, 1.3, 1),
            ard(5.6, 2.8, 4.9, 2.0, 2),
            ard(5.0, 3.4, 1.6, 0.4, 0),
            ard(6.0, 2.2, 5.0, 1.5, 2));
    double[][] iris_expandR =
        ard(
            ard(0, 1, 0, 6.3, 2.5, 4.9, 1.5),
            ard(0, 1, 0, 5.7, 2.8, 4.5, 1.3),
            ard(0, 0, 1, 5.6, 2.8, 4.9, 2.0),
            ard(1, 0, 0, 5.0, 3.4, 1.6, 0.4),
            ard(0, 0, 1, 6.0, 2.2, 5.0, 1.5));
    String[] iris_cols = new String[] {"sepal_len", "sepal_wid", "petal_len", "petal_wid", "class"};
    String[][] iris_domains =
        new String[][] {null, null, null, null, new String[] {"setosa", "versicolor", "virginica"}};

    Frame fr = null;
    try {
      fr = parse_test_file(Key.make("iris.hex"), "smalldata/iris/iris_wheader.csv");
      DataInfo dinfo =
          new DataInfo(
              Key.make(),
              fr,
              null,
              0,
              true,
              DataInfo.TransformType.NONE,
              DataInfo.TransformType.NONE,
              false,
              false,
              false, /* weights */
              false, /* offset */
              false, /* fold */
              false);

      Log.info("Original matrix:\n" + colFormat(iris_cols, "%8.7s") + ArrayUtils.pprint(iris));
      double[][] iris_perm = ArrayUtils.permuteCols(iris, dinfo._permutation);
      Log.info(
          "Permuted matrix:\n"
              + colFormat(iris_cols, "%8.7s", dinfo._permutation)
              + ArrayUtils.pprint(iris_perm));

      double[][] iris_exp = GLRM.expandCats(iris_perm, dinfo);
      Log.info(
          "Expanded matrix:\n"
              + colExpFormat(iris_cols, iris_domains, "%8.7s", dinfo._permutation)
              + ArrayUtils.pprint(iris_exp));
      Assert.assertArrayEquals(iris_expandR, iris_exp);
    } catch (Throwable t) {
      t.printStackTrace();
      throw new RuntimeException(t);
    } finally {
      if (fr != null) fr.delete();
    }
  }
Ejemplo n.º 2
0
  @Test
  public void testExpandCatsProstate() throws InterruptedException, ExecutionException {
    double[][] prostate =
        ard(
            ard(0, 71, 1, 0, 0, 4.8, 14.0, 7),
            ard(1, 70, 1, 1, 0, 8.4, 21.8, 5),
            ard(0, 73, 1, 3, 0, 10.0, 27.4, 6),
            ard(1, 68, 1, 0, 0, 6.7, 16.7, 6));
    double[][] pros_expandR =
        ard(
            ard(1, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 71, 4.8, 14.0, 7),
            ard(0, 1, 0, 0, 0, 1, 0, 0, 1, 1, 0, 70, 8.4, 21.8, 5),
            ard(0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 73, 10.0, 27.4, 6),
            ard(1, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 68, 6.7, 16.7, 6));
    String[] pros_cols =
        new String[] {"Capsule", "Age", "Race", "Dpros", "Dcaps", "PSA", "Vol", "Gleason"};
    String[][] pros_domains =
        new String[][] {
          new String[] {"No", "Yes"},
          null,
          new String[] {"Other", "White", "Black"},
          new String[] {"None", "UniLeft", "UniRight", "Bilobar"},
          new String[] {"No", "Yes"},
          null,
          null,
          null
        };
    final int[] cats = new int[] {1, 3, 4, 5}; // Categoricals: CAPSULE, RACE, DPROS, DCAPS

    Frame fr = null;
    try {
      Scope.enter();
      fr = parse_test_file(Key.make("prostate.hex"), "smalldata/logreg/prostate.csv");
      for (int i = 0; i < cats.length; i++)
        Scope.track(fr.replace(cats[i], fr.vec(cats[i]).toCategoricalVec())._key);
      fr.remove("ID").remove();
      DKV.put(fr._key, fr);
      DataInfo dinfo =
          new DataInfo(
              Key.make(),
              fr,
              null,
              0,
              true,
              DataInfo.TransformType.NONE,
              DataInfo.TransformType.NONE,
              false,
              false,
              false, /* weights */
              false, /* offset */
              false, /* fold */
              false);

      Log.info("Original matrix:\n" + colFormat(pros_cols, "%8.7s") + ArrayUtils.pprint(prostate));
      double[][] pros_perm = ArrayUtils.permuteCols(prostate, dinfo._permutation);
      Log.info(
          "Permuted matrix:\n"
              + colFormat(pros_cols, "%8.7s", dinfo._permutation)
              + ArrayUtils.pprint(pros_perm));

      double[][] pros_exp = GLRM.expandCats(pros_perm, dinfo);
      Log.info(
          "Expanded matrix:\n"
              + colExpFormat(pros_cols, pros_domains, "%8.7s", dinfo._permutation)
              + ArrayUtils.pprint(pros_exp));
      Assert.assertArrayEquals(pros_expandR, pros_exp);
    } catch (Throwable t) {
      t.printStackTrace();
      throw new RuntimeException(t);
    } finally {
      if (fr != null) fr.delete();
      Scope.exit();
    }
  }