From 1abf605c0cd8edb02bf0704b2c6ba0e6e04f8219 Mon Sep 17 00:00:00 2001 From: "Artem M. Chirkin" <9253178+achirkin@users.noreply.github.com> Date: Thu, 10 Feb 2022 01:20:46 +0100 Subject: [PATCH] QN solvers: Use different gradient norms for different for different loss functions. (#4491) Different loss functions may scale differently with the number of features. This has an effect on the convergence criteria. To account for that, I let a loss function define its preferred metric. As a result, the number of iterations should be less dependent on the number of features for all loss functions. Authors: - Artem M. Chirkin (https://github.com/achirkin) Approvers: - Corey J. Nolet (https://github.com/cjnolet) - Tamas Bela Feher (https://github.com/tfeher) URL: https://github.com/rapidsai/cuml/pull/4491 --- cpp/src/glm/qn/glm_base.cuh | 16 ++++++++++++++++ cpp/src/glm/qn/glm_linear.cuh | 10 ++++++++++ cpp/src/glm/qn/glm_logistic.cuh | 5 +++++ cpp/src/glm/qn/glm_regularizer.cuh | 5 +++++ cpp/src/glm/qn/glm_softmax.cuh | 5 +++++ cpp/src/glm/qn/glm_svm.cuh | 20 ++++++++++++++++++++ cpp/src/glm/qn/qn_solvers.cuh | 18 ++++++++++++------ cpp/src/glm/qn/qn_util.cuh | 13 ++----------- 8 files changed, 75 insertions(+), 17 deletions(-) diff --git a/cpp/src/glm/qn/glm_base.cuh b/cpp/src/glm/qn/glm_base.cuh index cfe975b48c..ef935b317e 100644 --- a/cpp/src/glm/qn/glm_base.cuh +++ b/cpp/src/glm/qn/glm_base.cuh @@ -222,6 +222,22 @@ struct GLMWithData : GLMDims { raft::interruptible::synchronize(stream); return loss_host; } + + /** + * @brief Calculate a norm of the gradient computed using the given Loss instance. + * + * This function is intended to be used in `check_convergence`; it's output is supposed + * to be proportional to the loss value w.r.t. the number of features (D). + * + * Different loss functions may scale differently with the number of features (D). + * This has an effect on the convergence criteria. To account for that, we let a + * loss function define its preferred metric. Normally, we differentiate between the + * L2 norm (e.g. for Squared loss) and LInf norm (e.g. for Softmax loss). + */ + inline T gradNorm(const SimpleVec& grad, T* dev_scalar, cudaStream_t stream) + { + return objective->gradNorm(grad, dev_scalar, stream); + } }; }; // namespace GLM diff --git a/cpp/src/glm/qn/glm_linear.cuh b/cpp/src/glm/qn/glm_linear.cuh index 11df1d5833..d5f7f51aed 100644 --- a/cpp/src/glm/qn/glm_linear.cuh +++ b/cpp/src/glm/qn/glm_linear.cuh @@ -44,6 +44,11 @@ struct SquaredLoss : GLMBase> { : Super(handle, D, 1, has_bias), lz{}, dlz{} { } + + inline T gradNorm(const SimpleVec& grad, T* dev_scalar, cudaStream_t stream) + { + return squaredNorm(grad, dev_scalar, stream) * 0.5; + } }; template @@ -65,6 +70,11 @@ struct AbsLoss : GLMBase> { : Super(handle, D, 1, has_bias), lz{}, dlz{} { } + + inline T gradNorm(const SimpleVec& grad, T* dev_scalar, cudaStream_t stream) + { + return nrm1(grad, dev_scalar, stream); + } }; }; // namespace GLM diff --git a/cpp/src/glm/qn/glm_logistic.cuh b/cpp/src/glm/qn/glm_logistic.cuh index 5e76da4843..5129006ec9 100644 --- a/cpp/src/glm/qn/glm_logistic.cuh +++ b/cpp/src/glm/qn/glm_logistic.cuh @@ -57,6 +57,11 @@ struct LogisticLoss : GLMBase> { : Super(handle, D, 1, has_bias), lz{}, dlz{} { } + + inline T gradNorm(const SimpleVec& grad, T* dev_scalar, cudaStream_t stream) + { + return nrmMax(grad, dev_scalar, stream); + } }; }; // namespace GLM }; // namespace ML diff --git a/cpp/src/glm/qn/glm_regularizer.cuh b/cpp/src/glm/qn/glm_regularizer.cuh index 4650205fc2..0f99df716e 100644 --- a/cpp/src/glm/qn/glm_regularizer.cuh +++ b/cpp/src/glm/qn/glm_regularizer.cuh @@ -85,6 +85,11 @@ struct RegularizedGLM : GLMDims { lossVal.fill(loss_host + reg_host, stream); } + + inline T gradNorm(const SimpleVec& grad, T* dev_scalar, cudaStream_t stream) + { + return loss->gradNorm(grad, dev_scalar, stream); + } }; }; // namespace GLM }; // namespace ML diff --git a/cpp/src/glm/qn/glm_softmax.cuh b/cpp/src/glm/qn/glm_softmax.cuh index f0f3835403..abf7708e03 100644 --- a/cpp/src/glm/qn/glm_softmax.cuh +++ b/cpp/src/glm/qn/glm_softmax.cuh @@ -188,6 +188,11 @@ struct Softmax : GLMBase> { { launchLogsoftmax(loss_val, Z.data, Z.data, y.data, Z.m, Z.n, stream); } + + inline T gradNorm(const SimpleVec& grad, T* dev_scalar, cudaStream_t stream) + { + return nrmMax(grad, dev_scalar, stream); + } }; }; // namespace GLM diff --git a/cpp/src/glm/qn/glm_svm.cuh b/cpp/src/glm/qn/glm_svm.cuh index fa71377760..944bfefe7a 100644 --- a/cpp/src/glm/qn/glm_svm.cuh +++ b/cpp/src/glm/qn/glm_svm.cuh @@ -48,6 +48,11 @@ struct SVCL1Loss : GLMBase> { : Super(handle, D, 1, has_bias), lz{}, dlz{} { } + + inline T gradNorm(const SimpleVec& grad, T* dev_scalar, cudaStream_t stream) + { + return nrm1(grad, dev_scalar, stream); + } }; template @@ -75,6 +80,11 @@ struct SVCL2Loss : GLMBase> { : Super(handle, D, 1, has_bias), lz{}, dlz{} { } + + inline T gradNorm(const SimpleVec& grad, T* dev_scalar, cudaStream_t stream) + { + return squaredNorm(grad, dev_scalar, stream) * 0.5; + } }; template @@ -103,6 +113,11 @@ struct SVRL1Loss : GLMBase> { : Super(handle, D, 1, has_bias), lz{sensitivity}, dlz{sensitivity} { } + + inline T gradNorm(const SimpleVec& grad, T* dev_scalar, cudaStream_t stream) + { + return nrm1(grad, dev_scalar, stream); + } }; template @@ -132,6 +147,11 @@ struct SVRL2Loss : GLMBase> { : Super(handle, D, 1, has_bias), lz{sensitivity}, dlz{sensitivity} { } + + inline T gradNorm(const SimpleVec& grad, T* dev_scalar, cudaStream_t stream) + { + return squaredNorm(grad, dev_scalar, stream) * 0.5; + } }; }; // namespace GLM diff --git a/cpp/src/glm/qn/qn_solvers.cuh b/cpp/src/glm/qn/qn_solvers.cuh index fdda4b4485..57322d692a 100644 --- a/cpp/src/glm/qn/qn_solvers.cuh +++ b/cpp/src/glm/qn/qn_solvers.cuh @@ -75,6 +75,7 @@ inline bool update_and_check(const char* solver, LINE_SEARCH_RETCODE lsret, T& fx, T& fxp, + const T& gnorm, ML::SimpleVec& x, ML::SimpleVec& xp, ML::SimpleVec& grad, @@ -100,8 +101,7 @@ inline bool update_and_check(const char* solver, CUML_LOG_TRACE("%s iteration %d, fx=%f", solver, iter, fx); // if the target is at least finite, we can check the convergence - if (isLsValid) - converged = check_convergence(param, iter, fx, x, grad, fx_hist, dev_scalar, stream); + if (isLsValid) converged = check_convergence(param, iter, fx, gnorm, fx_hist); if (!isLsSuccess && !converged) { CUML_LOG_WARN( @@ -179,12 +179,13 @@ inline OPT_RETCODE min_lbfgs(const LBFGSParam& param, CUML_LOG_DEBUG("Running L-BFGS"); // Evaluate function and compute gradient - fx = f(x, grad, dev_scalar, stream); + fx = f(x, grad, dev_scalar, stream); + T gnorm = f.gradNorm(grad, dev_scalar, stream); if (param.past > 0) fx_hist[0] = fx; // Early exit if the initial x is already a minimizer - if (check_convergence(param, *k, fx, x, grad, fx_hist, dev_scalar, stream)) { + if (check_convergence(param, *k, fx, gnorm, fx_hist)) { CUML_LOG_DEBUG("Initial solution fulfills optimality condition."); return OPT_SUCCESS; } @@ -209,6 +210,7 @@ inline OPT_RETCODE min_lbfgs(const LBFGSParam& param, // Line search to update x, fx and gradient lsret = ls_backtrack(param, f, fx, x, grad, step, drt, xp, dev_scalar, stream); + gnorm = f.gradNorm(grad, dev_scalar, stream); if (update_and_check("L-BFGS", param, @@ -216,6 +218,7 @@ inline OPT_RETCODE min_lbfgs(const LBFGSParam& param, lsret, fx, fxp, + gnorm, x, xp, grad, @@ -321,8 +324,9 @@ inline OPT_RETCODE min_owlqn(const LBFGSParam& param, // op to compute the pseudo gradients op_pseudo_grad pseudo_grad(l1_penalty); - fx = f_wrap(x, grad, dev_scalar, + fx = f_wrap(x, grad, dev_scalar, stream); // fx is loss+regularizer, grad is grad of loss only + T gnorm = f.gradNorm(grad, dev_scalar, stream); // compute pseudo grad, but don't overwrite grad: used to build H // pseudo.assign_binary(x, grad, pseudo_grad); @@ -331,7 +335,7 @@ inline OPT_RETCODE min_owlqn(const LBFGSParam& param, if (param.past > 0) fx_hist[0] = fx; // Early exit if the initial x is already a minimizer - if (check_convergence(param, *k, fx, x, grad, fx_hist, dev_scalar, stream)) { + if (check_convergence(param, *k, fx, gnorm, fx_hist)) { CUML_LOG_DEBUG("Initial solution fulfills optimality condition."); return OPT_SUCCESS; } @@ -358,6 +362,7 @@ inline OPT_RETCODE min_owlqn(const LBFGSParam& param, // Projected line search to update x, fx and gradient lsret = ls_backtrack_projected( param, f_wrap, fx, x, grad, pseudo, step, drt, xp, l1_penalty, dev_scalar, stream); + gnorm = f.gradNorm(grad, dev_scalar, stream); if (update_and_check("QWL-QN", param, @@ -365,6 +370,7 @@ inline OPT_RETCODE min_owlqn(const LBFGSParam& param, lsret, fx, fxp, + gnorm, x, xp, grad, diff --git a/cpp/src/glm/qn/qn_util.cuh b/cpp/src/glm/qn/qn_util.cuh index c0620ee445..a5c28172c6 100644 --- a/cpp/src/glm/qn/qn_util.cuh +++ b/cpp/src/glm/qn/qn_util.cuh @@ -143,18 +143,9 @@ HDI T project_orth(T x, T y) } template -inline bool check_convergence(const LBFGSParam& param, - const int k, - const T fx, - SimpleVec& x, - SimpleVec& grad, - std::vector& fx_hist, - T* dev_scalar, - cudaStream_t stream) +inline bool check_convergence( + const LBFGSParam& param, const int k, const T fx, const T gnorm, std::vector& fx_hist) { - // Gradient norm is now in Linf to match the reference implementation - // (originally it was L2-norm) - T gnorm = nrmMax(grad, dev_scalar, stream); // Positive scale factor for the stop condition T fmag = std::max(fx, param.epsilon);