diff --git a/distributed/core.py b/distributed/core.py index 14e6949278..45c9c95e5f 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -15,7 +15,7 @@ from contextlib import suppress from enum import Enum from functools import partial -from typing import TYPE_CHECKING, Any, Callable, ClassVar, TypedDict, TypeVar +from typing import TYPE_CHECKING, Any, Callable, ClassVar, TypedDict, TypeVar, final import tblib from tlz import merge @@ -462,6 +462,7 @@ async def start_unsafe(self): await self.rpc.start() return self + @final async def start(self): async with self._startup_lock: if self.status == Status.failed: diff --git a/distributed/nanny.py b/distributed/nanny.py index 0e73206339..8012a3d545 100644 --- a/distributed/nanny.py +++ b/distributed/nanny.py @@ -398,7 +398,7 @@ async def instantiate(self) -> Status: self.process = WorkerProcess( worker_kwargs=worker_kwargs, silence_logs=self.silence_logs, - on_exit=self._on_exit_sync, + on_exit=self._on_worker_exit_sync, worker=self.Worker, env=self.env, config=self.config, @@ -490,19 +490,20 @@ def is_alive(self): def run(self, comm, *args, **kwargs): return run(self, comm, *args, **kwargs) - def _on_exit_sync(self, exitcode): + def _on_worker_exit_sync(self, exitcode): try: - self._ongoing_background_tasks.call_soon(self._on_exit, exitcode) + self._ongoing_background_tasks.call_soon(self._on_worker_exit, exitcode) except AsyncTaskGroupClosedError: # Async task group has already been closed, so the nanny is already clos(ed|ing). pass @log_errors - async def _on_exit(self, exitcode): + async def _on_worker_exit(self, exitcode): if self.status not in ( Status.init, Status.closing, Status.closed, Status.closing_gracefully, + Status.failed, ): try: await self._unregister() @@ -517,6 +518,7 @@ async def _on_exit(self, exitcode): Status.closing, Status.closed, Status.closing_gracefully, + Status.failed, ): logger.warning("Restarting worker") await self.instantiate() @@ -577,7 +579,7 @@ async def close(self, timeout=5): if self.process is not None: await self.kill(timeout=timeout) except Exception: - pass + logger.exception("Error in Nanny killing Worker subprocess") self.process = None await self.rpc.close() self.status = Status.closed @@ -662,15 +664,15 @@ async def start(self) -> Status: await self.process.start() except OSError: logger.exception("Nanny failed to start process", exc_info=True) - self.process.terminate() + # NOTE: doesn't wait for process to terminate, just for terminate signal to be sent + await self.process.terminate() self.status = Status.failed - return self.status try: msg = await self._wait_until_connected(uid) except Exception: - logger.exception("Failed to connect to process") + # NOTE: doesn't wait for process to terminate, just for terminate signal to be sent + await self.process.terminate() self.status = Status.failed - self.process.terminate() raise if not msg: return self.status @@ -731,7 +733,12 @@ def mark_stopped(self): async def kill(self, timeout: float = 2, executor_wait: bool = True) -> None: """ Ensure the worker process is stopped, waiting at most - *timeout* seconds before terminating it abruptly. + ``timeout * 0.8`` seconds before killing it abruptly. + + When `kill` returns, the worker process has been joined. + + If the worker process does not terminate within ``timeout`` seconds, + even after being killed, `asyncio.TimeoutError` is raised. """ deadline = time() + timeout @@ -740,32 +747,38 @@ async def kill(self, timeout: float = 2, executor_wait: bool = True) -> None: if self.status == Status.stopping: await self.stopped.wait() return - assert self.status in (Status.starting, Status.running) + assert self.status in ( + Status.starting, + Status.running, + Status.failed, # process failed to start, but hasn't been joined yet + ), self.status self.status = Status.stopping logger.info("Nanny asking worker to close") process = self.process + assert self.process + wait_timeout = timeout * 0.8 self.child_stop_q.put( { "op": "stop", - "timeout": max(0, deadline - time()) * 0.8, + "timeout": wait_timeout, "executor_wait": executor_wait, } ) await asyncio.sleep(0) # otherwise we get broken pipe errors self.child_stop_q.close() - while process.is_alive() and time() < deadline: - await asyncio.sleep(0.05) + try: + await process.join(wait_timeout) + return + except asyncio.TimeoutError: + pass - if process.is_alive(): - logger.warning( - f"Worker process still alive after {timeout} seconds, killing" - ) - try: - await process.terminate() - except Exception as e: - logger.error("Failed to kill worker process: %s", e) + logger.warning( + f"Worker process still alive after {wait_timeout} seconds, killing" + ) + await process.kill() + await process.join(max(0, deadline - time())) async def _wait_until_connected(self, uid): while True: @@ -783,9 +796,6 @@ async def _wait_until_connected(self, uid): continue if "exception" in msg: - logger.error( - "Failed while trying to start worker process: %s", msg["exception"] - ) raise msg["exception"] else: return msg diff --git a/distributed/process.py b/distributed/process.py index 7f43b5beca..ca7a9b705e 100644 --- a/distributed/process.py +++ b/distributed/process.py @@ -17,7 +17,7 @@ import dask -from distributed.utils import TimeoutError, get_mp_context +from distributed.utils import get_mp_context logger = logging.getLogger(__name__) @@ -299,15 +299,9 @@ async def join(self, timeout=None): assert self._state.pid is not None, "can only join a started process" if self._state.exitcode is not None: return - if timeout is None: - await self._exit_future - else: - try: - # Shield otherwise the timeout cancels the future and our - # on_exit callback will try to set a result on a canceled future - await asyncio.wait_for(asyncio.shield(self._exit_future), timeout) - except TimeoutError: - pass + # Shield otherwise the timeout cancels the future and our + # on_exit callback will try to set a result on a canceled future + await asyncio.wait_for(asyncio.shield(self._exit_future), timeout) def close(self): """ diff --git a/distributed/tests/test_asyncprocess.py b/distributed/tests/test_asyncprocess.py index 65aa0e303c..b0cd9848fb 100644 --- a/distributed/tests/test_asyncprocess.py +++ b/distributed/tests/test_asyncprocess.py @@ -80,7 +80,8 @@ async def test_simple(): assert proc.exitcode is None t1 = time() - await proc.join(timeout=0.02) + with pytest.raises(asyncio.TimeoutError): + await proc.join(timeout=0.02) dt = time() - t1 assert 0.2 >= dt >= 0.001 assert proc.is_alive() diff --git a/distributed/tests/test_nanny.py b/distributed/tests/test_nanny.py index 504ade763c..b9c6b2377f 100644 --- a/distributed/tests/test_nanny.py +++ b/distributed/tests/test_nanny.py @@ -13,6 +13,8 @@ import psutil import pytest +from distributed.diagnostics.plugin import WorkerPlugin + pytestmark = pytest.mark.gpu from tlz import first, valmap @@ -27,7 +29,7 @@ from distributed.diagnostics import SchedulerPlugin from distributed.metrics import time from distributed.protocol.pickle import dumps -from distributed.utils import TimeoutError, parse_ports +from distributed.utils import TimeoutError, get_mp_context, parse_ports from distributed.utils_test import ( captured_logger, gen_cluster, @@ -474,21 +476,29 @@ async def test_nanny_closed_by_keyboard_interrupt(ucx_loop, protocol): assert "remove-worker" in str(s.events) -class StartException(Exception): - pass - - class BrokenWorker(worker.Worker): - async def start(self): - raise StartException("broken") + async def start_unsafe(self): + raise ValueError("broken") @gen_cluster(nthreads=[]) async def test_worker_start_exception(s): - # make sure this raises the right Exception: - with raises_with_cause(RuntimeError, None, StartException, None): - async with Nanny(s.address, worker_class=BrokenWorker) as n: - pass + nanny = Nanny(s.address, worker_class=BrokenWorker) + with captured_logger(logger="distributed.nanny", level=logging.WARNING) as logs: + with raises_with_cause( + RuntimeError, + "Nanny failed to start", + RuntimeError, + "BrokenWorker failed to start", + ): + async with nanny: + pass + assert nanny.status == Status.failed + # ^ NOTE: `Nanny.close` sets it to `closed`, then `Server.start._close_on_failure` sets it to `failed` + assert nanny.process is None + assert "Restarting worker" not in logs.getvalue() + # Avoid excessive spewing. (It's also printed once extra within the subprocess, which is okay.) + assert logs.getvalue().count("ValueError: broken") == 1, logs.getvalue() @gen_cluster(nthreads=[]) @@ -571,6 +581,44 @@ async def test_restart_memory(c, s, n): await asyncio.sleep(0.1) +class BlockClose(WorkerPlugin): + def __init__(self, close_happened): + self.close_happened = close_happened + + async def teardown(self, worker): + # Never let the worker cleanly shut down, so it has to be killed + self.close_happened.set() + while True: + await asyncio.sleep(10) + + +@pytest.mark.slow +@gen_cluster(nthreads=[]) +async def test_close_joins(s): + close_happened = get_mp_context().Event() + + nanny = Nanny(s.address, plugins=[BlockClose(close_happened)]) + async with nanny: + p = nanny.process + assert p + close_t = asyncio.create_task(nanny.close()) + + while not close_happened.wait(0): + await asyncio.sleep(0.01) + + assert not close_t.done() + assert nanny.status == Status.closing + assert nanny.process and nanny.process.status == Status.stopping + + await close_t + + assert nanny.status == Status.closed + assert not nanny.process + + assert p.status == Status.stopped + assert not p.process + + @gen_cluster(Worker=Nanny, nthreads=[("", 1)]) async def test_scheduler_crash_doesnt_restart(s, a): # Simulate a scheduler crash by disconnecting it first