Skip to content

Commit

Permalink
Add support for sample_weights in LinearRegression (#4428)
Browse files Browse the repository at this point in the history
Closes #4031.
Scikit-learn is rescaling the data ([here](https://github.com/scikit-learn/scikit-learn/blob/0d378913be6d7e485b792ea36e9268be31ed52d0/sklearn/linear_model/_base.py#L313)) to take into account the sample_weight parameter.

Authors:
  - Micka (https://github.com/lowener)

Approvers:
  - Dante Gama Dessavre (https://github.com/dantegd)
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #4428
  • Loading branch information
lowener authored Mar 10, 2022
1 parent 4c1d671 commit fc94e5f
Show file tree
Hide file tree
Showing 6 changed files with 160 additions and 25 deletions.
8 changes: 6 additions & 2 deletions cpp/include/cuml/linear_model/glm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ namespace GLM {
* @param normalize if true, normalize data to zero mean, unit variance
* @param algo specifies which solver to use (0: SVD, 1: Eigendecomposition, 2:
* QR-decomposition)
* @param sample_weight device pointer to sample weight vector of length n_rows (nullptr
for uniform weights)
* @{
*/
void olsFit(const raft::handle_t& handle,
Expand All @@ -45,7 +47,8 @@ void olsFit(const raft::handle_t& handle,
float* intercept,
bool fit_intercept,
bool normalize,
int algo = 0);
int algo = 0,
float* sample_weight = nullptr);
void olsFit(const raft::handle_t& handle,
double* input,
int n_rows,
Expand All @@ -55,7 +58,8 @@ void olsFit(const raft::handle_t& handle,
double* intercept,
bool fit_intercept,
bool normalize,
int algo = 0);
int algo = 0,
double* sample_weight = nullptr);
/** @} */

/**
Expand Down
12 changes: 8 additions & 4 deletions cpp/src/glm/glm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ void olsFit(const raft::handle_t& handle,
float* intercept,
bool fit_intercept,
bool normalize,
int algo)
int algo,
float* sample_weight)
{
olsFit(handle,
input,
Expand All @@ -47,7 +48,8 @@ void olsFit(const raft::handle_t& handle,
fit_intercept,
normalize,
handle.get_stream(),
algo);
algo,
sample_weight);
}

void olsFit(const raft::handle_t& handle,
Expand All @@ -59,7 +61,8 @@ void olsFit(const raft::handle_t& handle,
double* intercept,
bool fit_intercept,
bool normalize,
int algo)
int algo,
double* sample_weight)
{
olsFit(handle,
input,
Expand All @@ -71,7 +74,8 @@ void olsFit(const raft::handle_t& handle,
fit_intercept,
normalize,
handle.get_stream(),
algo);
algo,
sample_weight);
}

void gemmPredict(const raft::handle_t& handle,
Expand Down
37 changes: 35 additions & 2 deletions cpp/src/glm/ols.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@
#include <raft/linalg/add.hpp>
#include <raft/linalg/gemv.hpp>
#include <raft/linalg/lstsq.hpp>
#include <raft/linalg/map.cuh>
#include <raft/linalg/norm.hpp>
#include <raft/linalg/power.cuh>
#include <raft/linalg/sqrt.cuh>
#include <raft/linalg/subtract.hpp>
#include <raft/matrix/math.hpp>
#include <raft/matrix/matrix.hpp>
Expand Down Expand Up @@ -48,6 +51,8 @@ namespace GLM {
* @param stream cuda stream
* @param algo specifies which solver to use (0: SVD, 1: Eigendecomposition, 2:
* QR-decomposition)
* @param sample_weight device pointer to sample weight vector of length n_rows (nullptr for uniform
* weights)
*/
template <typename math_t>
void olsFit(const raft::handle_t& handle,
Expand All @@ -60,7 +65,8 @@ void olsFit(const raft::handle_t& handle,
bool fit_intercept,
bool normalize,
cudaStream_t stream,
int algo = 0)
int algo = 0,
math_t* sample_weight = nullptr)
{
auto cublas_handle = handle.get_cublas_handle();
auto cusolver_handle = handle.get_cusolver_dn_handle();
Expand All @@ -87,7 +93,21 @@ void olsFit(const raft::handle_t& handle,
norm2_input.data(),
fit_intercept,
normalize,
stream);
stream,
sample_weight);
}

if (sample_weight != nullptr) {
raft::linalg::sqrt(sample_weight, sample_weight, n_rows, stream);
raft::matrix::matrixVectorBinaryMult(
input, sample_weight, n_rows, n_cols, false, false, stream);
raft::linalg::map(
labels,
n_rows,
[] __device__(math_t a, math_t b) { return a * b; },
stream,
labels,
sample_weight);
}

int selectedAlgo = algo;
Expand All @@ -107,6 +127,19 @@ void olsFit(const raft::handle_t& handle,
}
raft::common::nvtx::pop_range();

if (sample_weight != nullptr) {
raft::matrix::matrixVectorBinaryDivSkipZero(
input, sample_weight, n_rows, n_cols, false, false, stream);
raft::linalg::map(
labels,
n_rows,
[] __device__(math_t a, math_t b) { return a / b; },
stream,
labels,
sample_weight);
raft::linalg::powerScalar(sample_weight, sample_weight, (math_t)2, n_rows, stream);
}

if (fit_intercept) {
postProcessData(handle,
input,
Expand Down
32 changes: 28 additions & 4 deletions cpp/src/glm/preprocess.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include <raft/stats/mean_center.hpp>
#include <raft/stats/meanvar.hpp>
#include <raft/stats/stddev.hpp>
#include <raft/stats/weighted_mean.cuh>
#include <rmm/device_scalar.hpp>
#include <rmm/device_uvector.hpp>

Expand Down Expand Up @@ -61,14 +62,15 @@ void preProcessData(const raft::handle_t& handle,
math_t* norm2_input,
bool fit_intercept,
bool normalize,
cudaStream_t stream)
cudaStream_t stream,
math_t* sample_weight = nullptr)
{
raft::common::nvtx::range fun_scope("ML::GLM::preProcessData-%d-%d", n_rows, n_cols);
ASSERT(n_cols > 0, "Parameter n_cols: number of columns cannot be less than one");
ASSERT(n_rows > 1, "Parameter n_rows: number of rows cannot be less than two");

if (fit_intercept) {
if (normalize) {
if (normalize && sample_weight == nullptr) {
raft::stats::meanvar(mu_input, norm2_input, input, n_cols, n_rows, false, false, stream);
raft::linalg::unaryOp(
norm2_input,
Expand All @@ -87,10 +89,32 @@ void preProcessData(const raft::handle_t& handle,
mu_input,
norm2_input);
} else {
raft::stats::mean(mu_input, input, n_cols, n_rows, false, false, stream);
if (sample_weight != nullptr) {
raft::stats::weightedMean(
mu_input, input, sample_weight, n_cols, n_rows, false, false, stream);
} else {
raft::stats::mean(mu_input, input, n_cols, n_rows, false, false, stream);
}
raft::stats::meanCenter(input, input, mu_input, n_cols, n_rows, false, true, stream);
if (normalize) {
raft::linalg::colNorm(norm2_input,
input,
n_cols,
n_rows,
raft::linalg::L2Norm,
false,
stream,
[] __device__(math_t v) { return raft::mySqrt(v); });
raft::matrix::matrixVectorBinaryDivSkipZero(
input, norm2_input, n_rows, n_cols, false, true, stream, true);
}
}

if (sample_weight != nullptr) {
raft::stats::weightedMean(mu_labels, labels, sample_weight, 1, n_rows, true, false, stream);
} else {
raft::stats::mean(mu_labels, labels, 1, n_rows, false, false, stream);
}
raft::stats::mean(mu_labels, labels, 1, n_rows, false, false, stream);
raft::stats::meanCenter(labels, labels, mu_labels, 1, n_rows, false, true, stream);
}
}
Expand Down
31 changes: 25 additions & 6 deletions python/cuml/linear_model/linear_regression.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,9 @@ cdef extern from "cuml/linear_model/glm.hpp" namespace "ML::GLM":
float *coef,
float *intercept,
bool fit_intercept,
bool normalize, int algo) except +
bool normalize,
int algo,
float *sample_weight) except +

cdef void olsFit(handle_t& handle,
double *input,
Expand All @@ -59,7 +61,9 @@ cdef extern from "cuml/linear_model/glm.hpp" namespace "ML::GLM":
double *coef,
double *intercept,
bool fit_intercept,
bool normalize, int algo) except +
bool normalize,
int algo,
double *sample_weight) except +


class LinearRegression(Base,
Expand Down Expand Up @@ -242,12 +246,13 @@ class LinearRegression(Base,
}[algorithm]

@generate_docstring()
def fit(self, X, y, convert_dtype=True) -> "LinearRegression":
def fit(self, X, y, convert_dtype=True,
sample_weight=None) -> "LinearRegression":
"""
Fit the model with X and y.

"""
cdef uintptr_t X_ptr, y_ptr
cdef uintptr_t X_ptr, y_ptr, sample_weight_ptr
X_m, n_rows, self.n_cols, self.dtype = \
input_to_cuml_array(X, check_dtype=[np.float32, np.float64])
X_ptr = X_m.ptr
Expand All @@ -259,6 +264,16 @@ class LinearRegression(Base,
check_rows=n_rows, check_cols=1)
y_ptr = y_m.ptr

if sample_weight is not None:
sample_weight_m, _, _, _ = \
input_to_cuml_array(sample_weight, check_dtype=self.dtype,
convert_to_dtype=(
self.dtype if convert_dtype else None),
check_rows=n_rows, check_cols=1)
sample_weight_ptr = sample_weight_m.ptr
else:
sample_weight_ptr = 0

if self.n_cols < 1:
msg = "X matrix must have at least a column"
raise TypeError(msg)
Expand Down Expand Up @@ -291,7 +306,8 @@ class LinearRegression(Base,
<float*>&c_intercept1,
<bool>self.fit_intercept,
<bool>self.normalize,
<int>self.algo)
<int>self.algo,
<float*>sample_weight_ptr)

self.intercept_ = c_intercept1
else:
Expand All @@ -304,14 +320,17 @@ class LinearRegression(Base,
<double*>&c_intercept2,
<bool>self.fit_intercept,
<bool>self.normalize,
<int>self.algo)
<int>self.algo,
<double*>sample_weight_ptr)

self.intercept_ = c_intercept2

self.handle.sync()

del X_m
del y_m
if sample_weight is not None:
del sample_weight_m

return self

Expand Down
65 changes: 58 additions & 7 deletions python/cuml/test/test_linear_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,23 +47,26 @@
"set degenerate(.*)::sklearn[.*]")


def _make_regression_dataset_uncached(nrows, ncols, n_info):
def _make_regression_dataset_uncached(nrows, ncols, n_info, **kwargs):
X, y = make_regression(
n_samples=nrows, n_features=ncols, n_informative=n_info, random_state=0
**kwargs, n_samples=nrows, n_features=ncols, n_informative=n_info,
random_state=0
)
return train_test_split(X, y, train_size=0.8, random_state=10)


@lru_cache(4)
def _make_regression_dataset_from_cache(nrows, ncols, n_info):
return _make_regression_dataset_uncached(nrows, ncols, n_info)
def _make_regression_dataset_from_cache(nrows, ncols, n_info, **kwargs):
return _make_regression_dataset_uncached(nrows, ncols, n_info, **kwargs)


def make_regression_dataset(datatype, nrows, ncols, n_info):
def make_regression_dataset(datatype, nrows, ncols, n_info, **kwargs):
if nrows * ncols < 1e8: # Keep cache under 4 GB
dataset = _make_regression_dataset_from_cache(nrows, ncols, n_info)
dataset = _make_regression_dataset_from_cache(nrows, ncols, n_info,
**kwargs)
else:
dataset = _make_regression_dataset_uncached(nrows, ncols, n_info)
dataset = _make_regression_dataset_uncached(nrows, ncols, n_info,
**kwargs)

return map(lambda arr: arr.astype(datatype), dataset)

Expand Down Expand Up @@ -129,6 +132,54 @@ def test_linear_regression_model(datatype, algorithm, nrows, column_info):
assert array_equal(skols_predict, cuols_predict, 1e-1, with_sign=True)


@pytest.mark.parametrize("datatype", [np.float32, np.float64])
@pytest.mark.parametrize("algorithm", ["eig", "svd", "qr", "svd-qr"])
@pytest.mark.parametrize(
"fit_intercept, normalize, distribution", [
(True, True, "lognormal"),
(True, True, "exponential"),
(True, False, "uniform"),
(True, False, "exponential"),
(False, True, "lognormal"),
(False, False, "uniform"),
]
)
def test_weighted_linear_regression(datatype, algorithm, fit_intercept,
normalize, distribution):
nrows, ncols, n_info = 1000, 20, 10
max_weight = 10
noise = 20
X_train, X_test, y_train, y_test = make_regression_dataset(
datatype, nrows, ncols, n_info, noise=noise
)

# set weight per sample to be from 1 to max_weight
if distribution == "uniform":
wt = np.random.randint(1, high=max_weight, size=len(X_train))
elif distribution == "exponential":
wt = np.random.exponential(size=len(X_train))
else:
wt = np.random.lognormal(size=len(X_train))

# Initialization of cuML's linear regression model
cuols = cuLinearRegression(fit_intercept=fit_intercept,
normalize=normalize,
algorithm=algorithm)

# fit and predict cuml linear regression model
cuols.fit(X_train, y_train, sample_weight=wt)
cuols_predict = cuols.predict(X_test)

# sklearn linear regression model initialization, fit and predict
skols = skLinearRegression(fit_intercept=fit_intercept,
normalize=normalize)
skols.fit(X_train, y_train, sample_weight=wt)

skols_predict = skols.predict(X_test)

assert array_equal(skols_predict, cuols_predict, 1e-1, with_sign=True)


@pytest.mark.skipif(
rmm._cuda.gpu.runtimeGetVersion() < 11000,
reason='svd solver does not support more than 46340 rows or columns for'
Expand Down

0 comments on commit fc94e5f

Please sign in to comment.