diff --git a/cpp/include/raft/core/device_resources.hpp b/cpp/include/raft/core/device_resources.hpp index 9b9e07cf4f..ec0b92dde2 100644 --- a/cpp/include/raft/core/device_resources.hpp +++ b/cpp/include/raft/core/device_resources.hpp @@ -59,11 +59,18 @@ namespace raft { */ class device_resources : public resources { public: - // delete copy/move constructors and assignment operators as - // copying and moving underlying resources is unsafe - device_resources(const device_resources&) = delete; - device_resources& operator=(const device_resources&) = delete; - device_resources(device_resources&&) = delete; + device_resources(const device_resources& handle, + rmm::mr::device_memory_resource* workspace_resource) + : resources{handle} + { + // replace the resource factory for the workspace_resources + resources::add_resource_factory( + std::make_shared(workspace_resource)); + } + + device_resources(const device_resources& handle) : resources{handle} {} + + device_resources(device_resources&&) = delete; device_resources& operator=(device_resources&&) = delete; /** @@ -210,7 +217,7 @@ class device_resources : public resources { return resource::get_subcomm(*this, key); } - const rmm::mr::device_memory_resource* get_workspace_resource() const + rmm::mr::device_memory_resource* get_workspace_resource() const { return resource::get_workspace_resource(*this); } diff --git a/cpp/include/raft/core/handle.hpp b/cpp/include/raft/core/handle.hpp index 6486965cdf..c1e7aa538f 100644 --- a/cpp/include/raft/core/handle.hpp +++ b/cpp/include/raft/core/handle.hpp @@ -32,11 +32,14 @@ namespace raft { */ class handle_t : public raft::device_resources { public: - // delete copy/move constructors and assignment operators as - // copying and moving underlying resources is unsafe - handle_t(const handle_t&) = delete; - handle_t& operator=(const handle_t&) = delete; - handle_t(handle_t&&) = delete; + handle_t(const handle_t& handle, rmm::mr::device_memory_resource* workspace_resource) + : device_resources(handle, workspace_resource) + { + } + + handle_t(const handle_t& handle) : device_resources{handle} {} + + handle_t(handle_t&&) = delete; handle_t& operator=(handle_t&&) = delete; /** diff --git a/cpp/include/raft/core/resources.hpp b/cpp/include/raft/core/resources.hpp index 797fd5968d..64e281e934 100644 --- a/cpp/include/raft/core/resources.hpp +++ b/cpp/include/raft/core/resources.hpp @@ -62,9 +62,12 @@ class resources { } } - resources(const resources&) = delete; - resources& operator=(const resources&) = delete; - resources(resources&&) = delete; + /** + * @brief Shallow copy of underlying resources instance. + * Note that this does not create any new resources. + */ + resources(const resources& res) : factories_(res.factories_), resources_(res.resources_) {} + resources(resources&&) = delete; resources& operator=(resources&&) = delete; /** @@ -120,7 +123,7 @@ class resources { return reinterpret_cast(res->get_resource()); } - private: + protected: mutable std::mutex mutex_; mutable std::vector factories_; mutable std::vector resources_; diff --git a/cpp/test/core/handle.cpp b/cpp/test/core/handle.cpp index 75b2d60bcd..8357c27f38 100644 --- a/cpp/test/core/handle.cpp +++ b/cpp/test/core/handle.cpp @@ -165,6 +165,22 @@ class mock_comms : public comms_iface { int n_ranks; }; +void assert_handles_equal(raft::handle_t& handle_one, raft::handle_t& handle_two) +{ + // Assert shallow copied state + ASSERT_EQ(handle_one.get_stream().value(), handle_two.get_stream().value()); + ASSERT_EQ(handle_one.get_stream_pool_size(), handle_two.get_stream_pool_size()); + + // Sanity check to make sure non-corresponding streams are not equal + ASSERT_NE(handle_one.get_stream_pool().get_stream(0).value(), + handle_two.get_stream_pool().get_stream(1).value()); + + for (size_t i = 0; i < handle_one.get_stream_pool_size(); ++i) { + ASSERT_EQ(handle_one.get_stream_pool().get_stream(i).value(), + handle_two.get_stream_pool().get_stream(i).value()); + } +} + TEST(Raft, HandleDefault) { handle_t h; @@ -268,4 +284,44 @@ TEST(Raft, WorkspaceResource) delete pool_mr; } +TEST(Raft, WorkspaceResourceCopy) +{ + auto stream_pool = std::make_shared(10); + + handle_t handle(rmm::cuda_stream_per_thread, stream_pool); + + auto pool_mr = new rmm::mr::pool_memory_resource(rmm::mr::get_current_device_resource()); + + handle_t copied_handle(handle, pool_mr); + + assert_handles_equal(handle, copied_handle); + + // Assert the workspace_resources are what we expect + ASSERT_TRUE(dynamic_cast*>( + handle.get_workspace_resource()) == nullptr); + + ASSERT_TRUE(dynamic_cast*>( + copied_handle.get_workspace_resource()) != nullptr); +} + +TEST(Raft, HandleCopy) +{ + auto stream_pool = std::make_shared(10); + + handle_t handle(rmm::cuda_stream_per_thread, stream_pool); + handle_t copied_handle(handle); + + assert_handles_equal(handle, copied_handle); +} + +TEST(Raft, HandleAssign) +{ + auto stream_pool = std::make_shared(10); + + handle_t handle(rmm::cuda_stream_per_thread, stream_pool); + handle_t copied_handle = handle; + + assert_handles_equal(handle, copied_handle); +} + } // namespace raft