Skip to content

Commit

Permalink
use device_async_resource_ref
Browse files Browse the repository at this point in the history
  • Loading branch information
madsbk committed Sep 2, 2024
1 parent e42fc60 commit cbd7f43
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 45 deletions.
47 changes: 11 additions & 36 deletions include/rmm/mr/device/failure_alternate_resource_adaptor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,9 @@ namespace mr {
* An instance of this resource must be constructed with two existing upstream resource to
* satisfy allocation requests.
*
* @tparam PrimaryUpstream The type of the primary upstream resource used for
* allocation/deallocation.
* @tparam AlternateUpstream The type of the alternate upstream resource used for
* allocation/deallocation when the primary fails.
* @tparam ExceptionType The type of exception that this adaptor should respond to.
*/
template <typename PrimaryUpstream,
typename AlternateUpstream,
typename ExceptionType = rmm::out_of_memory>
template <typename ExceptionType = rmm::out_of_memory>
class failure_alternate_resource_adaptor final : public device_memory_resource {
public:
using exception_type = ExceptionType; ///< The type of exception this object catches/throws
Expand All @@ -57,18 +51,14 @@ class failure_alternate_resource_adaptor final : public device_memory_resource {
* primary resource to satisfy allocation requests and if that fails, use `alternate_upstream`
* as an alternate
*
* @throws rmm::logic_error if `primary_upstream == nullptr` or `alternate_upstream == nullptr`
*
* @param primary_upstream The primary resource used for allocating/deallocating device memory
* @param alternate_upstream The alternate resource used for allocating/deallocating device memory
* memory
*/
failure_alternate_resource_adaptor(PrimaryUpstream* primary_upstream,
AlternateUpstream* alternate_upstream)
failure_alternate_resource_adaptor(device_async_resource_ref primary_upstream,
device_async_resource_ref alternate_upstream)
: primary_upstream_{primary_upstream}, alternate_upstream_{alternate_upstream}
{
RMM_EXPECTS(nullptr != primary_upstream, "Unexpected null upstream resource pointer.");
RMM_EXPECTS(nullptr != alternate_upstream, "Unexpected null upstream resource pointer.");
}

failure_alternate_resource_adaptor() = delete;
Expand Down Expand Up @@ -96,19 +86,6 @@ class failure_alternate_resource_adaptor final : public device_memory_resource {
return alternate_upstream_;
}

/**
* @briefreturn{PrimaryUpstream* to the upstream memory resource}
*/
[[nodiscard]] PrimaryUpstream* get_upstream() const noexcept { return primary_upstream_; }

/**
* @briefreturn{AlternateUpstream* to the alternate upstream memory resource}
*/
[[nodiscard]] AlternateUpstream* get_alternate_upstream() const noexcept
{
return alternate_upstream_;
}

private:
/**
* @brief Allocates memory of size at least `bytes` using the upstream
Expand All @@ -125,9 +102,9 @@ class failure_alternate_resource_adaptor final : public device_memory_resource {
{
void* ret{};
try {
ret = primary_upstream_->allocate(bytes, stream);
ret = primary_upstream_.allocate_async(bytes, stream);
} catch (exception_type const& e) {
ret = alternate_upstream_->allocate(bytes, stream);
ret = alternate_upstream_.allocate_async(bytes, stream);
std::lock_guard<std::mutex> lock(mtx_);
alternate_allocations_.insert(ret);
}
Expand All @@ -149,9 +126,9 @@ class failure_alternate_resource_adaptor final : public device_memory_resource {
count = alternate_allocations_.erase(ptr);
}
if (count > 0) {
alternate_upstream_->deallocate(ptr, bytes, stream);
alternate_upstream_.deallocate_async(ptr, bytes, stream);
} else {
primary_upstream_->deallocate(ptr, bytes, stream);
primary_upstream_.deallocate_async(ptr, bytes, stream);
}
}

Expand All @@ -165,16 +142,14 @@ class failure_alternate_resource_adaptor final : public device_memory_resource {
[[nodiscard]] bool do_is_equal(device_memory_resource const& other) const noexcept override
{
if (this == &other) { return true; }
auto cast =
dynamic_cast<failure_alternate_resource_adaptor<PrimaryUpstream, AlternateUpstream> const*>(
&other);
if (cast == nullptr) { return primary_upstream_->is_equal(other); }
auto cast = dynamic_cast<failure_alternate_resource_adaptor const*>(&other);
if (cast == nullptr) { return false; }
return get_upstream_resource() == cast->get_upstream_resource() &&
get_alternate_upstream_resource() == cast->get_alternate_upstream_resource();
}

PrimaryUpstream* primary_upstream_; // the primary upstream
AlternateUpstream* alternate_upstream_; // the alternate upstream
device_async_resource_ref primary_upstream_; // the primary upstream
device_async_resource_ref alternate_upstream_; // the alternate upstream
std::unordered_set<void*> alternate_allocations_; // set of alternate allocations
mutable std::mutex mtx_; // Mutex for exclusive lock.
};
Expand Down
25 changes: 16 additions & 9 deletions python/rmm/rmm/_lib/memory_resource.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ from libcpp.string cimport string
from cuda.cudart import cudaError_t

from rmm._cuda.gpu import CUDARuntimeError, getDevice, setDevice

from rmm._cuda.stream cimport Stream

from rmm._cuda.stream import DEFAULT_STREAM

from rmm._lib.cuda_stream_view cimport cuda_stream_view
Expand All @@ -44,6 +46,7 @@ from rmm._lib.per_device_resource cimport (
cuda_device_id,
set_per_device_resource as cpp_set_per_device_resource,
)

from rmm.statistics import Statistics

# Transparent handle of a C++ exception
Expand Down Expand Up @@ -84,6 +87,10 @@ cdef extern from *:

# NOTE: Keep extern declarations in .pyx file as much as possible to avoid
# leaking dependencies when importing RMM Cython .pxd files

cdef extern from "rmm/error.hpp" namespace "rmm" nogil:
cdef cppclass out_of_memory

cdef extern from "rmm/mr/device/cuda_memory_resource.hpp" \
namespace "rmm::mr" nogil:
cdef cppclass cuda_memory_resource(device_memory_resource):
Expand Down Expand Up @@ -125,7 +132,6 @@ cdef extern from "rmm/mr/device/cuda_async_memory_resource.hpp" \
win32
win32_kmt


cdef extern from "rmm/mr/device/pool_memory_resource.hpp" \
namespace "rmm::mr" nogil:
cdef cppclass pool_memory_resource[Upstream](device_memory_resource):
Expand Down Expand Up @@ -231,12 +237,15 @@ cdef extern from "rmm/mr/device/failure_callback_resource_adaptor.hpp" \

cdef extern from "rmm/mr/device/failure_alternate_resource_adaptor.hpp" \
namespace "rmm::mr" nogil:
cdef cppclass failure_alternate_resource_adaptor[
PrimaryUpstream, AlternateUpstream
](device_memory_resource):
cdef cppclass failure_alternate_resource_adaptor[ExceptionType](
device_memory_resource
):
# Notice, `failure_alternate_resource_adaptor` takes `device_async_resource_ref`
# as upstream arguments but we define them as `device_memory_resource*` and
# rely on implicit type conversion.
failure_alternate_resource_adaptor(
PrimaryUpstream* upstream_mr,
AlternateUpstream* alternate_upstream_mr,
device_memory_resource* upstream_mr,
device_memory_resource* alternate_upstream_mr,
) except +

cdef extern from "rmm/mr/device/prefetch_resource_adaptor.hpp" \
Expand Down Expand Up @@ -1061,9 +1070,7 @@ cdef class FailureAlternateResourceAdaptor(UpstreamResourceAdaptor):
self.alternate_upstream_mr = alternate_upstream_mr

self.c_obj.reset(
new failure_alternate_resource_adaptor[
device_memory_resource, device_memory_resource
](
new failure_alternate_resource_adaptor[out_of_memory](
upstream_mr.get_mr(),
alternate_upstream_mr.get_mr(),
)
Expand Down

0 comments on commit cbd7f43

Please sign in to comment.