コード例 #1
0
  @Test
  public void testCategoricalProstate() throws InterruptedException, ExecutionException {
    GLRM job = null;
    GLRMModel model = null;
    Frame train = null;
    final int[] cats = new int[] {1, 3, 4, 5}; // Categoricals: CAPSULE, RACE, DPROS, DCAPS

    try {
      Scope.enter();
      train = parse_test_file(Key.make("prostate.hex"), "smalldata/logreg/prostate.csv");
      for (int i = 0; i < cats.length; i++)
        Scope.track(train.replace(cats[i], train.vec(cats[i]).toCategoricalVec())._key);
      train.remove("ID").remove();
      DKV.put(train._key, train);

      GLRMParameters parms = new GLRMParameters();
      parms._train = train._key;
      parms._k = 8;
      parms._gamma_x = parms._gamma_y = 0.1;
      parms._regularization_x = GLRMModel.GLRMParameters.Regularizer.Quadratic;
      parms._regularization_y = GLRMModel.GLRMParameters.Regularizer.Quadratic;
      parms._init = GLRM.Initialization.PlusPlus;
      parms._transform = DataInfo.TransformType.STANDARDIZE;
      parms._recover_svd = false;
      parms._max_iterations = 200;

      try {
        job = new GLRM(parms);
        model = job.trainModel().get();
        Log.info(
            "Iteration "
                + model._output._iterations
                + ": Objective value = "
                + model._output._objective);
        model.score(train).delete();
        ModelMetricsGLRM mm = (ModelMetricsGLRM) ModelMetrics.getFromDKV(model, train);
        Log.info(
            "Numeric Sum of Squared Error = "
                + mm._numerr
                + "\tCategorical Misclassification Error = "
                + mm._caterr);
      } catch (Throwable t) {
        t.printStackTrace();
        throw new RuntimeException(t);
      } finally {
        job.remove();
      }
    } catch (Throwable t) {
      t.printStackTrace();
      throw new RuntimeException(t);
    } finally {
      if (train != null) train.delete();
      if (model != null) model.delete();
      Scope.exit();
    }
  }
コード例 #2
0
  @Test
  public void testLosses() throws InterruptedException, ExecutionException {
    long seed = 0xDECAF;
    Random rng = new Random(seed);
    Frame train = null;
    final int[] cats = new int[] {1, 3, 4, 5}; // Categoricals: CAPSULE, RACE, DPROS, DCAPS
    final GLRMParameters.Regularizer[] regs =
        new GLRMParameters.Regularizer[] {
          GLRMParameters.Regularizer.Quadratic,
          GLRMParameters.Regularizer.L1,
          GLRMParameters.Regularizer.NonNegative,
          GLRMParameters.Regularizer.OneSparse,
          GLRMParameters.Regularizer.UnitOneSparse,
          GLRMParameters.Regularizer.Simplex
        };

    Scope.enter();
    try {
      train = parse_test_file(Key.make("prostate.hex"), "smalldata/logreg/prostate.csv");
      for (int i = 0; i < cats.length; i++)
        Scope.track(train.replace(cats[i], train.vec(cats[i]).toCategoricalVec())._key);
      train.remove("ID").remove();
      DKV.put(train._key, train);

      for (GLRMParameters.Loss loss :
          new GLRMParameters.Loss[] {
            GLRMParameters.Loss.Quadratic,
            GLRMParameters.Loss.Absolute,
            GLRMParameters.Loss.Huber,
            GLRMParameters.Loss.Poisson,
            GLRMParameters.Loss.Hinge,
            GLRMParameters.Loss.Logistic
          }) {
        for (GLRMParameters.Loss multiloss :
            new GLRMParameters.Loss[] {
              GLRMParameters.Loss.Categorical, GLRMParameters.Loss.Ordinal
            }) {
          GLRMModel model = null;
          try {
            Scope.enter();
            long myseed = rng.nextLong();
            Log.info("GLRM using seed = " + myseed);

            GLRMParameters parms = new GLRMParameters();
            parms._train = train._key;
            parms._transform = DataInfo.TransformType.NONE;
            parms._k = 5;
            parms._loss = loss;
            parms._multi_loss = multiloss;
            parms._init = GLRM.Initialization.SVD;
            parms._regularization_x = regs[rng.nextInt(regs.length)];
            parms._regularization_y = regs[rng.nextInt(regs.length)];
            parms._gamma_x = Math.abs(rng.nextDouble());
            parms._gamma_y = Math.abs(rng.nextDouble());
            parms._recover_svd = false;
            parms._seed = myseed;
            parms._verbose = false;
            parms._max_iterations = 500;

            GLRM job = new GLRM(parms);
            try {
              model = job.trainModel().get();
              Log.info(
                  "Iteration "
                      + model._output._iterations
                      + ": Objective value = "
                      + model._output._objective);
              model.score(train).delete();
              ModelMetricsGLRM mm = (ModelMetricsGLRM) ModelMetrics.getFromDKV(model, train);
              Log.info(
                  "Numeric Sum of Squared Error = "
                      + mm._numerr
                      + "\tCategorical Misclassification Error = "
                      + mm._caterr);
            } catch (Throwable t) {
              throw t;
            } finally {
              job.remove();
            }
          } catch (Throwable t) {
            t.printStackTrace();
            throw new RuntimeException(t);
          } finally {
            if (model != null) model.delete();
            Scope.exit();
          }
        }
      }
    } finally {
      if (train != null) train.delete();
      Scope.exit();
    }
  }