From 2747816b6dbaa2112de570b2da33521f86cd3ba1 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Wed, 11 Jan 2023 13:13:44 -0500 Subject: [PATCH 1/3] Adding workspace resource which defaults to the `current_device_resource()`. --- cpp/include/raft/core/device_resources.hpp | 11 ++- cpp/include/raft/core/handle.hpp | 7 +- .../core/resource/device_memory_resource.hpp | 75 +++++++++++++++++++ .../raft/core/resource/resource_types.hpp | 1 + cpp/test/core/handle.cpp | 20 +++++ 5 files changed, 110 insertions(+), 4 deletions(-) create mode 100644 cpp/include/raft/core/resource/device_memory_resource.hpp diff --git a/cpp/include/raft/core/device_resources.hpp b/cpp/include/raft/core/device_resources.hpp index faca07e8f4..9b9e07cf4f 100644 --- a/cpp/include/raft/core/device_resources.hpp +++ b/cpp/include/raft/core/device_resources.hpp @@ -45,6 +45,7 @@ #include #include #include +#include #include #include #include @@ -73,7 +74,8 @@ class device_resources : public resources { * @param[in] stream_pool the stream pool used (which has default of nullptr if unspecified) */ device_resources(rmm::cuda_stream_view stream_view = rmm::cuda_stream_per_thread, - std::shared_ptr stream_pool = {nullptr}) + std::shared_ptr stream_pool = {nullptr}, + rmm::mr::device_memory_resource* workspace_resource = nullptr) : resources{} { resources::add_resource_factory(std::make_shared()); @@ -81,6 +83,8 @@ class device_resources : public resources { std::make_shared(stream_view)); resources::add_resource_factory( std::make_shared(stream_pool)); + resources::add_resource_factory( + std::make_shared(workspace_resource)); } /** Destroys all held-up resources */ @@ -206,6 +210,11 @@ class device_resources : public resources { return resource::get_subcomm(*this, key); } + const rmm::mr::device_memory_resource* get_workspace_resource() const + { + return resource::get_workspace_resource(*this); + } + bool comms_initialized() const { return resource::comms_initialized(*this); } const cudaDeviceProp& get_device_properties() const diff --git a/cpp/include/raft/core/handle.hpp b/cpp/include/raft/core/handle.hpp index 48c1718eb0..6486965cdf 100644 --- a/cpp/include/raft/core/handle.hpp +++ b/cpp/include/raft/core/handle.hpp @@ -46,9 +46,10 @@ class handle_t : public raft::device_resources { * unspecified) * @param[in] stream_pool the stream pool used (which has default of nullptr if unspecified) */ - handle_t(rmm::cuda_stream_view stream_view = rmm::cuda_stream_per_thread, - std::shared_ptr stream_pool = {nullptr}) - : device_resources{stream_view, stream_pool} + handle_t(rmm::cuda_stream_view stream_view = rmm::cuda_stream_per_thread, + std::shared_ptr stream_pool = {nullptr}, + rmm::mr::device_memory_resource* workspace_resource = nullptr) + : device_resources{stream_view, stream_pool, workspace_resource} { } diff --git a/cpp/include/raft/core/resource/device_memory_resource.hpp b/cpp/include/raft/core/resource/device_memory_resource.hpp new file mode 100644 index 0000000000..789b66c9fd --- /dev/null +++ b/cpp/include/raft/core/resource/device_memory_resource.hpp @@ -0,0 +1,75 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include + +namespace raft::resource { +class device_memory_resource : public resource { + public: + device_memory_resource(rmm::mr::device_memory_resource* mr_ = nullptr) : mr(mr_) + { + if (mr_ == nullptr) { mr = rmm::mr::get_current_device_resource(); } + } + void* get_resource() override { return mr; } + + ~device_memory_resource() override {} + + private: + rmm::mr::device_memory_resource* mr; +}; + +/** + * Factory that knows how to construct a specific raft::resource to populate + * the resources instance. + */ +class workspace_resource_factory : public resource_factory { + public: + workspace_resource_factory(rmm::mr::device_memory_resource* mr_ = nullptr) : mr(mr_) {} + resource_type get_resource_type() override { return resource_type::WORKSPACE_RESOURCE; } + resource* make_resource() override { return new device_memory_resource(mr); } + + private: + rmm::mr::device_memory_resource* mr; +}; + +/** + * Load a temp workspace resource from a resources instance (and populate it on the res + * if needed). + * @param res raft resources object for managing resources + * @return + */ +inline rmm::mr::device_memory_resource* get_workspace_resource(resources const& res) +{ + if (!res.has_resource_factory(resource_type::WORKSPACE_RESOURCE)) { + res.add_resource_factory(std::make_shared()); + } + return res.get_resource(resource_type::WORKSPACE_RESOURCE); +}; + +/** + * Set a temp workspace resource on a resources instance. + * + * @param res raft resources object for managing resources + * @return + */ +inline void set_workspace_resource(resources const& res, rmm::mr::device_memory_resource* mr) +{ + res.add_resource_factory(std::make_shared(mr)); +}; +} // namespace raft::resource \ No newline at end of file diff --git a/cpp/include/raft/core/resource/resource_types.hpp b/cpp/include/raft/core/resource/resource_types.hpp index c763066c79..ace4b7061b 100644 --- a/cpp/include/raft/core/resource/resource_types.hpp +++ b/cpp/include/raft/core/resource/resource_types.hpp @@ -35,6 +35,7 @@ enum resource_type { DEVICE_PROPERTIES, // cuda device properties DEVICE_ID, // cuda device id THRUST_POLICY, // thrust execution policy + WORKSPACE_RESOURCE, // rmm device memory resource LAST_KEY // reserved for the last key }; diff --git a/cpp/test/core/handle.cpp b/cpp/test/core/handle.cpp index 2148742e83..fc609f8542 100644 --- a/cpp/test/core/handle.cpp +++ b/cpp/test/core/handle.cpp @@ -21,6 +21,7 @@ #include #include #include +#include #include namespace raft { @@ -248,4 +249,23 @@ TEST(Raft, SubComms) ASSERT_EQ(handle.get_subcomm("key2").get_size(), 2); } +TEST(Raft, WorkspaceResource) +{ + handle_t handle; + + // We can't assert equality but we can test that a pool resource was not returned + ASSERT_TRUE(dynamic_cast*>( + handle.get_workspace_resource()) == nullptr); + + auto pool_mr = new rmm::mr::pool_memory_resource(rmm::mr::get_current_device_resource()); + std::shared_ptr pool = {nullptr}; + handle_t handle2(rmm::cuda_stream_per_thread, pool, pool_mr); + + // We can't assert equality so we can test that a pool resource was, in fact, returned + ASSERT_TRUE(dynamic_cast*>( + handle2.get_workspace_resource()) != nullptr); + + delete pool_mr; +} + } // namespace raft From 53ea66b67fb8eb92f022e9d13cd339e9a71ce56a Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Wed, 11 Jan 2023 13:20:06 -0500 Subject: [PATCH 2/3] Fixing assertions to verify exact pointer as well. --- cpp/test/core/handle.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/test/core/handle.cpp b/cpp/test/core/handle.cpp index fc609f8542..75b2d60bcd 100644 --- a/cpp/test/core/handle.cpp +++ b/cpp/test/core/handle.cpp @@ -253,17 +253,17 @@ TEST(Raft, WorkspaceResource) { handle_t handle; - // We can't assert equality but we can test that a pool resource was not returned ASSERT_TRUE(dynamic_cast*>( handle.get_workspace_resource()) == nullptr); + ASSERT_EQ(rmm::mr::get_current_device_resource(), handle.get_workspace_resource()); auto pool_mr = new rmm::mr::pool_memory_resource(rmm::mr::get_current_device_resource()); std::shared_ptr pool = {nullptr}; handle_t handle2(rmm::cuda_stream_per_thread, pool, pool_mr); - // We can't assert equality so we can test that a pool resource was, in fact, returned ASSERT_TRUE(dynamic_cast*>( handle2.get_workspace_resource()) != nullptr); + ASSERT_EQ(pool_mr, handle2.get_workspace_resource()); delete pool_mr; } From 5aa9fd308664caeb30293047762d933b8fd743e5 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Thu, 12 Jan 2023 19:25:15 -0500 Subject: [PATCH 3/3] Fixing the docs a tad --- cpp/include/raft/core/resource/device_memory_resource.hpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/cpp/include/raft/core/resource/device_memory_resource.hpp b/cpp/include/raft/core/resource/device_memory_resource.hpp index 789b66c9fd..0706f28f94 100644 --- a/cpp/include/raft/core/resource/device_memory_resource.hpp +++ b/cpp/include/raft/core/resource/device_memory_resource.hpp @@ -52,7 +52,7 @@ class workspace_resource_factory : public resource_factory { * Load a temp workspace resource from a resources instance (and populate it on the res * if needed). * @param res raft resources object for managing resources - * @return + * @return device memory resource object */ inline rmm::mr::device_memory_resource* get_workspace_resource(resources const& res) { @@ -66,6 +66,7 @@ inline rmm::mr::device_memory_resource* get_workspace_resource(resources const& * Set a temp workspace resource on a resources instance. * * @param res raft resources object for managing resources + * @param mr a valid rmm device_memory_resource * @return */ inline void set_workspace_resource(resources const& res, rmm::mr::device_memory_resource* mr)