From 1451325b65a72637329716f6cde282e530d0d8f2 Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Thu, 2 Dec 2021 11:56:02 -0800 Subject: [PATCH 01/15] First implementation of linear regression weights --- cpp/include/cuml/linear_model/glm.hpp | 7 +++-- cpp/src/glm/ols.cuh | 20 ++++++++++++- .../cuml/linear_model/linear_regression.pyx | 30 +++++++++++++++---- 3 files changed, 48 insertions(+), 9 deletions(-) diff --git a/cpp/include/cuml/linear_model/glm.hpp b/cpp/include/cuml/linear_model/glm.hpp index 19c3d734a8..3b3cdcf795 100644 --- a/cpp/include/cuml/linear_model/glm.hpp +++ b/cpp/include/cuml/linear_model/glm.hpp @@ -32,6 +32,7 @@ namespace GLM { * @param normalize if true, normalize data to zero mean, unit variance * @param algo specifies which solver to use (0: SVD, 1: Eigendecomposition, 2: * QR-decomposition) + * @param sample_weights device pointer to sample weights vector of length n_rows * @{ */ void olsFit(const raft::handle_t& handle, @@ -43,7 +44,8 @@ void olsFit(const raft::handle_t& handle, float* intercept, bool fit_intercept, bool normalize, - int algo = 0); + int algo = 0, + float* sample_weights = nullptr); void olsFit(const raft::handle_t& handle, double* input, int n_rows, @@ -53,7 +55,8 @@ void olsFit(const raft::handle_t& handle, double* intercept, bool fit_intercept, bool normalize, - int algo = 0); + int algo = 0, + double *sample_weights = nullptr); /** @} */ /** diff --git a/cpp/src/glm/ols.cuh b/cpp/src/glm/ols.cuh index 0334f72906..95fdf1f675 100644 --- a/cpp/src/glm/ols.cuh +++ b/cpp/src/glm/ols.cuh @@ -49,6 +49,7 @@ using namespace MLCommon; * @param stream cuda stream * @param algo specifies which solver to use (0: SVD, 1: Eigendecomposition, 2: * QR-decomposition) + * @param sample_weights device pointer to sample weights vector of length n_rows */ template void olsFit(const raft::handle_t& handle, @@ -61,7 +62,8 @@ void olsFit(const raft::handle_t& handle, bool fit_intercept, bool normalize, cudaStream_t stream, - int algo = 0) + int algo = 0, + math_t* sample_weights = nullptr) { auto cublas_handle = handle.get_cublas_handle(); auto cusolver_handle = handle.get_cusolver_dn_handle(); @@ -73,6 +75,14 @@ void olsFit(const raft::handle_t& handle, rmm::device_uvector norm2_input(0, stream); rmm::device_uvector mu_labels(0, stream); + if (sample_weights != nullptr) { + LinAlg::sqrt(sample_weights, sample_weights, n_rows, stream); + raft::matrix::matrixVectorBinaryMult(input, sample_weights, n_rows, n_cols, false, false, stream); + raft::linalg::map(label, n_rows, + [] __device__(math_t a, math_t b) { return a * b; }, + stream, label, sample_weights); + } + if (fit_intercept) { mu_input.resize(n_cols, stream); mu_labels.resize(1, stream); @@ -123,6 +133,14 @@ void olsFit(const raft::handle_t& handle, } else { *intercept = math_t(0); } + + if (sample_weights != nullptr) { + raft::matrix::matrixVectorBinaryDivSkipZero(input, sample_weights, n_rows, n_cols, false, false, stream); + raft::linalg::map(label, n_rows, + [] __device__(math_t a, math_t b) { return a / b; }, + stream, label, sample_weights); + LinAlg::powerScalar(sample_weights, sample_weights, 2, n_rows, stream); + } } /** diff --git a/python/cuml/linear_model/linear_regression.pyx b/python/cuml/linear_model/linear_regression.pyx index a4023e9307..866b3a5ae5 100644 --- a/python/cuml/linear_model/linear_regression.pyx +++ b/python/cuml/linear_model/linear_regression.pyx @@ -49,7 +49,9 @@ cdef extern from "cuml/linear_model/glm.hpp" namespace "ML::GLM": float *coef, float *intercept, bool fit_intercept, - bool normalize, int algo) except + + bool normalize, + int algo, + float *sample_weights) except + cdef void olsFit(handle_t& handle, double *input, @@ -59,7 +61,9 @@ cdef extern from "cuml/linear_model/glm.hpp" namespace "ML::GLM": double *coef, double *intercept, bool fit_intercept, - bool normalize, int algo) except + + bool normalize, + int algo, + double *sample_weights) except + class LinearRegression(Base, @@ -239,12 +243,12 @@ class LinearRegression(Base, }[algorithm] @generate_docstring() - def fit(self, X, y, convert_dtype=True) -> "LinearRegression": + def fit(self, X, y, convert_dtype=True, sample_weights=None) -> "LinearRegression": """ Fit the model with X and y. """ - cdef uintptr_t X_ptr, y_ptr + cdef uintptr_t X_ptr, y_ptr, sample_weights_ptr X_m, n_rows, self.n_cols, self.dtype = \ input_to_cuml_array(X, check_dtype=[np.float32, np.float64]) X_ptr = X_m.ptr @@ -256,6 +260,16 @@ class LinearRegression(Base, check_rows=n_rows, check_cols=1) y_ptr = y_m.ptr + if sample_weights: + sample_weights_m, _, _, _ = \ + input_to_cuml_array(sample_weights, check_dtype=self.dtype, + convert_to_dtype=(self.dtype if convert_dtype + else None), + check_rows=n_rows, check_cols=1) + sample_weights_ptr = sample_weights_m.ptr + else: + sample_weights_ptr = 0 + if self.n_cols < 1: msg = "X matrix must have at least a column" raise TypeError(msg) @@ -288,7 +302,8 @@ class LinearRegression(Base, &c_intercept1, self.fit_intercept, self.normalize, - self.algo) + self.algo, + sample_weights_ptr) self.intercept_ = c_intercept1 else: @@ -301,7 +316,8 @@ class LinearRegression(Base, &c_intercept2, self.fit_intercept, self.normalize, - self.algo) + self.algo, + sample_weights_ptr) self.intercept_ = c_intercept2 @@ -309,6 +325,8 @@ class LinearRegression(Base, del X_m del y_m + if sample_weights: + del sample_weights_ptr return self From 99cdebed3655a8c586faabe4d644f879f52cd7c5 Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Fri, 3 Dec 2021 08:03:49 -0800 Subject: [PATCH 02/15] Fix typos --- cpp/src/glm/glm.cu | 12 ++++++++---- cpp/src/glm/ols.cuh | 13 ++++++++----- python/cuml/linear_model/linear_regression.pyx | 3 +-- 3 files changed, 17 insertions(+), 11 deletions(-) diff --git a/cpp/src/glm/glm.cu b/cpp/src/glm/glm.cu index 210e40f649..f387dec4fb 100644 --- a/cpp/src/glm/glm.cu +++ b/cpp/src/glm/glm.cu @@ -37,7 +37,8 @@ void olsFit(const raft::handle_t& handle, float* intercept, bool fit_intercept, bool normalize, - int algo) + int algo, + float *sample_weights) { olsFit(handle, input, @@ -49,7 +50,8 @@ void olsFit(const raft::handle_t& handle, fit_intercept, normalize, handle.get_stream(), - algo); + algo, + sample_weights); } void olsFit(const raft::handle_t& handle, @@ -61,7 +63,8 @@ void olsFit(const raft::handle_t& handle, double* intercept, bool fit_intercept, bool normalize, - int algo) + int algo, + double *sample_weights) { olsFit(handle, input, @@ -73,7 +76,8 @@ void olsFit(const raft::handle_t& handle, fit_intercept, normalize, handle.get_stream(), - algo); + algo, + sample_weights); } void gemmPredict(const raft::handle_t& handle, diff --git a/cpp/src/glm/ols.cuh b/cpp/src/glm/ols.cuh index 95fdf1f675..b00df11886 100644 --- a/cpp/src/glm/ols.cuh +++ b/cpp/src/glm/ols.cuh @@ -18,6 +18,9 @@ #include #include +#include +#include +#include #include #include #include @@ -78,9 +81,9 @@ void olsFit(const raft::handle_t& handle, if (sample_weights != nullptr) { LinAlg::sqrt(sample_weights, sample_weights, n_rows, stream); raft::matrix::matrixVectorBinaryMult(input, sample_weights, n_rows, n_cols, false, false, stream); - raft::linalg::map(label, n_rows, + raft::linalg::map(labels, n_rows, [] __device__(math_t a, math_t b) { return a * b; }, - stream, label, sample_weights); + stream, labels, sample_weights); } if (fit_intercept) { @@ -136,10 +139,10 @@ void olsFit(const raft::handle_t& handle, if (sample_weights != nullptr) { raft::matrix::matrixVectorBinaryDivSkipZero(input, sample_weights, n_rows, n_cols, false, false, stream); - raft::linalg::map(label, n_rows, + raft::linalg::map(labels, n_rows, [] __device__(math_t a, math_t b) { return a / b; }, - stream, label, sample_weights); - LinAlg::powerScalar(sample_weights, sample_weights, 2, n_rows, stream); + stream, labels, sample_weights); + LinAlg::powerScalar(sample_weights, sample_weights, (math_t)2, n_rows, stream); } } diff --git a/python/cuml/linear_model/linear_regression.pyx b/python/cuml/linear_model/linear_regression.pyx index 866b3a5ae5..5a6861e08e 100644 --- a/python/cuml/linear_model/linear_regression.pyx +++ b/python/cuml/linear_model/linear_regression.pyx @@ -325,8 +325,7 @@ class LinearRegression(Base, del X_m del y_m - if sample_weights: - del sample_weights_ptr + del sample_weights_m return self From 18277955471eba4d7afff46a3c40d807f239aaad Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Fri, 3 Dec 2021 11:38:24 -0800 Subject: [PATCH 03/15] Add test and fix typo --- cpp/include/cuml/linear_model/glm.hpp | 6 ++--- cpp/src/glm/glm.cu | 8 +++--- cpp/src/glm/ols.cuh | 20 +++++++-------- .../cuml/linear_model/linear_regression.pyx | 25 ++++++++++--------- 4 files changed, 30 insertions(+), 29 deletions(-) diff --git a/cpp/include/cuml/linear_model/glm.hpp b/cpp/include/cuml/linear_model/glm.hpp index 3b3cdcf795..f6477ef7a7 100644 --- a/cpp/include/cuml/linear_model/glm.hpp +++ b/cpp/include/cuml/linear_model/glm.hpp @@ -32,7 +32,7 @@ namespace GLM { * @param normalize if true, normalize data to zero mean, unit variance * @param algo specifies which solver to use (0: SVD, 1: Eigendecomposition, 2: * QR-decomposition) - * @param sample_weights device pointer to sample weights vector of length n_rows + * @param sample_weight device pointer to sample weight vector of length n_rows * @{ */ void olsFit(const raft::handle_t& handle, @@ -45,7 +45,7 @@ void olsFit(const raft::handle_t& handle, bool fit_intercept, bool normalize, int algo = 0, - float* sample_weights = nullptr); + float* sample_weight = nullptr); void olsFit(const raft::handle_t& handle, double* input, int n_rows, @@ -56,7 +56,7 @@ void olsFit(const raft::handle_t& handle, bool fit_intercept, bool normalize, int algo = 0, - double *sample_weights = nullptr); + double *sample_weight = nullptr); /** @} */ /** diff --git a/cpp/src/glm/glm.cu b/cpp/src/glm/glm.cu index f387dec4fb..c6dd41fa3d 100644 --- a/cpp/src/glm/glm.cu +++ b/cpp/src/glm/glm.cu @@ -38,7 +38,7 @@ void olsFit(const raft::handle_t& handle, bool fit_intercept, bool normalize, int algo, - float *sample_weights) + float *sample_weight) { olsFit(handle, input, @@ -51,7 +51,7 @@ void olsFit(const raft::handle_t& handle, normalize, handle.get_stream(), algo, - sample_weights); + sample_weight); } void olsFit(const raft::handle_t& handle, @@ -64,7 +64,7 @@ void olsFit(const raft::handle_t& handle, bool fit_intercept, bool normalize, int algo, - double *sample_weights) + double *sample_weight) { olsFit(handle, input, @@ -77,7 +77,7 @@ void olsFit(const raft::handle_t& handle, normalize, handle.get_stream(), algo, - sample_weights); + sample_weight); } void gemmPredict(const raft::handle_t& handle, diff --git a/cpp/src/glm/ols.cuh b/cpp/src/glm/ols.cuh index b00df11886..6fadaee3f1 100644 --- a/cpp/src/glm/ols.cuh +++ b/cpp/src/glm/ols.cuh @@ -52,7 +52,7 @@ using namespace MLCommon; * @param stream cuda stream * @param algo specifies which solver to use (0: SVD, 1: Eigendecomposition, 2: * QR-decomposition) - * @param sample_weights device pointer to sample weights vector of length n_rows + * @param sample_weight device pointer to sample weight vector of length n_rows */ template void olsFit(const raft::handle_t& handle, @@ -66,7 +66,7 @@ void olsFit(const raft::handle_t& handle, bool normalize, cudaStream_t stream, int algo = 0, - math_t* sample_weights = nullptr) + math_t* sample_weight = nullptr) { auto cublas_handle = handle.get_cublas_handle(); auto cusolver_handle = handle.get_cusolver_dn_handle(); @@ -78,12 +78,12 @@ void olsFit(const raft::handle_t& handle, rmm::device_uvector norm2_input(0, stream); rmm::device_uvector mu_labels(0, stream); - if (sample_weights != nullptr) { - LinAlg::sqrt(sample_weights, sample_weights, n_rows, stream); - raft::matrix::matrixVectorBinaryMult(input, sample_weights, n_rows, n_cols, false, false, stream); + if (sample_weight != nullptr) { + LinAlg::sqrt(sample_weight, sample_weight, n_rows, stream); + raft::matrix::matrixVectorBinaryMult(input, sample_weight, n_rows, n_cols, false, false, stream); raft::linalg::map(labels, n_rows, [] __device__(math_t a, math_t b) { return a * b; }, - stream, labels, sample_weights); + stream, labels, sample_weight); } if (fit_intercept) { @@ -137,12 +137,12 @@ void olsFit(const raft::handle_t& handle, *intercept = math_t(0); } - if (sample_weights != nullptr) { - raft::matrix::matrixVectorBinaryDivSkipZero(input, sample_weights, n_rows, n_cols, false, false, stream); + if (sample_weight != nullptr) { + raft::matrix::matrixVectorBinaryDivSkipZero(input, sample_weight, n_rows, n_cols, false, false, stream); raft::linalg::map(labels, n_rows, [] __device__(math_t a, math_t b) { return a / b; }, - stream, labels, sample_weights); - LinAlg::powerScalar(sample_weights, sample_weights, (math_t)2, n_rows, stream); + stream, labels, sample_weight); + LinAlg::powerScalar(sample_weight, sample_weight, (math_t)2, n_rows, stream); } } diff --git a/python/cuml/linear_model/linear_regression.pyx b/python/cuml/linear_model/linear_regression.pyx index 5a6861e08e..889be893d6 100644 --- a/python/cuml/linear_model/linear_regression.pyx +++ b/python/cuml/linear_model/linear_regression.pyx @@ -51,7 +51,7 @@ cdef extern from "cuml/linear_model/glm.hpp" namespace "ML::GLM": bool fit_intercept, bool normalize, int algo, - float *sample_weights) except + + float *sample_weight) except + cdef void olsFit(handle_t& handle, double *input, @@ -63,7 +63,7 @@ cdef extern from "cuml/linear_model/glm.hpp" namespace "ML::GLM": bool fit_intercept, bool normalize, int algo, - double *sample_weights) except + + double *sample_weight) except + class LinearRegression(Base, @@ -243,12 +243,12 @@ class LinearRegression(Base, }[algorithm] @generate_docstring() - def fit(self, X, y, convert_dtype=True, sample_weights=None) -> "LinearRegression": + def fit(self, X, y, convert_dtype=True, sample_weight=None) -> "LinearRegression": """ Fit the model with X and y. """ - cdef uintptr_t X_ptr, y_ptr, sample_weights_ptr + cdef uintptr_t X_ptr, y_ptr, sample_weight_ptr X_m, n_rows, self.n_cols, self.dtype = \ input_to_cuml_array(X, check_dtype=[np.float32, np.float64]) X_ptr = X_m.ptr @@ -260,15 +260,15 @@ class LinearRegression(Base, check_rows=n_rows, check_cols=1) y_ptr = y_m.ptr - if sample_weights: - sample_weights_m, _, _, _ = \ - input_to_cuml_array(sample_weights, check_dtype=self.dtype, + if sample_weight: + sample_weight_m, _, _, _ = \ + input_to_cuml_array(sample_weight, check_dtype=self.dtype, convert_to_dtype=(self.dtype if convert_dtype else None), check_rows=n_rows, check_cols=1) - sample_weights_ptr = sample_weights_m.ptr + sample_weight_ptr = sample_weight_m.ptr else: - sample_weights_ptr = 0 + sample_weight_ptr = 0 if self.n_cols < 1: msg = "X matrix must have at least a column" @@ -303,7 +303,7 @@ class LinearRegression(Base, self.fit_intercept, self.normalize, self.algo, - sample_weights_ptr) + sample_weight_ptr) self.intercept_ = c_intercept1 else: @@ -317,7 +317,7 @@ class LinearRegression(Base, self.fit_intercept, self.normalize, self.algo, - sample_weights_ptr) + sample_weight_ptr) self.intercept_ = c_intercept2 @@ -325,7 +325,8 @@ class LinearRegression(Base, del X_m del y_m - del sample_weights_m + if sample_weight: + del sample_weight_m return self From ae8b24e3256b3d9ffd5b71c99887d9e9374aa63d Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Mon, 6 Dec 2021 09:22:21 -0800 Subject: [PATCH 04/15] Adding test --- .../cuml/linear_model/linear_regression.pyx | 4 +- python/cuml/test/test_linear_model.py | 45 ++++++++++++++++--- 2 files changed, 40 insertions(+), 9 deletions(-) diff --git a/python/cuml/linear_model/linear_regression.pyx b/python/cuml/linear_model/linear_regression.pyx index 889be893d6..703d30b0a0 100644 --- a/python/cuml/linear_model/linear_regression.pyx +++ b/python/cuml/linear_model/linear_regression.pyx @@ -260,7 +260,7 @@ class LinearRegression(Base, check_rows=n_rows, check_cols=1) y_ptr = y_m.ptr - if sample_weight: + if sample_weight is not None: sample_weight_m, _, _, _ = \ input_to_cuml_array(sample_weight, check_dtype=self.dtype, convert_to_dtype=(self.dtype if convert_dtype @@ -325,7 +325,7 @@ class LinearRegression(Base, del X_m del y_m - if sample_weight: + if sample_weight is not None: del sample_weight_m return self diff --git a/python/cuml/test/test_linear_model.py b/python/cuml/test/test_linear_model.py index 8f06fb7331..b8ea99590e 100644 --- a/python/cuml/test/test_linear_model.py +++ b/python/cuml/test/test_linear_model.py @@ -46,23 +46,23 @@ "set degenerate(.*)::sklearn[.*]") -def _make_regression_dataset_uncached(nrows, ncols, n_info): +def _make_regression_dataset_uncached(nrows, ncols, n_info, **kwargs): X, y = make_regression( - n_samples=nrows, n_features=ncols, n_informative=n_info, random_state=0 + **kwargs, n_samples=nrows, n_features=ncols, n_informative=n_info, random_state=0 ) return train_test_split(X, y, train_size=0.8, random_state=10) @lru_cache(4) -def _make_regression_dataset_from_cache(nrows, ncols, n_info): - return _make_regression_dataset_uncached(nrows, ncols, n_info) +def _make_regression_dataset_from_cache(nrows, ncols, n_info, **kwargs): + return _make_regression_dataset_uncached(nrows, ncols, n_info, **kwargs) -def make_regression_dataset(datatype, nrows, ncols, n_info): +def make_regression_dataset(datatype, nrows, ncols, n_info, **kwargs): if nrows * ncols < 1e8: # Keep cache under 4 GB - dataset = _make_regression_dataset_from_cache(nrows, ncols, n_info) + dataset = _make_regression_dataset_from_cache(nrows, ncols, n_info, **kwargs) else: - dataset = _make_regression_dataset_uncached(nrows, ncols, n_info) + dataset = _make_regression_dataset_uncached(nrows, ncols, n_info, **kwargs) return map(lambda arr: arr.astype(datatype), dataset) @@ -128,6 +128,37 @@ def test_linear_regression_model(datatype, algorithm, nrows, column_info): assert array_equal(skols_predict, cuols_predict, 1e-1, with_sign=True) +@pytest.mark.parametrize("datatype", [np.float32, np.float64]) +@pytest.mark.parametrize("algorithm", ["eig", "svd", "qr", "svd-qr"]) +@pytest.mark.parametrize("fit_intercept", [True, False]) +def test_weighted_linear_regression(datatype, algorithm, fit_intercept): + nrows, ncols, n_info = 1000, 20, 10 + max_weight = 10 + noise = 20 + X_train, X_test, y_train, y_test = make_regression_dataset( + datatype, nrows, ncols, n_info, noise=noise + ) + + # set weight per sample to be from 1 to max_weight + wt = np.random.randint(1, high=max_weight, size=len(X_train)) + + # Initialization of cuML's linear regression model + cuols = cuLinearRegression(fit_intercept=fit_intercept, + normalize=False, + algorithm=algorithm) + + # fit and predict cuml linear regression model + cuols.fit(X_train, y_train, sample_weight=wt) + cuols_predict = cuols.predict(X_test) + + # sklearn linear regression model initialization, fit and predict + skols = skLinearRegression(fit_intercept=fit_intercept, normalize=False) + skols.fit(X_train, y_train, sample_weight=wt) + + skols_predict = skols.predict(X_test) + + assert array_equal(skols_predict, cuols_predict, 1e-1, with_sign=True) + @pytest.mark.skipif( rmm._cuda.gpu.runtimeGetVersion() < 11000, reason='svd solver does not support more than 46340 rows or columns for' From f58a8a7ae32ee425a0dc51d23c41e8232b1e4f96 Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Thu, 27 Jan 2022 17:49:20 +0100 Subject: [PATCH 05/15] Draft for sample weighted mean --- cpp/src/glm/ols.cuh | 19 ++++++++++--------- cpp/src/glm/preprocess.cuh | 37 +++++++++++++++++++++++++++++++++---- 2 files changed, 43 insertions(+), 13 deletions(-) diff --git a/cpp/src/glm/ols.cuh b/cpp/src/glm/ols.cuh index 0e0a673c38..9d274d58ee 100644 --- a/cpp/src/glm/ols.cuh +++ b/cpp/src/glm/ols.cuh @@ -79,14 +79,6 @@ void olsFit(const raft::handle_t& handle, rmm::device_uvector norm2_input(0, stream); rmm::device_uvector mu_labels(0, stream); - if (sample_weight != nullptr) { - LinAlg::sqrt(sample_weight, sample_weight, n_rows, stream); - raft::matrix::matrixVectorBinaryMult(input, sample_weight, n_rows, n_cols, false, false, stream); - raft::linalg::map(labels, n_rows, - [] __device__(math_t a, math_t b) { return a * b; }, - stream, labels, sample_weight); - } - if (fit_intercept) { mu_input.resize(n_cols, stream); mu_labels.resize(1, stream); @@ -102,7 +94,16 @@ void olsFit(const raft::handle_t& handle, norm2_input.data(), fit_intercept, normalize, - stream); + stream, + sample_weight); + } + + if (sample_weight != nullptr) { + LinAlg::sqrt(sample_weight, sample_weight, n_rows, stream); + raft::matrix::matrixVectorBinaryMult(input, sample_weight, n_rows, n_cols, false, false, stream); + raft::linalg::map(labels, n_rows, + [] __device__(math_t a, math_t b) { return a * b; }, + stream, labels, sample_weight); } int selectedAlgo = algo; diff --git a/cpp/src/glm/preprocess.cuh b/cpp/src/glm/preprocess.cuh index 8ee77966c2..b383f20dba 100644 --- a/cpp/src/glm/preprocess.cuh +++ b/cpp/src/glm/preprocess.cuh @@ -17,7 +17,10 @@ #pragma once #include +#include #include +#include +#include #include #include #include @@ -44,16 +47,42 @@ void preProcessData(const raft::handle_t& handle, math_t* norm2_input, bool fit_intercept, bool normalize, - cudaStream_t stream) + cudaStream_t stream, + math_t* sample_weight = nullptr) { ASSERT(n_cols > 0, "Parameter n_cols: number of columns cannot be less than one"); ASSERT(n_rows > 1, "Parameter n_rows: number of rows cannot be less than two"); - + rmm::device_uvector mu_input2(n_cols, stream); if (fit_intercept) { - raft::stats::mean(mu_input, input, n_cols, n_rows, false, false, stream); + rmm::device_scalar sum_sw(stream); + rmm::device_uvector temp(0, stream); + rmm::device_uvector temp_labels(0, stream); + if (sample_weight) { + temp.resize(n_rows * n_cols, stream); + raft::copy(temp.data(), input, n_rows * n_cols, stream); + raft::stats::sum(sum_sw.data(), sample_weight, 1, n_rows, false, stream); + raft::matrix::matrixVectorBinaryMult(temp.data(), sample_weight, n_rows, n_cols, false, false, stream); + + raft::stats::mean(mu_input, temp.data(), n_cols, n_rows, false, false, stream); + math_t ratio = math_t(n_rows) / sum_sw.value(stream); + raft::linalg::scalarMultiply(mu_input, mu_input, ratio, n_cols, stream); + } else { + raft::stats::mean(mu_input, input, n_cols, n_rows, false, false, stream); + } raft::stats::meanCenter(input, input, mu_input, n_cols, n_rows, false, true, stream); - raft::stats::mean(mu_labels, labels, 1, n_rows, false, false, stream); + if (sample_weight) { + temp_labels.resize(n_rows, stream); + raft::copy(temp_labels.data(), labels, n_rows, stream); + raft::linalg::map(temp_labels.data(), n_rows, + [] __device__(math_t a, math_t b) { return a * b; }, + stream, temp_labels.data(), sample_weight); + raft::stats::mean(mu_labels, temp_labels.data(), 1, n_rows, false, false, stream); + math_t ratio = math_t(n_rows) / sum_sw.value(stream); + raft::linalg::scalarMultiply(mu_labels, mu_labels, ratio, 1, stream); + } else { + raft::stats::mean(mu_labels, labels, 1, n_rows, false, false, stream); + } raft::stats::meanCenter(labels, labels, mu_labels, 1, n_rows, false, true, stream); if (normalize) { From f47e360b4d9c259baf4f677169bcbf349bf65c74 Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Fri, 28 Jan 2022 14:43:25 +0100 Subject: [PATCH 06/15] Added rowSampleWeightedMean for cleaner code --- cpp/src/glm/ols.cuh | 16 +++++------ cpp/src/glm/preprocess.cuh | 28 ++++--------------- cpp/src_prims/stats/weighted_mean.cuh | 40 +++++++++++++++++++++++++++ python/cuml/test/test_linear_model.py | 7 +++-- 4 files changed, 57 insertions(+), 34 deletions(-) diff --git a/cpp/src/glm/ols.cuh b/cpp/src/glm/ols.cuh index 9d274d58ee..98282f5931 100644 --- a/cpp/src/glm/ols.cuh +++ b/cpp/src/glm/ols.cuh @@ -121,6 +121,14 @@ void olsFit(const raft::handle_t& handle, } raft::common::nvtx::pop_range(); + if (sample_weight != nullptr) { + raft::matrix::matrixVectorBinaryDivSkipZero(input, sample_weight, n_rows, n_cols, false, false, stream); + raft::linalg::map(labels, n_rows, + [] __device__(math_t a, math_t b) { return a / b; }, + stream, labels, sample_weight); + LinAlg::powerScalar(sample_weight, sample_weight, (math_t)2, n_rows, stream); + } + if (fit_intercept) { postProcessData(handle, input, @@ -138,14 +146,6 @@ void olsFit(const raft::handle_t& handle, } else { *intercept = math_t(0); } - - if (sample_weight != nullptr) { - raft::matrix::matrixVectorBinaryDivSkipZero(input, sample_weight, n_rows, n_cols, false, false, stream); - raft::linalg::map(labels, n_rows, - [] __device__(math_t a, math_t b) { return a / b; }, - stream, labels, sample_weight); - LinAlg::powerScalar(sample_weight, sample_weight, (math_t)2, n_rows, stream); - } } /** diff --git a/cpp/src/glm/preprocess.cuh b/cpp/src/glm/preprocess.cuh index b383f20dba..550a93f264 100644 --- a/cpp/src/glm/preprocess.cuh +++ b/cpp/src/glm/preprocess.cuh @@ -17,10 +17,7 @@ #pragma once #include -#include #include -#include -#include #include #include #include @@ -29,6 +26,7 @@ #include #include #include +#include namespace ML { namespace GLM { @@ -52,34 +50,18 @@ void preProcessData(const raft::handle_t& handle, { ASSERT(n_cols > 0, "Parameter n_cols: number of columns cannot be less than one"); ASSERT(n_rows > 1, "Parameter n_rows: number of rows cannot be less than two"); - rmm::device_uvector mu_input2(n_cols, stream); if (fit_intercept) { - rmm::device_scalar sum_sw(stream); - rmm::device_uvector temp(0, stream); - rmm::device_uvector temp_labels(0, stream); if (sample_weight) { - temp.resize(n_rows * n_cols, stream); - raft::copy(temp.data(), input, n_rows * n_cols, stream); - raft::stats::sum(sum_sw.data(), sample_weight, 1, n_rows, false, stream); - raft::matrix::matrixVectorBinaryMult(temp.data(), sample_weight, n_rows, n_cols, false, false, stream); - - raft::stats::mean(mu_input, temp.data(), n_cols, n_rows, false, false, stream); - math_t ratio = math_t(n_rows) / sum_sw.value(stream); - raft::linalg::scalarMultiply(mu_input, mu_input, ratio, n_cols, stream); + MLCommon::Stats::rowSampleWeightedMean(mu_input, input, + sample_weight, n_cols, n_rows, false, false, stream); } else { raft::stats::mean(mu_input, input, n_cols, n_rows, false, false, stream); } raft::stats::meanCenter(input, input, mu_input, n_cols, n_rows, false, true, stream); if (sample_weight) { - temp_labels.resize(n_rows, stream); - raft::copy(temp_labels.data(), labels, n_rows, stream); - raft::linalg::map(temp_labels.data(), n_rows, - [] __device__(math_t a, math_t b) { return a * b; }, - stream, temp_labels.data(), sample_weight); - raft::stats::mean(mu_labels, temp_labels.data(), 1, n_rows, false, false, stream); - math_t ratio = math_t(n_rows) / sum_sw.value(stream); - raft::linalg::scalarMultiply(mu_labels, mu_labels, ratio, 1, stream); + MLCommon::Stats::rowSampleWeightedMean(mu_labels, labels, + sample_weight, 1, n_rows, true, false, stream); } else { raft::stats::mean(mu_labels, labels, 1, n_rows, false, false, stream); } diff --git a/cpp/src_prims/stats/weighted_mean.cuh b/cpp/src_prims/stats/weighted_mean.cuh index da1969fdb7..8f739b038e 100644 --- a/cpp/src_prims/stats/weighted_mean.cuh +++ b/cpp/src_prims/stats/weighted_mean.cuh @@ -18,7 +18,9 @@ #include #include +#include #include +#include namespace MLCommon { namespace Stats { @@ -56,6 +58,44 @@ void rowWeightedMean( [WS] __device__(Type v) { return v / WS; }); } +/** + * @brief Compute the row-wise weighted mean of the input matrix + * + * @tparam Type the data type + * @param mu the output mean vector + * @param data the input matrix + * @param weights per-sample weight + * @param D number of columns of data + * @param N number of rows of data + * @param row_major input matrix is row-major or not + * @param along_rows whether to reduce along rows or columns + * @param stream cuda stream to launch work on + */ +template +void rowSampleWeightedMean( + Type* mu, const Type* data, const Type* weights, int D, int N, + bool row_major, bool along_rows, cudaStream_t stream) +{ + // sum the weights & copy back to CPU + Type WS = 0; + raft::stats::sum(mu, weights, 1, N, row_major, stream); + raft::update_host(&WS, mu, 1, stream); + + raft::linalg::reduce( + mu, + data, + D, + N, + (Type)0, + row_major, + along_rows, + stream, + false, + [weights] __device__(Type v, int i) { return v * weights[i]; }, + [] __device__(Type a, Type b) { return a + b; }, + [WS] __device__(Type v) { return v / WS; }); +} + /** * @brief Compute the column-wise weighted mean of the input matrix * diff --git a/python/cuml/test/test_linear_model.py b/python/cuml/test/test_linear_model.py index b8ea99590e..1a8a46526c 100644 --- a/python/cuml/test/test_linear_model.py +++ b/python/cuml/test/test_linear_model.py @@ -131,7 +131,8 @@ def test_linear_regression_model(datatype, algorithm, nrows, column_info): @pytest.mark.parametrize("datatype", [np.float32, np.float64]) @pytest.mark.parametrize("algorithm", ["eig", "svd", "qr", "svd-qr"]) @pytest.mark.parametrize("fit_intercept", [True, False]) -def test_weighted_linear_regression(datatype, algorithm, fit_intercept): +@pytest.mark.parametrize("normalize", [True, False]) +def test_weighted_linear_regression(datatype, algorithm, fit_intercept, normalize): nrows, ncols, n_info = 1000, 20, 10 max_weight = 10 noise = 20 @@ -144,7 +145,7 @@ def test_weighted_linear_regression(datatype, algorithm, fit_intercept): # Initialization of cuML's linear regression model cuols = cuLinearRegression(fit_intercept=fit_intercept, - normalize=False, + normalize=normalize, algorithm=algorithm) # fit and predict cuml linear regression model @@ -152,7 +153,7 @@ def test_weighted_linear_regression(datatype, algorithm, fit_intercept): cuols_predict = cuols.predict(X_test) # sklearn linear regression model initialization, fit and predict - skols = skLinearRegression(fit_intercept=fit_intercept, normalize=False) + skols = skLinearRegression(fit_intercept=fit_intercept, normalize=normalize) skols.fit(X_train, y_train, sample_weight=wt) skols_predict = skols.predict(X_test) From 65630b1581a5d3caa9deab0e825d19fb696c8a8f Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Fri, 28 Jan 2022 14:51:37 +0100 Subject: [PATCH 07/15] Fix style and copyrights --- cpp/include/cuml/linear_model/glm.hpp | 6 ++-- cpp/src/glm/glm.cu | 4 +-- cpp/src/glm/ols.cuh | 31 +++++++++++++------ cpp/src/glm/preprocess.cuh | 11 ++++--- cpp/src_prims/stats/weighted_mean.cuh | 13 +++++--- .../cuml/linear_model/linear_regression.pyx | 2 +- python/cuml/test/test_linear_model.py | 2 +- 7 files changed, 43 insertions(+), 26 deletions(-) diff --git a/cpp/include/cuml/linear_model/glm.hpp b/cpp/include/cuml/linear_model/glm.hpp index f6477ef7a7..77400d3f70 100644 --- a/cpp/include/cuml/linear_model/glm.hpp +++ b/cpp/include/cuml/linear_model/glm.hpp @@ -44,7 +44,7 @@ void olsFit(const raft::handle_t& handle, float* intercept, bool fit_intercept, bool normalize, - int algo = 0, + int algo = 0, float* sample_weight = nullptr); void olsFit(const raft::handle_t& handle, double* input, @@ -55,8 +55,8 @@ void olsFit(const raft::handle_t& handle, double* intercept, bool fit_intercept, bool normalize, - int algo = 0, - double *sample_weight = nullptr); + int algo = 0, + double* sample_weight = nullptr); /** @} */ /** diff --git a/cpp/src/glm/glm.cu b/cpp/src/glm/glm.cu index f3da26ad31..31dea95d21 100644 --- a/cpp/src/glm/glm.cu +++ b/cpp/src/glm/glm.cu @@ -38,7 +38,7 @@ void olsFit(const raft::handle_t& handle, bool fit_intercept, bool normalize, int algo, - float *sample_weight) + float* sample_weight) { olsFit(handle, input, @@ -64,7 +64,7 @@ void olsFit(const raft::handle_t& handle, bool fit_intercept, bool normalize, int algo, - double *sample_weight) + double* sample_weight) { olsFit(handle, input, diff --git a/cpp/src/glm/ols.cuh b/cpp/src/glm/ols.cuh index 98282f5931..c25d5c9eb9 100644 --- a/cpp/src/glm/ols.cuh +++ b/cpp/src/glm/ols.cuh @@ -17,11 +17,11 @@ #pragma once #include -#include #include -#include +#include #include #include +#include #include #include #include @@ -53,7 +53,8 @@ using namespace MLCommon; * @param stream cuda stream * @param algo specifies which solver to use (0: SVD, 1: Eigendecomposition, 2: * QR-decomposition) - * @param sample_weight device pointer to sample weight vector of length n_rows + * @param sample_weight device pointer to sample weight vector of length n_rows (nullptr for uniform + * weights) */ template void olsFit(const raft::handle_t& handle, @@ -66,7 +67,7 @@ void olsFit(const raft::handle_t& handle, bool fit_intercept, bool normalize, cudaStream_t stream, - int algo = 0, + int algo = 0, math_t* sample_weight = nullptr) { auto cublas_handle = handle.get_cublas_handle(); @@ -100,10 +101,15 @@ void olsFit(const raft::handle_t& handle, if (sample_weight != nullptr) { LinAlg::sqrt(sample_weight, sample_weight, n_rows, stream); - raft::matrix::matrixVectorBinaryMult(input, sample_weight, n_rows, n_cols, false, false, stream); - raft::linalg::map(labels, n_rows, + raft::matrix::matrixVectorBinaryMult( + input, sample_weight, n_rows, n_cols, false, false, stream); + raft::linalg::map( + labels, + n_rows, [] __device__(math_t a, math_t b) { return a * b; }, - stream, labels, sample_weight); + stream, + labels, + sample_weight); } int selectedAlgo = algo; @@ -122,10 +128,15 @@ void olsFit(const raft::handle_t& handle, raft::common::nvtx::pop_range(); if (sample_weight != nullptr) { - raft::matrix::matrixVectorBinaryDivSkipZero(input, sample_weight, n_rows, n_cols, false, false, stream); - raft::linalg::map(labels, n_rows, + raft::matrix::matrixVectorBinaryDivSkipZero( + input, sample_weight, n_rows, n_cols, false, false, stream); + raft::linalg::map( + labels, + n_rows, [] __device__(math_t a, math_t b) { return a / b; }, - stream, labels, sample_weight); + stream, + labels, + sample_weight); LinAlg::powerScalar(sample_weight, sample_weight, (math_t)2, n_rows, stream); } diff --git a/cpp/src/glm/preprocess.cuh b/cpp/src/glm/preprocess.cuh index 550a93f264..c66a4f5ac0 100644 --- a/cpp/src/glm/preprocess.cuh +++ b/cpp/src/glm/preprocess.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2021, NVIDIA CORPORATION. + * Copyright (c) 2018-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -50,18 +50,19 @@ void preProcessData(const raft::handle_t& handle, { ASSERT(n_cols > 0, "Parameter n_cols: number of columns cannot be less than one"); ASSERT(n_rows > 1, "Parameter n_rows: number of rows cannot be less than two"); + if (fit_intercept) { if (sample_weight) { - MLCommon::Stats::rowSampleWeightedMean(mu_input, input, - sample_weight, n_cols, n_rows, false, false, stream); + MLCommon::Stats::rowSampleWeightedMean( + mu_input, input, sample_weight, n_cols, n_rows, false, false, stream); } else { raft::stats::mean(mu_input, input, n_cols, n_rows, false, false, stream); } raft::stats::meanCenter(input, input, mu_input, n_cols, n_rows, false, true, stream); if (sample_weight) { - MLCommon::Stats::rowSampleWeightedMean(mu_labels, labels, - sample_weight, 1, n_rows, true, false, stream); + MLCommon::Stats::rowSampleWeightedMean( + mu_labels, labels, sample_weight, 1, n_rows, true, false, stream); } else { raft::stats::mean(mu_labels, labels, 1, n_rows, false, false, stream); } diff --git a/cpp/src_prims/stats/weighted_mean.cuh b/cpp/src_prims/stats/weighted_mean.cuh index 8f739b038e..0a1b918b51 100644 --- a/cpp/src_prims/stats/weighted_mean.cuh +++ b/cpp/src_prims/stats/weighted_mean.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2021, NVIDIA CORPORATION. + * Copyright (c) 2019-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -72,9 +72,14 @@ void rowWeightedMean( * @param stream cuda stream to launch work on */ template -void rowSampleWeightedMean( - Type* mu, const Type* data, const Type* weights, int D, int N, - bool row_major, bool along_rows, cudaStream_t stream) +void rowSampleWeightedMean(Type* mu, + const Type* data, + const Type* weights, + int D, + int N, + bool row_major, + bool along_rows, + cudaStream_t stream) { // sum the weights & copy back to CPU Type WS = 0; diff --git a/python/cuml/linear_model/linear_regression.pyx b/python/cuml/linear_model/linear_regression.pyx index a28d5b1132..41a58b6fb0 100644 --- a/python/cuml/linear_model/linear_regression.pyx +++ b/python/cuml/linear_model/linear_regression.pyx @@ -1,5 +1,5 @@ # -# Copyright (c) 2019-2021, NVIDIA CORPORATION. +# Copyright (c) 2019-2022, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/python/cuml/test/test_linear_model.py b/python/cuml/test/test_linear_model.py index 1a8a46526c..35d732f6d2 100644 --- a/python/cuml/test/test_linear_model.py +++ b/python/cuml/test/test_linear_model.py @@ -1,4 +1,4 @@ -# Copyright (c) 2019-2021, NVIDIA CORPORATION. +# Copyright (c) 2019-2022, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. From 80908951dffd7f17a21c138bd0834e6d06fd7e98 Mon Sep 17 00:00:00 2001 From: Micka <9810050+lowener@users.noreply.github.com> Date: Mon, 31 Jan 2022 11:55:46 +0100 Subject: [PATCH 08/15] Update Copyright --- cpp/include/cuml/linear_model/glm.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/include/cuml/linear_model/glm.hpp b/cpp/include/cuml/linear_model/glm.hpp index 77400d3f70..968b2d99ee 100644 --- a/cpp/include/cuml/linear_model/glm.hpp +++ b/cpp/include/cuml/linear_model/glm.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2021, NVIDIA CORPORATION. + * Copyright (c) 2018-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. From e81980cc89203a9598a60352632866dbb67fc262 Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Mon, 31 Jan 2022 12:42:18 +0100 Subject: [PATCH 09/15] Fix style --- .../cuml/linear_model/linear_regression.pyx | 7 ++++--- python/cuml/test/test_linear_model.py | 20 ++++++++++++------- 2 files changed, 17 insertions(+), 10 deletions(-) diff --git a/python/cuml/linear_model/linear_regression.pyx b/python/cuml/linear_model/linear_regression.pyx index 41a58b6fb0..024770aa5b 100644 --- a/python/cuml/linear_model/linear_regression.pyx +++ b/python/cuml/linear_model/linear_regression.pyx @@ -243,7 +243,8 @@ class LinearRegression(Base, }[algorithm] @generate_docstring() - def fit(self, X, y, convert_dtype=True, sample_weight=None) -> "LinearRegression": + def fit(self, X, y, convert_dtype=True, + sample_weight=None) -> "LinearRegression": """ Fit the model with X and y. @@ -263,8 +264,8 @@ class LinearRegression(Base, if sample_weight is not None: sample_weight_m, _, _, _ = \ input_to_cuml_array(sample_weight, check_dtype=self.dtype, - convert_to_dtype=(self.dtype if convert_dtype - else None), + convert_to_dtype=( + self.dtype if convert_dtype else None), check_rows=n_rows, check_cols=1) sample_weight_ptr = sample_weight_m.ptr else: diff --git a/python/cuml/test/test_linear_model.py b/python/cuml/test/test_linear_model.py index 35d732f6d2..ff21fb6d7e 100644 --- a/python/cuml/test/test_linear_model.py +++ b/python/cuml/test/test_linear_model.py @@ -48,7 +48,8 @@ def _make_regression_dataset_uncached(nrows, ncols, n_info, **kwargs): X, y = make_regression( - **kwargs, n_samples=nrows, n_features=ncols, n_informative=n_info, random_state=0 + **kwargs, n_samples=nrows, n_features=ncols, n_informative=n_info, + random_state=0 ) return train_test_split(X, y, train_size=0.8, random_state=10) @@ -60,9 +61,11 @@ def _make_regression_dataset_from_cache(nrows, ncols, n_info, **kwargs): def make_regression_dataset(datatype, nrows, ncols, n_info, **kwargs): if nrows * ncols < 1e8: # Keep cache under 4 GB - dataset = _make_regression_dataset_from_cache(nrows, ncols, n_info, **kwargs) + dataset = _make_regression_dataset_from_cache(nrows, ncols, n_info, + **kwargs) else: - dataset = _make_regression_dataset_uncached(nrows, ncols, n_info, **kwargs) + dataset = _make_regression_dataset_uncached(nrows, ncols, n_info, + **kwargs) return map(lambda arr: arr.astype(datatype), dataset) @@ -132,17 +135,18 @@ def test_linear_regression_model(datatype, algorithm, nrows, column_info): @pytest.mark.parametrize("algorithm", ["eig", "svd", "qr", "svd-qr"]) @pytest.mark.parametrize("fit_intercept", [True, False]) @pytest.mark.parametrize("normalize", [True, False]) -def test_weighted_linear_regression(datatype, algorithm, fit_intercept, normalize): +def test_weighted_linear_regression(datatype, algorithm, fit_intercept, + normalize): nrows, ncols, n_info = 1000, 20, 10 max_weight = 10 noise = 20 X_train, X_test, y_train, y_test = make_regression_dataset( datatype, nrows, ncols, n_info, noise=noise ) - + # set weight per sample to be from 1 to max_weight wt = np.random.randint(1, high=max_weight, size=len(X_train)) - + # Initialization of cuML's linear regression model cuols = cuLinearRegression(fit_intercept=fit_intercept, normalize=normalize, @@ -153,13 +157,15 @@ def test_weighted_linear_regression(datatype, algorithm, fit_intercept, normaliz cuols_predict = cuols.predict(X_test) # sklearn linear regression model initialization, fit and predict - skols = skLinearRegression(fit_intercept=fit_intercept, normalize=normalize) + skols = skLinearRegression(fit_intercept=fit_intercept, + normalize=normalize) skols.fit(X_train, y_train, sample_weight=wt) skols_predict = skols.predict(X_test) assert array_equal(skols_predict, cuols_predict, 1e-1, with_sign=True) + @pytest.mark.skipif( rmm._cuda.gpu.runtimeGetVersion() < 11000, reason='svd solver does not support more than 46340 rows or columns for' From b8c7092c710028485fc578f380e8dbff4284377e Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Mon, 31 Jan 2022 17:30:10 +0100 Subject: [PATCH 10/15] add doc --- cpp/include/cuml/linear_model/glm.hpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/cpp/include/cuml/linear_model/glm.hpp b/cpp/include/cuml/linear_model/glm.hpp index 968b2d99ee..bc1b942640 100644 --- a/cpp/include/cuml/linear_model/glm.hpp +++ b/cpp/include/cuml/linear_model/glm.hpp @@ -32,7 +32,8 @@ namespace GLM { * @param normalize if true, normalize data to zero mean, unit variance * @param algo specifies which solver to use (0: SVD, 1: Eigendecomposition, 2: * QR-decomposition) - * @param sample_weight device pointer to sample weight vector of length n_rows + * @param sample_weight device pointer to sample weight vector of length n_rows (nullptr + for uniform weights) * @{ */ void olsFit(const raft::handle_t& handle, From dd92cdf60385662a5143e8ee5a3480c2b1ff215f Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Wed, 2 Feb 2022 17:57:04 +0100 Subject: [PATCH 11/15] Add distributions for tests and doc --- cpp/src_prims/stats/weighted_mean.cuh | 9 ++++++--- python/cuml/test/test_linear_model.py | 21 +++++++++++++++++---- 2 files changed, 23 insertions(+), 7 deletions(-) diff --git a/cpp/src_prims/stats/weighted_mean.cuh b/cpp/src_prims/stats/weighted_mean.cuh index 0a1b918b51..953884726a 100644 --- a/cpp/src_prims/stats/weighted_mean.cuh +++ b/cpp/src_prims/stats/weighted_mean.cuh @@ -26,7 +26,8 @@ namespace MLCommon { namespace Stats { /** - * @brief Compute the row-wise weighted mean of the input matrix + * @brief Compute the row-wise weighted mean of the input matrix with a + * vector of column weights * * @tparam Type the data type * @param mu the output mean vector @@ -59,7 +60,8 @@ void rowWeightedMean( } /** - * @brief Compute the row-wise weighted mean of the input matrix + * @brief Compute the row-wise weighted mean of the input matrix with a + * vector of sample weights * * @tparam Type the data type * @param mu the output mean vector @@ -102,7 +104,8 @@ void rowSampleWeightedMean(Type* mu, } /** - * @brief Compute the column-wise weighted mean of the input matrix + * @brief Compute the column-wise weighted mean of the input matrix with a + * vector of column weights * * @tparam Type the data type * @param mu the output mean vector diff --git a/python/cuml/test/test_linear_model.py b/python/cuml/test/test_linear_model.py index ff21fb6d7e..9d66a801da 100644 --- a/python/cuml/test/test_linear_model.py +++ b/python/cuml/test/test_linear_model.py @@ -133,10 +133,18 @@ def test_linear_regression_model(datatype, algorithm, nrows, column_info): @pytest.mark.parametrize("datatype", [np.float32, np.float64]) @pytest.mark.parametrize("algorithm", ["eig", "svd", "qr", "svd-qr"]) -@pytest.mark.parametrize("fit_intercept", [True, False]) -@pytest.mark.parametrize("normalize", [True, False]) +@pytest.mark.parametrize( + "fit_intercept, normalize, distribution", [ + (True, True, "lognormal"), + (True, True, "exponential"), + (True, False, "uniform"), + (True, False, "exponential"), + (False, True, "lognormal"), + (False, False, "uniform"), + ] +) def test_weighted_linear_regression(datatype, algorithm, fit_intercept, - normalize): + normalize, distribution): nrows, ncols, n_info = 1000, 20, 10 max_weight = 10 noise = 20 @@ -145,7 +153,12 @@ def test_weighted_linear_regression(datatype, algorithm, fit_intercept, ) # set weight per sample to be from 1 to max_weight - wt = np.random.randint(1, high=max_weight, size=len(X_train)) + if distribution == "uniform": + wt = np.random.randint(1, high=max_weight, size=len(X_train)) + elif distribution == "exponential": + wt = np.random.exponential(size=len(X_train)) + else: + wt = np.random.lognormal(size=len(X_train)) # Initialization of cuML's linear regression model cuols = cuLinearRegression(fit_intercept=fit_intercept, From f6e8c3385f137339903c241da7d0d43e7c3749a5 Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Fri, 18 Feb 2022 01:09:08 +0100 Subject: [PATCH 12/15] Update new raft stats --- cpp/cmake/thirdparty/get_raft.cmake | 4 ++-- cpp/src/glm/ols.cuh | 4 ++-- cpp/src/glm/preprocess.cuh | 8 +++----- 3 files changed, 7 insertions(+), 9 deletions(-) diff --git a/cpp/cmake/thirdparty/get_raft.cmake b/cpp/cmake/thirdparty/get_raft.cmake index 63f795d519..edf0c3bb30 100644 --- a/cpp/cmake/thirdparty/get_raft.cmake +++ b/cpp/cmake/thirdparty/get_raft.cmake @@ -57,8 +57,8 @@ set(CUML_BRANCH_VERSION_raft "${CUML_VERSION_MAJOR}.${CUML_VERSION_MINOR}") # To use a different RAFT locally, set the CMake variable # CPM_raft_SOURCE=/path/to/local/raft find_and_configure_raft(VERSION ${CUML_MIN_VERSION_raft} - FORK rapidsai - PINNED_TAG branch-${CUML_BRANCH_VERSION_raft} + FORK lowener + PINNED_TAG 22.04-weighted-mean USE_RAFT_NN ${CUML_USE_RAFT_NN} USE_FAISS_STATIC ${CUML_USE_FAISS_STATIC} ) diff --git a/cpp/src/glm/ols.cuh b/cpp/src/glm/ols.cuh index 1f5d005781..509a13c3f5 100644 --- a/cpp/src/glm/ols.cuh +++ b/cpp/src/glm/ols.cuh @@ -98,7 +98,7 @@ void olsFit(const raft::handle_t& handle, } if (sample_weight != nullptr) { - LinAlg::sqrt(sample_weight, sample_weight, n_rows, stream); + raft::linalg::sqrt(sample_weight, sample_weight, n_rows, stream); raft::matrix::matrixVectorBinaryMult( input, sample_weight, n_rows, n_cols, false, false, stream); raft::linalg::map( @@ -137,7 +137,7 @@ void olsFit(const raft::handle_t& handle, stream, labels, sample_weight); - LinAlg::powerScalar(sample_weight, sample_weight, (math_t)2, n_rows, stream); + raft::linalg::powerScalar(sample_weight, sample_weight, (math_t)2, n_rows, stream); } if (fit_intercept) { diff --git a/cpp/src/glm/preprocess.cuh b/cpp/src/glm/preprocess.cuh index 79d844a205..ee9f37f281 100644 --- a/cpp/src/glm/preprocess.cuh +++ b/cpp/src/glm/preprocess.cuh @@ -25,9 +25,9 @@ #include #include #include +#include #include #include -#include namespace ML { namespace GLM { @@ -90,15 +90,13 @@ void preProcessData(const raft::handle_t& handle, norm2_input); } else { if (sample_weight != nullptr) { - MLCommon::Stats::rowSampleWeightedMean( + raft::stats::rowSampleWeightedMean( mu_input, input, sample_weight, n_cols, n_rows, false, false, stream); } else { raft::stats::mean(mu_input, input, n_cols, n_rows, false, false, stream); } raft::stats::meanCenter(input, input, mu_input, n_cols, n_rows, false, true, stream); if (normalize) { - /* - */ raft::linalg::colNorm(norm2_input, input, n_cols, @@ -113,7 +111,7 @@ void preProcessData(const raft::handle_t& handle, } if (sample_weight != nullptr) { - MLCommon::Stats::rowSampleWeightedMean( + raft::stats::rowSampleWeightedMean( mu_labels, labels, sample_weight, 1, n_rows, true, false, stream); } else { raft::stats::mean(mu_labels, labels, 1, n_rows, false, false, stream); From 6e35d8d2dae6a8503a1a6fa0f54bdda24e1bdacb Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Fri, 18 Feb 2022 01:23:57 +0100 Subject: [PATCH 13/15] Fix style --- cpp/src/glm/ols.cuh | 2 +- cpp/src/glm/preprocess.cuh | 14 +++++++------- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/cpp/src/glm/ols.cuh b/cpp/src/glm/ols.cuh index 509a13c3f5..b9fa783f82 100644 --- a/cpp/src/glm/ols.cuh +++ b/cpp/src/glm/ols.cuh @@ -137,7 +137,7 @@ void olsFit(const raft::handle_t& handle, stream, labels, sample_weight); - raft::linalg::powerScalar(sample_weight, sample_weight, (math_t)2, n_rows, stream); + raft::linalg::powerScalar(sample_weight, sample_weight, (math_t)2, n_rows, stream); } if (fit_intercept) { diff --git a/cpp/src/glm/preprocess.cuh b/cpp/src/glm/preprocess.cuh index ee9f37f281..fecc6fd52f 100644 --- a/cpp/src/glm/preprocess.cuh +++ b/cpp/src/glm/preprocess.cuh @@ -98,13 +98,13 @@ void preProcessData(const raft::handle_t& handle, raft::stats::meanCenter(input, input, mu_input, n_cols, n_rows, false, true, stream); if (normalize) { raft::linalg::colNorm(norm2_input, - input, - n_cols, - n_rows, - raft::linalg::L2Norm, - false, - stream, - [] __device__(math_t v) { return raft::mySqrt(v); }); + input, + n_cols, + n_rows, + raft::linalg::L2Norm, + false, + stream, + [] __device__(math_t v) { return raft::mySqrt(v); }); raft::matrix::matrixVectorBinaryDivSkipZero( input, norm2_input, n_rows, n_cols, false, true, stream, true); } From da761af87934b946ac99e3d5a361384e34c25439 Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Wed, 2 Mar 2022 05:41:47 -0800 Subject: [PATCH 14/15] Use latest raft updates --- cpp/cmake/thirdparty/get_raft.cmake | 4 ++-- cpp/src/glm/preprocess.cuh | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/cpp/cmake/thirdparty/get_raft.cmake b/cpp/cmake/thirdparty/get_raft.cmake index edf0c3bb30..63f795d519 100644 --- a/cpp/cmake/thirdparty/get_raft.cmake +++ b/cpp/cmake/thirdparty/get_raft.cmake @@ -57,8 +57,8 @@ set(CUML_BRANCH_VERSION_raft "${CUML_VERSION_MAJOR}.${CUML_VERSION_MINOR}") # To use a different RAFT locally, set the CMake variable # CPM_raft_SOURCE=/path/to/local/raft find_and_configure_raft(VERSION ${CUML_MIN_VERSION_raft} - FORK lowener - PINNED_TAG 22.04-weighted-mean + FORK rapidsai + PINNED_TAG branch-${CUML_BRANCH_VERSION_raft} USE_RAFT_NN ${CUML_USE_RAFT_NN} USE_FAISS_STATIC ${CUML_USE_FAISS_STATIC} ) diff --git a/cpp/src/glm/preprocess.cuh b/cpp/src/glm/preprocess.cuh index fecc6fd52f..f6fbbdfa03 100644 --- a/cpp/src/glm/preprocess.cuh +++ b/cpp/src/glm/preprocess.cuh @@ -25,7 +25,7 @@ #include #include #include -#include +#include #include #include @@ -90,7 +90,7 @@ void preProcessData(const raft::handle_t& handle, norm2_input); } else { if (sample_weight != nullptr) { - raft::stats::rowSampleWeightedMean( + raft::stats::weightedMean( mu_input, input, sample_weight, n_cols, n_rows, false, false, stream); } else { raft::stats::mean(mu_input, input, n_cols, n_rows, false, false, stream); @@ -111,7 +111,7 @@ void preProcessData(const raft::handle_t& handle, } if (sample_weight != nullptr) { - raft::stats::rowSampleWeightedMean( + raft::stats::weightedMean( mu_labels, labels, sample_weight, 1, n_rows, true, false, stream); } else { raft::stats::mean(mu_labels, labels, 1, n_rows, false, false, stream); From df297c52da6f1b87ff333fd0752403ccfa541e2b Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Mon, 7 Mar 2022 15:12:03 +0100 Subject: [PATCH 15/15] fix style --- cpp/src/glm/preprocess.cuh | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/cpp/src/glm/preprocess.cuh b/cpp/src/glm/preprocess.cuh index f6fbbdfa03..1d2b470e52 100644 --- a/cpp/src/glm/preprocess.cuh +++ b/cpp/src/glm/preprocess.cuh @@ -111,8 +111,7 @@ void preProcessData(const raft::handle_t& handle, } if (sample_weight != nullptr) { - raft::stats::weightedMean( - mu_labels, labels, sample_weight, 1, n_rows, true, false, stream); + raft::stats::weightedMean(mu_labels, labels, sample_weight, 1, n_rows, true, false, stream); } else { raft::stats::mean(mu_labels, labels, 1, n_rows, false, false, stream); }