Skip to content

Commit

Permalink
fix race condition in limiting resource adapter (#869)
Browse files Browse the repository at this point in the history
Fixes #868 

Also fixed some clang-tidy warnings.

Authors:
  - Rong Ou (https://github.com/rongou)

Approvers:
  - Alessandro Bellina (https://github.com/abellina)
  - Jake Hemstad (https://github.com/jrhemstad)
  - Mike Wilson (https://github.com/hyperbolic2346)
  - Mark Harris (https://github.com/harrism)

URL: #869
  • Loading branch information
rongou authored Sep 14, 2021
1 parent 73256f3 commit fe53a72
Showing 1 changed file with 29 additions and 26 deletions.
55 changes: 29 additions & 26 deletions include/rmm/mr/device/limiting_resource_adaptor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@

#include <cstddef>

namespace rmm {
namespace mr {
namespace rmm::mr {
/**
* @brief Resource that uses `Upstream` to allocate memory and limits the total
* allocations possible.
Expand Down Expand Up @@ -59,12 +58,12 @@ class limiting_resource_adaptor final : public device_memory_resource {
RMM_EXPECTS(nullptr != upstream, "Unexpected null upstream resource pointer.");
}

limiting_resource_adaptor() = delete;
~limiting_resource_adaptor() = default;
limiting_resource_adaptor(limiting_resource_adaptor const&) = delete;
limiting_resource_adaptor(limiting_resource_adaptor&&) = default;
limiting_resource_adaptor() = delete;
~limiting_resource_adaptor() override = default;
limiting_resource_adaptor(limiting_resource_adaptor const&) = delete;
limiting_resource_adaptor(limiting_resource_adaptor&&) noexcept = default;
limiting_resource_adaptor& operator=(limiting_resource_adaptor const&) = delete;
limiting_resource_adaptor& operator=(limiting_resource_adaptor&&) = default;
limiting_resource_adaptor& operator=(limiting_resource_adaptor&&) noexcept = default;

/**
* @brief Return pointer to the upstream resource.
Expand All @@ -79,14 +78,17 @@ class limiting_resource_adaptor final : public device_memory_resource {
* @return true The upstream resource supports streams
* @return false The upstream resource does not support streams.
*/
bool supports_streams() const noexcept override { return upstream_->supports_streams(); }
[[nodiscard]] bool supports_streams() const noexcept override
{
return upstream_->supports_streams();
}

/**
* @brief Query whether the resource supports the get_mem_info API.
*
* @return bool true if the upstream resource supports get_mem_info, false otherwise.
*/
bool supports_get_mem_info() const noexcept override
[[nodiscard]] bool supports_get_mem_info() const noexcept override
{
return upstream_->supports_get_mem_info();
}
Expand All @@ -100,7 +102,7 @@ class limiting_resource_adaptor final : public device_memory_resource {
* @return std::size_t number of bytes that have been allocated through this
* allocator.
*/
std::size_t get_allocated_bytes() const { return allocated_bytes_; }
[[nodiscard]] std::size_t get_allocated_bytes() const { return allocated_bytes_; }

/**
* @brief Query the maximum number of bytes that this allocator is allowed
Expand All @@ -109,7 +111,7 @@ class limiting_resource_adaptor final : public device_memory_resource {
*
* @return std::size_t max number of bytes allowed for this allocator
*/
std::size_t get_allocation_limit() const { return allocation_limit_; }
[[nodiscard]] std::size_t get_allocation_limit() const { return allocation_limit_; }

private:
/**
Expand All @@ -127,17 +129,19 @@ class limiting_resource_adaptor final : public device_memory_resource {
*/
void* do_allocate(std::size_t bytes, cuda_stream_view stream) override
{
void* p = nullptr;

std::size_t proposed_size = rmm::detail::align_up(bytes, allocation_alignment_);
if (proposed_size + allocated_bytes_ <= allocation_limit_) {
p = upstream_->allocate(bytes, stream);
allocated_bytes_ += proposed_size;
} else {
throw rmm::bad_alloc{"Exceeded memory limit"};
auto const proposed_size = rmm::detail::align_up(bytes, allocation_alignment_);
auto const old = allocated_bytes_.fetch_add(proposed_size);
if (old + proposed_size <= allocation_limit_) {
try {
return upstream_->allocate(bytes, stream);
} catch (...) {
allocated_bytes_ -= proposed_size;
throw;
}
}

return p;
allocated_bytes_ -= proposed_size;
RMM_FAIL("Exceeded memory limit", rmm::bad_alloc);
}

/**
Expand Down Expand Up @@ -165,13 +169,12 @@ class limiting_resource_adaptor final : public device_memory_resource {
* @return true If the two resources are equivalent
* @return false If the two resources are not equal
*/
bool do_is_equal(device_memory_resource const& other) const noexcept override
[[nodiscard]] bool do_is_equal(device_memory_resource const& other) const noexcept override
{
if (this == &other)
return true;
else {
limiting_resource_adaptor<Upstream> const* cast =
dynamic_cast<limiting_resource_adaptor<Upstream> const*>(&other);
auto const* cast = dynamic_cast<limiting_resource_adaptor<Upstream> const*>(&other);
if (cast != nullptr)
return upstream_->is_equal(*cast->get_upstream());
else
Expand All @@ -187,7 +190,8 @@ class limiting_resource_adaptor final : public device_memory_resource {
* @param stream Stream on which to get the mem info.
* @return std::pair contaiing free_size and total_size of memory
*/
std::pair<std::size_t, std::size_t> do_get_mem_info(cuda_stream_view stream) const override
[[nodiscard]] std::pair<std::size_t, std::size_t> do_get_mem_info(
cuda_stream_view stream) const override
{
return {allocation_limit_ - allocated_bytes_, allocation_limit_};
}
Expand Down Expand Up @@ -220,5 +224,4 @@ limiting_resource_adaptor<Upstream> make_limiting_adaptor(Upstream* upstream,
return limiting_resource_adaptor<Upstream>{upstream, allocation_limit};
}

} // namespace mr
} // namespace rmm
} // namespace rmm::mr

0 comments on commit fe53a72

Please sign in to comment.