From 5bb564e8f31b37809e7fe5880267582ea9e39481 Mon Sep 17 00:00:00 2001 From: Victor Lafargue Date: Tue, 23 Mar 2021 21:07:19 +0100 Subject: [PATCH] Support for sample_weight parameter in LogisticRegression (#3572) Closes #3559 This PR adds the `sample_weight` and `class_weight` parameters to the `LogisticRegression` estimator. Authors: - Victor Lafargue (@viclafargue) Approvers: - Dante Gama Dessavre (@dantegd) URL: https://github.com/rapidsai/cuml/pull/3572 --- cpp/include/cuml/linear_model/glm.hpp | 4 +- cpp/src/glm/glm.cu | 10 +-- cpp/src/glm/qn/glm_base.cuh | 56 +++++++++---- cpp/src/glm/qn/qn.cuh | 5 +- .../cuml/linear_model/logistic_regression.pyx | 83 +++++++++++++++++-- python/cuml/solvers/qn.pyx | 29 +++++-- python/cuml/test/test_linear_model.py | 83 +++++++++++++++++++ 7 files changed, 234 insertions(+), 36 deletions(-) diff --git a/cpp/include/cuml/linear_model/glm.hpp b/cpp/include/cuml/linear_model/glm.hpp index 5f26b182ce..e6106cca83 100644 --- a/cpp/include/cuml/linear_model/glm.hpp +++ b/cpp/include/cuml/linear_model/glm.hpp @@ -124,12 +124,12 @@ void qnFit(const raft::handle_t &cuml_handle, float *X, float *y, int N, int D, int C, bool fit_intercept, float l1, float l2, int max_iter, float grad_tol, int linesearch_max_iter, int lbfgs_memory, int verbosity, float *w0, float *f, int *num_iters, bool X_col_major, - int loss_type); + int loss_type, float *sample_weight = nullptr); void qnFit(const raft::handle_t &cuml_handle, double *X, double *y, int N, int D, int C, bool fit_intercept, double l1, double l2, int max_iter, double grad_tol, int linesearch_max_iter, int lbfgs_memory, int verbosity, double *w0, double *f, int *num_iters, - bool X_col_major, int loss_type); + bool X_col_major, int loss_type, double *sample_weight = nullptr); /** @} */ /** diff --git a/cpp/src/glm/glm.cu b/cpp/src/glm/glm.cu index 52a705b048..0635068097 100644 --- a/cpp/src/glm/glm.cu +++ b/cpp/src/glm/glm.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2020, NVIDIA CORPORATION. + * Copyright (c) 2018-2021, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -71,20 +71,20 @@ void qnFit(const raft::handle_t &cuml_handle, float *X, float *y, int N, int D, int C, bool fit_intercept, float l1, float l2, int max_iter, float grad_tol, int linesearch_max_iter, int lbfgs_memory, int verbosity, float *w0, float *f, int *num_iters, bool X_col_major, - int loss_type) { + int loss_type, float *sample_weight) { qnFit(cuml_handle, X, y, N, D, C, fit_intercept, l1, l2, max_iter, grad_tol, linesearch_max_iter, lbfgs_memory, verbosity, w0, f, num_iters, - X_col_major, loss_type, cuml_handle.get_stream()); + X_col_major, loss_type, cuml_handle.get_stream(), sample_weight); } void qnFit(const raft::handle_t &cuml_handle, double *X, double *y, int N, int D, int C, bool fit_intercept, double l1, double l2, int max_iter, double grad_tol, int linesearch_max_iter, int lbfgs_memory, int verbosity, double *w0, double *f, int *num_iters, - bool X_col_major, int loss_type) { + bool X_col_major, int loss_type, double *sample_weight) { qnFit(cuml_handle, X, y, N, D, C, fit_intercept, l1, l2, max_iter, grad_tol, linesearch_max_iter, lbfgs_memory, verbosity, w0, f, num_iters, - X_col_major, loss_type, cuml_handle.get_stream()); + X_col_major, loss_type, cuml_handle.get_stream(), sample_weight); } void qnDecisionFunction(const raft::handle_t &cuml_handle, float *X, int N, diff --git a/cpp/src/glm/qn/glm_base.cuh b/cpp/src/glm/qn/glm_base.cuh index 1dcfc9b417..25a3b627fb 100644 --- a/cpp/src/glm/qn/glm_base.cuh +++ b/cpp/src/glm/qn/glm_base.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2020, NVIDIA CORPORATION. + * Copyright (c) 2018-2021, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,6 +21,7 @@ #include #include #include +#include #include #include #include @@ -96,9 +97,22 @@ struct GLMBase : GLMDims { typedef SimpleVec Vec; const raft::handle_t &handle; + T *sample_weights; + T weights_sum; GLMBase(const raft::handle_t &handle, int D, int C, bool fit_intercept) - : GLMDims(C, D, fit_intercept), handle(handle) {} + : GLMDims(C, D, fit_intercept), + handle(handle), + sample_weights(nullptr), + weights_sum(0) {} + + void add_sample_weights(T *sample_weights, int n_samples, + cudaStream_t stream) { + this->sample_weights = sample_weights; + this->weights_sum = + thrust::reduce(thrust::cuda::par.on(stream), sample_weights, + sample_weights + n_samples, (T)0, thrust::plus()); + } /* * Computes the following: @@ -111,22 +125,36 @@ struct GLMBase : GLMDims { cudaStream_t stream) { // Base impl assumes simple case C = 1 Loss *loss = static_cast(this); - T invN = 1.0 / y.len; - - auto f_l = [=] __device__(const T y, const T z) { - return loss->lz(y, z) * invN; - }; // TODO would be nice to have a kernel that fuses these two steps // This would be easy, if mapThenSumReduce allowed outputing the result of // map (supporting inplace) - raft::linalg::mapThenSumReduce(loss_val, y.len, f_l, stream, y.data, - Z.data); - - auto f_dl = [=] __device__(const T y, const T z) { - return loss->dlz(y, z); - }; - raft::linalg::binaryOp(Z.data, y.data, Z.data, y.len, f_dl, stream); + if (this->sample_weights) { // Sample weights are in use + T normalization = 1.0 / this->weights_sum; + auto f_l = [=] __device__(const T y, const T z, const T weight) { + return loss->lz(y, z) * (weight * normalization); + }; + raft::linalg::mapThenSumReduce(loss_val, y.len, f_l, stream, y.data, + Z.data, sample_weights); + + auto f_dl = [=] __device__(const T y, const T z, const T weight) { + return weight * loss->dlz(y, z); + }; + raft::linalg::map(Z.data, y.len, f_dl, stream, y.data, Z.data, + sample_weights); + } else { // Sample weights are not used + T normalization = 1.0 / y.len; + auto f_l = [=] __device__(const T y, const T z) { + return loss->lz(y, z) * normalization; + }; + raft::linalg::mapThenSumReduce(loss_val, y.len, f_l, stream, y.data, + Z.data); + + auto f_dl = [=] __device__(const T y, const T z) { + return loss->dlz(y, z); + }; + raft::linalg::binaryOp(Z.data, y.data, Z.data, y.len, f_dl, stream); + } } inline void loss_grad(T *loss_val, Mat &G, const Mat &W, diff --git a/cpp/src/glm/qn/qn.cuh b/cpp/src/glm/qn/qn.cuh index 98655a512b..b5fe812b50 100644 --- a/cpp/src/glm/qn/qn.cuh +++ b/cpp/src/glm/qn/qn.cuh @@ -64,7 +64,7 @@ void qnFit(const raft::handle_t &handle, T *X, T *y, int N, int D, int C, bool fit_intercept, T l1, T l2, int max_iter, T grad_tol, int linesearch_max_iter, int lbfgs_memory, int verbosity, T *w0, T *f, int *num_iters, bool X_col_major, int loss_type, - cudaStream_t stream) { + cudaStream_t stream, T *sample_weight = nullptr) { STORAGE_ORDER ord = X_col_major ? COL_MAJOR : ROW_MAJOR; int C_len = (loss_type == 0) ? (C - 1) : C; @@ -75,6 +75,7 @@ void qnFit(const raft::handle_t &handle, T *X, T *y, int N, int D, int C, case 0: { ASSERT(C == 2, "qn.h: logistic loss invalid C"); LogisticLoss loss(handle, D, fit_intercept); + if (sample_weight) loss.add_sample_weights(sample_weight, N, stream); qn_fit(handle, loss, X, y, z.data, N, l1, l2, max_iter, grad_tol, linesearch_max_iter, lbfgs_memory, verbosity, w0, f, num_iters, ord, stream); @@ -82,6 +83,7 @@ void qnFit(const raft::handle_t &handle, T *X, T *y, int N, int D, int C, case 1: { ASSERT(C == 1, "qn.h: squared loss invalid C"); SquaredLoss loss(handle, D, fit_intercept); + if (sample_weight) loss.add_sample_weights(sample_weight, N, stream); qn_fit(handle, loss, X, y, z.data, N, l1, l2, max_iter, grad_tol, linesearch_max_iter, lbfgs_memory, verbosity, w0, f, num_iters, ord, stream); @@ -89,6 +91,7 @@ void qnFit(const raft::handle_t &handle, T *X, T *y, int N, int D, int C, case 2: { ASSERT(C > 2, "qn.h: softmax invalid C"); Softmax loss(handle, D, C, fit_intercept); + if (sample_weight) loss.add_sample_weights(sample_weight, N, stream); qn_fit(handle, loss, X, y, z.data, N, l1, l2, max_iter, grad_tol, linesearch_max_iter, lbfgs_memory, verbosity, w0, f, num_iters, ord, stream); diff --git a/python/cuml/linear_model/logistic_regression.pyx b/python/cuml/linear_model/logistic_regression.pyx index 666b058dae..e139f32df7 100644 --- a/python/cuml/linear_model/logistic_regression.pyx +++ b/python/cuml/linear_model/logistic_regression.pyx @@ -16,6 +16,7 @@ # distutils: language = c++ +import numpy as np import cupy as cp import pprint @@ -126,6 +127,15 @@ class LogisticRegression(Base, If False, the model expects that you have centered the data. class_weight: None Custom class weighs are currently not supported. + class_weight: dict or 'balanced', default=None + By default all classes have a weight one. However, a dictionary + can be provided with weights associated with classes + in the form ``{class_label: weight}``. The "balanced" mode + uses the values of y to automatically adjust weights + inversely proportional to class frequencies in the input data + as ``n_samples / (n_classes * np.bincount(y))``. Note that + these weights will be multiplied with sample_weight + (passed through the fit method) if sample_weight is specified. max_iter: int (default = 1000) Maximum number of iterations taken for the solvers to converge. linesearch_max_iter: int (default = 50) @@ -178,6 +188,8 @@ class LogisticRegression(Base, """ classes_ = CumlArrayDescriptor() + class_weight_ = CumlArrayDescriptor() + expl_spec_weights_ = CumlArrayDescriptor() def __init__( self, @@ -199,9 +211,6 @@ class LogisticRegression(Base, handle=handle, verbose=verbose, output_type=output_type ) - if class_weight: - raise ValueError("`class_weight` not supported.") - if penalty not in supported_penalties: raise ValueError("`penalty` " + str(penalty) + "not supported.") @@ -246,6 +255,21 @@ class LogisticRegression(Base, loss = "sigmoid" + if class_weight is not None: + if class_weight == 'balanced': + self.class_weight_ = 'balanced' + else: + classes = list(class_weight.keys()) + weights = list(class_weight.values()) + max_class = sorted(classes)[-1] + class_weight = cp.ones(max_class + 1) + class_weight[classes] = weights + self.class_weight_, _, _, _ = input_to_cuml_array(class_weight) + self.expl_spec_weights_, _, _, _ = \ + input_to_cuml_array(np.array(classes)) + else: + self.class_weight_ = None + self.solver_model = QN( loss=loss, fit_intercept=self.fit_intercept, @@ -267,7 +291,8 @@ class LogisticRegression(Base, @generate_docstring() @cuml.internals.api_base_return_any(set_output_dtype=True) - def fit(self, X, y, convert_dtype=True) -> "LogisticRegression": + def fit(self, X, y, sample_weight=None, + convert_dtype=True) -> "LogisticRegression": """ Fit the model with X and y. @@ -275,11 +300,54 @@ class LogisticRegression(Base, # Converting y to device array here to use `unique` function # since calling input_to_cuml_array again in QN has no cost # Not needed to check dtype since qn class checks it already - y_m, _, _, _ = input_to_cuml_array(y) - + y_m, n_rows, _, _ = input_to_cuml_array(y) self.classes_ = cp.unique(y_m) self._num_classes = len(self.classes_) + if self._num_classes == 2: + if self.classes_[0] != 0 or self.classes_[1] != 1: + raise ValueError("Only values of 0 and 1 are" + " supported for binary classification.") + + if sample_weight is not None or self.class_weight_ is not None: + if sample_weight is None: + sample_weight = cp.ones(n_rows) + + sample_weight, n_weights, D, _ = input_to_cuml_array(sample_weight) + + if n_rows != n_weights or D != 1: + raise ValueError("sample_weight.shape == {}, " + "expected ({},)!".format(sample_weight.shape, + n_rows)) + + def check_expl_spec_weights(): + with cuml.using_output_type("numpy"): + for c in self.expl_spec_weights_: + i = np.searchsorted(self.classes_, c) + if i >= self._num_classes or self.classes_[i] != c: + msg = "Class label {} not present.".format(c) + raise ValueError(msg) + + if self.class_weight_ is not None: + if self.class_weight_ == 'balanced': + class_weight = n_rows / \ + (self._num_classes * + cp.bincount(y_m.to_output('cupy'))) + class_weight = CumlArray(class_weight) + else: + check_expl_spec_weights() + n_explicit = self.class_weight_.shape[0] + if n_explicit != self._num_classes: + class_weight = cp.ones(self._num_classes) + class_weight[:n_explicit] = self.class_weight_ + class_weight = CumlArray(class_weight) + self.class_weight_ = class_weight + else: + class_weight = self.class_weight_ + out = y_m.to_output('cupy') + sample_weight *= class_weight[out].to_output('cupy') + sample_weight = CumlArray(sample_weight) + if self._num_classes > 2: loss = "softmax" else: @@ -293,7 +361,8 @@ class LogisticRegression(Base, if logger.should_log_for(logger.level_debug): logger.debug(self.verb_prefix + "Calling QN fit " + str(loss)) - self.solver_model.fit(X, y_m, convert_dtype=convert_dtype) + self.solver_model.fit(X, y_m, sample_weight=sample_weight, + convert_dtype=convert_dtype) # coefficients and intercept are contained in the same array if logger.should_log_for(logger.level_debug): diff --git a/python/cuml/solvers/qn.pyx b/python/cuml/solvers/qn.pyx index 1c0a5e6454..bd2a433617 100644 --- a/python/cuml/solvers/qn.pyx +++ b/python/cuml/solvers/qn.pyx @@ -53,7 +53,8 @@ cdef extern from "cuml/linear_model/glm.hpp" namespace "ML::GLM": float *f, int *num_iters, bool X_col_major, - int loss_type) except + + int loss_type, + float *sample_weight) except + void qnFit(handle_t& cuml_handle, double *X, @@ -73,7 +74,8 @@ cdef extern from "cuml/linear_model/glm.hpp" namespace "ML::GLM": double *f, int *num_iters, bool X_col_major, - int loss_type) except + + int loss_type, + double *sample_weight) except + void qnDecisionFunction(handle_t& cuml_handle, float *X, @@ -291,7 +293,7 @@ class QN(Base, return self._coef_ @generate_docstring() - def fit(self, X, y, convert_dtype=False) -> "QN": + def fit(self, X, y, sample_weight=None, convert_dtype=False) -> "QN": """ Fit the model with X and y. @@ -310,14 +312,25 @@ class QN(Base, self._num_classes = len(cp.unique(y_m)) + cdef uintptr_t sample_weight_ptr = 0 + if sample_weight is not None: + sample_weight, _, _, _ = \ + input_to_cuml_array(sample_weight, + check_dtype=self.dtype, + check_rows=n_rows, check_cols=1, + convert_to_dtype=(self.dtype + if convert_dtype + else None)) + sample_weight_ptr = sample_weight.ptr + self.loss_type = self._get_loss_int(self.loss) if self.loss_type != 2 and self._num_classes > 2: raise ValueError("Only softmax (multinomial) loss supports more" "than 2 classes.") if self.loss_type == 2 and self._num_classes <= 2: - raise ValueError("Only softmax (multinomial) loss supports more" - "than 2 classes.") + raise ValueError("Two classes or less cannot be trained" + "with softmax (multinomial).") if self.loss_type == 0: self._num_classes_dim = self._num_classes - 1 @@ -357,7 +370,8 @@ class QN(Base, &objective32, &num_iters, True, - self.loss_type) + self.loss_type, + sample_weight_ptr) self.objective = objective32 @@ -380,7 +394,8 @@ class QN(Base, &objective64, &num_iters, True, - self.loss_type) + self.loss_type, + sample_weight_ptr) self.objective = objective64 diff --git a/python/cuml/test/test_linear_model.py b/python/cuml/test/test_linear_model.py index 63660528c7..720d263fbd 100644 --- a/python/cuml/test/test_linear_model.py +++ b/python/cuml/test/test_linear_model.py @@ -488,3 +488,86 @@ def test_logistic_predict_convert_dtype(train_dtype, test_dtype): clf = cuLog() clf.fit(X_train, y_train) clf.predict(X_test.astype(test_dtype)) + + +@pytest.fixture(scope='session', + params=['binary', 'multiclass-3', 'multiclass-7']) +def regression_dataset(request): + regression_type = request.param + + out = {} + for test_status in ['regular', 'stress_test']: + if test_status == 'regular': + n_samples, n_features = 100000, 5 + elif test_status == 'stress_test': + n_samples, n_features = 1000000, 20 + + data = (np.random.rand(n_samples, n_features) * 2) - 1 + + if regression_type == 'binary': + coef = (np.random.rand(n_features) * 2) - 1 + coef /= np.linalg.norm(coef) + output = (data @ coef) > 0 + elif regression_type.startswith('multiclass'): + n_classes = 3 if regression_type == 'multiclass-3' else 7 + coef = (np.random.rand(n_features, n_classes) * 2) - 1 + coef /= np.linalg.norm(coef, axis=0) + output = (data @ coef).argmax(axis=1) + output = output.astype(np.int32) + + out[test_status] = (regression_type, data, coef, output) + return out + + +@pytest.mark.parametrize('option', ['sample_weight', 'class_weight', + 'balanced', 'no_weight']) +@pytest.mark.parametrize('test_status', ['regular', + stress_param('stress_test')]) +def test_logistic_regression_weighting(regression_dataset, + option, test_status): + regression_type, data, coef, output = regression_dataset[test_status] + + class_weight = None + sample_weight = None + if option == 'sample_weight': + n_samples = data.shape[0] + sample_weight = np.abs(np.random.rand(n_samples)) + elif option == 'class_weight': + class_weight = np.random.rand(2) + class_weight = {0: class_weight[0], 1: class_weight[1]} + elif option == 'balanced': + class_weight = 'balanced' + + culog = cuLog(fit_intercept=False, class_weight=class_weight) + culog.fit(data, output, sample_weight=sample_weight) + + sklog = skLog(fit_intercept=False, class_weight=class_weight) + sklog.fit(data, output, sample_weight=sample_weight) + + skcoef = np.squeeze(sklog.coef_) + cucoef = np.squeeze(culog.coef_) + if regression_type == 'binary': + skcoef /= np.linalg.norm(skcoef) + cucoef /= np.linalg.norm(cucoef) + unit_tol = 0.04 + total_tol = 0.08 + elif regression_type.startswith('multiclass'): + skcoef = skcoef.T + skcoef /= np.linalg.norm(skcoef, axis=1)[:, None] + cucoef /= np.linalg.norm(cucoef, axis=1)[:, None] + unit_tol = 0.2 + total_tol = 0.3 + + equality = array_equal(skcoef, cucoef, unit_tol=unit_tol, + total_tol=total_tol) + if not equality: + print('\ncoef.shape: ', coef.shape) + print('coef:\n', coef) + print('cucoef.shape: ', cucoef.shape) + print('cucoef:\n', cucoef) + assert equality + + cuOut = culog.predict(data) + skOut = sklog.predict(data) + assert array_equal(skOut, cuOut, unit_tol=unit_tol, + total_tol=total_tol)