Skip to content

Commit

Permalink
Address reviewer feedback
Browse files Browse the repository at this point in the history
1. Add new issue #814
   to capture observation that the make_regression test
   still passes even if the test doesn't call make_regression.

2. Remove stream from new make_regression interface,
   since the handle encapsulates the stream.
  • Loading branch information
mhoemmen committed Sep 8, 2022
1 parent 95b475d commit 1b6103a
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 13 deletions.
4 changes: 1 addition & 3 deletions cpp/include/raft/random/make_regression.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,6 @@ void make_regression(const raft::handle_t& handle,
* the values for the regression problem
* @param[in] n_informative Number of informative features (non-zero
* coefficients)
* @param[in] stream CUDA stream
* @param[out] coef If present, a row-major (features, targets) matrix
* to store the coefficients used to generate the values
* for the regression problem
Expand All @@ -139,7 +138,6 @@ void make_regression(
raft::device_matrix_view<DataT, raft::matrix_extent<IdxT>, raft::row_major> out,
raft::device_matrix_view<DataT, raft::matrix_extent<IdxT>, raft::row_major> values,
IdxT n_informative,
cudaStream_t stream,
std::optional<raft::device_matrix_view<DataT, raft::matrix_extent<IdxT>, raft::row_major>> coef,
DataT bias = DataT{},
IdxT effective_rank = static_cast<IdxT>(-1),
Expand Down Expand Up @@ -168,7 +166,7 @@ void make_regression(
n_samples,
n_features,
n_informative,
stream,
handle.get_stream(),
coef_ptr,
n_targets,
bias,
Expand Down
23 changes: 13 additions & 10 deletions cpp/test/random/make_regression.cu
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,9 @@ class MakeRegressionTest : public ::testing::TestWithParam<MakeRegressionInputs<
params.seed,
params.gtype);

// FIXME (mfh 2022/09/07) This test passes even if I don't call make_regression.
// FIXME (mfh 2022/09/07) This test passes even if it doesn't call
// make_regression. Please see
// https://github.com/rapidsai/raft/issues/814.

// Calculate the values from the data and coefficients (column-major)
T alpha = (T)1.0, beta = (T)0.0;
Expand Down Expand Up @@ -171,6 +173,8 @@ class MakeRegressionMdspanTest : public ::testing::TestWithParam<MakeRegressionI
protected:
void SetUp() override
{
auto stream = handle.get_stream();

// Noise must be zero to compare the actual and expected values
T noise = (T)0.0, tail_strength = (T)0.5;

Expand All @@ -190,7 +194,6 @@ class MakeRegressionMdspanTest : public ::testing::TestWithParam<MakeRegressionI
out_mat,
values_mat,
params.n_informative,
stream,
coef_mat,
params.bias,
params.effective_rank,
Expand All @@ -200,7 +203,9 @@ class MakeRegressionMdspanTest : public ::testing::TestWithParam<MakeRegressionI
params.seed,
params.gtype);

// FIXME (mfh 2022/09/07) This test passes even if I don't call make_regression.
// FIXME (mfh 2022/09/07) This test passes even if it doesn't call
// make_regression. Please see
// https://github.com/rapidsai/raft/issues/814.

// Calculate the values from the data and coefficients (column-major)
T alpha{};
Expand All @@ -222,8 +227,7 @@ class MakeRegressionMdspanTest : public ::testing::TestWithParam<MakeRegressionI
stream));

// Transpose the values to row-major
raft::linalg::transpose(
handle, values_cm.data(), values_prod.data(), params.n_samples, params.n_targets, stream);
raft::linalg::transpose(handle, values_cm.data(), values_prod.data(), params.n_samples, params.n_targets, stream);

// Add the bias
raft::linalg::addScalar(values_prod.data(),
Expand All @@ -241,9 +245,8 @@ class MakeRegressionMdspanTest : public ::testing::TestWithParam<MakeRegressionI
private:
MakeRegressionInputs<T> params{::testing::TestWithParam<MakeRegressionInputs<T>>::GetParam()};
raft::handle_t handle;
cudaStream_t stream{handle.get_stream()};
rmm::device_uvector<T> values_ret{params.n_samples * params.n_targets, stream};
rmm::device_uvector<T> values_prod{params.n_samples * params.n_targets, stream};
rmm::device_uvector<T> values_ret{params.n_samples * params.n_targets, handle.get_stream()};
rmm::device_uvector<T> values_prod{params.n_samples * params.n_targets, handle.get_stream()};
int zero_count = -1;
};

Expand All @@ -259,7 +262,7 @@ TEST_P(MakeRegressionMdspanTestF, Result)
params.n_samples,
params.n_targets,
raft::CompareApprox<float>(params.tolerance),
stream));
handle.get_stream()));
}
INSTANTIATE_TEST_CASE_P(MakeRegressionMdspanTests,
MakeRegressionMdspanTestF,
Expand All @@ -277,7 +280,7 @@ TEST_P(MakeRegressionMdspanTestD, Result)
params.n_samples,
params.n_targets,
raft::CompareApprox<double>(params.tolerance),
stream));
handle.get_stream()));
}
INSTANTIATE_TEST_CASE_P(MakeRegressionMdspanTests,
MakeRegressionMdspanTestD,
Expand Down

0 comments on commit 1b6103a

Please sign in to comment.