Skip to content

Commit

Permalink
Ensure Nanny doesn't restart workers that fail to start, and joins su…
Browse files Browse the repository at this point in the history
…bprocess (dask#6427)
  • Loading branch information
gjoseph92 committed Oct 31, 2022
1 parent 9e89d64 commit 2eec21e
Show file tree
Hide file tree
Showing 5 changed files with 102 additions and 48 deletions.
3 changes: 2 additions & 1 deletion distributed/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from contextlib import suppress
from enum import Enum
from functools import partial
from typing import TYPE_CHECKING, Any, Callable, ClassVar, TypedDict, TypeVar
from typing import TYPE_CHECKING, Any, Callable, ClassVar, TypedDict, TypeVar, final

import tblib
from tlz import merge
Expand Down Expand Up @@ -462,6 +462,7 @@ async def start_unsafe(self):
await self.rpc.start()
return self

@final
async def start(self):
async with self._startup_lock:
if self.status == Status.failed:
Expand Down
60 changes: 35 additions & 25 deletions distributed/nanny.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,7 @@ async def instantiate(self) -> Status:
self.process = WorkerProcess(
worker_kwargs=worker_kwargs,
silence_logs=self.silence_logs,
on_exit=self._on_exit_sync,
on_exit=self._on_worker_exit_sync,
worker=self.Worker,
env=self.env,
config=self.config,
Expand Down Expand Up @@ -490,19 +490,20 @@ def is_alive(self):
def run(self, comm, *args, **kwargs):
return run(self, comm, *args, **kwargs)

def _on_exit_sync(self, exitcode):
def _on_worker_exit_sync(self, exitcode):
try:
self._ongoing_background_tasks.call_soon(self._on_exit, exitcode)
self._ongoing_background_tasks.call_soon(self._on_worker_exit, exitcode)
except AsyncTaskGroupClosedError: # Async task group has already been closed, so the nanny is already clos(ed|ing).
pass

@log_errors
async def _on_exit(self, exitcode):
async def _on_worker_exit(self, exitcode):
if self.status not in (
Status.init,
Status.closing,
Status.closed,
Status.closing_gracefully,
Status.failed,
):
try:
await self._unregister()
Expand All @@ -517,6 +518,7 @@ async def _on_exit(self, exitcode):
Status.closing,
Status.closed,
Status.closing_gracefully,
Status.failed,
):
logger.warning("Restarting worker")
await self.instantiate()
Expand Down Expand Up @@ -577,7 +579,7 @@ async def close(self, timeout=5):
if self.process is not None:
await self.kill(timeout=timeout)
except Exception:
pass
logger.exception("Error in Nanny killing Worker subprocess")
self.process = None
await self.rpc.close()
self.status = Status.closed
Expand Down Expand Up @@ -662,15 +664,15 @@ async def start(self) -> Status:
await self.process.start()
except OSError:
logger.exception("Nanny failed to start process", exc_info=True)
self.process.terminate()
# NOTE: doesn't wait for process to terminate, just for terminate signal to be sent
await self.process.terminate()
self.status = Status.failed
return self.status
try:
msg = await self._wait_until_connected(uid)
except Exception:
logger.exception("Failed to connect to process")
# NOTE: doesn't wait for process to terminate, just for terminate signal to be sent
await self.process.terminate()
self.status = Status.failed
self.process.terminate()
raise
if not msg:
return self.status
Expand Down Expand Up @@ -731,7 +733,12 @@ def mark_stopped(self):
async def kill(self, timeout: float = 2, executor_wait: bool = True) -> None:
"""
Ensure the worker process is stopped, waiting at most
*timeout* seconds before terminating it abruptly.
``timeout * 0.8`` seconds before killing it abruptly.
When `kill` returns, the worker process has been joined.
If the worker process does not terminate within ``timeout`` seconds,
even after being killed, `asyncio.TimeoutError` is raised.
"""
deadline = time() + timeout

Expand All @@ -740,32 +747,38 @@ async def kill(self, timeout: float = 2, executor_wait: bool = True) -> None:
if self.status == Status.stopping:
await self.stopped.wait()
return
assert self.status in (Status.starting, Status.running)
assert self.status in (
Status.starting,
Status.running,
Status.failed, # process failed to start, but hasn't been joined yet
), self.status
self.status = Status.stopping
logger.info("Nanny asking worker to close")

process = self.process
assert self.process
wait_timeout = timeout * 0.8
self.child_stop_q.put(
{
"op": "stop",
"timeout": max(0, deadline - time()) * 0.8,
"timeout": wait_timeout,
"executor_wait": executor_wait,
}
)
await asyncio.sleep(0) # otherwise we get broken pipe errors
self.child_stop_q.close()

while process.is_alive() and time() < deadline:
await asyncio.sleep(0.05)
try:
await process.join(wait_timeout)
return
except asyncio.TimeoutError:
pass

if process.is_alive():
logger.warning(
f"Worker process still alive after {timeout} seconds, killing"
)
try:
await process.terminate()
except Exception as e:
logger.error("Failed to kill worker process: %s", e)
logger.warning(
f"Worker process still alive after {wait_timeout} seconds, killing"
)
await process.kill()
await process.join(max(0, deadline - time()))

async def _wait_until_connected(self, uid):
while True:
Expand All @@ -783,9 +796,6 @@ async def _wait_until_connected(self, uid):
continue

if "exception" in msg:
logger.error(
"Failed while trying to start worker process: %s", msg["exception"]
)
raise msg["exception"]
else:
return msg
Expand Down
14 changes: 4 additions & 10 deletions distributed/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import dask

from distributed.utils import TimeoutError, get_mp_context
from distributed.utils import get_mp_context

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -299,15 +299,9 @@ async def join(self, timeout=None):
assert self._state.pid is not None, "can only join a started process"
if self._state.exitcode is not None:
return
if timeout is None:
await self._exit_future
else:
try:
# Shield otherwise the timeout cancels the future and our
# on_exit callback will try to set a result on a canceled future
await asyncio.wait_for(asyncio.shield(self._exit_future), timeout)
except TimeoutError:
pass
# Shield otherwise the timeout cancels the future and our
# on_exit callback will try to set a result on a canceled future
await asyncio.wait_for(asyncio.shield(self._exit_future), timeout)

def close(self):
"""
Expand Down
3 changes: 2 additions & 1 deletion distributed/tests/test_asyncprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@ async def test_simple():
assert proc.exitcode is None

t1 = time()
await proc.join(timeout=0.02)
with pytest.raises(asyncio.TimeoutError):
await proc.join(timeout=0.02)
dt = time() - t1
assert 0.2 >= dt >= 0.001
assert proc.is_alive()
Expand Down
70 changes: 59 additions & 11 deletions distributed/tests/test_nanny.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
import psutil
import pytest

from distributed.diagnostics.plugin import WorkerPlugin

pytestmark = pytest.mark.gpu

from tlz import first, valmap
Expand All @@ -27,7 +29,7 @@
from distributed.diagnostics import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.utils import TimeoutError, parse_ports
from distributed.utils import TimeoutError, get_mp_context, parse_ports
from distributed.utils_test import (
captured_logger,
gen_cluster,
Expand Down Expand Up @@ -474,21 +476,29 @@ async def test_nanny_closed_by_keyboard_interrupt(ucx_loop, protocol):
assert "remove-worker" in str(s.events)


class StartException(Exception):
pass


class BrokenWorker(worker.Worker):
async def start(self):
raise StartException("broken")
async def start_unsafe(self):
raise ValueError("broken")


@gen_cluster(nthreads=[])
async def test_worker_start_exception(s):
# make sure this raises the right Exception:
with raises_with_cause(RuntimeError, None, StartException, None):
async with Nanny(s.address, worker_class=BrokenWorker) as n:
pass
nanny = Nanny(s.address, worker_class=BrokenWorker)
with captured_logger(logger="distributed.nanny", level=logging.WARNING) as logs:
with raises_with_cause(
RuntimeError,
"Nanny failed to start",
RuntimeError,
"BrokenWorker failed to start",
):
async with nanny:
pass
assert nanny.status == Status.failed
# ^ NOTE: `Nanny.close` sets it to `closed`, then `Server.start._close_on_failure` sets it to `failed`
assert nanny.process is None
assert "Restarting worker" not in logs.getvalue()
# Avoid excessive spewing. (It's also printed once extra within the subprocess, which is okay.)
assert logs.getvalue().count("ValueError: broken") == 1, logs.getvalue()


@gen_cluster(nthreads=[])
Expand Down Expand Up @@ -571,6 +581,44 @@ async def test_restart_memory(c, s, n):
await asyncio.sleep(0.1)


class BlockClose(WorkerPlugin):
def __init__(self, close_happened):
self.close_happened = close_happened

async def teardown(self, worker):
# Never let the worker cleanly shut down, so it has to be killed
self.close_happened.set()
while True:
await asyncio.sleep(10)


@pytest.mark.slow
@gen_cluster(nthreads=[])
async def test_close_joins(s):
close_happened = get_mp_context().Event()

nanny = Nanny(s.address, plugins=[BlockClose(close_happened)])
async with nanny:
p = nanny.process
assert p
close_t = asyncio.create_task(nanny.close())

while not close_happened.wait(0):
await asyncio.sleep(0.01)

assert not close_t.done()
assert nanny.status == Status.closing
assert nanny.process and nanny.process.status == Status.stopping

await close_t

assert nanny.status == Status.closed
assert not nanny.process

assert p.status == Status.stopped
assert not p.process


@gen_cluster(Worker=Nanny, nthreads=[("", 1)])
async def test_scheduler_crash_doesnt_restart(s, a):
# Simulate a scheduler crash by disconnecting it first
Expand Down

0 comments on commit 2eec21e

Please sign in to comment.