Skip to content

Commit

Permalink
Unify Copy-On-Write and Spilling (#15436)
Browse files Browse the repository at this point in the history
This is the final step to unify COW and spilling. Now, `SpillableBuffer` inherits from `ExposureTrackedBuffer` so the final class hierarchy becomes: 
``` 
SpillableBufferOwner -> BufferOwner 
SpillableBuffer -> ExposureTrackedBuffer -> Buffer 
```

Additionally, spill-on-demand is now set globally using `set_spill_on_demand_globally()` instead of in the `SpillManager` constructor.

Authors:
  - Mads R. B. Kristensen (https://github.com/madsbk)

Approvers:
  - Vyas Ramasubramani (https://github.com/vyasr)
  - GALI PREM SAGAR (https://github.com/galipremsagar)

URL: #15436
  • Loading branch information
madsbk authored Apr 20, 2024
1 parent d37636d commit 96903bb
Show file tree
Hide file tree
Showing 8 changed files with 284 additions and 135 deletions.
69 changes: 44 additions & 25 deletions python/cudf/cudf/core/buffer/buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,25 @@ class BufferOwner(Serializable):
been accessed outside of BufferOwner. In this case, we have no control
over knowing if the data is being modified by a third party.
Use `_from_device_memory` and `_from_host_memory` to create
Use `from_device_memory` and `from_host_memory` to create
a new instance from either device or host memory respectively.
Parameters
----------
ptr
An integer representing a pointer to memory.
size
The size of the memory in nbytes
owner
Python object to which the lifetime of the memory allocation is tied.
This buffer will keep a reference to `owner`.
exposed
Pointer to the underlying memory
Raises
------
ValueError
If size is negative
"""

_ptr: int
Expand All @@ -117,14 +134,25 @@ class BufferOwner(Serializable):
# The set of buffers that point to this owner.
_slices: weakref.WeakSet[Buffer]

def __init__(self):
raise ValueError(
f"do not create a {self.__class__} directly, please "
"use the factory function `cudf.core.buffer.as_buffer`"
)
def __init__(
self,
*,
ptr: int,
size: int,
owner: object,
exposed: bool,
):
if size < 0:
raise ValueError("size cannot be negative")

self._ptr = ptr
self._size = size
self._owner = owner
self._exposed = exposed
self._slices = weakref.WeakSet()

@classmethod
def _from_device_memory(cls, data: Any, exposed: bool) -> Self:
def from_device_memory(cls, data: Any, exposed: bool) -> Self:
"""Create from an object providing a `__cuda_array_interface__`.
No data is being copied.
Expand All @@ -151,24 +179,15 @@ def _from_device_memory(cls, data: Any, exposed: bool) -> Self:
If the resulting buffer has negative size
"""

# Bypass `__init__` and initialize attributes manually
ret = cls.__new__(cls)
ret._owner = data
ret._exposed = exposed
ret._slices = weakref.WeakSet()
if isinstance(data, rmm.DeviceBuffer): # Common case shortcut
ret._ptr = data.ptr
ret._size = data.size
ptr = data.ptr
size = data.size
else:
ret._ptr, ret._size = get_ptr_and_size(
data.__cuda_array_interface__
)
if ret.size < 0:
raise ValueError("size cannot be negative")
return ret
ptr, size = get_ptr_and_size(data.__cuda_array_interface__)
return cls(ptr=ptr, size=size, owner=data, exposed=exposed)

@classmethod
def _from_host_memory(cls, data: Any) -> Self:
def from_host_memory(cls, data: Any) -> Self:
"""Create an owner from a buffer or array like object
Data must implement `__array_interface__`, the buffer protocol, and/or
Expand Down Expand Up @@ -196,7 +215,7 @@ def _from_host_memory(cls, data: Any) -> Self:
# Copy to device memory
buf = rmm.DeviceBuffer(ptr=ptr, size=size)
# Create from device memory
return cls._from_device_memory(buf, exposed=False)
return cls.from_device_memory(buf, exposed=False)

@property
def size(self) -> int:
Expand Down Expand Up @@ -375,7 +394,7 @@ def copy(self, deep: bool = True) -> Self:
)

# Otherwise, we create a new copy of the memory
owner = self._owner._from_device_memory(
owner = self._owner.from_device_memory(
rmm.DeviceBuffer(
ptr=self._owner.get_ptr(mode="read") + self._offset,
size=self.size,
Expand Down Expand Up @@ -439,9 +458,9 @@ def deserialize(cls, header: dict, frames: list) -> Self:

owner_type: BufferOwner = pickle.loads(header["owner-type-serialized"])
if hasattr(frame, "__cuda_array_interface__"):
owner = owner_type._from_device_memory(frame, exposed=False)
owner = owner_type.from_device_memory(frame, exposed=False)
else:
owner = owner_type._from_host_memory(frame)
owner = owner_type.from_host_memory(frame)
return cls(
owner=owner,
offset=0,
Expand Down
18 changes: 6 additions & 12 deletions python/cudf/cudf/core/buffer/exposure_tracked_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,20 +23,14 @@ class ExposureTrackedBuffer(Buffer):
The size of the slice (in bytes)
"""

_owner: BufferOwner

def __init__(
self,
owner: BufferOwner,
offset: int = 0,
size: Optional[int] = None,
) -> None:
super().__init__(owner=owner, offset=offset, size=size)
self._owner._slices.add(self)

@property
def exposed(self) -> bool:
return self._owner.exposed
self.owner._slices.add(self)

def get_ptr(self, *, mode: Literal["read", "write"]) -> int:
if mode == "write" and cudf.get_option("copy_on_write"):
Expand Down Expand Up @@ -72,7 +66,7 @@ def copy(self, deep: bool = True) -> Self:
copy-on-write option (see above).
"""
if cudf.get_option("copy_on_write"):
return super().copy(deep=deep or self.exposed)
return super().copy(deep=deep or self.owner.exposed)
return super().copy(deep=deep)

@property
Expand All @@ -98,11 +92,11 @@ def make_single_owner_inplace(self) -> None:
Buffer representing the same device memory as `data`
"""

if len(self._owner._slices) > 1:
# If this is not the only slice pointing to `self._owner`, we
# point to a new deep copy of the owner.
if len(self.owner._slices) > 1:
# If this is not the only slice pointing to `self.owner`, we
# point to a new copy of our slice of `self.owner`.
t = self.copy(deep=True)
self._owner = t._owner
self._owner = t.owner
self._offset = t._offset
self._size = t._size
self._owner._slices.add(self)
101 changes: 75 additions & 26 deletions python/cudf/cudf/core/buffer/spill_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import warnings
import weakref
from collections import defaultdict
from contextlib import contextmanager
from dataclasses import dataclass
from functools import partial
from typing import Dict, List, Optional, Tuple
Expand Down Expand Up @@ -201,10 +202,6 @@ class SpillManager:
This class implements tracking of all known spillable buffers, on-demand
spilling of said buffers, and (optionally) maintains a memory usage limit.
When `spill_on_demand=True`, the manager registers an RMM out-of-memory
error handler, which will spill spillable buffers in order to free up
memory.
When `device_memory_limit=<limit-in-bytes>`, the manager will try keep
the device memory usage below the specified limit by spilling of spillable
buffers continuously, which will introduce a modest overhead.
Expand All @@ -213,8 +210,6 @@ class SpillManager:
Parameters
----------
spill_on_demand : bool
Enable spill on demand.
device_memory_limit: int, optional
If not None, this is the device memory limit in bytes that triggers
device to host spilling. The global manager sets this to the value
Expand All @@ -230,30 +225,15 @@ class SpillManager:
def __init__(
self,
*,
spill_on_demand: bool = False,
device_memory_limit: Optional[int] = None,
statistic_level: int = 0,
) -> None:
self._lock = threading.Lock()
self._buffers = weakref.WeakValueDictionary()
self._id_counter = 0
self._spill_on_demand = spill_on_demand
self._device_memory_limit = device_memory_limit
self.statistics = SpillStatistics(statistic_level)

if self._spill_on_demand:
# Set the RMM out-of-memory handle if not already set
mr = rmm.mr.get_current_device_resource()
if all(
not isinstance(m, rmm.mr.FailureCallbackResourceAdaptor)
for m in get_rmm_memory_resource_stack(mr)
):
rmm.mr.set_current_device_resource(
rmm.mr.FailureCallbackResourceAdaptor(
mr, self._out_of_memory_handle
)
)

def _out_of_memory_handle(self, nbytes: int, *, retry_once=True) -> bool:
"""Try to handle an out-of-memory error by spilling
Expand Down Expand Up @@ -408,8 +388,7 @@ def __repr__(self) -> str:
dev_limit = format_bytes(self._device_memory_limit)

return (
f"<SpillManager spill_on_demand={self._spill_on_demand} "
f"device_memory_limit={dev_limit} | "
f"<SpillManager device_memory_limit={dev_limit} | "
f"{format_bytes(spilled)} spilled | "
f"{format_bytes(unspilled)} ({unspillable_ratio:.0%}) "
f"unspilled (unspillable)>"
Expand Down Expand Up @@ -442,12 +421,82 @@ def get_global_manager() -> Optional[SpillManager]:
"""Get the global manager or None if spilling is disabled"""
global _global_manager_uninitialized
if _global_manager_uninitialized:
manager = None
if get_option("spill"):
manager = SpillManager(
spill_on_demand=get_option("spill_on_demand"),
device_memory_limit=get_option("spill_device_limit"),
statistic_level=get_option("spill_stats"),
)
set_global_manager(manager)
set_global_manager(manager)
if get_option("spill_on_demand"):
set_spill_on_demand_globally()
else:
set_global_manager(None)
return _global_manager


def set_spill_on_demand_globally() -> None:
"""Enable spill on demand in the current global spill manager.
Warning: this modifies the current RMM memory resource. A memory resource
to handle out-of-memory errors is pushed onto the RMM memory resource stack.
Raises
------
ValueError
If no global spill manager exists (spilling is disabled).
ValueError
If a failure callback resource is already in the resource stack.
"""

manager = get_global_manager()
if manager is None:
raise ValueError(
"Cannot enable spill on demand with no global spill manager"
)
mr = rmm.mr.get_current_device_resource()
if any(
isinstance(m, rmm.mr.FailureCallbackResourceAdaptor)
for m in get_rmm_memory_resource_stack(mr)
):
raise ValueError(
"Spill on demand (or another failure callback resource) "
"is already registered"
)
rmm.mr.set_current_device_resource(
rmm.mr.FailureCallbackResourceAdaptor(
mr, manager._out_of_memory_handle
)
)


@contextmanager
def spill_on_demand_globally():
"""Context to enable spill on demand temporarily.
Warning: this modifies the current RMM memory resource. A memory resource
to handle out-of-memory errors is pushed onto the RMM memory resource stack
when entering the context and popped again when exiting.
Raises
------
ValueError
If no global spill manager exists (spilling is disabled).
ValueError
If a failure callback resource is already in the resource stack.
ValueError
If the RMM memory source stack was changed while in the context.
"""
set_spill_on_demand_globally()
# Save the new memory resource stack for later cleanup
mr_stack = get_rmm_memory_resource_stack(
rmm.mr.get_current_device_resource()
)
try:
yield
finally:
mr = rmm.mr.get_current_device_resource()
if mr_stack != get_rmm_memory_resource_stack(mr):
raise ValueError(
"RMM memory source stack was changed while in the context"
)
rmm.mr.set_current_device_resource(mr_stack[1])
Loading

0 comments on commit 96903bb

Please sign in to comment.