diff --git a/cpp/include/raft/core/device_resources.hpp b/cpp/include/raft/core/device_resources.hpp index faca07e8f4..9b9e07cf4f 100644 --- a/cpp/include/raft/core/device_resources.hpp +++ b/cpp/include/raft/core/device_resources.hpp @@ -45,6 +45,7 @@ #include #include #include +#include #include #include #include @@ -73,7 +74,8 @@ class device_resources : public resources { * @param[in] stream_pool the stream pool used (which has default of nullptr if unspecified) */ device_resources(rmm::cuda_stream_view stream_view = rmm::cuda_stream_per_thread, - std::shared_ptr stream_pool = {nullptr}) + std::shared_ptr stream_pool = {nullptr}, + rmm::mr::device_memory_resource* workspace_resource = nullptr) : resources{} { resources::add_resource_factory(std::make_shared()); @@ -81,6 +83,8 @@ class device_resources : public resources { std::make_shared(stream_view)); resources::add_resource_factory( std::make_shared(stream_pool)); + resources::add_resource_factory( + std::make_shared(workspace_resource)); } /** Destroys all held-up resources */ @@ -206,6 +210,11 @@ class device_resources : public resources { return resource::get_subcomm(*this, key); } + const rmm::mr::device_memory_resource* get_workspace_resource() const + { + return resource::get_workspace_resource(*this); + } + bool comms_initialized() const { return resource::comms_initialized(*this); } const cudaDeviceProp& get_device_properties() const diff --git a/cpp/include/raft/core/handle.hpp b/cpp/include/raft/core/handle.hpp index 48c1718eb0..6486965cdf 100644 --- a/cpp/include/raft/core/handle.hpp +++ b/cpp/include/raft/core/handle.hpp @@ -46,9 +46,10 @@ class handle_t : public raft::device_resources { * unspecified) * @param[in] stream_pool the stream pool used (which has default of nullptr if unspecified) */ - handle_t(rmm::cuda_stream_view stream_view = rmm::cuda_stream_per_thread, - std::shared_ptr stream_pool = {nullptr}) - : device_resources{stream_view, stream_pool} + handle_t(rmm::cuda_stream_view stream_view = rmm::cuda_stream_per_thread, + std::shared_ptr stream_pool = {nullptr}, + rmm::mr::device_memory_resource* workspace_resource = nullptr) + : device_resources{stream_view, stream_pool, workspace_resource} { } diff --git a/cpp/include/raft/core/resource/device_memory_resource.hpp b/cpp/include/raft/core/resource/device_memory_resource.hpp new file mode 100644 index 0000000000..0706f28f94 --- /dev/null +++ b/cpp/include/raft/core/resource/device_memory_resource.hpp @@ -0,0 +1,76 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include + +namespace raft::resource { +class device_memory_resource : public resource { + public: + device_memory_resource(rmm::mr::device_memory_resource* mr_ = nullptr) : mr(mr_) + { + if (mr_ == nullptr) { mr = rmm::mr::get_current_device_resource(); } + } + void* get_resource() override { return mr; } + + ~device_memory_resource() override {} + + private: + rmm::mr::device_memory_resource* mr; +}; + +/** + * Factory that knows how to construct a specific raft::resource to populate + * the resources instance. + */ +class workspace_resource_factory : public resource_factory { + public: + workspace_resource_factory(rmm::mr::device_memory_resource* mr_ = nullptr) : mr(mr_) {} + resource_type get_resource_type() override { return resource_type::WORKSPACE_RESOURCE; } + resource* make_resource() override { return new device_memory_resource(mr); } + + private: + rmm::mr::device_memory_resource* mr; +}; + +/** + * Load a temp workspace resource from a resources instance (and populate it on the res + * if needed). + * @param res raft resources object for managing resources + * @return device memory resource object + */ +inline rmm::mr::device_memory_resource* get_workspace_resource(resources const& res) +{ + if (!res.has_resource_factory(resource_type::WORKSPACE_RESOURCE)) { + res.add_resource_factory(std::make_shared()); + } + return res.get_resource(resource_type::WORKSPACE_RESOURCE); +}; + +/** + * Set a temp workspace resource on a resources instance. + * + * @param res raft resources object for managing resources + * @param mr a valid rmm device_memory_resource + * @return + */ +inline void set_workspace_resource(resources const& res, rmm::mr::device_memory_resource* mr) +{ + res.add_resource_factory(std::make_shared(mr)); +}; +} // namespace raft::resource \ No newline at end of file diff --git a/cpp/include/raft/core/resource/resource_types.hpp b/cpp/include/raft/core/resource/resource_types.hpp index c763066c79..ace4b7061b 100644 --- a/cpp/include/raft/core/resource/resource_types.hpp +++ b/cpp/include/raft/core/resource/resource_types.hpp @@ -35,6 +35,7 @@ enum resource_type { DEVICE_PROPERTIES, // cuda device properties DEVICE_ID, // cuda device id THRUST_POLICY, // thrust execution policy + WORKSPACE_RESOURCE, // rmm device memory resource LAST_KEY // reserved for the last key }; diff --git a/cpp/test/core/handle.cpp b/cpp/test/core/handle.cpp index 2148742e83..75b2d60bcd 100644 --- a/cpp/test/core/handle.cpp +++ b/cpp/test/core/handle.cpp @@ -21,6 +21,7 @@ #include #include #include +#include #include namespace raft { @@ -248,4 +249,23 @@ TEST(Raft, SubComms) ASSERT_EQ(handle.get_subcomm("key2").get_size(), 2); } +TEST(Raft, WorkspaceResource) +{ + handle_t handle; + + ASSERT_TRUE(dynamic_cast*>( + handle.get_workspace_resource()) == nullptr); + ASSERT_EQ(rmm::mr::get_current_device_resource(), handle.get_workspace_resource()); + + auto pool_mr = new rmm::mr::pool_memory_resource(rmm::mr::get_current_device_resource()); + std::shared_ptr pool = {nullptr}; + handle_t handle2(rmm::cuda_stream_per_thread, pool, pool_mr); + + ASSERT_TRUE(dynamic_cast*>( + handle2.get_workspace_resource()) != nullptr); + ASSERT_EQ(pool_mr, handle2.get_workspace_resource()); + + delete pool_mr; +} + } // namespace raft