Skip to content

Commit

Permalink
Enable shallow copy of handle_t's resources with different workspac…
Browse files Browse the repository at this point in the history
…e_resource (#1165)

This effectively affords users the flexibility to shallow copy a handle and it's underlying vectors and change out only the `workspace_resource` so that they can, for example, configure multiple different pools or managed pools for workspace resources. 

cc @Nyrio RE: hierarchical k-means API.

Authors:
  - Corey J. Nolet (https://github.com/cjnolet)

Approvers:
  - Divye Gala (https://github.com/divyegala)

URL: #1165
  • Loading branch information
cjnolet authored Jan 25, 2023
1 parent 0076101 commit 7c12b1e
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 15 deletions.
19 changes: 13 additions & 6 deletions cpp/include/raft/core/device_resources.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,18 @@ namespace raft {
*/
class device_resources : public resources {
public:
// delete copy/move constructors and assignment operators as
// copying and moving underlying resources is unsafe
device_resources(const device_resources&) = delete;
device_resources& operator=(const device_resources&) = delete;
device_resources(device_resources&&) = delete;
device_resources(const device_resources& handle,
rmm::mr::device_memory_resource* workspace_resource)
: resources{handle}
{
// replace the resource factory for the workspace_resources
resources::add_resource_factory(
std::make_shared<resource::workspace_resource_factory>(workspace_resource));
}

device_resources(const device_resources& handle) : resources{handle} {}

device_resources(device_resources&&) = delete;
device_resources& operator=(device_resources&&) = delete;

/**
Expand Down Expand Up @@ -210,7 +217,7 @@ class device_resources : public resources {
return resource::get_subcomm(*this, key);
}

const rmm::mr::device_memory_resource* get_workspace_resource() const
rmm::mr::device_memory_resource* get_workspace_resource() const
{
return resource::get_workspace_resource(*this);
}
Expand Down
13 changes: 8 additions & 5 deletions cpp/include/raft/core/handle.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,14 @@ namespace raft {
*/
class handle_t : public raft::device_resources {
public:
// delete copy/move constructors and assignment operators as
// copying and moving underlying resources is unsafe
handle_t(const handle_t&) = delete;
handle_t& operator=(const handle_t&) = delete;
handle_t(handle_t&&) = delete;
handle_t(const handle_t& handle, rmm::mr::device_memory_resource* workspace_resource)
: device_resources(handle, workspace_resource)
{
}

handle_t(const handle_t& handle) : device_resources{handle} {}

handle_t(handle_t&&) = delete;
handle_t& operator=(handle_t&&) = delete;

/**
Expand Down
11 changes: 7 additions & 4 deletions cpp/include/raft/core/resources.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,12 @@ class resources {
}
}

resources(const resources&) = delete;
resources& operator=(const resources&) = delete;
resources(resources&&) = delete;
/**
* @brief Shallow copy of underlying resources instance.
* Note that this does not create any new resources.
*/
resources(const resources& res) : factories_(res.factories_), resources_(res.resources_) {}
resources(resources&&) = delete;
resources& operator=(resources&&) = delete;

/**
Expand Down Expand Up @@ -120,7 +123,7 @@ class resources {
return reinterpret_cast<res_t*>(res->get_resource());
}

private:
protected:
mutable std::mutex mutex_;
mutable std::vector<pair_res_factory> factories_;
mutable std::vector<pair_resource> resources_;
Expand Down
56 changes: 56 additions & 0 deletions cpp/test/core/handle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,22 @@ class mock_comms : public comms_iface {
int n_ranks;
};

void assert_handles_equal(raft::handle_t& handle_one, raft::handle_t& handle_two)
{
// Assert shallow copied state
ASSERT_EQ(handle_one.get_stream().value(), handle_two.get_stream().value());
ASSERT_EQ(handle_one.get_stream_pool_size(), handle_two.get_stream_pool_size());

// Sanity check to make sure non-corresponding streams are not equal
ASSERT_NE(handle_one.get_stream_pool().get_stream(0).value(),
handle_two.get_stream_pool().get_stream(1).value());

for (size_t i = 0; i < handle_one.get_stream_pool_size(); ++i) {
ASSERT_EQ(handle_one.get_stream_pool().get_stream(i).value(),
handle_two.get_stream_pool().get_stream(i).value());
}
}

TEST(Raft, HandleDefault)
{
handle_t h;
Expand Down Expand Up @@ -268,4 +284,44 @@ TEST(Raft, WorkspaceResource)
delete pool_mr;
}

TEST(Raft, WorkspaceResourceCopy)
{
auto stream_pool = std::make_shared<rmm::cuda_stream_pool>(10);

handle_t handle(rmm::cuda_stream_per_thread, stream_pool);

auto pool_mr = new rmm::mr::pool_memory_resource(rmm::mr::get_current_device_resource());

handle_t copied_handle(handle, pool_mr);

assert_handles_equal(handle, copied_handle);

// Assert the workspace_resources are what we expect
ASSERT_TRUE(dynamic_cast<const rmm::mr::pool_memory_resource<rmm::mr::device_memory_resource>*>(
handle.get_workspace_resource()) == nullptr);

ASSERT_TRUE(dynamic_cast<const rmm::mr::pool_memory_resource<rmm::mr::device_memory_resource>*>(
copied_handle.get_workspace_resource()) != nullptr);
}

TEST(Raft, HandleCopy)
{
auto stream_pool = std::make_shared<rmm::cuda_stream_pool>(10);

handle_t handle(rmm::cuda_stream_per_thread, stream_pool);
handle_t copied_handle(handle);

assert_handles_equal(handle, copied_handle);
}

TEST(Raft, HandleAssign)
{
auto stream_pool = std::make_shared<rmm::cuda_stream_pool>(10);

handle_t handle(rmm::cuda_stream_per_thread, stream_pool);
handle_t copied_handle = handle;

assert_handles_equal(handle, copied_handle);
}

} // namespace raft

0 comments on commit 7c12b1e

Please sign in to comment.