Esempio n. 1
0
    protected void computeStatsFillModel(
        PCAModel pca, DataInfo dinfo, SingularValueDecomposition svd, Gram gram, long nobs) {
      // Save adapted frame info for scoring later
      pca._output._normSub = dinfo._normSub == null ? new double[dinfo._nums] : dinfo._normSub;
      if (dinfo._normMul == null) {
        pca._output._normMul = new double[dinfo._nums];
        Arrays.fill(pca._output._normMul, 1.0);
      } else pca._output._normMul = dinfo._normMul;
      pca._output._permutation = dinfo._permutation;
      pca._output._nnums = dinfo._nums;
      pca._output._ncats = dinfo._cats;
      pca._output._catOffsets = dinfo._catOffsets;

      double dfcorr = nobs / (nobs - 1.0);
      double[] sval = svd.getSingularValues();
      pca._output._std_deviation = new double[_parms._k]; // Only want first k standard deviations
      for (int i = 0; i < _parms._k; i++) {
        sval[i] =
            dfcorr
                * sval[
                    i]; // Degrees of freedom = n-1, where n = nobs = # row observations processed
        pca._output._std_deviation[i] = Math.sqrt(sval[i]);
      }

      double[][] eigvec = svd.getV().getArray();
      pca._output._eigenvectors_raw =
          new double[eigvec.length][_parms._k]; // Only want first k eigenvectors
      for (int i = 0; i < eigvec.length; i++)
        System.arraycopy(eigvec[i], 0, pca._output._eigenvectors_raw[i], 0, _parms._k);
      pca._output._total_variance =
          dfcorr * gram.diagSum(); // Since gram = X'X/n, but variance requires n-1 in denominator
      buildTables(pca, dinfo.coefNames());
    }
Esempio n. 2
0
    // Main worker thread
    @Override
    protected void compute2() {
      PCAModel model = null;
      DataInfo dinfo = null;
      DataInfo xinfo = null;
      Frame x = null;

      try {
        init(true); // Initialize parameters
        _parms.read_lock_frames(PCA.this); // Fetch & read-lock input frames
        if (error_count() > 0)
          throw new IllegalArgumentException("Found validation errors: " + validationErrors());

        // The model to be built
        model = new PCAModel(dest(), _parms, new PCAModel.PCAOutput(PCA.this));
        model.delete_and_lock(_key);

        if (_parms._pca_method == PCAParameters.Method.GramSVD) {
          dinfo =
              new DataInfo(
                  Key.make(),
                  _train,
                  null,
                  0,
                  _parms._use_all_factor_levels,
                  _parms._transform,
                  DataInfo.TransformType.NONE,
                  /* skipMissing */ true, /* missingBucket */
                  false, /* weights */
                  false, /* offset */
                  false, /* intercept */
                  false);
          DKV.put(dinfo._key, dinfo);

          // Calculate and save Gram matrix of training data
          // NOTE: Gram computes A'A/n where n = nrow(A) = number of rows in training set (excluding
          // rows with NAs)
          GramTask gtsk = new Gram.GramTask(self(), dinfo).doAll(dinfo._adaptedFrame);
          Gram gram =
              gtsk._gram; // TODO: This ends up with all NaNs if training data has too many missing
          // values
          assert gram.fullN() == _ncolExp;

          // Compute SVD of Gram A'A/n using JAMA library
          // Note: Singular values ordered in weakly descending order by algorithm
          Matrix gramJ = new Matrix(gtsk._gram.getXX());
          SingularValueDecomposition svdJ = gramJ.svd();
          computeStatsFillModel(model, dinfo, svdJ, gram, gtsk._nobs);

        } else if (_parms._pca_method == PCAParameters.Method.Power) {
          SVDModel.SVDParameters parms = new SVDModel.SVDParameters();
          parms._train = _parms._train;
          parms._ignored_columns = _parms._ignored_columns;
          parms._ignore_const_cols = _parms._ignore_const_cols;
          parms._score_each_iteration = _parms._score_each_iteration;
          parms._use_all_factor_levels = _parms._use_all_factor_levels;
          parms._transform = _parms._transform;
          parms._nv = _parms._k;
          parms._max_iterations = _parms._max_iterations;
          parms._seed = _parms._seed;

          // Calculate standard deviation and projection as well
          parms._only_v = false;
          parms._u_name = _parms._loading_name;
          parms._keep_u = _parms._keep_loading;

          SVDModel svd = null;
          SVD job = null;
          try {
            job = new EmbeddedSVD(_key, _progressKey, parms);
            svd = job.trainModel().get();
            if (job.isCancelledOrCrashed()) PCA.this.cancel();
          } finally {
            if (job != null) job.remove();
            if (svd != null) svd.remove();
          }
          // Recover PCA results from SVD model
          computeStatsFillModel(model, svd);

        } else if (_parms._pca_method == PCAParameters.Method.GLRM) {
          GLRMModel.GLRMParameters parms = new GLRMModel.GLRMParameters();
          parms._train = _parms._train;
          parms._ignored_columns = _parms._ignored_columns;
          parms._ignore_const_cols = _parms._ignore_const_cols;
          parms._score_each_iteration = _parms._score_each_iteration;
          parms._transform = _parms._transform;
          parms._k = _parms._k;
          parms._max_iterations = _parms._max_iterations;
          parms._seed = _parms._seed;

          parms._recover_svd = true;
          parms._loss = GLRMModel.GLRMParameters.Loss.L2;
          parms._gamma_x = 0;
          parms._gamma_y = 0;

          GLRMModel glrm = null;
          GLRM job = null;
          try {
            job = new EmbeddedGLRM(_key, _progressKey, parms);
            glrm = job.trainModel().get();
            if (job.isCancelledOrCrashed()) PCA.this.cancel();
          } finally {
            if (job != null) job.remove();
            if (glrm != null) {
              glrm._parms._loading_key.get().delete();
              glrm.remove();
            }
          }
          // Recover PCA results from GLRM model
          computeStatsFillModel(model, glrm);
        }

        model.update(self());
        update(1);
        done();
      } catch (Throwable t) {
        Job thisJob = DKV.getGet(_key);
        if (thisJob._state == JobState.CANCELLED) {
          Log.info("Job cancelled by user.");
        } else {
          t.printStackTrace();
          failed(t);
          throw t;
        }
      } finally {
        _parms.read_unlock_frames(PCA.this);
        if (model != null) model.unlock(_key);
        if (dinfo != null) dinfo.remove();
        if (xinfo != null) xinfo.remove();
        if (x != null && !_parms._keep_loading) x.delete();
      }
      tryComplete();
    }