Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fallback resource adaptor #1665

Draft
wants to merge 18 commits into
base: branch-24.10
Choose a base branch
from
159 changes: 159 additions & 0 deletions include/rmm/mr/device/fallback_resource_adapater.hpp
madsbk marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#include <rmm/detail/error.hpp>
#include <rmm/detail/export.hpp>
#include <rmm/mr/device/device_memory_resource.hpp>
#include <rmm/resource_ref.hpp>

#include <cstddef>
#include <mutex>
#include <unordered_set>

namespace RMM_NAMESPACE {
namespace mr {
/**
* @addtogroup device_resource_adaptors
* @{
* @file
*/

/**
* @brief A device memory resource that uses an alternate upstream resource when the primary
* upstream resource throws a specified exception type.
*
* An instance of this resource must be constructed with two upstream resources to satisfy
* allocation requests.
*
* @tparam ExceptionType The type of exception that this adaptor should respond to.
*/
template <typename ExceptionType = rmm::out_of_memory>
class fallback_resource_adapater final : public device_memory_resource {
madsbk marked this conversation as resolved.
Show resolved Hide resolved
public:
using exception_type = ExceptionType; ///< The type of exception this object catches/throws

/**
* @brief Construct a new `fallback_resource_adapater` that uses `primary_upstream`
* to satisfy allocation requests and if that fails with `ExceptionType`, uses
* `alternate_upstream`.
*
* @param primary_upstream The primary resource used for allocating/deallocating device memory
* @param alternate_upstream The alternate resource used for allocating/deallocating device memory
* memory
*/
fallback_resource_adapater(device_async_resource_ref primary_upstream,
device_async_resource_ref alternate_upstream)
: primary_upstream_{primary_upstream}, alternate_upstream_{alternate_upstream}
{
}

fallback_resource_adapater() = delete;
~fallback_resource_adapater() override = default;
fallback_resource_adapater(fallback_resource_adapater const&) = delete;
fallback_resource_adapater& operator=(fallback_resource_adapater const&) = delete;
fallback_resource_adapater(fallback_resource_adapater&&) noexcept =
default; ///< @default_move_constructor
fallback_resource_adapater& operator=(fallback_resource_adapater&&) noexcept =
default; ///< @default_move_assignment{fallback_resource_adapater}

/**
* @briefreturn{rmm::device_async_resource_ref to the upstream resource}
*/
[[nodiscard]] rmm::device_async_resource_ref get_upstream_resource() const noexcept
{
return primary_upstream_;
}

/**
* @briefreturn{rmm::device_async_resource_ref to the alternate upstream resource}
*/
[[nodiscard]] rmm::device_async_resource_ref get_alternate_upstream_resource() const noexcept
{
return alternate_upstream_;
}

private:
/**
* @brief Allocates memory of size at least `bytes` using the upstream
* resource.
*
* @throws any exceptions thrown from the upstream resources, only `exception_type`
* thrown by the primary upstream is caught.
*
* @param bytes The size, in bytes, of the allocation
* @param stream Stream on which to perform the allocation
* @return void* Pointer to the newly allocated memory
*/
void* do_allocate(std::size_t bytes, cuda_stream_view stream) override
{
void* ret{};
try {
ret = primary_upstream_.allocate_async(bytes, stream);
} catch (exception_type const& e) {
ret = alternate_upstream_.allocate_async(bytes, stream);
std::lock_guard<std::mutex> lock(mtx_);
alternate_allocations_.insert(ret);
}
return ret;
}

/**
* @brief Free allocation of size `bytes` pointed to by `ptr`
*
* @param ptr Pointer to be deallocated
* @param bytes Size of the allocation
* @param stream Stream on which to perform the deallocation
*/
void do_deallocate(void* ptr, std::size_t bytes, cuda_stream_view stream) override
{
std::size_t count{0};
{
std::lock_guard<std::mutex> lock(mtx_);
count = alternate_allocations_.erase(ptr);
}
if (count > 0) {
alternate_upstream_.deallocate_async(ptr, bytes, stream);
} else {
primary_upstream_.deallocate_async(ptr, bytes, stream);
}
}

/**
* @brief Compare the resource to another.
*
* @param other The other resource to compare to
* @return true If the two resources are equivalent
* @return false If the two resources are not equal
*/
[[nodiscard]] bool do_is_equal(device_memory_resource const& other) const noexcept override
{
if (this == &other) { return true; }
auto cast = dynamic_cast<fallback_resource_adapater const*>(&other);
if (cast == nullptr) { return false; }
return get_upstream_resource() == cast->get_upstream_resource() &&
get_alternate_upstream_resource() == cast->get_alternate_upstream_resource();
}

device_async_resource_ref primary_upstream_;
device_async_resource_ref alternate_upstream_;
std::unordered_set<void*> alternate_allocations_;
mutable std::mutex mtx_;
};

/** @} */ // end of group
} // namespace mr
} // namespace RMM_NAMESPACE
5 changes: 5 additions & 0 deletions python/rmm/rmm/_lib/memory_resource.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,11 @@ cdef class TrackingResourceAdaptor(UpstreamResourceAdaptor):
cdef class FailureCallbackResourceAdaptor(UpstreamResourceAdaptor):
cdef object _callback

cdef class FallbackResourceAdaptor(UpstreamResourceAdaptor):
cdef readonly DeviceMemoryResource alternate_upstream_mr

cpdef DeviceMemoryResource get_alternate_upstream(self)

cdef class PrefetchResourceAdaptor(UpstreamResourceAdaptor):
pass

Expand Down
62 changes: 60 additions & 2 deletions python/rmm/rmm/_lib/memory_resource.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ from libcpp.string cimport string
from cuda.cudart import cudaError_t

from rmm._cuda.gpu import CUDARuntimeError, getDevice, setDevice

from rmm._cuda.stream cimport Stream

from rmm._cuda.stream import DEFAULT_STREAM

from rmm._lib.cuda_stream_view cimport cuda_stream_view
Expand All @@ -44,6 +46,7 @@ from rmm._lib.per_device_resource cimport (
cuda_device_id,
set_per_device_resource as cpp_set_per_device_resource,
)

from rmm.statistics import Statistics

# Transparent handle of a C++ exception
Expand Down Expand Up @@ -84,6 +87,10 @@ cdef extern from *:

# NOTE: Keep extern declarations in .pyx file as much as possible to avoid
# leaking dependencies when importing RMM Cython .pxd files

cdef extern from "rmm/error.hpp" namespace "rmm" nogil:
cdef cppclass out_of_memory

cdef extern from "rmm/mr/device/cuda_memory_resource.hpp" \
namespace "rmm::mr" nogil:
cdef cppclass cuda_memory_resource(device_memory_resource):
Expand Down Expand Up @@ -125,7 +132,6 @@ cdef extern from "rmm/mr/device/cuda_async_memory_resource.hpp" \
win32
win32_kmt


cdef extern from "rmm/mr/device/pool_memory_resource.hpp" \
namespace "rmm::mr" nogil:
cdef cppclass pool_memory_resource[Upstream](device_memory_resource):
Expand Down Expand Up @@ -229,6 +235,19 @@ cdef extern from "rmm/mr/device/failure_callback_resource_adaptor.hpp" \
void* callback_arg
) except +

cdef extern from "rmm/mr/device/fallback_resource_adapater.hpp" \
namespace "rmm::mr" nogil:
cdef cppclass fallback_resource_adapater[ExceptionType](
device_memory_resource
):
# Notice, `fallback_resource_adapater` takes `device_async_resource_ref`
# as upstream arguments but we define them here as `device_memory_resource*` and
# rely on implicit type conversion.
fallback_resource_adapater(
device_memory_resource* upstream_mr,
device_memory_resource* alternate_upstream_mr,
) except +

cdef extern from "rmm/mr/device/prefetch_resource_adaptor.hpp" \
namespace "rmm::mr" nogil:
cdef cppclass prefetch_resource_adaptor[Upstream](device_memory_resource):
Expand Down Expand Up @@ -279,7 +298,6 @@ cdef class UpstreamResourceAdaptor(DeviceMemoryResource):
"""

def __cinit__(self, DeviceMemoryResource upstream_mr, *args, **kwargs):

if (upstream_mr is None):
raise Exception("Argument `upstream_mr` must not be None")

Expand Down Expand Up @@ -1039,6 +1057,46 @@ cdef class FailureCallbackResourceAdaptor(UpstreamResourceAdaptor):
"""
pass


cdef class FallbackResourceAdaptor(UpstreamResourceAdaptor):

def __cinit__(
self,
DeviceMemoryResource upstream_mr,
DeviceMemoryResource alternate_upstream_mr,
):
if (alternate_upstream_mr is None):
raise Exception("Argument `alternate_upstream_mr` must not be None")
self.alternate_upstream_mr = alternate_upstream_mr

self.c_obj.reset(
new fallback_resource_adapater[out_of_memory](
upstream_mr.get_mr(),
alternate_upstream_mr.get_mr(),
)
)

def __init__(
self,
DeviceMemoryResource upstream_mr,
DeviceMemoryResource alternate_upstream_mr,
):
"""
A memory resource that uses an alternate resource when memory allocation fails.

Parameters
----------
upstream : DeviceMemoryResource
The primary resource used for allocating/deallocating device memory
alternate_upstream : DeviceMemoryResource
The alternate resource used when the primary fails to allocate
"""
pass

cpdef DeviceMemoryResource get_alternate_upstream(self):
return self.alternate_upstream_mr


cdef class PrefetchResourceAdaptor(UpstreamResourceAdaptor):

def __cinit__(
Expand Down
2 changes: 2 additions & 0 deletions python/rmm/rmm/mr.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
CudaMemoryResource,
DeviceMemoryResource,
FailureCallbackResourceAdaptor,
FallbackResourceAdaptor,
FixedSizeMemoryResource,
LimitingResourceAdaptor,
LoggingResourceAdaptor,
Expand Down Expand Up @@ -61,6 +62,7 @@
"SystemMemoryResource",
"TrackingResourceAdaptor",
"FailureCallbackResourceAdaptor",
"FallbackResourceAdaptor",
"UpstreamResourceAdaptor",
"_flush_logs",
"_initialize",
Expand Down
51 changes: 51 additions & 0 deletions python/rmm/rmm/tests/test_rmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -787,6 +787,57 @@ def callback(nbytes: int) -> bool:
assert retried[0]


def test_fallback_resource_adapater():
base = rmm.mr.CudaMemoryResource()

def alloc_cb(size, stream, *, track: list[int], limit: int):
if size > limit:
raise MemoryError()
ret = base.allocate(size, stream)
track.append(ret)
return ret

def dealloc_cb(ptr, size, stream, *, track: list[int]):
track.append(ptr)
return base.deallocate(ptr, size, stream)

main_track = []
main_mr = rmm.mr.CallbackMemoryResource(
functools.partial(alloc_cb, track=main_track, limit=200),
functools.partial(dealloc_cb, track=main_track),
)
alternate_track = []
alternate_mr = rmm.mr.CallbackMemoryResource(
functools.partial(alloc_cb, track=alternate_track, limit=1000),
functools.partial(dealloc_cb, track=alternate_track),
)
mr = rmm.mr.FallbackResourceAdaptor(main_mr, alternate_mr)
assert main_mr is mr.get_upstream()
assert alternate_mr is mr.get_alternate_upstream()

# Delete the upstream memory resources here to check that they are
# kept alive by `mr`
del main_mr
del alternate_mr

# Buffer size within the limit of `main_mr`
rmm.DeviceBuffer(size=100, mr=mr)
# we expect an alloc and a dealloc of the same buffer in
# `main_track` and an empty `alternate_track`
assert len(main_track) == 2
assert main_track[0] == main_track[1]
assert len(alternate_track) == 0

# Buffer size outside the limit of `main_mr`
rmm.DeviceBuffer(size=500, mr=mr)
# we expect an alloc and a dealloc of the same buffer in
# `alternate_track` and an unchanged `main_mr`
assert len(main_track) == 2
assert main_track[0] == main_track[1]
assert len(alternate_track) == 2
assert alternate_track[0] == alternate_track[1]


@pytest.mark.parametrize("managed", [True, False])
def test_prefetch_resource_adaptor(managed):
if managed:
Expand Down
5 changes: 4 additions & 1 deletion tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -145,9 +145,12 @@ ConfigureTest(STATISTICS_TEST mr/device/statistics_mr_tests.cpp GPUS 1 PERCENT 1
# tracking adaptor tests
ConfigureTest(TRACKING_TEST mr/device/tracking_mr_tests.cpp GPUS 1 PERCENT 100)

# out-of-memory callback adaptor tests
# failure callback adaptor tests
ConfigureTest(FAILURE_CALLBACK_TEST mr/device/failure_callback_mr_tests.cpp)

# failure fallback adaptor tests
ConfigureTest(FAILURE_ALTERNATE_TEST mr/device/fallback_mr_tests.cpp)

# prefetch adaptor tests
ConfigureTest(PREFETCH_ADAPTOR_TEST mr/device/prefetch_resource_adaptor_tests.cpp)

Expand Down
Loading
Loading