diff --git a/cpp/include/raft/util/memory_pool-ext.hpp b/cpp/include/raft/util/memory_pool-ext.hpp index a02908346b..5c752e450b 100644 --- a/cpp/include/raft/util/memory_pool-ext.hpp +++ b/cpp/include/raft/util/memory_pool-ext.hpp @@ -18,10 +18,11 @@ #include // size_t #include // std::unique_ptr #include // rmm::mr::device_memory_resource +#include // rmm::mr::pool_memory_resource namespace raft { -std::unique_ptr get_pool_memory_resource( - rmm::mr::device_memory_resource*& mr, size_t initial_size); +std::unique_ptr> +get_pool_memory_resource(rmm::mr::device_memory_resource*& mr, size_t initial_size); } // namespace raft diff --git a/cpp/include/raft/util/memory_pool-inl.hpp b/cpp/include/raft/util/memory_pool-inl.hpp index a227b6e53f..f954819ad1 100644 --- a/cpp/include/raft/util/memory_pool-inl.hpp +++ b/cpp/include/raft/util/memory_pool-inl.hpp @@ -57,8 +57,9 @@ namespace raft { * @return if a new memory pool is created, it returns a unique_ptr to it; * this managed pointer controls the lifetime of the created memory resource. */ -RAFT_INLINE_CONDITIONAL std::unique_ptr get_pool_memory_resource( - rmm::mr::device_memory_resource*& mr, size_t initial_size) +RAFT_INLINE_CONDITIONAL +std::unique_ptr> +get_pool_memory_resource(rmm::mr::device_memory_resource*& mr, size_t initial_size) { using pool_res_t = rmm::mr::pool_memory_resource; std::unique_ptr pool_res{};