From 4d579a70630cf22c377de6f34ac64cdcc65b5649 Mon Sep 17 00:00:00 2001 From: fjetter Date: Thu, 7 Nov 2024 14:24:33 +0100 Subject: [PATCH] Ensure ConnectionPool is closed even if network stack swallows CancelledErrors --- distributed/core.py | 14 +++++++++-- distributed/nanny.py | 1 + distributed/tests/test_core.py | 43 +++++++++++++++++++++++++++++++++- 3 files changed, 55 insertions(+), 3 deletions(-) diff --git a/distributed/core.py b/distributed/core.py index a26758f7761..33f3a9a9fa4 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -1358,6 +1358,7 @@ def __init__( ) self._pending_count = 0 self._connecting_count = 0 + self._connecting_close_timeout = 5 self.status = Status.init def _validate(self) -> None: @@ -1537,7 +1538,9 @@ def callback(task: asyncio.Task[Comm]) -> None: try: return connect_attempt.result() except asyncio.CancelledError: - raise CommClosedError(reason) + if reason: + raise CommClosedError(reason) + raise def reuse(self, addr: str, comm: Comm) -> None: """ @@ -1615,8 +1618,15 @@ async def close(self) -> None: for _ in comms: self.semaphore.release() + start = time() while self._connecting: - await asyncio.sleep(0.005) + if time() - start > self._connecting_close_timeout: + logger.warning( + "Pending connections refuse to cancel. %d connections pending. Closing anyway.", + len(self._connecting), + ) + break + await asyncio.sleep(0.01) def coerce_to_address(o): diff --git a/distributed/nanny.py b/distributed/nanny.py index 859b9f22dcc..845b73bd050 100644 --- a/distributed/nanny.py +++ b/distributed/nanny.py @@ -623,6 +623,7 @@ async def close( # type:ignore[override] self.status = Status.closed await super().close() self.__exit_stack.__exit__(None, None, None) + logger.info("Nanny at %r closed.", self.address_safe) return "OK" async def _log_event(self, topic, msg): diff --git a/distributed/tests/test_core.py b/distributed/tests/test_core.py index 6f4836ae08d..79d63ad0687 100644 --- a/distributed/tests/test_core.py +++ b/distributed/tests/test_core.py @@ -17,7 +17,7 @@ import dask from distributed.batched import BatchedSend -from distributed.comm.core import CommClosedError +from distributed.comm.core import CommClosedError, FatalCommClosedError from distributed.comm.registry import backends from distributed.comm.tcp import TCPBackend, TCPListener from distributed.core import ( @@ -707,6 +707,47 @@ async def connect_to_server(): assert all(t.cancelled() for t in tasks) +@gen_test() +async def test_connection_pool_catch_all_cancellederrors(monkeypatch): + from distributed.comm.registry import backends + from distributed.comm.tcp import TCPBackend, TCPConnector + + in_connect = asyncio.Event() + block_connect = asyncio.Event() + + class BlockedConnector(TCPConnector): + async def connect(self, address, deserialize, **connection_args): + # This is extremely artificial and assumes that something further + # down in the stack would block a cancellation. We want to make sure + # that our ConnectionPool closes regardless of this. + in_connect.set() + try: + await block_connect.wait() + except asyncio.CancelledError: + await asyncio.sleep(30) + raise + raise FatalCommClosedError() + + class BlockedConnectBackend(TCPBackend): + _connector_class = BlockedConnector + + monkeypatch.setitem(backends, "tcp", BlockedConnectBackend()) + + async with Server({}) as server: + await server.listen("tcp://") + pool = await ConnectionPool(limit=2) + pool._connecting_close_timeout = 0 + + t = asyncio.create_task(pool.connect(server.address)) + + await in_connect.wait() + while not pool._connecting_count: + await asyncio.sleep(0.1) + with captured_logger("distributed.core") as sio: + await pool.close() + assert "Pending connections refuse to cancel" in sio.getvalue() + + @gen_test() async def test_remove_cancels_connect_attempts(): loop = asyncio.get_running_loop()