diff --git a/cpp/bench/ann/src/raft/raft_ann_bench_utils.h b/cpp/bench/ann/src/raft/raft_ann_bench_utils.h index ffe8f8717b..9b086fdb23 100644 --- a/cpp/bench/ann/src/raft/raft_ann_bench_utils.h +++ b/cpp/bench/ann/src/raft/raft_ann_bench_utils.h @@ -31,6 +31,7 @@ #include #include #include +#include #include #include @@ -74,13 +75,14 @@ inline auto rmm_oom_callback(std::size_t bytes, void*) -> bool */ class shared_raft_resources { public: - using pool_mr_type = rmm::mr::pool_memory_resource; - using mr_type = rmm::mr::failure_callback_resource_adaptor; + using pool_mr_type = rmm::mr::pool_memory_resource; + using mr_type = rmm::mr::failure_callback_resource_adaptor; + using large_mr_type = rmm::mr::managed_memory_resource; shared_raft_resources() try : orig_resource_{rmm::mr::get_current_device_resource()}, pool_resource_(orig_resource_, 1024 * 1024 * 1024ull), - resource_(&pool_resource_, rmm_oom_callback, nullptr) { + resource_(&pool_resource_, rmm_oom_callback, nullptr), large_mr_() { rmm::mr::set_current_device_resource(&resource_); } catch (const std::exception& e) { auto cuda_status = cudaGetLastError(); @@ -103,10 +105,16 @@ class shared_raft_resources { ~shared_raft_resources() noexcept { rmm::mr::set_current_device_resource(orig_resource_); } + auto get_large_memory_resource() noexcept + { + return static_cast(&large_mr_); + } + private: rmm::mr::device_memory_resource* orig_resource_; pool_mr_type pool_resource_; mr_type resource_; + large_mr_type large_mr_; }; /** @@ -129,6 +137,12 @@ class configured_raft_resources { res_{std::make_unique( rmm::cuda_stream_view(get_stream_from_global_pool()))} { + // set the large workspace resource to the raft handle, but without the deleter + // (this resource is managed by the shared_res). + raft::resource::set_large_workspace_resource( + *res_, + std::shared_ptr(shared_res_->get_large_memory_resource(), + raft::void_op{})); } /** Default constructor creates all resources anew. */ diff --git a/cpp/include/raft/core/resource/device_memory_resource.hpp b/cpp/include/raft/core/resource/device_memory_resource.hpp index 9aa9e4fb85..b785010a0a 100644 --- a/cpp/include/raft/core/resource/device_memory_resource.hpp +++ b/cpp/include/raft/core/resource/device_memory_resource.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * Copyright (c) 2022-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -35,6 +35,16 @@ namespace raft::resource { * @{ */ +class device_memory_resource : public resource { + public: + explicit device_memory_resource(std::shared_ptr mr) : mr_(mr) {} + ~device_memory_resource() override = default; + auto get_resource() -> void* override { return mr_.get(); } + + private: + std::shared_ptr mr_; +}; + class limiting_memory_resource : public resource { public: limiting_memory_resource(std::shared_ptr mr, @@ -66,6 +76,29 @@ class limiting_memory_resource : public resource { } }; +/** + * Factory that knows how to construct a specific raft::resource to populate + * the resources instance. + */ +class large_workspace_resource_factory : public resource_factory { + public: + explicit large_workspace_resource_factory( + std::shared_ptr mr = {nullptr}) + : mr_{mr ? mr + : std::shared_ptr{ + rmm::mr::get_current_device_resource(), void_op{}}} + { + } + auto get_resource_type() -> resource_type override + { + return resource_type::LARGE_WORKSPACE_RESOURCE; + } + auto make_resource() -> resource* override { return new device_memory_resource(mr_); } + + private: + std::shared_ptr mr_; +}; + /** * Factory that knows how to construct a specific raft::resource to populate * the resources instance. @@ -144,7 +177,7 @@ class workspace_resource_factory : public resource_factory { // Note, the workspace does not claim all this memory from the start, so it's still usable by // the main resource as well. // This limit is merely an order for algorithm internals to plan the batching accordingly. - return total_size / 2; + return total_size / 4; } }; @@ -241,6 +274,21 @@ inline void set_workspace_to_global_resource( workspace_resource_factory::default_plain_resource(), allocation_limit, std::nullopt)); }; +inline auto get_large_workspace_resource(resources const& res) -> rmm::mr::device_memory_resource* +{ + if (!res.has_resource_factory(resource_type::LARGE_WORKSPACE_RESOURCE)) { + res.add_resource_factory(std::make_shared()); + } + return res.get_resource(resource_type::LARGE_WORKSPACE_RESOURCE); +}; + +inline void set_large_workspace_resource(resources const& res, + std::shared_ptr mr = { + nullptr}) +{ + res.add_resource_factory(std::make_shared(mr)); +}; + /** @} */ } // namespace raft::resource diff --git a/cpp/include/raft/core/resource/resource_types.hpp b/cpp/include/raft/core/resource/resource_types.hpp index d2021728c4..d9126251c9 100644 --- a/cpp/include/raft/core/resource/resource_types.hpp +++ b/cpp/include/raft/core/resource/resource_types.hpp @@ -28,23 +28,24 @@ namespace raft::resource { */ enum resource_type { // device-specific resource types - CUBLAS_HANDLE = 0, // cublas handle - CUSOLVER_DN_HANDLE, // cusolver dn handle - CUSOLVER_SP_HANDLE, // cusolver sp handle - CUSPARSE_HANDLE, // cusparse handle - CUDA_STREAM_VIEW, // view of a cuda stream - CUDA_STREAM_POOL, // cuda stream pool - CUDA_STREAM_SYNC_EVENT, // cuda event for syncing streams - COMMUNICATOR, // raft communicator - SUB_COMMUNICATOR, // raft sub communicator - DEVICE_PROPERTIES, // cuda device properties - DEVICE_ID, // cuda device id - STREAM_VIEW, // view of a cuda stream or a placeholder in - // CUDA-free builds - THRUST_POLICY, // thrust execution policy - WORKSPACE_RESOURCE, // rmm device memory resource - CUBLASLT_HANDLE, // cublasLt handle - CUSTOM, // runtime-shared default-constructible resource + CUBLAS_HANDLE = 0, // cublas handle + CUSOLVER_DN_HANDLE, // cusolver dn handle + CUSOLVER_SP_HANDLE, // cusolver sp handle + CUSPARSE_HANDLE, // cusparse handle + CUDA_STREAM_VIEW, // view of a cuda stream + CUDA_STREAM_POOL, // cuda stream pool + CUDA_STREAM_SYNC_EVENT, // cuda event for syncing streams + COMMUNICATOR, // raft communicator + SUB_COMMUNICATOR, // raft sub communicator + DEVICE_PROPERTIES, // cuda device properties + DEVICE_ID, // cuda device id + STREAM_VIEW, // view of a cuda stream or a placeholder in + // CUDA-free builds + THRUST_POLICY, // thrust execution policy + WORKSPACE_RESOURCE, // rmm device memory resource for small temporary allocations + CUBLASLT_HANDLE, // cublasLt handle + CUSTOM, // runtime-shared default-constructible resource + LARGE_WORKSPACE_RESOURCE, // rmm device memory resource for somewhat large temporary allocations LAST_KEY // reserved for the last key }; diff --git a/cpp/include/raft/matrix/detail/select_radix.cuh b/cpp/include/raft/matrix/detail/select_radix.cuh index 9480c8e202..2207b0216e 100644 --- a/cpp/include/raft/matrix/detail/select_radix.cuh +++ b/cpp/include/raft/matrix/detail/select_radix.cuh @@ -894,7 +894,7 @@ void radix_topk(const T* in, unsigned grid_dim, int sm_cnt, rmm::cuda_stream_view stream, - rmm::device_async_resource_ref mr = rmm::mr::get_current_device_resource()) + rmm::device_async_resource_ref mr) { // TODO: is it possible to relax this restriction? static_assert(calc_num_passes() > 1);