diff --git a/cpp/include/raft/core/mdarray.hpp b/cpp/include/raft/core/mdarray.hpp index 618e307f5d..c7350a978c 100644 --- a/cpp/include/raft/core/mdarray.hpp +++ b/cpp/include/raft/core/mdarray.hpp @@ -197,8 +197,9 @@ class mdarray #endif // RAFT_MDARRAY_CTOR_CONSTEXPR /** - * @brief The only constructor that can create storage, this is to make sure CUDA stream is being - * used. + * @brief The only constructor that can create storage, raft::resources is accepted + * so that the device implementation can make sure the relevant CUDA stream is + * being used for allocation. */ RAFT_MDARRAY_CTOR_CONSTEXPR mdarray(raft::resources const& handle, mapping_type const& m, diff --git a/cpp/include/raft/linalg/detail/qr.cuh b/cpp/include/raft/linalg/detail/qr.cuh index 4cba028d87..bc7c551d89 100644 --- a/cpp/include/raft/linalg/detail/qr.cuh +++ b/cpp/include/raft/linalg/detail/qr.cuh @@ -18,6 +18,8 @@ #include "cublas_wrappers.hpp" #include "cusolver_wrappers.hpp" +#include +#include #include #include #include @@ -42,10 +44,10 @@ namespace detail { */ template void qrGetQ_inplace( - raft::device_resources const& handle, math_t* Q, int n_rows, int n_cols, cudaStream_t stream) + raft::resources const& handle, math_t* Q, int n_rows, int n_cols, cudaStream_t stream) { RAFT_EXPECTS(n_rows >= n_cols, "QR decomposition expects n_rows >= n_cols."); - cusolverDnHandle_t cusolver = handle.get_cusolver_dn_handle(); + cusolverDnHandle_t cusolver = resource::get_cusolver_dn_handle(handle); rmm::device_uvector tau(n_cols, stream); RAFT_CUDA_TRY(cudaMemsetAsync(tau.data(), 0, sizeof(math_t) * n_cols, stream)); @@ -83,7 +85,7 @@ void qrGetQ_inplace( } template -void qrGetQ(raft::device_resources const& handle, +void qrGetQ(raft::resources const& handle, const math_t* M, math_t* Q, int n_rows, @@ -95,7 +97,7 @@ void qrGetQ(raft::device_resources const& handle, } template -void qrGetQR(raft::device_resources const& handle, +void qrGetQR(raft::resources const& handle, math_t* M, math_t* Q, math_t* R, @@ -103,7 +105,7 @@ void qrGetQR(raft::device_resources const& handle, int n_cols, cudaStream_t stream) { - cusolverDnHandle_t cusolverH = handle.get_cusolver_dn_handle(); + cusolverDnHandle_t cusolverH = resource::get_cusolver_dn_handle(handle); int m = n_rows, n = n_cols; rmm::device_uvector R_full(m * n, stream); diff --git a/cpp/include/raft/linalg/detail/transpose.cuh b/cpp/include/raft/linalg/detail/transpose.cuh index 05588bda9c..bbd71a4cf1 100644 --- a/cpp/include/raft/linalg/detail/transpose.cuh +++ b/cpp/include/raft/linalg/detail/transpose.cuh @@ -19,7 +19,9 @@ #include "cublas_wrappers.hpp" #include -#include +#include +#include +#include #include #include #include @@ -29,14 +31,14 @@ namespace linalg { namespace detail { template -void transpose(raft::device_resources const& handle, +void transpose(raft::resources const& handle, math_t* in, math_t* out, int n_rows, int n_cols, cudaStream_t stream) { - cublasHandle_t cublas_h = handle.get_cublas_handle(); + cublasHandle_t cublas_h = resource::get_cublas_handle(handle); RAFT_CUBLAS_TRY(cublasSetStream(cublas_h, stream)); int out_n_rows = n_cols; @@ -83,7 +85,7 @@ void transpose(math_t* inout, int n, cudaStream_t stream) template void transpose_row_major_impl( - raft::device_resources const& handle, + raft::resources const& handle, raft::mdspan, LayoutPolicy, AccessorPolicy> in, raft::mdspan, LayoutPolicy, AccessorPolicy> out) { @@ -92,7 +94,7 @@ void transpose_row_major_impl( T constexpr kOne = 1; T constexpr kZero = 0; - CUBLAS_TRY(cublasgeam(handle.get_cublas_handle(), + CUBLAS_TRY(cublasgeam(resource::get_cublas_handle(handle), CUBLAS_OP_T, CUBLAS_OP_N, out_n_cols, @@ -105,12 +107,12 @@ void transpose_row_major_impl( out.stride(0), out.data_handle(), out.stride(0), - handle.get_stream())); + resource::get_cuda_stream(handle))); } template void transpose_col_major_impl( - raft::device_resources const& handle, + raft::resources const& handle, raft::mdspan, LayoutPolicy, AccessorPolicy> in, raft::mdspan, LayoutPolicy, AccessorPolicy> out) { @@ -119,7 +121,7 @@ void transpose_col_major_impl( T constexpr kOne = 1; T constexpr kZero = 0; - CUBLAS_TRY(cublasgeam(handle.get_cublas_handle(), + CUBLAS_TRY(cublasgeam(resource::get_cublas_handle(handle), CUBLAS_OP_T, CUBLAS_OP_N, out_n_rows, @@ -132,7 +134,7 @@ void transpose_col_major_impl( out.stride(1), out.data_handle(), out.stride(1), - handle.get_stream())); + resource::get_cuda_stream(handle))); } }; // end namespace detail }; // end namespace linalg diff --git a/cpp/include/raft/linalg/qr.cuh b/cpp/include/raft/linalg/qr.cuh index 8e58af63c1..948996d0ac 100644 --- a/cpp/include/raft/linalg/qr.cuh +++ b/cpp/include/raft/linalg/qr.cuh @@ -19,6 +19,8 @@ #pragma once #include "detail/qr.cuh" +#include +#include namespace raft { namespace linalg { @@ -33,7 +35,7 @@ namespace linalg { * @param stream cuda stream */ template -void qrGetQ(raft::device_resources const& handle, +void qrGetQ(raft::resources const& handle, const math_t* M, math_t* Q, int n_rows, @@ -54,7 +56,7 @@ void qrGetQ(raft::device_resources const& handle, * @param stream cuda stream */ template -void qrGetQR(raft::device_resources const& handle, +void qrGetQR(raft::resources const& handle, math_t* M, math_t* Q, math_t* R, @@ -77,13 +79,18 @@ void qrGetQR(raft::device_resources const& handle, * @param[out] Q Output raft::device_matrix_view */ template -void qr_get_q(raft::device_resources const& handle, +void qr_get_q(raft::resources const& handle, raft::device_matrix_view M, raft::device_matrix_view Q) { RAFT_EXPECTS(Q.size() == M.size(), "Size mismatch between Output and Input"); - qrGetQ(handle, M.data_handle(), Q.data_handle(), M.extent(0), M.extent(1), handle.get_stream()); + qrGetQ(handle, + M.data_handle(), + Q.data_handle(), + M.extent(0), + M.extent(1), + resource::get_cuda_stream(handle)); } /** @@ -94,7 +101,7 @@ void qr_get_q(raft::device_resources const& handle, * @param[out] R Output raft::device_matrix_view */ template -void qr_get_qr(raft::device_resources const& handle, +void qr_get_qr(raft::resources const& handle, raft::device_matrix_view M, raft::device_matrix_view Q, raft::device_matrix_view R) @@ -107,7 +114,7 @@ void qr_get_qr(raft::device_resources const& handle, R.data_handle(), M.extent(0), M.extent(1), - handle.get_stream()); + resource::get_cuda_stream(handle)); } /** @} */ diff --git a/cpp/include/raft/linalg/transpose.cuh b/cpp/include/raft/linalg/transpose.cuh index 2f31cfd722..0fe752347d 100644 --- a/cpp/include/raft/linalg/transpose.cuh +++ b/cpp/include/raft/linalg/transpose.cuh @@ -20,6 +20,7 @@ #include "detail/transpose.cuh" #include +#include namespace raft { namespace linalg { @@ -34,7 +35,7 @@ namespace linalg { * @param stream: cuda stream */ template -void transpose(raft::device_resources const& handle, +void transpose(raft::resources const& handle, math_t* in, math_t* out, int n_rows, @@ -76,7 +77,7 @@ void transpose(math_t* inout, int n, cudaStream_t stream) * @param[out] out Output matirx, storage is pre-allocated by caller. */ template -auto transpose(raft::device_resources const& handle, +auto transpose(raft::resources const& handle, raft::mdspan, LayoutPolicy, AccessorPolicy> in, raft::mdspan, LayoutPolicy, AccessorPolicy> out) -> std::enable_if_t, void> diff --git a/cpp/include/raft/random/detail/make_regression.cuh b/cpp/include/raft/random/detail/make_regression.cuh index 01d97d496d..1715dcbe81 100644 --- a/cpp/include/raft/random/detail/make_regression.cuh +++ b/cpp/include/raft/random/detail/make_regression.cuh @@ -22,7 +22,8 @@ #include -#include +#include +#include #include #include #include @@ -52,7 +53,7 @@ static __global__ void _singular_profile_kernel(DataT* out, IdxT n, DataT tail_s /* Internal auxiliary function to generate a low-rank matrix */ template -static void _make_low_rank_matrix(raft::device_resources const& handle, +static void _make_low_rank_matrix(raft::resources const& handle, DataT* out, IdxT n_rows, IdxT n_cols, @@ -61,8 +62,7 @@ static void _make_low_rank_matrix(raft::device_resources const& handle, raft::random::RngState& r, cudaStream_t stream) { - cusolverDnHandle_t cusolver_handle = handle.get_cusolver_dn_handle(); - cublasHandle_t cublas_handle = handle.get_cublas_handle(); + cublasHandle_t cublas_handle = resource::get_cublas_handle(handle); IdxT n = std::min(n_rows, n_cols); @@ -143,7 +143,7 @@ static __global__ void _gather2d_kernel( } template -void make_regression_caller(raft::device_resources const& handle, +void make_regression_caller(raft::resources const& handle, DataT* out, DataT* values, IdxT n_rows, @@ -162,8 +162,7 @@ void make_regression_caller(raft::device_resources const& handle, { n_informative = std::min(n_informative, n_cols); - cusolverDnHandle_t cusolver_handle = handle.get_cusolver_dn_handle(); - cublasHandle_t cublas_handle = handle.get_cublas_handle(); + cublasHandle_t cublas_handle = resource::get_cublas_handle(handle); cublasSetPointerMode(cublas_handle, CUBLAS_POINTER_MODE_HOST); raft::random::RngState r(seed, type); diff --git a/cpp/include/raft/random/detail/multi_variable_gaussian.cuh b/cpp/include/raft/random/detail/multi_variable_gaussian.cuh index 16f50446ae..68934ac1ff 100644 --- a/cpp/include/raft/random/detail/multi_variable_gaussian.cuh +++ b/cpp/include/raft/random/detail/multi_variable_gaussian.cuh @@ -20,7 +20,10 @@ #include #include #include -#include +#include +#include +#include +#include #include #include #include @@ -139,18 +142,16 @@ class multi_variable_gaussian_impl { int *info, Lwork, info_h; syevjInfo_t syevj_params = NULL; curandGenerator_t gen; - raft::device_resources const& handle; + raft::resources const& handle; cusolverEigMode_t jobz = CUSOLVER_EIG_MODE_VECTOR; bool deinitilized = false; public: // functions multi_variable_gaussian_impl() = delete; - multi_variable_gaussian_impl(raft::device_resources const& handle, - const int dim, - Decomposer method) + multi_variable_gaussian_impl(raft::resources const& handle, const int dim, Decomposer method) : handle(handle), dim(dim), method(method) { - auto cusolverHandle = handle.get_cusolver_dn_handle(); + auto cusolverHandle = resource::get_cusolver_dn_handle(handle); CURAND_CHECK(curandCreateGenerator(&gen, CURAND_RNG_PSEUDO_DEFAULT)); CURAND_CHECK(curandSetPseudoRandomGeneratorSeed(gen, 28)); // SEED @@ -191,9 +192,9 @@ class multi_variable_gaussian_impl { void give_gaussian(const int nPoints, T* P, T* X, const T* x = 0) { - auto cusolverHandle = handle.get_cusolver_dn_handle(); - auto cublasHandle = handle.get_cublas_handle(); - auto cudaStream = handle.get_stream(); + auto cusolverHandle = resource::get_cusolver_dn_handle(handle); + auto cublasHandle = resource::get_cublas_handle(handle); + auto cudaStream = resource::get_cuda_stream(handle); if (method == chol_decomp) { // lower part will contains chol_decomp RAFT_CUSOLVER_TRY(raft::linalg::detail::cusolverDnpotrf( @@ -299,7 +300,7 @@ class multi_variable_gaussian_setup_token; template multi_variable_gaussian_setup_token build_multi_variable_gaussian_token_impl( - raft::device_resources const& handle, + raft::resources const& handle, rmm::mr::device_memory_resource& mem_resource, const int dim, const multi_variable_gaussian_decomposition_method method); @@ -315,7 +316,7 @@ template class multi_variable_gaussian_setup_token { template friend multi_variable_gaussian_setup_token build_multi_variable_gaussian_token_impl( - raft::device_resources const& handle, + raft::resources const& handle, rmm::mr::device_memory_resource& mem_resource, const int dim, const multi_variable_gaussian_decomposition_method method); @@ -342,7 +343,7 @@ class multi_variable_gaussian_setup_token { // Constructor, only for use by friend functions. // Hiding this will let us change the implementation in the future. - multi_variable_gaussian_setup_token(raft::device_resources const& handle, + multi_variable_gaussian_setup_token(raft::resources const& handle, rmm::mr::device_memory_resource& mem_resource, const int dim, const multi_variable_gaussian_decomposition_method method) @@ -399,14 +400,15 @@ class multi_variable_gaussian_setup_token { private: std::unique_ptr> impl_; - raft::device_resources const& handle_; + raft::resources const& handle_; rmm::mr::device_memory_resource& mem_resource_; int dim_ = 0; auto allocate_workspace() const { const auto num_elements = impl_->get_workspace_size(); - return rmm::device_uvector{num_elements, handle_.get_stream(), &mem_resource_}; + return rmm::device_uvector{ + num_elements, resource::get_cuda_stream(handle_), &mem_resource_}; } int dim() const { return dim_; } @@ -414,7 +416,7 @@ class multi_variable_gaussian_setup_token { template multi_variable_gaussian_setup_token build_multi_variable_gaussian_token_impl( - raft::device_resources const& handle, + raft::resources const& handle, rmm::mr::device_memory_resource& mem_resource, const int dim, const multi_variable_gaussian_decomposition_method method) @@ -434,7 +436,7 @@ void compute_multi_variable_gaussian_impl( template void compute_multi_variable_gaussian_impl( - raft::device_resources const& handle, + raft::resources const& handle, rmm::mr::device_memory_resource& mem_resource, std::optional> x, raft::device_matrix_view P, @@ -455,7 +457,7 @@ class multi_variable_gaussian : public detail::multi_variable_gaussian_impl { // using detail::multi_variable_gaussian_impl::Decomposer::qr; multi_variable_gaussian() = delete; - multi_variable_gaussian(raft::device_resources const& handle, + multi_variable_gaussian(raft::resources const& handle, const int dim, typename detail::multi_variable_gaussian_impl::Decomposer method) : detail::multi_variable_gaussian_impl{handle, dim, method} diff --git a/cpp/include/raft/random/detail/rmat_rectangular_generator.cuh b/cpp/include/raft/random/detail/rmat_rectangular_generator.cuh index b5e0610405..d00fc29056 100644 --- a/cpp/include/raft/random/detail/rmat_rectangular_generator.cuh +++ b/cpp/include/raft/random/detail/rmat_rectangular_generator.cuh @@ -18,7 +18,8 @@ #include "rmat_rectangular_generator_types.cuh" -#include +#include +#include #include #include #include @@ -206,7 +207,7 @@ void rmat_rectangular_gen_caller(IdxT* out, * @param[in] c_scale 2^c_scale represents the number of destination nodes */ template -void rmat_rectangular_gen_impl(raft::device_resources const& handle, +void rmat_rectangular_gen_impl(raft::resources const& handle, raft::random::RngState& r, raft::device_vector_view theta, raft::random::detail::rmat_rectangular_gen_output output, @@ -247,7 +248,7 @@ void rmat_rectangular_gen_impl(raft::device_resources const& handle, r_scale, c_scale, n_edges, - handle.get_stream(), + resource::get_cuda_stream(handle), r); } @@ -259,7 +260,7 @@ void rmat_rectangular_gen_impl(raft::device_resources const& handle, * `theta` parameter. */ template -void rmat_rectangular_gen_impl(raft::device_resources const& handle, +void rmat_rectangular_gen_impl(raft::resources const& handle, raft::random::RngState& r, raft::random::detail::rmat_rectangular_gen_output output, ProbT a, @@ -286,8 +287,17 @@ void rmat_rectangular_gen_impl(raft::device_resources const& handle, IdxT* out_dst_ptr = out_dst_has_value ? (*out_dst).data_handle() : nullptr; const IdxT n_edges = output.number_of_edges(); - detail::rmat_rectangular_gen_caller( - out_ptr, out_src_ptr, out_dst_ptr, a, b, c, r_scale, c_scale, n_edges, handle.get_stream(), r); + detail::rmat_rectangular_gen_caller(out_ptr, + out_src_ptr, + out_dst_ptr, + a, + b, + c, + r_scale, + c_scale, + n_edges, + resource::get_cuda_stream(handle), + r); } } // end namespace detail diff --git a/cpp/include/raft/random/detail/rng_impl_deprecated.cuh b/cpp/include/raft/random/detail/rng_impl_deprecated.cuh index 362c844fb3..8895d22cf0 100644 --- a/cpp/include/raft/random/detail/rng_impl_deprecated.cuh +++ b/cpp/include/raft/random/detail/rng_impl_deprecated.cuh @@ -23,7 +23,7 @@ #include "rng_device.cuh" #include -#include +#include #include #include #include @@ -259,7 +259,7 @@ class RngImpl { template METHOD_DEPR(sampleWithoutReplacement) - void sampleWithoutReplacement(raft::device_resources const& handle, + void sampleWithoutReplacement(raft::resources const& handle, DataT* out, IdxT* outIdx, const DataT* in, diff --git a/cpp/include/raft/random/make_blobs.cuh b/cpp/include/raft/random/make_blobs.cuh index 7aa0362f6d..079ab43b74 100644 --- a/cpp/include/raft/random/make_blobs.cuh +++ b/cpp/include/raft/random/make_blobs.cuh @@ -22,6 +22,8 @@ #include "detail/make_blobs.cuh" #include #include +#include +#include namespace raft::random { @@ -129,7 +131,7 @@ void make_blobs(DataT* out, */ template void make_blobs( - raft::device_resources const& handle, + raft::resources const& handle, raft::device_matrix_view out, raft::device_vector_view labels, IdxT n_clusters = 5, @@ -167,7 +169,7 @@ void make_blobs( (IdxT)out.extent(0), (IdxT)out.extent(1), n_clusters, - handle.get_stream(), + resource::get_cuda_stream(handle), row_major, prm_centers, prm_cluster_std, diff --git a/cpp/include/raft/random/make_regression.cuh b/cpp/include/raft/random/make_regression.cuh index f4a7e82308..0aa9cc4daa 100644 --- a/cpp/include/raft/random/make_regression.cuh +++ b/cpp/include/raft/random/make_regression.cuh @@ -26,6 +26,8 @@ #include #include #include +#include +#include #include "detail/make_regression.cuh" @@ -67,7 +69,7 @@ namespace raft::random { * @param[in] type Random generator type */ template -void make_regression(raft::device_resources const& handle, +void make_regression(raft::resources const& handle, DataT* out, DataT* values, IdxT n_rows, @@ -138,7 +140,7 @@ void make_regression(raft::device_resources const& handle, * @param[in] type Random generator type */ template -void make_regression(raft::device_resources const& handle, +void make_regression(raft::resources const& handle, raft::device_matrix_view out, raft::device_matrix_view values, IdxT n_informative, @@ -170,7 +172,7 @@ void make_regression(raft::device_resources const& handle, n_samples, n_features, n_informative, - handle.get_stream(), + resource::get_cuda_stream(handle), coef_ptr, n_targets, bias, diff --git a/cpp/include/raft/random/multi_variable_gaussian.cuh b/cpp/include/raft/random/multi_variable_gaussian.cuh index 91a7695f2c..eada1c9521 100644 --- a/cpp/include/raft/random/multi_variable_gaussian.cuh +++ b/cpp/include/raft/random/multi_variable_gaussian.cuh @@ -20,6 +20,7 @@ #pragma once #include "detail/multi_variable_gaussian.cuh" +#include #include namespace raft::random { @@ -30,7 +31,7 @@ namespace raft::random { */ template -void multi_variable_gaussian(raft::device_resources const& handle, +void multi_variable_gaussian(raft::resources const& handle, rmm::mr::device_memory_resource& mem_resource, std::optional> x, raft::device_matrix_view P, @@ -41,7 +42,7 @@ void multi_variable_gaussian(raft::device_resources const& handle, } template -void multi_variable_gaussian(raft::device_resources const& handle, +void multi_variable_gaussian(raft::resources const& handle, std::optional> x, raft::device_matrix_view P, raft::device_matrix_view X, diff --git a/cpp/include/raft/random/permute.cuh b/cpp/include/raft/random/permute.cuh index 16de1d676d..d349b68add 100644 --- a/cpp/include/raft/random/permute.cuh +++ b/cpp/include/raft/random/permute.cuh @@ -23,7 +23,8 @@ #include #include -#include +#include +#include #include namespace raft::random { @@ -94,7 +95,7 @@ using perms_out_view_t = typename perms_out_view -void permute(raft::device_resources const& handle, +void permute(raft::resources const& handle, raft::device_matrix_view in, std::optional> permsOut, std::optional> out) @@ -127,8 +128,13 @@ void permute(raft::device_resources const& handle, if (permsOut_ptr != nullptr || out_ptr != nullptr) { const IdxType N = in.extent(0); const IdxType D = in.extent(1); - detail::permute( - permsOut_ptr, out_ptr, in.data_handle(), D, N, is_row_major, handle.get_stream()); + detail::permute(permsOut_ptr, + out_ptr, + in.data_handle(), + D, + N, + is_row_major, + resource::get_cuda_stream(handle)); } } @@ -141,7 +147,7 @@ template -void permute(raft::device_resources const& handle, +void permute(raft::resources const& handle, raft::device_matrix_view in, PermsOutType&& permsOut, OutType&& out) diff --git a/cpp/include/raft/random/rmat_rectangular_generator.cuh b/cpp/include/raft/random/rmat_rectangular_generator.cuh index d578794d31..90cd9baf81 100644 --- a/cpp/include/raft/random/rmat_rectangular_generator.cuh +++ b/cpp/include/raft/random/rmat_rectangular_generator.cuh @@ -17,6 +17,7 @@ #pragma once #include "detail/rmat_rectangular_generator.cuh" +#include namespace raft::random { @@ -78,7 +79,7 @@ namespace raft::random { */ template void rmat_rectangular_gen( - raft::device_resources const& handle, + raft::resources const& handle, raft::random::RngState& r, raft::device_vector_view theta, raft::device_mdspan, raft::row_major> out, @@ -102,7 +103,7 @@ void rmat_rectangular_gen( * @pre `out_src.extent(0) == out_dst.extent(0)` is `true` */ template -void rmat_rectangular_gen(raft::device_resources const& handle, +void rmat_rectangular_gen(raft::resources const& handle, raft::random::RngState& r, raft::device_vector_view theta, raft::device_vector_view out_src, @@ -125,7 +126,7 @@ void rmat_rectangular_gen(raft::device_resources const& handle, */ template void rmat_rectangular_gen( - raft::device_resources const& handle, + raft::resources const& handle, raft::random::RngState& r, raft::device_vector_view theta, raft::device_mdspan, raft::row_major> out, @@ -152,7 +153,7 @@ void rmat_rectangular_gen( */ template void rmat_rectangular_gen( - raft::device_resources const& handle, + raft::resources const& handle, raft::random::RngState& r, raft::device_mdspan, raft::row_major> out, raft::device_vector_view out_src, @@ -179,7 +180,7 @@ void rmat_rectangular_gen( * @pre `out_src.extent(0) == out_dst.extent(0)` is `true` */ template -void rmat_rectangular_gen(raft::device_resources const& handle, +void rmat_rectangular_gen(raft::resources const& handle, raft::random::RngState& r, raft::device_vector_view out_src, raft::device_vector_view out_dst, @@ -204,7 +205,7 @@ void rmat_rectangular_gen(raft::device_resources const& handle, */ template void rmat_rectangular_gen( - raft::device_resources const& handle, + raft::resources const& handle, raft::random::RngState& r, raft::device_mdspan, raft::row_major> out, ProbT a, diff --git a/cpp/include/raft/random/rng.cuh b/cpp/include/raft/random/rng.cuh index d03975d0db..c3b44a7577 100644 --- a/cpp/include/raft/random/rng.cuh +++ b/cpp/include/raft/random/rng.cuh @@ -22,7 +22,8 @@ #include #include #include -#include +#include +#include #include #include @@ -41,13 +42,14 @@ namespace raft::random { * @param[in] end end of the range */ template -void uniform(raft::device_resources const& handle, +void uniform(raft::resources const& handle, RngState& rng_state, raft::device_vector_view out, OutputValueType start, OutputValueType end) { - detail::uniform(rng_state, out.data_handle(), out.extent(0), start, end, handle.get_stream()); + detail::uniform( + rng_state, out.data_handle(), out.extent(0), start, end, resource::get_cuda_stream(handle)); } /** @@ -63,14 +65,14 @@ void uniform(raft::device_resources const& handle, * @param[in] end end of the range */ template -void uniform(raft::device_resources const& handle, +void uniform(raft::resources const& handle, RngState& rng_state, OutType* ptr, LenType len, OutType start, OutType end) { - detail::uniform(rng_state, ptr, len, start, end, handle.get_stream()); + detail::uniform(rng_state, ptr, len, start, end, resource::get_cuda_stream(handle)); } /** @@ -86,7 +88,7 @@ void uniform(raft::device_resources const& handle, * @param[in] end end of the range */ template -void uniformInt(raft::device_resources const& handle, +void uniformInt(raft::resources const& handle, RngState& rng_state, raft::device_vector_view out, OutputValueType start, @@ -98,7 +100,8 @@ void uniformInt(raft::device_resources const& handle, "so that we can write to it."); static_assert(std::is_integral::value, "uniformInt: The elements of the output vector must have integral type."); - detail::uniformInt(rng_state, out.data_handle(), out.extent(0), start, end, handle.get_stream()); + detail::uniformInt( + rng_state, out.data_handle(), out.extent(0), start, end, resource::get_cuda_stream(handle)); } /** @@ -114,14 +117,14 @@ void uniformInt(raft::device_resources const& handle, * @param[in] end end of the range */ template -void uniformInt(raft::device_resources const& handle, +void uniformInt(raft::resources const& handle, RngState& rng_state, OutType* ptr, LenType len, OutType start, OutType end) { - detail::uniformInt(rng_state, ptr, len, start, end, handle.get_stream()); + detail::uniformInt(rng_state, ptr, len, start, end, resource::get_cuda_stream(handle)); } /** @@ -138,13 +141,14 @@ void uniformInt(raft::device_resources const& handle, * @param[in] sigma std-dev of the distribution */ template -void normal(raft::device_resources const& handle, +void normal(raft::resources const& handle, RngState& rng_state, raft::device_vector_view out, OutputValueType mu, OutputValueType sigma) { - detail::normal(rng_state, out.data_handle(), out.extent(0), mu, sigma, handle.get_stream()); + detail::normal( + rng_state, out.data_handle(), out.extent(0), mu, sigma, resource::get_cuda_stream(handle)); } /** @@ -160,14 +164,14 @@ void normal(raft::device_resources const& handle, * @param[in] sigma std-dev of the distribution */ template -void normal(raft::device_resources const& handle, +void normal(raft::resources const& handle, RngState& rng_state, OutType* ptr, LenType len, OutType mu, OutType sigma) { - detail::normal(rng_state, ptr, len, mu, sigma, handle.get_stream()); + detail::normal(rng_state, ptr, len, mu, sigma, resource::get_cuda_stream(handle)); } /** @@ -183,7 +187,7 @@ void normal(raft::device_resources const& handle, * @param[in] sigma standard deviation of the distribution */ template -void normalInt(raft::device_resources const& handle, +void normalInt(raft::resources const& handle, RngState& rng_state, raft::device_vector_view out, OutputValueType mu, @@ -196,7 +200,8 @@ void normalInt(raft::device_resources const& handle, static_assert(std::is_integral::value, "normalInt: The output vector's value type must be an integer."); - detail::normalInt(rng_state, out.data_handle(), out.extent(0), mu, sigma, handle.get_stream()); + detail::normalInt( + rng_state, out.data_handle(), out.extent(0), mu, sigma, resource::get_cuda_stream(handle)); } /** @@ -212,14 +217,14 @@ void normalInt(raft::device_resources const& handle, * @param[in] sigma std-dev of the distribution */ template -void normalInt(raft::device_resources const& handle, +void normalInt(raft::resources const& handle, RngState& rng_state, IntType* ptr, LenType len, IntType mu, IntType sigma) { - detail::normalInt(rng_state, ptr, len, mu, sigma, handle.get_stream()); + detail::normalInt(rng_state, ptr, len, mu, sigma, resource::get_cuda_stream(handle)); } /** @@ -244,7 +249,7 @@ void normalInt(raft::device_resources const& handle, */ template void normalTable( - raft::device_resources const& handle, + raft::resources const& handle, RngState& rng_state, raft::device_vector_view mu_vec, std::variant, OutputValueType> sigma, @@ -283,7 +288,7 @@ void normalTable( mu_vec.data_handle(), sigma_vec_ptr, sigma_value, - handle.get_stream()); + resource::get_cuda_stream(handle)); } /** @@ -307,7 +312,7 @@ void normalTable( * @param[in] sigma scalar sigma to be used if 'sigma_vec' is nullptr */ template -void normalTable(raft::device_resources const& handle, +void normalTable(raft::resources const& handle, RngState& rng_state, OutType* ptr, LenType n_rows, @@ -317,7 +322,7 @@ void normalTable(raft::device_resources const& handle, OutType sigma) { detail::normalTable( - rng_state, ptr, n_rows, n_cols, mu_vec, sigma_vec, sigma, handle.get_stream()); + rng_state, ptr, n_rows, n_cols, mu_vec, sigma_vec, sigma, resource::get_cuda_stream(handle)); } /** @@ -332,12 +337,12 @@ void normalTable(raft::device_resources const& handle, * @param[out] out the output vector */ template -void fill(raft::device_resources const& handle, +void fill(raft::resources const& handle, RngState& rng_state, OutputValueType val, raft::device_vector_view out) { - detail::fill(rng_state, out.data_handle(), out.extent(0), val, handle.get_stream()); + detail::fill(rng_state, out.data_handle(), out.extent(0), val, resource::get_cuda_stream(handle)); } /** @@ -353,9 +358,9 @@ void fill(raft::device_resources const& handle, */ template void fill( - raft::device_resources const& handle, RngState& rng_state, OutType* ptr, LenType len, OutType val) + raft::resources const& handle, RngState& rng_state, OutType* ptr, LenType len, OutType val) { - detail::fill(rng_state, ptr, len, val, handle.get_stream()); + detail::fill(rng_state, ptr, len, val, resource::get_cuda_stream(handle)); } /** @@ -372,12 +377,13 @@ void fill( * @param[in] prob coin-toss probability for heads */ template -void bernoulli(raft::device_resources const& handle, +void bernoulli(raft::resources const& handle, RngState& rng_state, raft::device_vector_view out, Type prob) { - detail::bernoulli(rng_state, out.data_handle(), out.extent(0), prob, handle.get_stream()); + detail::bernoulli( + rng_state, out.data_handle(), out.extent(0), prob, resource::get_cuda_stream(handle)); } /** @@ -395,9 +401,9 @@ void bernoulli(raft::device_resources const& handle, */ template void bernoulli( - raft::device_resources const& handle, RngState& rng_state, OutType* ptr, LenType len, Type prob) + raft::resources const& handle, RngState& rng_state, OutType* ptr, LenType len, Type prob) { - detail::bernoulli(rng_state, ptr, len, prob, handle.get_stream()); + detail::bernoulli(rng_state, ptr, len, prob, resource::get_cuda_stream(handle)); } /** @@ -413,14 +419,14 @@ void bernoulli( * @param[in] scale scaling factor */ template -void scaled_bernoulli(raft::device_resources const& handle, +void scaled_bernoulli(raft::resources const& handle, RngState& rng_state, raft::device_vector_view out, OutputValueType prob, OutputValueType scale) { detail::scaled_bernoulli( - rng_state, out.data_handle(), out.extent(0), prob, scale, handle.get_stream()); + rng_state, out.data_handle(), out.extent(0), prob, scale, resource::get_cuda_stream(handle)); } /** @@ -436,14 +442,14 @@ void scaled_bernoulli(raft::device_resources const& handle, * @param[in] scale scaling factor */ template -void scaled_bernoulli(raft::device_resources const& handle, +void scaled_bernoulli(raft::resources const& handle, RngState& rng_state, OutType* ptr, LenType len, OutType prob, OutType scale) { - detail::scaled_bernoulli(rng_state, ptr, len, prob, scale, handle.get_stream()); + detail::scaled_bernoulli(rng_state, ptr, len, prob, scale, resource::get_cuda_stream(handle)); } /** @@ -460,13 +466,14 @@ void scaled_bernoulli(raft::device_resources const& handle, * @note https://en.wikipedia.org/wiki/Gumbel_distribution */ template -void gumbel(raft::device_resources const& handle, +void gumbel(raft::resources const& handle, RngState& rng_state, raft::device_vector_view out, OutputValueType mu, OutputValueType beta) { - detail::gumbel(rng_state, out.data_handle(), out.extent(0), mu, beta, handle.get_stream()); + detail::gumbel( + rng_state, out.data_handle(), out.extent(0), mu, beta, resource::get_cuda_stream(handle)); } /** @@ -483,14 +490,14 @@ void gumbel(raft::device_resources const& handle, * @note https://en.wikipedia.org/wiki/Gumbel_distribution */ template -void gumbel(raft::device_resources const& handle, +void gumbel(raft::resources const& handle, RngState& rng_state, OutType* ptr, LenType len, OutType mu, OutType beta) { - detail::gumbel(rng_state, ptr, len, mu, beta, handle.get_stream()); + detail::gumbel(rng_state, ptr, len, mu, beta, resource::get_cuda_stream(handle)); } /** @@ -506,13 +513,14 @@ void gumbel(raft::device_resources const& handle, * @param[in] sigma standard deviation of the distribution */ template -void lognormal(raft::device_resources const& handle, +void lognormal(raft::resources const& handle, RngState& rng_state, raft::device_vector_view out, OutputValueType mu, OutputValueType sigma) { - detail::lognormal(rng_state, out.data_handle(), out.extent(0), mu, sigma, handle.get_stream()); + detail::lognormal( + rng_state, out.data_handle(), out.extent(0), mu, sigma, resource::get_cuda_stream(handle)); } /** @@ -528,14 +536,14 @@ void lognormal(raft::device_resources const& handle, * @param[in] sigma standard deviation of the distribution */ template -void lognormal(raft::device_resources const& handle, +void lognormal(raft::resources const& handle, RngState& rng_state, OutType* ptr, LenType len, OutType mu, OutType sigma) { - detail::lognormal(rng_state, ptr, len, mu, sigma, handle.get_stream()); + detail::lognormal(rng_state, ptr, len, mu, sigma, resource::get_cuda_stream(handle)); } /** @@ -551,13 +559,14 @@ void lognormal(raft::device_resources const& handle, * @param[in] scale scale value */ template -void logistic(raft::device_resources const& handle, +void logistic(raft::resources const& handle, RngState& rng_state, raft::device_vector_view out, OutputValueType mu, OutputValueType scale) { - detail::logistic(rng_state, out.data_handle(), out.extent(0), mu, scale, handle.get_stream()); + detail::logistic( + rng_state, out.data_handle(), out.extent(0), mu, scale, resource::get_cuda_stream(handle)); } /** @@ -573,14 +582,14 @@ void logistic(raft::device_resources const& handle, * @param[in] scale scale value */ template -void logistic(raft::device_resources const& handle, +void logistic(raft::resources const& handle, RngState& rng_state, OutType* ptr, LenType len, OutType mu, OutType scale) { - detail::logistic(rng_state, ptr, len, mu, scale, handle.get_stream()); + detail::logistic(rng_state, ptr, len, mu, scale, resource::get_cuda_stream(handle)); } /** @@ -595,12 +604,13 @@ void logistic(raft::device_resources const& handle, * @param[in] lambda the exponential distribution's lambda parameter */ template -void exponential(raft::device_resources const& handle, +void exponential(raft::resources const& handle, RngState& rng_state, raft::device_vector_view out, OutputValueType lambda) { - detail::exponential(rng_state, out.data_handle(), out.extent(0), lambda, handle.get_stream()); + detail::exponential( + rng_state, out.data_handle(), out.extent(0), lambda, resource::get_cuda_stream(handle)); } /** @@ -615,13 +625,10 @@ void exponential(raft::device_resources const& handle, * @param[in] lambda the exponential distribution's lambda parameter */ template -void exponential(raft::device_resources const& handle, - RngState& rng_state, - OutType* ptr, - LenType len, - OutType lambda) +void exponential( + raft::resources const& handle, RngState& rng_state, OutType* ptr, LenType len, OutType lambda) { - detail::exponential(rng_state, ptr, len, lambda, handle.get_stream()); + detail::exponential(rng_state, ptr, len, lambda, resource::get_cuda_stream(handle)); } /** @@ -636,12 +643,13 @@ void exponential(raft::device_resources const& handle, * @param[in] sigma the distribution's sigma parameter */ template -void rayleigh(raft::device_resources const& handle, +void rayleigh(raft::resources const& handle, RngState& rng_state, raft::device_vector_view out, OutputValueType sigma) { - detail::rayleigh(rng_state, out.data_handle(), out.extent(0), sigma, handle.get_stream()); + detail::rayleigh( + rng_state, out.data_handle(), out.extent(0), sigma, resource::get_cuda_stream(handle)); } /** @@ -656,15 +664,11 @@ void rayleigh(raft::device_resources const& handle, * @param[in] sigma the distribution's sigma parameter */ template -void rayleigh(raft::device_resources const& handle, - RngState& rng_state, - OutType* ptr, - LenType len, - OutType sigma) +void rayleigh( + raft::resources const& handle, RngState& rng_state, OutType* ptr, LenType len, OutType sigma) { - detail::rayleigh(rng_state, ptr, len, sigma, handle.get_stream()); + detail::rayleigh(rng_state, ptr, len, sigma, resource::get_cuda_stream(handle)); } - /** * @brief Generate laplace distributed random numbers * @@ -678,13 +682,14 @@ void rayleigh(raft::device_resources const& handle, * @param[in] scale the scale */ template -void laplace(raft::device_resources const& handle, +void laplace(raft::resources const& handle, RngState& rng_state, raft::device_vector_view out, OutputValueType mu, OutputValueType scale) { - detail::laplace(rng_state, out.data_handle(), out.extent(0), mu, scale, handle.get_stream()); + detail::laplace( + rng_state, out.data_handle(), out.extent(0), mu, scale, resource::get_cuda_stream(handle)); } /** @@ -700,14 +705,14 @@ void laplace(raft::device_resources const& handle, * @param[in] scale the scale */ template -void laplace(raft::device_resources const& handle, +void laplace(raft::resources const& handle, RngState& rng_state, OutType* ptr, LenType len, OutType mu, OutType scale) { - detail::laplace(rng_state, ptr, len, mu, scale, handle.get_stream()); + detail::laplace(rng_state, ptr, len, mu, scale, resource::get_cuda_stream(handle)); } /** @@ -716,10 +721,10 @@ void laplace(raft::device_resources const& handle, * Usage example: * @code{.cpp} * #include - * #include + * #include * #include * - * raft::raft::device_resources handle; + * raft::resources handle; * ... * raft::random::RngState rng(seed); * auto indices = raft::make_device_vector(handle, n_samples); @@ -737,7 +742,7 @@ void laplace(raft::device_resources const& handle, */ template std::enable_if_t> discrete( - raft::device_resources const& handle, + raft::resources const& handle, RngState& rng_state, raft::device_vector_view out, raft::device_vector_view weights) @@ -747,7 +752,7 @@ std::enable_if_t> discrete( weights.data_handle(), out.extent(0), weights.extent(0), - handle.get_stream()); + resource::get_cuda_stream(handle)); } /** @@ -770,7 +775,7 @@ std::enable_if_t> discrete( * @param[in] len input array length */ template -void sampleWithoutReplacement(raft::device_resources const& handle, +void sampleWithoutReplacement(raft::resources const& handle, RngState& rng_state, DataT* out, IdxT* outIdx, @@ -780,7 +785,7 @@ void sampleWithoutReplacement(raft::device_resources const& handle, IdxT len) { detail::sampleWithoutReplacement( - rng_state, out, outIdx, in, wts, sampledLen, len, handle.get_stream()); + rng_state, out, outIdx, in, wts, sampledLen, len, resource::get_cuda_stream(handle)); } /** @@ -1106,7 +1111,7 @@ class DEPR Rng : public detail::RngImpl { * @param stream cuda stream */ template - void sampleWithoutReplacement(raft::device_resources const& handle, + void sampleWithoutReplacement(raft::resources const& handle, DataT* out, IdxT* outIdx, const DataT* in, diff --git a/cpp/include/raft/random/sample_without_replacement.cuh b/cpp/include/raft/random/sample_without_replacement.cuh index be8bda8cd3..b074f68af6 100644 --- a/cpp/include/raft/random/sample_without_replacement.cuh +++ b/cpp/include/raft/random/sample_without_replacement.cuh @@ -21,7 +21,8 @@ #include #include #include -#include +#include +#include #include #include @@ -93,7 +94,7 @@ using weight_t = typename weight_alias::type; * equals the number of inputs `in.extent(0)`. */ template -void sample_without_replacement(raft::device_resources const& handle, +void sample_without_replacement(raft::resources const& handle, RngState& rng_state, raft::device_vector_view in, WeightsVectorType&& weights_opt, @@ -144,7 +145,7 @@ void sample_without_replacement(raft::device_resources const& handle, wts_ptr, sampledLen, len, - handle.get_stream()); + resource::get_cuda_stream(handle)); } /** diff --git a/cpp/test/core/temporary_device_buffer.cu b/cpp/test/core/temporary_device_buffer.cu index 52a2ec4c9b..cc8af24f10 100644 --- a/cpp/test/core/temporary_device_buffer.cu +++ b/cpp/test/core/temporary_device_buffer.cu @@ -16,6 +16,7 @@ #include "../test_utils.cuh" +#include #include #include diff --git a/cpp/test/distance/dist_adj.cu b/cpp/test/distance/dist_adj.cu index bb63cc9be3..413e548532 100644 --- a/cpp/test/distance/dist_adj.cu +++ b/cpp/test/distance/dist_adj.cu @@ -16,6 +16,7 @@ #include "../test_utils.cuh" #include +#include #include #include #include diff --git a/cpp/test/linalg/transpose.cu b/cpp/test/linalg/transpose.cu index 17955abb34..6f5800dd8f 100644 --- a/cpp/test/linalg/transpose.cu +++ b/cpp/test/linalg/transpose.cu @@ -16,6 +16,7 @@ #include "../test_utils.cuh" +#include #include #include #include diff --git a/cpp/test/matrix/columnSort.cu b/cpp/test/matrix/columnSort.cu index 2292772b1a..9a65918f8f 100644 --- a/cpp/test/matrix/columnSort.cu +++ b/cpp/test/matrix/columnSort.cu @@ -19,6 +19,7 @@ #include #include #include +#include #include #include #include diff --git a/cpp/test/random/make_blobs.cu b/cpp/test/random/make_blobs.cu index c2dbc5dc1c..0565635e3b 100644 --- a/cpp/test/random/make_blobs.cu +++ b/cpp/test/random/make_blobs.cu @@ -18,6 +18,8 @@ #include #include #include +#include + #include #include #include diff --git a/cpp/test/random/make_regression.cu b/cpp/test/random/make_regression.cu index 7508b57bdd..74aa00171b 100644 --- a/cpp/test/random/make_regression.cu +++ b/cpp/test/random/make_regression.cu @@ -20,8 +20,10 @@ #include #include "../test_utils.cuh" +#include #include #include + #include #include #include diff --git a/cpp/test/random/multi_variable_gaussian.cu b/cpp/test/random/multi_variable_gaussian.cu index 1aa8b6a555..a27dffc7bf 100644 --- a/cpp/test/random/multi_variable_gaussian.cu +++ b/cpp/test/random/multi_variable_gaussian.cu @@ -18,8 +18,10 @@ #include #include #include +#include #include #include + #include #include diff --git a/cpp/test/random/permute.cu b/cpp/test/random/permute.cu index d5fcca270e..2c5ddf9d5a 100644 --- a/cpp/test/random/permute.cu +++ b/cpp/test/random/permute.cu @@ -16,8 +16,10 @@ #include "../test_utils.cuh" #include +#include #include #include + #include #include #include diff --git a/cpp/test/random/rmat_rectangular_generator.cu b/cpp/test/random/rmat_rectangular_generator.cu index aae3898389..fd9a8ec732 100644 --- a/cpp/test/random/rmat_rectangular_generator.cu +++ b/cpp/test/random/rmat_rectangular_generator.cu @@ -21,8 +21,10 @@ #include "../test_utils.cuh" +#include #include #include + #include #include diff --git a/cpp/test/random/rng.cu b/cpp/test/random/rng.cu index d3b8e44b05..92f79b1fa0 100644 --- a/cpp/test/random/rng.cu +++ b/cpp/test/random/rng.cu @@ -20,9 +20,11 @@ #include "../test_utils.cuh" #include #include +#include #include #include #include + #include #include diff --git a/cpp/test/random/rng_discrete.cu b/cpp/test/random/rng_discrete.cu index 741f7c65e0..b9b283b87d 100644 --- a/cpp/test/random/rng_discrete.cu +++ b/cpp/test/random/rng_discrete.cu @@ -18,9 +18,11 @@ #include #include #include +#include #include #include #include + #include #include #include diff --git a/cpp/test/random/rng_int.cu b/cpp/test/random/rng_int.cu index 83300b3ecc..8208b04489 100644 --- a/cpp/test/random/rng_int.cu +++ b/cpp/test/random/rng_int.cu @@ -17,6 +17,7 @@ #include "../test_utils.cuh" #include #include +#include #include #include #include diff --git a/cpp/test/random/sample_without_replacement.cu b/cpp/test/random/sample_without_replacement.cu index ae5a58da3d..dcad32ce8a 100644 --- a/cpp/test/random/sample_without_replacement.cu +++ b/cpp/test/random/sample_without_replacement.cu @@ -16,6 +16,7 @@ #include "../test_utils.cuh" #include +#include #include #include #include diff --git a/cpp/test/stats/histogram.cu b/cpp/test/stats/histogram.cu index 9ad7998180..c6c3dd48ca 100644 --- a/cpp/test/stats/histogram.cu +++ b/cpp/test/stats/histogram.cu @@ -16,6 +16,7 @@ #include "../test_utils.cuh" #include +#include #include #include #include diff --git a/cpp/test/stats/minmax.cu b/cpp/test/stats/minmax.cu index 8b58f9692a..e0dc77520d 100644 --- a/cpp/test/stats/minmax.cu +++ b/cpp/test/stats/minmax.cu @@ -18,6 +18,7 @@ #include #include #include +#include #include #include #include