Skip to content

Commit

Permalink
Replace all internal usage of get_upstream with `get_upstream_resou…
Browse files Browse the repository at this point in the history
…rce` (#1491)

We want to get away from raw resources, so prepare deprecation of it by replacing all internal usages

This PR relies on preparation in downstream repositories

Authors:
  - Michael Schellenberger Costa (https://github.com/miscco)
  - Bradley Dice (https://github.com/bdice)

Approvers:
  - Bradley Dice (https://github.com/bdice)
  - Vyas Ramasubramani (https://github.com/vyasr)
  - Mark Harris (https://github.com/harrism)

URL: #1491
  • Loading branch information
miscco authored Mar 14, 2024
1 parent 2c161da commit a98931b
Show file tree
Hide file tree
Showing 14 changed files with 33 additions and 42 deletions.
3 changes: 2 additions & 1 deletion include/rmm/mr/device/aligned_resource_adaptor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,8 @@ class aligned_resource_adaptor final : public device_memory_resource {
{
if (this == &other) { return true; }
auto cast = dynamic_cast<aligned_resource_adaptor<Upstream> const*>(&other);
return cast != nullptr && upstream_->is_equal(*cast->get_upstream()) &&
if (cast == nullptr) { return false; }
return get_upstream_resource() == cast->get_upstream_resource() &&
alignment_ == cast->alignment_ && alignment_threshold_ == cast->alignment_threshold_;
}

Expand Down
13 changes: 6 additions & 7 deletions include/rmm/mr/device/binning_memory_resource.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,13 +149,13 @@ class binning_memory_resource final : public device_memory_resource {
* Chooses a memory_resource that allocates the smallest blocks at least as large as `bytes`.
*
* @param bytes Requested allocation size in bytes
* @return rmm::mr::device_memory_resource& memory_resource that can allocate the requested size.
* @return Get the resource reference for the requested size.
*/
device_memory_resource* get_resource(std::size_t bytes)
rmm::device_async_resource_ref get_resource_ref(std::size_t bytes)
{
auto iter = resource_bins_.lower_bound(bytes);
return (iter != resource_bins_.cend()) ? iter->second
: static_cast<device_memory_resource*>(get_upstream());
return (iter != resource_bins_.cend()) ? rmm::device_async_resource_ref{iter->second}
: get_upstream_resource();
}

/**
Expand All @@ -170,7 +170,7 @@ class binning_memory_resource final : public device_memory_resource {
void* do_allocate(std::size_t bytes, cuda_stream_view stream) override
{
if (bytes <= 0) { return nullptr; }
return get_resource(bytes)->allocate(bytes, stream);
return get_resource_ref(bytes).allocate_async(bytes, stream);
}

/**
Expand All @@ -183,8 +183,7 @@ class binning_memory_resource final : public device_memory_resource {
*/
void do_deallocate(void* ptr, std::size_t bytes, cuda_stream_view stream) override
{
auto res = get_resource(bytes);
if (res != nullptr) { res->deallocate(ptr, bytes, stream); }
get_resource_ref(bytes).deallocate_async(ptr, bytes, stream);
}

Upstream* upstream_mr_; // The upstream memory_resource from which to allocate blocks.
Expand Down
4 changes: 2 additions & 2 deletions include/rmm/mr/device/failure_callback_resource_adaptor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -183,8 +183,8 @@ class failure_callback_resource_adaptor final : public device_memory_resource {
{
if (this == &other) { return true; }
auto cast = dynamic_cast<failure_callback_resource_adaptor<Upstream> const*>(&other);
return cast != nullptr ? upstream_->is_equal(*cast->get_upstream())
: upstream_->is_equal(other);
if (cast == nullptr) { return upstream_->is_equal(other); }
return get_upstream_resource() == cast->get_upstream_resource();
}

Upstream* upstream_; // the upstream resource used for satisfying allocation requests
Expand Down
4 changes: 2 additions & 2 deletions include/rmm/mr/device/fixed_size_memory_resource.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ class fixed_size_memory_resource
*/
free_list blocks_from_upstream(cuda_stream_view stream)
{
void* ptr = get_upstream()->allocate(upstream_chunk_size_, stream);
void* ptr = get_upstream_resource().allocate_async(upstream_chunk_size_, stream);
block_type block{ptr};
upstream_blocks_.push_back(block);

Expand Down Expand Up @@ -211,7 +211,7 @@ class fixed_size_memory_resource
lock_guard lock(this->get_mutex());

for (auto block : upstream_blocks_) {
get_upstream()->deallocate(block.pointer(), upstream_chunk_size_);
get_upstream_resource().deallocate(block.pointer(), upstream_chunk_size_);
}
upstream_blocks_.clear();
}
Expand Down
4 changes: 2 additions & 2 deletions include/rmm/mr/device/limiting_resource_adaptor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -162,8 +162,8 @@ class limiting_resource_adaptor final : public device_memory_resource {
{
if (this == &other) { return true; }
auto const* cast = dynamic_cast<limiting_resource_adaptor<Upstream> const*>(&other);
if (cast != nullptr) { return upstream_->is_equal(*cast->get_upstream()); }
return upstream_->is_equal(other);
if (cast == nullptr) { return upstream_->is_equal(other); }
return get_upstream_resource() == cast->get_upstream_resource();
}

// maximum bytes this allocator is allowed to allocate.
Expand Down
4 changes: 2 additions & 2 deletions include/rmm/mr/device/logging_resource_adaptor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -277,8 +277,8 @@ class logging_resource_adaptor final : public device_memory_resource {
{
if (this == &other) { return true; }
auto const* cast = dynamic_cast<logging_resource_adaptor<Upstream> const*>(&other);
if (cast != nullptr) { return upstream_->is_equal(*cast->get_upstream()); }
return upstream_->is_equal(other);
if (cast == nullptr) { return upstream_->is_equal(other); }
return get_upstream_resource() == cast->get_upstream_resource();
}

// make_logging_adaptor needs access to private get_default_filename
Expand Down
4 changes: 2 additions & 2 deletions include/rmm/mr/device/pool_memory_resource.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -503,7 +503,7 @@ class pool_memory_resource final
if (size == 0) { return {}; }

try {
void* ptr = get_upstream()->allocate_async(size, stream);
void* ptr = get_upstream_resource().allocate_async(size, stream);
return std::optional<block_type>{
*upstream_blocks_.emplace(static_cast<char*>(ptr), size, true).first};
} catch (std::exception const& e) {
Expand Down Expand Up @@ -570,7 +570,7 @@ class pool_memory_resource final
lock_guard lock(this->get_mutex());

for (auto block : upstream_blocks_) {
get_upstream()->deallocate(block.pointer(), block.size());
get_upstream_resource().deallocate(block.pointer(), block.size());
}
upstream_blocks_.clear();
#ifdef RMM_POOL_TRACK_ALLOCATIONS
Expand Down
4 changes: 2 additions & 2 deletions include/rmm/mr/device/statistics_resource_adaptor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -209,8 +209,8 @@ class statistics_resource_adaptor final : public device_memory_resource {
{
if (this == &other) { return true; }
auto cast = dynamic_cast<statistics_resource_adaptor<Upstream> const*>(&other);
return cast != nullptr ? upstream_->is_equal(*cast->get_upstream())
: upstream_->is_equal(other);
if (cast == nullptr) { return upstream_->is_equal(other); }
return get_upstream_resource() == cast->get_upstream_resource();
}

counter bytes_; // peak, current and total allocated bytes
Expand Down
8 changes: 3 additions & 5 deletions include/rmm/mr/device/thread_safe_resource_adaptor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,11 +119,9 @@ class thread_safe_resource_adaptor final : public device_memory_resource {
bool do_is_equal(device_memory_resource const& other) const noexcept override
{
if (this == &other) { return true; }
auto thread_safe_other = dynamic_cast<thread_safe_resource_adaptor<Upstream> const*>(&other);
if (thread_safe_other != nullptr) {
return upstream_->is_equal(*thread_safe_other->get_upstream());
}
return upstream_->is_equal(other);
auto cast = dynamic_cast<thread_safe_resource_adaptor<Upstream> const*>(&other);
if (cast == nullptr) { return upstream_->is_equal(other); }
return get_upstream_resource() == cast->get_upstream_resource();
}

std::mutex mutable mtx; // mutex for thread safe access to upstream
Expand Down
4 changes: 2 additions & 2 deletions include/rmm/mr/device/tracking_resource_adaptor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -264,8 +264,8 @@ class tracking_resource_adaptor final : public device_memory_resource {
{
if (this == &other) { return true; }
auto cast = dynamic_cast<tracking_resource_adaptor<Upstream> const*>(&other);
return cast != nullptr ? upstream_->is_equal(*cast->get_upstream())
: upstream_->is_equal(other);
if (cast == nullptr) { return upstream_->is_equal(other); }
return get_upstream_resource() == cast->get_upstream_resource();
}

bool capture_stacks_; // whether or not to capture call stacks
Expand Down
4 changes: 2 additions & 2 deletions tests/device_check_resource_adaptor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ class device_check_resource_adaptor final : public rmm::mr::device_memory_resour
{
if (this == &other) { return true; }
auto const* cast = dynamic_cast<device_check_resource_adaptor const*>(&other);
if (cast != nullptr) { return upstream_->is_equal(*cast->get_upstream()); }
return upstream_->is_equal(other);
if (cast == nullptr) { return upstream_->is_equal(other); }
return get_upstream_resource() == cast->get_upstream_resource();
}

rmm::cuda_device_id device_id;
Expand Down
9 changes: 0 additions & 9 deletions tests/mr/device/adaptor_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,15 +135,6 @@ TYPED_TEST(AdaptorTest, Equality)
}
}

TYPED_TEST(AdaptorTest, GetUpstream)
{
if constexpr (std::is_same_v<TypeParam, owning_wrapper>) {
EXPECT_TRUE(this->mr->wrapped().get_upstream()->is_equal(this->cuda));
} else {
EXPECT_TRUE(this->mr->get_upstream()->is_equal(this->cuda));
}
}

TYPED_TEST(AdaptorTest, GetUpstreamResource)
{
rmm::device_async_resource_ref expected{this->cuda};
Expand Down
5 changes: 3 additions & 2 deletions tests/mr/device/statistics_mr_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,8 @@ TEST(StatisticsTest, PeakAllocations)

TEST(StatisticsTest, MultiTracking)
{
statistics_adaptor mr{rmm::mr::get_current_device_resource()};
auto* orig_device_resource = rmm::mr::get_current_device_resource();
statistics_adaptor mr{orig_device_resource};
rmm::mr::set_current_device_resource(&mr);

std::vector<std::shared_ptr<rmm::device_buffer>> allocations;
Expand Down Expand Up @@ -171,7 +172,7 @@ TEST(StatisticsTest, MultiTracking)
EXPECT_EQ(inner_mr.get_allocations_counter().peak, 5);

// Reset the current device resource
rmm::mr::set_current_device_resource(mr.get_upstream());
rmm::mr::set_current_device_resource(orig_device_resource);
}

TEST(StatisticsTest, NegativeInnerTracking)
Expand Down
5 changes: 3 additions & 2 deletions tests/mr/device/tracking_mr_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,8 @@ TEST(TrackingTest, AllocationsLeftWithoutStacks)

TEST(TrackingTest, MultiTracking)
{
tracking_adaptor mr{rmm::mr::get_current_device_resource(), true};
auto* orig_device_resource = rmm::mr::get_current_device_resource();
tracking_adaptor mr{orig_device_resource, true};
rmm::mr::set_current_device_resource(&mr);

std::vector<std::shared_ptr<rmm::device_buffer>> allocations;
Expand Down Expand Up @@ -140,7 +141,7 @@ TEST(TrackingTest, MultiTracking)
EXPECT_EQ(inner_mr.get_allocated_bytes(), 0);

// Reset the current device resource
rmm::mr::set_current_device_resource(mr.get_upstream());
rmm::mr::set_current_device_resource(orig_device_resource);
}

TEST(TrackingTest, NegativeInnerTracking)
Expand Down

0 comments on commit a98931b

Please sign in to comment.