Skip to content

Commit

Permalink
Start of implementing externals
Browse files Browse the repository at this point in the history
  • Loading branch information
madsbk committed Feb 16, 2021
1 parent 635e01d commit 196b5b6
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 3 deletions.
21 changes: 20 additions & 1 deletion dask_cuda/proxify_host_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,13 +181,32 @@ def get_access_info(self) -> Tuple[int, List[Tuple[int, int, List[ProxyObject]]]
total_dev_mem_usage += size
return total_dev_mem_usage, dev_buf_access

def add_external(self, obj):
found_proxies: List[ProxyObject] = []
proxied_id_to_proxy = self.proxies_tally.get_proxied_id_to_proxy()
ret = proxify_device_objects(obj, proxied_id_to_proxy, found_proxies)
last_access = time.time()
self_weakref = weakref.ref(self)
for p in found_proxies:
weakref.finalize(p, self.del_external, id(p))
external = weakref.proxy(p)
p._obj_pxy["hostfile"] = self_weakref
p._obj_pxy["last_access"] = last_access
p._obj_pxy["external"] = external
self.proxies_tally.add_key(id(p), [external])
self.maybe_evict()
return ret

def del_external(self, name):
self.proxies_tally.del_key(name)

def __setitem__(self, key, value):
with self.lock:
if key in self.store:
# Make sure we register the removal of an existing key
del self[key]

found_proxies = []
found_proxies: List[ProxyObject] = []
proxied_id_to_proxy = self.proxies_tally.get_proxied_id_to_proxy()
self.store[key] = proxify_device_objects(
value, proxied_id_to_proxy, found_proxies
Expand Down
6 changes: 4 additions & 2 deletions dask_cuda/proxy_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,8 @@ def _obj_pxy_serialize(self, serializers):
self._obj_pxy["serializers"] = serializers
hostfile = self._obj_pxy.get("hostfile", lambda: None)()
if hostfile is not None:
hostfile.proxies_tally.spill_proxy(self)
external = self._obj_pxy.get("external", self)
hostfile.proxies_tally.spill_proxy(external)

# Invalidate the (possible) cached "device_memory_objects"
self._obj_pxy_cache.pop("device_memory_objects", None)
Expand Down Expand Up @@ -311,7 +312,8 @@ def _obj_pxy_deserialize(self):
self._obj_pxy["obj"] = distributed.protocol.deserialize(header, frames)
self._obj_pxy["serializers"] = None
if hostfile is not None:
hostfile.proxies_tally.unspill_proxy(self)
external = self._obj_pxy.get("external", self)
hostfile.proxies_tally.unspill_proxy(external)

self._obj_pxy["last_access"] = time.time()
return self._obj_pxy["obj"]
Expand Down
22 changes: 22 additions & 0 deletions dask_cuda/tests/test_proxify_host_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,3 +129,25 @@ def test_cudf_get_device_memory_objects():
]
res = get_device_memory_objects(objects)
assert len(res) == 4, "We expect four buffer objects"


def test_externals():
dhf = ProxifyHostFile(device_memory_limit=itemsize)
dhf["k1"] = cupy.arange(1) + 1
k1 = dhf["k1"]
k2 = dhf.add_external(cupy.arange(1) + 1)
# `k2` isn't part of the store but still triggers spilling of `k1`
assert len(dhf) == 1
assert k1._obj_pxy_is_serialized()
assert not k2._obj_pxy_is_serialized()
k1[0] # Trigger spilling of `k2`
assert not k1._obj_pxy_is_serialized()
assert k2._obj_pxy_is_serialized()
k2[0] # Trigger spilling of `k1`
assert k1._obj_pxy_is_serialized()
assert not k2._obj_pxy_is_serialized()
assert dhf.proxies_tally.get_dev_mem_usage() == itemsize
# Removing `k2` also removes it from the tally
del k2
assert dhf.proxies_tally.get_dev_mem_usage() == 0
assert len(list(dhf.proxies_tally.get_unspilled_proxies())) == 0

0 comments on commit 196b5b6

Please sign in to comment.