diff --git a/dask_cuda/explicit_comms/dataframe/shuffle.py b/dask_cuda/explicit_comms/dataframe/shuffle.py index 985bedd87..0e0f0d9a8 100644 --- a/dask_cuda/explicit_comms/dataframe/shuffle.py +++ b/dask_cuda/explicit_comms/dataframe/shuffle.py @@ -11,12 +11,13 @@ import dask.dataframe import distributed from dask.base import compute_as_if_collection, tokenize -from dask.dataframe.core import DataFrame, _concat +from dask.dataframe.core import DataFrame, _concat as dd_concat from dask.dataframe.shuffle import shuffle_group from dask.delayed import delayed from distributed import wait from distributed.protocol import nested_deserialize, to_serialize +from ...proxify_host_file import ProxifyHostFile from .. import comms @@ -46,6 +47,7 @@ def sort_in_parts( rank_to_out_part_ids: Dict[int, List[int]], ignore_index: bool, concat_dfs_of_same_output_partition: bool, + concat, ) -> Dict[int, List[List[DataFrame]]]: """ Sort the list of grouped dataframes in `in_parts` @@ -96,7 +98,7 @@ def sort_in_parts( for i in range(len(rank_to_out_parts_list[rank])): if len(rank_to_out_parts_list[rank][i]) > 1: rank_to_out_parts_list[rank][i] = [ - _concat( + concat( rank_to_out_parts_list[rank][i], ignore_index=ignore_index ) ] @@ -144,11 +146,30 @@ async def local_shuffle( eps = s["eps"] assert s["rank"] in workers + try: + hostfile = first(iter(in_parts[0].values()))._obj_pxy.get( + "hostfile", lambda: None + )() + except AttributeError: + hostfile = None + + if isinstance(hostfile, ProxifyHostFile): + + def concat(args, ignore_index=False): + if len(args) < 2: + return args[0] + + return hostfile.add_external(dd_concat(args, ignore_index=ignore_index)) + + else: + concat = dd_concat + rank_to_out_parts_list = sort_in_parts( in_parts, rank_to_out_part_ids, ignore_index, concat_dfs_of_same_output_partition=True, + concat=concat, ) # Communicate all the dataframe-partitions all-to-all. The result is @@ -176,7 +197,7 @@ async def local_shuffle( dfs.extend(rank_to_out_parts_list[myrank][i]) rank_to_out_parts_list[myrank][i] = None if len(dfs) > 1: - ret.append(_concat(dfs, ignore_index=ignore_index)) + ret.append(concat(dfs, ignore_index=ignore_index)) else: ret.append(dfs[0]) return ret diff --git a/dask_cuda/proxify_device_objects.py b/dask_cuda/proxify_device_objects.py index 7c166003f..9baf5be8f 100644 --- a/dask_cuda/proxify_device_objects.py +++ b/dask_cuda/proxify_device_objects.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List +from typing import Any, List, MutableMapping from dask.utils import Dispatch @@ -9,8 +9,9 @@ def proxify_device_objects( obj: Any, - proxied_id_to_proxy: Dict[int, ProxyObject], + proxied_id_to_proxy: MutableMapping[int, ProxyObject], found_proxies: List[ProxyObject], + excl_proxies: bool = False, ): """ Wrap device objects in ProxyObject @@ -22,25 +23,31 @@ def proxify_device_objects( ---------- obj: Any Object to search through or wrap in a ProxyObject. - proxied_id_to_proxy: Dict[int, ProxyObject] + proxied_id_to_proxy: MutableMapping[int, ProxyObject] Dict mapping the id() of proxied objects (CUDA device objects) to their proxy and is updated with all new proxied objects found in `obj`. found_proxies: List[ProxyObject] List of found proxies in `obj`. Notice, this includes all proxies found, including those already in `proxied_id_to_proxy`. + excl_proxies: bool + Don't add found objects that are already ProxyObject to found_proxies. Returns ------- ret: Any A copy of `obj` where all CUDA device objects are wrapped in ProxyObject """ - return dispatch(obj, proxied_id_to_proxy, found_proxies) + return dispatch(obj, proxied_id_to_proxy, found_proxies, excl_proxies) def proxify(obj, proxied_id_to_proxy, found_proxies, subclass=None): _id = id(obj) if _id in proxied_id_to_proxy: ret = proxied_id_to_proxy[_id] + finalize = ret._obj_pxy.get("external_finalize", None) + if finalize: + finalize() + proxied_id_to_proxy[_id] = ret = asproxy(obj, subclass=subclass) else: proxied_id_to_proxy[_id] = ret = asproxy(obj, subclass=subclass) found_proxies.append(ret) @@ -48,15 +55,18 @@ def proxify(obj, proxied_id_to_proxy, found_proxies, subclass=None): @dispatch.register(object) -def proxify_device_object_default(obj, proxied_id_to_proxy, found_proxies): +def proxify_device_object_default( + obj, proxied_id_to_proxy, found_proxies, excl_proxies +): if hasattr(obj, "__cuda_array_interface__"): return proxify(obj, proxied_id_to_proxy, found_proxies) return obj @dispatch.register(ProxyObject) -def proxify_device_object_proxy_object(obj, proxied_id_to_proxy, found_proxies): - +def proxify_device_object_proxy_object( + obj, proxied_id_to_proxy, found_proxies, excl_proxies +): # We deserialize CUDA-serialized objects since it is very cheap and # makes it easy to administrate device memory usage if obj._obj_pxy_is_serialized() and "cuda" in obj._obj_pxy["serializers"]: @@ -70,7 +80,16 @@ def proxify_device_object_proxy_object(obj, proxied_id_to_proxy, found_proxies): else: proxied_id_to_proxy[_id] = obj - found_proxies.append(obj) + finalize = obj._obj_pxy.get("external_finalize", None) + if finalize: + finalize() + obj = obj._obj_pxy_copy() + if not obj._obj_pxy_is_serialized(): + _id = id(obj._obj_pxy["obj"]) + proxied_id_to_proxy[_id] = obj + + if not excl_proxies: + found_proxies.append(obj) return obj @@ -78,13 +97,22 @@ def proxify_device_object_proxy_object(obj, proxied_id_to_proxy, found_proxies): @dispatch.register(tuple) @dispatch.register(set) @dispatch.register(frozenset) -def proxify_device_object_python_collection(seq, proxied_id_to_proxy, found_proxies): - return type(seq)(dispatch(o, proxied_id_to_proxy, found_proxies) for o in seq) +def proxify_device_object_python_collection( + seq, proxied_id_to_proxy, found_proxies, excl_proxies +): + return type(seq)( + dispatch(o, proxied_id_to_proxy, found_proxies, excl_proxies) for o in seq + ) @dispatch.register(dict) -def proxify_device_object_python_dict(seq, proxied_id_to_proxy, found_proxies): - return {k: dispatch(v, proxied_id_to_proxy, found_proxies) for k, v in seq.items()} +def proxify_device_object_python_dict( + seq, proxied_id_to_proxy, found_proxies, excl_proxies +): + return { + k: dispatch(v, proxied_id_to_proxy, found_proxies, excl_proxies) + for k, v in seq.items() + } # Implement cuDF specific proxification @@ -107,7 +135,9 @@ class FrameProxyObject(ProxyObject, cudf._lib.table.Table): @dispatch.register(cudf.DataFrame) @dispatch.register(cudf.Series) @dispatch.register(cudf.Index) - def proxify_device_object_cudf_dataframe(obj, proxied_id_to_proxy, found_proxies): + def proxify_device_object_cudf_dataframe( + obj, proxied_id_to_proxy, found_proxies, excl_proxies + ): return proxify( obj, proxied_id_to_proxy, found_proxies, subclass=FrameProxyObject ) diff --git a/dask_cuda/proxify_host_file.py b/dask_cuda/proxify_host_file.py index a47669c37..477f94b5e 100644 --- a/dask_cuda/proxify_host_file.py +++ b/dask_cuda/proxify_host_file.py @@ -164,7 +164,7 @@ def get_dev_buffer_to_proxies(self) -> DefaultDict[Hashable, List[ProxyObject]]: with self.lock: # Notice, multiple proxy object can point to different non-overlapping # parts of the same device buffer. - ret = DefaultDict(list) + ret = defaultdict(list) for proxy in self.proxies_tally.get_unspilled_proxies(): for dev_buffer in proxy._obj_pxy_get_device_memory_objects(): ret[dev_buffer].append(proxy) @@ -181,22 +181,67 @@ def get_access_info(self) -> Tuple[int, List[Tuple[int, int, List[ProxyObject]]] total_dev_mem_usage += size return total_dev_mem_usage, dev_buf_access + def add_external(self, obj): + """Add an external object to the hostfile that count against the + device_memory_limit but isn't part of the store. + + Normally, we use __setitem__ to store objects in the hostfile and make it + count against the device_memory_limit with the inherent consequence that + the objects are not freeable before subsequential calls to __delitem__. + This is a problem for long running tasks that want objects to count against + the device_memory_limit while freeing them ASAP without explicit calls to + __delitem__. + + Developer Notes + --------------- + In order to avoid holding references to the found proxies in `obj`, we + wrap them in `weakref.proxy(p)` and adds them to the `proxies_tally`. + In order to remove them from the `proxies_tally` again, we attach a + finalize(p) on the wrapped proxies that calls del_external(). + """ + + # Notice, since `self.store` isn't modified, no lock is needed + found_proxies: List[ProxyObject] = [] + proxied_id_to_proxy = {} + # Notice, we are excluding found objects that are already proxies + ret = proxify_device_objects( + obj, proxied_id_to_proxy, found_proxies, excl_proxies=True + ) + last_access = time.monotonic() + self_weakref = weakref.ref(self) + for p in found_proxies: + name = id(p) + finalize = weakref.finalize(p, self.del_external, name) + external = weakref.proxy(p) + p._obj_pxy["hostfile"] = self_weakref + p._obj_pxy["last_access"] = last_access + p._obj_pxy["external"] = external + p._obj_pxy["external_finalize"] = finalize + self.proxies_tally.add_key(name, [external]) + self.maybe_evict() + return ret + + def del_external(self, name): + self.proxies_tally.del_key(name) + def __setitem__(self, key, value): with self.lock: if key in self.store: # Make sure we register the removal of an existing key del self[key] - found_proxies = [] + found_proxies: List[ProxyObject] = [] proxied_id_to_proxy = self.proxies_tally.get_proxied_id_to_proxy() self.store[key] = proxify_device_objects( value, proxied_id_to_proxy, found_proxies ) - last_access = time.time() + last_access = time.monotonic() self_weakref = weakref.ref(self) for p in found_proxies: p._obj_pxy["hostfile"] = self_weakref p._obj_pxy["last_access"] = last_access + assert "external" not in p._obj_pxy + self.proxies_tally.add_key(key, found_proxies) self.maybe_evict() diff --git a/dask_cuda/proxy_object.py b/dask_cuda/proxy_object.py index 26aa8ce77..206677d8d 100644 --- a/dask_cuda/proxy_object.py +++ b/dask_cuda/proxy_object.py @@ -199,6 +199,12 @@ def __init__( self._obj_pxy_lock = threading.RLock() self._obj_pxy_cache = {} + def __del__(self): + """In order to call `external_finalize()` ASAP, we call it here""" + external_finalize = self._obj_pxy.get("external_finalize", None) + if external_finalize is not None: + external_finalize() + def _obj_pxy_get_init_args(self, include_obj=True): """Return the attributes needed to initialize a ProxyObject @@ -277,7 +283,8 @@ def _obj_pxy_serialize(self, serializers): self._obj_pxy["serializers"] = serializers hostfile = self._obj_pxy.get("hostfile", lambda: None)() if hostfile is not None: - hostfile.proxies_tally.spill_proxy(self) + external = self._obj_pxy.get("external", self) + hostfile.proxies_tally.spill_proxy(external) # Invalidate the (possible) cached "device_memory_objects" self._obj_pxy_cache.pop("device_memory_objects", None) @@ -311,9 +318,10 @@ def _obj_pxy_deserialize(self): self._obj_pxy["obj"] = distributed.protocol.deserialize(header, frames) self._obj_pxy["serializers"] = None if hostfile is not None: - hostfile.proxies_tally.unspill_proxy(self) + external = self._obj_pxy.get("external", self) + hostfile.proxies_tally.unspill_proxy(external) - self._obj_pxy["last_access"] = time.time() + self._obj_pxy["last_access"] = time.monotonic() return self._obj_pxy["obj"] def _obj_pxy_is_cuda_object(self) -> bool: diff --git a/dask_cuda/tests/test_proxify_host_file.py b/dask_cuda/tests/test_proxify_host_file.py index a03c784b0..7e3053a2e 100644 --- a/dask_cuda/tests/test_proxify_host_file.py +++ b/dask_cuda/tests/test_proxify_host_file.py @@ -12,23 +12,24 @@ cupy = pytest.importorskip("cupy") cupy.cuda.set_allocator(None) -itemsize = cupy.arange(1).nbytes +one_item_array = lambda: cupy.arange(1) +one_item_nbytes = one_item_array().nbytes def test_one_item_limit(): - dhf = ProxifyHostFile(device_memory_limit=itemsize) - dhf["k1"] = cupy.arange(1) + 1 - dhf["k2"] = cupy.arange(1) + 2 + dhf = ProxifyHostFile(device_memory_limit=one_item_nbytes) + dhf["k1"] = one_item_array() + 42 + dhf["k2"] = one_item_array() # Check k1 is spilled because of the newer k2 k1 = dhf["k1"] + k2 = dhf["k2"] assert k1._obj_pxy_is_serialized() - assert not dhf["k2"]._obj_pxy_is_serialized() + assert not k2._obj_pxy_is_serialized() # Accessing k1 spills k2 and unspill k1 k1_val = k1[0] - assert k1_val == 1 - k2 = dhf["k2"] + assert k1_val == 42 assert k2._obj_pxy_is_serialized() # Duplicate arrays changes nothing @@ -37,7 +38,7 @@ def test_one_item_limit(): assert k2._obj_pxy_is_serialized() # Adding a new array spills k1 and k2 - dhf["k4"] = cupy.arange(1) + 4 + dhf["k4"] = one_item_array() assert k1._obj_pxy_is_serialized() assert k2._obj_pxy_is_serialized() assert not dhf["k4"]._obj_pxy_is_serialized() @@ -50,11 +51,11 @@ def test_one_item_limit(): # Deleting k2 does not change anything since k3 still holds a # reference to the underlying proxy object - assert dhf.proxies_tally.get_dev_mem_usage() == 8 + assert dhf.proxies_tally.get_dev_mem_usage() == one_item_nbytes p1 = list(dhf.proxies_tally.get_unspilled_proxies()) assert len(p1) == 1 del dhf["k2"] - assert dhf.proxies_tally.get_dev_mem_usage() == 8 + assert dhf.proxies_tally.get_dev_mem_usage() == one_item_nbytes p2 = list(dhf.proxies_tally.get_unspilled_proxies()) assert len(p2) == 1 assert p1[0] is p2[0] @@ -114,7 +115,7 @@ def test_dataframes_share_dev_mem(): assert not v1._obj_pxy_is_serialized() assert not v2._obj_pxy_is_serialized() # Now the device_memory_limit is exceeded, which should evict both dataframes - dhf["k1"] = cupy.arange(1) + dhf["k1"] = one_item_array() assert v1._obj_pxy_is_serialized() assert v2._obj_pxy_is_serialized() @@ -129,3 +130,59 @@ def test_cudf_get_device_memory_objects(): ] res = get_device_memory_objects(objects) assert len(res) == 4, "We expect four buffer objects" + + +def test_externals(): + dhf = ProxifyHostFile(device_memory_limit=one_item_nbytes) + dhf["k1"] = one_item_array() + k1 = dhf["k1"] + k2 = dhf.add_external(one_item_array()) + # `k2` isn't part of the store but still triggers spilling of `k1` + assert len(dhf) == 1 + assert k1._obj_pxy_is_serialized() + assert not k2._obj_pxy_is_serialized() + k1[0] # Trigger spilling of `k2` + assert not k1._obj_pxy_is_serialized() + assert k2._obj_pxy_is_serialized() + k2[0] # Trigger spilling of `k1` + assert k1._obj_pxy_is_serialized() + assert not k2._obj_pxy_is_serialized() + assert dhf.proxies_tally.get_dev_mem_usage() == one_item_nbytes + # Removing `k2` also removes it from the tally + del k2 + assert dhf.proxies_tally.get_dev_mem_usage() == 0 + assert len(list(dhf.proxies_tally.get_unspilled_proxies())) == 0 + + +def test_externals_setitem(): + dhf = ProxifyHostFile(device_memory_limit=one_item_nbytes) + k1 = dhf.add_external(one_item_array()) + assert type(k1) is dask_cuda.proxy_object.ProxyObject + assert len(dhf) == 0 + assert "external" in k1._obj_pxy + assert "external_finalize" in k1._obj_pxy + dhf["k1"] = k1 + k1 = dhf["k1"] + assert type(k1) is dask_cuda.proxy_object.ProxyObject + assert len(dhf) == 1 + assert "external" not in k1._obj_pxy + assert "external_finalize" not in k1._obj_pxy + + k1 = dhf.add_external(one_item_array()) + k1._obj_pxy_serialize(serializers=("dask", "pickle")) + dhf["k1"] = k1 + k1 = dhf["k1"] + assert type(k1) is dask_cuda.proxy_object.ProxyObject + assert len(dhf) == 1 + assert "external" not in k1._obj_pxy + assert "external_finalize" not in k1._obj_pxy + + dhf["k1"] = one_item_array() + assert len(dhf.proxies_tally.proxy_id_to_proxy) == 1 + assert dhf.proxies_tally.get_dev_mem_usage() == one_item_nbytes + k1 = dhf.add_external(k1) + assert len(dhf.proxies_tally.proxy_id_to_proxy) == 1 + assert dhf.proxies_tally.get_dev_mem_usage() == one_item_nbytes + k1 = dhf.add_external(dhf["k1"]) + assert len(dhf.proxies_tally.proxy_id_to_proxy) == 1 + assert dhf.proxies_tally.get_dev_mem_usage() == one_item_nbytes