diff --git a/distributed/actor.py b/distributed/actor.py index b5d1b32a0fb..1d3b446347c 100644 --- a/distributed/actor.py +++ b/distributed/actor.py @@ -95,20 +95,22 @@ def __init__(self, cls, address, key, worker=None): super().__init__(key) self._cls = cls self._address = address + self._key = key self._future = None - if worker: - self._worker = worker - self._client = None - else: + self._worker = worker + self._client = None + self._try_bind_worker_client() + + def _try_bind_worker_client(self): + if not self._worker: try: - # TODO: `get_worker` may return the wrong worker instance for async local clusters (most tests) - # when run outside of a task (when deserializing a key pointing to an Actor, etc.) self._worker = get_worker() except ValueError: self._worker = None + if not self._client: try: self._client = get_client() - self._future = Future(key, inform=self._worker is None) + self._future = Future(self._key, inform=False) # ^ When running on a worker, only hold a weak reference to the key, otherwise the key could become unreleasable. except ValueError: self._client = None @@ -121,6 +123,8 @@ def __reduce__(self): @property def _io_loop(self): + if self._worker is None and self._client is None: + self._try_bind_worker_client() if self._worker: return self._worker.loop else: @@ -128,6 +132,8 @@ def _io_loop(self): @property def _scheduler_rpc(self): + if self._worker is None and self._client is None: + self._try_bind_worker_client() if self._worker: return self._worker.scheduler else: @@ -135,6 +141,8 @@ def _scheduler_rpc(self): @property def _worker_rpc(self): + if self._worker is None and self._client is None: + self._try_bind_worker_client() if self._worker: return self._worker.rpc(self._address) else: @@ -168,7 +176,7 @@ def __getattr__(self, key): raise ValueError( "Worker holding Actor was lost. Status: " + self._future.status ) - + self._try_bind_worker_client() if ( self._worker and self._worker.address == self._address diff --git a/distributed/client.py b/distributed/client.py index be6b216d7cc..953ea20146b 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -198,32 +198,58 @@ class Future(WrappedKey): def __init__(self, key, client=None, inform=True, state=None): self.key = key self._cleared = False - tkey = stringify(key) - self.client = client or Client.current() - self.client._inc_ref(tkey) - self._generation = self.client.generation + self._tkey = stringify(key) + self._client = client + self._input_state = state + self._inform = inform + self._state = None + self._bind_late() - if tkey in self.client.futures: - self._state = self.client.futures[tkey] - else: - self._state = self.client.futures[tkey] = FutureState() - - if inform: - self.client._send_to_scheduler( - { - "op": "client-desires-keys", - "keys": [stringify(key)], - "client": self.client.id, - } - ) + @property + def client(self): + self._bind_late() + return self._client - if state is not None: + def _bind_late(self): + if not self._client: try: - handler = self.client._state_handlers[state] - except KeyError: - pass + client = get_client() + except ValueError: + client = None + self._client = client + if self._client and not self._state: + self._client._inc_ref(self._tkey) + self._generation = self._client.generation + + if self._tkey in self._client.futures: + self._state = self._client.futures[self._tkey] else: - handler(key=key) + self._state = self._client.futures[self._tkey] = FutureState() + + if self._inform: + self._client._send_to_scheduler( + { + "op": "client-desires-keys", + "keys": [self._tkey], + "client": self._client.id, + } + ) + + if self._input_state is not None: + try: + handler = self._client._state_handlers[self._input_state] + except KeyError: + pass + else: + handler(key=self.key) + + def _verify_initialized(self): + if not self.client or not self._state: + raise RuntimeError( + f"{type(self)} object not properly initialized. This can happen" + " if the object is being deserialized outside of the context of" + " a Client or Worker." + ) @property def executor(self): @@ -277,6 +303,7 @@ def result(self, timeout=None): result The result of the computation. Or a coroutine if the client is asynchronous. """ + self._verify_initialized() if self.client.asynchronous: return self.client.sync(self._result, callback_timeout=timeout) @@ -338,6 +365,7 @@ def exception(self, timeout=None, **kwargs): -------- Future.traceback """ + self._verify_initialized() return self.client.sync(self._exception, callback_timeout=timeout, **kwargs) def add_done_callback(self, fn): @@ -354,6 +382,7 @@ def add_done_callback(self, fn): fn : callable The method or function to be called """ + self._verify_initialized() cls = Future if cls._cb_executor is None or cls._cb_executor_pid != os.getpid(): try: @@ -381,6 +410,7 @@ def cancel(self, **kwargs): -------- Client.cancel """ + self._verify_initialized() return self.client.cancel([self], **kwargs) def retry(self, **kwargs): @@ -390,6 +420,7 @@ def retry(self, **kwargs): -------- Client.retry """ + self._verify_initialized() return self.client.retry([self], **kwargs) def cancelled(self): @@ -440,6 +471,7 @@ def traceback(self, timeout=None, **kwargs): -------- Future.exception """ + self._verify_initialized() return self.client.sync(self._traceback, callback_timeout=timeout, **kwargs) @property @@ -454,6 +486,7 @@ def release(self): This method can be called from different threads (see e.g. Client.get() or Future.__del__()) """ + self._verify_initialized() if not self._cleared and self.client.generation == self._generation: self._cleared = True try: @@ -461,24 +494,8 @@ def release(self): except TypeError: # pragma: no cover pass # Shutting down, add_callback may be None - def __getstate__(self): - return self.key, self.client.scheduler.address - - def __setstate__(self, state): - key, address = state - try: - c = Client.current(allow_global=False) - except ValueError: - c = get_client(address) - self.__init__(key, c) - c._send_to_scheduler( - { - "op": "update-graph", - "tasks": {}, - "keys": [stringify(self.key)], - "client": c.id, - } - ) + def __reduce__(self) -> str | tuple[Any, ...]: + return Future, (self.key,) def __del__(self): try: diff --git a/distributed/diagnostics/plugin.py b/distributed/diagnostics/plugin.py index 527b85e1ee4..0434f8c8f3c 100644 --- a/distributed/diagnostics/plugin.py +++ b/distributed/diagnostics/plugin.py @@ -299,7 +299,13 @@ async def setup(self, worker): from distributed.semaphore import Semaphore async with ( - await Semaphore(max_leases=1, name=socket.gethostname(), register=True) + await Semaphore( + max_leases=1, + name=socket.gethostname(), + register=True, + scheduler_rpc=worker.scheduler, + loop=worker.loop, + ) ): if not await self._is_installed(worker): logger.info( diff --git a/distributed/event.py b/distributed/event.py index 91e70d00318..145abbf3857 100644 --- a/distributed/event.py +++ b/distributed/event.py @@ -178,12 +178,18 @@ class Event: """ def __init__(self, name=None, client=None): - try: - self.client = client or get_client() - except ValueError: - self.client = None + self._client = client self.name = name or "event-" + uuid.uuid4().hex + @property + def client(self): + if not self._client: + try: + self._client = get_client() + except ValueError: + pass + return self._client + def __await__(self): """async constructor diff --git a/distributed/lock.py b/distributed/lock.py index c4a44aebda9..99ec34cd6f7 100644 --- a/distributed/lock.py +++ b/distributed/lock.py @@ -7,9 +7,8 @@ from dask.utils import parse_timedelta -from distributed.client import Client from distributed.utils import TimeoutError, log_errors, wait_for -from distributed.worker import get_worker +from distributed.worker import get_client logger = logging.getLogger(__name__) @@ -95,15 +94,28 @@ class Lock: """ def __init__(self, name=None, client=None): - try: - self.client = client or Client.current() - except ValueError: - # Initialise new client - self.client = get_worker().client + self._client = client self.name = name or "lock-" + uuid.uuid4().hex self.id = uuid.uuid4().hex self._locked = False + @property + def client(self): + if not self._client: + try: + self._client = get_client() + except ValueError: + pass + return self._client + + def _verify_running(self): + if not self.client: + raise RuntimeError( + f"{type(self)} object not properly initialized. This can happen" + " if the object is being deserialized outside of the context of" + " a Client or Worker." + ) + def acquire(self, blocking=True, timeout=None): """Acquire the lock @@ -127,6 +139,7 @@ def acquire(self, blocking=True, timeout=None): ------- True or False whether or not it successfully acquired the lock """ + self._verify_running() timeout = parse_timedelta(timeout) if not blocking: @@ -145,6 +158,7 @@ def acquire(self, blocking=True, timeout=None): def release(self): """Release the lock if already acquired""" + self._verify_running() if not self.locked(): raise ValueError("Lock is not yet acquired") result = self.client.sync( diff --git a/distributed/profile.py b/distributed/profile.py index e455b6a2ee3..0dac6f614bc 100644 --- a/distributed/profile.py +++ b/distributed/profile.py @@ -155,15 +155,20 @@ def process( merge """ if depth is None: - depth = sys.getrecursionlimit() - 50 - if depth <= 0: - return None + # Cut off rather conservatively since the output of the profiling + # sometimes need to be recursed into as well, e.g. for serialization + # which can cause recursion errors later on since this can generate + # deeply nested dictionaries + depth = min(250, sys.getrecursionlimit() // 4) + if any(frame.f_code.co_filename.endswith(o) for o in omit): return None prev = frame.f_back - if prev is not None and ( - stop is None or not prev.f_code.co_filename.endswith(stop) + if ( + depth > 0 + and prev is not None + and (stop is None or not prev.f_code.co_filename.endswith(stop)) ): new_state = process(prev, frame, state, stop=stop, depth=depth - 1) if new_state is None: diff --git a/distributed/protocol/tests/test_protocol.py b/distributed/protocol/tests/test_protocol.py index 3bf639451d3..2bf9eefb9b8 100644 --- a/distributed/protocol/tests/test_protocol.py +++ b/distributed/protocol/tests/test_protocol.py @@ -1,5 +1,6 @@ from __future__ import annotations +import sys from threading import Thread from time import sleep @@ -8,6 +9,7 @@ import dask from dask.sizeof import sizeof +from distributed.compatibility import WINDOWS from distributed.protocol import dumps, loads, maybe_compress, msgpack, to_serialize from distributed.protocol.compression import ( compressions, @@ -412,3 +414,22 @@ def test_sizeof_serialize(Wrapper, Wrapped): assert size <= sizeof(ser_obj) < size * 1.05 serialized = Wrapped(*serialize(ser_obj)) assert size <= sizeof(serialized) < size * 1.05 + + +@pytest.mark.skipif(WINDOWS, reason="On windows this is triggering a stackoverflow") +def test_deeply_nested_structures(): + # These kind of deeply nested structures are generated in our profiling code + def gen_deeply_nested(depth): + msg = {} + d = msg + while depth: + depth -= 1 + d["children"] = d = {} + return msg + + msg = gen_deeply_nested(sys.getrecursionlimit() - 100) + with pytest.raises(TypeError, match="Could not serialize object"): + serialize(msg, on_error="raise") + + msg = gen_deeply_nested(sys.getrecursionlimit() // 4) + assert isinstance(serialize(msg), tuple) diff --git a/distributed/queues.py b/distributed/queues.py index 65cb80c2041..0a026d2c3b6 100644 --- a/distributed/queues.py +++ b/distributed/queues.py @@ -167,18 +167,27 @@ class Queue: """ def __init__(self, name=None, client=None, maxsize=0): - try: - self.client = client or get_client() - except ValueError: - self.client = None + self._client = client self.name = name or "queue-" + uuid.uuid4().hex self.maxsize = maxsize + self._maybe_start() + + def _maybe_start(self): if self.client: if self.client.asynchronous: self._started = asyncio.ensure_future(self._start()) else: self.client.sync(self._start) + @property + def client(self): + if not self._client: + try: + self._client = get_client() + except ValueError: + pass + return self._client + def _verify_running(self): if not self.client: raise RuntimeError( @@ -192,6 +201,7 @@ async def _start(self): return self def __await__(self): + self._maybe_start() if hasattr(self, "_started"): return self._started.__await__() else: diff --git a/distributed/semaphore.py b/distributed/semaphore.py index 6c5a8def1a2..59d6951d7b0 100644 --- a/distributed/semaphore.py +++ b/distributed/semaphore.py @@ -342,20 +342,9 @@ def __init__( scheduler_rpc=None, loop=None, ): - try: - try: - worker = get_worker() - self.scheduler = scheduler_rpc or worker.scheduler - self.loop = loop or worker.loop + self._scheduler = scheduler_rpc + self._loop = loop - except ValueError: - client = get_client() - self.scheduler = scheduler_rpc or client.scheduler - self.loop = loop or client.loop - except ValueError: - # This happens if this is deserialized on the scheduler - self.scheduler = None - self.loop = None self.name = name or "semaphore-" + uuid.uuid4().hex self.max_leases = max_leases self.id = uuid.uuid4().hex @@ -386,6 +375,31 @@ def __init__( if self.loop is not None: self.loop.add_callback(pc.start) + @property + def scheduler(self): + self._bind_late() + return self._scheduler + + @property + def loop(self): + self._bind_late() + return self._loop + + def _bind_late(self): + if self._scheduler is None or self._loop is None: + try: + try: + worker = get_worker() + self._scheduler = self._scheduler or worker.scheduler + self._loop = self._loop or worker.loop + + except ValueError: + client = get_client() + self._scheduler = self._scheduler or client.scheduler + self._loop = self._loop or client.loop + except ValueError: + pass + def _verify_running(self): if not self.scheduler or not self.loop: raise RuntimeError( @@ -547,4 +561,5 @@ def close(self): return self.sync(self.scheduler.semaphore_close, name=self.name) def __del__(self): - self.refresh_callback.stop() + if hasattr(self, "refresh_callback"): + self.refresh_callback.stop() diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 1d311448874..aa029628506 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -104,6 +104,7 @@ dec, div, double, + ensure_no_new_clients, gen_cluster, gen_test, get_cert, @@ -1355,12 +1356,6 @@ async def test_scatter_direct(c, s, a, b): assert result == 123 assert not s.counters["op"].components[0]["scatter"] - result = await future - assert not s.counters["op"].components[0]["gather"] - - result = await c.gather(future) - assert not s.counters["op"].components[0]["gather"] - @gen_cluster() async def test_scatter_direct_2(s, a, b): @@ -3977,15 +3972,37 @@ async def test_serialize_future(s, a, b): result = await future for ci in (c1, c2): - for ctxman in lambda ci: ci.as_current(), lambda ci: temp_default_client( - ci - ): - with ctxman(ci): + with ensure_no_new_clients(): + with ci.as_current(): future2 = pickle.loads(pickle.dumps(future)) assert future2.client is ci assert stringify(future2.key) in ci.futures result2 = await future2 assert result == result2 + with temp_default_client(ci): + future2 = pickle.loads(pickle.dumps(future)) + + +@gen_cluster() +async def test_serialize_future_without_client(s, a, b): + # Do not use a ctx manager to avoid having this being set as a current and/or default client + c1 = await Client(s.address, asynchronous=True, set_as_default=False) + + with ensure_no_new_clients(): + + def do_stuff(): + return 1 + + future = c1.submit(do_stuff) + pickled = pickle.dumps(future) + unpickled_fut = pickle.loads(pickled) + + with pytest.raises(RuntimeError): + await unpickled_fut + + with c1.as_current(): + unpickled_fut_ctx = pickle.loads(pickled) + assert await unpickled_fut_ctx == 1 @gen_cluster() @@ -5301,11 +5318,19 @@ def func(): def test_get_client_sync(c, s, a, b): - results = c.run(lambda: get_worker().scheduler.address) - assert results == {w["address"]: s["address"] for w in [a, b]} - - results = c.run(lambda: get_client().scheduler.address) - assert results == {w["address"]: s["address"] for w in [a, b]} + for w in [a, b]: + assert ( + c.submit( + lambda: get_worker().scheduler.address, workers=[w["address"]] + ).result() + == s["address"] + ) + assert ( + c.submit( + lambda: get_client().scheduler.address, workers=[w["address"]] + ).result() + == s["address"] + ) @gen_cluster(client=True) @@ -6356,7 +6381,6 @@ async def test_futures_of_sorted(c, s, a, b): assert str(k) in str(f) -@pytest.mark.flaky(reruns=10, reruns_delay=5) @gen_cluster( client=True, config={ @@ -6456,7 +6480,6 @@ async def f(): await f() assert set(results) == set(range(1, 11)) - assert not s.counters["op"].components[0]["gather"] @gen_cluster(client=True) diff --git a/distributed/tests/test_multi_locks.py b/distributed/tests/test_multi_locks.py index 79a1cbc2e39..8d8d8417dda 100644 --- a/distributed/tests/test_multi_locks.py +++ b/distributed/tests/test_multi_locks.py @@ -57,8 +57,8 @@ async def test_timeout(c, s, a, b): await lock1.release() -@gen_cluster() -async def test_timeout_wake_waiter(s, a, b): +@gen_cluster(client=True) +async def test_timeout_wake_waiter(c, s, a, b): l1 = MultiLock(names=["x"]) l2 = MultiLock(names=["x", "y"]) l3 = MultiLock(names=["y"]) diff --git a/distributed/tests/test_profile.py b/distributed/tests/test_profile.py index 36ec8de6eac..3036d127453 100644 --- a/distributed/tests/test_profile.py +++ b/distributed/tests/test_profile.py @@ -357,7 +357,7 @@ def test_call_stack_f_lineno(f_lasti: int, f_lineno: int) -> None: def test_stack_overflow(): old = sys.getrecursionlimit() - sys.setrecursionlimit(200) + sys.setrecursionlimit(300) try: state = create() frame = None @@ -370,9 +370,11 @@ def f(i): else: return f(i - 1) - f(sys.getrecursionlimit() - 40) + f(sys.getrecursionlimit() - 100) process(frame, None, state) - merge(state, state, state) + assert state["children"] + assert state["count"] + assert merge(state, state, state) finally: sys.setrecursionlimit(old) diff --git a/distributed/tests/test_pubsub.py b/distributed/tests/test_pubsub.py index 1b096cb19e7..664ea4f4ab7 100644 --- a/distributed/tests/test_pubsub.py +++ b/distributed/tests/test_pubsub.py @@ -54,7 +54,7 @@ def pingpong(a, b, start=False, n=1000, msg=1): @gen_cluster(client=True, nthreads=[]) async def test_client(c, s): - with pytest.raises(ValueError, match="No workers found"): + with pytest.raises(ValueError, match="No worker found"): get_worker() sub = Sub("a") pub = Pub("a") diff --git a/distributed/tests/test_semaphore.py b/distributed/tests/test_semaphore.py index a855a63f4f5..299cceb7b8f 100644 --- a/distributed/tests/test_semaphore.py +++ b/distributed/tests/test_semaphore.py @@ -98,12 +98,13 @@ def test_timeout_sync(client): @gen_cluster( + client=True, config={ "distributed.scheduler.locks.lease-validation-interval": "200ms", "distributed.scheduler.locks.lease-timeout": "200ms", }, ) -async def test_release_semaphore_after_timeout(s, a, b): +async def test_release_semaphore_after_timeout(c, s, a, b): sem = await Semaphore(name="x", max_leases=2) await sem.acquire() # leases: 2 - 1 = 1 @@ -121,8 +122,8 @@ async def test_release_semaphore_after_timeout(s, a, b): assert not (await sem.acquire(timeout=0.1)) -@gen_cluster() -async def test_async_ctx(s, a, b): +@gen_cluster(client=True) +async def test_async_ctx(c, s, a, b): sem = await Semaphore(name="x") async with sem: assert not await sem.acquire(timeout=0.025) diff --git a/distributed/tests/test_utils_test.py b/distributed/tests/test_utils_test.py index 3a75fe6cf98..0d3f88180bf 100755 --- a/distributed/tests/test_utils_test.py +++ b/distributed/tests/test_utils_test.py @@ -39,6 +39,7 @@ check_thread_leak, cluster, dump_cluster_state, + ensure_no_new_clients, freeze_batched_send, gen_cluster, gen_nbytes, @@ -1081,3 +1082,20 @@ def test_sizeof(): def test_sizeof_error(input, exc, msg): with pytest.raises(exc, match=msg): SizeOf(input) + + +@gen_test() +async def test_ensure_no_new_clients(): + with ensure_no_new_clients(): + async with Scheduler() as s: + pass + async with Scheduler() as s: + with ensure_no_new_clients(): + pass + with pytest.raises(AssertionError): + with ensure_no_new_clients(): + async with Client(s.address, asynchronous=True): + pass + async with Client(s.address, asynchronous=True): + with ensure_no_new_clients(): + pass diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 1039bd01a66..be9f77613d3 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -1061,6 +1061,9 @@ async def f(): assert results == {a.address: 11, b.address: 11} +@pytest.mark.xfail( + reason="Async tasks do not provide worker context. See https://github.com/dask/distributed/pull/7339" +) def test_get_client_coroutine_sync(client, s, a, b): async def f(): client = await get_client() @@ -1068,8 +1071,8 @@ async def f(): result = await future return result - results = client.run(f) - assert results == {a["address"]: 11, b["address"]: 11} + for w in [a, b]: + assert client.submit(f, workers=[w["address"]]).result() == 1 @gen_cluster() @@ -1240,21 +1243,6 @@ async def test_deque_handler(s): assert any(msg.msg == "foo456" for msg in deque_handler.deque) -def test_get_worker_name(client): - def f(): - get_client().submit(inc, 1).result() - - client.run(f) - - def func(dask_scheduler): - return list(dask_scheduler.clients) - - start = time() - while not any("worker" in n for n in client.run_on_scheduler(func)): - sleep(0.1) - assert time() < start + 10 - - @gen_cluster(nthreads=[], client=True) async def test_scheduler_address_config(c, s): with dask.config.set({"scheduler-address": s.address}): @@ -2225,8 +2213,8 @@ def get_worker_client_id(): def_client = get_client() return def_client.id - worker_client = await c.submit(get_worker_client_id) - assert worker_client == existing_client + wclient = await c.submit(get_worker_client_id) + assert wclient == existing_client assert not Worker._initialized_clients diff --git a/distributed/utils_test.py b/distributed/utils_test.py index 8665dbeed24..3e349fdb7a2 100644 --- a/distributed/utils_test.py +++ b/distributed/utils_test.py @@ -312,6 +312,14 @@ def get(self): return getattr(sys.modules[self.modname], self.slotname) +@contextmanager +def ensure_no_new_clients(): + before = set(Client._instances) + yield + after = set(Client._instances) + assert after.issubset(before) + + def varying(items): """ Return a function that returns a result (or raises an exception) diff --git a/distributed/variable.py b/distributed/variable.py index 2816840c099..c541d8e1336 100644 --- a/distributed/variable.py +++ b/distributed/variable.py @@ -136,10 +136,6 @@ class Variable: it is wise not to send too much. If you want to share a large amount of data then ``scatter`` it and share the future instead. - .. warning:: - - This object is experimental and has known issues in Python 2 - Parameters ---------- name: string (optional) @@ -166,12 +162,18 @@ class Variable: """ def __init__(self, name=None, client=None): - try: - self.client = client or get_client() - except ValueError: - self.client = None + self._client = client self.name = name or "variable-" + uuid.uuid4().hex + @property + def client(self): + if not self._client: + try: + self._client = get_client() + except ValueError: + pass + return self._client + def _verify_running(self): if not self.client: raise RuntimeError( diff --git a/distributed/worker.py b/distributed/worker.py index 59b2c8784d3..9723436bfd2 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -41,7 +41,7 @@ overload, ) -from tlz import first, keymap, pluck +from tlz import keymap, pluck from tornado.ioloop import IOLoop import dask @@ -106,6 +106,7 @@ parse_ports, recursive_to_dict, run_in_executor_with_context, + set_thread_state, silence_logging, thread_state, wait_for, @@ -2704,10 +2705,7 @@ def get_worker() -> Worker: try: return thread_state.execution_state["worker"] except AttributeError: - try: - return first(w for w in Worker._instances if w.status in WORKER_ANY_RUNNING) - except StopIteration: - raise ValueError("No workers found") + raise ValueError("No worker found") from None def get_client(address=None, timeout=None, resolve_address=True) -> Client: @@ -2927,6 +2925,8 @@ def loads_function(bytes_object): @context_meter.meter("deserialize") def _deserialize(function=None, args=None, kwargs=None, task=NO_VALUE): """Deserialize task inputs and regularize to func, args, kwargs""" + # Some objects require threadlocal state during deserialization, e.g. to + # detect the current worker if function is not None: function = loads_function(function) if args and isinstance(args, bytes): @@ -3055,11 +3055,12 @@ def apply_function( ident = threading.get_ident() with active_threads_lock: active_threads[ident] = key - thread_state.start_time = time() - thread_state.execution_state = execution_state - thread_state.key = key - - msg = apply_function_simple(function, args, kwargs, time_delay) + with set_thread_state( + start_time=time(), + execution_state=execution_state, + key=key, + ): + msg = apply_function_simple(function, args, kwargs, time_delay) with active_threads_lock: del active_threads[ident] @@ -3186,16 +3187,18 @@ def apply_function_actor( with active_threads_lock: active_threads[ident] = key - thread_state.execution_state = execution_state - thread_state.key = key - thread_state.actor = True - - result = function(*args, **kwargs) + with set_thread_state( + start_time=time(), + execution_state=execution_state, + key=key, + actor=True, + ): + result = function(*args, **kwargs) - with active_threads_lock: - del active_threads[ident] + with active_threads_lock: + del active_threads[ident] - return result + return result def get_msg_safe_str(msg): @@ -3273,6 +3276,7 @@ async def run(server, comm, function, args=(), kwargs=None, wait=True): if has_arg(function, "dask_scheduler"): kwargs["dask_scheduler"] = server logger.info("Run out-of-band function %r", funcname(function)) + try: if not is_coro: result = function(*args, **kwargs)