diff --git a/dask_cuda/device_host_file.py b/dask_cuda/device_host_file.py index 5e2463be0..c03fa2973 100644 --- a/dask_cuda/device_host_file.py +++ b/dask_cuda/device_host_file.py @@ -175,14 +175,12 @@ def __init__( local_directory=None, log_spilling=False, ): - if local_directory is None: - local_directory = dask.config.get("temporary-directory") or os.getcwd() - - if local_directory and not os.path.exists(local_directory): - os.makedirs(local_directory, exist_ok=True) - local_directory = os.path.join(local_directory, "dask-worker-space") - - self.disk_func_path = os.path.join(local_directory, "storage") + self.disk_func_path = os.path.join( + local_directory or dask.config.get("temporary-directory") or os.getcwd(), + "dask-worker-space", + "storage", + ) + os.makedirs(self.disk_func_path, exist_ok=True) self.host_func = dict() self.disk_func = Func( diff --git a/dask_cuda/explicit_comms/dataframe/shuffle.py b/dask_cuda/explicit_comms/dataframe/shuffle.py index aeea71467..cce5480e7 100644 --- a/dask_cuda/explicit_comms/dataframe/shuffle.py +++ b/dask_cuda/explicit_comms/dataframe/shuffle.py @@ -18,7 +18,7 @@ from distributed import wait from distributed.protocol import nested_deserialize, to_serialize -from ...proxify_host_file import ProxifyHostFile +from ...proxify_host_file import ProxyManager from .. import comms @@ -148,19 +148,17 @@ async def local_shuffle( eps = s["eps"] try: - hostfile = first(iter(in_parts[0].values()))._obj_pxy.get( - "hostfile", lambda: None - )() + manager = first(iter(in_parts[0].values()))._obj_pxy.get("manager", None) except AttributeError: - hostfile = None + manager = None - if isinstance(hostfile, ProxifyHostFile): + if isinstance(manager, ProxyManager): def concat(args, ignore_index=False): if len(args) < 2: return args[0] - return hostfile.add_external(dd_concat(args, ignore_index=ignore_index)) + return manager.proxify(dd_concat(args, ignore_index=ignore_index)) else: concat = dd_concat diff --git a/dask_cuda/get_device_memory_objects.py b/dask_cuda/get_device_memory_objects.py index deba96a06..385f70793 100644 --- a/dask_cuda/get_device_memory_objects.py +++ b/dask_cuda/get_device_memory_objects.py @@ -28,10 +28,7 @@ def get_device_memory_objects(obj) -> set: @dispatch.register(object) def get_device_memory_objects_default(obj): if hasattr(obj, "_obj_pxy"): - if obj._obj_pxy["serializers"] is None: - return dispatch(obj._obj_pxy["obj"]) - else: - return [] + return dispatch(obj._obj_pxy["obj"]) if hasattr(obj, "data"): return dispatch(obj.data) if hasattr(obj, "_owner") and obj._owner is not None: diff --git a/dask_cuda/local_cuda_cluster.py b/dask_cuda/local_cuda_cluster.py index 4a706e67f..07c51c863 100644 --- a/dask_cuda/local_cuda_cluster.py +++ b/dask_cuda/local_cuda_cluster.py @@ -273,9 +273,7 @@ def __init__( { "device_memory_limit": self.device_memory_limit, "memory_limit": self.host_memory_limit, - "local_directory": local_directory - or dask.config.get("temporary-directory") - or os.getcwd(), + "local_directory": local_directory, "log_spilling": log_spilling, }, ) diff --git a/dask_cuda/proxify_device_objects.py b/dask_cuda/proxify_device_objects.py index f3e3efb3f..cd067d3d1 100644 --- a/dask_cuda/proxify_device_objects.py +++ b/dask_cuda/proxify_device_objects.py @@ -165,14 +165,9 @@ def wrapper(*args, **kwargs): def proxify(obj, proxied_id_to_proxy, found_proxies, subclass=None): _id = id(obj) - if _id in proxied_id_to_proxy: - ret = proxied_id_to_proxy[_id] - finalize = ret._obj_pxy.get("external_finalize", None) - if finalize: - finalize() - proxied_id_to_proxy[_id] = ret = asproxy(obj, subclass=subclass) - else: - proxied_id_to_proxy[_id] = ret = asproxy(obj, subclass=subclass) + if _id not in proxied_id_to_proxy: + proxied_id_to_proxy[_id] = asproxy(obj, subclass=subclass) + ret = proxied_id_to_proxy[_id] found_proxies.append(ret) return ret @@ -190,11 +185,6 @@ def proxify_device_object_default( def proxify_device_object_proxy_object( obj, proxied_id_to_proxy, found_proxies, excl_proxies ): - # We deserialize CUDA-serialized objects since it is very cheap and - # makes it easy to administrate device memory usage - if obj._obj_pxy_is_serialized() and "cuda" in obj._obj_pxy["serializers"]: - obj._obj_pxy_deserialize() - # Check if `obj` is already known if not obj._obj_pxy_is_serialized(): _id = id(obj._obj_pxy["obj"]) @@ -203,14 +193,6 @@ def proxify_device_object_proxy_object( else: proxied_id_to_proxy[_id] = obj - finalize = obj._obj_pxy.get("external_finalize", None) - if finalize: - finalize() - obj = obj._obj_pxy_copy() - if not obj._obj_pxy_is_serialized(): - _id = id(obj._obj_pxy["obj"]) - proxied_id_to_proxy[_id] = obj - if not excl_proxies: found_proxies.append(obj) return obj diff --git a/dask_cuda/proxify_host_file.py b/dask_cuda/proxify_host_file.py index 6dd5d6b6b..a056ad5b5 100644 --- a/dask_cuda/proxify_host_file.py +++ b/dask_cuda/proxify_host_file.py @@ -1,18 +1,23 @@ +import abc import logging import threading import time +import warnings import weakref from collections import defaultdict from typing import ( + Any, DefaultDict, Dict, Hashable, Iterator, List, MutableMapping, + Optional, Set, Tuple, ) +from weakref import ReferenceType import dask from dask.sizeof import sizeof @@ -21,105 +26,244 @@ from .proxy_object import ProxyObject -class UnspilledProxies: - """Class to track current unspilled proxies""" +class Proxies(abc.ABC): + """Abstract base class to implement tracking of proxies + + This class is not threadsafe + """ def __init__(self): - self.dev_mem_usage = 0 - self.proxy_id_to_dev_mems: DefaultDict[int, Set[Hashable]] = defaultdict(set) + self._proxy_id_to_proxy: Dict[int, ReferenceType[ProxyObject]] = {} + self._mem_usage = 0 + + def __len__(self) -> int: + return len(self._proxy_id_to_proxy) + + @abc.abstractmethod + def mem_usage_add(self, proxy: ProxyObject) -> None: + """Given a new proxy, update `self._mem_usage`""" + + @abc.abstractmethod + def mem_usage_remove(self, proxy: ProxyObject) -> None: + """Removal of proxy, update `self._mem_usage`""" + + def add(self, proxy: ProxyObject) -> None: + """Add a proxy for tracking, calls `self.mem_usage_add`""" + assert not self.contains_proxy_id(id(proxy)) + self._proxy_id_to_proxy[id(proxy)] = weakref.ref(proxy) + self.mem_usage_add(proxy) + + def remove(self, proxy: ProxyObject) -> None: + """Remove proxy from tracking, calls `self.mem_usage_remove`""" + del self._proxy_id_to_proxy[id(proxy)] + self.mem_usage_remove(proxy) + if len(self._proxy_id_to_proxy) == 0: + if self._mem_usage != 0: + warnings.warn( + "ProxyManager is empty but the tally of " + f"{self} is {self._mem_usage} bytes. " + "Resetting the tally." + ) + self._mem_usage = 0 + + def __iter__(self) -> Iterator[ProxyObject]: + for p in self._proxy_id_to_proxy.values(): + ret = p() + if ret is not None: + yield ret + + def contains_proxy_id(self, proxy_id: int) -> bool: + return proxy_id in self._proxy_id_to_proxy + + def mem_usage(self) -> int: + return self._mem_usage + + +class ProxiesOnHost(Proxies): + """Implement tracking of proxies on the CPU + + This uses dask.sizeof to update memory usage. + """ + + def mem_usage_add(self, proxy: ProxyObject): + self._mem_usage += sizeof(proxy) + + def mem_usage_remove(self, proxy: ProxyObject): + self._mem_usage -= sizeof(proxy) + + +class ProxiesOnDevice(Proxies): + """Implement tracking of proxies on the GPU + + This is a bit more complicated than ProxiesOnHost because we have to + handle that multiple proxy objects can refer to the same underlying + device memory object. Thus, we have to track aliasing and make sure + we don't count down the memory usage prematurely. + """ + + def __init__(self): + super().__init__() + self.proxy_id_to_dev_mems: Dict[int, Set[Hashable]] = {} self.dev_mem_to_proxy_ids: DefaultDict[Hashable, Set[int]] = defaultdict(set) - def add(self, proxy: ProxyObject): + def mem_usage_add(self, proxy: ProxyObject): proxy_id = id(proxy) - if proxy_id not in self.proxy_id_to_dev_mems: - for dev_mem in proxy._obj_pxy_get_device_memory_objects(): - self.proxy_id_to_dev_mems[proxy_id].add(dev_mem) - ps = self.dev_mem_to_proxy_ids[dev_mem] - if len(ps) == 0: - self.dev_mem_usage += sizeof(dev_mem) - ps.add(proxy_id) - - def remove(self, proxy: ProxyObject): + assert proxy_id not in self.proxy_id_to_dev_mems + self.proxy_id_to_dev_mems[proxy_id] = set() + for dev_mem in proxy._obj_pxy_get_device_memory_objects(): + self.proxy_id_to_dev_mems[proxy_id].add(dev_mem) + ps = self.dev_mem_to_proxy_ids[dev_mem] + if len(ps) == 0: + self._mem_usage += sizeof(dev_mem) + ps.add(proxy_id) + + def mem_usage_remove(self, proxy: ProxyObject): proxy_id = id(proxy) - if proxy_id in self.proxy_id_to_dev_mems: - for dev_mem in self.proxy_id_to_dev_mems.pop(proxy_id): - self.dev_mem_to_proxy_ids[dev_mem].remove(proxy_id) - if len(self.dev_mem_to_proxy_ids[dev_mem]) == 0: - del self.dev_mem_to_proxy_ids[dev_mem] - self.dev_mem_usage -= sizeof(dev_mem) + for dev_mem in self.proxy_id_to_dev_mems.pop(proxy_id): + self.dev_mem_to_proxy_ids[dev_mem].remove(proxy_id) + if len(self.dev_mem_to_proxy_ids[dev_mem]) == 0: + del self.dev_mem_to_proxy_ids[dev_mem] + self._mem_usage -= sizeof(dev_mem) - def __iter__(self): - return iter(self.proxy_id_to_dev_mems) - -class ProxiesTally: +class ProxyManager: """ - This class together with UnspilledProxies implements the tracking of current - objects in device memory and the total memory usage. It turns out having to - re-calculate device memory usage continuously is too expensive. - - We have to track four events: - - When adding a new key to the host file - - When removing a key from the host file - - When a proxy in the host file is deserialized - - When a proxy in the host file is serialized - - However, it gets a bit complicated because: - - The value of a key in the host file can contain many proxy objects and a single - proxy object can be referred from many keys - - Multiple proxy objects can refer to the same underlying device memory object - - Proxy objects are not hashable thus we have to use the `id()` as key in - dictionaries - - ProxiesTally and UnspilledProxies implements this by carefully maintaining - dictionaries that maps to/from keys, proxy objects, and device memory objects. + This class together with Proxies, ProxiesOnHost, and ProxiesOnDevice + implements the tracking of all known proxies and their total host/device + memory usage. It turns out having to re-calculate memory usage continuously + is too expensive. + + The idea is to have the ProxifyHostFile or the proxies themselves update + their location (device or host). The manager then tallies the total memory usage. + + Notice, the manager only keeps weak references to the proxies. """ - def __init__(self): + def __init__(self, device_memory_limit: int): self.lock = threading.RLock() - self.proxy_id_to_proxy: Dict[int, ProxyObject] = {} - self.key_to_proxy_ids: DefaultDict[Hashable, Set[int]] = defaultdict(set) - self.proxy_id_to_keys: DefaultDict[int, Set[Hashable]] = defaultdict(set) - self.unspilled_proxies = UnspilledProxies() + self._host = ProxiesOnHost() + self._dev = ProxiesOnDevice() + self._device_memory_limit = device_memory_limit + + def __repr__(self) -> str: + return ( + f"" + ) + + def __len__(self) -> int: + return len(self._host) + len(self._dev) + + def pprint(self) -> str: + ret = f"{self}:" + if len(self) == 0: + return ret + " Empty" + ret += "\n" + for proxy in self._host: + ret += f" host - {repr(proxy)}\n" + for proxy in self._dev: + ret += f" dev - {repr(proxy)}\n" + return ret[:-1] # Strip last newline + + def get_proxies_by_serializer(self, serializer: Optional[str]) -> Proxies: + if serializer in ("dask", "pickle"): + return self._host + else: + return self._dev - def add_key(self, key, proxies: List[ProxyObject]): + def contains(self, proxy_id: int) -> bool: with self.lock: - for proxy in proxies: - proxy_id = id(proxy) - self.proxy_id_to_proxy[proxy_id] = proxy - self.key_to_proxy_ids[key].add(proxy_id) - self.proxy_id_to_keys[proxy_id].add(key) - if not proxy._obj_pxy_is_serialized(): - self.unspilled_proxies.add(proxy) - - def del_key(self, key): + return self._host.contains_proxy_id( + proxy_id + ) or self._dev.contains_proxy_id(proxy_id) + + def add(self, proxy: ProxyObject) -> None: with self.lock: - for proxy_id in self.key_to_proxy_ids.pop(key, ()): - self.proxy_id_to_keys[proxy_id].remove(key) - if len(self.proxy_id_to_keys[proxy_id]) == 0: - del self.proxy_id_to_keys[proxy_id] - self.unspilled_proxies.remove(self.proxy_id_to_proxy.pop(proxy_id)) + if not self.contains(id(proxy)): + self.get_proxies_by_serializer(proxy._obj_pxy["serializer"]).add(proxy) - def spill_proxy(self, proxy: ProxyObject): + def remove(self, proxy: ProxyObject) -> None: + with self.lock: + # Find where the proxy is located and remove it + proxies: Optional[Proxies] = None + if self._host.contains_proxy_id(id(proxy)): + proxies = self._host + if self._dev.contains_proxy_id(id(proxy)): + assert proxies is None, "Proxy in multiple locations" + proxies = self._dev + assert proxies is not None, "Trying to remove unknown proxy" + proxies.remove(proxy) + + def move( + self, + proxy: ProxyObject, + from_serializer: Optional[str], + to_serializer: Optional[str], + ) -> None: with self.lock: - self.unspilled_proxies.remove(proxy) + src = self.get_proxies_by_serializer(from_serializer) + dst = self.get_proxies_by_serializer(to_serializer) + if src is not dst: + src.remove(proxy) + dst.add(proxy) - def unspill_proxy(self, proxy: ProxyObject): + def proxify(self, obj: object) -> object: with self.lock: - self.unspilled_proxies.add(proxy) + found_proxies: List[ProxyObject] = [] + proxied_id_to_proxy: Dict[int, ProxyObject] = {} + ret = proxify_device_objects(obj, proxied_id_to_proxy, found_proxies) + last_access = time.monotonic() + for p in found_proxies: + p._obj_pxy["last_access"] = last_access + if not self.contains(id(p)): + p._obj_pxy_register_manager(self) + self.add(p) + self.maybe_evict() + return ret - def get_unspilled_proxies(self) -> Iterator[ProxyObject]: + def get_dev_buffer_to_proxies(self) -> DefaultDict[Hashable, List[ProxyObject]]: with self.lock: - for proxy_id in self.unspilled_proxies: - ret = self.proxy_id_to_proxy[proxy_id] - assert not ret._obj_pxy_is_serialized() - yield ret + # Notice, multiple proxy object can point to different non-overlapping + # parts of the same device buffer. + ret = defaultdict(list) + for proxy in self._dev: + for dev_buffer in proxy._obj_pxy_get_device_memory_objects(): + ret[dev_buffer].append(proxy) + return ret - def get_proxied_id_to_proxy(self) -> Dict[int, ProxyObject]: - return {id(p._obj_pxy["obj"]): p for p in self.get_unspilled_proxies()} + def get_dev_access_info( + self, + ) -> Tuple[int, List[Tuple[int, int, List[ProxyObject]]]]: + with self.lock: + total_dev_mem_usage = 0 + dev_buf_access = [] + for dev_buf, proxies in self.get_dev_buffer_to_proxies().items(): + last_access = max(p._obj_pxy.get("last_access", 0) for p in proxies) + size = sizeof(dev_buf) + dev_buf_access.append((last_access, size, proxies)) + total_dev_mem_usage += size + assert total_dev_mem_usage == self._dev.mem_usage() + return total_dev_mem_usage, dev_buf_access + + def maybe_evict(self, extra_dev_mem=0) -> None: + if ( # Shortcut when not evicting + self._dev.mem_usage() + extra_dev_mem <= self._device_memory_limit + ): + return - def get_dev_mem_usage(self) -> int: - return self.unspilled_proxies.dev_mem_usage + with self.lock: + total_dev_mem_usage, dev_buf_access = self.get_dev_access_info() + total_dev_mem_usage += extra_dev_mem + if total_dev_mem_usage > self._device_memory_limit: + dev_buf_access.sort(key=lambda x: (x[0], -x[1])) + for _, size, proxies in dev_buf_access: + for p in proxies: + # Serialize to disk, which "dask" and "pickle" does + p._obj_pxy_serialize(serializers=("dask", "pickle")) + total_dev_mem_usage -= size + if total_dev_mem_usage <= self._device_memory_limit: + break class ProxifyHostFile(MutableMapping): @@ -155,9 +299,9 @@ class ProxifyHostFile(MutableMapping): def __init__(self, device_memory_limit: int, compatibility_mode: bool = None): self.device_memory_limit = device_memory_limit - self.store = {} + self.store: Dict[Hashable, Any] = {} self.lock = threading.RLock() - self.proxies_tally = ProxiesTally() + self.manager = ProxyManager(device_memory_limit) if compatibility_mode is None: self.compatibility_mode = dask.config.get( "jit-unspill-compatibility-mode", default=False @@ -190,122 +334,21 @@ def fast(self): ) return None - def get_dev_buffer_to_proxies(self) -> DefaultDict[Hashable, List[ProxyObject]]: - with self.lock: - # Notice, multiple proxy object can point to different non-overlapping - # parts of the same device buffer. - ret = defaultdict(list) - for proxy in self.proxies_tally.get_unspilled_proxies(): - for dev_buffer in proxy._obj_pxy_get_device_memory_objects(): - ret[dev_buffer].append(proxy) - return ret - - def get_access_info(self) -> Tuple[int, List[Tuple[int, int, List[ProxyObject]]]]: - with self.lock: - total_dev_mem_usage = 0 - dev_buf_access = [] - for dev_buf, proxies in self.get_dev_buffer_to_proxies().items(): - last_access = max(p._obj_pxy.get("last_access", 0) for p in proxies) - size = sizeof(dev_buf) - dev_buf_access.append((last_access, size, proxies)) - total_dev_mem_usage += size - return total_dev_mem_usage, dev_buf_access - - def add_external(self, obj): - """Add an external object to the hostfile that count against the - device_memory_limit but isn't part of the store. - - Normally, we use __setitem__ to store objects in the hostfile and make it - count against the device_memory_limit with the inherent consequence that - the objects are not freeable before subsequential calls to __delitem__. - This is a problem for long running tasks that want objects to count against - the device_memory_limit while freeing them ASAP without explicit calls to - __delitem__. - - Developer Notes - --------------- - In order to avoid holding references to the found proxies in `obj`, we - wrap them in `weakref.proxy(p)` and adds them to the `proxies_tally`. - In order to remove them from the `proxies_tally` again, we attach a - finalize(p) on the wrapped proxies that calls del_external(). - """ - - # Notice, since `self.store` isn't modified, no lock is needed - found_proxies: List[ProxyObject] = [] - proxied_id_to_proxy = {} - # Notice, we are excluding found objects that are already proxies - ret = proxify_device_objects( - obj, proxied_id_to_proxy, found_proxies, excl_proxies=True - ) - last_access = time.monotonic() - self_weakref = weakref.ref(self) - for p in found_proxies: - name = id(p) - finalize = weakref.finalize(p, self.del_external, name) - external = weakref.proxy(p) - p._obj_pxy["hostfile"] = self_weakref - p._obj_pxy["last_access"] = last_access - p._obj_pxy["external"] = external - p._obj_pxy["external_finalize"] = finalize - self.proxies_tally.add_key(name, [external]) - self.maybe_evict() - return ret - - def del_external(self, name): - self.proxies_tally.del_key(name) - def __setitem__(self, key, value): with self.lock: if key in self.store: # Make sure we register the removal of an existing key del self[key] - - found_proxies: List[ProxyObject] = [] - proxied_id_to_proxy = self.proxies_tally.get_proxied_id_to_proxy() - self.store[key] = proxify_device_objects( - value, proxied_id_to_proxy, found_proxies - ) - last_access = time.monotonic() - self_weakref = weakref.ref(self) - for p in found_proxies: - p._obj_pxy["hostfile"] = self_weakref - p._obj_pxy["last_access"] = last_access - assert "external" not in p._obj_pxy - - self.proxies_tally.add_key(key, found_proxies) - self.maybe_evict() + self.store[key] = self.manager.proxify(value) def __getitem__(self, key): with self.lock: ret = self.store[key] if self.compatibility_mode: ret = unproxify_device_objects(ret, skip_explicit_proxies=True) - self.maybe_evict() + self.manager.maybe_evict() return ret def __delitem__(self, key): with self.lock: del self.store[key] - self.proxies_tally.del_key(key) - - def evict(self, proxy: ProxyObject): - proxy._obj_pxy_serialize(serializers=("dask", "pickle")) - - def maybe_evict(self, extra_dev_mem=0): - if ( # Shortcut when not evicting - self.proxies_tally.get_dev_mem_usage() + extra_dev_mem - <= self.device_memory_limit - ): - return - - with self.lock: - total_dev_mem_usage, dev_buf_access = self.get_access_info() - total_dev_mem_usage += extra_dev_mem - if total_dev_mem_usage > self.device_memory_limit: - dev_buf_access.sort(key=lambda x: (x[0], -x[1])) - for _, size, proxies in dev_buf_access: - for p in proxies: - self.evict(p) - total_dev_mem_usage -= size - if total_dev_mem_usage <= self.device_memory_limit: - break diff --git a/dask_cuda/proxy_object.py b/dask_cuda/proxy_object.py index 649f400ed..5dd8651b4 100644 --- a/dask_cuda/proxy_object.py +++ b/dask_cuda/proxy_object.py @@ -5,7 +5,10 @@ import threading import time from collections import OrderedDict -from typing import Any, Dict, List, Optional, Set +from contextlib import ( # TODO: use `contextlib.nullcontext()` from Python 3.7+ + suppress as nullcontext, +) +from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional, Type import pandas @@ -13,6 +16,7 @@ import dask.array.core import dask.dataframe.methods import dask.dataframe.utils +import dask.utils import distributed.protocol import distributed.utils from dask.sizeof import sizeof @@ -31,21 +35,26 @@ from .get_device_memory_objects import get_device_memory_objects from .is_device_object import is_device_object +if TYPE_CHECKING: + from .proxify_host_file import ProxyManager + + # List of attributes that should be copied to the proxy at creation, which makes # them accessible without deserialization of the proxied object _FIXED_ATTRS = ["name", "__len__"] -def asproxy(obj, serializers=None, subclass=None) -> "ProxyObject": +def asproxy( + obj: object, serializers: Iterable[str] = None, subclass: Type["ProxyObject"] = None +) -> "ProxyObject": """Wrap `obj` in a ProxyObject object if it isn't already. Parameters ---------- obj: object Object to wrap in a ProxyObject object. - serializers: list(str), optional - List of serializers to use to serialize `obj`. If None, - no serialization is done. + serializers: Iterable[str], optional + Serializers to use to serialize `obj`. If None, no serialization is done. subclass: class, optional Specify a subclass of ProxyObject to create instead of ProxyObject. `subclass` must be pickable. @@ -54,9 +63,10 @@ def asproxy(obj, serializers=None, subclass=None) -> "ProxyObject": ------- The ProxyObject proxying `obj` """ - - if hasattr(obj, "_obj_pxy"): # Already a proxy object + if isinstance(obj, ProxyObject): # Already a proxy object ret = obj + elif isinstance(obj, (list, set, tuple, dict)): + raise ValueError(f"Cannot wrap a collection ({type(obj)}) in a proxy object") else: fixed_attr = {} for attr in _FIXED_ATTRS: @@ -81,7 +91,7 @@ def asproxy(obj, serializers=None, subclass=None) -> "ProxyObject": typename=dask.utils.typename(type(obj)), is_cuda_object=is_device_object(obj), subclass=subclass_serialized, - serializers=None, + serializer=None, explicit_proxy=False, ) if serializers is not None: @@ -112,7 +122,7 @@ def unproxy(obj): return obj -def _obj_pxy_cache_wrapper(attr_name): +def _obj_pxy_cache_wrapper(attr_name: str): """Caching the access of attr_name in ProxyObject._obj_pxy_cache""" def wrapper1(func): @@ -183,9 +193,8 @@ class ProxyObject: subclass: bytes Pickled type to use instead of ProxyObject when deserializing. The type must inherit from ProxyObject. - serializers: list(str), optional - List of serializers to use to serialize `obj`. If None, `obj` - isn't serialized. + serializers: str, optional + Serializers to use to serialize `obj`. If None, no serialization is done. explicit_proxy: bool Mark the proxy object as "explicit", which means that the user allows it as input argument to dask tasks even in compatibility-mode. @@ -198,8 +207,8 @@ def __init__( type_serialized: bytes, typename: str, is_cuda_object: bool, - subclass: bytes, - serializers: Optional[List[str]], + subclass: Optional[bytes], + serializer: Optional[str], explicit_proxy: bool, ): self._obj_pxy = { @@ -209,19 +218,19 @@ def __init__( "typename": typename, "is_cuda_object": is_cuda_object, "subclass": subclass, - "serializers": serializers, + "serializer": serializer, "explicit_proxy": explicit_proxy, } self._obj_pxy_lock = threading.RLock() - self._obj_pxy_cache = {} + self._obj_pxy_cache: Dict[str, Any] = {} def __del__(self): - """In order to call `external_finalize()` ASAP, we call it here""" - external_finalize = self._obj_pxy.get("external_finalize", None) - if external_finalize is not None: - external_finalize() + """We have to unregister us from the manager if any""" + manager: "ProxyManager" = self._obj_pxy.get("manager", None) + if manager is not None: + manager.remove(self) - def _obj_pxy_get_init_args(self, include_obj=True): + def _obj_pxy_get_init_args(self, include_obj=True) -> OrderedDict: """Return the attributes needed to initialize a ProxyObject Notice, the returned dictionary is ordered as the __init__() arguments @@ -242,7 +251,7 @@ def _obj_pxy_get_init_args(self, include_obj=True): "typename", "is_cuda_object", "subclass", - "serializers", + "serializer", "explicit_proxy", ] return OrderedDict([(a, self._obj_pxy[a]) for a in args]) @@ -260,17 +269,35 @@ def _obj_pxy_copy(self) -> "ProxyObject": args["obj"] = self._obj_pxy["obj"] return type(self)(**args) - def _obj_pxy_is_serialized(self): + def _obj_pxy_register_manager(self, manager: "ProxyManager") -> None: + """Register a manager + + The manager tallies the total memory usage of proxies and + evicts/serialize proxy objects as needed. + + In order to prevent deadlocks, the proxy now use the lock of the + manager. + + Parameters + ---------- + manager: ProxyManager + The manager to manage this proxy object + """ + assert "manager" not in self._obj_pxy + self._obj_pxy["manager"] = manager + self._obj_pxy_lock = manager.lock + + def _obj_pxy_is_serialized(self) -> bool: """Return whether the proxied object is serialized or not""" - return self._obj_pxy["serializers"] is not None + return self._obj_pxy["serializer"] is not None - def _obj_pxy_serialize(self, serializers): + def _obj_pxy_serialize(self, serializers: Iterable[str]): """Inplace serialization of the proxied object using the `serializers` Parameters ---------- - serializers: tuple[str] - Tuple of serializers to use to serialize the proxied object. + serializers: Iterable[str] + Serializers to use to serialize the proxied object. Returns ------- @@ -282,30 +309,31 @@ def _obj_pxy_serialize(self, serializers): if not serializers: raise ValueError("Please specify a list of serializers") - if type(serializers) is not tuple: - serializers = tuple(serializers) - with self._obj_pxy_lock: - if self._obj_pxy["serializers"] is not None: - if self._obj_pxy["serializers"] == serializers: + if self._obj_pxy_is_serialized(): + if self._obj_pxy["serializer"] in serializers: return self._obj_pxy["obj"] # Nothing to be done else: # The proxied object is serialized with other serializers self._obj_pxy_deserialize() - if self._obj_pxy["serializers"] is None: - self._obj_pxy["obj"] = distributed.protocol.serialize( + # Lock manager (if any) + manager: "ProxyManager" = self._obj_pxy.get("manager", None) + with (nullcontext() if manager is None else manager.lock): + header, _ = self._obj_pxy["obj"] = distributed.protocol.serialize( self._obj_pxy["obj"], serializers, on_error="raise" ) - self._obj_pxy["serializers"] = serializers - hostfile = self._obj_pxy.get("hostfile", lambda: None)() - if hostfile is not None: - external = self._obj_pxy.get("external", self) - hostfile.proxies_tally.spill_proxy(external) - - # Invalidate the (possible) cached "device_memory_objects" - self._obj_pxy_cache.pop("device_memory_objects", None) - return self._obj_pxy["obj"] + assert "is-collection" not in header # Collections not allowed + org_ser, new_ser = self._obj_pxy["serializer"], header["serializer"] + self._obj_pxy["serializer"] = new_ser + + # Tell the manager (if any) that this proxy has changed serializer + if manager: + manager.move(self, from_serializer=org_ser, to_serializer=new_ser) + + # Invalidate the (possible) cached "device_memory_objects" + self._obj_pxy_cache.pop("device_memory_objects", None) + return self._obj_pxy["obj"] def _obj_pxy_deserialize(self, maybe_evict: bool = True): """Inplace deserialization of the proxied object @@ -313,7 +341,7 @@ def _obj_pxy_deserialize(self, maybe_evict: bool = True): Parameters ---------- maybe_evict: bool - Before deserializing, call associated hostfile.maybe_evict() + Before deserializing, maybe evict managered proxy objects Returns ------- @@ -321,27 +349,30 @@ def _obj_pxy_deserialize(self, maybe_evict: bool = True): The proxied object (deserialized) """ with self._obj_pxy_lock: - if self._obj_pxy["serializers"] is not None: - hostfile = self._obj_pxy.get("hostfile", lambda: None)() - # When not deserializing a CUDA-serialized proxied, we might have - # to evict because of the increased device memory usage. - if maybe_evict and "cuda" not in self._obj_pxy["serializers"]: - if hostfile is not None: - # In order to avoid a potential deadlock, we skip the - # `maybe_evict()` call if another thread is also accessing - # the hostfile. - if hostfile.lock.acquire(blocking=False): - try: - hostfile.maybe_evict(self.__sizeof__()) - finally: - hostfile.lock.release() - - header, frames = self._obj_pxy["obj"] - self._obj_pxy["obj"] = distributed.protocol.deserialize(header, frames) - self._obj_pxy["serializers"] = None - if hostfile is not None: - external = self._obj_pxy.get("external", self) - hostfile.proxies_tally.unspill_proxy(external) + if self._obj_pxy_is_serialized(): + manager: "ProxyManager" = self._obj_pxy.get("manager", None) + serializer = self._obj_pxy["serializer"] + + # Lock manager (if any) + with (nullcontext() if manager is None else manager.lock): + + # When not deserializing a CUDA-serialized proxied, tell the + # manager that it might have to evict because of the increased + # device memory usage. + if manager and maybe_evict and serializer != "cuda": + manager.maybe_evict(self.__sizeof__()) + + # Deserialize the proxied object + header, frames = self._obj_pxy["obj"] + self._obj_pxy["obj"] = distributed.protocol.deserialize( + header, frames + ) + self._obj_pxy["serializer"] = None + # Tell the manager (if any) that this proxy has changed serializer + if manager: + manager.move( + self, from_serializer=serializer, to_serializer=None + ) self._obj_pxy["last_access"] = time.monotonic() return self._obj_pxy["obj"] @@ -354,16 +385,12 @@ def _obj_pxy_is_cuda_object(self) -> bool: ret : boolean Is the proxied object a CUDA object? """ - with self._obj_pxy_lock: - return self._obj_pxy["is_cuda_object"] + return self._obj_pxy["is_cuda_object"] @_obj_pxy_cache_wrapper("device_memory_objects") - def _obj_pxy_get_device_memory_objects(self) -> Set: + def _obj_pxy_get_device_memory_objects(self) -> set: """Return all device memory objects within the proxied object. - Calling this when the proxied object is serialized returns the - empty list. - Returns ------- ret : set @@ -416,13 +443,13 @@ def __repr__(self): with self._obj_pxy_lock: typename = self._obj_pxy["typename"] ret = f"<{dask.utils.typename(type(self))} at {hex(id(self))} of {typename}" - if self._obj_pxy["serializers"] is not None: - ret += f" (serialized={repr(self._obj_pxy['serializers'])})>" + if self._obj_pxy_is_serialized(): + ret += f" (serialized={repr(self._obj_pxy['serializer'])})>" else: ret += f" at {hex(id(self._obj_pxy['obj']))}>" return ret - @property + @property # type: ignore # mypy doesn't support decorated property @_obj_pxy_cache_wrapper("type_serialized") def __class__(self): return pickle.loads(self._obj_pxy["type_serialized"]) @@ -515,8 +542,8 @@ def __mod__(self, other): def __divmod__(self, other): return divmod(self._obj_pxy_deserialize(), other) - def __pow__(self, other, *args): - return pow(self._obj_pxy_deserialize(), other, *args) + def __pow__(self, other): + return pow(self._obj_pxy_deserialize(), other) def __lshift__(self, other): return self._obj_pxy_deserialize() << other @@ -687,7 +714,7 @@ def obj_pxy_cuda_serialize(obj: ProxyObject): or another CUDA friendly communication library. As serializers, it uses "cuda", which means that proxied CUDA objects are _not_ spilled to main memory. """ - if obj._obj_pxy["serializers"] is not None: # Already serialized + if obj._obj_pxy_is_serialized(): # Already serialized header, frames = obj._obj_pxy["obj"] else: # Notice, since obj._obj_pxy_serialize() is a inplace operation, we make a diff --git a/dask_cuda/tests/test_proxify_host_file.py b/dask_cuda/tests/test_proxify_host_file.py index 2cbfafd8d..05b5223c8 100644 --- a/dask_cuda/tests/test_proxify_host_file.py +++ b/dask_cuda/tests/test_proxify_host_file.py @@ -1,3 +1,5 @@ +from typing import Iterable + import numpy as np import pandas import pytest @@ -12,9 +14,9 @@ import dask_cuda import dask_cuda.proxify_device_objects -import dask_cuda.proxy_object from dask_cuda.get_device_memory_objects import get_device_memory_objects from dask_cuda.proxify_host_file import ProxifyHostFile +from dask_cuda.proxy_object import ProxyObject cupy = pytest.importorskip("cupy") cupy.cuda.set_allocator(None) @@ -27,53 +29,80 @@ dask_cuda.proxify_device_objects.ignore_types = () +def is_proxies_equal(p1: Iterable[ProxyObject], p2: Iterable[ProxyObject]): + """Check that two collections of proxies contains the same proxies (unordered) + + In order to avoid deserializing proxy objects when comparing them, + this funcntion compares object IDs. + """ + + ids1 = sorted([id(p) for p in p1]) + ids2 = sorted([id(p) for p in p2]) + return ids1 == ids2 + + def test_one_item_limit(): dhf = ProxifyHostFile(device_memory_limit=one_item_nbytes) - dhf["k1"] = one_item_array() + 42 - dhf["k2"] = one_item_array() + + a1 = one_item_array() + 42 + a2 = one_item_array() + dhf["k1"] = a1 + dhf["k2"] = a2 # Check k1 is spilled because of the newer k2 k1 = dhf["k1"] k2 = dhf["k2"] assert k1._obj_pxy_is_serialized() assert not k2._obj_pxy_is_serialized() + assert is_proxies_equal(dhf.manager._host, [k1]) + assert is_proxies_equal(dhf.manager._dev, [k2]) # Accessing k1 spills k2 and unspill k1 k1_val = k1[0] assert k1_val == 42 assert k2._obj_pxy_is_serialized() + assert is_proxies_equal(dhf.manager._host, [k2]) + assert is_proxies_equal(dhf.manager._dev, [k1]) # Duplicate arrays changes nothing dhf["k3"] = [k1, k2] assert not k1._obj_pxy_is_serialized() assert k2._obj_pxy_is_serialized() + assert is_proxies_equal(dhf.manager._host, [k2]) + assert is_proxies_equal(dhf.manager._dev, [k1]) # Adding a new array spills k1 and k2 dhf["k4"] = one_item_array() + k4 = dhf["k4"] assert k1._obj_pxy_is_serialized() assert k2._obj_pxy_is_serialized() assert not dhf["k4"]._obj_pxy_is_serialized() + assert is_proxies_equal(dhf.manager._host, [k1, k2]) + assert is_proxies_equal(dhf.manager._dev, [k4]) # Accessing k2 spills k1 and k4 k2[0] assert k1._obj_pxy_is_serialized() assert dhf["k4"]._obj_pxy_is_serialized() assert not k2._obj_pxy_is_serialized() + assert is_proxies_equal(dhf.manager._host, [k1, k4]) + assert is_proxies_equal(dhf.manager._dev, [k2]) # Deleting k2 does not change anything since k3 still holds a # reference to the underlying proxy object - assert dhf.proxies_tally.get_dev_mem_usage() == one_item_nbytes - p1 = list(dhf.proxies_tally.get_unspilled_proxies()) - assert len(p1) == 1 + assert dhf.manager.get_dev_access_info()[0] == one_item_nbytes + assert is_proxies_equal(dhf.manager._host, [k1, k4]) + assert is_proxies_equal(dhf.manager._dev, [k2]) del dhf["k2"] - assert dhf.proxies_tally.get_dev_mem_usage() == one_item_nbytes - p2 = list(dhf.proxies_tally.get_unspilled_proxies()) - assert len(p2) == 1 - assert p1[0] is p2[0] + assert is_proxies_equal(dhf.manager._host, [k1, k4]) + assert is_proxies_equal(dhf.manager._dev, [k2]) - # Overwriting "k3" with a non-cuda object, should be noticed + # Overwriting "k3" with a non-cuda object and deleting `k2` + # should empty the device dhf["k3"] = "non-cuda-object" - assert dhf.proxies_tally.get_dev_mem_usage() == 0 + del k2 + assert is_proxies_equal(dhf.manager._host, [k1, k4]) + assert is_proxies_equal(dhf.manager._dev, []) @pytest.mark.parametrize("jit_unspill", [True, False]) @@ -87,7 +116,7 @@ def task(x): if jit_unspill: # Check that `x` is a proxy object and the proxied DataFrame is serialized assert "FrameProxyObject" in str(type(x)) - assert x._obj_pxy["serializers"] == ("dask", "pickle") + assert x._obj_pxy["serializer"] == "dask" else: assert type(x) == cudf.DataFrame assert len(x) == 10 # Trigger deserialization @@ -144,59 +173,49 @@ def test_cudf_get_device_memory_objects(): def test_externals(): + """Test adding objects directly to the manager + + Add an object directly to the manager makes it count against the + device_memory_limit but isn't part of the store. + + Normally, we use __setitem__ to store objects in the hostfile and make it + count against the device_memory_limit with the inherent consequence that + the objects are not freeable before subsequential calls to __delitem__. + This is a problem for long running tasks that want objects to count against + the device_memory_limit while freeing them ASAP without explicit calls to + __delitem__. + """ dhf = ProxifyHostFile(device_memory_limit=one_item_nbytes) dhf["k1"] = one_item_array() k1 = dhf["k1"] - k2 = dhf.add_external(one_item_array()) + k2 = dhf.manager.proxify(one_item_array()) # `k2` isn't part of the store but still triggers spilling of `k1` assert len(dhf) == 1 assert k1._obj_pxy_is_serialized() assert not k2._obj_pxy_is_serialized() + assert is_proxies_equal(dhf.manager._host, [k1]) + assert is_proxies_equal(dhf.manager._dev, [k2]) + assert dhf.manager._dev._mem_usage == one_item_nbytes + k1[0] # Trigger spilling of `k2` assert not k1._obj_pxy_is_serialized() assert k2._obj_pxy_is_serialized() + assert is_proxies_equal(dhf.manager._host, [k2]) + assert is_proxies_equal(dhf.manager._dev, [k1]) + assert dhf.manager._dev._mem_usage == one_item_nbytes + k2[0] # Trigger spilling of `k1` assert k1._obj_pxy_is_serialized() assert not k2._obj_pxy_is_serialized() - assert dhf.proxies_tally.get_dev_mem_usage() == one_item_nbytes + assert is_proxies_equal(dhf.manager._host, [k1]) + assert is_proxies_equal(dhf.manager._dev, [k2]) + assert dhf.manager._dev._mem_usage == one_item_nbytes + # Removing `k2` also removes it from the tally del k2 - assert dhf.proxies_tally.get_dev_mem_usage() == 0 - assert len(list(dhf.proxies_tally.get_unspilled_proxies())) == 0 - - -def test_externals_setitem(): - dhf = ProxifyHostFile(device_memory_limit=one_item_nbytes) - k1 = dhf.add_external(one_item_array()) - assert type(k1) is dask_cuda.proxy_object.ProxyObject - assert len(dhf) == 0 - assert "external" in k1._obj_pxy - assert "external_finalize" in k1._obj_pxy - dhf["k1"] = k1 - k1 = dhf["k1"] - assert type(k1) is dask_cuda.proxy_object.ProxyObject - assert len(dhf) == 1 - assert "external" not in k1._obj_pxy - assert "external_finalize" not in k1._obj_pxy - - k1 = dhf.add_external(one_item_array()) - k1._obj_pxy_serialize(serializers=("dask", "pickle")) - dhf["k1"] = k1 - k1 = dhf["k1"] - assert type(k1) is dask_cuda.proxy_object.ProxyObject - assert len(dhf) == 1 - assert "external" not in k1._obj_pxy - assert "external_finalize" not in k1._obj_pxy - - dhf["k1"] = one_item_array() - assert len(dhf.proxies_tally.proxy_id_to_proxy) == 1 - assert dhf.proxies_tally.get_dev_mem_usage() == one_item_nbytes - k1 = dhf.add_external(k1) - assert len(dhf.proxies_tally.proxy_id_to_proxy) == 1 - assert dhf.proxies_tally.get_dev_mem_usage() == one_item_nbytes - k1 = dhf.add_external(dhf["k1"]) - assert len(dhf.proxies_tally.proxy_id_to_proxy) == 1 - assert dhf.proxies_tally.get_dev_mem_usage() == one_item_nbytes + assert is_proxies_equal(dhf.manager._host, [k1]) + assert is_proxies_equal(dhf.manager._dev, []) + assert dhf.manager._dev._mem_usage == 0 def test_proxify_device_objects_of_cupy_array(): diff --git a/dask_cuda/tests/test_proxy.py b/dask_cuda/tests/test_proxy.py index 6d3f1c972..f0d1f7393 100644 --- a/dask_cuda/tests/test_proxy.py +++ b/dask_cuda/tests/test_proxy.py @@ -23,38 +23,58 @@ def test_proxy_object(serializers): """Check "transparency" of the proxy object""" - org = list(range(10)) + org = bytearray(range(10)) pxy = proxy_object.asproxy(org, serializers=serializers) assert len(org) == len(pxy) assert org[0] == pxy[0] assert 1 in pxy - assert -1 not in pxy + assert 10 not in pxy assert str(org) == str(pxy) assert "dask_cuda.proxy_object.ProxyObject at " in repr(pxy) - assert "list at " in repr(pxy) + assert "bytearray at " in repr(pxy) pxy._obj_pxy_serialize(serializers=("dask", "pickle")) assert "dask_cuda.proxy_object.ProxyObject at " in repr(pxy) - assert "list (serialized=('dask', 'pickle'))" in repr(pxy) + assert "bytearray (serialized='dask')" in repr(pxy) assert org == proxy_object.unproxy(pxy) assert org == proxy_object.unproxy(org) +class DummyObj: + """Class that only "pickle" can serialize""" + + def __reduce__(self): + return (DummyObj, ()) + + +def test_proxy_object_serializer(): + """Check the serializers argument""" + pxy = proxy_object.asproxy(DummyObj(), serializers=("dask", "pickle")) + assert pxy._obj_pxy["serializer"] == "pickle" + assert "DummyObj (serialized='pickle')" in repr(pxy) + + with pytest.raises(ValueError) as excinfo: + pxy = proxy_object.asproxy([42], serializers=("dask", "pickle")) + assert "Cannot wrap a collection" in str(excinfo.value) + + @pytest.mark.parametrize("serializers_first", [None, ("dask", "pickle")]) @pytest.mark.parametrize("serializers_second", [None, ("dask", "pickle")]) def test_double_proxy_object(serializers_first, serializers_second): """Check asproxy() when creating a proxy object of a proxy object""" - org = list(range(10)) + serializer1 = serializers_first[0] if serializers_first else None + serializer2 = serializers_second[0] if serializers_second else None + org = bytearray(range(10)) pxy1 = proxy_object.asproxy(org, serializers=serializers_first) - assert pxy1._obj_pxy["serializers"] == serializers_first + assert pxy1._obj_pxy["serializer"] == serializer1 pxy2 = proxy_object.asproxy(pxy1, serializers=serializers_second) if serializers_second is None: # Check that `serializers=None` doesn't change the initial serializers - assert pxy2._obj_pxy["serializers"] == serializers_first + assert pxy2._obj_pxy["serializer"] == serializer1 else: - assert pxy2._obj_pxy["serializers"] == serializers_second + assert pxy2._obj_pxy["serializer"] == serializer2 assert pxy1 is pxy2 @@ -257,7 +277,7 @@ def task(x): if jit_unspill: # Check that `x` is a proxy object and the proxied DataFrame is serialized assert "FrameProxyObject" in str(type(x)) - assert x._obj_pxy["serializers"] == ("dask", "pickle") + assert x._obj_pxy["serializer"] == "dask" else: assert type(x) == cudf.DataFrame assert len(x) == 10 # Trigger deserialization @@ -292,7 +312,7 @@ def __dask_tokenize__(self): def _obj_pxy_deserialize(self): if self._obj_pxy["assert_on_deserializing"]: - assert self._obj_pxy["serializers"] is None + assert self._obj_pxy["serializer"] is None return super()._obj_pxy_deserialize() @@ -305,16 +325,16 @@ def test_communicating_proxy_objects(protocol, send_serializers): def task(x): # Check that the subclass survives the trip from client to worker assert isinstance(x, _PxyObjTest) - serializers_used = x._obj_pxy["serializers"] + serializers_used = x._obj_pxy["serializer"] # Check that `x` is serialized with the expected serializers if protocol == "ucx": if send_serializers is None: - assert serializers_used == ("cuda",) + assert serializers_used == "cuda" else: - assert serializers_used == send_serializers + assert serializers_used == send_serializers[0] else: - assert serializers_used == ("dask", "pickle") + assert serializers_used == "dask" with dask_cuda.LocalCUDACluster( n_workers=1, protocol=protocol, enable_tcp_over_ucx=protocol == "ucx" diff --git a/dask_cuda/utils.py b/dask_cuda/utils.py index b716e2a83..457306bcf 100644 --- a/dask_cuda/utils.py +++ b/dask_cuda/utils.py @@ -594,7 +594,6 @@ def nvml_device_index(i, CUDA_VISIBLE_DEVICES): def parse_device_memory_limit(device_memory_limit, device_index=0): """Parse memory limit to be used by a CUDA device. - Parameters ---------- device_memory_limit: float, int, str or None