diff --git a/cpp/include/raft/random/make_regression.cuh b/cpp/include/raft/random/make_regression.cuh index e18f0ac8d6..aac0d19e00 100644 --- a/cpp/include/raft/random/make_regression.cuh +++ b/cpp/include/raft/random/make_regression.cuh @@ -116,7 +116,6 @@ void make_regression(const raft::handle_t& handle, * the values for the regression problem * @param[in] n_informative Number of informative features (non-zero * coefficients) - * @param[in] stream CUDA stream * @param[out] coef If present, a row-major (features, targets) matrix * to store the coefficients used to generate the values * for the regression problem @@ -139,7 +138,6 @@ void make_regression( raft::device_matrix_view, raft::row_major> out, raft::device_matrix_view, raft::row_major> values, IdxT n_informative, - cudaStream_t stream, std::optional, raft::row_major>> coef, DataT bias = DataT{}, IdxT effective_rank = static_cast(-1), @@ -168,7 +166,7 @@ void make_regression( n_samples, n_features, n_informative, - stream, + handle.get_stream(), coef_ptr, n_targets, bias, diff --git a/cpp/test/random/make_regression.cu b/cpp/test/random/make_regression.cu index 30d81cea92..84dadf1e24 100644 --- a/cpp/test/random/make_regression.cu +++ b/cpp/test/random/make_regression.cu @@ -78,7 +78,9 @@ class MakeRegressionTest : public ::testing::TestWithParam params{::testing::TestWithParam>::GetParam()}; raft::handle_t handle; - cudaStream_t stream{handle.get_stream()}; - rmm::device_uvector values_ret{params.n_samples * params.n_targets, stream}; - rmm::device_uvector values_prod{params.n_samples * params.n_targets, stream}; + rmm::device_uvector values_ret{params.n_samples * params.n_targets, handle.get_stream()}; + rmm::device_uvector values_prod{params.n_samples * params.n_targets, handle.get_stream()}; int zero_count = -1; }; @@ -259,7 +263,7 @@ TEST_P(MakeRegressionMdspanTestF, Result) params.n_samples, params.n_targets, raft::CompareApprox(params.tolerance), - stream)); + handle.get_stream())); } INSTANTIATE_TEST_CASE_P(MakeRegressionMdspanTests, MakeRegressionMdspanTestF, @@ -277,7 +281,7 @@ TEST_P(MakeRegressionMdspanTestD, Result) params.n_samples, params.n_targets, raft::CompareApprox(params.tolerance), - stream)); + handle.get_stream())); } INSTANTIATE_TEST_CASE_P(MakeRegressionMdspanTests, MakeRegressionMdspanTestD,