From ce4a260e1bd24d267172a189e3953ac0547875a6 Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Fri, 20 Aug 2021 09:21:43 +0200 Subject: [PATCH 01/27] clean up --- dask_cuda/proxy_object.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/dask_cuda/proxy_object.py b/dask_cuda/proxy_object.py index 649f400ed..27edb8e07 100644 --- a/dask_cuda/proxy_object.py +++ b/dask_cuda/proxy_object.py @@ -5,7 +5,7 @@ import threading import time from collections import OrderedDict -from typing import Any, Dict, List, Optional, Set +from typing import Any, Dict, Iterable, List, Optional, Set import pandas @@ -264,12 +264,12 @@ def _obj_pxy_is_serialized(self): """Return whether the proxied object is serialized or not""" return self._obj_pxy["serializers"] is not None - def _obj_pxy_serialize(self, serializers): + def _obj_pxy_serialize(self, serializers: Iterable[str]): """Inplace serialization of the proxied object using the `serializers` Parameters ---------- - serializers: tuple[str] + serializers: Iterable[str] Tuple of serializers to use to serialize the proxied object. Returns @@ -281,9 +281,7 @@ def _obj_pxy_serialize(self, serializers): """ if not serializers: raise ValueError("Please specify a list of serializers") - - if type(serializers) is not tuple: - serializers = tuple(serializers) + serializers = tuple(serializers) with self._obj_pxy_lock: if self._obj_pxy["serializers"] is not None: From 617a91a450c3b906d0dd126eb431b41230061879 Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Fri, 20 Aug 2021 09:48:30 +0200 Subject: [PATCH 02/27] mypy: type hints --- dask_cuda/proxify_host_file.py | 3 ++- dask_cuda/proxy_object.py | 12 ++++++------ 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/dask_cuda/proxify_host_file.py b/dask_cuda/proxify_host_file.py index 6dd5d6b6b..20eee76e4 100644 --- a/dask_cuda/proxify_host_file.py +++ b/dask_cuda/proxify_host_file.py @@ -4,6 +4,7 @@ import weakref from collections import defaultdict from typing import ( + Any, DefaultDict, Dict, Hashable, @@ -155,7 +156,7 @@ class ProxifyHostFile(MutableMapping): def __init__(self, device_memory_limit: int, compatibility_mode: bool = None): self.device_memory_limit = device_memory_limit - self.store = {} + self.store: Dict[Hashable, Any] = {} self.lock = threading.RLock() self.proxies_tally = ProxiesTally() if compatibility_mode is None: diff --git a/dask_cuda/proxy_object.py b/dask_cuda/proxy_object.py index 27edb8e07..6b1f37390 100644 --- a/dask_cuda/proxy_object.py +++ b/dask_cuda/proxy_object.py @@ -112,7 +112,7 @@ def unproxy(obj): return obj -def _obj_pxy_cache_wrapper(attr_name): +def _obj_pxy_cache_wrapper(attr_name: str): """Caching the access of attr_name in ProxyObject._obj_pxy_cache""" def wrapper1(func): @@ -213,7 +213,7 @@ def __init__( "explicit_proxy": explicit_proxy, } self._obj_pxy_lock = threading.RLock() - self._obj_pxy_cache = {} + self._obj_pxy_cache: Dict[str, Any] = {} def __del__(self): """In order to call `external_finalize()` ASAP, we call it here""" @@ -420,8 +420,8 @@ def __repr__(self): ret += f" at {hex(id(self._obj_pxy['obj']))}>" return ret - @property - @_obj_pxy_cache_wrapper("type_serialized") + @property # type: ignore # mypy doesn't support decorated property + @_obj_pxy_cache_wrapper("type_serialized") # def __class__(self): return pickle.loads(self._obj_pxy["type_serialized"]) @@ -513,8 +513,8 @@ def __mod__(self, other): def __divmod__(self, other): return divmod(self._obj_pxy_deserialize(), other) - def __pow__(self, other, *args): - return pow(self._obj_pxy_deserialize(), other, *args) + def __pow__(self, other): + return pow(self._obj_pxy_deserialize(), other) def __lshift__(self, other): return self._obj_pxy_deserialize() << other From dab39646f770c52aff3faa23e65afe48230cb0a5 Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Fri, 20 Aug 2021 12:01:04 +0200 Subject: [PATCH 03/27] ProxyObject: clean up --- dask_cuda/proxy_object.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/dask_cuda/proxy_object.py b/dask_cuda/proxy_object.py index 6b1f37390..01c41faef 100644 --- a/dask_cuda/proxy_object.py +++ b/dask_cuda/proxy_object.py @@ -291,15 +291,14 @@ def _obj_pxy_serialize(self, serializers: Iterable[str]): # The proxied object is serialized with other serializers self._obj_pxy_deserialize() - if self._obj_pxy["serializers"] is None: - self._obj_pxy["obj"] = distributed.protocol.serialize( - self._obj_pxy["obj"], serializers, on_error="raise" - ) - self._obj_pxy["serializers"] = serializers - hostfile = self._obj_pxy.get("hostfile", lambda: None)() - if hostfile is not None: - external = self._obj_pxy.get("external", self) - hostfile.proxies_tally.spill_proxy(external) + self._obj_pxy["obj"] = distributed.protocol.serialize( + self._obj_pxy["obj"], serializers, on_error="raise" + ) + self._obj_pxy["serializers"] = serializers + hostfile = self._obj_pxy.get("hostfile", lambda: None)() + if hostfile is not None: + external = self._obj_pxy.get("external", self) + hostfile.proxies_tally.spill_proxy(external) # Invalidate the (possible) cached "device_memory_objects" self._obj_pxy_cache.pop("device_memory_objects", None) From c9321004445a85b573918ea31906b4db19e931f3 Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Fri, 20 Aug 2021 12:20:15 +0200 Subject: [PATCH 04/27] DeviceHostFile: clean up local_directory argument --- dask_cuda/device_host_file.py | 14 ++++++-------- dask_cuda/local_cuda_cluster.py | 4 +--- 2 files changed, 7 insertions(+), 11 deletions(-) diff --git a/dask_cuda/device_host_file.py b/dask_cuda/device_host_file.py index 5e2463be0..c03fa2973 100644 --- a/dask_cuda/device_host_file.py +++ b/dask_cuda/device_host_file.py @@ -175,14 +175,12 @@ def __init__( local_directory=None, log_spilling=False, ): - if local_directory is None: - local_directory = dask.config.get("temporary-directory") or os.getcwd() - - if local_directory and not os.path.exists(local_directory): - os.makedirs(local_directory, exist_ok=True) - local_directory = os.path.join(local_directory, "dask-worker-space") - - self.disk_func_path = os.path.join(local_directory, "storage") + self.disk_func_path = os.path.join( + local_directory or dask.config.get("temporary-directory") or os.getcwd(), + "dask-worker-space", + "storage", + ) + os.makedirs(self.disk_func_path, exist_ok=True) self.host_func = dict() self.disk_func = Func( diff --git a/dask_cuda/local_cuda_cluster.py b/dask_cuda/local_cuda_cluster.py index 4a706e67f..07c51c863 100644 --- a/dask_cuda/local_cuda_cluster.py +++ b/dask_cuda/local_cuda_cluster.py @@ -273,9 +273,7 @@ def __init__( { "device_memory_limit": self.device_memory_limit, "memory_limit": self.host_memory_limit, - "local_directory": local_directory - or dask.config.get("temporary-directory") - or os.getcwd(), + "local_directory": local_directory, "log_spilling": log_spilling, }, ) From ed19cb8bc4a03e1c885fa7b0d51e86fc6d5a6e93 Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Fri, 20 Aug 2021 12:22:05 +0200 Subject: [PATCH 05/27] clean up --- dask_cuda/utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/dask_cuda/utils.py b/dask_cuda/utils.py index b716e2a83..457306bcf 100644 --- a/dask_cuda/utils.py +++ b/dask_cuda/utils.py @@ -594,7 +594,6 @@ def nvml_device_index(i, CUDA_VISIBLE_DEVICES): def parse_device_memory_limit(device_memory_limit, device_index=0): """Parse memory limit to be used by a CUDA device. - Parameters ---------- device_memory_limit: float, int, str or None From e7975d14725c0e9660f7e1db583a68bac1ecbb5e Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Fri, 20 Aug 2021 14:25:43 +0200 Subject: [PATCH 06/27] Clean up --- dask_cuda/get_device_memory_objects.py | 2 +- dask_cuda/proxy_object.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/dask_cuda/get_device_memory_objects.py b/dask_cuda/get_device_memory_objects.py index deba96a06..bd00ba0f5 100644 --- a/dask_cuda/get_device_memory_objects.py +++ b/dask_cuda/get_device_memory_objects.py @@ -28,7 +28,7 @@ def get_device_memory_objects(obj) -> set: @dispatch.register(object) def get_device_memory_objects_default(obj): if hasattr(obj, "_obj_pxy"): - if obj._obj_pxy["serializers"] is None: + if not obj._obj_pxy_is_serialized(): return dispatch(obj._obj_pxy["obj"]) else: return [] diff --git a/dask_cuda/proxy_object.py b/dask_cuda/proxy_object.py index 01c41faef..7edb5be81 100644 --- a/dask_cuda/proxy_object.py +++ b/dask_cuda/proxy_object.py @@ -284,7 +284,7 @@ def _obj_pxy_serialize(self, serializers: Iterable[str]): serializers = tuple(serializers) with self._obj_pxy_lock: - if self._obj_pxy["serializers"] is not None: + if self._obj_pxy_is_serialized(): if self._obj_pxy["serializers"] == serializers: return self._obj_pxy["obj"] # Nothing to be done else: @@ -318,7 +318,7 @@ def _obj_pxy_deserialize(self, maybe_evict: bool = True): The proxied object (deserialized) """ with self._obj_pxy_lock: - if self._obj_pxy["serializers"] is not None: + if self._obj_pxy_is_serialized(): hostfile = self._obj_pxy.get("hostfile", lambda: None)() # When not deserializing a CUDA-serialized proxied, we might have # to evict because of the increased device memory usage. @@ -413,7 +413,7 @@ def __repr__(self): with self._obj_pxy_lock: typename = self._obj_pxy["typename"] ret = f"<{dask.utils.typename(type(self))} at {hex(id(self))} of {typename}" - if self._obj_pxy["serializers"] is not None: + if self._obj_pxy_is_serialized(): ret += f" (serialized={repr(self._obj_pxy['serializers'])})>" else: ret += f" at {hex(id(self._obj_pxy['obj']))}>" @@ -684,7 +684,7 @@ def obj_pxy_cuda_serialize(obj: ProxyObject): or another CUDA friendly communication library. As serializers, it uses "cuda", which means that proxied CUDA objects are _not_ spilled to main memory. """ - if obj._obj_pxy["serializers"] is not None: # Already serialized + if obj._obj_pxy_is_serialized(): # Already serialized header, frames = obj._obj_pxy["obj"] else: # Notice, since obj._obj_pxy_serialize() is a inplace operation, we make a From 716a17939a42b7a5e58f871ae73d367cbcfb36a9 Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Thu, 26 Aug 2021 12:29:34 +0200 Subject: [PATCH 07/27] More type hints --- dask_cuda/proxy_object.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/dask_cuda/proxy_object.py b/dask_cuda/proxy_object.py index 7edb5be81..e58a48317 100644 --- a/dask_cuda/proxy_object.py +++ b/dask_cuda/proxy_object.py @@ -5,7 +5,7 @@ import threading import time from collections import OrderedDict -from typing import Any, Dict, Iterable, List, Optional, Set +from typing import Any, Dict, Iterable, List, Optional, Set, Type import pandas @@ -36,16 +36,17 @@ _FIXED_ATTRS = ["name", "__len__"] -def asproxy(obj, serializers=None, subclass=None) -> "ProxyObject": +def asproxy( + obj: object, serializers: Iterable[str] = None, subclass: Type["ProxyObject"] = None +) -> "ProxyObject": """Wrap `obj` in a ProxyObject object if it isn't already. Parameters ---------- obj: object Object to wrap in a ProxyObject object. - serializers: list(str), optional - List of serializers to use to serialize `obj`. If None, - no serialization is done. + serializers: Iterable[str], optional + Serializers to use to serialize `obj`. If None, no serialization is done. subclass: class, optional Specify a subclass of ProxyObject to create instead of ProxyObject. `subclass` must be pickable. @@ -183,9 +184,8 @@ class ProxyObject: subclass: bytes Pickled type to use instead of ProxyObject when deserializing. The type must inherit from ProxyObject. - serializers: list(str), optional - List of serializers to use to serialize `obj`. If None, `obj` - isn't serialized. + serializers: Iterable[str], optional + Serializers to use to serialize `obj`. If None, no serialization is done. explicit_proxy: bool Mark the proxy object as "explicit", which means that the user allows it as input argument to dask tasks even in compatibility-mode. @@ -199,7 +199,7 @@ def __init__( typename: str, is_cuda_object: bool, subclass: bytes, - serializers: Optional[List[str]], + serializers: Optional[Iterable[str]], explicit_proxy: bool, ): self._obj_pxy = { From 660ca0f254cc388e3c58f0a2c703d0d6f9ef9b04 Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Thu, 26 Aug 2021 13:00:13 +0200 Subject: [PATCH 08/27] Disallowing proxy of collections In order to address ambiguity when serializing using multiple serializers, we disallow proxies of collections such as list or tuples. Instead, users should wrap each collection item in a proxy. --- dask_cuda/proxy_object.py | 5 +++-- dask_cuda/tests/test_proxy.py | 10 +++++----- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/dask_cuda/proxy_object.py b/dask_cuda/proxy_object.py index e58a48317..7aba58328 100644 --- a/dask_cuda/proxy_object.py +++ b/dask_cuda/proxy_object.py @@ -5,7 +5,7 @@ import threading import time from collections import OrderedDict -from typing import Any, Dict, Iterable, List, Optional, Set, Type +from typing import Any, Dict, Iterable, Optional, Set, Type import pandas @@ -55,9 +55,10 @@ def asproxy( ------- The ProxyObject proxying `obj` """ - if hasattr(obj, "_obj_pxy"): # Already a proxy object ret = obj + elif isinstance(obj, (list, set, tuple, dict)): + raise ValueError(f"Cannot wrap a collection ({type(obj)}) in a proxy object") else: fixed_attr = {} for attr in _FIXED_ATTRS: diff --git a/dask_cuda/tests/test_proxy.py b/dask_cuda/tests/test_proxy.py index 6d3f1c972..c208cb990 100644 --- a/dask_cuda/tests/test_proxy.py +++ b/dask_cuda/tests/test_proxy.py @@ -23,20 +23,20 @@ def test_proxy_object(serializers): """Check "transparency" of the proxy object""" - org = list(range(10)) + org = bytearray(range(10)) pxy = proxy_object.asproxy(org, serializers=serializers) assert len(org) == len(pxy) assert org[0] == pxy[0] assert 1 in pxy - assert -1 not in pxy + assert 10 not in pxy assert str(org) == str(pxy) assert "dask_cuda.proxy_object.ProxyObject at " in repr(pxy) - assert "list at " in repr(pxy) + assert "bytearray at " in repr(pxy) pxy._obj_pxy_serialize(serializers=("dask", "pickle")) assert "dask_cuda.proxy_object.ProxyObject at " in repr(pxy) - assert "list (serialized=('dask', 'pickle'))" in repr(pxy) + assert "bytearray (serialized=('dask', 'pickle'))" in repr(pxy) assert org == proxy_object.unproxy(pxy) assert org == proxy_object.unproxy(org) @@ -46,7 +46,7 @@ def test_proxy_object(serializers): @pytest.mark.parametrize("serializers_second", [None, ("dask", "pickle")]) def test_double_proxy_object(serializers_first, serializers_second): """Check asproxy() when creating a proxy object of a proxy object""" - org = list(range(10)) + org = bytearray(range(10)) pxy1 = proxy_object.asproxy(org, serializers=serializers_first) assert pxy1._obj_pxy["serializers"] == serializers_first pxy2 = proxy_object.asproxy(pxy1, serializers=serializers_second) From 871ac9733eb5a6cfe5615f6781967179e086b39b Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Thu, 26 Aug 2021 13:40:33 +0200 Subject: [PATCH 09/27] Now tracking a proxy's actual serializer --- dask_cuda/proxify_device_objects.py | 2 +- dask_cuda/proxy_object.py | 33 ++++++++++--------- dask_cuda/tests/test_proxify_host_file.py | 2 +- dask_cuda/tests/test_proxy.py | 40 +++++++++++++++++------ 4 files changed, 49 insertions(+), 28 deletions(-) diff --git a/dask_cuda/proxify_device_objects.py b/dask_cuda/proxify_device_objects.py index f3e3efb3f..70f90b0b8 100644 --- a/dask_cuda/proxify_device_objects.py +++ b/dask_cuda/proxify_device_objects.py @@ -192,7 +192,7 @@ def proxify_device_object_proxy_object( ): # We deserialize CUDA-serialized objects since it is very cheap and # makes it easy to administrate device memory usage - if obj._obj_pxy_is_serialized() and "cuda" in obj._obj_pxy["serializers"]: + if obj._obj_pxy_is_serialized() and obj._obj_pxy["serializer"] != "cuda": obj._obj_pxy_deserialize() # Check if `obj` is already known diff --git a/dask_cuda/proxy_object.py b/dask_cuda/proxy_object.py index 7aba58328..9098255e0 100644 --- a/dask_cuda/proxy_object.py +++ b/dask_cuda/proxy_object.py @@ -13,6 +13,7 @@ import dask.array.core import dask.dataframe.methods import dask.dataframe.utils +import dask.utils import distributed.protocol import distributed.utils from dask.sizeof import sizeof @@ -55,7 +56,7 @@ def asproxy( ------- The ProxyObject proxying `obj` """ - if hasattr(obj, "_obj_pxy"): # Already a proxy object + if isinstance(obj, ProxyObject): # Already a proxy object ret = obj elif isinstance(obj, (list, set, tuple, dict)): raise ValueError(f"Cannot wrap a collection ({type(obj)}) in a proxy object") @@ -83,7 +84,7 @@ def asproxy( typename=dask.utils.typename(type(obj)), is_cuda_object=is_device_object(obj), subclass=subclass_serialized, - serializers=None, + serializer=None, explicit_proxy=False, ) if serializers is not None: @@ -185,7 +186,7 @@ class ProxyObject: subclass: bytes Pickled type to use instead of ProxyObject when deserializing. The type must inherit from ProxyObject. - serializers: Iterable[str], optional + serializers: str, optional Serializers to use to serialize `obj`. If None, no serialization is done. explicit_proxy: bool Mark the proxy object as "explicit", which means that the user allows it @@ -199,8 +200,8 @@ def __init__( type_serialized: bytes, typename: str, is_cuda_object: bool, - subclass: bytes, - serializers: Optional[Iterable[str]], + subclass: Optional[bytes], + serializer: Optional[str], explicit_proxy: bool, ): self._obj_pxy = { @@ -210,7 +211,7 @@ def __init__( "typename": typename, "is_cuda_object": is_cuda_object, "subclass": subclass, - "serializers": serializers, + "serializer": serializer, "explicit_proxy": explicit_proxy, } self._obj_pxy_lock = threading.RLock() @@ -243,7 +244,7 @@ def _obj_pxy_get_init_args(self, include_obj=True): "typename", "is_cuda_object", "subclass", - "serializers", + "serializer", "explicit_proxy", ] return OrderedDict([(a, self._obj_pxy[a]) for a in args]) @@ -263,7 +264,7 @@ def _obj_pxy_copy(self) -> "ProxyObject": def _obj_pxy_is_serialized(self): """Return whether the proxied object is serialized or not""" - return self._obj_pxy["serializers"] is not None + return self._obj_pxy["serializer"] is not None def _obj_pxy_serialize(self, serializers: Iterable[str]): """Inplace serialization of the proxied object using the `serializers` @@ -271,7 +272,7 @@ def _obj_pxy_serialize(self, serializers: Iterable[str]): Parameters ---------- serializers: Iterable[str] - Tuple of serializers to use to serialize the proxied object. + Serializers to use to serialize the proxied object. Returns ------- @@ -282,20 +283,20 @@ def _obj_pxy_serialize(self, serializers: Iterable[str]): """ if not serializers: raise ValueError("Please specify a list of serializers") - serializers = tuple(serializers) with self._obj_pxy_lock: if self._obj_pxy_is_serialized(): - if self._obj_pxy["serializers"] == serializers: + if self._obj_pxy["serializer"] in serializers: return self._obj_pxy["obj"] # Nothing to be done else: # The proxied object is serialized with other serializers self._obj_pxy_deserialize() - self._obj_pxy["obj"] = distributed.protocol.serialize( + header, _ = self._obj_pxy["obj"] = distributed.protocol.serialize( self._obj_pxy["obj"], serializers, on_error="raise" ) - self._obj_pxy["serializers"] = serializers + assert "is-collection" not in header # Collections not allowed + self._obj_pxy["serializer"] = header["serializer"] hostfile = self._obj_pxy.get("hostfile", lambda: None)() if hostfile is not None: external = self._obj_pxy.get("external", self) @@ -323,7 +324,7 @@ def _obj_pxy_deserialize(self, maybe_evict: bool = True): hostfile = self._obj_pxy.get("hostfile", lambda: None)() # When not deserializing a CUDA-serialized proxied, we might have # to evict because of the increased device memory usage. - if maybe_evict and "cuda" not in self._obj_pxy["serializers"]: + if maybe_evict and self._obj_pxy["serializer"] != "cuda": if hostfile is not None: # In order to avoid a potential deadlock, we skip the # `maybe_evict()` call if another thread is also accessing @@ -336,7 +337,7 @@ def _obj_pxy_deserialize(self, maybe_evict: bool = True): header, frames = self._obj_pxy["obj"] self._obj_pxy["obj"] = distributed.protocol.deserialize(header, frames) - self._obj_pxy["serializers"] = None + self._obj_pxy["serializer"] = None if hostfile is not None: external = self._obj_pxy.get("external", self) hostfile.proxies_tally.unspill_proxy(external) @@ -415,7 +416,7 @@ def __repr__(self): typename = self._obj_pxy["typename"] ret = f"<{dask.utils.typename(type(self))} at {hex(id(self))} of {typename}" if self._obj_pxy_is_serialized(): - ret += f" (serialized={repr(self._obj_pxy['serializers'])})>" + ret += f" (serialized={repr(self._obj_pxy['serializer'])})>" else: ret += f" at {hex(id(self._obj_pxy['obj']))}>" return ret diff --git a/dask_cuda/tests/test_proxify_host_file.py b/dask_cuda/tests/test_proxify_host_file.py index 2cbfafd8d..437c62f0a 100644 --- a/dask_cuda/tests/test_proxify_host_file.py +++ b/dask_cuda/tests/test_proxify_host_file.py @@ -87,7 +87,7 @@ def task(x): if jit_unspill: # Check that `x` is a proxy object and the proxied DataFrame is serialized assert "FrameProxyObject" in str(type(x)) - assert x._obj_pxy["serializers"] == ("dask", "pickle") + assert x._obj_pxy["serializer"] == "dask" else: assert type(x) == cudf.DataFrame assert len(x) == 10 # Trigger deserialization diff --git a/dask_cuda/tests/test_proxy.py b/dask_cuda/tests/test_proxy.py index c208cb990..f0d1f7393 100644 --- a/dask_cuda/tests/test_proxy.py +++ b/dask_cuda/tests/test_proxy.py @@ -36,25 +36,45 @@ def test_proxy_object(serializers): pxy._obj_pxy_serialize(serializers=("dask", "pickle")) assert "dask_cuda.proxy_object.ProxyObject at " in repr(pxy) - assert "bytearray (serialized=('dask', 'pickle'))" in repr(pxy) + assert "bytearray (serialized='dask')" in repr(pxy) assert org == proxy_object.unproxy(pxy) assert org == proxy_object.unproxy(org) +class DummyObj: + """Class that only "pickle" can serialize""" + + def __reduce__(self): + return (DummyObj, ()) + + +def test_proxy_object_serializer(): + """Check the serializers argument""" + pxy = proxy_object.asproxy(DummyObj(), serializers=("dask", "pickle")) + assert pxy._obj_pxy["serializer"] == "pickle" + assert "DummyObj (serialized='pickle')" in repr(pxy) + + with pytest.raises(ValueError) as excinfo: + pxy = proxy_object.asproxy([42], serializers=("dask", "pickle")) + assert "Cannot wrap a collection" in str(excinfo.value) + + @pytest.mark.parametrize("serializers_first", [None, ("dask", "pickle")]) @pytest.mark.parametrize("serializers_second", [None, ("dask", "pickle")]) def test_double_proxy_object(serializers_first, serializers_second): """Check asproxy() when creating a proxy object of a proxy object""" + serializer1 = serializers_first[0] if serializers_first else None + serializer2 = serializers_second[0] if serializers_second else None org = bytearray(range(10)) pxy1 = proxy_object.asproxy(org, serializers=serializers_first) - assert pxy1._obj_pxy["serializers"] == serializers_first + assert pxy1._obj_pxy["serializer"] == serializer1 pxy2 = proxy_object.asproxy(pxy1, serializers=serializers_second) if serializers_second is None: # Check that `serializers=None` doesn't change the initial serializers - assert pxy2._obj_pxy["serializers"] == serializers_first + assert pxy2._obj_pxy["serializer"] == serializer1 else: - assert pxy2._obj_pxy["serializers"] == serializers_second + assert pxy2._obj_pxy["serializer"] == serializer2 assert pxy1 is pxy2 @@ -257,7 +277,7 @@ def task(x): if jit_unspill: # Check that `x` is a proxy object and the proxied DataFrame is serialized assert "FrameProxyObject" in str(type(x)) - assert x._obj_pxy["serializers"] == ("dask", "pickle") + assert x._obj_pxy["serializer"] == "dask" else: assert type(x) == cudf.DataFrame assert len(x) == 10 # Trigger deserialization @@ -292,7 +312,7 @@ def __dask_tokenize__(self): def _obj_pxy_deserialize(self): if self._obj_pxy["assert_on_deserializing"]: - assert self._obj_pxy["serializers"] is None + assert self._obj_pxy["serializer"] is None return super()._obj_pxy_deserialize() @@ -305,16 +325,16 @@ def test_communicating_proxy_objects(protocol, send_serializers): def task(x): # Check that the subclass survives the trip from client to worker assert isinstance(x, _PxyObjTest) - serializers_used = x._obj_pxy["serializers"] + serializers_used = x._obj_pxy["serializer"] # Check that `x` is serialized with the expected serializers if protocol == "ucx": if send_serializers is None: - assert serializers_used == ("cuda",) + assert serializers_used == "cuda" else: - assert serializers_used == send_serializers + assert serializers_used == send_serializers[0] else: - assert serializers_used == ("dask", "pickle") + assert serializers_used == "dask" with dask_cuda.LocalCUDACluster( n_workers=1, protocol=protocol, enable_tcp_over_ucx=protocol == "ucx" From d98db987dd67db4a23ab745cf5dcbb7a333ee72d Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Fri, 27 Aug 2021 13:54:44 +0200 Subject: [PATCH 10/27] Re-implemented proxy tracking --- dask_cuda/explicit_comms/dataframe/shuffle.py | 12 +- dask_cuda/proxify_device_objects.py | 13 - dask_cuda/proxify_host_file.py | 379 +++++++++--------- dask_cuda/proxy_object.py | 76 ++-- dask_cuda/tests/test_proxify_host_file.py | 115 +++--- 5 files changed, 316 insertions(+), 279 deletions(-) diff --git a/dask_cuda/explicit_comms/dataframe/shuffle.py b/dask_cuda/explicit_comms/dataframe/shuffle.py index aeea71467..d14902b29 100644 --- a/dask_cuda/explicit_comms/dataframe/shuffle.py +++ b/dask_cuda/explicit_comms/dataframe/shuffle.py @@ -18,7 +18,7 @@ from distributed import wait from distributed.protocol import nested_deserialize, to_serialize -from ...proxify_host_file import ProxifyHostFile +from ...proxify_host_file import ProxyManager from .. import comms @@ -148,19 +148,19 @@ async def local_shuffle( eps = s["eps"] try: - hostfile = first(iter(in_parts[0].values()))._obj_pxy.get( - "hostfile", lambda: None + manager = first(iter(in_parts[0].values()))._obj_pxy.get( + "manager", lambda: None )() except AttributeError: - hostfile = None + manager = None - if isinstance(hostfile, ProxifyHostFile): + if isinstance(manager, ProxyManager): def concat(args, ignore_index=False): if len(args) < 2: return args[0] - return hostfile.add_external(dd_concat(args, ignore_index=ignore_index)) + return manager.proxify(dd_concat(args, ignore_index=ignore_index)) else: concat = dd_concat diff --git a/dask_cuda/proxify_device_objects.py b/dask_cuda/proxify_device_objects.py index 70f90b0b8..8c45ebd49 100644 --- a/dask_cuda/proxify_device_objects.py +++ b/dask_cuda/proxify_device_objects.py @@ -190,11 +190,6 @@ def proxify_device_object_default( def proxify_device_object_proxy_object( obj, proxied_id_to_proxy, found_proxies, excl_proxies ): - # We deserialize CUDA-serialized objects since it is very cheap and - # makes it easy to administrate device memory usage - if obj._obj_pxy_is_serialized() and obj._obj_pxy["serializer"] != "cuda": - obj._obj_pxy_deserialize() - # Check if `obj` is already known if not obj._obj_pxy_is_serialized(): _id = id(obj._obj_pxy["obj"]) @@ -203,14 +198,6 @@ def proxify_device_object_proxy_object( else: proxied_id_to_proxy[_id] = obj - finalize = obj._obj_pxy.get("external_finalize", None) - if finalize: - finalize() - obj = obj._obj_pxy_copy() - if not obj._obj_pxy_is_serialized(): - _id = id(obj._obj_pxy["obj"]) - proxied_id_to_proxy[_id] = obj - if not excl_proxies: found_proxies.append(obj) return obj diff --git a/dask_cuda/proxify_host_file.py b/dask_cuda/proxify_host_file.py index 20eee76e4..1537cee12 100644 --- a/dask_cuda/proxify_host_file.py +++ b/dask_cuda/proxify_host_file.py @@ -1,3 +1,4 @@ +import abc import logging import threading import time @@ -11,9 +12,11 @@ Iterator, List, MutableMapping, + Optional, Set, Tuple, ) +from weakref import ReferenceType import dask from dask.sizeof import sizeof @@ -22,105 +25,224 @@ from .proxy_object import ProxyObject -class UnspilledProxies: - """Class to track current unspilled proxies""" +class Proxies(abc.ABC): + """Abstract base class to implement tracking of proxies + + This class is not threadsafe + """ def __init__(self): - self.dev_mem_usage = 0 + self._proxy_id_to_proxy: Dict[int, ReferenceType[ProxyObject]] = {} + self._mem_usage = 0 + + @abc.abstractmethod + def mem_usage_add(self, proxy: ProxyObject) -> None: + pass + + @abc.abstractmethod + def mem_usage_sub(self, proxy: ProxyObject) -> None: + pass + + def add(self, proxy: ProxyObject) -> None: + assert not self.contains_proxy_id(id(proxy)) + self._proxy_id_to_proxy[id(proxy)] = weakref.ref(proxy) + self.mem_usage_add(proxy) + + def remove(self, proxy_id: int) -> None: + weak_proxy = self._proxy_id_to_proxy.pop(proxy_id) + if weak_proxy is not None: + proxy = weak_proxy() + assert proxy is not None + self.mem_usage_sub(proxy) + + def __iter__(self) -> Iterator[ProxyObject]: + for p in self._proxy_id_to_proxy.values(): + ret = p() + if ret is not None: + yield ret + + def contains_proxy_id(self, proxy_id: int): + return proxy_id in self._proxy_id_to_proxy + + def mem_usage(self) -> int: + return self._mem_usage + + +class ProxiesOnHost(Proxies): + def mem_usage_add(self, proxy: ProxyObject): + self._mem_usage += sizeof(proxy) + + def mem_usage_sub(self, proxy: ProxyObject): + self._mem_usage = sizeof(proxy) + + +class ProxiesOnDevice(Proxies): + """Implement tracking of proxies on the GPU + + This is a bit more complicated than ProxiesOnHost because we has to + handle that multiple proxy objects can refer to the same underlying + device memory object. Thus, we have to track aliasing and make sure + we don't count down the memory usage prematurely. + """ + + def __init__(self): + super().__init__() self.proxy_id_to_dev_mems: DefaultDict[int, Set[Hashable]] = defaultdict(set) self.dev_mem_to_proxy_ids: DefaultDict[Hashable, Set[int]] = defaultdict(set) - def add(self, proxy: ProxyObject): + def mem_usage_add(self, proxy: ProxyObject): proxy_id = id(proxy) - if proxy_id not in self.proxy_id_to_dev_mems: - for dev_mem in proxy._obj_pxy_get_device_memory_objects(): - self.proxy_id_to_dev_mems[proxy_id].add(dev_mem) - ps = self.dev_mem_to_proxy_ids[dev_mem] - if len(ps) == 0: - self.dev_mem_usage += sizeof(dev_mem) - ps.add(proxy_id) - - def remove(self, proxy: ProxyObject): + for dev_mem in proxy._obj_pxy_get_device_memory_objects(): + self.proxy_id_to_dev_mems[proxy_id].add(dev_mem) + ps = self.dev_mem_to_proxy_ids[dev_mem] + if len(ps) == 0: + self._mem_usage += sizeof(dev_mem) + ps.add(proxy_id) + + def mem_usage_sub(self, proxy: ProxyObject): proxy_id = id(proxy) - if proxy_id in self.proxy_id_to_dev_mems: - for dev_mem in self.proxy_id_to_dev_mems.pop(proxy_id): - self.dev_mem_to_proxy_ids[dev_mem].remove(proxy_id) - if len(self.dev_mem_to_proxy_ids[dev_mem]) == 0: - del self.dev_mem_to_proxy_ids[dev_mem] - self.dev_mem_usage -= sizeof(dev_mem) + for dev_mem in self.proxy_id_to_dev_mems.pop(proxy_id): + self.dev_mem_to_proxy_ids[dev_mem].remove(proxy_id) + if len(self.dev_mem_to_proxy_ids[dev_mem]) == 0: + del self.dev_mem_to_proxy_ids[dev_mem] + self._mem_usage -= sizeof(dev_mem) - def __iter__(self): - return iter(self.proxy_id_to_dev_mems) - -class ProxiesTally: +class ProxyManager: """ - This class together with UnspilledProxies implements the tracking of current - objects in device memory and the total memory usage. It turns out having to - re-calculate device memory usage continuously is too expensive. - - We have to track four events: - - When adding a new key to the host file - - When removing a key from the host file - - When a proxy in the host file is deserialized - - When a proxy in the host file is serialized - - However, it gets a bit complicated because: - - The value of a key in the host file can contain many proxy objects and a single - proxy object can be referred from many keys - - Multiple proxy objects can refer to the same underlying device memory object - - Proxy objects are not hashable thus we have to use the `id()` as key in - dictionaries - - ProxiesTally and UnspilledProxies implements this by carefully maintaining - dictionaries that maps to/from keys, proxy objects, and device memory objects. + This class together with Proxies, ProxiesOnHost, and ProxiesOnDevice + implements the tracking of all known proxies and their total host/device + memory usage. It turns out having to re-calculate memory usage continuously + is too expensive. + + The idea is to have the ProxifyHostFile or the proxies themself update + their location (device or host). The manager then tallies the total memory usage. + + Notice, the manager only keeps weak references to the proxies. """ - def __init__(self): + def __init__(self, device_memory_limit: int): self.lock = threading.RLock() - self.proxy_id_to_proxy: Dict[int, ProxyObject] = {} - self.key_to_proxy_ids: DefaultDict[Hashable, Set[int]] = defaultdict(set) - self.proxy_id_to_keys: DefaultDict[int, Set[Hashable]] = defaultdict(set) - self.unspilled_proxies = UnspilledProxies() + self._host = ProxiesOnHost() + self._dev = ProxiesOnDevice() + self._device_memory_limit = device_memory_limit + + def __repr__(self) -> str: + return ( + f"" + ) + + def pprint(self): + ret = f"{self}:\n" + for proxy in self._host: + ret += f" host - {repr(proxy)}\n" + for proxy in self._dev: + ret += f" dev - {repr(proxy)}\n" + return ret[:-1] # Strip last newline + + @staticmethod + def serializer_target(serializer: Optional[str]) -> str: + if serializer in ("dask", "pickle"): + return "host" + else: + return "dev" - def add_key(self, key, proxies: List[ProxyObject]): + def contains(self, proxy_id: int): with self.lock: - for proxy in proxies: - proxy_id = id(proxy) - self.proxy_id_to_proxy[proxy_id] = proxy - self.key_to_proxy_ids[key].add(proxy_id) - self.proxy_id_to_keys[proxy_id].add(key) - if not proxy._obj_pxy_is_serialized(): - self.unspilled_proxies.add(proxy) - - def del_key(self, key): + return self._host.contains_proxy_id( + proxy_id + ) or self._dev.contains_proxy_id(proxy_id) + + def add(self, proxy: ProxyObject) -> None: with self.lock: - for proxy_id in self.key_to_proxy_ids.pop(key, ()): - self.proxy_id_to_keys[proxy_id].remove(key) - if len(self.proxy_id_to_keys[proxy_id]) == 0: - del self.proxy_id_to_keys[proxy_id] - self.unspilled_proxies.remove(self.proxy_id_to_proxy.pop(proxy_id)) + if not self.contains(id(proxy)): + if self.serializer_target(proxy._obj_pxy["serializer"]) == "host": + self._host.add(proxy) + else: + self._dev.add(proxy) - def spill_proxy(self, proxy: ProxyObject): + def remove(self, proxy_id: int) -> None: + with self.lock: + if not self.contains(proxy_id): + self._host.remove(proxy_id) + self._dev.remove(proxy_id) + + def move( + self, + proxy: ProxyObject, + from_serializer: Optional[str], + to_serializer: Optional[str], + ) -> None: with self.lock: - self.unspilled_proxies.remove(proxy) + src = self.serializer_target(from_serializer) + dst = self.serializer_target(to_serializer) + if src == "host" and dst == "dev": + self._host.remove(id(proxy)) + self._dev.add(proxy) + elif src == "dev" and dst == "host": + self._host.add(proxy) + self._dev.remove(id(proxy)) + + def proxify(self, obj: object) -> object: + with self.lock: + found_proxies: List[ProxyObject] = [] + proxied_id_to_proxy: Dict[int, ProxyObject] = {} + ret = proxify_device_objects(obj, proxied_id_to_proxy, found_proxies) + last_access = time.monotonic() + self_weakref = weakref.ref(self) + for p in found_proxies: + p._obj_pxy["manager"] = self_weakref + p._obj_pxy["last_access"] = last_access + p._obj_pxy["finalizer"] = weakref.finalize(p, self.remove, id(p)) + self.add(p) + self.maybe_evict() + return ret - def unspill_proxy(self, proxy: ProxyObject): + def get_dev_buffer_to_proxies(self) -> DefaultDict[Hashable, List[ProxyObject]]: with self.lock: - self.unspilled_proxies.add(proxy) + # Notice, multiple proxy object can point to different non-overlapping + # parts of the same device buffer. + ret = defaultdict(list) + for proxy in self._dev: + for dev_buffer in proxy._obj_pxy_get_device_memory_objects(): + ret[dev_buffer].append(proxy) + return ret - def get_unspilled_proxies(self) -> Iterator[ProxyObject]: + def get_dev_access_info( + self, + ) -> Tuple[int, List[Tuple[int, int, List[ProxyObject]]]]: with self.lock: - for proxy_id in self.unspilled_proxies: - ret = self.proxy_id_to_proxy[proxy_id] - assert not ret._obj_pxy_is_serialized() - yield ret + total_dev_mem_usage = 0 + dev_buf_access = [] + for dev_buf, proxies in self.get_dev_buffer_to_proxies().items(): + last_access = max(p._obj_pxy.get("last_access", 0) for p in proxies) + size = sizeof(dev_buf) + dev_buf_access.append((last_access, size, proxies)) + total_dev_mem_usage += size + assert total_dev_mem_usage == self._dev.mem_usage() + return total_dev_mem_usage, dev_buf_access - def get_proxied_id_to_proxy(self) -> Dict[int, ProxyObject]: - return {id(p._obj_pxy["obj"]): p for p in self.get_unspilled_proxies()} + def evict(self, proxy: ProxyObject): + proxy._obj_pxy_serialize(serializers=("dask", "pickle")) - def get_dev_mem_usage(self) -> int: - return self.unspilled_proxies.dev_mem_usage + def maybe_evict(self, extra_dev_mem=0): + if ( # Shortcut when not evicting + self._dev.mem_usage() + extra_dev_mem <= self._device_memory_limit + ): + return + + with self.lock: + total_dev_mem_usage, dev_buf_access = self.get_dev_access_info() + total_dev_mem_usage += extra_dev_mem + if total_dev_mem_usage > self._device_memory_limit: + dev_buf_access.sort(key=lambda x: (x[0], -x[1])) + for _, size, proxies in dev_buf_access: + for p in proxies: + self.evict(p) + total_dev_mem_usage -= size + if total_dev_mem_usage <= self._device_memory_limit: + break class ProxifyHostFile(MutableMapping): @@ -158,7 +280,7 @@ def __init__(self, device_memory_limit: int, compatibility_mode: bool = None): self.device_memory_limit = device_memory_limit self.store: Dict[Hashable, Any] = {} self.lock = threading.RLock() - self.proxies_tally = ProxiesTally() + self.manager = ProxyManager(device_memory_limit) if compatibility_mode is None: self.compatibility_mode = dask.config.get( "jit-unspill-compatibility-mode", default=False @@ -191,122 +313,21 @@ def fast(self): ) return None - def get_dev_buffer_to_proxies(self) -> DefaultDict[Hashable, List[ProxyObject]]: - with self.lock: - # Notice, multiple proxy object can point to different non-overlapping - # parts of the same device buffer. - ret = defaultdict(list) - for proxy in self.proxies_tally.get_unspilled_proxies(): - for dev_buffer in proxy._obj_pxy_get_device_memory_objects(): - ret[dev_buffer].append(proxy) - return ret - - def get_access_info(self) -> Tuple[int, List[Tuple[int, int, List[ProxyObject]]]]: - with self.lock: - total_dev_mem_usage = 0 - dev_buf_access = [] - for dev_buf, proxies in self.get_dev_buffer_to_proxies().items(): - last_access = max(p._obj_pxy.get("last_access", 0) for p in proxies) - size = sizeof(dev_buf) - dev_buf_access.append((last_access, size, proxies)) - total_dev_mem_usage += size - return total_dev_mem_usage, dev_buf_access - - def add_external(self, obj): - """Add an external object to the hostfile that count against the - device_memory_limit but isn't part of the store. - - Normally, we use __setitem__ to store objects in the hostfile and make it - count against the device_memory_limit with the inherent consequence that - the objects are not freeable before subsequential calls to __delitem__. - This is a problem for long running tasks that want objects to count against - the device_memory_limit while freeing them ASAP without explicit calls to - __delitem__. - - Developer Notes - --------------- - In order to avoid holding references to the found proxies in `obj`, we - wrap them in `weakref.proxy(p)` and adds them to the `proxies_tally`. - In order to remove them from the `proxies_tally` again, we attach a - finalize(p) on the wrapped proxies that calls del_external(). - """ - - # Notice, since `self.store` isn't modified, no lock is needed - found_proxies: List[ProxyObject] = [] - proxied_id_to_proxy = {} - # Notice, we are excluding found objects that are already proxies - ret = proxify_device_objects( - obj, proxied_id_to_proxy, found_proxies, excl_proxies=True - ) - last_access = time.monotonic() - self_weakref = weakref.ref(self) - for p in found_proxies: - name = id(p) - finalize = weakref.finalize(p, self.del_external, name) - external = weakref.proxy(p) - p._obj_pxy["hostfile"] = self_weakref - p._obj_pxy["last_access"] = last_access - p._obj_pxy["external"] = external - p._obj_pxy["external_finalize"] = finalize - self.proxies_tally.add_key(name, [external]) - self.maybe_evict() - return ret - - def del_external(self, name): - self.proxies_tally.del_key(name) - def __setitem__(self, key, value): with self.lock: if key in self.store: # Make sure we register the removal of an existing key del self[key] - - found_proxies: List[ProxyObject] = [] - proxied_id_to_proxy = self.proxies_tally.get_proxied_id_to_proxy() - self.store[key] = proxify_device_objects( - value, proxied_id_to_proxy, found_proxies - ) - last_access = time.monotonic() - self_weakref = weakref.ref(self) - for p in found_proxies: - p._obj_pxy["hostfile"] = self_weakref - p._obj_pxy["last_access"] = last_access - assert "external" not in p._obj_pxy - - self.proxies_tally.add_key(key, found_proxies) - self.maybe_evict() + self.store[key] = self.manager.proxify(value) def __getitem__(self, key): with self.lock: ret = self.store[key] if self.compatibility_mode: ret = unproxify_device_objects(ret, skip_explicit_proxies=True) - self.maybe_evict() + self.manager.maybe_evict() return ret def __delitem__(self, key): with self.lock: del self.store[key] - self.proxies_tally.del_key(key) - - def evict(self, proxy: ProxyObject): - proxy._obj_pxy_serialize(serializers=("dask", "pickle")) - - def maybe_evict(self, extra_dev_mem=0): - if ( # Shortcut when not evicting - self.proxies_tally.get_dev_mem_usage() + extra_dev_mem - <= self.device_memory_limit - ): - return - - with self.lock: - total_dev_mem_usage, dev_buf_access = self.get_access_info() - total_dev_mem_usage += extra_dev_mem - if total_dev_mem_usage > self.device_memory_limit: - dev_buf_access.sort(key=lambda x: (x[0], -x[1])) - for _, size, proxies in dev_buf_access: - for p in proxies: - self.evict(p) - total_dev_mem_usage -= size - if total_dev_mem_usage <= self.device_memory_limit: - break diff --git a/dask_cuda/proxy_object.py b/dask_cuda/proxy_object.py index 9098255e0..9cc061c8f 100644 --- a/dask_cuda/proxy_object.py +++ b/dask_cuda/proxy_object.py @@ -5,6 +5,9 @@ import threading import time from collections import OrderedDict +from contextlib import ( # TODO: use `contextlib.nullcontext()` from Python 3.7+ + suppress as nullcontext, +) from typing import Any, Dict, Iterable, Optional, Set, Type import pandas @@ -292,19 +295,23 @@ def _obj_pxy_serialize(self, serializers: Iterable[str]): # The proxied object is serialized with other serializers self._obj_pxy_deserialize() - header, _ = self._obj_pxy["obj"] = distributed.protocol.serialize( - self._obj_pxy["obj"], serializers, on_error="raise" - ) - assert "is-collection" not in header # Collections not allowed - self._obj_pxy["serializer"] = header["serializer"] - hostfile = self._obj_pxy.get("hostfile", lambda: None)() - if hostfile is not None: - external = self._obj_pxy.get("external", self) - hostfile.proxies_tally.spill_proxy(external) - - # Invalidate the (possible) cached "device_memory_objects" - self._obj_pxy_cache.pop("device_memory_objects", None) - return self._obj_pxy["obj"] + # Lock manager (if any) + manager = self._obj_pxy.get("manager", lambda: None)() + with (nullcontext() if manager is None else manager.lock): + header, _ = self._obj_pxy["obj"] = distributed.protocol.serialize( + self._obj_pxy["obj"], serializers, on_error="raise" + ) + assert "is-collection" not in header # Collections not allowed + org_ser, new_ser = self._obj_pxy["serializer"], header["serializer"] + self._obj_pxy["serializer"] = new_ser + + # Tell the manager (if any) that this proxy has change serializer + if manager: + manager.move(self, from_serializer=org_ser, to_serializer=new_ser) + + # Invalidate the (possible) cached "device_memory_objects" + self._obj_pxy_cache.pop("device_memory_objects", None) + return self._obj_pxy["obj"] def _obj_pxy_deserialize(self, maybe_evict: bool = True): """Inplace deserialization of the proxied object @@ -321,26 +328,33 @@ def _obj_pxy_deserialize(self, maybe_evict: bool = True): """ with self._obj_pxy_lock: if self._obj_pxy_is_serialized(): - hostfile = self._obj_pxy.get("hostfile", lambda: None)() + manager = self._obj_pxy.get("manager", lambda: None)() + serializer = self._obj_pxy["serializer"] + # When not deserializing a CUDA-serialized proxied, we might have # to evict because of the increased device memory usage. - if maybe_evict and self._obj_pxy["serializer"] != "cuda": - if hostfile is not None: - # In order to avoid a potential deadlock, we skip the - # `maybe_evict()` call if another thread is also accessing - # the hostfile. - if hostfile.lock.acquire(blocking=False): - try: - hostfile.maybe_evict(self.__sizeof__()) - finally: - hostfile.lock.release() - - header, frames = self._obj_pxy["obj"] - self._obj_pxy["obj"] = distributed.protocol.deserialize(header, frames) - self._obj_pxy["serializer"] = None - if hostfile is not None: - external = self._obj_pxy.get("external", self) - hostfile.proxies_tally.unspill_proxy(external) + if manager and maybe_evict and serializer != "cuda": + # In order to avoid a potential deadlock, we skip the + # `maybe_evict()` call if another thread is also accessing + # the hostfile. + if manager.lock.acquire(blocking=False): + try: + manager.maybe_evict(self.__sizeof__()) + finally: + manager.lock.release() + + # Lock manager (if any) + with (nullcontext() if manager is None else manager.lock): + header, frames = self._obj_pxy["obj"] + self._obj_pxy["obj"] = distributed.protocol.deserialize( + header, frames + ) + self._obj_pxy["serializer"] = None + # Tell the manager (if any) that this proxy has change serializer + if manager: + manager.move( + self, from_serializer=serializer, to_serializer=None + ) self._obj_pxy["last_access"] = time.monotonic() return self._obj_pxy["obj"] diff --git a/dask_cuda/tests/test_proxify_host_file.py b/dask_cuda/tests/test_proxify_host_file.py index 437c62f0a..1b026b5ac 100644 --- a/dask_cuda/tests/test_proxify_host_file.py +++ b/dask_cuda/tests/test_proxify_host_file.py @@ -1,3 +1,5 @@ +from typing import Iterable + import numpy as np import pandas import pytest @@ -12,9 +14,9 @@ import dask_cuda import dask_cuda.proxify_device_objects -import dask_cuda.proxy_object from dask_cuda.get_device_memory_objects import get_device_memory_objects from dask_cuda.proxify_host_file import ProxifyHostFile +from dask_cuda.proxy_object import ProxyObject cupy = pytest.importorskip("cupy") cupy.cuda.set_allocator(None) @@ -27,53 +29,80 @@ dask_cuda.proxify_device_objects.ignore_types = () +def is_proxies_equal(p1: Iterable[ProxyObject], p2: Iterable[ProxyObject]): + """Check that two collections of proxies contains the same proxies (unordered) + + In order to avoid deserializing proxy objects when comparing them, + this funcntion compares object IDs. + """ + + ids1 = sorted([id(p) for p in p1]) + ids2 = sorted([id(p) for p in p2]) + return ids1 == ids2 + + def test_one_item_limit(): dhf = ProxifyHostFile(device_memory_limit=one_item_nbytes) - dhf["k1"] = one_item_array() + 42 - dhf["k2"] = one_item_array() + + a1 = one_item_array() + 42 + a2 = one_item_array() + dhf["k1"] = a1 + dhf["k2"] = a2 # Check k1 is spilled because of the newer k2 k1 = dhf["k1"] k2 = dhf["k2"] assert k1._obj_pxy_is_serialized() assert not k2._obj_pxy_is_serialized() + assert is_proxies_equal(dhf.manager._host, [k1]) + assert is_proxies_equal(dhf.manager._dev, [k2]) # Accessing k1 spills k2 and unspill k1 k1_val = k1[0] assert k1_val == 42 assert k2._obj_pxy_is_serialized() + assert is_proxies_equal(dhf.manager._host, [k2]) + assert is_proxies_equal(dhf.manager._dev, [k1]) # Duplicate arrays changes nothing dhf["k3"] = [k1, k2] assert not k1._obj_pxy_is_serialized() assert k2._obj_pxy_is_serialized() + assert is_proxies_equal(dhf.manager._host, [k2]) + assert is_proxies_equal(dhf.manager._dev, [k1]) # Adding a new array spills k1 and k2 dhf["k4"] = one_item_array() + k4 = dhf["k4"] assert k1._obj_pxy_is_serialized() assert k2._obj_pxy_is_serialized() assert not dhf["k4"]._obj_pxy_is_serialized() + assert is_proxies_equal(dhf.manager._host, [k1, k2]) + assert is_proxies_equal(dhf.manager._dev, [k4]) # Accessing k2 spills k1 and k4 k2[0] assert k1._obj_pxy_is_serialized() assert dhf["k4"]._obj_pxy_is_serialized() assert not k2._obj_pxy_is_serialized() + assert is_proxies_equal(dhf.manager._host, [k1, k4]) + assert is_proxies_equal(dhf.manager._dev, [k2]) # Deleting k2 does not change anything since k3 still holds a # reference to the underlying proxy object - assert dhf.proxies_tally.get_dev_mem_usage() == one_item_nbytes - p1 = list(dhf.proxies_tally.get_unspilled_proxies()) - assert len(p1) == 1 + assert dhf.manager.get_dev_access_info()[0] == one_item_nbytes + assert is_proxies_equal(dhf.manager._host, [k1, k4]) + assert is_proxies_equal(dhf.manager._dev, [k2]) del dhf["k2"] - assert dhf.proxies_tally.get_dev_mem_usage() == one_item_nbytes - p2 = list(dhf.proxies_tally.get_unspilled_proxies()) - assert len(p2) == 1 - assert p1[0] is p2[0] + assert is_proxies_equal(dhf.manager._host, [k1, k4]) + assert is_proxies_equal(dhf.manager._dev, [k2]) - # Overwriting "k3" with a non-cuda object, should be noticed + # Overwriting "k3" with a non-cuda object and deleting `k2` + # should empty the device dhf["k3"] = "non-cuda-object" - assert dhf.proxies_tally.get_dev_mem_usage() == 0 + del k2 + assert is_proxies_equal(dhf.manager._host, [k1, k4]) + assert is_proxies_equal(dhf.manager._dev, []) @pytest.mark.parametrize("jit_unspill", [True, False]) @@ -144,59 +173,45 @@ def test_cudf_get_device_memory_objects(): def test_externals(): + """Test adding objects directly to the manager + + Add an object directly to the manager makes it count against the + device_memory_limit but isn't part of the store. + + Normally, we use __setitem__ to store objects in the hostfile and make it + count against the device_memory_limit with the inherent consequence that + the objects are not freeable before subsequential calls to __delitem__. + This is a problem for long running tasks that want objects to count against + the device_memory_limit while freeing them ASAP without explicit calls to + __delitem__. + """ dhf = ProxifyHostFile(device_memory_limit=one_item_nbytes) dhf["k1"] = one_item_array() k1 = dhf["k1"] - k2 = dhf.add_external(one_item_array()) + k2 = dhf.manager.proxify(one_item_array()) # `k2` isn't part of the store but still triggers spilling of `k1` assert len(dhf) == 1 assert k1._obj_pxy_is_serialized() assert not k2._obj_pxy_is_serialized() + assert is_proxies_equal(dhf.manager._host, [k1]) + assert is_proxies_equal(dhf.manager._dev, [k2]) + k1[0] # Trigger spilling of `k2` assert not k1._obj_pxy_is_serialized() assert k2._obj_pxy_is_serialized() + assert is_proxies_equal(dhf.manager._host, [k2]) + assert is_proxies_equal(dhf.manager._dev, [k1]) + k2[0] # Trigger spilling of `k1` assert k1._obj_pxy_is_serialized() assert not k2._obj_pxy_is_serialized() - assert dhf.proxies_tally.get_dev_mem_usage() == one_item_nbytes + assert is_proxies_equal(dhf.manager._host, [k1]) + assert is_proxies_equal(dhf.manager._dev, [k2]) + # Removing `k2` also removes it from the tally del k2 - assert dhf.proxies_tally.get_dev_mem_usage() == 0 - assert len(list(dhf.proxies_tally.get_unspilled_proxies())) == 0 - - -def test_externals_setitem(): - dhf = ProxifyHostFile(device_memory_limit=one_item_nbytes) - k1 = dhf.add_external(one_item_array()) - assert type(k1) is dask_cuda.proxy_object.ProxyObject - assert len(dhf) == 0 - assert "external" in k1._obj_pxy - assert "external_finalize" in k1._obj_pxy - dhf["k1"] = k1 - k1 = dhf["k1"] - assert type(k1) is dask_cuda.proxy_object.ProxyObject - assert len(dhf) == 1 - assert "external" not in k1._obj_pxy - assert "external_finalize" not in k1._obj_pxy - - k1 = dhf.add_external(one_item_array()) - k1._obj_pxy_serialize(serializers=("dask", "pickle")) - dhf["k1"] = k1 - k1 = dhf["k1"] - assert type(k1) is dask_cuda.proxy_object.ProxyObject - assert len(dhf) == 1 - assert "external" not in k1._obj_pxy - assert "external_finalize" not in k1._obj_pxy - - dhf["k1"] = one_item_array() - assert len(dhf.proxies_tally.proxy_id_to_proxy) == 1 - assert dhf.proxies_tally.get_dev_mem_usage() == one_item_nbytes - k1 = dhf.add_external(k1) - assert len(dhf.proxies_tally.proxy_id_to_proxy) == 1 - assert dhf.proxies_tally.get_dev_mem_usage() == one_item_nbytes - k1 = dhf.add_external(dhf["k1"]) - assert len(dhf.proxies_tally.proxy_id_to_proxy) == 1 - assert dhf.proxies_tally.get_dev_mem_usage() == one_item_nbytes + assert is_proxies_equal(dhf.manager._host, [k1]) + assert is_proxies_equal(dhf.manager._dev, []) def test_proxify_device_objects_of_cupy_array(): From 1406076f7cfc7df32b6ba6572bc7cf24b442b85d Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Mon, 6 Sep 2021 09:40:47 +0200 Subject: [PATCH 11/27] get_device_memory_objects(): now incl. serialized data --- dask_cuda/get_device_memory_objects.py | 5 +---- dask_cuda/proxy_object.py | 3 --- 2 files changed, 1 insertion(+), 7 deletions(-) diff --git a/dask_cuda/get_device_memory_objects.py b/dask_cuda/get_device_memory_objects.py index bd00ba0f5..385f70793 100644 --- a/dask_cuda/get_device_memory_objects.py +++ b/dask_cuda/get_device_memory_objects.py @@ -28,10 +28,7 @@ def get_device_memory_objects(obj) -> set: @dispatch.register(object) def get_device_memory_objects_default(obj): if hasattr(obj, "_obj_pxy"): - if not obj._obj_pxy_is_serialized(): - return dispatch(obj._obj_pxy["obj"]) - else: - return [] + return dispatch(obj._obj_pxy["obj"]) if hasattr(obj, "data"): return dispatch(obj.data) if hasattr(obj, "_owner") and obj._owner is not None: diff --git a/dask_cuda/proxy_object.py b/dask_cuda/proxy_object.py index 9cc061c8f..9c1af125f 100644 --- a/dask_cuda/proxy_object.py +++ b/dask_cuda/proxy_object.py @@ -374,9 +374,6 @@ def _obj_pxy_is_cuda_object(self) -> bool: def _obj_pxy_get_device_memory_objects(self) -> Set: """Return all device memory objects within the proxied object. - Calling this when the proxied object is serialized returns the - empty list. - Returns ------- ret : set From d8136dcc8be70b6826e732ecf8915c2378d8882d Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Mon, 6 Sep 2021 09:42:38 +0200 Subject: [PATCH 12/27] ProxyObject: call the new finalizer --- dask_cuda/proxy_object.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/dask_cuda/proxy_object.py b/dask_cuda/proxy_object.py index 9c1af125f..b839629dd 100644 --- a/dask_cuda/proxy_object.py +++ b/dask_cuda/proxy_object.py @@ -221,10 +221,8 @@ def __init__( self._obj_pxy_cache: Dict[str, Any] = {} def __del__(self): - """In order to call `external_finalize()` ASAP, we call it here""" - external_finalize = self._obj_pxy.get("external_finalize", None) - if external_finalize is not None: - external_finalize() + """In order to call `finalizer()` ASAP, we call it here""" + self._obj_pxy.get("finalizer", lambda: None)() def _obj_pxy_get_init_args(self, include_obj=True): """Return the attributes needed to initialize a ProxyObject From 704037fe9062ee9b61699dd4d6a8324f786d52b1 Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Mon, 6 Sep 2021 09:44:52 +0200 Subject: [PATCH 13/27] proxify(): remove the external finalizer --- dask_cuda/proxify_device_objects.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/dask_cuda/proxify_device_objects.py b/dask_cuda/proxify_device_objects.py index 8c45ebd49..5783d36f0 100644 --- a/dask_cuda/proxify_device_objects.py +++ b/dask_cuda/proxify_device_objects.py @@ -165,14 +165,10 @@ def wrapper(*args, **kwargs): def proxify(obj, proxied_id_to_proxy, found_proxies, subclass=None): _id = id(obj) - if _id in proxied_id_to_proxy: - ret = proxied_id_to_proxy[_id] - finalize = ret._obj_pxy.get("external_finalize", None) - if finalize: - finalize() - proxied_id_to_proxy[_id] = ret = asproxy(obj, subclass=subclass) - else: + if _id not in proxied_id_to_proxy: proxied_id_to_proxy[_id] = ret = asproxy(obj, subclass=subclass) + else: + ret = proxied_id_to_proxy[_id] found_proxies.append(ret) return ret From 252e7172dba67febfa812014861add3ded572c73 Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Mon, 6 Sep 2021 09:46:41 +0200 Subject: [PATCH 14/27] Fixed _mem_usage typo --- dask_cuda/proxify_host_file.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dask_cuda/proxify_host_file.py b/dask_cuda/proxify_host_file.py index 1537cee12..4d52f30ee 100644 --- a/dask_cuda/proxify_host_file.py +++ b/dask_cuda/proxify_host_file.py @@ -73,13 +73,13 @@ def mem_usage_add(self, proxy: ProxyObject): self._mem_usage += sizeof(proxy) def mem_usage_sub(self, proxy: ProxyObject): - self._mem_usage = sizeof(proxy) + self._mem_usage -= sizeof(proxy) class ProxiesOnDevice(Proxies): """Implement tracking of proxies on the GPU - This is a bit more complicated than ProxiesOnHost because we has to + This is a bit more complicated than ProxiesOnHost because we have to handle that multiple proxy objects can refer to the same underlying device memory object. Thus, we have to track aliasing and make sure we don't count down the memory usage prematurely. From 2d7f03004f8c3f3127bcdf5c5513b70d3502834a Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Mon, 6 Sep 2021 09:48:05 +0200 Subject: [PATCH 15/27] Tracking now passes all tests --- dask_cuda/proxify_host_file.py | 53 +++++++++++++++-------- dask_cuda/tests/test_proxify_host_file.py | 4 ++ 2 files changed, 40 insertions(+), 17 deletions(-) diff --git a/dask_cuda/proxify_host_file.py b/dask_cuda/proxify_host_file.py index 4d52f30ee..edcc6b008 100644 --- a/dask_cuda/proxify_host_file.py +++ b/dask_cuda/proxify_host_file.py @@ -35,6 +35,9 @@ def __init__(self): self._proxy_id_to_proxy: Dict[int, ReferenceType[ProxyObject]] = {} self._mem_usage = 0 + def __len__(self) -> int: + return len(self._proxy_id_to_proxy) + @abc.abstractmethod def mem_usage_add(self, proxy: ProxyObject) -> None: pass @@ -49,11 +52,11 @@ def add(self, proxy: ProxyObject) -> None: self.mem_usage_add(proxy) def remove(self, proxy_id: int) -> None: - weak_proxy = self._proxy_id_to_proxy.pop(proxy_id) - if weak_proxy is not None: - proxy = weak_proxy() - assert proxy is not None - self.mem_usage_sub(proxy) + proxy = self._proxy_id_to_proxy.pop(proxy_id)() + assert proxy is not None + self.mem_usage_sub(proxy) + if len(self._proxy_id_to_proxy) == 0: + assert self._mem_usage == 0, self._mem_usage def __iter__(self) -> Iterator[ProxyObject]: for p in self._proxy_id_to_proxy.values(): @@ -92,6 +95,7 @@ def __init__(self): def mem_usage_add(self, proxy: ProxyObject): proxy_id = id(proxy) + assert proxy_id not in self.proxy_id_to_dev_mems for dev_mem in proxy._obj_pxy_get_device_memory_objects(): self.proxy_id_to_dev_mems[proxy_id].add(dev_mem) ps = self.dev_mem_to_proxy_ids[dev_mem] @@ -101,6 +105,7 @@ def mem_usage_add(self, proxy: ProxyObject): def mem_usage_sub(self, proxy: ProxyObject): proxy_id = id(proxy) + assert proxy_id in self.proxy_id_to_dev_mems for dev_mem in self.proxy_id_to_dev_mems.pop(proxy_id): self.dev_mem_to_proxy_ids[dev_mem].remove(proxy_id) if len(self.dev_mem_to_proxy_ids[dev_mem]) == 0: @@ -130,11 +135,18 @@ def __init__(self, device_memory_limit: int): def __repr__(self) -> str: return ( f"" + f" host={self._host.mem_usage()}({len(self._host)})" + f" dev={self._dev.mem_usage()}({len(self._dev)})>" ) - def pprint(self): - ret = f"{self}:\n" + def __len__(self) -> int: + return len(self._host) + len(self._dev) + + def pprint(self) -> str: + ret = f"{self}:" + if len(self) == 0: + return ret + " Empty" + ret += "\n" for proxy in self._host: ret += f" host - {repr(proxy)}\n" for proxy in self._dev: @@ -148,7 +160,7 @@ def serializer_target(serializer: Optional[str]) -> str: else: return "dev" - def contains(self, proxy_id: int): + def contains(self, proxy_id: int) -> bool: with self.lock: return self._host.contains_proxy_id( proxy_id @@ -164,9 +176,15 @@ def add(self, proxy: ProxyObject) -> None: def remove(self, proxy_id: int) -> None: with self.lock: - if not self.contains(proxy_id): - self._host.remove(proxy_id) - self._dev.remove(proxy_id) + # Find where the proxy is located and remove it + proxies = None + if self._host.contains_proxy_id(proxy_id): + proxies = self._host + if self._dev.contains_proxy_id(proxy_id): + assert proxies is None, "Proxy in multiple locations" + proxies = self._dev + assert proxies is not None, "Trying to remove unknown proxy" + proxies.remove(proxy_id) def move( self, @@ -192,10 +210,11 @@ def proxify(self, obj: object) -> object: last_access = time.monotonic() self_weakref = weakref.ref(self) for p in found_proxies: - p._obj_pxy["manager"] = self_weakref p._obj_pxy["last_access"] = last_access - p._obj_pxy["finalizer"] = weakref.finalize(p, self.remove, id(p)) - self.add(p) + if not self.contains(id(p)): + p._obj_pxy["manager"] = self_weakref + p._obj_pxy["finalizer"] = weakref.finalize(p, self.remove, id(p)) + self.add(p) self.maybe_evict() return ret @@ -223,10 +242,10 @@ def get_dev_access_info( assert total_dev_mem_usage == self._dev.mem_usage() return total_dev_mem_usage, dev_buf_access - def evict(self, proxy: ProxyObject): + def evict(self, proxy: ProxyObject) -> None: proxy._obj_pxy_serialize(serializers=("dask", "pickle")) - def maybe_evict(self, extra_dev_mem=0): + def maybe_evict(self, extra_dev_mem=0) -> None: if ( # Shortcut when not evicting self._dev.mem_usage() + extra_dev_mem <= self._device_memory_limit ): diff --git a/dask_cuda/tests/test_proxify_host_file.py b/dask_cuda/tests/test_proxify_host_file.py index 1b026b5ac..05b5223c8 100644 --- a/dask_cuda/tests/test_proxify_host_file.py +++ b/dask_cuda/tests/test_proxify_host_file.py @@ -195,23 +195,27 @@ def test_externals(): assert not k2._obj_pxy_is_serialized() assert is_proxies_equal(dhf.manager._host, [k1]) assert is_proxies_equal(dhf.manager._dev, [k2]) + assert dhf.manager._dev._mem_usage == one_item_nbytes k1[0] # Trigger spilling of `k2` assert not k1._obj_pxy_is_serialized() assert k2._obj_pxy_is_serialized() assert is_proxies_equal(dhf.manager._host, [k2]) assert is_proxies_equal(dhf.manager._dev, [k1]) + assert dhf.manager._dev._mem_usage == one_item_nbytes k2[0] # Trigger spilling of `k1` assert k1._obj_pxy_is_serialized() assert not k2._obj_pxy_is_serialized() assert is_proxies_equal(dhf.manager._host, [k1]) assert is_proxies_equal(dhf.manager._dev, [k2]) + assert dhf.manager._dev._mem_usage == one_item_nbytes # Removing `k2` also removes it from the tally del k2 assert is_proxies_equal(dhf.manager._host, [k1]) assert is_proxies_equal(dhf.manager._dev, []) + assert dhf.manager._dev._mem_usage == 0 def test_proxify_device_objects_of_cupy_array(): From 531932f8b06103839adc34b24f1d2a1d34de708e Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Mon, 6 Sep 2021 09:57:30 +0200 Subject: [PATCH 16/27] Some docs --- dask_cuda/proxify_host_file.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/dask_cuda/proxify_host_file.py b/dask_cuda/proxify_host_file.py index edcc6b008..27dd04325 100644 --- a/dask_cuda/proxify_host_file.py +++ b/dask_cuda/proxify_host_file.py @@ -40,18 +40,20 @@ def __len__(self) -> int: @abc.abstractmethod def mem_usage_add(self, proxy: ProxyObject) -> None: - pass + """Given a new proxy, update `self._mem_usage`""" @abc.abstractmethod def mem_usage_sub(self, proxy: ProxyObject) -> None: - pass + """Removal of proxy, update `self._mem_usage`""" def add(self, proxy: ProxyObject) -> None: + """Add a proxy for tracking, calls `self.mem_usage_add`""" assert not self.contains_proxy_id(id(proxy)) self._proxy_id_to_proxy[id(proxy)] = weakref.ref(proxy) self.mem_usage_add(proxy) def remove(self, proxy_id: int) -> None: + """Remove proxy from tracking, calls `self.mem_usage_sub`""" proxy = self._proxy_id_to_proxy.pop(proxy_id)() assert proxy is not None self.mem_usage_sub(proxy) @@ -64,7 +66,7 @@ def __iter__(self) -> Iterator[ProxyObject]: if ret is not None: yield ret - def contains_proxy_id(self, proxy_id: int): + def contains_proxy_id(self, proxy_id: int) -> bool: return proxy_id in self._proxy_id_to_proxy def mem_usage(self) -> int: @@ -72,6 +74,11 @@ def mem_usage(self) -> int: class ProxiesOnHost(Proxies): + """Implement tracking of proxies on the CPU + + This uses dask.sizeof to update memory usage. + """ + def mem_usage_add(self, proxy: ProxyObject): self._mem_usage += sizeof(proxy) From 7d5ab18345ee0edbe7a40f6b04166b602153b3a6 Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Mon, 6 Sep 2021 17:04:48 +0200 Subject: [PATCH 17/27] Minor clean up --- dask_cuda/proxy_object.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/dask_cuda/proxy_object.py b/dask_cuda/proxy_object.py index b839629dd..49ae20858 100644 --- a/dask_cuda/proxy_object.py +++ b/dask_cuda/proxy_object.py @@ -303,7 +303,7 @@ def _obj_pxy_serialize(self, serializers: Iterable[str]): org_ser, new_ser = self._obj_pxy["serializer"], header["serializer"] self._obj_pxy["serializer"] = new_ser - # Tell the manager (if any) that this proxy has change serializer + # Tell the manager (if any) that this proxy has changed serializer if manager: manager.move(self, from_serializer=org_ser, to_serializer=new_ser) @@ -317,7 +317,7 @@ def _obj_pxy_deserialize(self, maybe_evict: bool = True): Parameters ---------- maybe_evict: bool - Before deserializing, call associated hostfile.maybe_evict() + Before deserializing, maybe evict managered proxy objects Returns ------- @@ -365,8 +365,7 @@ def _obj_pxy_is_cuda_object(self) -> bool: ret : boolean Is the proxied object a CUDA object? """ - with self._obj_pxy_lock: - return self._obj_pxy["is_cuda_object"] + return self._obj_pxy["is_cuda_object"] @_obj_pxy_cache_wrapper("device_memory_objects") def _obj_pxy_get_device_memory_objects(self) -> Set: From c15c394ade3ef3097ae25d6d6491604d2eafd118 Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Tue, 7 Sep 2021 08:23:35 +0200 Subject: [PATCH 18/27] minor doc clean up --- dask_cuda/proxy_object.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dask_cuda/proxy_object.py b/dask_cuda/proxy_object.py index 49ae20858..c8e94f992 100644 --- a/dask_cuda/proxy_object.py +++ b/dask_cuda/proxy_object.py @@ -333,7 +333,7 @@ def _obj_pxy_deserialize(self, maybe_evict: bool = True): # to evict because of the increased device memory usage. if manager and maybe_evict and serializer != "cuda": # In order to avoid a potential deadlock, we skip the - # `maybe_evict()` call if another thread is also accessing + # evict call if another thread is also accessing # the hostfile. if manager.lock.acquire(blocking=False): try: From 4cc1a3e332dc064343ad0e6fc2659ffa245b3d39 Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Tue, 7 Sep 2021 08:51:40 +0200 Subject: [PATCH 19/27] Proxies now use use the manager's lock --- dask_cuda/explicit_comms/dataframe/shuffle.py | 4 +--- dask_cuda/proxify_host_file.py | 5 ++--- dask_cuda/proxy_object.py | 22 ++++++++++++++----- 3 files changed, 19 insertions(+), 12 deletions(-) diff --git a/dask_cuda/explicit_comms/dataframe/shuffle.py b/dask_cuda/explicit_comms/dataframe/shuffle.py index d14902b29..cce5480e7 100644 --- a/dask_cuda/explicit_comms/dataframe/shuffle.py +++ b/dask_cuda/explicit_comms/dataframe/shuffle.py @@ -148,9 +148,7 @@ async def local_shuffle( eps = s["eps"] try: - manager = first(iter(in_parts[0].values()))._obj_pxy.get( - "manager", lambda: None - )() + manager = first(iter(in_parts[0].values()))._obj_pxy.get("manager", None) except AttributeError: manager = None diff --git a/dask_cuda/proxify_host_file.py b/dask_cuda/proxify_host_file.py index 27dd04325..7a18116b6 100644 --- a/dask_cuda/proxify_host_file.py +++ b/dask_cuda/proxify_host_file.py @@ -184,7 +184,7 @@ def add(self, proxy: ProxyObject) -> None: def remove(self, proxy_id: int) -> None: with self.lock: # Find where the proxy is located and remove it - proxies = None + proxies: Optional[Proxies] = None if self._host.contains_proxy_id(proxy_id): proxies = self._host if self._dev.contains_proxy_id(proxy_id): @@ -215,11 +215,10 @@ def proxify(self, obj: object) -> object: proxied_id_to_proxy: Dict[int, ProxyObject] = {} ret = proxify_device_objects(obj, proxied_id_to_proxy, found_proxies) last_access = time.monotonic() - self_weakref = weakref.ref(self) for p in found_proxies: p._obj_pxy["last_access"] = last_access if not self.contains(id(p)): - p._obj_pxy["manager"] = self_weakref + p._obj_pxy_register_manager(self) p._obj_pxy["finalizer"] = weakref.finalize(p, self.remove, id(p)) self.add(p) self.maybe_evict() diff --git a/dask_cuda/proxy_object.py b/dask_cuda/proxy_object.py index c8e94f992..a95ad0e79 100644 --- a/dask_cuda/proxy_object.py +++ b/dask_cuda/proxy_object.py @@ -8,7 +8,7 @@ from contextlib import ( # TODO: use `contextlib.nullcontext()` from Python 3.7+ suppress as nullcontext, ) -from typing import Any, Dict, Iterable, Optional, Set, Type +from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional, Type import pandas @@ -35,6 +35,10 @@ from .get_device_memory_objects import get_device_memory_objects from .is_device_object import is_device_object +if TYPE_CHECKING: + from .proxify_host_file import ProxyManager + + # List of attributes that should be copied to the proxy at creation, which makes # them accessible without deserialization of the proxied object _FIXED_ATTRS = ["name", "__len__"] @@ -224,7 +228,7 @@ def __del__(self): """In order to call `finalizer()` ASAP, we call it here""" self._obj_pxy.get("finalizer", lambda: None)() - def _obj_pxy_get_init_args(self, include_obj=True): + def _obj_pxy_get_init_args(self, include_obj=True) -> OrderedDict: """Return the attributes needed to initialize a ProxyObject Notice, the returned dictionary is ordered as the __init__() arguments @@ -263,7 +267,13 @@ def _obj_pxy_copy(self) -> "ProxyObject": args["obj"] = self._obj_pxy["obj"] return type(self)(**args) - def _obj_pxy_is_serialized(self): + def _obj_pxy_register_manager(self, manager: "ProxyManager") -> None: + with self._obj_pxy_lock: + assert "manager" not in self._obj_pxy + self._obj_pxy["manager"] = manager + self._obj_pxy_lock = manager.lock + + def _obj_pxy_is_serialized(self) -> bool: """Return whether the proxied object is serialized or not""" return self._obj_pxy["serializer"] is not None @@ -294,7 +304,7 @@ def _obj_pxy_serialize(self, serializers: Iterable[str]): self._obj_pxy_deserialize() # Lock manager (if any) - manager = self._obj_pxy.get("manager", lambda: None)() + manager = self._obj_pxy.get("manager", None) with (nullcontext() if manager is None else manager.lock): header, _ = self._obj_pxy["obj"] = distributed.protocol.serialize( self._obj_pxy["obj"], serializers, on_error="raise" @@ -326,7 +336,7 @@ def _obj_pxy_deserialize(self, maybe_evict: bool = True): """ with self._obj_pxy_lock: if self._obj_pxy_is_serialized(): - manager = self._obj_pxy.get("manager", lambda: None)() + manager = self._obj_pxy.get("manager", None) serializer = self._obj_pxy["serializer"] # When not deserializing a CUDA-serialized proxied, we might have @@ -368,7 +378,7 @@ def _obj_pxy_is_cuda_object(self) -> bool: return self._obj_pxy["is_cuda_object"] @_obj_pxy_cache_wrapper("device_memory_objects") - def _obj_pxy_get_device_memory_objects(self) -> Set: + def _obj_pxy_get_device_memory_objects(self) -> set: """Return all device memory objects within the proxied object. Returns From 19587c34aff190b9a8b0880d271d53328648df42 Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Tue, 7 Sep 2021 09:39:52 +0200 Subject: [PATCH 20/27] Docs and clean up --- dask_cuda/proxify_host_file.py | 5 ++-- dask_cuda/proxy_object.py | 55 +++++++++++++++++++++------------- 2 files changed, 38 insertions(+), 22 deletions(-) diff --git a/dask_cuda/proxify_host_file.py b/dask_cuda/proxify_host_file.py index 7a18116b6..d4e98a27c 100644 --- a/dask_cuda/proxify_host_file.py +++ b/dask_cuda/proxify_host_file.py @@ -218,8 +218,9 @@ def proxify(self, obj: object) -> object: for p in found_proxies: p._obj_pxy["last_access"] = last_access if not self.contains(id(p)): - p._obj_pxy_register_manager(self) - p._obj_pxy["finalizer"] = weakref.finalize(p, self.remove, id(p)) + p._obj_pxy_register_manager( + self, weakref.finalize(p, self.remove, id(p)) + ) self.add(p) self.maybe_evict() return ret diff --git a/dask_cuda/proxy_object.py b/dask_cuda/proxy_object.py index a95ad0e79..b573be0c6 100644 --- a/dask_cuda/proxy_object.py +++ b/dask_cuda/proxy_object.py @@ -4,6 +4,7 @@ import pickle import threading import time +import weakref from collections import OrderedDict from contextlib import ( # TODO: use `contextlib.nullcontext()` from Python 3.7+ suppress as nullcontext, @@ -267,11 +268,29 @@ def _obj_pxy_copy(self) -> "ProxyObject": args["obj"] = self._obj_pxy["obj"] return type(self)(**args) - def _obj_pxy_register_manager(self, manager: "ProxyManager") -> None: - with self._obj_pxy_lock: - assert "manager" not in self._obj_pxy - self._obj_pxy["manager"] = manager - self._obj_pxy_lock = manager.lock + def _obj_pxy_register_manager( + self, manager: "ProxyManager", finalizer: weakref.finalize + ) -> None: + """Register a manager + + The manager tallies the total memory usage of proxies and + evicts/serialize proxy objects as needed. + + In order to prevent deadlocks, the proxy now use the lock of the + manager. + + Parameters + ---------- + manager: ProxyManager + The manager to manage this proxy object + finalizer: weakref.finalize + The finalizer that should unregister the proxy from the manager + """ + assert "manager" not in self._obj_pxy + assert "finalizer" not in self._obj_pxy + self._obj_pxy["manager"] = manager + self._obj_pxy["finalizer"] = finalizer + self._obj_pxy_lock = manager.lock def _obj_pxy_is_serialized(self) -> bool: """Return whether the proxied object is serialized or not""" @@ -304,7 +323,7 @@ def _obj_pxy_serialize(self, serializers: Iterable[str]): self._obj_pxy_deserialize() # Lock manager (if any) - manager = self._obj_pxy.get("manager", None) + manager: "ProxyManager" = self._obj_pxy.get("manager", None) with (nullcontext() if manager is None else manager.lock): header, _ = self._obj_pxy["obj"] = distributed.protocol.serialize( self._obj_pxy["obj"], serializers, on_error="raise" @@ -336,29 +355,25 @@ def _obj_pxy_deserialize(self, maybe_evict: bool = True): """ with self._obj_pxy_lock: if self._obj_pxy_is_serialized(): - manager = self._obj_pxy.get("manager", None) + manager: "ProxyManager" = self._obj_pxy.get("manager", None) serializer = self._obj_pxy["serializer"] - # When not deserializing a CUDA-serialized proxied, we might have - # to evict because of the increased device memory usage. - if manager and maybe_evict and serializer != "cuda": - # In order to avoid a potential deadlock, we skip the - # evict call if another thread is also accessing - # the hostfile. - if manager.lock.acquire(blocking=False): - try: - manager.maybe_evict(self.__sizeof__()) - finally: - manager.lock.release() - # Lock manager (if any) with (nullcontext() if manager is None else manager.lock): + + # When not deserializing a CUDA-serialized proxied, tell the + # manager that it might have to evict because of the increased + # device memory usage. + if manager and maybe_evict and serializer != "cuda": + manager.maybe_evict(self.__sizeof__()) + + # Deserialize the proxied object header, frames = self._obj_pxy["obj"] self._obj_pxy["obj"] = distributed.protocol.deserialize( header, frames ) self._obj_pxy["serializer"] = None - # Tell the manager (if any) that this proxy has change serializer + # Tell the manager (if any) that this proxy has changed serializer if manager: manager.move( self, from_serializer=serializer, to_serializer=None From 3b614feaacb202f4fce1a0488989011f80d08a17 Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Tue, 7 Sep 2021 03:18:32 -0700 Subject: [PATCH 21/27] mem_usage_add(): proxy can have an empty set of dev buffers --- dask_cuda/proxify_host_file.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dask_cuda/proxify_host_file.py b/dask_cuda/proxify_host_file.py index d4e98a27c..412833674 100644 --- a/dask_cuda/proxify_host_file.py +++ b/dask_cuda/proxify_host_file.py @@ -97,12 +97,13 @@ class ProxiesOnDevice(Proxies): def __init__(self): super().__init__() - self.proxy_id_to_dev_mems: DefaultDict[int, Set[Hashable]] = defaultdict(set) + self.proxy_id_to_dev_mems: Dict[int, Set[Hashable]] = {} self.dev_mem_to_proxy_ids: DefaultDict[Hashable, Set[int]] = defaultdict(set) def mem_usage_add(self, proxy: ProxyObject): proxy_id = id(proxy) assert proxy_id not in self.proxy_id_to_dev_mems + self.proxy_id_to_dev_mems[proxy_id] = set() for dev_mem in proxy._obj_pxy_get_device_memory_objects(): self.proxy_id_to_dev_mems[proxy_id].add(dev_mem) ps = self.dev_mem_to_proxy_ids[dev_mem] @@ -112,7 +113,6 @@ def mem_usage_add(self, proxy: ProxyObject): def mem_usage_sub(self, proxy: ProxyObject): proxy_id = id(proxy) - assert proxy_id in self.proxy_id_to_dev_mems for dev_mem in self.proxy_id_to_dev_mems.pop(proxy_id): self.dev_mem_to_proxy_ids[dev_mem].remove(proxy_id) if len(self.dev_mem_to_proxy_ids[dev_mem]) == 0: From a70e664dbeb893f081fccb789960b5a7902d4823 Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Tue, 7 Sep 2021 12:20:18 +0200 Subject: [PATCH 22/27] Proxies.remove(): raise warning when the tally isn't going to zero --- dask_cuda/proxify_host_file.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/dask_cuda/proxify_host_file.py b/dask_cuda/proxify_host_file.py index 412833674..7f3c0b5fe 100644 --- a/dask_cuda/proxify_host_file.py +++ b/dask_cuda/proxify_host_file.py @@ -2,6 +2,7 @@ import logging import threading import time +import warnings import weakref from collections import defaultdict from typing import ( @@ -58,7 +59,13 @@ def remove(self, proxy_id: int) -> None: assert proxy is not None self.mem_usage_sub(proxy) if len(self._proxy_id_to_proxy) == 0: - assert self._mem_usage == 0, self._mem_usage + if self._mem_usage != 0: + warnings.warn( + "ProxyManager is empty but the tally of " + f"{self} is {self._mem_usage} bytes. " + "Resetting the tally." + ) + self._mem_usage = 0 def __iter__(self) -> Iterator[ProxyObject]: for p in self._proxy_id_to_proxy.values(): From 451d7525a0ea266f1c29244a438c4ea08f2fa2d6 Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Tue, 7 Sep 2021 12:44:57 +0200 Subject: [PATCH 23/27] now using the proxy's dtor instead of a finalizer --- dask_cuda/proxify_host_file.py | 21 +++++++++------------ dask_cuda/proxy_object.py | 15 +++++---------- 2 files changed, 14 insertions(+), 22 deletions(-) diff --git a/dask_cuda/proxify_host_file.py b/dask_cuda/proxify_host_file.py index 7f3c0b5fe..b69d860d2 100644 --- a/dask_cuda/proxify_host_file.py +++ b/dask_cuda/proxify_host_file.py @@ -53,10 +53,9 @@ def add(self, proxy: ProxyObject) -> None: self._proxy_id_to_proxy[id(proxy)] = weakref.ref(proxy) self.mem_usage_add(proxy) - def remove(self, proxy_id: int) -> None: + def remove(self, proxy: ProxyObject) -> None: """Remove proxy from tracking, calls `self.mem_usage_sub`""" - proxy = self._proxy_id_to_proxy.pop(proxy_id)() - assert proxy is not None + del self._proxy_id_to_proxy[id(proxy)] self.mem_usage_sub(proxy) if len(self._proxy_id_to_proxy) == 0: if self._mem_usage != 0: @@ -188,17 +187,17 @@ def add(self, proxy: ProxyObject) -> None: else: self._dev.add(proxy) - def remove(self, proxy_id: int) -> None: + def remove(self, proxy: ProxyObject) -> None: with self.lock: # Find where the proxy is located and remove it proxies: Optional[Proxies] = None - if self._host.contains_proxy_id(proxy_id): + if self._host.contains_proxy_id(id(proxy)): proxies = self._host - if self._dev.contains_proxy_id(proxy_id): + if self._dev.contains_proxy_id(id(proxy)): assert proxies is None, "Proxy in multiple locations" proxies = self._dev assert proxies is not None, "Trying to remove unknown proxy" - proxies.remove(proxy_id) + proxies.remove(proxy) def move( self, @@ -210,11 +209,11 @@ def move( src = self.serializer_target(from_serializer) dst = self.serializer_target(to_serializer) if src == "host" and dst == "dev": - self._host.remove(id(proxy)) + self._host.remove(proxy) self._dev.add(proxy) elif src == "dev" and dst == "host": self._host.add(proxy) - self._dev.remove(id(proxy)) + self._dev.remove(proxy) def proxify(self, obj: object) -> object: with self.lock: @@ -225,9 +224,7 @@ def proxify(self, obj: object) -> object: for p in found_proxies: p._obj_pxy["last_access"] = last_access if not self.contains(id(p)): - p._obj_pxy_register_manager( - self, weakref.finalize(p, self.remove, id(p)) - ) + p._obj_pxy_register_manager(self) self.add(p) self.maybe_evict() return ret diff --git a/dask_cuda/proxy_object.py b/dask_cuda/proxy_object.py index b573be0c6..4bfb42a31 100644 --- a/dask_cuda/proxy_object.py +++ b/dask_cuda/proxy_object.py @@ -4,7 +4,6 @@ import pickle import threading import time -import weakref from collections import OrderedDict from contextlib import ( # TODO: use `contextlib.nullcontext()` from Python 3.7+ suppress as nullcontext, @@ -226,8 +225,10 @@ def __init__( self._obj_pxy_cache: Dict[str, Any] = {} def __del__(self): - """In order to call `finalizer()` ASAP, we call it here""" - self._obj_pxy.get("finalizer", lambda: None)() + """We have to unregister us from the manager if any""" + manager: "ProxyManager" = self._obj_pxy.get("manager", None) + if manager is not None: + manager.remove(self) def _obj_pxy_get_init_args(self, include_obj=True) -> OrderedDict: """Return the attributes needed to initialize a ProxyObject @@ -268,9 +269,7 @@ def _obj_pxy_copy(self) -> "ProxyObject": args["obj"] = self._obj_pxy["obj"] return type(self)(**args) - def _obj_pxy_register_manager( - self, manager: "ProxyManager", finalizer: weakref.finalize - ) -> None: + def _obj_pxy_register_manager(self, manager: "ProxyManager") -> None: """Register a manager The manager tallies the total memory usage of proxies and @@ -283,13 +282,9 @@ def _obj_pxy_register_manager( ---------- manager: ProxyManager The manager to manage this proxy object - finalizer: weakref.finalize - The finalizer that should unregister the proxy from the manager """ assert "manager" not in self._obj_pxy - assert "finalizer" not in self._obj_pxy self._obj_pxy["manager"] = manager - self._obj_pxy["finalizer"] = finalizer self._obj_pxy_lock = manager.lock def _obj_pxy_is_serialized(self) -> bool: From 23c8b4b5717c25f3cb0c5034ca0eb3abb0170495 Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Tue, 7 Sep 2021 16:07:31 +0200 Subject: [PATCH 24/27] Implements get_proxies_by_serializer() --- dask_cuda/proxify_host_file.py | 25 +++++++++---------------- 1 file changed, 9 insertions(+), 16 deletions(-) diff --git a/dask_cuda/proxify_host_file.py b/dask_cuda/proxify_host_file.py index b69d860d2..7034b73bd 100644 --- a/dask_cuda/proxify_host_file.py +++ b/dask_cuda/proxify_host_file.py @@ -166,12 +166,11 @@ def pprint(self) -> str: ret += f" dev - {repr(proxy)}\n" return ret[:-1] # Strip last newline - @staticmethod - def serializer_target(serializer: Optional[str]) -> str: + def get_proxies_by_serializer(self, serializer: Optional[str]) -> Proxies: if serializer in ("dask", "pickle"): - return "host" + return self._host else: - return "dev" + return self._dev def contains(self, proxy_id: int) -> bool: with self.lock: @@ -182,10 +181,7 @@ def contains(self, proxy_id: int) -> bool: def add(self, proxy: ProxyObject) -> None: with self.lock: if not self.contains(id(proxy)): - if self.serializer_target(proxy._obj_pxy["serializer"]) == "host": - self._host.add(proxy) - else: - self._dev.add(proxy) + self.get_proxies_by_serializer(proxy._obj_pxy["serializer"]).add(proxy) def remove(self, proxy: ProxyObject) -> None: with self.lock: @@ -206,14 +202,11 @@ def move( to_serializer: Optional[str], ) -> None: with self.lock: - src = self.serializer_target(from_serializer) - dst = self.serializer_target(to_serializer) - if src == "host" and dst == "dev": - self._host.remove(proxy) - self._dev.add(proxy) - elif src == "dev" and dst == "host": - self._host.add(proxy) - self._dev.remove(proxy) + src = self.get_proxies_by_serializer(from_serializer) + dst = self.get_proxies_by_serializer(to_serializer) + if src is not dst: + src.remove(proxy) + dst.add(proxy) def proxify(self, obj: object) -> object: with self.lock: From a4fdb602e98ddc1cb183759599351b900fc03f6b Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Tue, 7 Sep 2021 16:47:01 +0200 Subject: [PATCH 25/27] removed evict() --- dask_cuda/proxify_host_file.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/dask_cuda/proxify_host_file.py b/dask_cuda/proxify_host_file.py index 7034b73bd..f0e4c9d3c 100644 --- a/dask_cuda/proxify_host_file.py +++ b/dask_cuda/proxify_host_file.py @@ -246,9 +246,6 @@ def get_dev_access_info( assert total_dev_mem_usage == self._dev.mem_usage() return total_dev_mem_usage, dev_buf_access - def evict(self, proxy: ProxyObject) -> None: - proxy._obj_pxy_serialize(serializers=("dask", "pickle")) - def maybe_evict(self, extra_dev_mem=0) -> None: if ( # Shortcut when not evicting self._dev.mem_usage() + extra_dev_mem <= self._device_memory_limit @@ -262,7 +259,8 @@ def maybe_evict(self, extra_dev_mem=0) -> None: dev_buf_access.sort(key=lambda x: (x[0], -x[1])) for _, size, proxies in dev_buf_access: for p in proxies: - self.evict(p) + # Serialize to disk, which "dask" and "pickle" does + p._obj_pxy_serialize(serializers=("dask", "pickle")) total_dev_mem_usage -= size if total_dev_mem_usage <= self._device_memory_limit: break From 5c4ca1a0331eb6566020893211a228911d995bd4 Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Wed, 8 Sep 2021 08:23:39 +0200 Subject: [PATCH 26/27] Clean up by @pentschev Co-authored-by: Peter Andreas Entschev --- dask_cuda/proxify_device_objects.py | 5 ++--- dask_cuda/proxify_host_file.py | 2 +- dask_cuda/proxy_object.py | 2 +- 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/dask_cuda/proxify_device_objects.py b/dask_cuda/proxify_device_objects.py index 5783d36f0..cd067d3d1 100644 --- a/dask_cuda/proxify_device_objects.py +++ b/dask_cuda/proxify_device_objects.py @@ -166,9 +166,8 @@ def wrapper(*args, **kwargs): def proxify(obj, proxied_id_to_proxy, found_proxies, subclass=None): _id = id(obj) if _id not in proxied_id_to_proxy: - proxied_id_to_proxy[_id] = ret = asproxy(obj, subclass=subclass) - else: - ret = proxied_id_to_proxy[_id] + proxied_id_to_proxy[_id] = asproxy(obj, subclass=subclass) + ret = proxied_id_to_proxy[_id] found_proxies.append(ret) return ret diff --git a/dask_cuda/proxify_host_file.py b/dask_cuda/proxify_host_file.py index f0e4c9d3c..34d9c828d 100644 --- a/dask_cuda/proxify_host_file.py +++ b/dask_cuda/proxify_host_file.py @@ -133,7 +133,7 @@ class ProxyManager: memory usage. It turns out having to re-calculate memory usage continuously is too expensive. - The idea is to have the ProxifyHostFile or the proxies themself update + The idea is to have the ProxifyHostFile or the proxies themselves update their location (device or host). The manager then tallies the total memory usage. Notice, the manager only keeps weak references to the proxies. diff --git a/dask_cuda/proxy_object.py b/dask_cuda/proxy_object.py index 4bfb42a31..5dd8651b4 100644 --- a/dask_cuda/proxy_object.py +++ b/dask_cuda/proxy_object.py @@ -450,7 +450,7 @@ def __repr__(self): return ret @property # type: ignore # mypy doesn't support decorated property - @_obj_pxy_cache_wrapper("type_serialized") # + @_obj_pxy_cache_wrapper("type_serialized") def __class__(self): return pickle.loads(self._obj_pxy["type_serialized"]) From 4225364cc5b83351c6ec8e1d0869feef85be8b80 Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Wed, 8 Sep 2021 08:35:40 +0200 Subject: [PATCH 27/27] Renamed mem_usage_sub => mem_usage_remove --- dask_cuda/proxify_host_file.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/dask_cuda/proxify_host_file.py b/dask_cuda/proxify_host_file.py index 34d9c828d..a056ad5b5 100644 --- a/dask_cuda/proxify_host_file.py +++ b/dask_cuda/proxify_host_file.py @@ -44,7 +44,7 @@ def mem_usage_add(self, proxy: ProxyObject) -> None: """Given a new proxy, update `self._mem_usage`""" @abc.abstractmethod - def mem_usage_sub(self, proxy: ProxyObject) -> None: + def mem_usage_remove(self, proxy: ProxyObject) -> None: """Removal of proxy, update `self._mem_usage`""" def add(self, proxy: ProxyObject) -> None: @@ -54,9 +54,9 @@ def add(self, proxy: ProxyObject) -> None: self.mem_usage_add(proxy) def remove(self, proxy: ProxyObject) -> None: - """Remove proxy from tracking, calls `self.mem_usage_sub`""" + """Remove proxy from tracking, calls `self.mem_usage_remove`""" del self._proxy_id_to_proxy[id(proxy)] - self.mem_usage_sub(proxy) + self.mem_usage_remove(proxy) if len(self._proxy_id_to_proxy) == 0: if self._mem_usage != 0: warnings.warn( @@ -88,7 +88,7 @@ class ProxiesOnHost(Proxies): def mem_usage_add(self, proxy: ProxyObject): self._mem_usage += sizeof(proxy) - def mem_usage_sub(self, proxy: ProxyObject): + def mem_usage_remove(self, proxy: ProxyObject): self._mem_usage -= sizeof(proxy) @@ -117,7 +117,7 @@ def mem_usage_add(self, proxy: ProxyObject): self._mem_usage += sizeof(dev_mem) ps.add(proxy_id) - def mem_usage_sub(self, proxy: ProxyObject): + def mem_usage_remove(self, proxy: ProxyObject): proxy_id = id(proxy) for dev_mem in self.proxy_id_to_dev_mems.pop(proxy_id): self.dev_mem_to_proxy_ids[dev_mem].remove(proxy_id)