Skip to content

Commit

Permalink
Add sample_weight for Ridge (#4696)
Browse files Browse the repository at this point in the history
Authors:
  - Micka (https://github.com/lowener)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #4696
  • Loading branch information
lowener authored May 27, 2022
1 parent ee0f42e commit b37a968
Show file tree
Hide file tree
Showing 10 changed files with 190 additions and 44 deletions.
10 changes: 7 additions & 3 deletions cpp/include/cuml/linear_model/glm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ namespace GLM {
* @param algo specifies which solver to use (0: SVD, 1: Eigendecomposition, 2:
* QR-decomposition)
* @param sample_weight device pointer to sample weight vector of length n_rows (nullptr
for uniform weights)
for uniform weights) This vector is modified during the computation
* @{
*/
void olsFit(const raft::handle_t& handle,
Expand Down Expand Up @@ -75,6 +75,8 @@ void olsFit(const raft::handle_t& handle,
* @param fit_intercept if true, fit intercept
* @param normalize if true, normalize data to zero mean, unit variance
* @param algo specifies which solver to use (0: SVD, 1: Eigendecomposition)
* @param sample_weight device pointer to sample weight vector of length n_rows (nullptr
for uniform weights) This vector is modified during the computation
* @{
*/
void ridgeFit(const raft::handle_t& handle,
Expand All @@ -88,7 +90,8 @@ void ridgeFit(const raft::handle_t& handle,
float* intercept,
bool fit_intercept,
bool normalize,
int algo = 0);
int algo = 0,
float* sample_weight = nullptr);
void ridgeFit(const raft::handle_t& handle,
double* input,
int n_rows,
Expand All @@ -100,7 +103,8 @@ void ridgeFit(const raft::handle_t& handle,
double* intercept,
bool fit_intercept,
bool normalize,
int algo = 0);
int algo = 0,
double* sample_weight = nullptr);
/** @} */

/**
Expand Down
12 changes: 8 additions & 4 deletions cpp/src/glm/glm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,8 @@ void ridgeFit(const raft::handle_t& handle,
float* intercept,
bool fit_intercept,
bool normalize,
int algo)
int algo,
float* sample_weight)
{
ridgeFit(handle,
input,
Expand All @@ -125,7 +126,8 @@ void ridgeFit(const raft::handle_t& handle,
fit_intercept,
normalize,
handle.get_stream(),
algo);
algo,
sample_weight);
}

void ridgeFit(const raft::handle_t& handle,
Expand All @@ -139,7 +141,8 @@ void ridgeFit(const raft::handle_t& handle,
double* intercept,
bool fit_intercept,
bool normalize,
int algo)
int algo,
double* sample_weight)
{
ridgeFit(handle,
input,
Expand All @@ -153,7 +156,8 @@ void ridgeFit(const raft::handle_t& handle,
fit_intercept,
normalize,
handle.get_stream(),
algo);
algo,
sample_weight);
}

template <typename T, typename I>
Expand Down
28 changes: 13 additions & 15 deletions cpp/src/glm/ols.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,20 @@

#pragma once

#include <raft/linalg/add.hpp>
#include <raft/linalg/gemv.hpp>
#include <raft/linalg/lstsq.hpp>
#include <raft/linalg/add.cuh>
#include <raft/linalg/gemv.cuh>
#include <raft/linalg/lstsq.cuh>
#include <raft/linalg/map.cuh>
#include <raft/linalg/norm.hpp>
#include <raft/linalg/norm.cuh>
#include <raft/linalg/power.cuh>
#include <raft/linalg/sqrt.cuh>
#include <raft/linalg/subtract.hpp>
#include <raft/matrix/math.hpp>
#include <raft/matrix/matrix.hpp>
#include <raft/stats/mean.hpp>
#include <raft/stats/mean_center.hpp>
#include <raft/stats/stddev.hpp>
#include <raft/stats/sum.hpp>
#include <raft/linalg/subtract.cuh>
#include <raft/matrix/math.cuh>
#include <raft/matrix/matrix.cuh>
#include <raft/stats/mean.cuh>
#include <raft/stats/mean_center.cuh>
#include <raft/stats/stddev.cuh>
#include <raft/stats/sum.cuh>
#include <rmm/device_uvector.hpp>

#include "preprocess.cuh"
Expand All @@ -52,7 +52,7 @@ namespace GLM {
* @param algo specifies which solver to use (0: SVD, 1: Eigendecomposition, 2:
* QR-decomposition)
* @param sample_weight device pointer to sample weight vector of length n_rows (nullptr for uniform
* weights)
* weights) This vector is modified during the computation
*/
template <typename math_t>
void olsFit(const raft::handle_t& handle,
Expand Down Expand Up @@ -93,7 +93,6 @@ void olsFit(const raft::handle_t& handle,
norm2_input.data(),
fit_intercept,
normalize,
stream,
sample_weight);
}

Expand Down Expand Up @@ -152,8 +151,7 @@ void olsFit(const raft::handle_t& handle,
mu_labels.data(),
norm2_input.data(),
fit_intercept,
normalize,
stream);
normalize);
} else {
*intercept = math_t(0);
}
Expand Down
6 changes: 3 additions & 3 deletions cpp/src/glm/preprocess.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,9 @@ void preProcessData(const raft::handle_t& handle,
math_t* norm2_input,
bool fit_intercept,
bool normalize,
cudaStream_t stream,
math_t* sample_weight = nullptr)
{
cudaStream_t stream = handle.get_stream();
raft::common::nvtx::range fun_scope("ML::GLM::preProcessData-%d-%d", n_rows, n_cols);
ASSERT(n_cols > 0, "Parameter n_cols: number of columns cannot be less than one");
ASSERT(n_rows > 1, "Parameter n_rows: number of rows cannot be less than two");
Expand Down Expand Up @@ -131,9 +131,9 @@ void postProcessData(const raft::handle_t& handle,
math_t* mu_labels,
math_t* norm2_input,
bool fit_intercept,
bool normalize,
cudaStream_t stream)
bool normalize)
{
cudaStream_t stream = handle.get_stream();
raft::common::nvtx::range fun_scope("ML::GLM::postProcessData-%d-%d", n_rows, n_cols);
ASSERT(n_cols > 0, "Parameter n_cols: number of columns cannot be less than one");
ASSERT(n_rows > 1, "Parameter n_rows: number of rows cannot be less than two");
Expand Down
35 changes: 31 additions & 4 deletions cpp/src/glm/ridge.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,8 @@ void ridgeEig(const raft::handle_t& handle,
* @param normalize if true, normalize data to zero mean, unit variance
* @param stream cuda stream
* @param algo specifies which solver to use (0: SVD, 1: Eigendecomposition)
* @param sample_weight device pointer to sample weight vector of length n_rows (nullptr for uniform
* weights) This vector is modified during the computation
*/
template <typename math_t>
void ridgeFit(const raft::handle_t& handle,
Expand All @@ -160,7 +162,8 @@ void ridgeFit(const raft::handle_t& handle,
bool fit_intercept,
bool normalize,
cudaStream_t stream,
int algo = 0)
int algo = 0,
math_t* sample_weight = nullptr)
{
auto cublas_handle = handle.get_cublas_handle();
auto cusolver_handle = handle.get_cusolver_dn_handle();
Expand All @@ -187,7 +190,19 @@ void ridgeFit(const raft::handle_t& handle,
norm2_input.data(),
fit_intercept,
normalize,
stream);
sample_weight);
}
if (sample_weight != nullptr) {
raft::linalg::sqrt(sample_weight, sample_weight, n_rows, stream);
raft::matrix::matrixVectorBinaryMult(
input, sample_weight, n_rows, n_cols, false, false, stream);
raft::linalg::map(
labels,
n_rows,
[] __device__(math_t a, math_t b) { return a * b; },
stream,
labels,
sample_weight);
}

if (algo == 0 || n_cols == 1) {
Expand All @@ -200,6 +215,19 @@ void ridgeFit(const raft::handle_t& handle,
ASSERT(false, "ridgeFit: no algorithm with this id has been implemented");
}

if (sample_weight != nullptr) {
raft::matrix::matrixVectorBinaryDivSkipZero(
input, sample_weight, n_rows, n_cols, false, false, stream);
raft::linalg::map(
labels,
n_rows,
[] __device__(math_t a, math_t b) { return a / b; },
stream,
labels,
sample_weight);
raft::linalg::powerScalar(sample_weight, sample_weight, (math_t)2, n_rows, stream);
}

if (fit_intercept) {
postProcessData(handle,
input,
Expand All @@ -212,8 +240,7 @@ void ridgeFit(const raft::handle_t& handle,
mu_labels.data(),
norm2_input.data(),
fit_intercept,
normalize,
stream);
normalize);
} else {
*intercept = math_t(0);
}
Expand Down
6 changes: 2 additions & 4 deletions cpp/src/solver/cd.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,7 @@ void cdFit(const raft::handle_t& handle,
mu_labels.data(),
norm2_input.data(),
fit_intercept,
normalize,
stream);
normalize);
}

std::vector<int> ri(n_cols);
Expand Down Expand Up @@ -267,8 +266,7 @@ void cdFit(const raft::handle_t& handle,
mu_labels.data(),
norm2_input.data(),
fit_intercept,
normalize,
stream);
normalize);

} else {
*intercept = math_t(0);
Expand Down
6 changes: 2 additions & 4 deletions cpp/src/solver/sgd.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,7 @@ void sgdFit(const raft::handle_t& handle,
mu_labels.data(),
norm2_input.data(),
fit_intercept,
false,
stream);
false);
}

rmm::device_uvector<math_t> grads(n_cols, stream);
Expand Down Expand Up @@ -307,8 +306,7 @@ void sgdFit(const raft::handle_t& handle,
mu_labels.data(),
norm2_input.data(),
fit_intercept,
false,
stream);
false);
} else {
*intercept = math_t(0);
}
Expand Down
55 changes: 54 additions & 1 deletion cpp/test/sg/ridge.cu
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,13 @@ class RidgeTest : public ::testing::TestWithParam<RidgeInputs<T>> {
pred3(params.n_row_2, stream),
pred3_ref(params.n_row_2, stream),
coef_sc(1, stream),
coef_sc_ref(1, stream)
coef_sc_ref(1, stream),
coef_sw(1, stream),
coef_sw_ref(1, stream)
{
basicTest();
basicTest2();
testSampleWeight();
}

protected:
Expand Down Expand Up @@ -238,6 +241,49 @@ class RidgeTest : public ::testing::TestWithParam<RidgeInputs<T>> {
params.algo);
}

void testSampleWeight()
{
int len = params.n_row * params.n_col;

rmm::device_uvector<T> data_sw(len, stream);
rmm::device_uvector<T> labels_sw(len, stream);
rmm::device_uvector<T> sample_weight(len, stream);

std::vector<T> data_h = {1.0, 1.0, 2.0, 2.0, 1.0, 2.0};
data_h.resize(len);
raft::update_device(data_sw.data(), data_h.data(), len, stream);

std::vector<T> labels_h = {6.0, 8.0, 9.0, 11.0, -1.0, 2.0};
labels_h.resize(len);
raft::update_device(labels_sw.data(), labels_h.data(), len, stream);

std::vector<T> coef_sw_ref_h = {0.26052};
coef_sw_ref_h.resize(1);
raft::update_device(coef_sw_ref.data(), coef_sw_ref_h.data(), 1, stream);

std::vector<T> sample_weight_h = {0.2, 0.3, 0.09, 0.15, 0.11, 0.15};
sample_weight_h.resize(len);
raft::update_device(sample_weight.data(), sample_weight_h.data(), len, stream);

T intercept_sw = T(0);
T alpha_sw = T(1.0);

ridgeFit(handle,
data_sw.data(),
len,
1,
labels_sw.data(),
&alpha_sw,
1,
coef_sw.data(),
&intercept_sw,
true,
false,
stream,
params.algo,
sample_weight.data());
}

protected:
raft::handle_t handle;
cudaStream_t stream = 0;
Expand All @@ -247,6 +293,7 @@ class RidgeTest : public ::testing::TestWithParam<RidgeInputs<T>> {
rmm::device_uvector<T> coef2, coef2_ref, pred2, pred2_ref;
rmm::device_uvector<T> coef3, coef3_ref, pred3, pred3_ref;
rmm::device_uvector<T> coef_sc, coef_sc_ref;
rmm::device_uvector<T> coef_sw, coef_sw_ref;
T intercept, intercept2, intercept3;
};

Expand Down Expand Up @@ -279,6 +326,9 @@ TEST_P(RidgeTestF, Fit)

ASSERT_TRUE(raft::devArrMatch(
coef_sc_ref.data(), coef_sc.data(), 1, raft::CompareApproxAbs<float>(params.tol)));

ASSERT_TRUE(raft::devArrMatch(
coef_sw_ref.data(), coef_sw.data(), 1, raft::CompareApproxAbs<float>(params.tol)));
}

typedef RidgeTest<double> RidgeTestD;
Expand All @@ -304,6 +354,9 @@ TEST_P(RidgeTestD, Fit)

ASSERT_TRUE(raft::devArrMatch(
coef_sc_ref.data(), coef_sc.data(), 1, raft::CompareApproxAbs<double>(params.tol)));

ASSERT_TRUE(raft::devArrMatch(
coef_sw_ref.data(), coef_sw.data(), 1, raft::CompareApproxAbs<double>(params.tol)));
}

INSTANTIATE_TEST_CASE_P(RidgeTests, RidgeTestF, ::testing::ValuesIn(inputsf2));
Expand Down
Loading

0 comments on commit b37a968

Please sign in to comment.