diff --git a/cpp/include/raft/random/detail/multi_variable_gaussian.cuh b/cpp/include/raft/random/detail/multi_variable_gaussian.cuh index 4b92375b8e..ec4bb7818d 100644 --- a/cpp/include/raft/random/detail/multi_variable_gaussian.cuh +++ b/cpp/include/raft/random/detail/multi_variable_gaussian.cuh @@ -19,6 +19,7 @@ #include #include #include +#include #include #include #include @@ -300,7 +301,11 @@ multi_variable_gaussian_setup_token setup_multi_variable_gaussian( const raft::handle_t& handle, const int dim, multi_variable_gaussian_decomposition_method method); template -std::size_t workspace_size(const multi_variable_gaussian_setup_token& token); +multi_variable_gaussian_setup_token setup_multi_variable_gaussian( + const raft::handle_t& handle, + rmm::mr::device_memory_resource* mem_resource, + const int dim, + multi_variable_gaussian_decomposition_method method); // @param x[in] vector of dim elements // @param P[inout] On input, dim x dim matrix; overwritten on output @@ -323,7 +328,11 @@ class multi_variable_gaussian_setup_token { multi_variable_gaussian_decomposition_method method); template - friend std::size_t workspace_size(multi_variable_gaussian_setup_token& token); + friend multi_variable_gaussian_setup_token setup_multi_variable_gaussian( + const raft::handle_t& handle, + rmm::mr::device_memory_resource* mem_resource, + const int dim, + multi_variable_gaussian_decomposition_method method); template friend void compute_multi_variable_gaussian( @@ -348,36 +357,53 @@ class multi_variable_gaussian_setup_token { // Constructor, only for use by friend functions. // Hiding this will let us change the implementation in the future. multi_variable_gaussian_setup_token(const raft::handle_t& handle, + rmm::mr::device_memory_resource* mem_resource, const int dim, multi_variable_gaussian_decomposition_method method) : impl_(std::make_unique>( handle, dim, new_enum_to_old_enum(method))), + handle_(handle), + mem_resource_(mem_resource), dim_(dim) { + RAFT_EXPECTS(mem_resource_ != nullptr, "multi_variable_gaussian: device_memory_resource pointer must be nonnull"); } private: std::unique_ptr> impl_; + const raft::handle_t& handle_; + rmm::mr::device_memory_resource* mem_resource_ = nullptr; int dim_ = 0; public: // FIXME (mfh 2022/09/23) Just a hack, because my friend declarations aren't working. multi_variable_gaussian_impl& get_impl() const { return *impl_; } + auto allocate_workspace() const { + const auto num_elements = impl_->get_workspace_size(); + return rmm::device_uvector{num_elements, handle_.get_stream(), mem_resource_}; + } + int dim() const { return dim_; } }; template multi_variable_gaussian_setup_token setup_multi_variable_gaussian( - const raft::handle_t& handle, const int dim, multi_variable_gaussian_decomposition_method method) + const raft::handle_t& handle, + rmm::mr::device_memory_resource* mem_resource, + const int dim, + multi_variable_gaussian_decomposition_method method) { - return multi_variable_gaussian_setup_token(handle, dim, method); + return multi_variable_gaussian_setup_token(handle, mem_resource, dim, method); } template -std::size_t workspace_size(multi_variable_gaussian_setup_token& token) +multi_variable_gaussian_setup_token setup_multi_variable_gaussian( + const raft::handle_t& handle, const int dim, multi_variable_gaussian_decomposition_method method) { - return token.get_impl().get_workspace_size(); + rmm::mr::device_memory_resource* mem_resource = + rmm::mr::get_current_device_resource(); + return multi_variable_gaussian_setup_token(handle, mem_resource, dim, method); } template @@ -385,13 +411,8 @@ void compute_multi_variable_gaussian( multi_variable_gaussian_setup_token& token, std::optional> x, raft::device_matrix_view P, - raft::device_matrix_view X, - raft::device_vector_view workspace) + raft::device_matrix_view X) { - RAFT_EXPECTS(static_cast(workspace.extent(0)) >= workspace_size(token), - "multi_variable_gaussian: Insufficient workspace"); - token.get_impl().set_workspace(workspace.data_handle()); - const int dim = P.extent(0); RAFT_EXPECTS(dim == token.dim(), "multi_variable_gaussian: " @@ -409,7 +430,6 @@ void compute_multi_variable_gaussian( "P.extent(0) = %d != X.extent(0) = %d", P.extent(0), X.extent(0)); - const bool x_has_value = x.has_value(); const int x_extent_0 = x_has_value ? (*x).extent(0) : 0; RAFT_EXPECTS(not x_has_value || P.extent(0) == x_extent_0, @@ -420,6 +440,9 @@ void compute_multi_variable_gaussian( const int nPoints = X.extent(1); const ValueType* x_ptr = x_has_value ? (*x).data_handle() : nullptr; + + auto workspace = token.allocate_workspace(); + token.get_impl().set_workspace(workspace.data()); token.get_impl().give_gaussian(nPoints, P.data_handle(), X.data_handle(), x_ptr); } diff --git a/cpp/test/random/multi_variable_gaussian.cu b/cpp/test/random/multi_variable_gaussian.cu index aeb22af48c..e44ae1beaf 100644 --- a/cpp/test/random/multi_variable_gaussian.cu +++ b/cpp/test/random/multi_variable_gaussian.cu @@ -279,19 +279,21 @@ class MVGMdspanTest : public ::testing::TestWithParam> { raft::update_device(x_d.data(), x.data(), dim, stream); // Set up the multivariable Gaussian computation - auto token = detail::setup_multi_variable_gaussian(handle, dim, method); - std::size_t o = detail::workspace_size(token); - - // give the workspace area to mvg - workspace_d.resize(o, stream); + { + // Test that setup with a default memory resource compiles. + auto token = detail::setup_multi_variable_gaussian(handle, dim, method); + ASSERT_EQ(dim, token.dim()); // just so token is used + } + rmm::mr::device_memory_resource* mem_resource = rmm::mr::get_current_device_resource(); + ASSERT_TRUE(mem_resource != nullptr); + auto token = detail::setup_multi_variable_gaussian(handle, mem_resource, dim, method); - raft::device_vector_view workspace_view(workspace_d.data(), o); std::optional> x_view(std::in_place, x_d.data(), dim); raft::device_matrix_view P_view(P_d.data(), dim, dim); raft::device_matrix_view X_view(X_d.data(), dim, nPoints); // X_view is the output. - detail::compute_multi_variable_gaussian(token, x_view, P_view, X_view, workspace_view); + detail::compute_multi_variable_gaussian(token, x_view, P_view, X_view); // saving the mean of the randoms in Rand_mean //@todo can be swapped with a API that calculates mean