Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ProxifyHostFile: tracking of external objects #527

Merged
merged 13 commits into from
Feb 25, 2021
27 changes: 24 additions & 3 deletions dask_cuda/explicit_comms/dataframe/shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,13 @@
import dask.dataframe
import distributed
from dask.base import compute_as_if_collection, tokenize
from dask.dataframe.core import DataFrame, _concat
from dask.dataframe.core import DataFrame, _concat as dd_concat
from dask.dataframe.shuffle import shuffle_group
from dask.delayed import delayed
from distributed import wait
from distributed.protocol import nested_deserialize, to_serialize

from ...proxify_host_file import ProxifyHostFile
from .. import comms


Expand Down Expand Up @@ -46,6 +47,7 @@ def sort_in_parts(
rank_to_out_part_ids: Dict[int, List[int]],
ignore_index: bool,
concat_dfs_of_same_output_partition: bool,
concat,
) -> Dict[int, List[List[DataFrame]]]:
""" Sort the list of grouped dataframes in `in_parts`

Expand Down Expand Up @@ -96,7 +98,7 @@ def sort_in_parts(
for i in range(len(rank_to_out_parts_list[rank])):
if len(rank_to_out_parts_list[rank][i]) > 1:
rank_to_out_parts_list[rank][i] = [
_concat(
concat(
rank_to_out_parts_list[rank][i], ignore_index=ignore_index
)
]
Expand Down Expand Up @@ -144,11 +146,30 @@ async def local_shuffle(
eps = s["eps"]
assert s["rank"] in workers

try:
hostfile = first(iter(in_parts[0].values()))._obj_pxy.get(
"hostfile", lambda: None
)()
except AttributeError:
hostfile = None

if isinstance(hostfile, ProxifyHostFile):

def concat(args, ignore_index=False):
if len(args) < 2:
return args[0]

return hostfile.add_external(dd_concat(args, ignore_index=ignore_index))

else:
concat = dd_concat

rank_to_out_parts_list = sort_in_parts(
in_parts,
rank_to_out_part_ids,
ignore_index,
concat_dfs_of_same_output_partition=True,
concat=concat,
)

# Communicate all the dataframe-partitions all-to-all. The result is
Expand Down Expand Up @@ -176,7 +197,7 @@ async def local_shuffle(
dfs.extend(rank_to_out_parts_list[myrank][i])
rank_to_out_parts_list[myrank][i] = None
if len(dfs) > 1:
ret.append(_concat(dfs, ignore_index=ignore_index))
ret.append(concat(dfs, ignore_index=ignore_index))
else:
ret.append(dfs[0])
return ret
Expand Down
56 changes: 43 additions & 13 deletions dask_cuda/proxify_device_objects.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, List
from typing import Any, List, MutableMapping

from dask.utils import Dispatch

Expand All @@ -9,8 +9,9 @@

def proxify_device_objects(
obj: Any,
proxied_id_to_proxy: Dict[int, ProxyObject],
proxied_id_to_proxy: MutableMapping[int, ProxyObject],
found_proxies: List[ProxyObject],
excl_proxies: bool = False,
):
""" Wrap device objects in ProxyObject

Expand All @@ -22,41 +23,50 @@ def proxify_device_objects(
----------
obj: Any
Object to search through or wrap in a ProxyObject.
proxied_id_to_proxy: Dict[int, ProxyObject]
proxied_id_to_proxy: MutableMapping[int, ProxyObject]
Dict mapping the id() of proxied objects (CUDA device objects) to
their proxy and is updated with all new proxied objects found in `obj`.
found_proxies: List[ProxyObject]
List of found proxies in `obj`. Notice, this includes all proxies found,
including those already in `proxied_id_to_proxy`.
excl_proxies: bool
Don't add found objects that are already ProxyObject to found_proxies.

Returns
-------
ret: Any
A copy of `obj` where all CUDA device objects are wrapped in ProxyObject
"""
return dispatch(obj, proxied_id_to_proxy, found_proxies)
return dispatch(obj, proxied_id_to_proxy, found_proxies, excl_proxies)


def proxify(obj, proxied_id_to_proxy, found_proxies, subclass=None):
_id = id(obj)
if _id in proxied_id_to_proxy:
ret = proxied_id_to_proxy[_id]
finalize = ret._obj_pxy.get("external_finalize", None)
if finalize:
finalize()
pentschev marked this conversation as resolved.
Show resolved Hide resolved
proxied_id_to_proxy[_id] = ret = asproxy(obj, subclass=subclass)
else:
proxied_id_to_proxy[_id] = ret = asproxy(obj, subclass=subclass)
found_proxies.append(ret)
return ret


@dispatch.register(object)
def proxify_device_object_default(obj, proxied_id_to_proxy, found_proxies):
def proxify_device_object_default(
obj, proxied_id_to_proxy, found_proxies, excl_proxies
):
if hasattr(obj, "__cuda_array_interface__"):
return proxify(obj, proxied_id_to_proxy, found_proxies)
return obj


@dispatch.register(ProxyObject)
def proxify_device_object_proxy_object(obj, proxied_id_to_proxy, found_proxies):

def proxify_device_object_proxy_object(
obj, proxied_id_to_proxy, found_proxies, excl_proxies
):
# We deserialize CUDA-serialized objects since it is very cheap and
# makes it easy to administrate device memory usage
if obj._obj_pxy_is_serialized() and "cuda" in obj._obj_pxy["serializers"]:
Expand All @@ -70,21 +80,39 @@ def proxify_device_object_proxy_object(obj, proxied_id_to_proxy, found_proxies):
else:
proxied_id_to_proxy[_id] = obj

found_proxies.append(obj)
finalize = obj._obj_pxy.get("external_finalize", None)
if finalize:
finalize()
madsbk marked this conversation as resolved.
Show resolved Hide resolved
obj = obj._obj_pxy_copy()
if not obj._obj_pxy_is_serialized():
_id = id(obj._obj_pxy["obj"])
proxied_id_to_proxy[_id] = obj

if not excl_proxies:
found_proxies.append(obj)
return obj


@dispatch.register(list)
@dispatch.register(tuple)
@dispatch.register(set)
@dispatch.register(frozenset)
def proxify_device_object_python_collection(seq, proxied_id_to_proxy, found_proxies):
return type(seq)(dispatch(o, proxied_id_to_proxy, found_proxies) for o in seq)
def proxify_device_object_python_collection(
seq, proxied_id_to_proxy, found_proxies, excl_proxies
):
return type(seq)(
dispatch(o, proxied_id_to_proxy, found_proxies, excl_proxies) for o in seq
)


@dispatch.register(dict)
def proxify_device_object_python_dict(seq, proxied_id_to_proxy, found_proxies):
return {k: dispatch(v, proxied_id_to_proxy, found_proxies) for k, v in seq.items()}
def proxify_device_object_python_dict(
seq, proxied_id_to_proxy, found_proxies, excl_proxies
):
return {
k: dispatch(v, proxied_id_to_proxy, found_proxies, excl_proxies)
for k, v in seq.items()
}


# Implement cuDF specific proxification
Expand All @@ -107,7 +135,9 @@ class FrameProxyObject(ProxyObject, cudf._lib.table.Table):
@dispatch.register(cudf.DataFrame)
@dispatch.register(cudf.Series)
@dispatch.register(cudf.Index)
def proxify_device_object_cudf_dataframe(obj, proxied_id_to_proxy, found_proxies):
def proxify_device_object_cudf_dataframe(
obj, proxied_id_to_proxy, found_proxies, excl_proxies
):
return proxify(
obj, proxied_id_to_proxy, found_proxies, subclass=FrameProxyObject
)
31 changes: 29 additions & 2 deletions dask_cuda/proxify_host_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def get_dev_buffer_to_proxies(self) -> DefaultDict[Hashable, List[ProxyObject]]:
with self.lock:
# Notice, multiple proxy object can point to different non-overlapping
# parts of the same device buffer.
ret = DefaultDict(list)
ret = defaultdict(list)
for proxy in self.proxies_tally.get_unspilled_proxies():
for dev_buffer in proxy._obj_pxy_get_device_memory_objects():
ret[dev_buffer].append(proxy)
Expand All @@ -181,13 +181,38 @@ def get_access_info(self) -> Tuple[int, List[Tuple[int, int, List[ProxyObject]]]
total_dev_mem_usage += size
return total_dev_mem_usage, dev_buf_access

def add_external(self, obj):
# Notice, since `self.store` isn't modified, no lock is needed
found_proxies: List[ProxyObject] = []
proxied_id_to_proxy = {}
# Notice, we are excluding found objects that are already proxies
ret = proxify_device_objects(
obj, proxied_id_to_proxy, found_proxies, excl_proxies=True
)
last_access = time.time()
madsbk marked this conversation as resolved.
Show resolved Hide resolved
self_weakref = weakref.ref(self)
for p in found_proxies:
name = id(p)
finalize = weakref.finalize(p, self.del_external, name)
external = weakref.proxy(p)
p._obj_pxy["hostfile"] = self_weakref
p._obj_pxy["last_access"] = last_access
p._obj_pxy["external"] = external
p._obj_pxy["external_finalize"] = finalize
self.proxies_tally.add_key(name, [external])
self.maybe_evict()
return ret

def del_external(self, name):
self.proxies_tally.del_key(name)

def __setitem__(self, key, value):
with self.lock:
if key in self.store:
# Make sure we register the removal of an existing key
del self[key]

found_proxies = []
found_proxies: List[ProxyObject] = []
proxied_id_to_proxy = self.proxies_tally.get_proxied_id_to_proxy()
self.store[key] = proxify_device_objects(
value, proxied_id_to_proxy, found_proxies
Expand All @@ -197,6 +222,8 @@ def __setitem__(self, key, value):
for p in found_proxies:
p._obj_pxy["hostfile"] = self_weakref
p._obj_pxy["last_access"] = last_access
assert "external" not in p._obj_pxy

self.proxies_tally.add_key(key, found_proxies)
self.maybe_evict()

Expand Down
12 changes: 10 additions & 2 deletions dask_cuda/proxy_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,12 @@ def __init__(
self._obj_pxy_lock = threading.RLock()
self._obj_pxy_cache = {}

def __del__(self):
"""In order to call `external_finalize()` ASAP, we call it here"""
external_finalize = self._obj_pxy.get("external_finalize", None)
if external_finalize is not None:
external_finalize()
madsbk marked this conversation as resolved.
Show resolved Hide resolved

def _obj_pxy_get_init_args(self, include_obj=True):
"""Return the attributes needed to initialize a ProxyObject

Expand Down Expand Up @@ -277,7 +283,8 @@ def _obj_pxy_serialize(self, serializers):
self._obj_pxy["serializers"] = serializers
hostfile = self._obj_pxy.get("hostfile", lambda: None)()
if hostfile is not None:
hostfile.proxies_tally.spill_proxy(self)
external = self._obj_pxy.get("external", self)
hostfile.proxies_tally.spill_proxy(external)

# Invalidate the (possible) cached "device_memory_objects"
self._obj_pxy_cache.pop("device_memory_objects", None)
Expand Down Expand Up @@ -311,7 +318,8 @@ def _obj_pxy_deserialize(self):
self._obj_pxy["obj"] = distributed.protocol.deserialize(header, frames)
self._obj_pxy["serializers"] = None
if hostfile is not None:
hostfile.proxies_tally.unspill_proxy(self)
external = self._obj_pxy.get("external", self)
hostfile.proxies_tally.unspill_proxy(external)

self._obj_pxy["last_access"] = time.time()
return self._obj_pxy["obj"]
Expand Down
64 changes: 60 additions & 4 deletions dask_cuda/tests/test_proxify_host_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,13 @@ def test_one_item_limit():

# Check k1 is spilled because of the newer k2
k1 = dhf["k1"]
k2 = dhf["k2"]
assert k1._obj_pxy_is_serialized()
assert not dhf["k2"]._obj_pxy_is_serialized()
assert not k2._obj_pxy_is_serialized()

# Accessing k1 spills k2 and unspill k1
k1_val = k1[0]
assert k1_val == 1
k2 = dhf["k2"]
assert k2._obj_pxy_is_serialized()

# Duplicate arrays changes nothing
Expand All @@ -50,11 +50,11 @@ def test_one_item_limit():

# Deleting k2 does not change anything since k3 still holds a
# reference to the underlying proxy object
assert dhf.proxies_tally.get_dev_mem_usage() == 8
assert dhf.proxies_tally.get_dev_mem_usage() == itemsize
p1 = list(dhf.proxies_tally.get_unspilled_proxies())
assert len(p1) == 1
del dhf["k2"]
assert dhf.proxies_tally.get_dev_mem_usage() == 8
assert dhf.proxies_tally.get_dev_mem_usage() == itemsize
p2 = list(dhf.proxies_tally.get_unspilled_proxies())
assert len(p2) == 1
assert p1[0] is p2[0]
Expand Down Expand Up @@ -129,3 +129,59 @@ def test_cudf_get_device_memory_objects():
]
res = get_device_memory_objects(objects)
assert len(res) == 4, "We expect four buffer objects"


def test_externals():
pentschev marked this conversation as resolved.
Show resolved Hide resolved
dhf = ProxifyHostFile(device_memory_limit=itemsize)
dhf["k1"] = cupy.arange(1) + 1
k1 = dhf["k1"]
k2 = dhf.add_external(cupy.arange(1) + 1)
# `k2` isn't part of the store but still triggers spilling of `k1`
assert len(dhf) == 1
assert k1._obj_pxy_is_serialized()
assert not k2._obj_pxy_is_serialized()
k1[0] # Trigger spilling of `k2`
assert not k1._obj_pxy_is_serialized()
assert k2._obj_pxy_is_serialized()
k2[0] # Trigger spilling of `k1`
assert k1._obj_pxy_is_serialized()
assert not k2._obj_pxy_is_serialized()
assert dhf.proxies_tally.get_dev_mem_usage() == itemsize
# Removing `k2` also removes it from the tally
del k2
assert dhf.proxies_tally.get_dev_mem_usage() == 0
assert len(list(dhf.proxies_tally.get_unspilled_proxies())) == 0


def test_externals_setitem():
dhf = ProxifyHostFile(device_memory_limit=itemsize)
k1 = dhf.add_external(cupy.arange(1))
assert type(k1) is dask_cuda.proxy_object.ProxyObject
assert len(dhf) == 0
assert "external" in k1._obj_pxy
assert "external_finalize" in k1._obj_pxy
dhf["k1"] = k1
k1 = dhf["k1"]
assert type(k1) is dask_cuda.proxy_object.ProxyObject
assert len(dhf) == 1
assert "external" not in k1._obj_pxy
assert "external_finalize" not in k1._obj_pxy

k1 = dhf.add_external(cupy.arange(1))
k1._obj_pxy_serialize(serializers=("dask", "pickle"))
dhf["k1"] = k1
k1 = dhf["k1"]
assert type(k1) is dask_cuda.proxy_object.ProxyObject
assert len(dhf) == 1
assert "external" not in k1._obj_pxy
assert "external_finalize" not in k1._obj_pxy

dhf["k1"] = cupy.arange(1)
assert len(dhf.proxies_tally.proxy_id_to_proxy) == 1
assert dhf.proxies_tally.get_dev_mem_usage() == itemsize
k1 = dhf.add_external(k1)
assert len(dhf.proxies_tally.proxy_id_to_proxy) == 1
assert dhf.proxies_tally.get_dev_mem_usage() == itemsize
k1 = dhf.add_external(dhf["k1"])
assert len(dhf.proxies_tally.proxy_id_to_proxy) == 1
assert dhf.proxies_tally.get_dev_mem_usage() == itemsize