From aaa663896f4b2ccb6cb70313a252c4dd1fdd95b0 Mon Sep 17 00:00:00 2001 From: fjetter Date: Thu, 23 Feb 2023 16:31:33 +0100 Subject: [PATCH 01/11] Get futures to work nicely --- distributed/actor.py | 34 +++++---- distributed/client.py | 87 +++++++++++----------- distributed/diagnostics/plugin.py | 8 ++- distributed/semaphore.py | 28 ++++---- distributed/tests/test_client.py | 74 +++++++++++++------ distributed/tests/test_multi_locks.py | 4 +- distributed/tests/test_semaphore.py | 7 +- distributed/tests/test_utils_test.py | 29 ++++++++ distributed/tests/test_worker.py | 26 ++----- distributed/utils_test.py | 8 +++ distributed/worker.py | 100 +++++++++++++++++--------- 11 files changed, 257 insertions(+), 148 deletions(-) diff --git a/distributed/actor.py b/distributed/actor.py index b5d1b32a0f..d3bfb771de 100644 --- a/distributed/actor.py +++ b/distributed/actor.py @@ -95,23 +95,27 @@ def __init__(self, cls, address, key, worker=None): super().__init__(key) self._cls = cls self._address = address + self._key = key self._future = None if worker: self._worker = worker self._client = None else: - try: - # TODO: `get_worker` may return the wrong worker instance for async local clusters (most tests) - # when run outside of a task (when deserializing a key pointing to an Actor, etc.) - self._worker = get_worker() - except ValueError: - self._worker = None - try: - self._client = get_client() - self._future = Future(key, inform=self._worker is None) - # ^ When running on a worker, only hold a weak reference to the key, otherwise the key could become unreleasable. - except ValueError: - self._client = None + self._try_bind_worker_client() + + def _try_bind_worker_client(self): + try: + # TODO: `get_worker` may return the wrong worker instance for async local clusters (most tests) + # when run outside of a task (when deserializing a key pointing to an Actor, etc.) + self._worker = get_worker() + except ValueError: + self._worker = None + try: + self._client = get_client() + self._future = Future(self._key, inform=self._worker is None) + # ^ When running on a worker, only hold a weak reference to the key, otherwise the key could become unreleasable. + except ValueError: + self._client = None def __repr__(self): return f"" @@ -121,6 +125,8 @@ def __reduce__(self): @property def _io_loop(self): + if self._worker is None and self._client is None: + self._try_bind_worker_client() if self._worker: return self._worker.loop else: @@ -128,6 +134,8 @@ def _io_loop(self): @property def _scheduler_rpc(self): + if self._worker is None and self._client is None: + self._try_bind_worker_client() if self._worker: return self._worker.scheduler else: @@ -135,6 +143,8 @@ def _scheduler_rpc(self): @property def _worker_rpc(self): + if self._worker is None and self._client is None: + self._try_bind_worker_client() if self._worker: return self._worker.rpc(self._address) else: diff --git a/distributed/client.py b/distributed/client.py index be6b216d7c..4fd56fae6d 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -199,31 +199,45 @@ def __init__(self, key, client=None, inform=True, state=None): self.key = key self._cleared = False tkey = stringify(key) - self.client = client or Client.current() - self.client._inc_ref(tkey) - self._generation = self.client.generation - - if tkey in self.client.futures: - self._state = self.client.futures[tkey] - else: - self._state = self.client.futures[tkey] = FutureState() - - if inform: - self.client._send_to_scheduler( - { - "op": "client-desires-keys", - "keys": [stringify(key)], - "client": self.client.id, - } - ) - - if state is not None: + if not client: try: - handler = self.client._state_handlers[state] - except KeyError: - pass + client = get_client() + except ValueError: + client = None + self.client = client + if self.client: + self.client._inc_ref(tkey) + self._generation = self.client.generation + + if tkey in self.client.futures: + self._state = self.client.futures[tkey] else: - handler(key=key) + self._state = self.client.futures[tkey] = FutureState() + + if inform: + self.client._send_to_scheduler( + { + "op": "client-desires-keys", + "keys": [stringify(key)], + "client": self.client.id, + } + ) + + if state is not None: + try: + handler = self.client._state_handlers[state] + except KeyError: + pass + else: + handler(key=key) + + def _verify_initialized(self): + if not self.client: + raise RuntimeError( + f"{type(self)} object not properly initialized. This can happen" + " if the object is being deserialized outside of the context of" + " a Client or Worker." + ) @property def executor(self): @@ -277,6 +291,7 @@ def result(self, timeout=None): result The result of the computation. Or a coroutine if the client is asynchronous. """ + self._verify_initialized() if self.client.asynchronous: return self.client.sync(self._result, callback_timeout=timeout) @@ -338,6 +353,7 @@ def exception(self, timeout=None, **kwargs): -------- Future.traceback """ + self._verify_initialized() return self.client.sync(self._exception, callback_timeout=timeout, **kwargs) def add_done_callback(self, fn): @@ -354,6 +370,7 @@ def add_done_callback(self, fn): fn : callable The method or function to be called """ + self._verify_initialized() cls = Future if cls._cb_executor is None or cls._cb_executor_pid != os.getpid(): try: @@ -381,6 +398,7 @@ def cancel(self, **kwargs): -------- Client.cancel """ + self._verify_initialized() return self.client.cancel([self], **kwargs) def retry(self, **kwargs): @@ -390,6 +408,7 @@ def retry(self, **kwargs): -------- Client.retry """ + self._verify_initialized() return self.client.retry([self], **kwargs) def cancelled(self): @@ -440,6 +459,7 @@ def traceback(self, timeout=None, **kwargs): -------- Future.exception """ + self._verify_initialized() return self.client.sync(self._traceback, callback_timeout=timeout, **kwargs) @property @@ -454,6 +474,7 @@ def release(self): This method can be called from different threads (see e.g. Client.get() or Future.__del__()) """ + self._verify_initialized() if not self._cleared and self.client.generation == self._generation: self._cleared = True try: @@ -461,24 +482,8 @@ def release(self): except TypeError: # pragma: no cover pass # Shutting down, add_callback may be None - def __getstate__(self): - return self.key, self.client.scheduler.address - - def __setstate__(self, state): - key, address = state - try: - c = Client.current(allow_global=False) - except ValueError: - c = get_client(address) - self.__init__(key, c) - c._send_to_scheduler( - { - "op": "update-graph", - "tasks": {}, - "keys": [stringify(self.key)], - "client": c.id, - } - ) + def __reduce__(self) -> str | tuple[Any, ...]: + return Future, (self.key,) def __del__(self): try: diff --git a/distributed/diagnostics/plugin.py b/distributed/diagnostics/plugin.py index 527b85e1ee..0434f8c8f3 100644 --- a/distributed/diagnostics/plugin.py +++ b/distributed/diagnostics/plugin.py @@ -299,7 +299,13 @@ async def setup(self, worker): from distributed.semaphore import Semaphore async with ( - await Semaphore(max_leases=1, name=socket.gethostname(), register=True) + await Semaphore( + max_leases=1, + name=socket.gethostname(), + register=True, + scheduler_rpc=worker.scheduler, + loop=worker.loop, + ) ): if not await self._is_installed(worker): logger.info( diff --git a/distributed/semaphore.py b/distributed/semaphore.py index 6c5a8def1a..d4ccf1cddf 100644 --- a/distributed/semaphore.py +++ b/distributed/semaphore.py @@ -342,20 +342,21 @@ def __init__( scheduler_rpc=None, loop=None, ): - try: + self.scheduler = scheduler_rpc + self.loop = loop + if scheduler_rpc is None or loop is None: try: - worker = get_worker() - self.scheduler = scheduler_rpc or worker.scheduler - self.loop = loop or worker.loop - + try: + worker = get_worker() + self.scheduler = scheduler_rpc or worker.scheduler + self.loop = loop or worker.loop + + except ValueError: + client = get_client() + self.scheduler = scheduler_rpc or client.scheduler + self.loop = loop or client.loop except ValueError: - client = get_client() - self.scheduler = scheduler_rpc or client.scheduler - self.loop = loop or client.loop - except ValueError: - # This happens if this is deserialized on the scheduler - self.scheduler = None - self.loop = None + pass self.name = name or "semaphore-" + uuid.uuid4().hex self.max_leases = max_leases self.id = uuid.uuid4().hex @@ -547,4 +548,5 @@ def close(self): return self.sync(self.scheduler.semaphore_close, name=self.name) def __del__(self): - self.refresh_callback.stop() + if hasattr(self, "refresh_callback"): + self.refresh_callback.stop() diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 1d31144887..b533837012 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -104,6 +104,7 @@ dec, div, double, + ensure_no_new_clients, gen_cluster, gen_test, get_cert, @@ -1355,12 +1356,6 @@ async def test_scatter_direct(c, s, a, b): assert result == 123 assert not s.counters["op"].components[0]["scatter"] - result = await future - assert not s.counters["op"].components[0]["gather"] - - result = await c.gather(future) - assert not s.counters["op"].components[0]["gather"] - @gen_cluster() async def test_scatter_direct_2(s, a, b): @@ -3977,15 +3972,37 @@ async def test_serialize_future(s, a, b): result = await future for ci in (c1, c2): - for ctxman in lambda ci: ci.as_current(), lambda ci: temp_default_client( - ci - ): - with ctxman(ci): + with ensure_no_new_clients(): + with ci.as_current(): future2 = pickle.loads(pickle.dumps(future)) assert future2.client is ci assert stringify(future2.key) in ci.futures result2 = await future2 assert result == result2 + with temp_default_client(ci): + future2 = pickle.loads(pickle.dumps(future)) + + +@gen_cluster() +async def test_serialize_future_without_client(s, a, b): + # Do not use a ctx manager to avoid having this being set as a current and/or default client + c1 = await Client(s.address, asynchronous=True, set_as_default=False) + + with ensure_no_new_clients(): + + def do_stuff(): + return 1 + + future = c1.submit(do_stuff) + pickled = pickle.dumps(future) + unpickled_fut = pickle.loads(pickled) + + with pytest.raises(RuntimeError): + await unpickled_fut + + with c1.as_current(): + unpickled_fut_ctx = pickle.loads(pickled) + assert await unpickled_fut_ctx == 1 @gen_cluster() @@ -5301,11 +5318,19 @@ def func(): def test_get_client_sync(c, s, a, b): - results = c.run(lambda: get_worker().scheduler.address) - assert results == {w["address"]: s["address"] for w in [a, b]} - - results = c.run(lambda: get_client().scheduler.address) - assert results == {w["address"]: s["address"] for w in [a, b]} + for w in [a, b]: + assert ( + c.submit( + lambda: get_worker().scheduler.address, workers=[w["address"]] + ).result() + == s["address"] + ) + assert ( + c.submit( + lambda: get_client().scheduler.address, workers=[w["address"]] + ).result() + == s["address"] + ) @gen_cluster(client=True) @@ -5331,13 +5356,20 @@ def test_serialize_collections_of_futures_sync(c): df = pd.DataFrame({"x": [1, 2, 3]}) ddf = dd.from_pandas(df, npartitions=2).persist() - future = c.scatter(ddf) + # future = c.scatter(ddf) + + # result = future.result() + # futs = futures_of(result) + assert_eq(ddf.compute(), df) - result = future.result() - assert_eq(result.compute(), df) + # assert future.type == dd.DataFrame + # futs = futures_of(future) + def inner(x, y): + futs = futures_of(x) + df = x.compute() + assert_eq(df, y) - assert future.type == dd.DataFrame - assert c.submit(lambda x, y: assert_eq(x.compute(), y), future, df).result() + c.submit(inner, ddf, df).result() def _dynamic_workload(x, delay=0.01): @@ -6356,7 +6388,6 @@ async def test_futures_of_sorted(c, s, a, b): assert str(k) in str(f) -@pytest.mark.flaky(reruns=10, reruns_delay=5) @gen_cluster( client=True, config={ @@ -6456,7 +6487,6 @@ async def f(): await f() assert set(results) == set(range(1, 11)) - assert not s.counters["op"].components[0]["gather"] @gen_cluster(client=True) diff --git a/distributed/tests/test_multi_locks.py b/distributed/tests/test_multi_locks.py index 79a1cbc2e3..8d8d8417dd 100644 --- a/distributed/tests/test_multi_locks.py +++ b/distributed/tests/test_multi_locks.py @@ -57,8 +57,8 @@ async def test_timeout(c, s, a, b): await lock1.release() -@gen_cluster() -async def test_timeout_wake_waiter(s, a, b): +@gen_cluster(client=True) +async def test_timeout_wake_waiter(c, s, a, b): l1 = MultiLock(names=["x"]) l2 = MultiLock(names=["x", "y"]) l3 = MultiLock(names=["y"]) diff --git a/distributed/tests/test_semaphore.py b/distributed/tests/test_semaphore.py index a855a63f4f..299cceb7b8 100644 --- a/distributed/tests/test_semaphore.py +++ b/distributed/tests/test_semaphore.py @@ -98,12 +98,13 @@ def test_timeout_sync(client): @gen_cluster( + client=True, config={ "distributed.scheduler.locks.lease-validation-interval": "200ms", "distributed.scheduler.locks.lease-timeout": "200ms", }, ) -async def test_release_semaphore_after_timeout(s, a, b): +async def test_release_semaphore_after_timeout(c, s, a, b): sem = await Semaphore(name="x", max_leases=2) await sem.acquire() # leases: 2 - 1 = 1 @@ -121,8 +122,8 @@ async def test_release_semaphore_after_timeout(s, a, b): assert not (await sem.acquire(timeout=0.1)) -@gen_cluster() -async def test_async_ctx(s, a, b): +@gen_cluster(client=True) +async def test_async_ctx(c, s, a, b): sem = await Semaphore(name="x") async with sem: assert not await sem.acquire(timeout=0.025) diff --git a/distributed/tests/test_utils_test.py b/distributed/tests/test_utils_test.py index 3a75fe6cf9..45b7966e12 100755 --- a/distributed/tests/test_utils_test.py +++ b/distributed/tests/test_utils_test.py @@ -39,6 +39,7 @@ check_thread_leak, cluster, dump_cluster_state, + ensure_no_new_clients, freeze_batched_send, gen_cluster, gen_nbytes, @@ -1081,3 +1082,31 @@ def test_sizeof(): def test_sizeof_error(input, exc, msg): with pytest.raises(exc, match=msg): SizeOf(input) + + +@gen_test() +async def test_ensure_no_new_clients(): + with ensure_no_new_clients(): + async with Scheduler() as s: + pass + async with Scheduler() as s: + # Running client already exists + async with Client(s.address, asynchronous=True): + try: + # We detect a running client + with pytest.raises(AssertionError): + with ensure_no_new_clients(): + c = await Client(s.address, asynchronous=True) + finally: + await c.close() + + # Forbid also clients that are closed again + with pytest.raises(AssertionError): + with ensure_no_new_clients(): + async with Client(s.address, asynchronous=True): + pass + + # But we also want to forbid any initialization + with pytest.raises(AssertionError): + with ensure_no_new_clients(): + Client(s.address, asynchronous=True) diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 1039bd01a6..be9f77613d 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -1061,6 +1061,9 @@ async def f(): assert results == {a.address: 11, b.address: 11} +@pytest.mark.xfail( + reason="Async tasks do not provide worker context. See https://github.com/dask/distributed/pull/7339" +) def test_get_client_coroutine_sync(client, s, a, b): async def f(): client = await get_client() @@ -1068,8 +1071,8 @@ async def f(): result = await future return result - results = client.run(f) - assert results == {a["address"]: 11, b["address"]: 11} + for w in [a, b]: + assert client.submit(f, workers=[w["address"]]).result() == 1 @gen_cluster() @@ -1240,21 +1243,6 @@ async def test_deque_handler(s): assert any(msg.msg == "foo456" for msg in deque_handler.deque) -def test_get_worker_name(client): - def f(): - get_client().submit(inc, 1).result() - - client.run(f) - - def func(dask_scheduler): - return list(dask_scheduler.clients) - - start = time() - while not any("worker" in n for n in client.run_on_scheduler(func)): - sleep(0.1) - assert time() < start + 10 - - @gen_cluster(nthreads=[], client=True) async def test_scheduler_address_config(c, s): with dask.config.set({"scheduler-address": s.address}): @@ -2225,8 +2213,8 @@ def get_worker_client_id(): def_client = get_client() return def_client.id - worker_client = await c.submit(get_worker_client_id) - assert worker_client == existing_client + wclient = await c.submit(get_worker_client_id) + assert wclient == existing_client assert not Worker._initialized_clients diff --git a/distributed/utils_test.py b/distributed/utils_test.py index 8665dbeed2..3e349fdb7a 100644 --- a/distributed/utils_test.py +++ b/distributed/utils_test.py @@ -312,6 +312,14 @@ def get(self): return getattr(sys.modules[self.modname], self.slotname) +@contextmanager +def ensure_no_new_clients(): + before = set(Client._instances) + yield + after = set(Client._instances) + assert after.issubset(before) + + def varying(items): """ Return a function that returns a result (or raises an exception) diff --git a/distributed/worker.py b/distributed/worker.py index 59b2c8784d..d34e4aa6a3 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -41,7 +41,7 @@ overload, ) -from tlz import first, keymap, pluck +from tlz import keymap, pluck from tornado.ioloop import IOLoop import dask @@ -106,6 +106,7 @@ parse_ports, recursive_to_dict, run_in_executor_with_context, + set_thread_state, silence_logging, thread_state, wait_for, @@ -2154,10 +2155,26 @@ def find_missing(self) -> None: ################ def run(self, comm, function, args=(), wait=True, kwargs=None): - return run(self, comm, function=function, args=args, kwargs=kwargs, wait=wait) + return run( + self, + comm, + function=function, + args=args, + kwargs=kwargs, + wait=wait, + thread_state={"execution_state": self.execution_state}, + ) def run_coroutine(self, comm, function, args=(), kwargs=None, wait=True): - return run(self, comm, function=function, args=args, kwargs=kwargs, wait=wait) + return run( + self, + comm, + function=function, + args=args, + kwargs=kwargs, + wait=wait, + thread_state={"execution_state": self.execution_state}, + ) async def actor_execute( self, @@ -2208,9 +2225,13 @@ async def _maybe_deserialize_task( start = time() # Offload deserializing large tasks if sizeof(ts.run_spec) > OFFLOAD_THRESHOLD: - function, args, kwargs = await offload(_deserialize, *ts.run_spec) + function, args, kwargs = await offload( + _deserialize, *ts.run_spec, execution_state=self.execution_state + ) else: - function, args, kwargs = _deserialize(*ts.run_spec) + function, args, kwargs = _deserialize( + *ts.run_spec, execution_state=self.execution_state + ) stop = time() if stop - start > 0.010: @@ -2704,10 +2725,7 @@ def get_worker() -> Worker: try: return thread_state.execution_state["worker"] except AttributeError: - try: - return first(w for w in Worker._instances if w.status in WORKER_ANY_RUNNING) - except StopIteration: - raise ValueError("No workers found") + raise ValueError("No workers found") from None def get_client(address=None, timeout=None, resolve_address=True) -> Client: @@ -2925,21 +2943,26 @@ def loads_function(bytes_object): @context_meter.meter("deserialize") -def _deserialize(function=None, args=None, kwargs=None, task=NO_VALUE): +def _deserialize( + function=None, args=None, kwargs=None, task=NO_VALUE, execution_state=None +): """Deserialize task inputs and regularize to func, args, kwargs""" - if function is not None: - function = loads_function(function) - if args and isinstance(args, bytes): - args = pickle.loads(args) - if kwargs and isinstance(kwargs, bytes): - kwargs = pickle.loads(kwargs) + # Some objects require threadlocal state during deserialization, e.g. to + # detect the current worker + with set_thread_state(execution_state=execution_state): + if function is not None: + function = loads_function(function) + if args and isinstance(args, bytes): + args = pickle.loads(args) + if kwargs and isinstance(kwargs, bytes): + kwargs = pickle.loads(kwargs) - if task is not NO_VALUE: - assert not function and not args and not kwargs - function = execute_task - args = (task,) + if task is not NO_VALUE: + assert not function and not args and not kwargs + function = execute_task + args = (task,) - return function, args or (), kwargs or {} + return function, args or (), kwargs or {} def execute_task(task): @@ -3055,11 +3078,12 @@ def apply_function( ident = threading.get_ident() with active_threads_lock: active_threads[ident] = key - thread_state.start_time = time() - thread_state.execution_state = execution_state - thread_state.key = key - - msg = apply_function_simple(function, args, kwargs, time_delay) + with set_thread_state( + start_time=time(), + execution_state=execution_state, + key=key, + ): + msg = apply_function_simple(function, args, kwargs, time_delay) with active_threads_lock: del active_threads[ident] @@ -3186,16 +3210,18 @@ def apply_function_actor( with active_threads_lock: active_threads[ident] = key - thread_state.execution_state = execution_state - thread_state.key = key - thread_state.actor = True + with set_thread_state( + start_time=time(), + execution_state=execution_state, + key=key, + actor=True, + ): + result = function(*args, **kwargs) - result = function(*args, **kwargs) + with active_threads_lock: + del active_threads[ident] - with active_threads_lock: - del active_threads[ident] - - return result + return result def get_msg_safe_str(msg): @@ -3259,7 +3285,9 @@ def convert_kwargs_to_str(kwargs: dict, max_len: int | None = None) -> str: return "{{{}}}".format(", ".join(strs)) -async def run(server, comm, function, args=(), kwargs=None, wait=True): +async def run( + server, comm, function, args=(), kwargs=None, wait=True, thread_state=None +): kwargs = kwargs or {} function = pickle.loads(function) is_coro = iscoroutinefunction(function) @@ -3273,7 +3301,9 @@ async def run(server, comm, function, args=(), kwargs=None, wait=True): if has_arg(function, "dask_scheduler"): kwargs["dask_scheduler"] = server logger.info("Run out-of-band function %r", funcname(function)) + try: + thread_state = thread_state or {} if not is_coro: result = function(*args, **kwargs) else: From f55151248ecd3b95c388bd685dbf93a02d973bd6 Mon Sep 17 00:00:00 2001 From: fjetter Date: Thu, 23 Feb 2023 17:13:28 +0100 Subject: [PATCH 02/11] Late bindings --- distributed/actor.py | 2 +- distributed/client.py | 46 ++++++++++++++++++---------- distributed/semaphore.py | 41 ++++++++++++++++--------- distributed/tests/test_utils_test.py | 21 +++---------- 4 files changed, 61 insertions(+), 49 deletions(-) diff --git a/distributed/actor.py b/distributed/actor.py index d3bfb771de..227e822e58 100644 --- a/distributed/actor.py +++ b/distributed/actor.py @@ -178,7 +178,7 @@ def __getattr__(self, key): raise ValueError( "Worker holding Actor was lost. Status: " + self._future.status ) - + self._try_bind_worker_client() if ( self._worker and self._worker.address == self._address diff --git a/distributed/client.py b/distributed/client.py index 4fd56fae6d..953ea20146 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -198,41 +198,53 @@ class Future(WrappedKey): def __init__(self, key, client=None, inform=True, state=None): self.key = key self._cleared = False - tkey = stringify(key) - if not client: + self._tkey = stringify(key) + self._client = client + self._input_state = state + self._inform = inform + self._state = None + self._bind_late() + + @property + def client(self): + self._bind_late() + return self._client + + def _bind_late(self): + if not self._client: try: client = get_client() except ValueError: client = None - self.client = client - if self.client: - self.client._inc_ref(tkey) - self._generation = self.client.generation + self._client = client + if self._client and not self._state: + self._client._inc_ref(self._tkey) + self._generation = self._client.generation - if tkey in self.client.futures: - self._state = self.client.futures[tkey] + if self._tkey in self._client.futures: + self._state = self._client.futures[self._tkey] else: - self._state = self.client.futures[tkey] = FutureState() + self._state = self._client.futures[self._tkey] = FutureState() - if inform: - self.client._send_to_scheduler( + if self._inform: + self._client._send_to_scheduler( { "op": "client-desires-keys", - "keys": [stringify(key)], - "client": self.client.id, + "keys": [self._tkey], + "client": self._client.id, } ) - if state is not None: + if self._input_state is not None: try: - handler = self.client._state_handlers[state] + handler = self._client._state_handlers[self._input_state] except KeyError: pass else: - handler(key=key) + handler(key=self.key) def _verify_initialized(self): - if not self.client: + if not self.client or not self._state: raise RuntimeError( f"{type(self)} object not properly initialized. This can happen" " if the object is being deserialized outside of the context of" diff --git a/distributed/semaphore.py b/distributed/semaphore.py index d4ccf1cddf..59d6951d7b 100644 --- a/distributed/semaphore.py +++ b/distributed/semaphore.py @@ -342,21 +342,9 @@ def __init__( scheduler_rpc=None, loop=None, ): - self.scheduler = scheduler_rpc - self.loop = loop - if scheduler_rpc is None or loop is None: - try: - try: - worker = get_worker() - self.scheduler = scheduler_rpc or worker.scheduler - self.loop = loop or worker.loop + self._scheduler = scheduler_rpc + self._loop = loop - except ValueError: - client = get_client() - self.scheduler = scheduler_rpc or client.scheduler - self.loop = loop or client.loop - except ValueError: - pass self.name = name or "semaphore-" + uuid.uuid4().hex self.max_leases = max_leases self.id = uuid.uuid4().hex @@ -387,6 +375,31 @@ def __init__( if self.loop is not None: self.loop.add_callback(pc.start) + @property + def scheduler(self): + self._bind_late() + return self._scheduler + + @property + def loop(self): + self._bind_late() + return self._loop + + def _bind_late(self): + if self._scheduler is None or self._loop is None: + try: + try: + worker = get_worker() + self._scheduler = self._scheduler or worker.scheduler + self._loop = self._loop or worker.loop + + except ValueError: + client = get_client() + self._scheduler = self._scheduler or client.scheduler + self._loop = self._loop or client.loop + except ValueError: + pass + def _verify_running(self): if not self.scheduler or not self.loop: raise RuntimeError( diff --git a/distributed/tests/test_utils_test.py b/distributed/tests/test_utils_test.py index 45b7966e12..c2c87ef91c 100755 --- a/distributed/tests/test_utils_test.py +++ b/distributed/tests/test_utils_test.py @@ -1090,23 +1090,10 @@ async def test_ensure_no_new_clients(): async with Scheduler() as s: pass async with Scheduler() as s: - # Running client already exists - async with Client(s.address, asynchronous=True): - try: - # We detect a running client - with pytest.raises(AssertionError): - with ensure_no_new_clients(): - c = await Client(s.address, asynchronous=True) - finally: - await c.close() - - # Forbid also clients that are closed again + with ensure_no_new_clients(): + pass with pytest.raises(AssertionError): with ensure_no_new_clients(): async with Client(s.address, asynchronous=True): - pass - - # But we also want to forbid any initialization - with pytest.raises(AssertionError): - with ensure_no_new_clients(): - Client(s.address, asynchronous=True) + with ensure_no_new_clients(): + pass From 045658e1cda7a9b3d5690487781272b0e1b1394e Mon Sep 17 00:00:00 2001 From: fjetter Date: Thu, 23 Feb 2023 17:16:47 +0100 Subject: [PATCH 03/11] self review --- distributed/actor.py | 2 -- distributed/tests/test_client.py | 17 +++++------------ distributed/worker.py | 25 +++---------------------- 3 files changed, 8 insertions(+), 36 deletions(-) diff --git a/distributed/actor.py b/distributed/actor.py index 227e822e58..b4baa44c92 100644 --- a/distributed/actor.py +++ b/distributed/actor.py @@ -105,8 +105,6 @@ def __init__(self, cls, address, key, worker=None): def _try_bind_worker_client(self): try: - # TODO: `get_worker` may return the wrong worker instance for async local clusters (most tests) - # when run outside of a task (when deserializing a key pointing to an Actor, etc.) self._worker = get_worker() except ValueError: self._worker = None diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index b533837012..aa02962850 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -5356,20 +5356,13 @@ def test_serialize_collections_of_futures_sync(c): df = pd.DataFrame({"x": [1, 2, 3]}) ddf = dd.from_pandas(df, npartitions=2).persist() - # future = c.scatter(ddf) + future = c.scatter(ddf) - # result = future.result() - # futs = futures_of(result) - assert_eq(ddf.compute(), df) + result = future.result() + assert_eq(result.compute(), df) - # assert future.type == dd.DataFrame - # futs = futures_of(future) - def inner(x, y): - futs = futures_of(x) - df = x.compute() - assert_eq(df, y) - - c.submit(inner, ddf, df).result() + assert future.type == dd.DataFrame + assert c.submit(lambda x, y: assert_eq(x.compute(), y), future, df).result() def _dynamic_workload(x, delay=0.01): diff --git a/distributed/worker.py b/distributed/worker.py index d34e4aa6a3..cf0ffc99af 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -2155,26 +2155,10 @@ def find_missing(self) -> None: ################ def run(self, comm, function, args=(), wait=True, kwargs=None): - return run( - self, - comm, - function=function, - args=args, - kwargs=kwargs, - wait=wait, - thread_state={"execution_state": self.execution_state}, - ) + return run(self, comm, function=function, args=args, kwargs=kwargs, wait=wait) def run_coroutine(self, comm, function, args=(), kwargs=None, wait=True): - return run( - self, - comm, - function=function, - args=args, - kwargs=kwargs, - wait=wait, - thread_state={"execution_state": self.execution_state}, - ) + return run(self, comm, function=function, args=args, kwargs=kwargs, wait=wait) async def actor_execute( self, @@ -3285,9 +3269,7 @@ def convert_kwargs_to_str(kwargs: dict, max_len: int | None = None) -> str: return "{{{}}}".format(", ".join(strs)) -async def run( - server, comm, function, args=(), kwargs=None, wait=True, thread_state=None -): +async def run(server, comm, function, args=(), kwargs=None, wait=True): kwargs = kwargs or {} function = pickle.loads(function) is_coro = iscoroutinefunction(function) @@ -3303,7 +3285,6 @@ async def run( logger.info("Run out-of-band function %r", funcname(function)) try: - thread_state = thread_state or {} if not is_coro: result = function(*args, **kwargs) else: From 285dde959349b3181f20d99c4175bc3e29a4fb60 Mon Sep 17 00:00:00 2001 From: fjetter Date: Fri, 10 Mar 2023 12:17:53 +0100 Subject: [PATCH 04/11] more lazy clients --- distributed/event.py | 14 ++++++++++---- distributed/lock.py | 28 +++++++++++++++++++++------- distributed/queues.py | 18 ++++++++++++++---- distributed/variable.py | 18 ++++++++++-------- distributed/worker.py | 39 ++++++++++++++++----------------------- 5 files changed, 71 insertions(+), 46 deletions(-) diff --git a/distributed/event.py b/distributed/event.py index 91e70d0031..145abbf385 100644 --- a/distributed/event.py +++ b/distributed/event.py @@ -178,12 +178,18 @@ class Event: """ def __init__(self, name=None, client=None): - try: - self.client = client or get_client() - except ValueError: - self.client = None + self._client = client self.name = name or "event-" + uuid.uuid4().hex + @property + def client(self): + if not self._client: + try: + self._client = get_client() + except ValueError: + pass + return self._client + def __await__(self): """async constructor diff --git a/distributed/lock.py b/distributed/lock.py index c4a44aebda..99ec34cd6f 100644 --- a/distributed/lock.py +++ b/distributed/lock.py @@ -7,9 +7,8 @@ from dask.utils import parse_timedelta -from distributed.client import Client from distributed.utils import TimeoutError, log_errors, wait_for -from distributed.worker import get_worker +from distributed.worker import get_client logger = logging.getLogger(__name__) @@ -95,15 +94,28 @@ class Lock: """ def __init__(self, name=None, client=None): - try: - self.client = client or Client.current() - except ValueError: - # Initialise new client - self.client = get_worker().client + self._client = client self.name = name or "lock-" + uuid.uuid4().hex self.id = uuid.uuid4().hex self._locked = False + @property + def client(self): + if not self._client: + try: + self._client = get_client() + except ValueError: + pass + return self._client + + def _verify_running(self): + if not self.client: + raise RuntimeError( + f"{type(self)} object not properly initialized. This can happen" + " if the object is being deserialized outside of the context of" + " a Client or Worker." + ) + def acquire(self, blocking=True, timeout=None): """Acquire the lock @@ -127,6 +139,7 @@ def acquire(self, blocking=True, timeout=None): ------- True or False whether or not it successfully acquired the lock """ + self._verify_running() timeout = parse_timedelta(timeout) if not blocking: @@ -145,6 +158,7 @@ def acquire(self, blocking=True, timeout=None): def release(self): """Release the lock if already acquired""" + self._verify_running() if not self.locked(): raise ValueError("Lock is not yet acquired") result = self.client.sync( diff --git a/distributed/queues.py b/distributed/queues.py index 65cb80c204..0a026d2c3b 100644 --- a/distributed/queues.py +++ b/distributed/queues.py @@ -167,18 +167,27 @@ class Queue: """ def __init__(self, name=None, client=None, maxsize=0): - try: - self.client = client or get_client() - except ValueError: - self.client = None + self._client = client self.name = name or "queue-" + uuid.uuid4().hex self.maxsize = maxsize + self._maybe_start() + + def _maybe_start(self): if self.client: if self.client.asynchronous: self._started = asyncio.ensure_future(self._start()) else: self.client.sync(self._start) + @property + def client(self): + if not self._client: + try: + self._client = get_client() + except ValueError: + pass + return self._client + def _verify_running(self): if not self.client: raise RuntimeError( @@ -192,6 +201,7 @@ async def _start(self): return self def __await__(self): + self._maybe_start() if hasattr(self, "_started"): return self._started.__await__() else: diff --git a/distributed/variable.py b/distributed/variable.py index 2816840c09..c541d8e133 100644 --- a/distributed/variable.py +++ b/distributed/variable.py @@ -136,10 +136,6 @@ class Variable: it is wise not to send too much. If you want to share a large amount of data then ``scatter`` it and share the future instead. - .. warning:: - - This object is experimental and has known issues in Python 2 - Parameters ---------- name: string (optional) @@ -166,12 +162,18 @@ class Variable: """ def __init__(self, name=None, client=None): - try: - self.client = client or get_client() - except ValueError: - self.client = None + self._client = client self.name = name or "variable-" + uuid.uuid4().hex + @property + def client(self): + if not self._client: + try: + self._client = get_client() + except ValueError: + pass + return self._client + def _verify_running(self): if not self.client: raise RuntimeError( diff --git a/distributed/worker.py b/distributed/worker.py index cf0ffc99af..23d727a151 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -2209,13 +2209,9 @@ async def _maybe_deserialize_task( start = time() # Offload deserializing large tasks if sizeof(ts.run_spec) > OFFLOAD_THRESHOLD: - function, args, kwargs = await offload( - _deserialize, *ts.run_spec, execution_state=self.execution_state - ) + function, args, kwargs = await offload(_deserialize, *ts.run_spec) else: - function, args, kwargs = _deserialize( - *ts.run_spec, execution_state=self.execution_state - ) + function, args, kwargs = _deserialize(*ts.run_spec) stop = time() if stop - start > 0.010: @@ -2927,26 +2923,23 @@ def loads_function(bytes_object): @context_meter.meter("deserialize") -def _deserialize( - function=None, args=None, kwargs=None, task=NO_VALUE, execution_state=None -): +def _deserialize(function=None, args=None, kwargs=None, task=NO_VALUE): """Deserialize task inputs and regularize to func, args, kwargs""" # Some objects require threadlocal state during deserialization, e.g. to # detect the current worker - with set_thread_state(execution_state=execution_state): - if function is not None: - function = loads_function(function) - if args and isinstance(args, bytes): - args = pickle.loads(args) - if kwargs and isinstance(kwargs, bytes): - kwargs = pickle.loads(kwargs) - - if task is not NO_VALUE: - assert not function and not args and not kwargs - function = execute_task - args = (task,) - - return function, args or (), kwargs or {} + if function is not None: + function = loads_function(function) + if args and isinstance(args, bytes): + args = pickle.loads(args) + if kwargs and isinstance(kwargs, bytes): + kwargs = pickle.loads(kwargs) + + if task is not NO_VALUE: + assert not function and not args and not kwargs + function = execute_task + args = (task,) + + return function, args or (), kwargs or {} def execute_task(task): From 77905767c7685c8733677df9369c4a7d8799836c Mon Sep 17 00:00:00 2001 From: fjetter Date: Fri, 24 Feb 2023 11:32:25 +0100 Subject: [PATCH 05/11] Never inform on Actor --- distributed/actor.py | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/distributed/actor.py b/distributed/actor.py index b4baa44c92..1d3b446347 100644 --- a/distributed/actor.py +++ b/distributed/actor.py @@ -97,23 +97,23 @@ def __init__(self, cls, address, key, worker=None): self._address = address self._key = key self._future = None - if worker: - self._worker = worker - self._client = None - else: - self._try_bind_worker_client() + self._worker = worker + self._client = None + self._try_bind_worker_client() def _try_bind_worker_client(self): - try: - self._worker = get_worker() - except ValueError: - self._worker = None - try: - self._client = get_client() - self._future = Future(self._key, inform=self._worker is None) - # ^ When running on a worker, only hold a weak reference to the key, otherwise the key could become unreleasable. - except ValueError: - self._client = None + if not self._worker: + try: + self._worker = get_worker() + except ValueError: + self._worker = None + if not self._client: + try: + self._client = get_client() + self._future = Future(self._key, inform=False) + # ^ When running on a worker, only hold a weak reference to the key, otherwise the key could become unreleasable. + except ValueError: + self._client = None def __repr__(self): return f"" From 1198e76d45fff6dd3d7c85ae5098aa5609ddf235 Mon Sep 17 00:00:00 2001 From: fjetter Date: Wed, 8 Mar 2023 13:10:02 +0100 Subject: [PATCH 06/11] Ensure profiling messages can be serialized despite deep nesting --- distributed/profile.py | 6 +++++- distributed/protocol/tests/test_protocol.py | 19 +++++++++++++++++++ distributed/tests/test_profile.py | 2 +- 3 files changed, 25 insertions(+), 2 deletions(-) diff --git a/distributed/profile.py b/distributed/profile.py index e455b6a2ee..140f0b662f 100644 --- a/distributed/profile.py +++ b/distributed/profile.py @@ -155,7 +155,11 @@ def process( merge """ if depth is None: - depth = sys.getrecursionlimit() - 50 + # Cut off rather conservatively since the output of the profiling + # sometimes need to be recursed into as well, e.g. for serialization + # which can cause recursion errors later on since this can generate + # deeply nested dictionaries + depth = min(250, sys.getrecursionlimit() // 4) if depth <= 0: return None if any(frame.f_code.co_filename.endswith(o) for o in omit): diff --git a/distributed/protocol/tests/test_protocol.py b/distributed/protocol/tests/test_protocol.py index 3bf639451d..9b8ec58323 100644 --- a/distributed/protocol/tests/test_protocol.py +++ b/distributed/protocol/tests/test_protocol.py @@ -1,5 +1,6 @@ from __future__ import annotations +import sys from threading import Thread from time import sleep @@ -412,3 +413,21 @@ def test_sizeof_serialize(Wrapper, Wrapped): assert size <= sizeof(ser_obj) < size * 1.05 serialized = Wrapped(*serialize(ser_obj)) assert size <= sizeof(serialized) < size * 1.05 + + +def test_deeply_nested_structures(): + # These kind of deeply nested structures are generated in our profiling code + def gen_deeply_nested(depth): + msg = {} + d = msg + while depth: + depth -= 1 + d["children"] = d = {} + return msg + + msg = gen_deeply_nested(sys.getrecursionlimit() - 100) + with pytest.raises(TypeError, match="Could not serialize object"): + serialize(msg, on_error="raise") + + msg = gen_deeply_nested(sys.getrecursionlimit() // 4) + assert isinstance(serialize(msg), tuple) diff --git a/distributed/tests/test_profile.py b/distributed/tests/test_profile.py index 36ec8de6ea..7c030776bf 100644 --- a/distributed/tests/test_profile.py +++ b/distributed/tests/test_profile.py @@ -370,7 +370,7 @@ def f(i): else: return f(i - 1) - f(sys.getrecursionlimit() - 40) + f(sys.getrecursionlimit() - 100) process(frame, None, state) merge(state, state, state) From c9b1ab6b663b94e07882aba867f59e94bf8002b3 Mon Sep 17 00:00:00 2001 From: fjetter Date: Wed, 8 Mar 2023 13:20:14 +0100 Subject: [PATCH 07/11] Extend tests for stackoverflow profile --- distributed/profile.py | 9 ++++---- distributed/tests/test_profile.py | 36 ++++++++++++++----------------- 2 files changed, 21 insertions(+), 24 deletions(-) diff --git a/distributed/profile.py b/distributed/profile.py index 140f0b662f..0dac6f614b 100644 --- a/distributed/profile.py +++ b/distributed/profile.py @@ -160,14 +160,15 @@ def process( # which can cause recursion errors later on since this can generate # deeply nested dictionaries depth = min(250, sys.getrecursionlimit() // 4) - if depth <= 0: - return None + if any(frame.f_code.co_filename.endswith(o) for o in omit): return None prev = frame.f_back - if prev is not None and ( - stop is None or not prev.f_code.co_filename.endswith(stop) + if ( + depth > 0 + and prev is not None + and (stop is None or not prev.f_code.co_filename.endswith(stop)) ): new_state = process(prev, frame, state, stop=stop, depth=depth - 1) if new_state is None: diff --git a/distributed/tests/test_profile.py b/distributed/tests/test_profile.py index 7c030776bf..adb782d2b4 100644 --- a/distributed/tests/test_profile.py +++ b/distributed/tests/test_profile.py @@ -356,23 +356,19 @@ def test_call_stack_f_lineno(f_lasti: int, f_lineno: int) -> None: def test_stack_overflow(): - old = sys.getrecursionlimit() - sys.setrecursionlimit(200) - try: - state = create() - frame = None - - def f(i): - if i == 0: - nonlocal frame - frame = sys._current_frames()[threading.get_ident()] - return - else: - return f(i - 1) - - f(sys.getrecursionlimit() - 100) - process(frame, None, state) - merge(state, state, state) - - finally: - sys.setrecursionlimit(old) + state = create() + frame = None + + def f(i): + if i == 0: + nonlocal frame + frame = sys._current_frames()[threading.get_ident()] + return + else: + return f(i - 1) + + f(sys.getrecursionlimit() - 200) + process(frame, None, state) + assert state["children"] + assert state["count"] + assert merge(state, state, state) From e0b121eb2fd414eca7e44b7c6f4128d3e2356c47 Mon Sep 17 00:00:00 2001 From: fjetter Date: Fri, 10 Mar 2023 12:15:20 +0100 Subject: [PATCH 08/11] Skip test on windows --- distributed/protocol/tests/test_protocol.py | 2 ++ distributed/shuffle/tests/test_shuffle.py | 13 +++++++------ 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/distributed/protocol/tests/test_protocol.py b/distributed/protocol/tests/test_protocol.py index 9b8ec58323..2bf9eefb9b 100644 --- a/distributed/protocol/tests/test_protocol.py +++ b/distributed/protocol/tests/test_protocol.py @@ -9,6 +9,7 @@ import dask from dask.sizeof import sizeof +from distributed.compatibility import WINDOWS from distributed.protocol import dumps, loads, maybe_compress, msgpack, to_serialize from distributed.protocol.compression import ( compressions, @@ -415,6 +416,7 @@ def test_sizeof_serialize(Wrapper, Wrapped): assert size <= sizeof(serialized) < size * 1.05 +@pytest.mark.skipif(WINDOWS, reason="On windows this is triggering a stackoverflow") def test_deeply_nested_structures(): # These kind of deeply nested structures are generated in our profiling code def gen_deeply_nested(depth): diff --git a/distributed/shuffle/tests/test_shuffle.py b/distributed/shuffle/tests/test_shuffle.py index 02604f0638..c4ee925b41 100644 --- a/distributed/shuffle/tests/test_shuffle.py +++ b/distributed/shuffle/tests/test_shuffle.py @@ -74,10 +74,10 @@ async def clean_scheduler( assert not extension.heartbeats -@pytest.mark.skipif( - pa is not None, - reason="We don't have a CI job that is installing a very old pyarrow version", -) +# @pytest.mark.skipif( +# pa is not None, +# reason="We don't have a CI job that is installing a very old pyarrow version", +# ) @gen_cluster(client=True) async def test_minimal_version(c, s, a, b): df = dask.datasets.timeseries( @@ -86,8 +86,9 @@ async def test_minimal_version(c, s, a, b): dtypes={"x": float, "y": float}, freq="10 s", ) - with pytest.raises(RuntimeError, match="requires pyarrow"): - await c.compute(dd.shuffle.shuffle(df, "x", shuffle="p2p")) + # with pytest.raises(RuntimeError, match="requires pyarrow"): + res = c.persist(dd.shuffle.shuffle(df, "x", shuffle="p2p")) + await c.compute(res) @gen_cluster(client=True) From 4950082cbb350d64a256c098406af605264d3dcf Mon Sep 17 00:00:00 2001 From: fjetter Date: Fri, 10 Mar 2023 13:14:49 +0100 Subject: [PATCH 09/11] restore test_stack_overflow --- distributed/tests/test_profile.py | 38 ++++++++++++++++++------------- 1 file changed, 22 insertions(+), 16 deletions(-) diff --git a/distributed/tests/test_profile.py b/distributed/tests/test_profile.py index adb782d2b4..5e7feebf2f 100644 --- a/distributed/tests/test_profile.py +++ b/distributed/tests/test_profile.py @@ -356,19 +356,25 @@ def test_call_stack_f_lineno(f_lasti: int, f_lineno: int) -> None: def test_stack_overflow(): - state = create() - frame = None - - def f(i): - if i == 0: - nonlocal frame - frame = sys._current_frames()[threading.get_ident()] - return - else: - return f(i - 1) - - f(sys.getrecursionlimit() - 200) - process(frame, None, state) - assert state["children"] - assert state["count"] - assert merge(state, state, state) + old = sys.getrecursionlimit() + sys.setrecursionlimit(200) + try: + state = create() + frame = None + + def f(i): + if i == 0: + nonlocal frame + frame = sys._current_frames()[threading.get_ident()] + return + else: + return f(i - 1) + + f(sys.getrecursionlimit() - 200) + process(frame, None, state) + assert state["children"] + assert state["count"] + assert merge(state, state, state) + + finally: + sys.setrecursionlimit(old) From f7786eed8befc26d7972b7a81c3f56aafd851878 Mon Sep 17 00:00:00 2001 From: fjetter Date: Fri, 10 Mar 2023 13:15:37 +0100 Subject: [PATCH 10/11] adjust limits --- distributed/tests/test_profile.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/distributed/tests/test_profile.py b/distributed/tests/test_profile.py index 5e7feebf2f..3036d12745 100644 --- a/distributed/tests/test_profile.py +++ b/distributed/tests/test_profile.py @@ -357,7 +357,7 @@ def test_call_stack_f_lineno(f_lasti: int, f_lineno: int) -> None: def test_stack_overflow(): old = sys.getrecursionlimit() - sys.setrecursionlimit(200) + sys.setrecursionlimit(300) try: state = create() frame = None @@ -370,7 +370,7 @@ def f(i): else: return f(i - 1) - f(sys.getrecursionlimit() - 200) + f(sys.getrecursionlimit() - 100) process(frame, None, state) assert state["children"] assert state["count"] From c6787d8e49696e2f904b98d2dccbdd2c01b73c0a Mon Sep 17 00:00:00 2001 From: James Bourbeau Date: Mon, 20 Mar 2023 21:44:33 -0500 Subject: [PATCH 11/11] Code review --- distributed/shuffle/tests/test_shuffle.py | 13 ++++++------- distributed/tests/test_pubsub.py | 2 +- distributed/tests/test_utils_test.py | 6 ++++-- distributed/worker.py | 2 +- 4 files changed, 12 insertions(+), 11 deletions(-) diff --git a/distributed/shuffle/tests/test_shuffle.py b/distributed/shuffle/tests/test_shuffle.py index c4ee925b41..02604f0638 100644 --- a/distributed/shuffle/tests/test_shuffle.py +++ b/distributed/shuffle/tests/test_shuffle.py @@ -74,10 +74,10 @@ async def clean_scheduler( assert not extension.heartbeats -# @pytest.mark.skipif( -# pa is not None, -# reason="We don't have a CI job that is installing a very old pyarrow version", -# ) +@pytest.mark.skipif( + pa is not None, + reason="We don't have a CI job that is installing a very old pyarrow version", +) @gen_cluster(client=True) async def test_minimal_version(c, s, a, b): df = dask.datasets.timeseries( @@ -86,9 +86,8 @@ async def test_minimal_version(c, s, a, b): dtypes={"x": float, "y": float}, freq="10 s", ) - # with pytest.raises(RuntimeError, match="requires pyarrow"): - res = c.persist(dd.shuffle.shuffle(df, "x", shuffle="p2p")) - await c.compute(res) + with pytest.raises(RuntimeError, match="requires pyarrow"): + await c.compute(dd.shuffle.shuffle(df, "x", shuffle="p2p")) @gen_cluster(client=True) diff --git a/distributed/tests/test_pubsub.py b/distributed/tests/test_pubsub.py index 1b096cb19e..664ea4f4ab 100644 --- a/distributed/tests/test_pubsub.py +++ b/distributed/tests/test_pubsub.py @@ -54,7 +54,7 @@ def pingpong(a, b, start=False, n=1000, msg=1): @gen_cluster(client=True, nthreads=[]) async def test_client(c, s): - with pytest.raises(ValueError, match="No workers found"): + with pytest.raises(ValueError, match="No worker found"): get_worker() sub = Sub("a") pub = Pub("a") diff --git a/distributed/tests/test_utils_test.py b/distributed/tests/test_utils_test.py index c2c87ef91c..0d3f88180b 100755 --- a/distributed/tests/test_utils_test.py +++ b/distributed/tests/test_utils_test.py @@ -1095,5 +1095,7 @@ async def test_ensure_no_new_clients(): with pytest.raises(AssertionError): with ensure_no_new_clients(): async with Client(s.address, asynchronous=True): - with ensure_no_new_clients(): - pass + pass + async with Client(s.address, asynchronous=True): + with ensure_no_new_clients(): + pass diff --git a/distributed/worker.py b/distributed/worker.py index 23d727a151..9723436bfd 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -2705,7 +2705,7 @@ def get_worker() -> Worker: try: return thread_state.execution_state["worker"] except AttributeError: - raise ValueError("No workers found") from None + raise ValueError("No worker found") from None def get_client(address=None, timeout=None, resolve_address=True) -> Client: