Skip to content

Commit

Permalink
Tolerate QN linesearch failures when it's harmless (rapidsai#3791)
Browse files Browse the repository at this point in the history
In some cases, the linesearch subroutine of the Quasi-Newton solver (logistic regression) may fail in a way that is tolerable for the solver. Currently, any linesearch failure makes the outer algorithm stop, and most of the time the found solution seems to be acceptable to a good precision. This proposal suggests to ignore some non-critical failures ( `LS_MAX_ITERS_REACHED` and `LS_INVALID_STEP_MIN`) and stop peacefully if the convergence check triggers or if the function change is too small.

The PR mostly changes the logging behavior, with one small exception. The last condition in `update_and_check` won't trigger if the linesearch error is not critical and if the target function goes down (i.e. the solver still advances). Previously, the solver would stop independently of the target change.

Authors:
  - Artem M. Chirkin (https://github.com/achirkin)

Approvers:
  - Tamas Bela Feher (https://github.com/tfeher)
  - Corey J. Nolet (https://github.com/cjnolet)

URL: rapidsai#3791
  • Loading branch information
achirkin authored Jul 22, 2021
1 parent 0aa2491 commit d1f5efd
Showing 1 changed file with 106 additions and 40 deletions.
146 changes: 106 additions & 40 deletions cpp/src/glm/qn/qn_solvers.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,74 @@ inline size_t owlqn_workspace_size(const LBFGSParam<T>& param, const int n)
return lbfgs_workspace_size(param, n) + vec_size;
}

template <typename T>
inline bool update_and_check(const char* solver,
const LBFGSParam<T>& param,
int iter,
LINE_SEARCH_RETCODE lsret,
T& fx,
T& fxp,
ML::SimpleVec<T>& x,
ML::SimpleVec<T>& xp,
ML::SimpleVec<T>& grad,
ML::SimpleVec<T>& gradp,
std::vector<T>& fx_hist,
T* dev_scalar,
OPT_RETCODE& outcode,
cudaStream_t stream)
{
bool stop = false;
bool converged = false;
bool isLsValid = !isnan(fx) && !isinf(fx);
// Linesearch may fail to converge, but still come closer to the solution;
// if that is not the case, let `check_convergence` ("insufficient change")
// below terminate the loop.
bool isLsNonCritical = lsret == LS_INVALID_STEP_MIN || lsret == LS_MAX_ITERS_REACHED;
// If the error is not critical, check that the target function does not grow.
// This shouldn't really happen, but weird things can happen if the convergence
// thresholds are too small.
bool isLsInDoubt = isLsValid && fx <= fxp + param.ftol && isLsNonCritical;
bool isLsSuccess = lsret == LS_SUCCESS || isLsInDoubt;

CUML_LOG_TRACE("%s iteration %d, fx=%f", solver, iter, fx);

// if the target is at least finite, we can check the convergence
if (isLsValid)
converged = check_convergence(param, iter, fx, x, grad, fx_hist, dev_scalar, stream);

if (!isLsSuccess && !converged) {
CUML_LOG_WARN(
"%s line search failed (code %d); stopping at the last valid step", solver, lsret);
outcode = OPT_LS_FAILED;
stop = true;
} else if (!isLsValid) {
CUML_LOG_ERROR(
"%s error fx=%f at iteration %d; stopping at the last valid step", solver, fx, iter);
outcode = OPT_NUMERIC_ERROR;
stop = true;
} else if (converged) {
CUML_LOG_DEBUG("%s converged", solver);
outcode = OPT_SUCCESS;
stop = true;
} else if (isLsInDoubt && fx + param.ftol >= fxp) {
// If a non-critical error has happened during the line search, check if the target
// is improved at least a bit. Otherwise, stop to avoid spinning till the iteration limit.
CUML_LOG_WARN(
"%s stopped, because the line search failed to advance (step delta = %f)", solver, fx - fxp);
outcode = OPT_LS_FAILED;
stop = true;
}

// if lineseach wasn't successful, undo the update.
if (!isLsSuccess || !isLsValid) {
fx = fxp;
x.copy_async(xp, stream);
grad.copy_async(gradp, stream);
}

return stop;
}

template <typename T, typename Function>
inline OPT_RETCODE min_lbfgs(const LBFGSParam<T>& param,
Function& f, // function to minimize
Expand Down Expand Up @@ -131,35 +199,32 @@ inline OPT_RETCODE min_lbfgs(const LBFGSParam<T>& param,
*k = 1;
int end = 0;
int n_vec = 0; // number of vector updates made in lbfgs_search_dir
OPT_RETCODE retcode;
LINE_SEARCH_RETCODE lsret;
for (; *k <= param.max_iterations; (*k)++) {
// Save the curent x and gradient
xp.copy_async(x, stream);
gradp.copy_async(grad, stream);
fxp = fx;

// Line search to update x, fx and gradient
LINE_SEARCH_RETCODE lsret =
ls_backtrack(param, f, fx, x, grad, step, drt, xp, dev_scalar, stream);

bool isLsSuccess = lsret == LS_SUCCESS;
CUML_LOG_TRACE("Iteration %d, fx=%f", *k, fx);

if (!isLsSuccess || isnan(fx) || isinf(fx)) {
fx = fxp;
x.copy_async(xp, stream);
grad.copy_async(gradp, stream);
if (!isLsSuccess) {
CUML_LOG_ERROR("L-BFGS line search failed");
return OPT_LS_FAILED;
}
CUML_LOG_ERROR("L-BFGS error fx=%f at iteration %d", fx, *k);
return OPT_NUMERIC_ERROR;
}

if (check_convergence(param, *k, fx, x, grad, fx_hist, dev_scalar, stream)) {
CUML_LOG_DEBUG("L-BFGS converged");
return OPT_SUCCESS;
}
lsret = ls_backtrack(param, f, fx, x, grad, step, drt, xp, dev_scalar, stream);

if (update_and_check("L-BFGS",
param,
*k,
lsret,
fx,
fxp,
x,
xp,
grad,
gradp,
fx_hist,
dev_scalar,
retcode,
stream))
return retcode;

// Update s and y
// s_{k+1} = x_{k+1} - x_k
Expand Down Expand Up @@ -282,37 +347,38 @@ inline OPT_RETCODE min_owlqn(const LBFGSParam<T>& param,

int end = 0;
int n_vec = 0; // number of vector updates made in lbfgs_search_dir
OPT_RETCODE retcode;
LINE_SEARCH_RETCODE lsret;
for ((*k) = 1; (*k) <= param.max_iterations; (*k)++) {
// Save the curent x and gradient
xp.copy_async(x, stream);
gradp.copy_async(grad, stream);
fxp = fx;

// Projected line search to update x, fx and gradient
LINE_SEARCH_RETCODE lsret = ls_backtrack_projected(
lsret = ls_backtrack_projected(
param, f_wrap, fx, x, grad, pseudo, step, drt, xp, l1_penalty, dev_scalar, stream);

bool isLsSuccess = lsret == LS_SUCCESS;
if (!isLsSuccess || isnan(fx) || isinf(fx)) {
fx = fxp;
x.copy_async(xp, stream);
grad.copy_async(gradp, stream);
if (!isLsSuccess) {
CUML_LOG_ERROR("QWL-QN line search failed");
return OPT_LS_FAILED;
}
CUML_LOG_ERROR("OWL-QN error fx=%f at iteration %d", fx, *k);
return OPT_NUMERIC_ERROR;
}
if (update_and_check("QWL-QN",
param,
*k,
lsret,
fx,
fxp,
x,
xp,
grad,
gradp,
fx_hist,
dev_scalar,
retcode,
stream))
return retcode;

// recompute pseudo
// pseudo.assign_binary(x, grad, pseudo_grad);
update_pseudo(x, grad, pseudo_grad, pg_limit, pseudo, stream);

if (check_convergence(param, *k, fx, x, pseudo, fx_hist, dev_scalar, stream)) {
CUML_LOG_DEBUG("OWL-QN converged");
return OPT_SUCCESS;
}

// Update s and y - We should only do this if there is no skipping condition

col_ref(S, svec, end);
Expand Down

0 comments on commit d1f5efd

Please sign in to comment.