From f28c719962f36976dc514a50b9cbb1ce62d5b704 Mon Sep 17 00:00:00 2001 From: Gabe Joseph Date: Tue, 20 Jul 2021 06:21:30 -0800 Subject: [PATCH] Actor: don't hold key references on workers (#4937) Fixes #4936 When constructing an Actor handle, if there is a current worker, make our Future a weakref. --- distributed/actor.py | 14 ++++++--- distributed/tests/test_actor.py | 55 +++++++++++++++++++++++++++++++++ distributed/worker.py | 2 +- 3 files changed, 65 insertions(+), 6 deletions(-) diff --git a/distributed/actor.py b/distributed/actor.py index 2ebbba53a1..19828281dc 100644 --- a/distributed/actor.py +++ b/distributed/actor.py @@ -3,11 +3,11 @@ import threading from queue import Queue -from .client import Future, default_client +from .client import Future from .protocol import to_serialize from .utils import iscoroutinefunction, sync, thread_state from .utils_comm import WrappedKey -from .worker import get_worker +from .worker import get_client, get_worker class Actor(WrappedKey): @@ -59,12 +59,15 @@ def __init__(self, cls, address, key, worker=None): self._client = None else: try: + # TODO: `get_worker` may return the wrong worker instance for async local clusters (most tests) + # when run outside of a task (when deserializing a key pointing to an Actor, etc.) self._worker = get_worker() except ValueError: self._worker = None try: - self._client = default_client() - self._future = Future(key) + self._client = get_client() + self._future = Future(key, inform=self._worker is None) + # ^ When running on a worker, only hold a weak reference to the key, otherwise the key could become unreleasable. except ValueError: self._client = None @@ -109,7 +112,8 @@ def _sync(self, func, *args, **kwargs): if self._client: return self._client.sync(func, *args, **kwargs) else: - # TODO support sync operation by checking against thread ident of loop + if self._asynchronous: + return func(*args, **kwargs) return sync(self._worker.loop, func, *args, **kwargs) def __dir__(self): diff --git a/distributed/tests/test_actor.py b/distributed/tests/test_actor.py index 379546c186..d529edcc98 100644 --- a/distributed/tests/test_actor.py +++ b/distributed/tests/test_actor.py @@ -564,6 +564,61 @@ async def wait(self): await c.gather(futures) +@gen_cluster(client=True, client_kwargs=dict(set_as_default=False)) +# ^ NOTE: without `set_as_default=False`, `get_client()` within worker would return +# the same client instance the test is using (because it's all one process). +# Even with this, both workers will share the same client instance. +async def test_worker_actor_handle_is_weakref(c, s, a, b): + counter = c.submit(Counter, actor=True, workers=[a.address]) + + await c.submit(lambda _: None, counter, workers=[b.address]) + + del counter + + start = time() + while a.actors or b.data: + await asyncio.sleep(0.1) + assert time() < start + 30 + + +def test_worker_actor_handle_is_weakref_sync(client): + workers = list(client.run(lambda: None)) + counter = client.submit(Counter, actor=True, workers=[workers[0]]) + + client.submit(lambda _: None, counter, workers=[workers[1]]).result() + + del counter + + def check(dask_worker): + return len(dask_worker.data) + len(dask_worker.actors) + + start = time() + while any(client.run(check).values()): + sleep(0.01) + assert time() < start + 30 + + +def test_worker_actor_handle_is_weakref_from_compute_sync(client): + workers = list(client.run(lambda: None)) + + with dask.annotate(workers=workers[0]): + counter = dask.delayed(Counter)() + with dask.annotate(workers=workers[1]): + intermediate = dask.delayed(lambda c: None)(counter) + with dask.annotate(workers=workers[0]): + final = dask.delayed(lambda x, c: x)(intermediate, counter) + + final.compute(actors=counter, optimize_graph=False) + + def worker_tasks_running(dask_worker): + return len(dask_worker.data) + len(dask_worker.actors) + + start = time() + while any(client.run(worker_tasks_running).values()): + sleep(0.01) + assert time() < start + 30 + + def test_one_thread_deadlock(): with cluster(nworkers=2) as (cl, w): client = Client(cl["address"]) diff --git a/distributed/worker.py b/distributed/worker.py index 95d6a116e6..04bc32e2e9 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -1412,7 +1412,7 @@ async def get_data( if k in self.actors: from .actor import Actor - data[k] = Actor(type(self.actors[k]), self.address, k) + data[k] = Actor(type(self.actors[k]), self.address, k, worker=self) msg = {"status": "OK", "data": {k: to_serialize(v) for k, v in data.items()}} nbytes = {k: self.tasks[k].nbytes for k in data if k in self.tasks}