From 3f3a59eea8cbd3e069913a954e6faac5eb450be3 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Fri, 13 Jan 2023 06:08:12 -0500 Subject: [PATCH] Adding workspace resource (#1137) This will default to `rmm::mr::get_current_device_resource()` in the event no explicit workspace resource has been set. It's using raw pointers right now, but that may be okay as the RMM memory resource API seems to promote that over shared pointers (or any dereferencing at all). Authors: - Corey J. Nolet (https://github.com/cjnolet) Approvers: - Ben Frederickson (https://github.com/benfred) URL: https://github.com/rapidsai/raft/pull/1137 --- cpp/include/raft/core/device_resources.hpp | 11 ++- cpp/include/raft/core/handle.hpp | 7 +- .../core/resource/device_memory_resource.hpp | 76 +++++++++++++++++++ .../raft/core/resource/resource_types.hpp | 1 + cpp/test/core/handle.cpp | 20 +++++ 5 files changed, 111 insertions(+), 4 deletions(-) create mode 100644 cpp/include/raft/core/resource/device_memory_resource.hpp diff --git a/cpp/include/raft/core/device_resources.hpp b/cpp/include/raft/core/device_resources.hpp index faca07e8f4..9b9e07cf4f 100644 --- a/cpp/include/raft/core/device_resources.hpp +++ b/cpp/include/raft/core/device_resources.hpp @@ -45,6 +45,7 @@ #include #include #include +#include #include #include #include @@ -73,7 +74,8 @@ class device_resources : public resources { * @param[in] stream_pool the stream pool used (which has default of nullptr if unspecified) */ device_resources(rmm::cuda_stream_view stream_view = rmm::cuda_stream_per_thread, - std::shared_ptr stream_pool = {nullptr}) + std::shared_ptr stream_pool = {nullptr}, + rmm::mr::device_memory_resource* workspace_resource = nullptr) : resources{} { resources::add_resource_factory(std::make_shared()); @@ -81,6 +83,8 @@ class device_resources : public resources { std::make_shared(stream_view)); resources::add_resource_factory( std::make_shared(stream_pool)); + resources::add_resource_factory( + std::make_shared(workspace_resource)); } /** Destroys all held-up resources */ @@ -206,6 +210,11 @@ class device_resources : public resources { return resource::get_subcomm(*this, key); } + const rmm::mr::device_memory_resource* get_workspace_resource() const + { + return resource::get_workspace_resource(*this); + } + bool comms_initialized() const { return resource::comms_initialized(*this); } const cudaDeviceProp& get_device_properties() const diff --git a/cpp/include/raft/core/handle.hpp b/cpp/include/raft/core/handle.hpp index 48c1718eb0..6486965cdf 100644 --- a/cpp/include/raft/core/handle.hpp +++ b/cpp/include/raft/core/handle.hpp @@ -46,9 +46,10 @@ class handle_t : public raft::device_resources { * unspecified) * @param[in] stream_pool the stream pool used (which has default of nullptr if unspecified) */ - handle_t(rmm::cuda_stream_view stream_view = rmm::cuda_stream_per_thread, - std::shared_ptr stream_pool = {nullptr}) - : device_resources{stream_view, stream_pool} + handle_t(rmm::cuda_stream_view stream_view = rmm::cuda_stream_per_thread, + std::shared_ptr stream_pool = {nullptr}, + rmm::mr::device_memory_resource* workspace_resource = nullptr) + : device_resources{stream_view, stream_pool, workspace_resource} { } diff --git a/cpp/include/raft/core/resource/device_memory_resource.hpp b/cpp/include/raft/core/resource/device_memory_resource.hpp new file mode 100644 index 0000000000..0706f28f94 --- /dev/null +++ b/cpp/include/raft/core/resource/device_memory_resource.hpp @@ -0,0 +1,76 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include + +namespace raft::resource { +class device_memory_resource : public resource { + public: + device_memory_resource(rmm::mr::device_memory_resource* mr_ = nullptr) : mr(mr_) + { + if (mr_ == nullptr) { mr = rmm::mr::get_current_device_resource(); } + } + void* get_resource() override { return mr; } + + ~device_memory_resource() override {} + + private: + rmm::mr::device_memory_resource* mr; +}; + +/** + * Factory that knows how to construct a specific raft::resource to populate + * the resources instance. + */ +class workspace_resource_factory : public resource_factory { + public: + workspace_resource_factory(rmm::mr::device_memory_resource* mr_ = nullptr) : mr(mr_) {} + resource_type get_resource_type() override { return resource_type::WORKSPACE_RESOURCE; } + resource* make_resource() override { return new device_memory_resource(mr); } + + private: + rmm::mr::device_memory_resource* mr; +}; + +/** + * Load a temp workspace resource from a resources instance (and populate it on the res + * if needed). + * @param res raft resources object for managing resources + * @return device memory resource object + */ +inline rmm::mr::device_memory_resource* get_workspace_resource(resources const& res) +{ + if (!res.has_resource_factory(resource_type::WORKSPACE_RESOURCE)) { + res.add_resource_factory(std::make_shared()); + } + return res.get_resource(resource_type::WORKSPACE_RESOURCE); +}; + +/** + * Set a temp workspace resource on a resources instance. + * + * @param res raft resources object for managing resources + * @param mr a valid rmm device_memory_resource + * @return + */ +inline void set_workspace_resource(resources const& res, rmm::mr::device_memory_resource* mr) +{ + res.add_resource_factory(std::make_shared(mr)); +}; +} // namespace raft::resource \ No newline at end of file diff --git a/cpp/include/raft/core/resource/resource_types.hpp b/cpp/include/raft/core/resource/resource_types.hpp index c763066c79..ace4b7061b 100644 --- a/cpp/include/raft/core/resource/resource_types.hpp +++ b/cpp/include/raft/core/resource/resource_types.hpp @@ -35,6 +35,7 @@ enum resource_type { DEVICE_PROPERTIES, // cuda device properties DEVICE_ID, // cuda device id THRUST_POLICY, // thrust execution policy + WORKSPACE_RESOURCE, // rmm device memory resource LAST_KEY // reserved for the last key }; diff --git a/cpp/test/core/handle.cpp b/cpp/test/core/handle.cpp index 2148742e83..75b2d60bcd 100644 --- a/cpp/test/core/handle.cpp +++ b/cpp/test/core/handle.cpp @@ -21,6 +21,7 @@ #include #include #include +#include #include namespace raft { @@ -248,4 +249,23 @@ TEST(Raft, SubComms) ASSERT_EQ(handle.get_subcomm("key2").get_size(), 2); } +TEST(Raft, WorkspaceResource) +{ + handle_t handle; + + ASSERT_TRUE(dynamic_cast*>( + handle.get_workspace_resource()) == nullptr); + ASSERT_EQ(rmm::mr::get_current_device_resource(), handle.get_workspace_resource()); + + auto pool_mr = new rmm::mr::pool_memory_resource(rmm::mr::get_current_device_resource()); + std::shared_ptr pool = {nullptr}; + handle_t handle2(rmm::cuda_stream_per_thread, pool, pool_mr); + + ASSERT_TRUE(dynamic_cast*>( + handle2.get_workspace_resource()) != nullptr); + ASSERT_EQ(pool_mr, handle2.get_workspace_resource()); + + delete pool_mr; +} + } // namespace raft