Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support get_worker() and worker_client() in async tasks #7844

Merged
merged 4 commits into from
May 31, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions distributed/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2354,9 +2354,9 @@ def gather(self, futures, errors="raise", direct=None, asynchronous=None):
elif isinstance(futures, Iterator):
return (self.gather(f, errors=errors, direct=direct) for f in futures)
else:
if hasattr(thread_state, "execution_state"): # within worker task
local_worker = thread_state.execution_state["worker"]
else:
try:
local_worker = get_worker()
except ValueError:
local_worker = None
return self.sync(
self._gather,
Expand Down Expand Up @@ -2579,10 +2579,11 @@ def scatter(
"Consider using a normal for loop and Client.submit"
)

if hasattr(thread_state, "execution_state"): # within worker task
local_worker = thread_state.execution_state["worker"]
else:
try:
local_worker = get_worker()
except ValueError:
local_worker = None

graingert marked this conversation as resolved.
Show resolved Hide resolved
return self.sync(
self._scatter,
data,
Expand Down
18 changes: 18 additions & 0 deletions distributed/tests/test_worker_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,24 @@ def func(x):
assert len([id for id in s.clients if id.lower().startswith("client")]) == 1


@gen_cluster(client=True)
async def test_submit_from_worker_async(c, s, a, b):
async def func(x):
with worker_client() as c:
x = c.submit(inc, x)
y = c.submit(double, x)
return await x + await y

x, y = c.map(func, [10, 20])
xx, yy = await c.gather([x, y])

assert xx == 10 + 1 + (10 + 1) * 2
assert yy == 20 + 1 + (20 + 1) * 2

assert len(s.transition_log) > 10
assert len([id for id in s.clients if id.lower().startswith("client")]) == 1


@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 2)
async def test_scatter_from_worker(c, s, a, b):
def func():
Expand Down
48 changes: 36 additions & 12 deletions distributed/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import bisect
import builtins
import contextlib
import contextvars
import errno
import logging
import math
Expand Down Expand Up @@ -2142,7 +2143,11 @@ async def actor_execute(

try:
if iscoroutinefunction(func):
result = await func(*args, **kwargs)
token = _worker_cvar.set(self)
try:
result = await func(*args, **kwargs)
finally:
_worker_cvar.reset(token)
elif separate_thread:
result = await self.loop.run_in_executor(
self.executors["actor"],
Expand All @@ -2156,7 +2161,11 @@ async def actor_execute(
self.active_threads_lock,
)
else:
result = func(*args, **kwargs)
token = _worker_cvar.set(self)
try:
result = func(*args, **kwargs)
finally:
_worker_cvar.reset(token)
return {"status": "OK", "result": to_serialize(result)}
except Exception as ex:
return {"status": "error", "exception": to_serialize(ex)}
Expand Down Expand Up @@ -2236,12 +2245,16 @@ async def execute(self, key: str, *, stimulus_id: str) -> StateMachineEvent:
try:
ts.start_time = time()
if iscoroutinefunction(function):
result = await apply_function_async(
function,
args2,
kwargs2,
self.scheduler_delay,
)
token = _worker_cvar.set(self)
try:
result = await apply_function_async(
function,
args2,
kwargs2,
self.scheduler_delay,
)
finally:
_worker_cvar.reset(token)
elif "ThreadPoolExecutor" in str(type(e)):
# The 'executor' time metric should be almost zero most of the time,
# e.g. thread synchronization overhead only, since thread-noncpu and
Expand Down Expand Up @@ -2663,6 +2676,9 @@ def total_in_connections(self):
return self.transfer_outgoing_count_limit


_worker_cvar: contextvars.ContextVar[Worker] = contextvars.ContextVar("_worker_cvar")


def get_worker() -> Worker:
"""Get the worker currently running this task

Expand All @@ -2682,8 +2698,8 @@ def get_worker() -> Worker:
worker_client
"""
try:
return thread_state.execution_state["worker"]
except AttributeError:
return _worker_cvar.get()
except LookupError:
raise ValueError("No worker found") from None


Expand Down Expand Up @@ -3010,7 +3026,11 @@ def apply_function(
execution_state=execution_state,
key=key,
):
msg = apply_function_simple(function, args, kwargs, time_delay)
token = _worker_cvar.set(execution_state["worker"])
try:
msg = apply_function_simple(function, args, kwargs, time_delay)
finally:
_worker_cvar.reset(token)

with active_threads_lock:
del active_threads[ident]
Expand Down Expand Up @@ -3143,7 +3163,11 @@ def apply_function_actor(
key=key,
actor=True,
):
result = function(*args, **kwargs)
token = _worker_cvar.set(execution_state["worker"])
try:
result = function(*args, **kwargs)
finally:
_worker_cvar.reset(token)

with active_threads_lock:
del active_threads[ident]
Expand Down
40 changes: 22 additions & 18 deletions distributed/worker_client.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import contextlib
import warnings
from contextlib import contextmanager

import dask

Expand All @@ -11,7 +11,7 @@
from distributed.worker_state_machine import SecedeEvent


@contextmanager
@contextlib.contextmanager
def worker_client(timeout=None, separate_thread=True):
"""Get client for this thread

Expand Down Expand Up @@ -53,22 +53,26 @@ def worker_client(timeout=None, separate_thread=True):

worker = get_worker()
client = get_client(timeout=timeout)
if separate_thread:
duration = time() - thread_state.start_time
secede() # have this thread secede from the thread pool
worker.loop.add_callback(
worker.handle_stimulus,
SecedeEvent(
key=thread_state.key,
compute_duration=duration,
stimulus_id=f"worker-client-secede-{time()}",
),
)

yield client

if separate_thread:
rejoin()
Copy link
Member Author

@graingert graingert May 18, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

previously this didn't rejoin if there was an exception in with worker_client():

with contextlib.ExitStack() as stack:
if separate_thread:
try:
thread_state.start_time
except AttributeError: # not in a synchronous task, can't secede
pass
else:
duration = time() - thread_state.start_time
secede() # have this thread secede from the thread pool
stack.callback(rejoin)
worker.loop.add_callback(
worker.handle_stimulus,
SecedeEvent(
key=thread_state.key,
compute_duration=duration,
stimulus_id=f"worker-client-secede-{time()}",
),
)

yield client


def local_client(*args, **kwargs):
Expand Down