Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for sample_weights in LinearRegression #4428

Merged
merged 18 commits into from
Mar 10, 2022
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions cpp/include/cuml/linear_model/glm.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2018-2021, NVIDIA CORPORATION.
* Copyright (c) 2018-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -32,6 +32,8 @@ namespace GLM {
* @param normalize if true, normalize data to zero mean, unit variance
* @param algo specifies which solver to use (0: SVD, 1: Eigendecomposition, 2:
* QR-decomposition)
* @param sample_weight device pointer to sample weight vector of length n_rows (nullptr
for uniform weights)
* @{
*/
void olsFit(const raft::handle_t& handle,
Expand All @@ -43,7 +45,8 @@ void olsFit(const raft::handle_t& handle,
float* intercept,
bool fit_intercept,
bool normalize,
int algo = 0);
int algo = 0,
float* sample_weight = nullptr);
void olsFit(const raft::handle_t& handle,
double* input,
int n_rows,
Expand All @@ -53,7 +56,8 @@ void olsFit(const raft::handle_t& handle,
double* intercept,
bool fit_intercept,
bool normalize,
int algo = 0);
int algo = 0,
double* sample_weight = nullptr);
/** @} */

/**
Expand Down
12 changes: 8 additions & 4 deletions cpp/src/glm/glm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ void olsFit(const raft::handle_t& handle,
float* intercept,
bool fit_intercept,
bool normalize,
int algo)
int algo,
float* sample_weight)
{
olsFit(handle,
input,
Expand All @@ -49,7 +50,8 @@ void olsFit(const raft::handle_t& handle,
fit_intercept,
normalize,
handle.get_stream(),
algo);
algo,
sample_weight);
}

void olsFit(const raft::handle_t& handle,
Expand All @@ -61,7 +63,8 @@ void olsFit(const raft::handle_t& handle,
double* intercept,
bool fit_intercept,
bool normalize,
int algo)
int algo,
double* sample_weight)
{
olsFit(handle,
input,
Expand All @@ -73,7 +76,8 @@ void olsFit(const raft::handle_t& handle,
fit_intercept,
normalize,
handle.get_stream(),
algo);
algo,
sample_weight);
}

void gemmPredict(const raft::handle_t& handle,
Expand Down
37 changes: 35 additions & 2 deletions cpp/src/glm/ols.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,11 @@
#pragma once

#include <linalg/lstsq.cuh>
#include <linalg/power.cuh>
#include <linalg/sqrt.cuh>
#include <raft/linalg/add.cuh>
#include <raft/linalg/gemv.h>
#include <raft/linalg/map.cuh>
#include <raft/linalg/norm.cuh>
#include <raft/linalg/subtract.cuh>
#include <raft/matrix/math.hpp>
Expand Down Expand Up @@ -50,6 +53,8 @@ using namespace MLCommon;
* @param stream cuda stream
* @param algo specifies which solver to use (0: SVD, 1: Eigendecomposition, 2:
* QR-decomposition)
* @param sample_weight device pointer to sample weight vector of length n_rows (nullptr for uniform
* weights)
*/
template <typename math_t>
void olsFit(const raft::handle_t& handle,
Expand All @@ -62,7 +67,8 @@ void olsFit(const raft::handle_t& handle,
bool fit_intercept,
bool normalize,
cudaStream_t stream,
int algo = 0)
int algo = 0,
math_t* sample_weight = nullptr)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not something you need to change in this PR, but it would be nice to start adopting std::optional for arguments like these. Becomes self-documenting as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree so I started to make the change but I saw that std::optional support on Cython is very recent and will be added to version 0.30. For the moment we're using version 0.29 so the adoption of std::optional on Python-facing functions should maybe wait a bit more.

{
auto cublas_handle = handle.get_cublas_handle();
auto cusolver_handle = handle.get_cusolver_dn_handle();
Expand All @@ -89,7 +95,21 @@ void olsFit(const raft::handle_t& handle,
norm2_input.data(),
fit_intercept,
normalize,
stream);
stream,
sample_weight);
}

if (sample_weight != nullptr) {
LinAlg::sqrt(sample_weight, sample_weight, n_rows, stream);
raft::matrix::matrixVectorBinaryMult(
input, sample_weight, n_rows, n_cols, false, false, stream);
raft::linalg::map(
labels,
n_rows,
[] __device__(math_t a, math_t b) { return a * b; },
stream,
labels,
sample_weight);
}

int selectedAlgo = algo;
Expand All @@ -107,6 +127,19 @@ void olsFit(const raft::handle_t& handle,
}
raft::common::nvtx::pop_range();

if (sample_weight != nullptr) {
raft::matrix::matrixVectorBinaryDivSkipZero(
input, sample_weight, n_rows, n_cols, false, false, stream);
raft::linalg::map(
labels,
n_rows,
[] __device__(math_t a, math_t b) { return a / b; },
stream,
labels,
sample_weight);
LinAlg::powerScalar(sample_weight, sample_weight, (math_t)2, n_rows, stream);
}

if (fit_intercept) {
postProcessData(handle,
input,
Expand Down
20 changes: 16 additions & 4 deletions cpp/src/glm/preprocess.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2018-2021, NVIDIA CORPORATION.
* Copyright (c) 2018-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -26,6 +26,7 @@
#include <raft/stats/stddev.hpp>
#include <rmm/device_scalar.hpp>
#include <rmm/device_uvector.hpp>
#include <stats/weighted_mean.cuh>

namespace ML {
namespace GLM {
Expand All @@ -44,16 +45,27 @@ void preProcessData(const raft::handle_t& handle,
math_t* norm2_input,
bool fit_intercept,
bool normalize,
cudaStream_t stream)
cudaStream_t stream,
math_t* sample_weight = nullptr)
{
ASSERT(n_cols > 0, "Parameter n_cols: number of columns cannot be less than one");
ASSERT(n_rows > 1, "Parameter n_rows: number of rows cannot be less than two");

if (fit_intercept) {
raft::stats::mean(mu_input, input, n_cols, n_rows, false, false, stream);
if (sample_weight) {
MLCommon::Stats::rowSampleWeightedMean(
mu_input, input, sample_weight, n_cols, n_rows, false, false, stream);
} else {
raft::stats::mean(mu_input, input, n_cols, n_rows, false, false, stream);
}
raft::stats::meanCenter(input, input, mu_input, n_cols, n_rows, false, true, stream);

raft::stats::mean(mu_labels, labels, 1, n_rows, false, false, stream);
if (sample_weight) {
MLCommon::Stats::rowSampleWeightedMean(
mu_labels, labels, sample_weight, 1, n_rows, true, false, stream);
} else {
raft::stats::mean(mu_labels, labels, 1, n_rows, false, false, stream);
}
raft::stats::meanCenter(labels, labels, mu_labels, 1, n_rows, false, true, stream);

if (normalize) {
Expand Down
47 changes: 46 additions & 1 deletion cpp/src_prims/stats/weighted_mean.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2021, NVIDIA CORPORATION.
* Copyright (c) 2019-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -18,7 +18,9 @@

#include <raft/cudart_utils.h>
#include <raft/linalg/coalesced_reduction.cuh>
#include <raft/linalg/reduce.cuh>
#include <raft/linalg/strided_reduction.cuh>
#include <raft/stats/sum.hpp>

namespace MLCommon {
namespace Stats {
Expand Down Expand Up @@ -56,6 +58,49 @@ void rowWeightedMean(
[WS] __device__(Type v) { return v / WS; });
}

/**
* @brief Compute the row-wise weighted mean of the input matrix
lowener marked this conversation as resolved.
Show resolved Hide resolved
*
* @tparam Type the data type
* @param mu the output mean vector
* @param data the input matrix
* @param weights per-sample weight
* @param D number of columns of data
* @param N number of rows of data
* @param row_major input matrix is row-major or not
* @param along_rows whether to reduce along rows or columns
* @param stream cuda stream to launch work on
*/
template <typename Type>
void rowSampleWeightedMean(Type* mu,
const Type* data,
const Type* weights,
int D,
int N,
bool row_major,
bool along_rows,
cudaStream_t stream)
{
// sum the weights & copy back to CPU
Type WS = 0;
raft::stats::sum(mu, weights, 1, N, row_major, stream);
raft::update_host(&WS, mu, 1, stream);

raft::linalg::reduce(
mu,
data,
D,
N,
(Type)0,
row_major,
along_rows,
stream,
false,
[weights] __device__(Type v, int i) { return v * weights[i]; },
[] __device__(Type a, Type b) { return a + b; },
[WS] __device__(Type v) { return v / WS; });
}

/**
* @brief Compute the column-wise weighted mean of the input matrix
*
Expand Down
33 changes: 26 additions & 7 deletions python/cuml/linear_model/linear_regression.pyx
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# Copyright (c) 2019-2021, NVIDIA CORPORATION.
# Copyright (c) 2019-2022, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -49,7 +49,9 @@ cdef extern from "cuml/linear_model/glm.hpp" namespace "ML::GLM":
float *coef,
float *intercept,
bool fit_intercept,
bool normalize, int algo) except +
bool normalize,
int algo,
float *sample_weight) except +

cdef void olsFit(handle_t& handle,
double *input,
Expand All @@ -59,7 +61,9 @@ cdef extern from "cuml/linear_model/glm.hpp" namespace "ML::GLM":
double *coef,
double *intercept,
bool fit_intercept,
bool normalize, int algo) except +
bool normalize,
int algo,
double *sample_weight) except +


class LinearRegression(Base,
Expand Down Expand Up @@ -239,12 +243,13 @@ class LinearRegression(Base,
}[algorithm]

@generate_docstring()
def fit(self, X, y, convert_dtype=True) -> "LinearRegression":
def fit(self, X, y, convert_dtype=True,
sample_weight=None) -> "LinearRegression":
"""
Fit the model with X and y.

"""
cdef uintptr_t X_ptr, y_ptr
cdef uintptr_t X_ptr, y_ptr, sample_weight_ptr
X_m, n_rows, self.n_cols, self.dtype = \
input_to_cuml_array(X, check_dtype=[np.float32, np.float64])
X_ptr = X_m.ptr
Expand All @@ -256,6 +261,16 @@ class LinearRegression(Base,
check_rows=n_rows, check_cols=1)
y_ptr = y_m.ptr

if sample_weight is not None:
sample_weight_m, _, _, _ = \
input_to_cuml_array(sample_weight, check_dtype=self.dtype,
convert_to_dtype=(
self.dtype if convert_dtype else None),
check_rows=n_rows, check_cols=1)
sample_weight_ptr = sample_weight_m.ptr
else:
sample_weight_ptr = 0

if self.n_cols < 1:
msg = "X matrix must have at least a column"
raise TypeError(msg)
Expand Down Expand Up @@ -288,7 +303,8 @@ class LinearRegression(Base,
<float*>&c_intercept1,
<bool>self.fit_intercept,
<bool>self.normalize,
<int>self.algo)
<int>self.algo,
<float*>sample_weight_ptr)

self.intercept_ = c_intercept1
else:
Expand All @@ -301,14 +317,17 @@ class LinearRegression(Base,
<double*>&c_intercept2,
<bool>self.fit_intercept,
<bool>self.normalize,
<int>self.algo)
<int>self.algo,
<double*>sample_weight_ptr)

self.intercept_ = c_intercept2

self.handle.sync()

del X_m
del y_m
if sample_weight is not None:
del sample_weight_m

return self

Expand Down
Loading