diff --git a/distributed/nanny.py b/distributed/nanny.py index 0c4f4172760..863c08e110a 100644 --- a/distributed/nanny.py +++ b/distributed/nanny.py @@ -390,17 +390,14 @@ async def start_unsafe(self): return self - async def kill(self, timeout: float = 2, reason: str = "nanny-kill") -> None: + async def kill(self, timeout: float = 5, reason: str = "nanny-kill") -> None: """Kill the local worker process Blocks until both the process is down and the scheduler is properly informed """ - if self.process is None: - return - - deadline = time() + timeout - await self.process.kill(reason=reason, timeout=0.8 * (deadline - time())) + if self.process is not None: + await self.process.kill(reason=reason, timeout=0.8 * timeout) async def instantiate(self) -> Status: """Start a local worker process @@ -822,7 +819,7 @@ def mark_stopped(self): async def kill( self, - timeout: float = 2, + timeout: float = 5, executor_wait: bool = True, reason: str = "workerprocess-kill", ) -> None: diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 805db3cdbb3..8a1575611f1 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -53,6 +53,7 @@ NO_AMM, BlockedGatherDep, BlockedGetData, + BlockedKillNanny, BrokenComm, NoSchedulerDelayWorker, assert_story, @@ -1132,22 +1133,8 @@ async def test_restart_waits_for_new_workers(c, s, *workers): assert set(s.workers.values()).isdisjoint(original_workers.values()) -class SlowKillNanny(Nanny): - def __init__(self, *args, **kwargs): - self.kill_proceed = asyncio.Event() - self.kill_called = asyncio.Event() - super().__init__(*args, **kwargs) - - async def kill(self, *, timeout, reason=None): - self.kill_called.set() - print("kill called") - await wait_for(self.kill_proceed.wait(), timeout) - print("kill proceed") - return await super().kill(timeout=timeout, reason=reason) - - @pytest.mark.slow -@gen_cluster(client=True, Worker=SlowKillNanny, nthreads=[("", 1)] * 2) +@gen_cluster(client=True, Worker=BlockedKillNanny, nthreads=[("", 1)] * 2) async def test_restart_nanny_timeout_exceeded(c, s, a, b): try: f = c.submit(div, 1, 0) @@ -1162,8 +1149,8 @@ async def test_restart_nanny_timeout_exceeded(c, s, a, b): TimeoutError, match=r"2/2 nanny worker\(s\) did not shut down within 1s" ): await c.restart(timeout="1s") - assert a.kill_called.is_set() - assert b.kill_called.is_set() + assert a.in_kill.is_set() + assert b.in_kill.is_set() assert not s.workers assert not s.erred_tasks @@ -1175,8 +1162,8 @@ async def test_restart_nanny_timeout_exceeded(c, s, a, b): assert f.status == "cancelled" assert fr.status == "cancelled" finally: - a.kill_proceed.set() - b.kill_proceed.set() + a.wait_kill.set() + b.wait_kill.set() @gen_cluster(client=True, nthreads=[("", 1)] * 2) @@ -1260,7 +1247,7 @@ async def test_restart_some_nannies_some_not(c, s, a, b): @gen_cluster( client=True, nthreads=[("", 1)], - Worker=SlowKillNanny, + Worker=BlockedKillNanny, worker_kwargs={"heartbeat_interval": "1ms"}, ) async def test_restart_heartbeat_before_closing(c, s, n): @@ -1271,13 +1258,13 @@ async def test_restart_heartbeat_before_closing(c, s, n): prev_workers = dict(s.workers) restart_task = asyncio.create_task(s.restart(stimulus_id="test")) - await n.kill_called.wait() + await n.in_kill.wait() await asyncio.sleep(0.5) # significantly longer than the heartbeat interval # WorkerState should not be removed yet, because the worker hasn't been told to close assert s.workers - n.kill_proceed.set() + n.wait_kill.set() # Wait until the worker has left (possibly until it's come back too) while s.workers == prev_workers: await asyncio.sleep(0.01)