Пример #1
0
  @Test
  public void testCategoricalProstate() throws InterruptedException, ExecutionException {
    GLRM job = null;
    GLRMModel model = null;
    Frame train = null;
    final int[] cats = new int[] {1, 3, 4, 5}; // Categoricals: CAPSULE, RACE, DPROS, DCAPS

    try {
      Scope.enter();
      train = parse_test_file(Key.make("prostate.hex"), "smalldata/logreg/prostate.csv");
      for (int i = 0; i < cats.length; i++)
        Scope.track(train.replace(cats[i], train.vec(cats[i]).toCategoricalVec())._key);
      train.remove("ID").remove();
      DKV.put(train._key, train);

      GLRMParameters parms = new GLRMParameters();
      parms._train = train._key;
      parms._k = 8;
      parms._gamma_x = parms._gamma_y = 0.1;
      parms._regularization_x = GLRMModel.GLRMParameters.Regularizer.Quadratic;
      parms._regularization_y = GLRMModel.GLRMParameters.Regularizer.Quadratic;
      parms._init = GLRM.Initialization.PlusPlus;
      parms._transform = DataInfo.TransformType.STANDARDIZE;
      parms._recover_svd = false;
      parms._max_iterations = 200;

      try {
        job = new GLRM(parms);
        model = job.trainModel().get();
        Log.info(
            "Iteration "
                + model._output._iterations
                + ": Objective value = "
                + model._output._objective);
        model.score(train).delete();
        ModelMetricsGLRM mm = (ModelMetricsGLRM) ModelMetrics.getFromDKV(model, train);
        Log.info(
            "Numeric Sum of Squared Error = "
                + mm._numerr
                + "\tCategorical Misclassification Error = "
                + mm._caterr);
      } catch (Throwable t) {
        t.printStackTrace();
        throw new RuntimeException(t);
      } finally {
        job.remove();
      }
    } catch (Throwable t) {
      t.printStackTrace();
      throw new RuntimeException(t);
    } finally {
      if (train != null) train.delete();
      if (model != null) model.delete();
      Scope.exit();
    }
  }
Пример #2
0
  @Test
  public void testCategoricalIris() throws InterruptedException, ExecutionException {
    GLRM job = null;
    GLRMModel model = null;
    Frame train = null;

    try {
      train = parse_test_file(Key.make("iris.hex"), "smalldata/iris/iris_wheader.csv");
      GLRMParameters parms = new GLRMParameters();
      parms._train = train._key;
      parms._k = 4;
      parms._loss = GLRMParameters.Loss.Absolute;
      parms._init = GLRM.Initialization.SVD;
      parms._transform = DataInfo.TransformType.NONE;
      parms._recover_svd = true;
      parms._max_iterations = 1000;

      try {
        job = new GLRM(parms);
        model = job.trainModel().get();
        Log.info(
            "Iteration "
                + model._output._iterations
                + ": Objective value = "
                + model._output._objective);
        model.score(train).delete();
        ModelMetricsGLRM mm = (ModelMetricsGLRM) ModelMetrics.getFromDKV(model, train);
        Log.info(
            "Numeric Sum of Squared Error = "
                + mm._numerr
                + "\tCategorical Misclassification Error = "
                + mm._caterr);
      } catch (Throwable t) {
        t.printStackTrace();
        throw new RuntimeException(t);
      } finally {
        job.remove();
      }
    } catch (Throwable t) {
      t.printStackTrace();
      throw new RuntimeException(t);
    } finally {
      if (train != null) train.delete();
      if (model != null) model.delete();
    }
  }
Пример #3
0
  @Test
  public void testSetColumnLossCats() throws InterruptedException, ExecutionException {
    GLRM job = null;
    GLRMModel model = null;
    Frame train = null;
    final int[] cats = new int[] {1, 3, 4, 5}; // Categoricals: CAPSULE, RACE, DPROS, DCAPS

    Scope.enter();
    try {
      train = parse_test_file(Key.make("prostate.hex"), "smalldata/logreg/prostate.csv");
      for (int i = 0; i < cats.length; i++)
        Scope.track(train.replace(cats[i], train.vec(cats[i]).toCategoricalVec())._key);
      train.remove("ID").remove();
      DKV.put(train._key, train);

      GLRMParameters parms = new GLRMParameters();
      parms._train = train._key;
      parms._k = 12;
      parms._loss = GLRMParameters.Loss.Quadratic;
      parms._multi_loss = GLRMParameters.Loss.Categorical;
      parms._loss_by_col =
          new GLRMParameters.Loss[] {
            GLRMParameters.Loss.Ordinal, GLRMParameters.Loss.Poisson, GLRMParameters.Loss.Absolute
          };
      parms._loss_by_col_idx = new int[] {3 /* DPROS */, 1 /* AGE */, 6 /* VOL */};
      parms._init = GLRM.Initialization.PlusPlus;
      parms._min_step_size = 1e-5;
      parms._recover_svd = false;
      parms._max_iterations = 2000;

      try {
        job = new GLRM(parms);
        model = job.trainModel().get();
        Log.info(
            "Iteration "
                + model._output._iterations
                + ": Objective value = "
                + model._output._objective);
        GLRMTest.checkLossbyCol(parms, model);

        model.score(train).delete();
        ModelMetricsGLRM mm = (ModelMetricsGLRM) ModelMetrics.getFromDKV(model, train);
        Log.info(
            "Numeric Sum of Squared Error = "
                + mm._numerr
                + "\tCategorical Misclassification Error = "
                + mm._caterr);
      } catch (Throwable t) {
        t.printStackTrace();
        throw new RuntimeException(t);
      } finally {
        job.remove();
      }

    } catch (Throwable t) {
      t.printStackTrace();
      throw new RuntimeException(t);
    } finally {
      if (train != null) train.delete();
      if (model != null) model.delete();
      Scope.exit();
    }
  }
Пример #4
0
  @Test
  public void testLosses() throws InterruptedException, ExecutionException {
    long seed = 0xDECAF;
    Random rng = new Random(seed);
    Frame train = null;
    final int[] cats = new int[] {1, 3, 4, 5}; // Categoricals: CAPSULE, RACE, DPROS, DCAPS
    final GLRMParameters.Regularizer[] regs =
        new GLRMParameters.Regularizer[] {
          GLRMParameters.Regularizer.Quadratic,
          GLRMParameters.Regularizer.L1,
          GLRMParameters.Regularizer.NonNegative,
          GLRMParameters.Regularizer.OneSparse,
          GLRMParameters.Regularizer.UnitOneSparse,
          GLRMParameters.Regularizer.Simplex
        };

    Scope.enter();
    try {
      train = parse_test_file(Key.make("prostate.hex"), "smalldata/logreg/prostate.csv");
      for (int i = 0; i < cats.length; i++)
        Scope.track(train.replace(cats[i], train.vec(cats[i]).toCategoricalVec())._key);
      train.remove("ID").remove();
      DKV.put(train._key, train);

      for (GLRMParameters.Loss loss :
          new GLRMParameters.Loss[] {
            GLRMParameters.Loss.Quadratic,
            GLRMParameters.Loss.Absolute,
            GLRMParameters.Loss.Huber,
            GLRMParameters.Loss.Poisson,
            GLRMParameters.Loss.Hinge,
            GLRMParameters.Loss.Logistic
          }) {
        for (GLRMParameters.Loss multiloss :
            new GLRMParameters.Loss[] {
              GLRMParameters.Loss.Categorical, GLRMParameters.Loss.Ordinal
            }) {
          GLRMModel model = null;
          try {
            Scope.enter();
            long myseed = rng.nextLong();
            Log.info("GLRM using seed = " + myseed);

            GLRMParameters parms = new GLRMParameters();
            parms._train = train._key;
            parms._transform = DataInfo.TransformType.NONE;
            parms._k = 5;
            parms._loss = loss;
            parms._multi_loss = multiloss;
            parms._init = GLRM.Initialization.SVD;
            parms._regularization_x = regs[rng.nextInt(regs.length)];
            parms._regularization_y = regs[rng.nextInt(regs.length)];
            parms._gamma_x = Math.abs(rng.nextDouble());
            parms._gamma_y = Math.abs(rng.nextDouble());
            parms._recover_svd = false;
            parms._seed = myseed;
            parms._verbose = false;
            parms._max_iterations = 500;

            GLRM job = new GLRM(parms);
            try {
              model = job.trainModel().get();
              Log.info(
                  "Iteration "
                      + model._output._iterations
                      + ": Objective value = "
                      + model._output._objective);
              model.score(train).delete();
              ModelMetricsGLRM mm = (ModelMetricsGLRM) ModelMetrics.getFromDKV(model, train);
              Log.info(
                  "Numeric Sum of Squared Error = "
                      + mm._numerr
                      + "\tCategorical Misclassification Error = "
                      + mm._caterr);
            } catch (Throwable t) {
              throw t;
            } finally {
              job.remove();
            }
          } catch (Throwable t) {
            t.printStackTrace();
            throw new RuntimeException(t);
          } finally {
            if (model != null) model.delete();
            Scope.exit();
          }
        }
      }
    } finally {
      if (train != null) train.delete();
      Scope.exit();
    }
  }
Пример #5
0
    // Main worker thread
    @Override
    protected void compute2() {
      PCAModel model = null;
      DataInfo dinfo = null;
      DataInfo xinfo = null;
      Frame x = null;

      try {
        init(true); // Initialize parameters
        _parms.read_lock_frames(PCA.this); // Fetch & read-lock input frames
        if (error_count() > 0)
          throw new IllegalArgumentException("Found validation errors: " + validationErrors());

        // The model to be built
        model = new PCAModel(dest(), _parms, new PCAModel.PCAOutput(PCA.this));
        model.delete_and_lock(_key);

        if (_parms._pca_method == PCAParameters.Method.GramSVD) {
          dinfo =
              new DataInfo(
                  Key.make(),
                  _train,
                  null,
                  0,
                  _parms._use_all_factor_levels,
                  _parms._transform,
                  DataInfo.TransformType.NONE,
                  /* skipMissing */ true, /* missingBucket */
                  false, /* weights */
                  false, /* offset */
                  false, /* intercept */
                  false);
          DKV.put(dinfo._key, dinfo);

          // Calculate and save Gram matrix of training data
          // NOTE: Gram computes A'A/n where n = nrow(A) = number of rows in training set (excluding
          // rows with NAs)
          GramTask gtsk = new Gram.GramTask(self(), dinfo).doAll(dinfo._adaptedFrame);
          Gram gram =
              gtsk._gram; // TODO: This ends up with all NaNs if training data has too many missing
          // values
          assert gram.fullN() == _ncolExp;

          // Compute SVD of Gram A'A/n using JAMA library
          // Note: Singular values ordered in weakly descending order by algorithm
          Matrix gramJ = new Matrix(gtsk._gram.getXX());
          SingularValueDecomposition svdJ = gramJ.svd();
          computeStatsFillModel(model, dinfo, svdJ, gram, gtsk._nobs);

        } else if (_parms._pca_method == PCAParameters.Method.Power) {
          SVDModel.SVDParameters parms = new SVDModel.SVDParameters();
          parms._train = _parms._train;
          parms._ignored_columns = _parms._ignored_columns;
          parms._ignore_const_cols = _parms._ignore_const_cols;
          parms._score_each_iteration = _parms._score_each_iteration;
          parms._use_all_factor_levels = _parms._use_all_factor_levels;
          parms._transform = _parms._transform;
          parms._nv = _parms._k;
          parms._max_iterations = _parms._max_iterations;
          parms._seed = _parms._seed;

          // Calculate standard deviation and projection as well
          parms._only_v = false;
          parms._u_name = _parms._loading_name;
          parms._keep_u = _parms._keep_loading;

          SVDModel svd = null;
          SVD job = null;
          try {
            job = new EmbeddedSVD(_key, _progressKey, parms);
            svd = job.trainModel().get();
            if (job.isCancelledOrCrashed()) PCA.this.cancel();
          } finally {
            if (job != null) job.remove();
            if (svd != null) svd.remove();
          }
          // Recover PCA results from SVD model
          computeStatsFillModel(model, svd);

        } else if (_parms._pca_method == PCAParameters.Method.GLRM) {
          GLRMModel.GLRMParameters parms = new GLRMModel.GLRMParameters();
          parms._train = _parms._train;
          parms._ignored_columns = _parms._ignored_columns;
          parms._ignore_const_cols = _parms._ignore_const_cols;
          parms._score_each_iteration = _parms._score_each_iteration;
          parms._transform = _parms._transform;
          parms._k = _parms._k;
          parms._max_iterations = _parms._max_iterations;
          parms._seed = _parms._seed;

          parms._recover_svd = true;
          parms._loss = GLRMModel.GLRMParameters.Loss.L2;
          parms._gamma_x = 0;
          parms._gamma_y = 0;

          GLRMModel glrm = null;
          GLRM job = null;
          try {
            job = new EmbeddedGLRM(_key, _progressKey, parms);
            glrm = job.trainModel().get();
            if (job.isCancelledOrCrashed()) PCA.this.cancel();
          } finally {
            if (job != null) job.remove();
            if (glrm != null) {
              glrm._parms._loading_key.get().delete();
              glrm.remove();
            }
          }
          // Recover PCA results from GLRM model
          computeStatsFillModel(model, glrm);
        }

        model.update(self());
        update(1);
        done();
      } catch (Throwable t) {
        Job thisJob = DKV.getGet(_key);
        if (thisJob._state == JobState.CANCELLED) {
          Log.info("Job cancelled by user.");
        } else {
          t.printStackTrace();
          failed(t);
          throw t;
        }
      } finally {
        _parms.read_unlock_frames(PCA.this);
        if (model != null) model.unlock(_key);
        if (dinfo != null) dinfo.remove();
        if (xinfo != null) xinfo.remove();
        if (x != null && !_parms._keep_loading) x.delete();
      }
      tryComplete();
    }