diff --git a/cpp/include/raft/random/detail/multi_variable_gaussian.cuh b/cpp/include/raft/random/detail/multi_variable_gaussian.cuh index ec4bb7818d..0899c7551c 100644 --- a/cpp/include/raft/random/detail/multi_variable_gaussian.cuh +++ b/cpp/include/raft/random/detail/multi_variable_gaussian.cuh @@ -19,7 +19,6 @@ #include #include #include -#include #include #include #include @@ -28,6 +27,7 @@ #include #include #include +#include #include #include @@ -297,51 +297,36 @@ template class multi_variable_gaussian_setup_token; template -multi_variable_gaussian_setup_token setup_multi_variable_gaussian( - const raft::handle_t& handle, const int dim, multi_variable_gaussian_decomposition_method method); - -template -multi_variable_gaussian_setup_token setup_multi_variable_gaussian( +multi_variable_gaussian_setup_token setup_multi_variable_gaussian_impl( const raft::handle_t& handle, - rmm::mr::device_memory_resource* mem_resource, + rmm::mr::device_memory_resource& mem_resource, const int dim, - multi_variable_gaussian_decomposition_method method); + const multi_variable_gaussian_decomposition_method method); -// @param x[in] vector of dim elements -// @param P[inout] On input, dim x dim matrix; overwritten on output -// @param X[out] dim x nPoints matrix template -void compute_multi_variable_gaussian( +void compute_multi_variable_gaussian_impl( multi_variable_gaussian_setup_token& token, std::optional> x, raft::device_matrix_view P, - raft::device_matrix_view X, - raft::device_vector_view workspace); + raft::device_matrix_view X); template class multi_variable_gaussian_setup_token { - private: template - friend multi_variable_gaussian_setup_token setup_multi_variable_gaussian( + friend multi_variable_gaussian_setup_token setup_multi_variable_gaussian_impl( const raft::handle_t& handle, + rmm::mr::device_memory_resource& mem_resource, const int dim, - multi_variable_gaussian_decomposition_method method); + const multi_variable_gaussian_decomposition_method method); template - friend multi_variable_gaussian_setup_token setup_multi_variable_gaussian( - const raft::handle_t& handle, - rmm::mr::device_memory_resource* mem_resource, - const int dim, - multi_variable_gaussian_decomposition_method method); - - template - friend void compute_multi_variable_gaussian( + friend void compute_multi_variable_gaussian_impl( multi_variable_gaussian_setup_token& token, std::optional> x, - raft::device_matrix_view P, - raft::device_matrix_view X, - raft::device_vector_view workspace); + raft::device_matrix_view P, + raft::device_matrix_view X); + private: typename multi_variable_gaussian_impl::Decomposer new_enum_to_old_enum( multi_variable_gaussian_decomposition_method method) { @@ -357,93 +342,107 @@ class multi_variable_gaussian_setup_token { // Constructor, only for use by friend functions. // Hiding this will let us change the implementation in the future. multi_variable_gaussian_setup_token(const raft::handle_t& handle, - rmm::mr::device_memory_resource* mem_resource, + rmm::mr::device_memory_resource& mem_resource, const int dim, - multi_variable_gaussian_decomposition_method method) + const multi_variable_gaussian_decomposition_method method) : impl_(std::make_unique>( handle, dim, new_enum_to_old_enum(method))), handle_(handle), mem_resource_(mem_resource), dim_(dim) { - RAFT_EXPECTS(mem_resource_ != nullptr, "multi_variable_gaussian: device_memory_resource pointer must be nonnull"); + } + + /** + * @brief Compute the multivariable Gaussian. + * + * @param[in] x vector of dim elements + * @param[inout] P On input, dim x dim matrix; overwritten on output + * @param[out] X dim x nPoints matrix + */ + void compute(std::optional> x, + raft::device_matrix_view P, + raft::device_matrix_view X) + { + const int input_dim = P.extent(0); + RAFT_EXPECTS(input_dim == dim(), + "multi_variable_gaussian: " + "P.extent(0) = %d does not match the extent %d " + "with which the token was created", + input_dim, + dim()); + RAFT_EXPECTS(P.extent(0) == P.extent(1), + "multi_variable_gaussian: " + "P must be square, but P.extent(0) = %d != P.extent(1) = %d", + P.extent(0), + P.extent(1)); + RAFT_EXPECTS(P.extent(0) == X.extent(0), + "multi_variable_gaussian: " + "P.extent(0) = %d != X.extent(0) = %d", + P.extent(0), + X.extent(0)); + const bool x_has_value = x.has_value(); + const int x_extent_0 = x_has_value ? (*x).extent(0) : 0; + RAFT_EXPECTS(not x_has_value || P.extent(0) == x_extent_0, + "multi_variable_gaussian: " + "P.extent(0) = %d != x.extent(0) = %d", + P.extent(0), + x_extent_0); + const int nPoints = X.extent(1); + const ValueType* x_ptr = x_has_value ? (*x).data_handle() : nullptr; + + auto workspace = allocate_workspace(); + impl_->set_workspace(workspace.data()); + impl_->give_gaussian(nPoints, P.data_handle(), X.data_handle(), x_ptr); } private: std::unique_ptr> impl_; const raft::handle_t& handle_; - rmm::mr::device_memory_resource* mem_resource_ = nullptr; + rmm::mr::device_memory_resource& mem_resource_; int dim_ = 0; - public: - // FIXME (mfh 2022/09/23) Just a hack, because my friend declarations aren't working. - multi_variable_gaussian_impl& get_impl() const { return *impl_; } - - auto allocate_workspace() const { + auto allocate_workspace() const + { const auto num_elements = impl_->get_workspace_size(); - return rmm::device_uvector{num_elements, handle_.get_stream(), mem_resource_}; + return rmm::device_uvector{num_elements, handle_.get_stream(), &mem_resource_}; } int dim() const { return dim_; } }; template -multi_variable_gaussian_setup_token setup_multi_variable_gaussian( +multi_variable_gaussian_setup_token setup_multi_variable_gaussian_impl( const raft::handle_t& handle, - rmm::mr::device_memory_resource* mem_resource, + rmm::mr::device_memory_resource& mem_resource, const int dim, - multi_variable_gaussian_decomposition_method method) + const multi_variable_gaussian_decomposition_method method) { return multi_variable_gaussian_setup_token(handle, mem_resource, dim, method); } template -multi_variable_gaussian_setup_token setup_multi_variable_gaussian( - const raft::handle_t& handle, const int dim, multi_variable_gaussian_decomposition_method method) +void compute_multi_variable_gaussian_impl( + multi_variable_gaussian_setup_token& token, + std::optional> x, + raft::device_matrix_view P, + raft::device_matrix_view X) { - rmm::mr::device_memory_resource* mem_resource = - rmm::mr::get_current_device_resource(); - return multi_variable_gaussian_setup_token(handle, mem_resource, dim, method); + token.compute(x, P, X); } template -void compute_multi_variable_gaussian( - multi_variable_gaussian_setup_token& token, +void compute_multi_variable_gaussian_impl( + const raft::handle_t& handle, + rmm::mr::device_memory_resource& mem_resource, std::optional> x, raft::device_matrix_view P, - raft::device_matrix_view X) + raft::device_matrix_view X, + const multi_variable_gaussian_decomposition_method method) { - const int dim = P.extent(0); - RAFT_EXPECTS(dim == token.dim(), - "multi_variable_gaussian: " - "P.extent(0) = %d does not match the dimension %d " - "with which the token was created", - P.extent(0), - token.dim()); - RAFT_EXPECTS(P.extent(0) == P.extent(1), - "multi_variable_gaussian: " - "P must be square, but P.extent(0) = %d != P.extent(1) = %d", - P.extent(0), - P.extent(1)); - RAFT_EXPECTS(P.extent(0) == X.extent(0), - "multi_variable_gaussian: " - "P.extent(0) = %d != X.extent(0) = %d", - P.extent(0), - X.extent(0)); - const bool x_has_value = x.has_value(); - const int x_extent_0 = x_has_value ? (*x).extent(0) : 0; - RAFT_EXPECTS(not x_has_value || P.extent(0) == x_extent_0, - "multi_variable_gaussian: " - "P.extent(0) = %d != x.extent(0) = %d", - P.extent(0), - x_extent_0); - - const int nPoints = X.extent(1); - const ValueType* x_ptr = x_has_value ? (*x).data_handle() : nullptr; - - auto workspace = token.allocate_workspace(); - token.get_impl().set_workspace(workspace.data()); - token.get_impl().give_gaussian(nPoints, P.data_handle(), X.data_handle(), x_ptr); + auto token = + setup_multi_variable_gaussian_impl(handle, mem_resource, P.extent(0), method); + compute_multi_variable_gaussian_impl(token, x, P, X); } }; // end of namespace detail diff --git a/cpp/include/raft/random/multi_variable_gaussian.cuh b/cpp/include/raft/random/multi_variable_gaussian.cuh index 1d9d63f6c5..a44c67b145 100644 --- a/cpp/include/raft/random/multi_variable_gaussian.cuh +++ b/cpp/include/raft/random/multi_variable_gaussian.cuh @@ -59,6 +59,81 @@ class multi_variable_gaussian : public detail::multi_variable_gaussian_impl { ~multi_variable_gaussian() { deinit(); } }; // end of multi_variable_gaussian +/** + * @brief Matrix decomposition method for `compute_multi_variable_gaussian` to use. + * + * `compute_multi_variable_gaussian` can use any of the following methods. + * + * - `CHOLESKY`: Cholesky decomposition + * - `JACOBI`: + + */ +using detail::multi_variable_gaussian_decomposition_method; +using detail::multi_variable_gaussian_setup_token; + +template +multi_variable_gaussian_setup_token setup_multi_variable_gaussian( + const raft::handle_t& handle, + const int dim, + const multi_variable_gaussian_decomposition_method method) +{ + rmm::mr::device_memory_resource* mem_resource_ptr = rmm::mr::get_current_device_resource(); + RAFT_EXPECTS(mem_resource_ptr != nullptr, + "setup_multi_variable_gaussian: " + "rmm::mr::get_current_device_resource() returned null; " + "please report this bug to the RAPIDS RAFT developers."); + return detail::setup_multi_variable_gaussian_impl( + handle, *mem_resource_ptr, dim, method); +} + +template +multi_variable_gaussian_setup_token setup_multi_variable_gaussian( + const raft::handle_t& handle, + rmm::mr::device_memory_resource& mem_resource, + const int dim, + multi_variable_gaussian_decomposition_method method) +{ + return detail::setup_multi_variable_gaussian_impl(handle, mem_resource, dim, method); +} + +template +void compute_multi_variable_gaussian( + multi_variable_gaussian_setup_token& token, + std::optional> x, + raft::device_matrix_view P, + raft::device_matrix_view X) +{ + detail::compute_multi_variable_gaussian_impl(token, x, P, X); +} + +template +void compute_multi_variable_gaussian( + const raft::handle_t& handle, + rmm::mr::device_memory_resource& mem_resource, + std::optional> x, + raft::device_matrix_view P, + raft::device_matrix_view X, + const multi_variable_gaussian_decomposition_method method) +{ + detail::compute_multi_variable_gaussian_impl(handle, mem_resource, x, P, X, method); +} + +template +void compute_multi_variable_gaussian( + const raft::handle_t& handle, + std::optional> x, + raft::device_matrix_view P, + raft::device_matrix_view X, + const multi_variable_gaussian_decomposition_method method) +{ + rmm::mr::device_memory_resource* mem_resource_ptr = rmm::mr::get_current_device_resource(); + RAFT_EXPECTS(mem_resource_ptr != nullptr, + "compute_multi_variable_gaussian: " + "rmm::mr::get_current_device_resource() returned null; " + "please report this bug to the RAPIDS RAFT developers."); + detail::compute_multi_variable_gaussian_impl(handle, *mem_resource_ptr, x, P, X, method); +} + }; // end of namespace raft::random -#endif \ No newline at end of file +#endif diff --git a/cpp/test/random/multi_variable_gaussian.cu b/cpp/test/random/multi_variable_gaussian.cu index e44ae1beaf..274edb454d 100644 --- a/cpp/test/random/multi_variable_gaussian.cu +++ b/cpp/test/random/multi_variable_gaussian.cu @@ -244,7 +244,6 @@ class MVGMdspanTest : public ::testing::TestWithParam> { auto cusolverH = handle.get_cusolver_dn_handle(); auto stream = handle.get_stream(); - // preparing to store stuff P.resize(dim * dim); x.resize(dim); X.resize(dim * nPoints); @@ -254,12 +253,10 @@ class MVGMdspanTest : public ::testing::TestWithParam> { Rand_cov.resize(dim * dim, stream); Rand_mean.resize(dim, stream); - // generating random mean and cov. srand(params.seed); for (int j = 0; j < dim; j++) x.data()[j] = rand() % 100 + 5.0f; - // for random Cov. matrix std::default_random_engine generator(params.seed); std::uniform_real_distribution distribution(0.0, 1.0); @@ -274,26 +271,23 @@ class MVGMdspanTest : public ::testing::TestWithParam> { } } - // porting inputs to gpu raft::update_device(P_d.data(), P.data(), dim * dim, stream); raft::update_device(x_d.data(), x.data(), dim, stream); - // Set up the multivariable Gaussian computation - { - // Test that setup with a default memory resource compiles. - auto token = detail::setup_multi_variable_gaussian(handle, dim, method); - ASSERT_EQ(dim, token.dim()); // just so token is used - } - rmm::mr::device_memory_resource* mem_resource = rmm::mr::get_current_device_resource(); - ASSERT_TRUE(mem_resource != nullptr); - auto token = detail::setup_multi_variable_gaussian(handle, mem_resource, dim, method); - std::optional> x_view(std::in_place, x_d.data(), dim); raft::device_matrix_view P_view(P_d.data(), dim, dim); raft::device_matrix_view X_view(X_d.data(), dim, nPoints); - // X_view is the output. - detail::compute_multi_variable_gaussian(token, x_view, P_view, X_view); + { + // Test that setup with a default memory resource compiles. + auto token = raft::random::setup_multi_variable_gaussian(handle, dim, method); + (void)token; + } + + rmm::mr::device_memory_resource* mem_resource_ptr = rmm::mr::get_current_device_resource(); + ASSERT_TRUE(mem_resource_ptr != nullptr); + raft::random::compute_multi_variable_gaussian( + handle, *mem_resource_ptr, x_view, P_view, X_view, method); // saving the mean of the randoms in Rand_mean //@todo can be swapped with a API that calculates mean @@ -446,10 +440,6 @@ TEST_P(MVGTestD, CovIsCorrectD) << " in CovIsCorrect"; } -// call the tests -INSTANTIATE_TEST_CASE_P(MVGTests, MVGTestF, ::testing::ValuesIn(inputsf)); -INSTANTIATE_TEST_CASE_P(MVGTests, MVGTestD, ::testing::ValuesIn(inputsd)); - using MVGMdspanTestF = MVGMdspanTest; using MVGMdspanTestD = MVGMdspanTest; TEST_P(MVGMdspanTestF, MeanIsCorrectF) @@ -485,6 +475,10 @@ TEST_P(MVGMdspanTestD, CovIsCorrectD) << " in CovIsCorrect"; } +// call the tests +INSTANTIATE_TEST_CASE_P(MVGTests, MVGTestF, ::testing::ValuesIn(inputsf)); +INSTANTIATE_TEST_CASE_P(MVGTests, MVGTestD, ::testing::ValuesIn(inputsd)); + // call the tests INSTANTIATE_TEST_CASE_P(MVGMdspanTests, MVGMdspanTestF, ::testing::ValuesIn(inputsf)); INSTANTIATE_TEST_CASE_P(MVGMdspanTests, MVGMdspanTestD, ::testing::ValuesIn(inputsd));