From cbd7f43d5093438eaa6d0fde4a3f5fe7f1147f24 Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Mon, 2 Sep 2024 21:30:24 +0200 Subject: [PATCH] use device_async_resource_ref --- .../failure_alternate_resource_adaptor.hpp | 47 +++++-------------- python/rmm/rmm/_lib/memory_resource.pyx | 25 ++++++---- 2 files changed, 27 insertions(+), 45 deletions(-) diff --git a/include/rmm/mr/device/failure_alternate_resource_adaptor.hpp b/include/rmm/mr/device/failure_alternate_resource_adaptor.hpp index 559a3d6ee..1757f4ce9 100644 --- a/include/rmm/mr/device/failure_alternate_resource_adaptor.hpp +++ b/include/rmm/mr/device/failure_alternate_resource_adaptor.hpp @@ -39,15 +39,9 @@ namespace mr { * An instance of this resource must be constructed with two existing upstream resource to * satisfy allocation requests. * - * @tparam PrimaryUpstream The type of the primary upstream resource used for - * allocation/deallocation. - * @tparam AlternateUpstream The type of the alternate upstream resource used for - * allocation/deallocation when the primary fails. * @tparam ExceptionType The type of exception that this adaptor should respond to. */ -template +template class failure_alternate_resource_adaptor final : public device_memory_resource { public: using exception_type = ExceptionType; ///< The type of exception this object catches/throws @@ -57,18 +51,14 @@ class failure_alternate_resource_adaptor final : public device_memory_resource { * primary resource to satisfy allocation requests and if that fails, use `alternate_upstream` * as an alternate * - * @throws rmm::logic_error if `primary_upstream == nullptr` or `alternate_upstream == nullptr` - * * @param primary_upstream The primary resource used for allocating/deallocating device memory * @param alternate_upstream The alternate resource used for allocating/deallocating device memory * memory */ - failure_alternate_resource_adaptor(PrimaryUpstream* primary_upstream, - AlternateUpstream* alternate_upstream) + failure_alternate_resource_adaptor(device_async_resource_ref primary_upstream, + device_async_resource_ref alternate_upstream) : primary_upstream_{primary_upstream}, alternate_upstream_{alternate_upstream} { - RMM_EXPECTS(nullptr != primary_upstream, "Unexpected null upstream resource pointer."); - RMM_EXPECTS(nullptr != alternate_upstream, "Unexpected null upstream resource pointer."); } failure_alternate_resource_adaptor() = delete; @@ -96,19 +86,6 @@ class failure_alternate_resource_adaptor final : public device_memory_resource { return alternate_upstream_; } - /** - * @briefreturn{PrimaryUpstream* to the upstream memory resource} - */ - [[nodiscard]] PrimaryUpstream* get_upstream() const noexcept { return primary_upstream_; } - - /** - * @briefreturn{AlternateUpstream* to the alternate upstream memory resource} - */ - [[nodiscard]] AlternateUpstream* get_alternate_upstream() const noexcept - { - return alternate_upstream_; - } - private: /** * @brief Allocates memory of size at least `bytes` using the upstream @@ -125,9 +102,9 @@ class failure_alternate_resource_adaptor final : public device_memory_resource { { void* ret{}; try { - ret = primary_upstream_->allocate(bytes, stream); + ret = primary_upstream_.allocate_async(bytes, stream); } catch (exception_type const& e) { - ret = alternate_upstream_->allocate(bytes, stream); + ret = alternate_upstream_.allocate_async(bytes, stream); std::lock_guard lock(mtx_); alternate_allocations_.insert(ret); } @@ -149,9 +126,9 @@ class failure_alternate_resource_adaptor final : public device_memory_resource { count = alternate_allocations_.erase(ptr); } if (count > 0) { - alternate_upstream_->deallocate(ptr, bytes, stream); + alternate_upstream_.deallocate_async(ptr, bytes, stream); } else { - primary_upstream_->deallocate(ptr, bytes, stream); + primary_upstream_.deallocate_async(ptr, bytes, stream); } } @@ -165,16 +142,14 @@ class failure_alternate_resource_adaptor final : public device_memory_resource { [[nodiscard]] bool do_is_equal(device_memory_resource const& other) const noexcept override { if (this == &other) { return true; } - auto cast = - dynamic_cast const*>( - &other); - if (cast == nullptr) { return primary_upstream_->is_equal(other); } + auto cast = dynamic_cast(&other); + if (cast == nullptr) { return false; } return get_upstream_resource() == cast->get_upstream_resource() && get_alternate_upstream_resource() == cast->get_alternate_upstream_resource(); } - PrimaryUpstream* primary_upstream_; // the primary upstream - AlternateUpstream* alternate_upstream_; // the alternate upstream + device_async_resource_ref primary_upstream_; // the primary upstream + device_async_resource_ref alternate_upstream_; // the alternate upstream std::unordered_set alternate_allocations_; // set of alternate allocations mutable std::mutex mtx_; // Mutex for exclusive lock. }; diff --git a/python/rmm/rmm/_lib/memory_resource.pyx b/python/rmm/rmm/_lib/memory_resource.pyx index f35e4ba84..17c67c8c3 100644 --- a/python/rmm/rmm/_lib/memory_resource.pyx +++ b/python/rmm/rmm/_lib/memory_resource.pyx @@ -32,7 +32,9 @@ from libcpp.string cimport string from cuda.cudart import cudaError_t from rmm._cuda.gpu import CUDARuntimeError, getDevice, setDevice + from rmm._cuda.stream cimport Stream + from rmm._cuda.stream import DEFAULT_STREAM from rmm._lib.cuda_stream_view cimport cuda_stream_view @@ -44,6 +46,7 @@ from rmm._lib.per_device_resource cimport ( cuda_device_id, set_per_device_resource as cpp_set_per_device_resource, ) + from rmm.statistics import Statistics # Transparent handle of a C++ exception @@ -84,6 +87,10 @@ cdef extern from *: # NOTE: Keep extern declarations in .pyx file as much as possible to avoid # leaking dependencies when importing RMM Cython .pxd files + +cdef extern from "rmm/error.hpp" namespace "rmm" nogil: + cdef cppclass out_of_memory + cdef extern from "rmm/mr/device/cuda_memory_resource.hpp" \ namespace "rmm::mr" nogil: cdef cppclass cuda_memory_resource(device_memory_resource): @@ -125,7 +132,6 @@ cdef extern from "rmm/mr/device/cuda_async_memory_resource.hpp" \ win32 win32_kmt - cdef extern from "rmm/mr/device/pool_memory_resource.hpp" \ namespace "rmm::mr" nogil: cdef cppclass pool_memory_resource[Upstream](device_memory_resource): @@ -231,12 +237,15 @@ cdef extern from "rmm/mr/device/failure_callback_resource_adaptor.hpp" \ cdef extern from "rmm/mr/device/failure_alternate_resource_adaptor.hpp" \ namespace "rmm::mr" nogil: - cdef cppclass failure_alternate_resource_adaptor[ - PrimaryUpstream, AlternateUpstream - ](device_memory_resource): + cdef cppclass failure_alternate_resource_adaptor[ExceptionType]( + device_memory_resource + ): + # Notice, `failure_alternate_resource_adaptor` takes `device_async_resource_ref` + # as upstream arguments but we define them as `device_memory_resource*` and + # rely on implicit type conversion. failure_alternate_resource_adaptor( - PrimaryUpstream* upstream_mr, - AlternateUpstream* alternate_upstream_mr, + device_memory_resource* upstream_mr, + device_memory_resource* alternate_upstream_mr, ) except + cdef extern from "rmm/mr/device/prefetch_resource_adaptor.hpp" \ @@ -1061,9 +1070,7 @@ cdef class FailureAlternateResourceAdaptor(UpstreamResourceAdaptor): self.alternate_upstream_mr = alternate_upstream_mr self.c_obj.reset( - new failure_alternate_resource_adaptor[ - device_memory_resource, device_memory_resource - ]( + new failure_alternate_resource_adaptor[out_of_memory]( upstream_mr.get_mr(), alternate_upstream_mr.get_mr(), )