Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky committed Mar 5, 2024
1 parent d59901a commit d000428
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 29 deletions.
11 changes: 4 additions & 7 deletions distributed/nanny.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,17 +390,14 @@ async def start_unsafe(self):

return self

async def kill(self, timeout: float = 2, reason: str = "nanny-kill") -> None:
async def kill(self, timeout: float = 5, reason: str = "nanny-kill") -> None:
"""Kill the local worker process
Blocks until both the process is down and the scheduler is properly
informed
"""
if self.process is None:
return

deadline = time() + timeout
await self.process.kill(reason=reason, timeout=0.8 * (deadline - time()))
if self.process is not None:
await self.process.kill(reason=reason, timeout=0.8 * timeout)

async def instantiate(self) -> Status:
"""Start a local worker process
Expand Down Expand Up @@ -822,7 +819,7 @@ def mark_stopped(self):

async def kill(
self,
timeout: float = 2,
timeout: float = 5,
executor_wait: bool = True,
reason: str = "workerprocess-kill",
) -> None:
Expand Down
31 changes: 9 additions & 22 deletions distributed/tests/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
NO_AMM,
BlockedGatherDep,
BlockedGetData,
BlockedKillNanny,
BrokenComm,
NoSchedulerDelayWorker,
assert_story,
Expand Down Expand Up @@ -1132,22 +1133,8 @@ async def test_restart_waits_for_new_workers(c, s, *workers):
assert set(s.workers.values()).isdisjoint(original_workers.values())


class SlowKillNanny(Nanny):
def __init__(self, *args, **kwargs):
self.kill_proceed = asyncio.Event()
self.kill_called = asyncio.Event()
super().__init__(*args, **kwargs)

async def kill(self, *, timeout, reason=None):
self.kill_called.set()
print("kill called")
await wait_for(self.kill_proceed.wait(), timeout)
print("kill proceed")
return await super().kill(timeout=timeout, reason=reason)


@pytest.mark.slow
@gen_cluster(client=True, Worker=SlowKillNanny, nthreads=[("", 1)] * 2)
@gen_cluster(client=True, Worker=BlockedKillNanny, nthreads=[("", 1)] * 2)
async def test_restart_nanny_timeout_exceeded(c, s, a, b):
try:
f = c.submit(div, 1, 0)
Expand All @@ -1162,8 +1149,8 @@ async def test_restart_nanny_timeout_exceeded(c, s, a, b):
TimeoutError, match=r"2/2 nanny worker\(s\) did not shut down within 1s"
):
await c.restart(timeout="1s")
assert a.kill_called.is_set()
assert b.kill_called.is_set()
assert a.in_kill.is_set()
assert b.in_kill.is_set()

assert not s.workers
assert not s.erred_tasks
Expand All @@ -1175,8 +1162,8 @@ async def test_restart_nanny_timeout_exceeded(c, s, a, b):
assert f.status == "cancelled"
assert fr.status == "cancelled"
finally:
a.kill_proceed.set()
b.kill_proceed.set()
a.wait_kill.set()
b.wait_kill.set()


@gen_cluster(client=True, nthreads=[("", 1)] * 2)
Expand Down Expand Up @@ -1260,7 +1247,7 @@ async def test_restart_some_nannies_some_not(c, s, a, b):
@gen_cluster(
client=True,
nthreads=[("", 1)],
Worker=SlowKillNanny,
Worker=BlockedKillNanny,
worker_kwargs={"heartbeat_interval": "1ms"},
)
async def test_restart_heartbeat_before_closing(c, s, n):
Expand All @@ -1271,13 +1258,13 @@ async def test_restart_heartbeat_before_closing(c, s, n):
prev_workers = dict(s.workers)
restart_task = asyncio.create_task(s.restart(stimulus_id="test"))

await n.kill_called.wait()
await n.in_kill.wait()
await asyncio.sleep(0.5) # significantly longer than the heartbeat interval

# WorkerState should not be removed yet, because the worker hasn't been told to close
assert s.workers

n.kill_proceed.set()
n.wait_kill.set()
# Wait until the worker has left (possibly until it's come back too)
while s.workers == prev_workers:
await asyncio.sleep(0.01)
Expand Down

0 comments on commit d000428

Please sign in to comment.