diff --git a/cpp/src/glm/qn/qn_solvers.cuh b/cpp/src/glm/qn/qn_solvers.cuh index 9c357ec0f6..052d593b13 100644 --- a/cpp/src/glm/qn/qn_solvers.cuh +++ b/cpp/src/glm/qn/qn_solvers.cuh @@ -68,6 +68,74 @@ inline size_t owlqn_workspace_size(const LBFGSParam& param, const int n) return lbfgs_workspace_size(param, n) + vec_size; } +template +inline bool update_and_check(const char* solver, + const LBFGSParam& param, + int iter, + LINE_SEARCH_RETCODE lsret, + T& fx, + T& fxp, + ML::SimpleVec& x, + ML::SimpleVec& xp, + ML::SimpleVec& grad, + ML::SimpleVec& gradp, + std::vector& fx_hist, + T* dev_scalar, + OPT_RETCODE& outcode, + cudaStream_t stream) +{ + bool stop = false; + bool converged = false; + bool isLsValid = !isnan(fx) && !isinf(fx); + // Linesearch may fail to converge, but still come closer to the solution; + // if that is not the case, let `check_convergence` ("insufficient change") + // below terminate the loop. + bool isLsNonCritical = lsret == LS_INVALID_STEP_MIN || lsret == LS_MAX_ITERS_REACHED; + // If the error is not critical, check that the target function does not grow. + // This shouldn't really happen, but weird things can happen if the convergence + // thresholds are too small. + bool isLsInDoubt = isLsValid && fx <= fxp + param.ftol && isLsNonCritical; + bool isLsSuccess = lsret == LS_SUCCESS || isLsInDoubt; + + CUML_LOG_TRACE("%s iteration %d, fx=%f", solver, iter, fx); + + // if the target is at least finite, we can check the convergence + if (isLsValid) + converged = check_convergence(param, iter, fx, x, grad, fx_hist, dev_scalar, stream); + + if (!isLsSuccess && !converged) { + CUML_LOG_WARN( + "%s line search failed (code %d); stopping at the last valid step", solver, lsret); + outcode = OPT_LS_FAILED; + stop = true; + } else if (!isLsValid) { + CUML_LOG_ERROR( + "%s error fx=%f at iteration %d; stopping at the last valid step", solver, fx, iter); + outcode = OPT_NUMERIC_ERROR; + stop = true; + } else if (converged) { + CUML_LOG_DEBUG("%s converged", solver); + outcode = OPT_SUCCESS; + stop = true; + } else if (isLsInDoubt && fx + param.ftol >= fxp) { + // If a non-critical error has happened during the line search, check if the target + // is improved at least a bit. Otherwise, stop to avoid spinning till the iteration limit. + CUML_LOG_WARN( + "%s stopped, because the line search failed to advance (step delta = %f)", solver, fx - fxp); + outcode = OPT_LS_FAILED; + stop = true; + } + + // if lineseach wasn't successful, undo the update. + if (!isLsSuccess || !isLsValid) { + fx = fxp; + x.copy_async(xp, stream); + grad.copy_async(gradp, stream); + } + + return stop; +} + template inline OPT_RETCODE min_lbfgs(const LBFGSParam& param, Function& f, // function to minimize @@ -131,6 +199,8 @@ inline OPT_RETCODE min_lbfgs(const LBFGSParam& param, *k = 1; int end = 0; int n_vec = 0; // number of vector updates made in lbfgs_search_dir + OPT_RETCODE retcode; + LINE_SEARCH_RETCODE lsret; for (; *k <= param.max_iterations; (*k)++) { // Save the curent x and gradient xp.copy_async(x, stream); @@ -138,28 +208,23 @@ inline OPT_RETCODE min_lbfgs(const LBFGSParam& param, fxp = fx; // Line search to update x, fx and gradient - LINE_SEARCH_RETCODE lsret = - ls_backtrack(param, f, fx, x, grad, step, drt, xp, dev_scalar, stream); - - bool isLsSuccess = lsret == LS_SUCCESS; - CUML_LOG_TRACE("Iteration %d, fx=%f", *k, fx); - - if (!isLsSuccess || isnan(fx) || isinf(fx)) { - fx = fxp; - x.copy_async(xp, stream); - grad.copy_async(gradp, stream); - if (!isLsSuccess) { - CUML_LOG_ERROR("L-BFGS line search failed"); - return OPT_LS_FAILED; - } - CUML_LOG_ERROR("L-BFGS error fx=%f at iteration %d", fx, *k); - return OPT_NUMERIC_ERROR; - } - - if (check_convergence(param, *k, fx, x, grad, fx_hist, dev_scalar, stream)) { - CUML_LOG_DEBUG("L-BFGS converged"); - return OPT_SUCCESS; - } + lsret = ls_backtrack(param, f, fx, x, grad, step, drt, xp, dev_scalar, stream); + + if (update_and_check("L-BFGS", + param, + *k, + lsret, + fx, + fxp, + x, + xp, + grad, + gradp, + fx_hist, + dev_scalar, + retcode, + stream)) + return retcode; // Update s and y // s_{k+1} = x_{k+1} - x_k @@ -282,6 +347,8 @@ inline OPT_RETCODE min_owlqn(const LBFGSParam& param, int end = 0; int n_vec = 0; // number of vector updates made in lbfgs_search_dir + OPT_RETCODE retcode; + LINE_SEARCH_RETCODE lsret; for ((*k) = 1; (*k) <= param.max_iterations; (*k)++) { // Save the curent x and gradient xp.copy_async(x, stream); @@ -289,30 +356,29 @@ inline OPT_RETCODE min_owlqn(const LBFGSParam& param, fxp = fx; // Projected line search to update x, fx and gradient - LINE_SEARCH_RETCODE lsret = ls_backtrack_projected( + lsret = ls_backtrack_projected( param, f_wrap, fx, x, grad, pseudo, step, drt, xp, l1_penalty, dev_scalar, stream); - bool isLsSuccess = lsret == LS_SUCCESS; - if (!isLsSuccess || isnan(fx) || isinf(fx)) { - fx = fxp; - x.copy_async(xp, stream); - grad.copy_async(gradp, stream); - if (!isLsSuccess) { - CUML_LOG_ERROR("QWL-QN line search failed"); - return OPT_LS_FAILED; - } - CUML_LOG_ERROR("OWL-QN error fx=%f at iteration %d", fx, *k); - return OPT_NUMERIC_ERROR; - } + if (update_and_check("QWL-QN", + param, + *k, + lsret, + fx, + fxp, + x, + xp, + grad, + gradp, + fx_hist, + dev_scalar, + retcode, + stream)) + return retcode; + // recompute pseudo // pseudo.assign_binary(x, grad, pseudo_grad); update_pseudo(x, grad, pseudo_grad, pg_limit, pseudo, stream); - if (check_convergence(param, *k, fx, x, pseudo, fx_hist, dev_scalar, stream)) { - CUML_LOG_DEBUG("OWL-QN converged"); - return OPT_SUCCESS; - } - // Update s and y - We should only do this if there is no skipping condition col_ref(S, svec, end);