Skip to content

Commit

Permalink
AlternateUpstream
Browse files Browse the repository at this point in the history
  • Loading branch information
madsbk committed Sep 2, 2024
1 parent 6d76ec6 commit bcaef4c
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 29 deletions.
78 changes: 56 additions & 22 deletions include/rmm/mr/device/failure_alternate_resource_adaptor.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2020-2024, NVIDIA CORPORATION.
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -32,23 +32,40 @@ namespace mr {
* @file
*/

template <typename Upstream, typename ExceptionType = rmm::out_of_memory>
/**
* @brief A device memory resource that use an alternate upstream resource when the primary throw a
* specified exception type.
*
* An instance of this resource must be constructed with two existing upstream resource in order to
* satisfy allocation requests.
*
* @tparam PrimaryUpstream The type of the primary upstream resource used for
* allocation/deallocation.
* @tparam AlternateUpstream The type of the alternate upstream resource used for
* allocation/deallocation when the primary fails.
* @tparam ExceptionType The type of exception that this adaptor should respond to.
*/
template <typename PrimaryUpstream,
typename AlternateUpstream,
typename ExceptionType = rmm::out_of_memory>
class failure_alternate_resource_adaptor final : public device_memory_resource {
public:
using exception_type = ExceptionType; ///< The type of exception this object catches/throws

/**
* @brief Construct a new `failure_alternate_resource_adaptor` using `upstream` to satisfy
* allocation requests.
* @brief Construct a new `failure_alternate_resource_adaptor` using `upstream` as the
* primary resource to satisfy allocation requests and if that fails, use `alternate_upstream`
* as an alternate
*
* @throws rmm::logic_error if `upstream == nullptr`
* @throws rmm::logic_error if `upstream == nullptr` or `alternate_upstream == nullptr`
*
* @param upstream The resource used for allocating/deallocating device memory
* @param alternate_upstream The resource used for alternate allocating/deallocating device
* memory
*/
failure_alternate_resource_adaptor(Upstream* upstream, Upstream* alternate_upstream)
: upstream_{upstream}, alternate_upstream_{alternate_upstream}
failure_alternate_resource_adaptor(PrimaryUpstream* upstream,
AlternateUpstream* alternate_upstream)
: primary_upstream_{upstream}, alternate_upstream_{alternate_upstream}
{
RMM_EXPECTS(nullptr != upstream, "Unexpected null upstream resource pointer.");
RMM_EXPECTS(nullptr != alternate_upstream,
Expand All @@ -69,23 +86,37 @@ class failure_alternate_resource_adaptor final : public device_memory_resource {
*/
[[nodiscard]] rmm::device_async_resource_ref get_upstream_resource() const noexcept
{
return upstream_;
return primary_upstream_;
}

/**
* @briefreturn{Upstream* to the upstream memory resource}
* @briefreturn{rmm::device_async_resource_ref to the alternate upstream resource}
*/
[[nodiscard]] Upstream* get_upstream() const noexcept { return upstream_; }
[[nodiscard]] rmm::device_async_resource_ref get_alternate_upstream_resource() const noexcept
{
return alternate_upstream_;
}

private:
using lock_guard = std::lock_guard<std::mutex>;
/**
* @briefreturn{PrimaryUpstream* to the upstream memory resource}
*/
[[nodiscard]] PrimaryUpstream* get_upstream() const noexcept { return primary_upstream_; }

/**
* @briefreturn{AlternateUpstream* to the alternate upstream memory resource}
*/
[[nodiscard]] AlternateUpstream* get_alternate_upstream() const noexcept
{
return alternate_upstream_;
}

private:
/**
* @brief Allocates memory of size at least `bytes` using the upstream
* resource.
*
* @throws `exception_type` if the requested allocation could not be fulfilled
* by the upstream resource.
* by the primary or the alternate upstream resource.
*
* @param bytes The size, in bytes, of the allocation
* @param stream Stream on which to perform the allocation
Expand All @@ -95,10 +126,10 @@ class failure_alternate_resource_adaptor final : public device_memory_resource {
{
void* ret{};
try {
ret = upstream_->allocate(bytes, stream);
ret = primary_upstream_->allocate(bytes, stream);
} catch (exception_type const& e) {
ret = alternate_upstream_->allocate(bytes, stream);
lock_guard lock(mtx_);
std::lock_guard<std::mutex> lock(mtx_);
alternate_allocations_.insert(ret);
}
return ret;
Expand All @@ -115,13 +146,13 @@ class failure_alternate_resource_adaptor final : public device_memory_resource {
{
std::size_t count{0};
{
lock_guard lock(mtx_);
std::lock_guard<std::mutex> lock(mtx_);
count = alternate_allocations_.erase(ptr);
}
if (count > 0) {
alternate_upstream_->deallocate(ptr, bytes, stream);
} else {
upstream_->deallocate(ptr, bytes, stream);
primary_upstream_->deallocate(ptr, bytes, stream);
}
}

Expand All @@ -135,13 +166,16 @@ class failure_alternate_resource_adaptor final : public device_memory_resource {
[[nodiscard]] bool do_is_equal(device_memory_resource const& other) const noexcept override
{
if (this == &other) { return true; }
auto cast = dynamic_cast<failure_alternate_resource_adaptor<Upstream> const*>(&other);
if (cast == nullptr) { return upstream_->is_equal(other); }
return get_upstream_resource() == cast->get_upstream_resource();
auto cast =
dynamic_cast<failure_alternate_resource_adaptor<PrimaryUpstream, AlternateUpstream> const*>(
&other);
if (cast == nullptr) { return primary_upstream_->is_equal(other); }
return get_upstream_resource() == cast->get_upstream_resource() &&
get_alternate_upstream_resource() == cast->get_alternate_upstream_resource();
}

Upstream* upstream_; // the upstream used for satisfying allocation requests
Upstream* alternate_upstream_; // the upstream used for satisfying alternate allocation requests
PrimaryUpstream* primary_upstream_; // the primary upstream
AlternateUpstream* alternate_upstream_; // the alternate upstream
std::unordered_set<void*> alternate_allocations_; // set of alternate allocations
mutable std::mutex mtx_; // Mutex for exclusive lock.
};
Expand Down
15 changes: 8 additions & 7 deletions python/rmm/rmm/_lib/memory_resource.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -231,12 +231,12 @@ cdef extern from "rmm/mr/device/failure_callback_resource_adaptor.hpp" \

cdef extern from "rmm/mr/device/failure_alternate_resource_adaptor.hpp" \
namespace "rmm::mr" nogil:
cdef cppclass failure_alternate_resource_adaptor[Upstream](
device_memory_resource
):
cdef cppclass failure_alternate_resource_adaptor[
PrimaryUpstream, AlternateUpstream
](device_memory_resource):
failure_alternate_resource_adaptor(
Upstream* upstream_mr,
Upstream* alternate_upstream_mr,
PrimaryUpstream* upstream_mr,
AlternateUpstream* alternate_upstream_mr,
) except +

cdef extern from "rmm/mr/device/prefetch_resource_adaptor.hpp" \
Expand Down Expand Up @@ -1061,7 +1061,9 @@ cdef class FailureAlternateResourceAdaptor(UpstreamResourceAdaptor):
self.alternate_upstream_mr = alternate_upstream_mr

self.c_obj.reset(
new failure_alternate_resource_adaptor[device_memory_resource](
new failure_alternate_resource_adaptor[
device_memory_resource, device_memory_resource
](
upstream_mr.get_mr(),
alternate_upstream_mr.get_mr(),
)
Expand All @@ -1081,7 +1083,6 @@ cdef class FailureAlternateResourceAdaptor(UpstreamResourceAdaptor):
return self.alternate_upstream_mr



cdef class PrefetchResourceAdaptor(UpstreamResourceAdaptor):

def __cinit__(
Expand Down

0 comments on commit bcaef4c

Please sign in to comment.