diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index f47764b93ca..fbdfec48ef3 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -6565,6 +6565,16 @@ async def test_get_task_metadata_multiple(c, s, a, b): assert metadata2[f2.key] == s.tasks.get(f2.key).metadata +@gen_cluster(client=True) +async def test_register_worker_plugin_exception(c, s, a, b): + class MyPlugin: + def setup(self, worker=None): + raise ValueError("Setup failed") + + with pytest.raises(ValueError, match="Setup failed"): + await c.register_worker_plugin(MyPlugin()) + + @gen_cluster(client=True) async def test_log_event(c, s, a, b): diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 22640916fa1..db580920d8a 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -39,6 +39,7 @@ from distributed.diagnostics import nvml from distributed.diagnostics.plugin import PipInstall from distributed.metrics import time +from distributed.protocol import pickle from distributed.scheduler import Scheduler from distributed.utils import TimeoutError from distributed.utils_test import ( @@ -380,6 +381,64 @@ def __str__(self): assert "Bar" in str(e.__cause__) +@pytest.mark.asyncio +async def test_plugin_exception(cleanup): + class MyPlugin: + def setup(self, worker=None): + raise ValueError("Setup failed") + + async with Scheduler(port=0) as s: + with pytest.raises(ValueError, match="Setup failed"): + async with Worker( + s.address, + plugins={ + MyPlugin(), + }, + ) as w: + pass + + +@pytest.mark.asyncio +async def test_plugin_multiple_exceptions(cleanup): + class MyPlugin1: + def setup(self, worker=None): + raise ValueError("MyPlugin1 Error") + + class MyPlugin2: + def setup(self, worker=None): + raise RuntimeError("MyPlugin2 Error") + + async with Scheduler(port=0) as s: + # There's no guarantee on the order of which exception is raised first + with pytest.raises((ValueError, RuntimeError), match="MyPlugin.* Error"): + with captured_logger("distributed.worker") as logger: + async with Worker( + s.address, + plugins={ + MyPlugin1(), + MyPlugin2(), + }, + ) as w: + pass + + text = logger.getvalue() + assert "MyPlugin1 Error" in text + assert "MyPlugin2 Error" in text + + +@pytest.mark.asyncio +async def test_plugin_internal_exception(cleanup): + async with Scheduler(port=0) as s: + with pytest.raises(UnicodeDecodeError, match="codec can't decode"): + async with Worker( + s.address, + plugins={ + b"corrupting pickle" + pickle.dumps(lambda: None, protocol=4), + }, + ) as w: + pass + + @gen_cluster(client=True) async def test_gather(c, s, a, b): x, y = await c.scatter(["x", "y"], workers=[b.address]) diff --git a/distributed/worker.py b/distributed/worker.py index baa4c5f3b89..25148958f9d 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -1504,9 +1504,23 @@ async def start(self): setproctitle("dask-worker [%s]" % self.address) - await asyncio.gather( - *(self.plugin_add(plugin=plugin) for plugin in self._pending_plugins) + plugins_msgs = await asyncio.gather( + *( + self.plugin_add(plugin=plugin, catch_errors=False) + for plugin in self._pending_plugins + ), + return_exceptions=True, ) + plugins_exceptions = [msg for msg in plugins_msgs if isinstance(msg, Exception)] + if len(plugins_exceptions) >= 1: + if len(plugins_exceptions) > 1: + logger.error( + "Multiple plugin exceptions raised. All exceptions will be logged, the first is raised." + ) + for exc in plugins_exceptions: + logger.error(repr(exc)) + raise plugins_exceptions[0] + self._pending_plugins = () await self._register_with_scheduler() @@ -3248,7 +3262,7 @@ def run(self, comm, function, args=(), wait=True, kwargs=None): def run_coroutine(self, comm, function, args=(), kwargs=None, wait=True): return run(self, comm, function=function, args=args, kwargs=kwargs, wait=wait) - async def plugin_add(self, comm=None, plugin=None, name=None): + async def plugin_add(self, comm=None, plugin=None, name=None, catch_errors=True): with log_errors(pdb=False): if isinstance(plugin, bytes): plugin = pickle.loads(plugin) @@ -3270,6 +3284,8 @@ async def plugin_add(self, comm=None, plugin=None, name=None): if isawaitable(result): result = await result except Exception as e: + if not catch_errors: + raise msg = error_message(e) return msg