@Override public boolean solve(Gram gram, double[] xy, double yy, double[] beta) { ADMMSolver admm = new ADMMSolver(_lambda, _alpha, 1e-2); if (gram != null) return admm.solve(gram, xy, yy, beta); Arrays.fill(beta, 0); long t1 = System.currentTimeMillis(); final double[] xb = gram.mul(beta); double objval = objectiveVal(xy, yy, beta, xb); final double[] newB = MemoryManager.malloc8d(beta.length); final double[] newG = MemoryManager.malloc8d(beta.length); double step = 1; final double l1pen = _lambda * _alpha; final double l2pen = _lambda * (1 - _alpha); double lsmobjval = lsm_objectiveVal(xy, yy, beta, xb); boolean converged = false; final int intercept = beta.length - 1; int iter = 0; MAIN: while (!converged && iter < 1000) { ++iter; step = 1; while (step > 1e-12) { // line search double l2shrink = 1 / (1 + step * l2pen); double l1shrink = l1pen * step; for (int i = 0; i < beta.length - 1; ++i) newB[i] = l2shrink * shrinkage((beta[i] - step * (xb[i] - xy[i])), l1shrink); newB[intercept] = beta[intercept] - step * (xb[intercept] - xy[intercept]); gram.mul(newB, newG); double newlsmobj = lsm_objectiveVal(xy, yy, newB, newG); double fhat = f_hat(newB, lsmobjval, beta, xb, xy, step); if (newlsmobj <= fhat) { lsmobjval = newlsmobj; converged = betaDiff(beta, newB) < 1e-6; System.arraycopy(newB, 0, beta, 0, newB.length); System.arraycopy(newG, 0, xb, 0, newG.length); continue MAIN; } else step *= 0.8; } converged = true; } return converged; }
@Override public void compute2() { Arrays.fill(z, 0); if (_lambda > 0 || _addedL2 > 0) gram.addDiag(_lambda * (1 - _alpha) + _addedL2); if (_alpha > 0 && _lambda > 0) gram.addDiag(rho); if (_proximalPenalty > 0 && _wgiven != null) { gram.addDiag(_proximalPenalty, true); for (int i = 0; i < xy.length; ++i) xy[i] += _proximalPenalty * _wgiven[i]; } int attempts = 0; long t1 = System.currentTimeMillis(); chol = gram.cholesky(null, true, _id); long t2 = System.currentTimeMillis(); while (!chol.isSPD() && attempts < 10) { if (_addedL2 == 0) _addedL2 = 1e-5; else _addedL2 *= 10; ++attempts; gram.addDiag(_addedL2); // try to add L2 penalty to make the Gram issp gram.cholesky(chol); } decompTime = (t2 - t1); if (!chol.isSPD()) throw new NonSPDMatrixException(gram); if (_alpha == 0 || _lambda == 0) { // no l1 penalty System.arraycopy(xy, 0, z, 0, xy.length); chol.parSolver(this, z, _iBlock, _rBlock).fork(); return; } gerr = Double.POSITIVE_INFINITY; _xyPrime = xy.clone(); _orlx = 1.8; // over-relaxation // first compute the x update // add rho*(z-u) to A'*y new ADMMIteration(this).fork(); }
public boolean solve(Gram gram, double[] xy, double yy, final double[] z, final double rho) { gerr = 0; double d = gram._diagAdded; final int N = xy.length; Arrays.fill(z, 0); if (_lambda > 0 || _addedL2 > 0) gram.addDiag(_lambda * (1 - _alpha) + _addedL2); if (_alpha > 0 && _lambda > 0) gram.addDiag(rho); if (_proximalPenalty > 0 && _wgiven != null) { gram.addDiag(_proximalPenalty, true); xy = xy.clone(); for (int i = 0; i < xy.length; ++i) xy[i] += _proximalPenalty * _wgiven[i]; } int attempts = 0; long t1 = System.currentTimeMillis(); Cholesky chol = gram.cholesky(null, true, _id); long t2 = System.currentTimeMillis(); while (!chol.isSPD() && attempts < 10) { if (_addedL2 == 0) _addedL2 = 1e-5; else _addedL2 *= 10; ++attempts; gram.addDiag(_addedL2); // try to add L2 penalty to make the Gram issp gram.cholesky(chol); } decompTime = (t2 - t1); if (!chol.isSPD()) throw new NonSPDMatrixException(gram); if (_alpha == 0 || _lambda == 0) { // no l1 penalty System.arraycopy(xy, 0, z, 0, xy.length); chol.solve(z); gram.addDiag(-gram._diagAdded + d); return true; } double[] u = MemoryManager.malloc8d(N); double[] xyPrime = xy.clone(); double kappa = _lambda * _alpha / rho; int i; int max_iter = Math.max(500, (int) (50000.0 / (1 + (xy.length >> 3)))); double orlx = 1.8; // over-relaxation double reltol = RELTOL; for (i = 0; i < max_iter; ++i) { long tX = System.currentTimeMillis(); // first compute the x update // add rho*(z-u) to A'*y for (int j = 0; j < N - 1; ++j) xyPrime[j] = xy[j] + rho * (z[j] - u[j]); xyPrime[N - 1] = xy[N - 1]; // updated x chol.solve(xyPrime); // compute u and z updateADMM double rnorm = 0, snorm = 0, unorm = 0, xnorm = 0; for (int j = 0; j < N - 1; ++j) { double x = xyPrime[j]; double zold = z[j]; double x_hat = x * orlx + (1 - orlx) * zold; z[j] = shrinkage(x_hat + u[j], kappa); u[j] += x_hat - z[j]; double r = xyPrime[j] - z[j]; double s = z[j] - zold; rnorm += r * r; snorm += s * s; xnorm += x * x; unorm += u[j] * u[j]; } z[N - 1] = xyPrime[N - 1]; if (rnorm < reltol * xnorm && snorm < reltol * unorm) { gerr = 0; double[] grad = grad(gram, z, xy); subgrad(_alpha, _lambda, z, grad); for (int x = 0; x < grad.length - 1; ++x) { if (gerr < grad[x]) gerr = grad[x]; else if (gerr < -grad[x]) gerr = -grad[x]; } if (gerr < 1e-4 || reltol <= 1e-6) break; while (rnorm < reltol * xnorm && snorm < reltol * unorm) reltol *= .1; } if (i % 20 == 0) orlx = (1 + 15 * orlx) * 0.0625; } gram.addDiag(-gram._diagAdded + d); assert gram._diagAdded == d; iterations = i; return _converged = (gerr < _gradientEps); }
public final double[] grad(Gram gram, double[] beta, double[] xy) { double[] grad = gram.mul(beta); for (int i = 0; i < grad.length; ++i) grad[i] -= xy[i]; return grad; }
@Override public void onCompletion(CountedCompleter caller) { gram.addDiag(-gram._diagAdded + d); assert gram._diagAdded == d; }