Skip to content

Commit

Permalink
Expose the secondary stopping condition for QN solver (rapidsai#3777)
Browse files Browse the repository at this point in the history
- Expose a parameter `delta` of the `QN` solver to control the loss value change stopping condition
 - Set a reasonable default for the parameter value that should keep the behavior close to sklearn in most cases

Note, this change does not expose `delta` to the wrapper class `LogisticRegression`.

Note, although this change does not break the python API, it does break the C/C++ API.

Contributes to solving rapidsai#3645

Authors:
  - Artem M. Chirkin (https://github.com/achirkin)

Approvers:
  - Tamas Bela Feher (https://github.com/tfeher)
  - Dante Gama Dessavre (https://github.com/dantegd)

URL: rapidsai#3777
  • Loading branch information
achirkin authored Apr 26, 2021
1 parent 80bc810 commit 93fed4d
Show file tree
Hide file tree
Showing 8 changed files with 104 additions and 53 deletions.
21 changes: 14 additions & 7 deletions cpp/include/cuml/linear_model/glm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,13 @@ void gemmPredict(const raft::handle_t &handle, const double *input, int n_rows,
* @param l2 l2 regularization strength. Note, that as in
* scikit, the bias will not be regularized.
* @param max_iter limit on iteration number
* @param grad_tol tolerance for gradient norm convergence check
* @param grad_tol tolerance for gradient norm convergence check.
* The training process will stop if
* `norm(current_loss_grad, inf) <= grad_tol * max(current_loss, grad_tol)`.
* @param change_tol tolerance for function change convergence check.
* The training process will stop if
* `abs(current_loss - previous_loss) <= change_tol * max(current_loss, grad_tol)`,
* where `previous_loss` is the loss value a small fixed number of steps ago.
* @param linesearch_max_iter max number of linesearch iterations per outer
* iteration
* @param lbfgs_memory rank of the lbfgs inverse-Hessian approximation.
Expand All @@ -122,14 +128,15 @@ void gemmPredict(const raft::handle_t &handle, const double *input, int n_rows,
*/
void qnFit(const raft::handle_t &cuml_handle, float *X, float *y, int N, int D,
int C, bool fit_intercept, float l1, float l2, int max_iter,
float grad_tol, int linesearch_max_iter, int lbfgs_memory,
int verbosity, float *w0, float *f, int *num_iters, bool X_col_major,
int loss_type, float *sample_weight = nullptr);
float grad_tol, float change_tol, int linesearch_max_iter,
int lbfgs_memory, int verbosity, float *w0, float *f, int *num_iters,
bool X_col_major, int loss_type, float *sample_weight = nullptr);
void qnFit(const raft::handle_t &cuml_handle, double *X, double *y, int N,
int D, int C, bool fit_intercept, double l1, double l2, int max_iter,
double grad_tol, int linesearch_max_iter, int lbfgs_memory,
int verbosity, double *w0, double *f, int *num_iters,
bool X_col_major, int loss_type, double *sample_weight = nullptr);
double grad_tol, double change_tol, int linesearch_max_iter,
int lbfgs_memory, int verbosity, double *w0, double *f,
int *num_iters, bool X_col_major, int loss_type,
double *sample_weight = nullptr);
/** @} */

/**
Expand Down
14 changes: 8 additions & 6 deletions cpp/include/cuml/linear_model/glm_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,17 @@ extern "C" {

cumlError_t cumlSpQnFit(cumlHandle_t cuml_handle, float *X, float *y, int N,
int D, int C, bool fit_intercept, float l1, float l2,
int max_iter, float grad_tol, int linesearch_max_iter,
int lbfgs_memory, int verbosity, float *w0, float *f,
int *num_iters, bool X_col_major, int loss_type);
int max_iter, float grad_tol, float change_tol,
int linesearch_max_iter, int lbfgs_memory,
int verbosity, float *w0, float *f, int *num_iters,
bool X_col_major, int loss_type);

cumlError_t cumlDpQnFit(cumlHandle_t cuml_handle, double *X, double *y, int N,
int D, int C, bool fit_intercept, double l1, double l2,
int max_iter, double grad_tol, int linesearch_max_iter,
int lbfgs_memory, int verbosity, double *w0, double *f,
int *num_iters, bool X_col_major, int loss_type);
int max_iter, double grad_tol, double change_tol,
int linesearch_max_iter, int lbfgs_memory,
int verbosity, double *w0, double *f, int *num_iters,
bool X_col_major, int loss_type);

#ifdef __cplusplus
}
Expand Down
23 changes: 13 additions & 10 deletions cpp/src/glm/glm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -72,22 +72,25 @@ void ridgeFit(const raft::handle_t &handle, double *input, int n_rows,

void qnFit(const raft::handle_t &cuml_handle, float *X, float *y, int N, int D,
int C, bool fit_intercept, float l1, float l2, int max_iter,
float grad_tol, int linesearch_max_iter, int lbfgs_memory,
int verbosity, float *w0, float *f, int *num_iters, bool X_col_major,
int loss_type, float *sample_weight) {
float grad_tol, float change_tol, int linesearch_max_iter,
int lbfgs_memory, int verbosity, float *w0, float *f, int *num_iters,
bool X_col_major, int loss_type, float *sample_weight) {
qnFit(cuml_handle, X, y, N, D, C, fit_intercept, l1, l2, max_iter, grad_tol,
linesearch_max_iter, lbfgs_memory, verbosity, w0, f, num_iters,
X_col_major, loss_type, cuml_handle.get_stream(), sample_weight);
change_tol, linesearch_max_iter, lbfgs_memory, verbosity, w0, f,
num_iters, X_col_major, loss_type, cuml_handle.get_stream(),
sample_weight);
}

void qnFit(const raft::handle_t &cuml_handle, double *X, double *y, int N,
int D, int C, bool fit_intercept, double l1, double l2, int max_iter,
double grad_tol, int linesearch_max_iter, int lbfgs_memory,
int verbosity, double *w0, double *f, int *num_iters,
bool X_col_major, int loss_type, double *sample_weight) {
double grad_tol, double change_tol, int linesearch_max_iter,
int lbfgs_memory, int verbosity, double *w0, double *f,
int *num_iters, bool X_col_major, int loss_type,
double *sample_weight) {
qnFit(cuml_handle, X, y, N, D, C, fit_intercept, l1, l2, max_iter, grad_tol,
linesearch_max_iter, lbfgs_memory, verbosity, w0, f, num_iters,
X_col_major, loss_type, cuml_handle.get_stream(), sample_weight);
change_tol, linesearch_max_iter, lbfgs_memory, verbosity, w0, f,
num_iters, X_col_major, loss_type, cuml_handle.get_stream(),
sample_weight);
}

void qnDecisionFunction(const raft::handle_t &cuml_handle, float *X, int N,
Expand Down
24 changes: 14 additions & 10 deletions cpp/src/glm/glm_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,19 @@ extern "C" {

cumlError_t cumlSpQnFit(cumlHandle_t cuml_handle, float *X, float *y, int N,
int D, int C, bool fit_intercept, float l1, float l2,
int max_iter, float grad_tol, int linesearch_max_iter,
int lbfgs_memory, int verbosity, float *w0, float *f,
int *num_iters, bool X_col_major, int loss_type) {
int max_iter, float grad_tol, float change_tol,
int linesearch_max_iter, int lbfgs_memory,
int verbosity, float *w0, float *f, int *num_iters,
bool X_col_major, int loss_type) {
cumlError_t status;
raft::handle_t *handle_ptr;
std::tie(handle_ptr, status) = ML::handleMap.lookupHandlePointer(cuml_handle);
if (status == CUML_SUCCESS) {
try {
ML::GLM::qnFit(*handle_ptr, X, y, N, D, C, fit_intercept, l1, l2,
max_iter, grad_tol, linesearch_max_iter, lbfgs_memory,
verbosity, w0, f, num_iters, X_col_major, loss_type);
max_iter, grad_tol, change_tol, linesearch_max_iter,
lbfgs_memory, verbosity, w0, f, num_iters, X_col_major,
loss_type);

}
// TODO: Implement this
Expand All @@ -51,17 +53,19 @@ cumlError_t cumlSpQnFit(cumlHandle_t cuml_handle, float *X, float *y, int N,

cumlError_t cumlDpQnFit(cumlHandle_t cuml_handle, double *X, double *y, int N,
int D, int C, bool fit_intercept, double l1, double l2,
int max_iter, double grad_tol, int linesearch_max_iter,
int lbfgs_memory, int verbosity, double *w0, double *f,
int *num_iters, bool X_col_major, int loss_type) {
int max_iter, double grad_tol, double change_tol,
int linesearch_max_iter, int lbfgs_memory,
int verbosity, double *w0, double *f, int *num_iters,
bool X_col_major, int loss_type) {
cumlError_t status;
raft::handle_t *handle_ptr;
std::tie(handle_ptr, status) = ML::handleMap.lookupHandlePointer(cuml_handle);
if (status == CUML_SUCCESS) {
try {
ML::GLM::qnFit(*handle_ptr, X, y, N, D, C, fit_intercept, l1, l2,
max_iter, grad_tol, linesearch_max_iter, lbfgs_memory,
verbosity, w0, f, num_iters, X_col_major, loss_type);
max_iter, grad_tol, change_tol, linesearch_max_iter,
lbfgs_memory, verbosity, w0, f, num_iters, X_col_major,
loss_type);

}
// TODO: Implement this
Expand Down
25 changes: 15 additions & 10 deletions cpp/src/glm/qn/qn.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,14 @@ namespace ML {
namespace GLM {
template <typename T, typename LossFunction>
int qn_fit(const raft::handle_t &handle, LossFunction &loss, T *Xptr, T *yptr,
T *zptr, int N, T l1, T l2, int max_iter, T grad_tol,
T *zptr, int N, T l1, T l2, int max_iter, T grad_tol, T change_tol,
int linesearch_max_iter, int lbfgs_memory, int verbosity,
T *w0, // initial value and result
T *fx, int *num_iters, STORAGE_ORDER ordX, cudaStream_t stream) {
LBFGSParam<T> opt_param;
opt_param.epsilon = grad_tol;
if (change_tol > 0) opt_param.past = 10; // even number - to detect zig-zags
opt_param.delta = change_tol;
opt_param.max_iterations = max_iter;
opt_param.m = lbfgs_memory;
opt_param.max_linesearch = linesearch_max_iter;
Expand Down Expand Up @@ -62,9 +64,9 @@ int qn_fit(const raft::handle_t &handle, LossFunction &loss, T *Xptr, T *yptr,
template <typename T>
void qnFit(const raft::handle_t &handle, T *X, T *y, int N, int D, int C,
bool fit_intercept, T l1, T l2, int max_iter, T grad_tol,
int linesearch_max_iter, int lbfgs_memory, int verbosity, T *w0,
T *f, int *num_iters, bool X_col_major, int loss_type,
cudaStream_t stream, T *sample_weight = nullptr) {
T change_tol, int linesearch_max_iter, int lbfgs_memory,
int verbosity, T *w0, T *f, int *num_iters, bool X_col_major,
int loss_type, cudaStream_t stream, T *sample_weight = nullptr) {
STORAGE_ORDER ord = X_col_major ? COL_MAJOR : ROW_MAJOR;
int C_len = (loss_type == 0) ? (C - 1) : C;

Expand All @@ -77,24 +79,27 @@ void qnFit(const raft::handle_t &handle, T *X, T *y, int N, int D, int C,
LogisticLoss<T> loss(handle, D, fit_intercept);
if (sample_weight) loss.add_sample_weights(sample_weight, N, stream);
qn_fit<T, decltype(loss)>(handle, loss, X, y, z.data, N, l1, l2, max_iter,
grad_tol, linesearch_max_iter, lbfgs_memory,
verbosity, w0, f, num_iters, ord, stream);
grad_tol, change_tol, linesearch_max_iter,
lbfgs_memory, verbosity, w0, f, num_iters, ord,
stream);
} break;
case 1: {
ASSERT(C == 1, "qn.h: squared loss invalid C");
SquaredLoss<T> loss(handle, D, fit_intercept);
if (sample_weight) loss.add_sample_weights(sample_weight, N, stream);
qn_fit<T, decltype(loss)>(handle, loss, X, y, z.data, N, l1, l2, max_iter,
grad_tol, linesearch_max_iter, lbfgs_memory,
verbosity, w0, f, num_iters, ord, stream);
grad_tol, change_tol, linesearch_max_iter,
lbfgs_memory, verbosity, w0, f, num_iters, ord,
stream);
} break;
case 2: {
ASSERT(C > 2, "qn.h: softmax invalid C");
Softmax<T> loss(handle, D, C, fit_intercept);
if (sample_weight) loss.add_sample_weights(sample_weight, N, stream);
qn_fit<T, decltype(loss)>(handle, loss, X, y, z.data, N, l1, l2, max_iter,
grad_tol, linesearch_max_iter, lbfgs_memory,
verbosity, w0, f, num_iters, ord, stream);
grad_tol, change_tol, linesearch_max_iter,
lbfgs_memory, verbosity, w0, f, num_iters, ord,
stream);
} break;
default: {
ASSERT(false, "qn.h: unknown loss function.");
Expand Down
6 changes: 3 additions & 3 deletions cpp/test/c_api/glm_api_test.c
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ void test_glm() {
cumlHandle_t handle = 0;
cumlError_t response = CUML_SUCCESS;

response = cumlSpQnFit(handle, NULL, NULL, 0, 1, 2, false, 1.0f, 2.0f, 3, 3.0f, 4, 5, 6, NULL, NULL, NULL, true, 7);
response = cumlSpQnFit(handle, NULL, NULL, 0, 1, 2, false, 1.0f, 2.0f, 3, 3.0f, 3.5f, 4, 5, 6, NULL, NULL, NULL, true, 7);

response = cumlDpQnFit(handle, NULL, NULL, 0, 1, 2, false, 1.0f, 2.0f, 3, 3.0f, 4, 5, 6, NULL, NULL, NULL, true, 7);
response = cumlDpQnFit(handle, NULL, NULL, 0, 1, 2, false, 1.0f, 2.0f, 3, 3.0f, 3.5f, 4, 5, 6, NULL, NULL, NULL, true, 7);

}
}
11 changes: 7 additions & 4 deletions cpp/test/sg/quasi_newton.cu
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ T run(const raft::handle_t &handle, LossFunction &loss, const SimpleMat<T> &X,
cudaStream_t stream) {
int max_iter = 100;
T grad_tol = 1e-16;
T change_tol = 1e-16;
int linesearch_max_iter = 50;
int lbfgs_memory = 5;
int num_iters = 0;
Expand All @@ -115,8 +116,9 @@ T run(const raft::handle_t &handle, LossFunction &loss, const SimpleMat<T> &X,
SimpleVec<T> w0(w, loss.n_param);

qn_fit<T, LossFunction>(handle, loss, X.data, y.data, z.data, X.m, l1, l2,
max_iter, grad_tol, linesearch_max_iter, lbfgs_memory,
verbosity, w0.data, &fx, &num_iters, X.ord, stream);
max_iter, grad_tol, change_tol, linesearch_max_iter,
lbfgs_memory, verbosity, w0.data, &fx, &num_iters,
X.ord, stream);

return fx;
}
Expand All @@ -128,6 +130,7 @@ T run_api(const raft::handle_t &cuml_handle, int loss_type, int C,
cudaStream_t stream) {
int max_iter = 100;
T grad_tol = 1e-8;
T change_tol = 1e-8;
int linesearch_max_iter = 50;
int lbfgs_memory = 5;
int num_iters = 0;
Expand All @@ -137,8 +140,8 @@ T run_api(const raft::handle_t &cuml_handle, int loss_type, int C,
T fx;

qnFit(cuml_handle, X.data, y.data, X.m, X.n, C, fit_intercept, l1, l2,
max_iter, grad_tol, linesearch_max_iter, lbfgs_memory, verbosity, w,
&fx, &num_iters, false, loss_type);
max_iter, grad_tol, change_tol, linesearch_max_iter, lbfgs_memory,
verbosity, w, &fx, &num_iters, false, loss_type);

return fx;
}
Expand Down
33 changes: 30 additions & 3 deletions python/cuml/solvers/qn.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ cdef extern from "cuml/linear_model/glm.hpp" namespace "ML::GLM":
float l2,
int max_iter,
float grad_tol,
float change_tol,
int linesearch_max_iter,
int lbfgs_memory,
int verbosity,
Expand All @@ -67,6 +68,7 @@ cdef extern from "cuml/linear_model/glm.hpp" namespace "ML::GLM":
double l2,
int max_iter,
double grad_tol,
double change_tol,
int linesearch_max_iter,
int lbfgs_memory,
int verbosity,
Expand Down Expand Up @@ -225,6 +227,26 @@ class QN(Base,
To account for the differences you may divide the `tol` by the sample
size; this would ensure that the cuML solver does not stop earlier than
the Scikit-learn solver.
delta: Optional[float] (default = None)
The training process will stop if
`abs(current_loss - previous_loss) <= delta * max(current_loss, tol)`.
When `None`, it's set to `tol * 0.01`; when `0`, the check is disabled.
Given the current step `k`, parameter `previous_loss` here is the loss
at the step `k - p`, where `p` is a small positive integer set
internally.
Note, this parameter corresponds to `ftol` in
`scipy.optimize.minimize(method=’L-BFGS-B’)
<https://docs.scipy.org/doc/scipy/reference/optimize.minimize-lbfgsb.html>`_,
which is set by default to a miniscule `2.2e-9` and is not exposed in
`sklearn.LogisticRegression()
<https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LogisticRegression.html>`_.
This condition is meant to protect the solver against doing vanishingly
small linesearch steps or zigzagging.
You may choose to set `delta = 0` to make sure the cuML solver does
not stop earlier than the Scikit-learn solver.
linesearch_max_iter: int (default = 50)
Max number of linesearch iterations per outer iteration of the
algorithm.
Expand Down Expand Up @@ -266,7 +288,7 @@ class QN(Base,
Wright - Numerical Optimization (1999)]
- `Orthant-wise limited-memory quasi-newton (OWL-QN)
[Andrew, Gao - ICML 2007]
[Andrew, Gao - ICML 2007]
<https://www.microsoft.com/en-us/research/publication/scalable-training-of-l1-regularized-log-linear-models/>`_
"""

Expand All @@ -275,7 +297,7 @@ class QN(Base,

def __init__(self, *, loss='sigmoid', fit_intercept=True,
l1_strength=0.0, l2_strength=0.0, max_iter=1000, tol=1e-4,
linesearch_max_iter=50, lbfgs_memory=5,
delta=None, linesearch_max_iter=50, lbfgs_memory=5,
verbose=False, handle=None, output_type=None,
warm_start=False):

Expand All @@ -288,6 +310,7 @@ class QN(Base,
self.l2_strength = l2_strength
self.max_iter = max_iter
self.tol = tol
self.delta = delta
self.linesearch_max_iter = linesearch_max_iter
self.lbfgs_memory = lbfgs_memory
self.num_iter = 0
Expand Down Expand Up @@ -377,6 +400,8 @@ class QN(Base,

cdef int num_iters

delta = self.delta if self.delta is not None else (self.tol * 0.01)

if self.dtype == np.float32:
qnFit(handle_[0],
<float*>X_ptr,
Expand All @@ -389,6 +414,7 @@ class QN(Base,
<float> self.l2_strength,
<int> self.max_iter,
<float> self.tol,
<float> delta,
<int> self.linesearch_max_iter,
<int> self.lbfgs_memory,
<int> self.verbose,
Expand All @@ -413,6 +439,7 @@ class QN(Base,
<double> self.l2_strength,
<int> self.max_iter,
<double> self.tol,
<double> delta,
<int> self.linesearch_max_iter,
<int> self.lbfgs_memory,
<int> self.verbose,
Expand Down Expand Up @@ -579,4 +606,4 @@ class QN(Base,
return super().get_param_names() + \
['loss', 'fit_intercept', 'l1_strength', 'l2_strength',
'max_iter', 'tol', 'linesearch_max_iter', 'lbfgs_memory',
'warm_start']
'warm_start', 'delta']

0 comments on commit 93fed4d

Please sign in to comment.