From 7a5b4e2dc7d310e2756f3122cd79b61c17a4b69b Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Thu, 18 May 2023 13:26:25 +0100 Subject: [PATCH 1/4] support get_worker() in async tasks --- distributed/client.py | 13 ++++++------ distributed/worker.py | 48 ++++++++++++++++++++++++++++++++----------- 2 files changed, 43 insertions(+), 18 deletions(-) diff --git a/distributed/client.py b/distributed/client.py index 21598f226e..ef189690a2 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -2354,9 +2354,9 @@ def gather(self, futures, errors="raise", direct=None, asynchronous=None): elif isinstance(futures, Iterator): return (self.gather(f, errors=errors, direct=direct) for f in futures) else: - if hasattr(thread_state, "execution_state"): # within worker task - local_worker = thread_state.execution_state["worker"] - else: + try: + local_worker = get_worker() + except ValueError: local_worker = None return self.sync( self._gather, @@ -2579,10 +2579,11 @@ def scatter( "Consider using a normal for loop and Client.submit" ) - if hasattr(thread_state, "execution_state"): # within worker task - local_worker = thread_state.execution_state["worker"] - else: + try: + local_worker = get_worker() + except ValueError: local_worker = None + return self.sync( self._scatter, data, diff --git a/distributed/worker.py b/distributed/worker.py index e6642d278b..18c1272213 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -4,6 +4,7 @@ import bisect import builtins import contextlib +import contextvars import errno import logging import math @@ -2142,7 +2143,11 @@ async def actor_execute( try: if iscoroutinefunction(func): - result = await func(*args, **kwargs) + token = _worker_cvar.set(self) + try: + result = await func(*args, **kwargs) + finally: + _worker_cvar.reset(token) elif separate_thread: result = await self.loop.run_in_executor( self.executors["actor"], @@ -2156,7 +2161,11 @@ async def actor_execute( self.active_threads_lock, ) else: - result = func(*args, **kwargs) + token = _worker_cvar.set(self) + try: + result = func(*args, **kwargs) + finally: + _worker_cvar.reset(token) return {"status": "OK", "result": to_serialize(result)} except Exception as ex: return {"status": "error", "exception": to_serialize(ex)} @@ -2236,12 +2245,16 @@ async def execute(self, key: str, *, stimulus_id: str) -> StateMachineEvent: try: ts.start_time = time() if iscoroutinefunction(function): - result = await apply_function_async( - function, - args2, - kwargs2, - self.scheduler_delay, - ) + token = _worker_cvar.set(self) + try: + result = await apply_function_async( + function, + args2, + kwargs2, + self.scheduler_delay, + ) + finally: + _worker_cvar.reset(token) elif "ThreadPoolExecutor" in str(type(e)): # The 'executor' time metric should be almost zero most of the time, # e.g. thread synchronization overhead only, since thread-noncpu and @@ -2663,6 +2676,9 @@ def total_in_connections(self): return self.transfer_outgoing_count_limit +_worker_cvar: contextvars.ContextVar[Worker] = contextvars.ContextVar("_worker_cvar") + + def get_worker() -> Worker: """Get the worker currently running this task @@ -2682,8 +2698,8 @@ def get_worker() -> Worker: worker_client """ try: - return thread_state.execution_state["worker"] - except AttributeError: + return _worker_cvar.get() + except LookupError: raise ValueError("No worker found") from None @@ -3010,7 +3026,11 @@ def apply_function( execution_state=execution_state, key=key, ): - msg = apply_function_simple(function, args, kwargs, time_delay) + token = _worker_cvar.set(execution_state["worker"]) + try: + msg = apply_function_simple(function, args, kwargs, time_delay) + finally: + _worker_cvar.reset(token) with active_threads_lock: del active_threads[ident] @@ -3143,7 +3163,11 @@ def apply_function_actor( key=key, actor=True, ): - result = function(*args, **kwargs) + token = _worker_cvar.set(execution_state["worker"]) + try: + result = function(*args, **kwargs) + finally: + _worker_cvar.reset(token) with active_threads_lock: del active_threads[ident] From 6c483b1f7af1dc5253dfe5d118baf297994b3657 Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Thu, 18 May 2023 13:37:46 +0100 Subject: [PATCH 2/4] support worker_client() in async tasks --- distributed/tests/test_worker_client.py | 18 +++++++++++ distributed/worker_client.py | 40 ++++++++++++++----------- 2 files changed, 40 insertions(+), 18 deletions(-) diff --git a/distributed/tests/test_worker_client.py b/distributed/tests/test_worker_client.py index 0c7e0f7366..0e6b873d1b 100644 --- a/distributed/tests/test_worker_client.py +++ b/distributed/tests/test_worker_client.py @@ -42,6 +42,24 @@ def func(x): assert len([id for id in s.clients if id.lower().startswith("client")]) == 1 +@gen_cluster(client=True) +async def test_submit_from_worker_async(c, s, a, b): + async def func(x): + with worker_client() as c: + x = c.submit(inc, x) + y = c.submit(double, x) + return await x + await y + + x, y = c.map(func, [10, 20]) + xx, yy = await c.gather([x, y]) + + assert xx == 10 + 1 + (10 + 1) * 2 + assert yy == 20 + 1 + (20 + 1) * 2 + + assert len(s.transition_log) > 10 + assert len([id for id in s.clients if id.lower().startswith("client")]) == 1 + + @gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 2) async def test_scatter_from_worker(c, s, a, b): def func(): diff --git a/distributed/worker_client.py b/distributed/worker_client.py index 793961be04..355156206d 100644 --- a/distributed/worker_client.py +++ b/distributed/worker_client.py @@ -1,7 +1,7 @@ from __future__ import annotations +import contextlib import warnings -from contextlib import contextmanager import dask @@ -11,7 +11,7 @@ from distributed.worker_state_machine import SecedeEvent -@contextmanager +@contextlib.contextmanager def worker_client(timeout=None, separate_thread=True): """Get client for this thread @@ -53,22 +53,26 @@ def worker_client(timeout=None, separate_thread=True): worker = get_worker() client = get_client(timeout=timeout) - if separate_thread: - duration = time() - thread_state.start_time - secede() # have this thread secede from the thread pool - worker.loop.add_callback( - worker.handle_stimulus, - SecedeEvent( - key=thread_state.key, - compute_duration=duration, - stimulus_id=f"worker-client-secede-{time()}", - ), - ) - - yield client - - if separate_thread: - rejoin() + with contextlib.ExitStack() as stack: + if separate_thread: + try: + thread_state.start_time + except AttributeError: # not in a synchronous task, can't secede + pass + else: + duration = time() - thread_state.start_time + secede() # have this thread secede from the thread pool + stack.callback(rejoin) + worker.loop.add_callback( + worker.handle_stimulus, + SecedeEvent( + key=thread_state.key, + compute_duration=duration, + stimulus_id=f"worker-client-secede-{time()}", + ), + ) + + yield client def local_client(*args, **kwargs): From fad99dd7416b45ac62bdb6329e0d612fe88a8113 Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Thu, 18 May 2023 14:08:01 +0100 Subject: [PATCH 3/4] remove extra whitespace --- distributed/client.py | 1 - 1 file changed, 1 deletion(-) diff --git a/distributed/client.py b/distributed/client.py index ef189690a2..32be8e5403 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -2583,7 +2583,6 @@ def scatter( local_worker = get_worker() except ValueError: local_worker = None - return self.sync( self._scatter, data, From 8e092cd8de346bd0228041402d5d81fd8746ba45 Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Thu, 18 May 2023 14:24:11 +0100 Subject: [PATCH 4/4] test worker_client/get_worker in Actors --- distributed/tests/test_actor.py | 42 ++++++++++++++++++++++++++++++++- 1 file changed, 41 insertions(+), 1 deletion(-) diff --git a/distributed/tests/test_actor.py b/distributed/tests/test_actor.py index 9c8163af6a..a3e6a84ad3 100644 --- a/distributed/tests/test_actor.py +++ b/distributed/tests/test_actor.py @@ -17,10 +17,12 @@ as_completed, get_client, wait, + worker_client, ) from distributed.actor import _LateLoopEvent from distributed.metrics import time -from distributed.utils_test import cluster, gen_cluster +from distributed.utils_test import cluster, double, gen_cluster, inc +from distributed.worker import get_worker class Counter: @@ -782,3 +784,41 @@ def __setstate__(self, state): future = c.submit(Foo, workers=a.address) foo = await future assert isinstance(foo.actor, Actor) + + +@gen_cluster(client=True) +async def test_worker_client_async(c, s, a, b): + class Actor: + async def demo(self, x): + with worker_client() as c: + x = c.submit(inc, x) + y = c.submit(double, x) + return await x + await y + + actor = await c.submit(Actor, actor=True) + assert await actor.demo(10) == 10 + 1 + (10 + 1) * 2 + + +@gen_cluster(client=True) +async def test_worker_client_separate_thread(c, s, a, b): + class Actor: + def demo(self, x): + with worker_client() as c: + x = c.submit(inc, x) + y = c.submit(double, x) + return x.result() + y.result() + + actor = await c.submit(Actor, actor=True) + assert await actor.demo(10, separate_thread=True) == 10 + 1 + (10 + 1) * 2 + + +@gen_cluster(client=True) +async def test_get_worker(c, s, a, b): + class Actor: + # There's not much you can do with a worker in a synchronous function + # running on the worker event loop. + def demo(self): + return get_worker().address + + actor = await c.submit(Actor, actor=True, workers=[a.address]) + assert await actor.demo() == a.address