Skip to content

Commit

Permalink
mdspan-ify make_regression
Browse files Browse the repository at this point in the history
Add an overload of make_regression
that takes mdspan, instead of raw pointers.
The overload does not increase generality
(e.g., it still requires row-major mdspan).

Part of rapidsai#535.
  • Loading branch information
mhoemmen committed Sep 7, 2022
1 parent b5b66a3 commit ab65a7c
Show file tree
Hide file tree
Showing 2 changed files with 204 additions and 2 deletions.
90 changes: 88 additions & 2 deletions cpp/include/raft/random/make_regression.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
#pragma once

#include <algorithm>
#include <optional>
#include <raft/mdarray.hpp>

#include "detail/make_regression.cuh"

Expand Down Expand Up @@ -58,7 +60,7 @@ namespace raft::random {
* @param[in] tail_strength The relative importance of the fat noisy tail
* of the singular values profile if
* effective_rank is not -1
* @param[in] noise Standard deviation of the gaussian noise
* @param[in] noise Standard deviation of the Gaussian noise
* applied to the output
* @param[in] shuffle Shuffle the samples and the features
* @param[in] seed Seed for the random number generator
Expand Down Expand Up @@ -100,6 +102,90 @@ void make_regression(const raft::handle_t& handle,
type);
}

/**
* @brief GPU-equivalent of sklearn.datasets.make_regression as documented at:
* https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_regression.html
*
* @tparam DataT Scalar type
* @tparam IdxT Index type
*
* @param[in] handle RAFT handle
* @param[out] out Row-major (samples, features) matrix to store
* the problem data
* @param[out] values Row-major (samples, targets) matrix to store
* the values for the regression problem
* @param[in] n_informative Number of informative features (non-zero
* coefficients)
* @param[in] stream CUDA stream
* @param[out] coef If present, a row-major (features, targets) matrix
* to store the coefficients used to generate the values
* for the regression problem
* @param[in] bias A scalar that will be added to the values
* @param[in] effective_rank The approximate rank of the data matrix (used
* to create correlations in the data). -1 is the
* code to use well-conditioned data
* @param[in] tail_strength The relative importance of the fat noisy tail
* of the singular values profile if
* effective_rank is not -1
* @param[in] noise Standard deviation of the Gaussian noise
* applied to the output
* @param[in] shuffle Shuffle the samples and the features
* @param[in] seed Seed for the random number generator
* @param[in] type Random generator type
*/
template <typename DataT, typename IdxT>
void make_regression(const raft::handle_t& handle,
raft::device_matrix_view<DataT,
raft::matrix_extent<IdxT>,
raft::row_major> out,
raft::device_matrix_view<DataT,
raft::matrix_extent<IdxT>,
raft::row_major> values,
IdxT n_informative,
cudaStream_t stream,
std::optional<
raft::device_matrix_view<DataT,
raft::matrix_extent<IdxT>,
raft::row_major>> coef,
DataT bias = DataT{},
IdxT effective_rank = static_cast<IdxT>(-1),
DataT tail_strength = DataT{0.5},
DataT noise = DataT{},
bool shuffle = true,
uint64_t seed = 0ULL,
GeneratorType type = GenPhilox)
{
const auto n_samples = out.extent(0);
assert(values.extent(0) == n_samples);
const auto n_features = out.extent(1);
const auto n_targets = values.extent(1);

const bool have_coef = coef.has_value();
if(have_coef) {
const auto coef_ref = *coef;
assert(coef_ref.extent(0) == n_features);
assert(coef_ref.extent(1) == n_targets);
}
DataT* coef_ptr = have_coef ? (*coef).data_handle() : nullptr;

detail::make_regression_caller(handle,
out.data_handle(),
values.data_handle(),
n_samples,
n_features,
n_informative,
stream,
coef_ptr,
n_targets,
bias,
effective_rank,
tail_strength,
noise,
shuffle,
seed,
type);
}

} // namespace raft::random

#endif
#endif
116 changes: 116 additions & 0 deletions cpp/test/random/make_regression.cu
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ class MakeRegressionTest : public ::testing::TestWithParam<MakeRegressionInputs<
params.seed,
params.gtype);

// FIXME (mfh 2022/09/07) This test passes even if I don't call make_regression.

// Calculate the values from the data and coefficients (column-major)
T alpha = (T)1.0, beta = (T)0.0;
RAFT_CUBLAS_TRY(raft::linalg::detail::cublasgemm(handle.get_cublas_handle(),
Expand Down Expand Up @@ -161,4 +163,118 @@ TEST_P(MakeRegressionTestD, Result)
}
INSTANTIATE_TEST_CASE_P(MakeRegressionTests, MakeRegressionTestD, ::testing::ValuesIn(inputsd_t));

template <typename T>
class MakeRegressionMdspanTest : public ::testing::TestWithParam<MakeRegressionInputs<T>> {
public:
MakeRegressionMdspanTest() = default;

protected:
void SetUp() override
{
// Noise must be zero to compare the actual and expected values
T noise = (T)0.0, tail_strength = (T)0.5;

rmm::device_uvector<T> data(params.n_samples * params.n_features, stream);
rmm::device_uvector<T> values_cm(params.n_samples * params.n_targets, stream);
rmm::device_uvector<T> coef(params.n_features * params.n_targets, stream);

using index_type = typename rmm::device_uvector<T>::index_type;
using matrix_view = raft::device_matrix_view<T, raft::matrix_extent<index_type>, raft::row_major>;
matrix_view out_mat(data.data(), params.n_samples, params.n_features);
matrix_view values_mat(values_ret.data(), params.n_samples, params.n_targets);
matrix_view coef_mat(coef.data(), params.n_features, params.n_targets);

// Create the regression problem
make_regression(handle, out_mat, values_mat,
params.n_informative,
stream,
coef_mat,
params.bias,
params.effective_rank,
tail_strength,
noise,
params.shuffle,
params.seed,
params.gtype);

// FIXME (mfh 2022/09/07) This test passes even if I don't call make_regression.

// Calculate the values from the data and coefficients (column-major)
T alpha{};
T beta{};
RAFT_CUBLAS_TRY(raft::linalg::detail::cublasgemm(handle.get_cublas_handle(),
CUBLAS_OP_T,
CUBLAS_OP_T,
params.n_samples,
params.n_targets,
params.n_features,
&alpha,
data.data(),
params.n_features,
coef.data(),
params.n_targets,
&beta,
values_cm.data(),
params.n_samples,
stream));

// Transpose the values to row-major
raft::linalg::transpose(
handle, values_cm.data(), values_prod.data(), params.n_samples, params.n_targets, stream);

// Add the bias
raft::linalg::addScalar(values_prod.data(),
values_prod.data(),
params.bias,
params.n_samples * params.n_targets,
stream);

// Count the number of zeroes in the coefficients
thrust::device_ptr<T> __coef = thrust::device_pointer_cast(coef.data());
constexpr T ZERO{};
zero_count = thrust::count(__coef, __coef + params.n_features * params.n_targets, ZERO);
}

private:
MakeRegressionInputs<T> params{
::testing::TestWithParam<MakeRegressionInputs<T>>::GetParam()};
raft::handle_t handle;
cudaStream_t stream{handle.get_stream()};
rmm::device_uvector<T> values_ret{params.n_samples * params.n_targets, stream};
rmm::device_uvector<T> values_prod{params.n_samples * params.n_targets, stream};
int zero_count = -1;
};

using MakeRegressionMdspanTestF = MakeRegressionTest<float>;

TEST_P(MakeRegressionMdspanTestF, Result)
{
ASSERT_TRUE(match(params.n_targets * (params.n_features - params.n_informative),
zero_count,
raft::Compare<int>()));
ASSERT_TRUE(devArrMatch(values_ret.data(),
values_prod.data(),
params.n_samples,
params.n_targets,
raft::CompareApprox<float>(params.tolerance),
stream));
}
INSTANTIATE_TEST_CASE_P(MakeRegressionMdspanTests, MakeRegressionMdspanTestF, ::testing::ValuesIn(inputsf_t));

using MakeRegressionMdspanTestD = MakeRegressionTest<double>;

TEST_P(MakeRegressionMdspanTestD, Result)
{
ASSERT_TRUE(match(params.n_targets * (params.n_features - params.n_informative),
zero_count,
raft::Compare<int>()));
ASSERT_TRUE(devArrMatch(values_ret.data(),
values_prod.data(),
params.n_samples,
params.n_targets,
raft::CompareApprox<double>(params.tolerance),
stream));
}
INSTANTIATE_TEST_CASE_P(MakeRegressionMdspanTests, MakeRegressionMdspanTestD, ::testing::ValuesIn(inputsd_t));

} // end namespace raft::random

0 comments on commit ab65a7c

Please sign in to comment.