From f76c40bb89a0f501505891d8d0ee00021923f867 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Thu, 26 May 2022 14:53:46 -0500 Subject: [PATCH] Log rather than raise exceptions in `preload.teardown()` (#6458) --- distributed/scheduler.py | 5 ++++- distributed/tests/test_preload.py | 18 ++++++++++++++++++ distributed/worker.py | 5 ++++- 3 files changed, 26 insertions(+), 2 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 5a12dccd37..7925fd18bf 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -3374,7 +3374,10 @@ async def log_errors(func): setproctitle("dask-scheduler [closing]") for preload in self.preloads: - await preload.teardown() + try: + await preload.teardown() + except Exception as e: + logger.exception(e) await asyncio.gather( *[log_errors(plugin.close) for plugin in list(self.plugins.values())] diff --git a/distributed/tests/test_preload.py b/distributed/tests/test_preload.py index 0bb73dd0d1..04bec1cb95 100644 --- a/distributed/tests/test_preload.py +++ b/distributed/tests/test_preload.py @@ -1,3 +1,4 @@ +import logging import os import re import shutil @@ -281,6 +282,23 @@ def dask_setup(client, value): assert c.foo == value +@gen_test() +async def test_teardown_failure_doesnt_crash_scheduler(): + text = """ +def dask_teardown(worker): + raise Exception(123) +""" + + with captured_logger(logging.getLogger("distributed.scheduler")) as s_logger: + with captured_logger(logging.getLogger("distributed.worker")) as w_logger: + async with Scheduler(dashboard_address=":0", preload=text) as s: + async with Worker(s.address, preload=[text]) as w: + pass + + assert "123" in s_logger.getvalue() + assert "123" in w_logger.getvalue() + + @gen_cluster(nthreads=[]) async def test_client_preload_config_click(s): text = dedent( diff --git a/distributed/worker.py b/distributed/worker.py index 1795bb80c5..c146cb56ff 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -1489,7 +1489,10 @@ async def close( ) for preload in self.preloads: - await preload.teardown() + try: + await preload.teardown() + except Exception as e: + logger.exception(e) for extension in self.extensions.values(): if hasattr(extension, "close"):