Пример #1
0
 @Override
 public void callback(final GLMIterationTask glmt) {
   if (!isRunning(self())) throw new JobCancelledException();
   boolean converged = false;
   if (glmt._beta != null && glmt._val != null && _glm.family != Family.tweedie) {
     glmt._val.finalize_AIC_AUC();
     _model.setAndTestValidation(_lambdaIdx, glmt._val); // .store();
     _model.clone().update(self());
     converged = true;
     double l1pen = alpha[0] * lambda[_lambdaIdx] * glmt._n;
     double l2pen = (1 - alpha[0]) * lambda[_lambdaIdx] * glmt._n;
     final double eps = 1e-2;
     for (int i = 0; i < glmt._grad.length - 1; ++i) { // add l2 reg. term to the gradient
       glmt._grad[i] += l2pen * glmt._beta[i];
       if (glmt._beta[i] < 0) converged &= Math.abs(glmt._grad[i] - l1pen) < eps;
       else if (glmt._beta[i] > 0) converged &= Math.abs(glmt._grad[i] + l1pen) < eps;
       else converged &= LSMSolver.shrinkage(glmt._grad[i], l1pen + eps) == 0;
     }
     if (converged) Log.info("GLM converged by reaching 0 gradient/subgradient.");
     double objval = glmt._val.residual_deviance + 0.5 * l2pen * l2norm(glmt._beta);
     if (!converged && _lastResult != null && needLineSearch(glmt._beta, objval, 1)) {
       new GLMTask.GLMLineSearchTask(
               GLM2.this,
               _dinfo,
               _glm,
               _lastResult._glmt._beta,
               glmt._beta,
               1e-8,
               new LineSearchIteration())
           .asyncExec(_dinfo._adaptedFrame);
       return;
     }
     _lastResult = new IterationInfo(GLM2.this._iter - 1, objval, glmt);
   }
   double[] newBeta =
       glmt._beta != null ? glmt._beta.clone() : MemoryManager.malloc8d(glmt._xy.length);
   double[] newBetaDeNorm = null;
   ADMMSolver slvr = new ADMMSolver(lambda[_lambdaIdx], alpha[0], _addedL2);
   slvr.solve(glmt._gram, glmt._xy, glmt._yy, newBeta);
   _addedL2 = slvr._addedL2;
   if (Utils.hasNaNsOrInfs(newBeta)) {
     Log.info("GLM forcibly converged by getting NaNs and/or Infs in beta");
   } else {
     if (_dinfo._standardize) {
       newBetaDeNorm = newBeta.clone();
       double norm = 0.0; // Reverse any normalization on the intercept
       // denormalize only the numeric coefs (categoricals are not normalized)
       final int numoff = newBeta.length - _dinfo._nums - 1;
       for (int i = numoff; i < newBeta.length - 1; i++) {
         double b = newBetaDeNorm[i] * _dinfo._normMul[i - numoff];
         norm += b * _dinfo._normSub[i - numoff]; // Also accumulate the intercept adjustment
         newBetaDeNorm[i] = b;
       }
       newBetaDeNorm[newBetaDeNorm.length - 1] -= norm;
     }
     _model.setLambdaSubmodel(
         _lambdaIdx,
         newBetaDeNorm == null ? newBeta : newBetaDeNorm,
         newBetaDeNorm == null ? null : newBeta,
         _iter);
     if (beta_diff(glmt._beta, newBeta) < beta_epsilon) {
       Log.info("GLM converged by reaching fixed-point.");
       converged = true;
     }
     if (!converged && _glm.family != Family.gaussian && _iter < max_iter) {
       ++_iter;
       new GLMIterationTask(GLM2.this, _dinfo, glmt._glm, newBeta, _ymu, _reg, new Iteration())
           .asyncExec(_dinfo._adaptedFrame);
       return;
     }
   }
   // done with this lambda
   nextLambda(glmt);
 }
Пример #2
0
    @Override
    public void callback(final GLMIterationTask glmt) {
      _model.stop_training();
      Log.info(
          "GLM2 iteration("
              + _iter
              + ") done in "
              + (System.currentTimeMillis() - _iterationStartTime)
              + "ms");
      if (!isRunning(self())) throw new JobCancelledException();
      currentLambdaIter++;
      if (glmt._val != null) {
        if (!(glmt._val.residual_deviance
            < glmt._val
                .null_deviance)) { // complete fail, look if we can restart with higher_accuracy on
          if (!highAccuracy()) {
            Log.info(
                "GLM2 reached negative explained deviance without line-search, rerunning with high accuracy settings.");
            setHighAccuracy();
            if (_lastResult != null)
              new GLMIterationTask(
                      GLM2.this,
                      _activeData,
                      glmt._glm,
                      true,
                      true,
                      true,
                      _lastResult._glmt._beta,
                      _ymu,
                      _reg,
                      new Iteration())
                  .asyncExec(_activeData._adaptedFrame);
            else if (_lambdaIdx > 2) // > 2 because 0 is null model, we don't wan to run with that
            new GLMIterationTask(
                      GLM2.this,
                      _activeData,
                      glmt._glm,
                      true,
                      true,
                      true,
                      _model.submodels[_lambdaIdx - 1].norm_beta,
                      _ymu,
                      _reg,
                      new Iteration())
                  .asyncExec(_activeData._adaptedFrame);
            else // no sane solution to go back to, start from scratch!
            new GLMIterationTask(
                      GLM2.this,
                      _activeData,
                      glmt._glm,
                      true,
                      false,
                      false,
                      null,
                      _ymu,
                      _reg,
                      new Iteration())
                  .asyncExec(_activeData._adaptedFrame);
            _lastResult = null;
            return;
          }
        }
        _model.setAndTestValidation(_lambdaIdx, glmt._val);
        _model.clone().update(self());
      }

      if (glmt._val != null && glmt._computeGradient) { // check gradient
        final double[] grad = glmt.gradient(l2pen());
        ADMMSolver.subgrad(alpha[0], lambda[_lambdaIdx], glmt._beta, grad);
        double err = 0;
        for (double d : grad)
          if (d > err) err = d;
          else if (d < -err) err = -d;
        Log.info("GLM2 gradient after " + _iter + " iterations = " + err);
        if (err <= GLM_GRAD_EPS) {
          Log.info(
              "GLM2 converged by reaching small enough gradient, with max |subgradient| = " + err);
          setNewBeta(glmt._beta);
          nextLambda(glmt, glmt._beta);
          return;
        }
      }
      if (glmt._beta != null
          && glmt._val != null
          && glmt._computeGradient
          && _glm.family != Family.tweedie) {
        if (_lastResult != null && needLineSearch(glmt._beta, objval(glmt), 1)) {
          if (!highAccuracy()) {
            setHighAccuracy();
            if (_lastResult._iter
                < (_iter - 2)) { // there is a gap form last result...return to it and start again
              final double[] prevBeta =
                  _lastResult._activeCols != _activeCols
                      ? resizeVec(_lastResult._glmt._beta, _activeCols, _lastResult._activeCols)
                      : _lastResult._glmt._beta;
              new GLMIterationTask(
                      GLM2.this,
                      _activeData,
                      glmt._glm,
                      true,
                      true,
                      true,
                      prevBeta,
                      _ymu,
                      _reg,
                      new Iteration())
                  .asyncExec(_activeData._adaptedFrame);
              return;
            }
          }
          final double[] b =
              resizeVec(_lastResult._glmt._beta, _activeCols, _lastResult._activeCols);
          assert (b.length == glmt._beta.length)
              : b.length + " != " + glmt._beta.length + ", activeCols = " + _activeCols.length;
          new GLMTask.GLMLineSearchTask(
                  GLM2.this,
                  _activeData,
                  _glm,
                  resizeVec(_lastResult._glmt._beta, _activeCols, _lastResult._activeCols),
                  glmt._beta,
                  1e-4,
                  glmt._nobs,
                  alpha[0],
                  lambda[_lambdaIdx],
                  new LineSearchIteration())
              .asyncExec(_activeData._adaptedFrame);
          return;
        }
        _lastResult = new IterationInfo(GLM2.this._iter - 1, glmt, _activeCols);
      }
      final double[] newBeta = MemoryManager.malloc8d(glmt._xy.length);
      ADMMSolver slvr = new ADMMSolver(lambda[_lambdaIdx], alpha[0], ADMM_GRAD_EPS, _addedL2);
      slvr.solve(glmt._gram, glmt._xy, glmt._yy, newBeta);
      _addedL2 = slvr._addedL2;
      if (Utils.hasNaNsOrInfs(newBeta)) {
        Log.info("GLM2 forcibly converged by getting NaNs and/or Infs in beta");
        nextLambda(glmt, glmt._beta);
      } else {
        setNewBeta(newBeta);
        final double bdiff = beta_diff(glmt._beta, newBeta);
        if (_glm.family == Family.gaussian
            || bdiff < beta_epsilon
            || _iter
                == max_iter) { // Gaussian is non-iterative and gradient is ADMMSolver's gradient =>
          // just validate and move on to the next lambda
          int diff = (int) Math.log10(bdiff);
          int nzs = 0;
          for (int i = 0; i < newBeta.length; ++i) if (newBeta[i] != 0) ++nzs;
          if (newBeta.length < 20) System.out.println("beta = " + Arrays.toString(newBeta));
          Log.info(
              "GLM2 (lambda_"
                  + _lambdaIdx
                  + "="
                  + lambda[_lambdaIdx]
                  + ") converged (reached a fixed point with ~ 1e"
                  + diff
                  + " precision) after "
                  + _iter
                  + "iterations, got "
                  + nzs
                  + " nzs");
          nextLambda(glmt, newBeta);
        } else { // not done yet, launch next iteration
          final boolean validate = higher_accuracy || (currentLambdaIter % 5) == 0;
          ++_iter;
          System.out.println("Iter = " + _iter);
          new GLMIterationTask(
                  GLM2.this,
                  _activeData,
                  glmt._glm,
                  true,
                  validate,
                  validate,
                  newBeta,
                  _ymu,
                  _reg,
                  new Iteration())
              .asyncExec(_activeData._adaptedFrame);
        }
      }
    }