Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Spill OOM Protection #16737

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 62 additions & 24 deletions python/cudf/cudf/core/buffer/spill_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,9 @@ def __init__(
self._device_memory_limit = device_memory_limit
self.statistics = SpillStatistics(statistic_level)

def _out_of_memory_handle(self, nbytes: int, *, retry_once=True) -> bool:
def _out_of_memory_handle(
self, nbytes: int, *, retry_once=True, verbose=True
) -> bool:
"""Try to handle an out-of-memory error by spilling

This can by used as the callback function to RMM's
Expand Down Expand Up @@ -269,15 +271,17 @@ def _out_of_memory_handle(self, nbytes: int, *, retry_once=True) -> bool:
if retry_once:
# Let's collect garbage and try one more time
gc.collect()
return self._out_of_memory_handle(nbytes, retry_once=False)

# TODO: write to log instead of stdout
print(
f"[WARNING] RMM allocation of {format_bytes(nbytes)} bytes "
"failed, spill-on-demand couldn't find any device memory to "
f"spill:\n{repr(self)}\ntraceback:\n{get_traceback()}\n"
f"{self.statistics}"
)
return self._out_of_memory_handle(
nbytes, retry_once=False, verbose=verbose
)

if verbose:
print(
f"[WARNING] RMM allocation of {format_bytes(nbytes)} bytes "
"failed, spill-on-demand couldn't find any device memory to "
f"spill:\n{repr(self)}\ntraceback:\n{get_traceback()}\n"
f"{self.statistics}"
)
return False # Since we didn't find anything to spill, we give up

def add(self, buffer: SpillableBufferOwner) -> None:
Expand Down Expand Up @@ -436,11 +440,17 @@ def get_global_manager() -> SpillManager | None:
return _global_manager


def set_spill_on_demand_globally() -> None:
def set_spill_on_demand_globally(
spill_oom_protection: int | None = None,
) -> None:
"""Enable spill on demand in the current global spill manager.

Warning: this modifies the current RMM memory resource. A memory resource
to handle out-of-memory errors is pushed onto the RMM memory resource stack.
Warning
-------
This modifies the current RMM memory resource. A memory resource to
handle out-of-memory errors is pushed onto the RMM memory resource
stack. Modifying or rmm.reinitialize the RMM stack will disable spill
on demand.

Raises
------
Expand All @@ -449,6 +459,13 @@ def set_spill_on_demand_globally() -> None:
ValueError
If a failure callback resource is already in the resource stack.
"""
if spill_oom_protection is None:
spill_oom_protection = get_option("spill_oom_protection")
if spill_oom_protection < 0 or spill_oom_protection > 100:
raise ValueError(
"spill_oom_protection must be an integer between 0 and 100 "
f"(was {spill_oom_protection})"
)

manager = get_global_manager()
if manager is None:
Expand All @@ -464,20 +481,38 @@ def set_spill_on_demand_globally() -> None:
"Spill on demand (or another failure callback resource) "
"is already registered"
)
rmm.mr.set_current_device_resource(
rmm.mr.FailureCallbackResourceAdaptor(
mr, manager._out_of_memory_handle

# Add a limiting resource to the RMM stack when spilling and the OOM
# protection should kick in before total device memory usage.
if spill_oom_protection > 0 and spill_oom_protection < 100:
_, total_dev_mem = rmm.mr.available_device_memory()
mr_limit = int(total_dev_mem * (spill_oom_protection / 100))
mr = rmm.mr.LimitingResourceAdaptor(mr, allocation_limit=mr_limit)

# Add the OOM handle to the RMM stack. When OOM protection is enabled,
# the OOM handle should be quiet.
oom_handle_func = manager._out_of_memory_handle
if spill_oom_protection > 0:
oom_handle_func = partial(oom_handle_func, verbose=False)
mr = rmm.mr.FailureCallbackResourceAdaptor(mr, oom_handle_func)

# Add the OOM protection to the RMM stack.
if spill_oom_protection > 0:
mr = rmm.mr.FailureAlternateResourceAdaptor(
mr, rmm.mr.ManagedMemoryResource()
)
)
rmm.mr.set_current_device_resource(mr)


@contextmanager
def spill_on_demand_globally():
def spill_on_demand_globally(spill_oom_protection: int | None = None):
"""Context to enable spill on demand temporarily.

Warning: this modifies the current RMM memory resource. A memory resource
to handle out-of-memory errors is pushed onto the RMM memory resource stack
when entering the context and popped again when exiting.
Warning
-------
This modifies the current RMM memory resource. A memory resource to
handle out-of-memory errors is pushed onto the RMM memory resource
stack when entering the context and popped again when exiting.

Raises
------
Expand All @@ -488,8 +523,11 @@ def spill_on_demand_globally():
ValueError
If the RMM memory source stack was changed while in the context.
"""
set_spill_on_demand_globally()
# Save the new memory resource stack for later cleanup
# Save the current memory resource for later cleanup
mr_old = rmm.mr.get_current_device_resource()

set_spill_on_demand_globally(spill_oom_protection)
# Save the new memory resource stack for later consistency check
mr_stack = get_rmm_memory_resource_stack(
rmm.mr.get_current_device_resource()
)
Expand All @@ -501,4 +539,4 @@ def spill_on_demand_globally():
raise ValueError(
"RMM memory source stack was changed while in the context"
)
rmm.mr.set_current_device_resource(mr_stack[1])
rmm.mr.set_current_device_resource(mr_old)
20 changes: 20 additions & 0 deletions python/cudf/cudf/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,26 @@ def _integer_and_none_validator(val):
_integer_and_none_validator,
)

_register_option(
"spill_oom_protection",
_env_get_int("CUDF_SPILL_OOM_PROTECTION", 100),
textwrap.dedent(
"""
If not 0, enables out-of-memory protection. The value specifies at
which procent of total device memory the protection should kick in.
0 - disabled.
1-99 - enable OOM protection when reaching 1%% to 99%% of total
device memory.
100 - enable OOM protection only on OOM errors (no headroom).

This has no effect if spill on demand is disabled, see the "spill"
and the "spill_on_demand" options.
Valid values are any positive integer. Default is 100 (no headroom).
"""
),
_integer_validator,
)

_register_option(
"spill_stats",
_env_get_int("CUDF_SPILL_STATS", 0),
Expand Down
55 changes: 54 additions & 1 deletion python/cudf/cudf/tests/test_spilling.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def spilled_and_unspilled(manager: SpillManager) -> tuple[int, int]:

@pytest.fixture
def manager(request):
"""Fixture to enable and make a spilling manager availabe"""
"""Fixture to enable and make a spilling manager available"""
kwargs = dict(getattr(request, "param", {}))
with warnings.catch_warnings():
warnings.simplefilter("error")
Expand All @@ -133,6 +133,16 @@ def manager(request):
set_global_manager(manager=None)


@pytest.fixture
def rmm_cleanup():
"""Reset the current RMM resource stack after the test"""
mr = rmm.mr.get_current_device_resource()
try:
yield
finally:
rmm.mr.set_current_device_resource(mr)


def test_spillable_buffer(manager: SpillManager):
buf = as_buffer(data=rmm.DeviceBuffer(size=10), exposed=False)
assert isinstance(buf, SpillableBuffer)
Expand Down Expand Up @@ -784,3 +794,46 @@ def test_spilling_and_copy_on_write(manager: SpillManager):
assert not a.is_spilled
assert a.owner.exposed
assert not b.owner.exposed


def test_oom_protection(manager: SpillManager, rmm_cleanup, capsys):
# Use a limit of 10% of device (256 aligned)
_, total_dev_mem = rmm.mr.available_device_memory()
alloc_limit = int(round(total_dev_mem * 0.1 / 256) * 256)

track_mr = rmm.mr.TrackingResourceAdaptor(
rmm.mr.get_current_device_resource()
)
rmm.mr.set_current_device_resource(
rmm.mr.LimitingResourceAdaptor(track_mr, allocation_limit=alloc_limit)
)

# With a disabled OOM protection, we expect a OOM error including a
# stdout warning.
with spill_on_demand_globally(spill_oom_protection=0):
a = as_buffer(data=rmm.DeviceBuffer(size=alloc_limit), exposed=True)
assert track_mr.get_allocated_bytes() == a.nbytes
with pytest.raises(MemoryError, match="Exceeded memory limit"):
as_buffer(data=rmm.DeviceBuffer(size=alloc_limit))
assert "[WARNING] RMM allocation" in capsys.readouterr().out
del a
assert track_mr.get_allocated_bytes() == 0

# With OOM protection (no headroom), `track_mr` only encounters the
# first alloction and no stdout warning.
with spill_on_demand_globally(spill_oom_protection=100):
a = as_buffer(data=rmm.DeviceBuffer(size=alloc_limit), exposed=True)
assert track_mr.get_allocated_bytes() == a.nbytes
b = as_buffer(data=rmm.DeviceBuffer(size=alloc_limit), exposed=True)
assert track_mr.get_allocated_bytes() == b.nbytes
del a
del b
assert track_mr.get_allocated_bytes() == 0
assert "[WARNING] RMM allocation" not in capsys.readouterr().out

# With OOM protection (5% headroom), `track_mr` doesn't encounters any
# alloctions because of the 5% headroom
with spill_on_demand_globally(spill_oom_protection=5):
as_buffer(data=rmm.DeviceBuffer(size=alloc_limit), exposed=True)
assert track_mr.get_allocated_bytes() == 0
assert "[WARNING] RMM allocation" not in capsys.readouterr().out
Loading