Skip to content

Commit

Permalink
Faster GLM preprocessing by fusing kernels (rapidsai#4549)
Browse files Browse the repository at this point in the history
Fuse fit_intercept and normalize kernels when both are enabled. This change reduces the preprocess/postprocess runtime almost by half when the data is normalized (which is false by default though).
Furthermore, it changes the behavior of the "normalize" switch from dividing by the column-wise L2 norm to dividing by the column-wise standard deviation.

Authors:
  - Artem M. Chirkin (https://github.com/achirkin)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)

URL: rapidsai#4549
  • Loading branch information
achirkin authored Feb 10, 2022
1 parent 8dbbbf0 commit 9dfa471
Show file tree
Hide file tree
Showing 7 changed files with 159 additions and 36 deletions.
75 changes: 57 additions & 18 deletions cpp/src/glm/preprocess.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include <raft/matrix/matrix.hpp>
#include <raft/stats/mean.hpp>
#include <raft/stats/mean_center.hpp>
#include <raft/stats/meanvar.hpp>
#include <raft/stats/stddev.hpp>
#include <rmm/device_scalar.hpp>
#include <rmm/device_uvector.hpp>
Expand All @@ -32,6 +33,24 @@ namespace GLM {

using namespace MLCommon;

/**
* @brief Center and scale the data, depending on the flags fit_intercept and normalize
*
* @tparam math_t the element type
* @param [inout] input the column-major data of size [n_rows, n_cols]
* @param [in] n_rows
* @param [in] n_cols
* @param [inout] labels vector of size [n_rows]
* @param [out] intercept
* @param [out] mu_input the column-wise means of the input of size [n_cols]
* @param [out] mu_labels the scalar mean of the target (labels vector)
* @param [out] norm2_input the column-wise standard deviations of the input of size [n_cols];
* note, the biased estimator is used to match sklearn's StandardScaler
* (dividing by n_rows, not by (n_rows - 1)).
* @param [in] fit_intercept whether to center the data / to fit the intercept
* @param [in] normalize whether to normalize the data
* @param [in] stream
*/
template <typename math_t>
void preProcessData(const raft::handle_t& handle,
math_t* input,
Expand All @@ -46,28 +65,35 @@ void preProcessData(const raft::handle_t& handle,
bool normalize,
cudaStream_t stream)
{
raft::common::nvtx::range fun_scope("ML::GLM::preProcessData-%d-%d", n_rows, n_cols);
ASSERT(n_cols > 0, "Parameter n_cols: number of columns cannot be less than one");
ASSERT(n_rows > 1, "Parameter n_rows: number of rows cannot be less than two");

if (fit_intercept) {
raft::stats::mean(mu_input, input, n_cols, n_rows, false, false, stream);
raft::stats::meanCenter(input, input, mu_input, n_cols, n_rows, false, true, stream);

raft::stats::mean(mu_labels, labels, 1, n_rows, false, false, stream);
raft::stats::meanCenter(labels, labels, mu_labels, 1, n_rows, false, true, stream);

if (normalize) {
raft::linalg::colNorm(norm2_input,
input,
n_cols,
n_rows,
raft::linalg::L2Norm,
false,
stream,
[] __device__(math_t v) { return raft::mySqrt(v); });
raft::matrix::matrixVectorBinaryDivSkipZero(
input, norm2_input, n_rows, n_cols, false, true, stream, true);
raft::stats::meanvar(mu_input, norm2_input, input, n_cols, n_rows, false, false, stream);
raft::linalg::unaryOp(
norm2_input,
norm2_input,
n_cols,
[] __device__(math_t v) { return raft::mySqrt(v); },
stream);
raft::matrix::linewiseOp(
input,
input,
n_rows,
n_cols,
false,
[] __device__(math_t x, math_t m, math_t s) { return s > 1e-10 ? (x - m) / s : 0; },
stream,
mu_input,
norm2_input);
} else {
raft::stats::mean(mu_input, input, n_cols, n_rows, false, false, stream);
raft::stats::meanCenter(input, input, mu_input, n_cols, n_rows, false, true, stream);
}
raft::stats::mean(mu_labels, labels, 1, n_rows, false, false, stream);
raft::stats::meanCenter(labels, labels, mu_labels, 1, n_rows, false, true, stream);
}
}

Expand All @@ -86,14 +112,14 @@ void postProcessData(const raft::handle_t& handle,
bool normalize,
cudaStream_t stream)
{
raft::common::nvtx::range fun_scope("ML::GLM::postProcessData-%d-%d", n_rows, n_cols);
ASSERT(n_cols > 0, "Parameter n_cols: number of columns cannot be less than one");
ASSERT(n_rows > 1, "Parameter n_rows: number of rows cannot be less than two");

cublasHandle_t cublas_handle = handle.get_cublas_handle();
rmm::device_scalar<math_t> d_intercept(stream);

if (normalize) {
raft::matrix::matrixVectorBinaryMult(input, norm2_input, n_rows, n_cols, false, true, stream);
raft::matrix::matrixVectorBinaryDivSkipZero(
coef, norm2_input, 1, n_cols, false, true, stream, true);
}
Expand All @@ -104,7 +130,20 @@ void postProcessData(const raft::handle_t& handle,
raft::linalg::subtract(d_intercept.data(), mu_labels, d_intercept.data(), 1, stream);
*intercept = d_intercept.value(stream);

raft::stats::meanAdd(input, input, mu_input, n_cols, n_rows, false, true, stream);
if (normalize) {
raft::matrix::linewiseOp(
input,
input,
n_rows,
n_cols,
false,
[] __device__(math_t x, math_t m, math_t s) { return s * x + m; },
stream,
mu_input,
norm2_input);
} else {
raft::stats::meanAdd(input, input, mu_input, n_cols, n_rows, false, true, stream);
}
raft::stats::meanAdd(labels, labels, mu_labels, 1, n_rows, false, true, stream);
}

Expand Down
8 changes: 5 additions & 3 deletions cpp/src/solver/cd.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,8 @@ __global__ void __launch_bounds__(1, 1) cdUpdateCoefKernel(math_t* coefLoc,
* @param fit_intercept
* boolean parameter to control if the intercept will be fitted or not
* @param normalize
* boolean parameter to control if the data will be normalized or not
* boolean parameter to control if the data will be normalized or not;
* NB: the input is scaled by the column-wise biased sample standard deviation estimator.
* @param epochs
* Maximum number of iterations that solver will run
* @param loss
Expand Down Expand Up @@ -183,8 +184,9 @@ void cdFit(const raft::handle_t& handle,

// Precompute the residual
if (normalize) {
// if we normalized the data during preprocessing, no need to compute the norm again.
math_t scalar = math_t(1.0) + l2_alpha;
// if we normalized the data, we know sample variance for each column is 1,
// thus no need to compute the norm again.
math_t scalar = math_t(n_rows) + l2_alpha;
raft::matrix::setValue(squared.data(), squared.data(), scalar, n_cols, stream);
} else {
raft::linalg::colNorm(
Expand Down
51 changes: 50 additions & 1 deletion cpp/test/sg/cd_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@
#include <solver/cd.cuh>
#include <test_utils.h>

#include <raft/stats/mean.hpp>
#include <raft/stats/meanvar.hpp>
#include <raft/stats/stddev.hpp>

namespace ML {
namespace Solver {

Expand Down Expand Up @@ -72,6 +76,16 @@ class CdTest : public ::testing::TestWithParam<CdInputs<T>> {
T labels_h[params.n_row] = {6.0, 8.3, 9.8, 11.2};
raft::update_device(labels.data(), labels_h, params.n_row, stream);

/* How to reproduce the coefficients for this test:
from sklearn.preprocessing import StandardScaler
scaler = StandardScaler(with_mean=True, with_std=True)
x_norm = scaler.fit_transform(data_h)
m = ElasticNet(fit_intercept=, normalize=, alpha=, l1_ratio=)
m.fit(x_norm, y)
print(m.coef_ / scaler.scale_ if normalize else m.coef_)
*/

T coef_ref_h[params.n_col] = {4.90832, 0.35031};
raft::update_device(coef_ref.data(), coef_ref_h, params.n_col, stream);

Expand All @@ -81,7 +95,7 @@ class CdTest : public ::testing::TestWithParam<CdInputs<T>> {
T coef3_ref_h[params.n_col] = {2.932841, 1.15248};
raft::update_device(coef3_ref.data(), coef3_ref_h, params.n_col, stream);

T coef4_ref_h[params.n_col] = {0.569439, -0.00542};
T coef4_ref_h[params.n_col] = {1.75420431, -0.16215289};
raft::update_device(coef4_ref.data(), coef4_ref_h, params.n_col, stream);

bool fit_intercept = false;
Expand Down Expand Up @@ -202,6 +216,25 @@ TEST_P(CdTestF, Fit)
ASSERT_TRUE(raft::devArrMatch(
coef3_ref.data(), coef3.data(), params.n_col, raft::CompareApproxAbs<float>(params.tol)));

rmm::device_uvector<float> means_1(params.n_col, stream);
rmm::device_uvector<float> means_2(params.n_col, stream);
rmm::device_uvector<float> vars_1(params.n_col, stream);
rmm::device_uvector<float> vars_2(params.n_col, stream);

raft::stats::mean(means_1.data(), data.data(), params.n_col, params.n_row, false, false, stream);
raft::stats::vars(
vars_1.data(), data.data(), means_1.data(), params.n_col, params.n_row, false, false, stream);
raft::stats::meanvar(
means_2.data(), vars_2.data(), data.data(), params.n_col, params.n_row, false, false, stream);

ASSERT_TRUE(raft::devArrMatch(
means_1.data(), means_2.data(), params.n_col, raft::CompareApprox<float>(0.0001)));
ASSERT_TRUE(raft::devArrMatch(
vars_1.data(), vars_2.data(), params.n_col, raft::CompareApprox<float>(0.0001)));

ASSERT_TRUE(raft::devArrMatch(
coef4_ref.data(), coef4.data(), params.n_col, raft::CompareApproxAbs<float>(params.tol)));

ASSERT_TRUE(raft::devArrMatch(
coef4_ref.data(), coef4.data(), params.n_col, raft::CompareApproxAbs<float>(params.tol)));
}
Expand All @@ -218,6 +251,22 @@ TEST_P(CdTestD, Fit)
ASSERT_TRUE(raft::devArrMatch(
coef3_ref.data(), coef3.data(), params.n_col, raft::CompareApproxAbs<double>(params.tol)));

rmm::device_uvector<double> means_1(params.n_col, stream);
rmm::device_uvector<double> means_2(params.n_col, stream);
rmm::device_uvector<double> vars_1(params.n_col, stream);
rmm::device_uvector<double> vars_2(params.n_col, stream);

raft::stats::mean(means_1.data(), data.data(), params.n_col, params.n_row, false, false, stream);
raft::stats::vars(
vars_1.data(), data.data(), means_1.data(), params.n_col, params.n_row, false, false, stream);
raft::stats::meanvar(
means_2.data(), vars_2.data(), data.data(), params.n_col, params.n_row, false, false, stream);

ASSERT_TRUE(raft::devArrMatch(
means_1.data(), means_2.data(), params.n_col, raft::CompareApprox<double>(0.0001)));
ASSERT_TRUE(raft::devArrMatch(
vars_1.data(), vars_2.data(), params.n_col, raft::CompareApprox<double>(0.0001)));

ASSERT_TRUE(raft::devArrMatch(
coef4_ref.data(), coef4.data(), params.n_col, raft::CompareApproxAbs<double>(params.tol)));
}
Expand Down
34 changes: 29 additions & 5 deletions cpp/test/sg/ridge.cu
Original file line number Diff line number Diff line change
Expand Up @@ -72,31 +72,55 @@ class RidgeTest : public ::testing::TestWithParam<RidgeInputs<T>> {
rmm::device_uvector<T> labels(params.n_row, stream);
T alpha = params.alpha;

/* How to reproduce the coefficients for this test:
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import Ridge
scaler = StandardScaler(with_mean=True, with_std=True)
x_norm = scaler.fit_transform(x_train)
m = Ridge(
fit_intercept=False, normalize=False, alpha=0.5)
m.fit(x_train, y)
print(m.coef_, m.predict(x_test))
m = Ridge(
fit_intercept=True, normalize=False, alpha=0.5)
m.fit(x_train, y)
print(m.coef_, m.predict(x_test))
m = Ridge(
fit_intercept=True, normalize=False, alpha=0.5)
m.fit(x_norm, y)
print(m.coef_ / scaler.scale_, m.predict(scaler.transform(x_test)))
*/

T data_h[len] = {0.0, 0.0, 1.0, 0.0, 0.0, 1.0};
raft::update_device(data.data(), data_h, len, stream);

T labels_h[params.n_row] = {0.0, 0.1, 1.0};
raft::update_device(labels.data(), labels_h, params.n_row, stream);

T coef_ref_h[params.n_col] = {0.39999998, 0.4};
T coef_ref_h[params.n_col] = {0.4, 0.4};
raft::update_device(coef_ref.data(), coef_ref_h, params.n_col, stream);

T coef2_ref_h[params.n_col] = {0.3454546, 0.34545454};
raft::update_device(coef2_ref.data(), coef2_ref_h, params.n_col, stream);

T coef3_ref_h[params.n_col] = {0.3799999, 0.38000008};
T coef3_ref_h[params.n_col] = {0.43846154, 0.43846154};
raft::update_device(coef3_ref.data(), coef3_ref_h, params.n_col, stream);

T pred_data_h[len2] = {0.5, 2.0, 0.2, 1.0};
raft::update_device(pred_data.data(), pred_data_h, len2, stream);

T pred_ref_h[params.n_row_2] = {0.28, 1.1999999};
T pred_ref_h[params.n_row_2] = {0.28, 1.2};
raft::update_device(pred_ref.data(), pred_ref_h, params.n_row_2, stream);

T pred2_ref_h[params.n_row_2] = {0.37818184, 1.1727273};
T pred2_ref_h[params.n_row_2] = {0.37818182, 1.17272727};
raft::update_device(pred2_ref.data(), pred2_ref_h, params.n_row_2, stream);

T pred3_ref_h[params.n_row_2] = {0.37933332, 1.2533332};
T pred3_ref_h[params.n_row_2] = {0.38128205, 1.38974359};
raft::update_device(pred3_ref.data(), pred3_ref_h, params.n_row_2, stream);

intercept = T(0);
Expand Down
9 changes: 6 additions & 3 deletions python/cuml/linear_model/elastic_net.pyx
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# Copyright (c) 2019-2021, NVIDIA CORPORATION.
# Copyright (c) 2019-2022, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -105,9 +105,12 @@ class ElasticNet(Base,
If True, Lasso tries to correct for the global mean of y.
If False, the model expects that you have centered the data.
normalize : boolean (default = False)
If True, the predictors in X will be normalized by dividing by it's L2
norm.
If True, the predictors in X will be normalized by dividing by the
column-wise standard deviation.
If False, no scaling will be done.
Note: this is in contrast to sklearn's deprecated `normalize` flag,
which divides by the column-wise L2 norm; but this is the same as if
using sklearn's StandardScaler.
max_iter : int (default = 1000)
The maximum number of iterations
tol : float (default = 1e-3)
Expand Down
9 changes: 6 additions & 3 deletions python/cuml/linear_model/linear_regression.pyx
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# Copyright (c) 2019-2021, NVIDIA CORPORATION.
# Copyright (c) 2019-2022, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -150,9 +150,12 @@ class LinearRegression(Base,
If False, the model expects that you have centered the data.
normalize : boolean (default = False)
This parameter is ignored when `fit_intercept` is set to False.
If True, the predictors in X will be normalized by dividing by it's
L2 norm.
If True, the predictors in X will be normalized by dividing by the
column-wise standard deviation.
If False, no scaling will be done.
Note: this is in contrast to sklearn's deprecated `normalize` flag,
which divides by the column-wise L2 norm; but this is the same as if
using sklearn's StandardScaler.
handle : cuml.Handle
Specifies the cuml.handle that holds internal CUDA state for
computations in this model. Most importantly, this specifies the CUDA
Expand Down
9 changes: 6 additions & 3 deletions python/cuml/linear_model/ridge.pyx
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# Copyright (c) 2019-2021, NVIDIA CORPORATION.
# Copyright (c) 2019-2022, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -154,9 +154,12 @@ class Ridge(Base,
If True, Ridge tries to correct for the global mean of y.
If False, the model expects that you have centered the data.
normalize : boolean (default = False)
If True, the predictors in X will be normalized by dividing by it's L2
norm.
If True, the predictors in X will be normalized by dividing by the
column-wise standard deviation.
If False, no scaling will be done.
Note: this is in contrast to sklearn's deprecated `normalize` flag,
which divides by the column-wise L2 norm; but this is the same as if
using sklearn's StandardScaler.
handle : cuml.Handle
Specifies the cuml.handle that holds internal CUDA state for
computations in this model. Most importantly, this specifies the CUDA
Expand Down

0 comments on commit 9dfa471

Please sign in to comment.