diff --git a/cpp/include/raft/random/detail/multi_variable_gaussian.cuh b/cpp/include/raft/random/detail/multi_variable_gaussian.cuh index 636d31c04e..2d19773c3b 100644 --- a/cpp/include/raft/random/detail/multi_variable_gaussian.cuh +++ b/cpp/include/raft/random/detail/multi_variable_gaussian.cuh @@ -16,7 +16,11 @@ #pragma once #include "curand_wrappers.hpp" +#include "random_types.hpp" #include +#include +#include +#include #include #include #include @@ -24,7 +28,9 @@ #include #include #include +#include #include +#include // mvg.cuh takes in matrices that are colomn major (as in fortan) #define IDX2C(i, j, ld) (j * ld + i) @@ -286,5 +292,157 @@ class multi_variable_gaussian_impl { ~multi_variable_gaussian_impl() { deinit(); } }; // end of multi_variable_gaussian_impl +template +class multi_variable_gaussian_setup_token; + +template +multi_variable_gaussian_setup_token build_multi_variable_gaussian_token_impl( + const raft::handle_t& handle, + rmm::mr::device_memory_resource& mem_resource, + const int dim, + const multi_variable_gaussian_decomposition_method method); + +template +void compute_multi_variable_gaussian_impl( + multi_variable_gaussian_setup_token& token, + std::optional> x, + raft::device_matrix_view P, + raft::device_matrix_view X); + +template +class multi_variable_gaussian_setup_token { + template + friend multi_variable_gaussian_setup_token build_multi_variable_gaussian_token_impl( + const raft::handle_t& handle, + rmm::mr::device_memory_resource& mem_resource, + const int dim, + const multi_variable_gaussian_decomposition_method method); + + template + friend void compute_multi_variable_gaussian_impl( + multi_variable_gaussian_setup_token& token, + std::optional> x, + raft::device_matrix_view P, + raft::device_matrix_view X); + + private: + typename multi_variable_gaussian_impl::Decomposer new_enum_to_old_enum( + multi_variable_gaussian_decomposition_method method) + { + if (method == multi_variable_gaussian_decomposition_method::CHOLESKY) { + return multi_variable_gaussian_impl::chol_decomp; + } else if (method == multi_variable_gaussian_decomposition_method::JACOBI) { + return multi_variable_gaussian_impl::jacobi; + } else { + return multi_variable_gaussian_impl::qr; + } + } + + // Constructor, only for use by friend functions. + // Hiding this will let us change the implementation in the future. + multi_variable_gaussian_setup_token(const raft::handle_t& handle, + rmm::mr::device_memory_resource& mem_resource, + const int dim, + const multi_variable_gaussian_decomposition_method method) + : impl_(std::make_unique>( + handle, dim, new_enum_to_old_enum(method))), + handle_(handle), + mem_resource_(mem_resource), + dim_(dim) + { + } + + /** + * @brief Compute the multivariable Gaussian. + * + * @param[in] x vector of dim elements + * @param[inout] P On input, dim x dim matrix; overwritten on output + * @param[out] X dim x nPoints matrix + */ + void compute(std::optional> x, + raft::device_matrix_view P, + raft::device_matrix_view X) + { + const int input_dim = P.extent(0); + RAFT_EXPECTS(input_dim == dim(), + "multi_variable_gaussian: " + "P.extent(0) = %d does not match the extent %d " + "with which the token was created", + input_dim, + dim()); + RAFT_EXPECTS(P.extent(0) == P.extent(1), + "multi_variable_gaussian: " + "P must be square, but P.extent(0) = %d != P.extent(1) = %d", + P.extent(0), + P.extent(1)); + RAFT_EXPECTS(P.extent(0) == X.extent(0), + "multi_variable_gaussian: " + "P.extent(0) = %d != X.extent(0) = %d", + P.extent(0), + X.extent(0)); + const bool x_has_value = x.has_value(); + const int x_extent_0 = x_has_value ? (*x).extent(0) : 0; + RAFT_EXPECTS(not x_has_value || P.extent(0) == x_extent_0, + "multi_variable_gaussian: " + "P.extent(0) = %d != x.extent(0) = %d", + P.extent(0), + x_extent_0); + const int nPoints = X.extent(1); + const ValueType* x_ptr = x_has_value ? (*x).data_handle() : nullptr; + + auto workspace = allocate_workspace(); + impl_->set_workspace(workspace.data()); + impl_->give_gaussian(nPoints, P.data_handle(), X.data_handle(), x_ptr); + } + + private: + std::unique_ptr> impl_; + const raft::handle_t& handle_; + rmm::mr::device_memory_resource& mem_resource_; + int dim_ = 0; + + auto allocate_workspace() const + { + const auto num_elements = impl_->get_workspace_size(); + return rmm::device_uvector{num_elements, handle_.get_stream(), &mem_resource_}; + } + + int dim() const { return dim_; } +}; + +template +multi_variable_gaussian_setup_token build_multi_variable_gaussian_token_impl( + const raft::handle_t& handle, + rmm::mr::device_memory_resource& mem_resource, + const int dim, + const multi_variable_gaussian_decomposition_method method) +{ + return multi_variable_gaussian_setup_token(handle, mem_resource, dim, method); +} + +template +void compute_multi_variable_gaussian_impl( + multi_variable_gaussian_setup_token& token, + std::optional> x, + raft::device_matrix_view P, + raft::device_matrix_view X) +{ + token.compute(x, P, X); +} + +template +void compute_multi_variable_gaussian_impl( + const raft::handle_t& handle, + rmm::mr::device_memory_resource& mem_resource, + std::optional> x, + raft::device_matrix_view P, + raft::device_matrix_view X, + const multi_variable_gaussian_decomposition_method method) +{ + auto token = + build_multi_variable_gaussian_token_impl(handle, mem_resource, P.extent(0), method); + compute_multi_variable_gaussian_impl(token, x, P, X); +} + }; // end of namespace detail -}; // end of namespace raft::random \ No newline at end of file +}; // end of namespace raft::random diff --git a/cpp/include/raft/random/detail/random_types.hpp b/cpp/include/raft/random/detail/random_types.hpp new file mode 100644 index 0000000000..28108f9513 --- /dev/null +++ b/cpp/include/raft/random/detail/random_types.hpp @@ -0,0 +1,23 @@ +/* + * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +namespace raft::random::detail { + +enum class multi_variable_gaussian_decomposition_method { CHOLESKY, JACOBI, QR }; + +}; // end of namespace raft::random::detail diff --git a/cpp/include/raft/random/multi_variable_gaussian.cuh b/cpp/include/raft/random/multi_variable_gaussian.cuh index 1d9d63f6c5..796a10fb65 100644 --- a/cpp/include/raft/random/multi_variable_gaussian.cuh +++ b/cpp/include/raft/random/multi_variable_gaussian.cuh @@ -59,6 +59,52 @@ class multi_variable_gaussian : public detail::multi_variable_gaussian_impl { ~multi_variable_gaussian() { deinit(); } }; // end of multi_variable_gaussian +/** + * @brief Matrix decomposition method for `compute_multi_variable_gaussian` to use. + * + * `compute_multi_variable_gaussian` can use any of the following methods. + * + * - `CHOLESKY`: Uses Cholesky decomposition on the normal equations. + * This may be faster than the other two methods, but less accurate. + * + * - `JACOBI`: Uses the singular value decomposition (SVD) computed with + * cuSOLVER's gesvdj algorithm, which is based on the Jacobi method + * (sweeps of plane rotations). This exposes more parallelism + * for small and medium size matrices than the QR option below. + * + * - `QR`: Uses the SVD computed with cuSOLVER's gesvd algorithm, + * which is based on the QR algortihm. + */ +using detail::multi_variable_gaussian_decomposition_method; + +template +void compute_multi_variable_gaussian( + const raft::handle_t& handle, + rmm::mr::device_memory_resource& mem_resource, + std::optional> x, + raft::device_matrix_view P, + raft::device_matrix_view X, + const multi_variable_gaussian_decomposition_method method) +{ + detail::compute_multi_variable_gaussian_impl(handle, mem_resource, x, P, X, method); +} + +template +void compute_multi_variable_gaussian( + const raft::handle_t& handle, + std::optional> x, + raft::device_matrix_view P, + raft::device_matrix_view X, + const multi_variable_gaussian_decomposition_method method) +{ + rmm::mr::device_memory_resource* mem_resource_ptr = rmm::mr::get_current_device_resource(); + RAFT_EXPECTS(mem_resource_ptr != nullptr, + "compute_multi_variable_gaussian: " + "rmm::mr::get_current_device_resource() returned null; " + "please report this bug to the RAPIDS RAFT developers."); + detail::compute_multi_variable_gaussian_impl(handle, *mem_resource_ptr, x, P, X, method); +} + }; // end of namespace raft::random -#endif \ No newline at end of file +#endif diff --git a/cpp/test/random/multi_variable_gaussian.cu b/cpp/test/random/multi_variable_gaussian.cu index c346fbf426..caf982d4ed 100644 --- a/cpp/test/random/multi_variable_gaussian.cu +++ b/cpp/test/random/multi_variable_gaussian.cu @@ -23,7 +23,7 @@ #include #include -// mvg.h takes in matrices that are colomn major (as in fortan) +// mvg.h takes in column-major matrices (as in Fortran) #define IDX2C(i, j, ld) (j * ld + i) namespace raft::random { @@ -206,6 +206,132 @@ class MVGTest : public ::testing::TestWithParam> { raft::handle_t handle; }; // end of MVGTest class +template +class MVGMdspanTest : public ::testing::TestWithParam> { + private: + static auto old_enum_to_new_enum(typename multi_variable_gaussian::Decomposer method) + { + if (method == multi_variable_gaussian::chol_decomp) { + return detail::multi_variable_gaussian_decomposition_method::CHOLESKY; + } else if (method == multi_variable_gaussian::jacobi) { + return detail::multi_variable_gaussian_decomposition_method::JACOBI; + } else { + return detail::multi_variable_gaussian_decomposition_method::QR; + } + } + + protected: + MVGMdspanTest() + : workspace_d(0, handle.get_stream()), + P_d(0, handle.get_stream()), + x_d(0, handle.get_stream()), + X_d(0, handle.get_stream()), + Rand_cov(0, handle.get_stream()), + Rand_mean(0, handle.get_stream()) + { + } + + void SetUp() override + { + params = ::testing::TestWithParam>::GetParam(); + dim = params.dim; + nPoints = params.nPoints; + auto method = old_enum_to_new_enum(params.method); + corr = params.corr; + tolerance = params.tolerance; + + auto cublasH = handle.get_cublas_handle(); + auto cusolverH = handle.get_cusolver_dn_handle(); + auto stream = handle.get_stream(); + + P.resize(dim * dim); + x.resize(dim); + X.resize(dim * nPoints); + P_d.resize(dim * dim, stream); + X_d.resize(nPoints * dim, stream); + x_d.resize(dim, stream); + Rand_cov.resize(dim * dim, stream); + Rand_mean.resize(dim, stream); + + srand(params.seed); + for (int j = 0; j < dim; j++) + x.data()[j] = rand() % 100 + 5.0f; + + std::default_random_engine generator(params.seed); + std::uniform_real_distribution distribution(0.0, 1.0); + + // P (symmetric positive definite matrix) + for (int j = 0; j < dim; j++) { + for (int i = 0; i < j + 1; i++) { + T k = distribution(generator); + if (corr == UNCORRELATED) k = 0.0; + P.data()[IDX2C(i, j, dim)] = k; + P.data()[IDX2C(j, i, dim)] = k; + if (i == j) P.data()[IDX2C(i, j, dim)] += dim; + } + } + + raft::update_device(P_d.data(), P.data(), dim * dim, stream); + raft::update_device(x_d.data(), x.data(), dim, stream); + + std::optional> x_view(std::in_place, x_d.data(), dim); + raft::device_matrix_view P_view(P_d.data(), dim, dim); + raft::device_matrix_view X_view(X_d.data(), dim, nPoints); + + rmm::mr::device_memory_resource* mem_resource_ptr = rmm::mr::get_current_device_resource(); + ASSERT_TRUE(mem_resource_ptr != nullptr); + raft::random::compute_multi_variable_gaussian( + handle, *mem_resource_ptr, x_view, P_view, X_view, method); + + // saving the mean of the randoms in Rand_mean + //@todo can be swapped with a API that calculates mean + RAFT_CUDA_TRY(cudaMemset(Rand_mean.data(), 0, dim * sizeof(T))); + dim3 block = (64); + dim3 grid = (raft::ceildiv(nPoints * dim, (int)block.x)); + En_KF_accumulate<<>>(nPoints, dim, X_d.data(), Rand_mean.data()); + RAFT_CUDA_TRY(cudaPeekAtLastError()); + grid = (raft::ceildiv(dim, (int)block.x)); + En_KF_normalize<<>>(nPoints, dim, Rand_mean.data()); + RAFT_CUDA_TRY(cudaPeekAtLastError()); + + // storing the error wrt random point mean in X_d + grid = (raft::ceildiv(dim * nPoints, (int)block.x)); + En_KF_dif<<>>(nPoints, dim, X_d.data(), Rand_mean.data(), X_d.data()); + RAFT_CUDA_TRY(cudaPeekAtLastError()); + + // finding the cov matrix, placing in Rand_cov + T alfa = 1.0 / (nPoints - 1), beta = 0.0; + + RAFT_CUBLAS_TRY(raft::linalg::detail::cublasgemm(cublasH, + CUBLAS_OP_N, + CUBLAS_OP_T, + dim, + dim, + nPoints, + &alfa, + X_d.data(), + dim, + X_d.data(), + dim, + &beta, + Rand_cov.data(), + dim, + stream)); + + // restoring cov provided into P_d + raft::update_device(P_d.data(), P.data(), dim * dim, stream); + } + + protected: + MVGInputs params; + std::vector P, x, X; + rmm::device_uvector workspace_d, P_d, x_d, X_d, Rand_cov, Rand_mean; + int dim, nPoints; + Correlation corr; + T tolerance; + raft::handle_t handle; +}; // end of MVGTest class + ///@todo find out the reason that Un-correlated covs are giving problems (in qr) // Declare your inputs const std::vector> inputsf = { @@ -273,8 +399,8 @@ const std::vector> inputsd = { }; // make the tests -typedef MVGTest MVGTestF; -typedef MVGTest MVGTestD; +using MVGTestF = MVGTest; +using MVGTestD = MVGTest; TEST_P(MVGTestF, MeanIsCorrectF) { EXPECT_TRUE(raft::devArrMatch( @@ -308,8 +434,47 @@ TEST_P(MVGTestD, CovIsCorrectD) << " in CovIsCorrect"; } +using MVGMdspanTestF = MVGMdspanTest; +using MVGMdspanTestD = MVGMdspanTest; +TEST_P(MVGMdspanTestF, MeanIsCorrectF) +{ + EXPECT_TRUE(raft::devArrMatch( + x_d.data(), Rand_mean.data(), dim, raft::CompareApprox(tolerance), handle.get_stream())) + << " in MeanIsCorrect"; +} +TEST_P(MVGMdspanTestF, CovIsCorrectF) +{ + EXPECT_TRUE(raft::devArrMatch(P_d.data(), + Rand_cov.data(), + dim, + dim, + raft::CompareApprox(tolerance), + handle.get_stream())) + << " in CovIsCorrect"; +} +TEST_P(MVGMdspanTestD, MeanIsCorrectD) +{ + EXPECT_TRUE(raft::devArrMatch( + x_d.data(), Rand_mean.data(), dim, raft::CompareApprox(tolerance), handle.get_stream())) + << " in MeanIsCorrect"; +} +TEST_P(MVGMdspanTestD, CovIsCorrectD) +{ + EXPECT_TRUE(raft::devArrMatch(P_d.data(), + Rand_cov.data(), + dim, + dim, + raft::CompareApprox(tolerance), + handle.get_stream())) + << " in CovIsCorrect"; +} + // call the tests INSTANTIATE_TEST_CASE_P(MVGTests, MVGTestF, ::testing::ValuesIn(inputsf)); INSTANTIATE_TEST_CASE_P(MVGTests, MVGTestD, ::testing::ValuesIn(inputsd)); -}; // end of namespace raft::random \ No newline at end of file +// call the tests +INSTANTIATE_TEST_CASE_P(MVGMdspanTests, MVGMdspanTestF, ::testing::ValuesIn(inputsf)); +INSTANTIATE_TEST_CASE_P(MVGMdspanTests, MVGMdspanTestD, ::testing::ValuesIn(inputsd)); + +}; // end of namespace raft::random