예제 #1
0
 @Override
 public boolean solve(Gram gram, double[] xy, double yy, double[] beta) {
   ADMMSolver admm = new ADMMSolver(_lambda, _alpha, 1e-2);
   if (gram != null) return admm.solve(gram, xy, yy, beta);
   Arrays.fill(beta, 0);
   long t1 = System.currentTimeMillis();
   final double[] xb = gram.mul(beta);
   double objval = objectiveVal(xy, yy, beta, xb);
   final double[] newB = MemoryManager.malloc8d(beta.length);
   final double[] newG = MemoryManager.malloc8d(beta.length);
   double step = 1;
   final double l1pen = _lambda * _alpha;
   final double l2pen = _lambda * (1 - _alpha);
   double lsmobjval = lsm_objectiveVal(xy, yy, beta, xb);
   boolean converged = false;
   final int intercept = beta.length - 1;
   int iter = 0;
   MAIN:
   while (!converged && iter < 1000) {
     ++iter;
     step = 1;
     while (step > 1e-12) { // line search
       double l2shrink = 1 / (1 + step * l2pen);
       double l1shrink = l1pen * step;
       for (int i = 0; i < beta.length - 1; ++i)
         newB[i] = l2shrink * shrinkage((beta[i] - step * (xb[i] - xy[i])), l1shrink);
       newB[intercept] = beta[intercept] - step * (xb[intercept] - xy[intercept]);
       gram.mul(newB, newG);
       double newlsmobj = lsm_objectiveVal(xy, yy, newB, newG);
       double fhat = f_hat(newB, lsmobjval, beta, xb, xy, step);
       if (newlsmobj <= fhat) {
         lsmobjval = newlsmobj;
         converged = betaDiff(beta, newB) < 1e-6;
         System.arraycopy(newB, 0, beta, 0, newB.length);
         System.arraycopy(newG, 0, xb, 0, newG.length);
         continue MAIN;
       } else step *= 0.8;
     }
     converged = true;
   }
   return converged;
 }
예제 #2
0
 @Override
 public void compute2() {
   Arrays.fill(z, 0);
   if (_lambda > 0 || _addedL2 > 0) gram.addDiag(_lambda * (1 - _alpha) + _addedL2);
   if (_alpha > 0 && _lambda > 0) gram.addDiag(rho);
   if (_proximalPenalty > 0 && _wgiven != null) {
     gram.addDiag(_proximalPenalty, true);
     for (int i = 0; i < xy.length; ++i) xy[i] += _proximalPenalty * _wgiven[i];
   }
   int attempts = 0;
   long t1 = System.currentTimeMillis();
   chol = gram.cholesky(null, true, _id);
   long t2 = System.currentTimeMillis();
   while (!chol.isSPD() && attempts < 10) {
     if (_addedL2 == 0) _addedL2 = 1e-5;
     else _addedL2 *= 10;
     ++attempts;
     gram.addDiag(_addedL2); // try to add L2 penalty to make the Gram issp
     gram.cholesky(chol);
   }
   decompTime = (t2 - t1);
   if (!chol.isSPD()) throw new NonSPDMatrixException(gram);
   if (_alpha == 0 || _lambda == 0) { // no l1 penalty
     System.arraycopy(xy, 0, z, 0, xy.length);
     chol.parSolver(this, z, _iBlock, _rBlock).fork();
     return;
   }
   gerr = Double.POSITIVE_INFINITY;
   _xyPrime = xy.clone();
   _orlx = 1.8; // over-relaxation
   // first compute the x update
   // add rho*(z-u) to A'*y
   new ADMMIteration(this).fork();
 }
예제 #3
0
 public boolean solve(Gram gram, double[] xy, double yy, final double[] z, final double rho) {
   gerr = 0;
   double d = gram._diagAdded;
   final int N = xy.length;
   Arrays.fill(z, 0);
   if (_lambda > 0 || _addedL2 > 0) gram.addDiag(_lambda * (1 - _alpha) + _addedL2);
   if (_alpha > 0 && _lambda > 0) gram.addDiag(rho);
   if (_proximalPenalty > 0 && _wgiven != null) {
     gram.addDiag(_proximalPenalty, true);
     xy = xy.clone();
     for (int i = 0; i < xy.length; ++i) xy[i] += _proximalPenalty * _wgiven[i];
   }
   int attempts = 0;
   long t1 = System.currentTimeMillis();
   Cholesky chol = gram.cholesky(null, true, _id);
   long t2 = System.currentTimeMillis();
   while (!chol.isSPD() && attempts < 10) {
     if (_addedL2 == 0) _addedL2 = 1e-5;
     else _addedL2 *= 10;
     ++attempts;
     gram.addDiag(_addedL2); // try to add L2 penalty to make the Gram issp
     gram.cholesky(chol);
   }
   decompTime = (t2 - t1);
   if (!chol.isSPD()) throw new NonSPDMatrixException(gram);
   if (_alpha == 0 || _lambda == 0) { // no l1 penalty
     System.arraycopy(xy, 0, z, 0, xy.length);
     chol.solve(z);
     gram.addDiag(-gram._diagAdded + d);
     return true;
   }
   double[] u = MemoryManager.malloc8d(N);
   double[] xyPrime = xy.clone();
   double kappa = _lambda * _alpha / rho;
   int i;
   int max_iter = Math.max(500, (int) (50000.0 / (1 + (xy.length >> 3))));
   double orlx = 1.8; // over-relaxation
   double reltol = RELTOL;
   for (i = 0; i < max_iter; ++i) {
     long tX = System.currentTimeMillis();
     // first compute the x update
     // add rho*(z-u) to A'*y
     for (int j = 0; j < N - 1; ++j) xyPrime[j] = xy[j] + rho * (z[j] - u[j]);
     xyPrime[N - 1] = xy[N - 1];
     // updated x
     chol.solve(xyPrime);
     // compute u and z updateADMM
     double rnorm = 0, snorm = 0, unorm = 0, xnorm = 0;
     for (int j = 0; j < N - 1; ++j) {
       double x = xyPrime[j];
       double zold = z[j];
       double x_hat = x * orlx + (1 - orlx) * zold;
       z[j] = shrinkage(x_hat + u[j], kappa);
       u[j] += x_hat - z[j];
       double r = xyPrime[j] - z[j];
       double s = z[j] - zold;
       rnorm += r * r;
       snorm += s * s;
       xnorm += x * x;
       unorm += u[j] * u[j];
     }
     z[N - 1] = xyPrime[N - 1];
     if (rnorm < reltol * xnorm && snorm < reltol * unorm) {
       gerr = 0;
       double[] grad = grad(gram, z, xy);
       subgrad(_alpha, _lambda, z, grad);
       for (int x = 0; x < grad.length - 1; ++x) {
         if (gerr < grad[x]) gerr = grad[x];
         else if (gerr < -grad[x]) gerr = -grad[x];
       }
       if (gerr < 1e-4 || reltol <= 1e-6) break;
       while (rnorm < reltol * xnorm && snorm < reltol * unorm) reltol *= .1;
     }
     if (i % 20 == 0) orlx = (1 + 15 * orlx) * 0.0625;
   }
   gram.addDiag(-gram._diagAdded + d);
   assert gram._diagAdded == d;
   iterations = i;
   return _converged = (gerr < _gradientEps);
 }
예제 #4
0
 public final double[] grad(Gram gram, double[] beta, double[] xy) {
   double[] grad = gram.mul(beta);
   for (int i = 0; i < grad.length; ++i) grad[i] -= xy[i];
   return grad;
 }
예제 #5
0
 @Override
 public void onCompletion(CountedCompleter caller) {
   gram.addDiag(-gram._diagAdded + d);
   assert gram._diagAdded == d;
 }