Skip to content

Commit

Permalink
Actor: don't hold key references on workers (#4937)
Browse files Browse the repository at this point in the history
Fixes #4936

When constructing an Actor handle, if there is a current worker, make our Future a weakref.
  • Loading branch information
gjoseph92 authored Jul 20, 2021
1 parent cb18bd1 commit f28c719
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 6 deletions.
14 changes: 9 additions & 5 deletions distributed/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
import threading
from queue import Queue

from .client import Future, default_client
from .client import Future
from .protocol import to_serialize
from .utils import iscoroutinefunction, sync, thread_state
from .utils_comm import WrappedKey
from .worker import get_worker
from .worker import get_client, get_worker


class Actor(WrappedKey):
Expand Down Expand Up @@ -59,12 +59,15 @@ def __init__(self, cls, address, key, worker=None):
self._client = None
else:
try:
# TODO: `get_worker` may return the wrong worker instance for async local clusters (most tests)
# when run outside of a task (when deserializing a key pointing to an Actor, etc.)
self._worker = get_worker()
except ValueError:
self._worker = None
try:
self._client = default_client()
self._future = Future(key)
self._client = get_client()
self._future = Future(key, inform=self._worker is None)
# ^ When running on a worker, only hold a weak reference to the key, otherwise the key could become unreleasable.
except ValueError:
self._client = None

Expand Down Expand Up @@ -109,7 +112,8 @@ def _sync(self, func, *args, **kwargs):
if self._client:
return self._client.sync(func, *args, **kwargs)
else:
# TODO support sync operation by checking against thread ident of loop
if self._asynchronous:
return func(*args, **kwargs)
return sync(self._worker.loop, func, *args, **kwargs)

def __dir__(self):
Expand Down
55 changes: 55 additions & 0 deletions distributed/tests/test_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,6 +564,61 @@ async def wait(self):
await c.gather(futures)


@gen_cluster(client=True, client_kwargs=dict(set_as_default=False))
# ^ NOTE: without `set_as_default=False`, `get_client()` within worker would return
# the same client instance the test is using (because it's all one process).
# Even with this, both workers will share the same client instance.
async def test_worker_actor_handle_is_weakref(c, s, a, b):
counter = c.submit(Counter, actor=True, workers=[a.address])

await c.submit(lambda _: None, counter, workers=[b.address])

del counter

start = time()
while a.actors or b.data:
await asyncio.sleep(0.1)
assert time() < start + 30


def test_worker_actor_handle_is_weakref_sync(client):
workers = list(client.run(lambda: None))
counter = client.submit(Counter, actor=True, workers=[workers[0]])

client.submit(lambda _: None, counter, workers=[workers[1]]).result()

del counter

def check(dask_worker):
return len(dask_worker.data) + len(dask_worker.actors)

start = time()
while any(client.run(check).values()):
sleep(0.01)
assert time() < start + 30


def test_worker_actor_handle_is_weakref_from_compute_sync(client):
workers = list(client.run(lambda: None))

with dask.annotate(workers=workers[0]):
counter = dask.delayed(Counter)()
with dask.annotate(workers=workers[1]):
intermediate = dask.delayed(lambda c: None)(counter)
with dask.annotate(workers=workers[0]):
final = dask.delayed(lambda x, c: x)(intermediate, counter)

final.compute(actors=counter, optimize_graph=False)

def worker_tasks_running(dask_worker):
return len(dask_worker.data) + len(dask_worker.actors)

start = time()
while any(client.run(worker_tasks_running).values()):
sleep(0.01)
assert time() < start + 30


def test_one_thread_deadlock():
with cluster(nworkers=2) as (cl, w):
client = Client(cl["address"])
Expand Down
2 changes: 1 addition & 1 deletion distributed/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1412,7 +1412,7 @@ async def get_data(
if k in self.actors:
from .actor import Actor

data[k] = Actor(type(self.actors[k]), self.address, k)
data[k] = Actor(type(self.actors[k]), self.address, k, worker=self)

msg = {"status": "OK", "data": {k: to_serialize(v) for k, v in data.items()}}
nbytes = {k: self.tasks[k].nbytes for k in data if k in self.tasks}
Expand Down

0 comments on commit f28c719

Please sign in to comment.