From 12d9946c0fd36049ecc6667a6f58c136513fd822 Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Wed, 4 Sep 2024 11:27:47 +0200 Subject: [PATCH 1/2] impl. spill_oom_protection --- python/cudf/cudf/core/buffer/spill_manager.py | 75 ++++++++++++++----- python/cudf/cudf/options.py | 20 +++++ 2 files changed, 75 insertions(+), 20 deletions(-) diff --git a/python/cudf/cudf/core/buffer/spill_manager.py b/python/cudf/cudf/core/buffer/spill_manager.py index ed351a6b107..5f2cf69d477 100644 --- a/python/cudf/cudf/core/buffer/spill_manager.py +++ b/python/cudf/cudf/core/buffer/spill_manager.py @@ -236,7 +236,9 @@ def __init__( self._device_memory_limit = device_memory_limit self.statistics = SpillStatistics(statistic_level) - def _out_of_memory_handle(self, nbytes: int, *, retry_once=True) -> bool: + def _out_of_memory_handle( + self, nbytes: int, *, retry_once=True, verbose=True + ) -> bool: """Try to handle an out-of-memory error by spilling This can by used as the callback function to RMM's @@ -269,15 +271,17 @@ def _out_of_memory_handle(self, nbytes: int, *, retry_once=True) -> bool: if retry_once: # Let's collect garbage and try one more time gc.collect() - return self._out_of_memory_handle(nbytes, retry_once=False) - - # TODO: write to log instead of stdout - print( - f"[WARNING] RMM allocation of {format_bytes(nbytes)} bytes " - "failed, spill-on-demand couldn't find any device memory to " - f"spill:\n{repr(self)}\ntraceback:\n{get_traceback()}\n" - f"{self.statistics}" - ) + return self._out_of_memory_handle( + nbytes, retry_once=False, verbose=verbose + ) + + if verbose: + print( + f"[WARNING] RMM allocation of {format_bytes(nbytes)} bytes " + "failed, spill-on-demand couldn't find any device memory to " + f"spill:\n{repr(self)}\ntraceback:\n{get_traceback()}\n" + f"{self.statistics}" + ) return False # Since we didn't find anything to spill, we give up def add(self, buffer: SpillableBufferOwner) -> None: @@ -436,11 +440,17 @@ def get_global_manager() -> SpillManager | None: return _global_manager -def set_spill_on_demand_globally() -> None: +def set_spill_on_demand_globally( + spill_oom_protection: int | None = None, +) -> None: """Enable spill on demand in the current global spill manager. - Warning: this modifies the current RMM memory resource. A memory resource - to handle out-of-memory errors is pushed onto the RMM memory resource stack. + Warning + ------- + This modifies the current RMM memory resource. A memory resource to + handle out-of-memory errors is pushed onto the RMM memory resource + stack. Modifying or rmm.reinitialize the RMM stack will disable spill + on demand. Raises ------ @@ -449,6 +459,13 @@ def set_spill_on_demand_globally() -> None: ValueError If a failure callback resource is already in the resource stack. """ + if spill_oom_protection is None: + spill_oom_protection = get_option("spill_oom_protection") + if spill_oom_protection < 0 or spill_oom_protection > 100: + raise ValueError( + "spill_oom_protection must be an integer between 0 and 100 " + f"(was {spill_oom_protection})" + ) manager = get_global_manager() if manager is None: @@ -464,20 +481,38 @@ def set_spill_on_demand_globally() -> None: "Spill on demand (or another failure callback resource) " "is already registered" ) - rmm.mr.set_current_device_resource( - rmm.mr.FailureCallbackResourceAdaptor( - mr, manager._out_of_memory_handle + + # Add a limiting resource to the RMM stack when the OOM protection should + # kick in before total device memory usage. + if spill_oom_protection > 0 and spill_oom_protection < 100: + _, total_dev_mem = rmm.mr.available_device_memory() + mr_limit = int(total_dev_mem * (spill_oom_protection / 100)) + mr = rmm.mr.LimitingResourceAdaptor(mr, allocation_limit=mr_limit) + + # Add the OOM handle to the RMM stack. When OOM protection is enabled, + # the OOM handle should be quiet. + oom_handle_func = manager._out_of_memory_handle + if spill_oom_protection > 0: + oom_handle_func = partial(oom_handle_func, verbose=False) + mr = rmm.mr.FailureCallbackResourceAdaptor(mr, oom_handle_func) + + # Add the OOM protection to the RMM stack. + if spill_oom_protection > 0: + mr = rmm.mr.FailureAlternateResourceAdaptor( + mr, rmm.mr.ManagedMemoryResource() ) - ) + rmm.mr.set_current_device_resource(mr) @contextmanager def spill_on_demand_globally(): """Context to enable spill on demand temporarily. - Warning: this modifies the current RMM memory resource. A memory resource - to handle out-of-memory errors is pushed onto the RMM memory resource stack - when entering the context and popped again when exiting. + Warning + ------- + This modifies the current RMM memory resource. A memory resource to + handle out-of-memory errors is pushed onto the RMM memory resource + stack when entering the context and popped again when exiting. Raises ------ diff --git a/python/cudf/cudf/options.py b/python/cudf/cudf/options.py index df7bbe22a61..4d69b52cc32 100644 --- a/python/cudf/cudf/options.py +++ b/python/cudf/cudf/options.py @@ -278,6 +278,26 @@ def _integer_and_none_validator(val): _integer_and_none_validator, ) +_register_option( + "spill_oom_protection", + _env_get_int("CUDF_SPILL_OOM_PROTECTION", 100), + textwrap.dedent( + """ + If not 0, enables out-of-memory protection. The value specifies at + which procent of total device memory the protection should kick in. + 0 - disabled. + 1-99 - enable OOM protection when reaching 1%% to 99%% of total + device memory. + 100 - enable OOM protection only on OOM errors (no headroom). + + This has no effect if spill on demand is disabled, see the "spill" + and the "spill_on_demand" options. + Valid values are any positive integer. Default is 0 (disabled). + """ + ), + _integer_validator, +) + _register_option( "spill_stats", _env_get_int("CUDF_SPILL_STATS", 0), From c246aca84073b3c75d460db21360c1b8396b504d Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Wed, 4 Sep 2024 14:18:31 +0200 Subject: [PATCH 2/2] test --- python/cudf/cudf/core/buffer/spill_manager.py | 15 +++-- python/cudf/cudf/options.py | 2 +- python/cudf/cudf/tests/test_spilling.py | 55 ++++++++++++++++++- 3 files changed, 64 insertions(+), 8 deletions(-) diff --git a/python/cudf/cudf/core/buffer/spill_manager.py b/python/cudf/cudf/core/buffer/spill_manager.py index 5f2cf69d477..7097e28d55f 100644 --- a/python/cudf/cudf/core/buffer/spill_manager.py +++ b/python/cudf/cudf/core/buffer/spill_manager.py @@ -482,8 +482,8 @@ def set_spill_on_demand_globally( "is already registered" ) - # Add a limiting resource to the RMM stack when the OOM protection should - # kick in before total device memory usage. + # Add a limiting resource to the RMM stack when spilling and the OOM + # protection should kick in before total device memory usage. if spill_oom_protection > 0 and spill_oom_protection < 100: _, total_dev_mem = rmm.mr.available_device_memory() mr_limit = int(total_dev_mem * (spill_oom_protection / 100)) @@ -505,7 +505,7 @@ def set_spill_on_demand_globally( @contextmanager -def spill_on_demand_globally(): +def spill_on_demand_globally(spill_oom_protection: int | None = None): """Context to enable spill on demand temporarily. Warning @@ -523,8 +523,11 @@ def spill_on_demand_globally(): ValueError If the RMM memory source stack was changed while in the context. """ - set_spill_on_demand_globally() - # Save the new memory resource stack for later cleanup + # Save the current memory resource for later cleanup + mr_old = rmm.mr.get_current_device_resource() + + set_spill_on_demand_globally(spill_oom_protection) + # Save the new memory resource stack for later consistency check mr_stack = get_rmm_memory_resource_stack( rmm.mr.get_current_device_resource() ) @@ -536,4 +539,4 @@ def spill_on_demand_globally(): raise ValueError( "RMM memory source stack was changed while in the context" ) - rmm.mr.set_current_device_resource(mr_stack[1]) + rmm.mr.set_current_device_resource(mr_old) diff --git a/python/cudf/cudf/options.py b/python/cudf/cudf/options.py index 4d69b52cc32..9c4c9353e40 100644 --- a/python/cudf/cudf/options.py +++ b/python/cudf/cudf/options.py @@ -292,7 +292,7 @@ def _integer_and_none_validator(val): This has no effect if spill on demand is disabled, see the "spill" and the "spill_on_demand" options. - Valid values are any positive integer. Default is 0 (disabled). + Valid values are any positive integer. Default is 100 (no headroom). """ ), _integer_validator, diff --git a/python/cudf/cudf/tests/test_spilling.py b/python/cudf/cudf/tests/test_spilling.py index 7af83a99d60..d45c89172cb 100644 --- a/python/cudf/cudf/tests/test_spilling.py +++ b/python/cudf/cudf/tests/test_spilling.py @@ -118,7 +118,7 @@ def spilled_and_unspilled(manager: SpillManager) -> tuple[int, int]: @pytest.fixture def manager(request): - """Fixture to enable and make a spilling manager availabe""" + """Fixture to enable and make a spilling manager available""" kwargs = dict(getattr(request, "param", {})) with warnings.catch_warnings(): warnings.simplefilter("error") @@ -133,6 +133,16 @@ def manager(request): set_global_manager(manager=None) +@pytest.fixture +def rmm_cleanup(): + """Reset the current RMM resource stack after the test""" + mr = rmm.mr.get_current_device_resource() + try: + yield + finally: + rmm.mr.set_current_device_resource(mr) + + def test_spillable_buffer(manager: SpillManager): buf = as_buffer(data=rmm.DeviceBuffer(size=10), exposed=False) assert isinstance(buf, SpillableBuffer) @@ -784,3 +794,46 @@ def test_spilling_and_copy_on_write(manager: SpillManager): assert not a.is_spilled assert a.owner.exposed assert not b.owner.exposed + + +def test_oom_protection(manager: SpillManager, rmm_cleanup, capsys): + # Use a limit of 10% of device (256 aligned) + _, total_dev_mem = rmm.mr.available_device_memory() + alloc_limit = int(round(total_dev_mem * 0.1 / 256) * 256) + + track_mr = rmm.mr.TrackingResourceAdaptor( + rmm.mr.get_current_device_resource() + ) + rmm.mr.set_current_device_resource( + rmm.mr.LimitingResourceAdaptor(track_mr, allocation_limit=alloc_limit) + ) + + # With a disabled OOM protection, we expect a OOM error including a + # stdout warning. + with spill_on_demand_globally(spill_oom_protection=0): + a = as_buffer(data=rmm.DeviceBuffer(size=alloc_limit), exposed=True) + assert track_mr.get_allocated_bytes() == a.nbytes + with pytest.raises(MemoryError, match="Exceeded memory limit"): + as_buffer(data=rmm.DeviceBuffer(size=alloc_limit)) + assert "[WARNING] RMM allocation" in capsys.readouterr().out + del a + assert track_mr.get_allocated_bytes() == 0 + + # With OOM protection (no headroom), `track_mr` only encounters the + # first alloction and no stdout warning. + with spill_on_demand_globally(spill_oom_protection=100): + a = as_buffer(data=rmm.DeviceBuffer(size=alloc_limit), exposed=True) + assert track_mr.get_allocated_bytes() == a.nbytes + b = as_buffer(data=rmm.DeviceBuffer(size=alloc_limit), exposed=True) + assert track_mr.get_allocated_bytes() == b.nbytes + del a + del b + assert track_mr.get_allocated_bytes() == 0 + assert "[WARNING] RMM allocation" not in capsys.readouterr().out + + # With OOM protection (5% headroom), `track_mr` doesn't encounters any + # alloctions because of the 5% headroom + with spill_on_demand_globally(spill_oom_protection=5): + as_buffer(data=rmm.DeviceBuffer(size=alloc_limit), exposed=True) + assert track_mr.get_allocated_bytes() == 0 + assert "[WARNING] RMM allocation" not in capsys.readouterr().out