diff --git a/include/rmm/mr/device/limiting_resource_adaptor.hpp b/include/rmm/mr/device/limiting_resource_adaptor.hpp index 5002962d5..b83fe3911 100644 --- a/include/rmm/mr/device/limiting_resource_adaptor.hpp +++ b/include/rmm/mr/device/limiting_resource_adaptor.hpp @@ -21,8 +21,7 @@ #include -namespace rmm { -namespace mr { +namespace rmm::mr { /** * @brief Resource that uses `Upstream` to allocate memory and limits the total * allocations possible. @@ -59,12 +58,12 @@ class limiting_resource_adaptor final : public device_memory_resource { RMM_EXPECTS(nullptr != upstream, "Unexpected null upstream resource pointer."); } - limiting_resource_adaptor() = delete; - ~limiting_resource_adaptor() = default; - limiting_resource_adaptor(limiting_resource_adaptor const&) = delete; - limiting_resource_adaptor(limiting_resource_adaptor&&) = default; + limiting_resource_adaptor() = delete; + ~limiting_resource_adaptor() override = default; + limiting_resource_adaptor(limiting_resource_adaptor const&) = delete; + limiting_resource_adaptor(limiting_resource_adaptor&&) noexcept = default; limiting_resource_adaptor& operator=(limiting_resource_adaptor const&) = delete; - limiting_resource_adaptor& operator=(limiting_resource_adaptor&&) = default; + limiting_resource_adaptor& operator=(limiting_resource_adaptor&&) noexcept = default; /** * @brief Return pointer to the upstream resource. @@ -79,14 +78,17 @@ class limiting_resource_adaptor final : public device_memory_resource { * @return true The upstream resource supports streams * @return false The upstream resource does not support streams. */ - bool supports_streams() const noexcept override { return upstream_->supports_streams(); } + [[nodiscard]] bool supports_streams() const noexcept override + { + return upstream_->supports_streams(); + } /** * @brief Query whether the resource supports the get_mem_info API. * * @return bool true if the upstream resource supports get_mem_info, false otherwise. */ - bool supports_get_mem_info() const noexcept override + [[nodiscard]] bool supports_get_mem_info() const noexcept override { return upstream_->supports_get_mem_info(); } @@ -100,7 +102,7 @@ class limiting_resource_adaptor final : public device_memory_resource { * @return std::size_t number of bytes that have been allocated through this * allocator. */ - std::size_t get_allocated_bytes() const { return allocated_bytes_; } + [[nodiscard]] std::size_t get_allocated_bytes() const { return allocated_bytes_; } /** * @brief Query the maximum number of bytes that this allocator is allowed @@ -109,7 +111,7 @@ class limiting_resource_adaptor final : public device_memory_resource { * * @return std::size_t max number of bytes allowed for this allocator */ - std::size_t get_allocation_limit() const { return allocation_limit_; } + [[nodiscard]] std::size_t get_allocation_limit() const { return allocation_limit_; } private: /** @@ -127,17 +129,19 @@ class limiting_resource_adaptor final : public device_memory_resource { */ void* do_allocate(std::size_t bytes, cuda_stream_view stream) override { - void* p = nullptr; - - std::size_t proposed_size = rmm::detail::align_up(bytes, allocation_alignment_); - if (proposed_size + allocated_bytes_ <= allocation_limit_) { - p = upstream_->allocate(bytes, stream); - allocated_bytes_ += proposed_size; - } else { - throw rmm::bad_alloc{"Exceeded memory limit"}; + auto const proposed_size = rmm::detail::align_up(bytes, allocation_alignment_); + auto const old = allocated_bytes_.fetch_add(proposed_size); + if (old + proposed_size <= allocation_limit_) { + try { + return upstream_->allocate(bytes, stream); + } catch (...) { + allocated_bytes_ -= proposed_size; + throw; + } } - return p; + allocated_bytes_ -= proposed_size; + RMM_FAIL("Exceeded memory limit", rmm::bad_alloc); } /** @@ -165,13 +169,12 @@ class limiting_resource_adaptor final : public device_memory_resource { * @return true If the two resources are equivalent * @return false If the two resources are not equal */ - bool do_is_equal(device_memory_resource const& other) const noexcept override + [[nodiscard]] bool do_is_equal(device_memory_resource const& other) const noexcept override { if (this == &other) return true; else { - limiting_resource_adaptor const* cast = - dynamic_cast const*>(&other); + auto const* cast = dynamic_cast const*>(&other); if (cast != nullptr) return upstream_->is_equal(*cast->get_upstream()); else @@ -187,7 +190,8 @@ class limiting_resource_adaptor final : public device_memory_resource { * @param stream Stream on which to get the mem info. * @return std::pair contaiing free_size and total_size of memory */ - std::pair do_get_mem_info(cuda_stream_view stream) const override + [[nodiscard]] std::pair do_get_mem_info( + cuda_stream_view stream) const override { return {allocation_limit_ - allocated_bytes_, allocation_limit_}; } @@ -220,5 +224,4 @@ limiting_resource_adaptor make_limiting_adaptor(Upstream* upstream, return limiting_resource_adaptor{upstream, allocation_limit}; } -} // namespace mr -} // namespace rmm +} // namespace rmm::mr