Skip to content

Commit

Permalink
QN solvers: Use different gradient norms for different for different …
Browse files Browse the repository at this point in the history
…loss functions. (rapidsai#4491)

Different loss functions may scale differently with the number of features. This has an effect on the convergence criteria. To account for that, I let a loss function define its preferred metric. As a result, the number of iterations should be less dependent on the number of features for all loss functions.

Authors:
  - Artem M. Chirkin (https://github.com/achirkin)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)
  - Tamas Bela Feher (https://github.com/tfeher)

URL: rapidsai#4491
  • Loading branch information
achirkin authored Feb 10, 2022
1 parent cdb41f9 commit 1abf605
Show file tree
Hide file tree
Showing 8 changed files with 75 additions and 17 deletions.
16 changes: 16 additions & 0 deletions cpp/src/glm/qn/glm_base.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,22 @@ struct GLMWithData : GLMDims {
raft::interruptible::synchronize(stream);
return loss_host;
}

/**
* @brief Calculate a norm of the gradient computed using the given Loss instance.
*
* This function is intended to be used in `check_convergence`; it's output is supposed
* to be proportional to the loss value w.r.t. the number of features (D).
*
* Different loss functions may scale differently with the number of features (D).
* This has an effect on the convergence criteria. To account for that, we let a
* loss function define its preferred metric. Normally, we differentiate between the
* L2 norm (e.g. for Squared loss) and LInf norm (e.g. for Softmax loss).
*/
inline T gradNorm(const SimpleVec<T>& grad, T* dev_scalar, cudaStream_t stream)
{
return objective->gradNorm(grad, dev_scalar, stream);
}
};

}; // namespace GLM
Expand Down
10 changes: 10 additions & 0 deletions cpp/src/glm/qn/glm_linear.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,11 @@ struct SquaredLoss : GLMBase<T, SquaredLoss<T>> {
: Super(handle, D, 1, has_bias), lz{}, dlz{}
{
}

inline T gradNorm(const SimpleVec<T>& grad, T* dev_scalar, cudaStream_t stream)
{
return squaredNorm(grad, dev_scalar, stream) * 0.5;
}
};

template <typename T>
Expand All @@ -65,6 +70,11 @@ struct AbsLoss : GLMBase<T, AbsLoss<T>> {
: Super(handle, D, 1, has_bias), lz{}, dlz{}
{
}

inline T gradNorm(const SimpleVec<T>& grad, T* dev_scalar, cudaStream_t stream)
{
return nrm1(grad, dev_scalar, stream);
}
};

}; // namespace GLM
Expand Down
5 changes: 5 additions & 0 deletions cpp/src/glm/qn/glm_logistic.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,11 @@ struct LogisticLoss : GLMBase<T, LogisticLoss<T>> {
: Super(handle, D, 1, has_bias), lz{}, dlz{}
{
}

inline T gradNorm(const SimpleVec<T>& grad, T* dev_scalar, cudaStream_t stream)
{
return nrmMax(grad, dev_scalar, stream);
}
};
}; // namespace GLM
}; // namespace ML
5 changes: 5 additions & 0 deletions cpp/src/glm/qn/glm_regularizer.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,11 @@ struct RegularizedGLM : GLMDims {

lossVal.fill(loss_host + reg_host, stream);
}

inline T gradNorm(const SimpleVec<T>& grad, T* dev_scalar, cudaStream_t stream)
{
return loss->gradNorm(grad, dev_scalar, stream);
}
};
}; // namespace GLM
}; // namespace ML
5 changes: 5 additions & 0 deletions cpp/src/glm/qn/glm_softmax.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,11 @@ struct Softmax : GLMBase<T, Softmax<T>> {
{
launchLogsoftmax(loss_val, Z.data, Z.data, y.data, Z.m, Z.n, stream);
}

inline T gradNorm(const SimpleVec<T>& grad, T* dev_scalar, cudaStream_t stream)
{
return nrmMax(grad, dev_scalar, stream);
}
};

}; // namespace GLM
Expand Down
20 changes: 20 additions & 0 deletions cpp/src/glm/qn/glm_svm.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,11 @@ struct SVCL1Loss : GLMBase<T, SVCL1Loss<T>> {
: Super(handle, D, 1, has_bias), lz{}, dlz{}
{
}

inline T gradNorm(const SimpleVec<T>& grad, T* dev_scalar, cudaStream_t stream)
{
return nrm1(grad, dev_scalar, stream);
}
};

template <typename T>
Expand Down Expand Up @@ -75,6 +80,11 @@ struct SVCL2Loss : GLMBase<T, SVCL2Loss<T>> {
: Super(handle, D, 1, has_bias), lz{}, dlz{}
{
}

inline T gradNorm(const SimpleVec<T>& grad, T* dev_scalar, cudaStream_t stream)
{
return squaredNorm(grad, dev_scalar, stream) * 0.5;
}
};

template <typename T>
Expand Down Expand Up @@ -103,6 +113,11 @@ struct SVRL1Loss : GLMBase<T, SVRL1Loss<T>> {
: Super(handle, D, 1, has_bias), lz{sensitivity}, dlz{sensitivity}
{
}

inline T gradNorm(const SimpleVec<T>& grad, T* dev_scalar, cudaStream_t stream)
{
return nrm1(grad, dev_scalar, stream);
}
};

template <typename T>
Expand Down Expand Up @@ -132,6 +147,11 @@ struct SVRL2Loss : GLMBase<T, SVRL2Loss<T>> {
: Super(handle, D, 1, has_bias), lz{sensitivity}, dlz{sensitivity}
{
}

inline T gradNorm(const SimpleVec<T>& grad, T* dev_scalar, cudaStream_t stream)
{
return squaredNorm(grad, dev_scalar, stream) * 0.5;
}
};

}; // namespace GLM
Expand Down
18 changes: 12 additions & 6 deletions cpp/src/glm/qn/qn_solvers.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ inline bool update_and_check(const char* solver,
LINE_SEARCH_RETCODE lsret,
T& fx,
T& fxp,
const T& gnorm,
ML::SimpleVec<T>& x,
ML::SimpleVec<T>& xp,
ML::SimpleVec<T>& grad,
Expand All @@ -100,8 +101,7 @@ inline bool update_and_check(const char* solver,
CUML_LOG_TRACE("%s iteration %d, fx=%f", solver, iter, fx);

// if the target is at least finite, we can check the convergence
if (isLsValid)
converged = check_convergence(param, iter, fx, x, grad, fx_hist, dev_scalar, stream);
if (isLsValid) converged = check_convergence(param, iter, fx, gnorm, fx_hist);

if (!isLsSuccess && !converged) {
CUML_LOG_WARN(
Expand Down Expand Up @@ -179,12 +179,13 @@ inline OPT_RETCODE min_lbfgs(const LBFGSParam<T>& param,
CUML_LOG_DEBUG("Running L-BFGS");

// Evaluate function and compute gradient
fx = f(x, grad, dev_scalar, stream);
fx = f(x, grad, dev_scalar, stream);
T gnorm = f.gradNorm(grad, dev_scalar, stream);

if (param.past > 0) fx_hist[0] = fx;

// Early exit if the initial x is already a minimizer
if (check_convergence(param, *k, fx, x, grad, fx_hist, dev_scalar, stream)) {
if (check_convergence(param, *k, fx, gnorm, fx_hist)) {
CUML_LOG_DEBUG("Initial solution fulfills optimality condition.");
return OPT_SUCCESS;
}
Expand All @@ -209,13 +210,15 @@ inline OPT_RETCODE min_lbfgs(const LBFGSParam<T>& param,

// Line search to update x, fx and gradient
lsret = ls_backtrack(param, f, fx, x, grad, step, drt, xp, dev_scalar, stream);
gnorm = f.gradNorm(grad, dev_scalar, stream);

if (update_and_check("L-BFGS",
param,
*k,
lsret,
fx,
fxp,
gnorm,
x,
xp,
grad,
Expand Down Expand Up @@ -321,8 +324,9 @@ inline OPT_RETCODE min_owlqn(const LBFGSParam<T>& param,
// op to compute the pseudo gradients
op_pseudo_grad<T> pseudo_grad(l1_penalty);

fx = f_wrap(x, grad, dev_scalar,
fx = f_wrap(x, grad, dev_scalar,
stream); // fx is loss+regularizer, grad is grad of loss only
T gnorm = f.gradNorm(grad, dev_scalar, stream);

// compute pseudo grad, but don't overwrite grad: used to build H
// pseudo.assign_binary(x, grad, pseudo_grad);
Expand All @@ -331,7 +335,7 @@ inline OPT_RETCODE min_owlqn(const LBFGSParam<T>& param,
if (param.past > 0) fx_hist[0] = fx;

// Early exit if the initial x is already a minimizer
if (check_convergence(param, *k, fx, x, grad, fx_hist, dev_scalar, stream)) {
if (check_convergence(param, *k, fx, gnorm, fx_hist)) {
CUML_LOG_DEBUG("Initial solution fulfills optimality condition.");
return OPT_SUCCESS;
}
Expand All @@ -358,13 +362,15 @@ inline OPT_RETCODE min_owlqn(const LBFGSParam<T>& param,
// Projected line search to update x, fx and gradient
lsret = ls_backtrack_projected(
param, f_wrap, fx, x, grad, pseudo, step, drt, xp, l1_penalty, dev_scalar, stream);
gnorm = f.gradNorm(grad, dev_scalar, stream);

if (update_and_check("QWL-QN",
param,
*k,
lsret,
fx,
fxp,
gnorm,
x,
xp,
grad,
Expand Down
13 changes: 2 additions & 11 deletions cpp/src/glm/qn/qn_util.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -143,18 +143,9 @@ HDI T project_orth(T x, T y)
}

template <typename T>
inline bool check_convergence(const LBFGSParam<T>& param,
const int k,
const T fx,
SimpleVec<T>& x,
SimpleVec<T>& grad,
std::vector<T>& fx_hist,
T* dev_scalar,
cudaStream_t stream)
inline bool check_convergence(
const LBFGSParam<T>& param, const int k, const T fx, const T gnorm, std::vector<T>& fx_hist)
{
// Gradient norm is now in Linf to match the reference implementation
// (originally it was L2-norm)
T gnorm = nrmMax(grad, dev_scalar, stream);
// Positive scale factor for the stop condition
T fmag = std::max(fx, param.epsilon);

Expand Down

0 comments on commit 1abf605

Please sign in to comment.