diff --git a/distributed/tests/test_worker_memory.py b/distributed/tests/test_worker_memory.py index 7698a2bddc2..e385ca07b34 100644 --- a/distributed/tests/test_worker_memory.py +++ b/distributed/tests/test_worker_memory.py @@ -480,6 +480,73 @@ def f(ev): assert "Resuming worker" in logger.getvalue() +@gen_cluster( + client=True, + nthreads=[("", 1), ("", 1)], + config={ + "distributed.worker.memory.target": False, + "distributed.worker.memory.spill": False, + "distributed.worker.memory.pause": False, + }, +) +async def test_pause_prevents_deps_fetch(c, s, a, b): + """A worker is paused while there are dependencies ready to fetch, but all other + workers are in flight + """ + a_addr = a.address + + class X: + def __sizeof__(self): + return 2**40 # Disable clustering in select_keys_for_gather + + def __reduce__(self): + return X.pause_on_unpickle, () + + @staticmethod + def pause_on_unpickle(): + # Note: outside of task execution, distributed.get_worker() + # returns a random worker running in the process + for w in Worker._instances: + if w.address == a_addr: + w.status = Status.paused + return X() + assert False + + x = c.submit(X, key="x", workers=[b.address]) + y = c.submit(inc, 1, key="y", workers=[b.address]) + await wait([x, y]) + w = c.submit(lambda _: None, x, key="w", priority=1, workers=[a.address]) + z = c.submit(inc, y, key="z", priority=0, workers=[a.address]) + + # - w and z reach worker a within the same message + # - w and z respectively make x and y go into fetch state. + # w has a higher priority than z, therefore w's dependency x has a higher priority + # than z's dependency y. + # a.data_needed = ["x", "y"] + # - ensure_communicating decides to fetch x but not to fetch y together with it, as + # it thinks x is 1TB in size + # - x fetch->flight; a is added to in_flight_workers + # - y is skipped by ensure_communicating since all workers that hold a replica are + # in flight + # - x reaches a and sends a into paused state + # - x flight->memory; a is removed from in_flight_workers + # - ensure_communicating is triggered again + # - ensure_communicating refuses to fetch y because the worker is paused + + while "y" not in a.tasks or a.tasks["y"].state != "fetch": + await asyncio.sleep(0.01) + await asyncio.sleep(0.1) + assert a.tasks["y"].state == "fetch" + assert "y" not in a.data + assert [ts.key for ts in a.data_needed] == ["y"] + + # Unpausing kicks off ensure_communicating again + a.status = Status.running + assert await z == 3 + assert a.tasks["y"].state == "memory" + assert "y" in a.data + + @gen_cluster( client=True, nthreads=[("", 1)], diff --git a/distributed/worker.py b/distributed/worker.py index 6475fea0930..e5563ed09d7 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -2748,6 +2748,9 @@ def stimulus_story( return [ev for ev in self.stimulus_log if getattr(ev, "key", None) in keys] def ensure_communicating(self) -> None: + if self.status != Status.running: + return + stimulus_id = f"ensure-communicating-{time()}" skipped_worker_in_flight_or_busy = [] @@ -3489,7 +3492,7 @@ async def _maybe_deserialize_task( raise def _ensure_computing(self) -> RecsInstrs: - if self.status in (Status.paused, Status.closing_gracefully): + if self.status != Status.running: return {}, [] recs: Recs = {}