Skip to content

Commit

Permalink
Raise plugin exceptions on Worker.start() (#4298)
Browse files Browse the repository at this point in the history
Raise plugin exceptions on Worker.start()
  • Loading branch information
pentschev authored Jan 18, 2022
1 parent 922d987 commit 4c8ebfd
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 3 deletions.
10 changes: 10 additions & 0 deletions distributed/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6565,6 +6565,16 @@ async def test_get_task_metadata_multiple(c, s, a, b):
assert metadata2[f2.key] == s.tasks.get(f2.key).metadata


@gen_cluster(client=True)
async def test_register_worker_plugin_exception(c, s, a, b):
class MyPlugin:
def setup(self, worker=None):
raise ValueError("Setup failed")

with pytest.raises(ValueError, match="Setup failed"):
await c.register_worker_plugin(MyPlugin())


@gen_cluster(client=True)
async def test_log_event(c, s, a, b):

Expand Down
59 changes: 59 additions & 0 deletions distributed/tests/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from distributed.diagnostics import nvml
from distributed.diagnostics.plugin import PipInstall
from distributed.metrics import time
from distributed.protocol import pickle
from distributed.scheduler import Scheduler
from distributed.utils import TimeoutError
from distributed.utils_test import (
Expand Down Expand Up @@ -380,6 +381,64 @@ def __str__(self):
assert "Bar" in str(e.__cause__)


@pytest.mark.asyncio
async def test_plugin_exception(cleanup):
class MyPlugin:
def setup(self, worker=None):
raise ValueError("Setup failed")

async with Scheduler(port=0) as s:
with pytest.raises(ValueError, match="Setup failed"):
async with Worker(
s.address,
plugins={
MyPlugin(),
},
) as w:
pass


@pytest.mark.asyncio
async def test_plugin_multiple_exceptions(cleanup):
class MyPlugin1:
def setup(self, worker=None):
raise ValueError("MyPlugin1 Error")

class MyPlugin2:
def setup(self, worker=None):
raise RuntimeError("MyPlugin2 Error")

async with Scheduler(port=0) as s:
# There's no guarantee on the order of which exception is raised first
with pytest.raises((ValueError, RuntimeError), match="MyPlugin.* Error"):
with captured_logger("distributed.worker") as logger:
async with Worker(
s.address,
plugins={
MyPlugin1(),
MyPlugin2(),
},
) as w:
pass

text = logger.getvalue()
assert "MyPlugin1 Error" in text
assert "MyPlugin2 Error" in text


@pytest.mark.asyncio
async def test_plugin_internal_exception(cleanup):
async with Scheduler(port=0) as s:
with pytest.raises(UnicodeDecodeError, match="codec can't decode"):
async with Worker(
s.address,
plugins={
b"corrupting pickle" + pickle.dumps(lambda: None, protocol=4),
},
) as w:
pass


@gen_cluster(client=True)
async def test_gather(c, s, a, b):
x, y = await c.scatter(["x", "y"], workers=[b.address])
Expand Down
22 changes: 19 additions & 3 deletions distributed/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1504,9 +1504,23 @@ async def start(self):

setproctitle("dask-worker [%s]" % self.address)

await asyncio.gather(
*(self.plugin_add(plugin=plugin) for plugin in self._pending_plugins)
plugins_msgs = await asyncio.gather(
*(
self.plugin_add(plugin=plugin, catch_errors=False)
for plugin in self._pending_plugins
),
return_exceptions=True,
)
plugins_exceptions = [msg for msg in plugins_msgs if isinstance(msg, Exception)]
if len(plugins_exceptions) >= 1:
if len(plugins_exceptions) > 1:
logger.error(
"Multiple plugin exceptions raised. All exceptions will be logged, the first is raised."
)
for exc in plugins_exceptions:
logger.error(repr(exc))
raise plugins_exceptions[0]

self._pending_plugins = ()

await self._register_with_scheduler()
Expand Down Expand Up @@ -3248,7 +3262,7 @@ def run(self, comm, function, args=(), wait=True, kwargs=None):
def run_coroutine(self, comm, function, args=(), kwargs=None, wait=True):
return run(self, comm, function=function, args=args, kwargs=kwargs, wait=wait)

async def plugin_add(self, comm=None, plugin=None, name=None):
async def plugin_add(self, comm=None, plugin=None, name=None, catch_errors=True):
with log_errors(pdb=False):
if isinstance(plugin, bytes):
plugin = pickle.loads(plugin)
Expand All @@ -3270,6 +3284,8 @@ async def plugin_add(self, comm=None, plugin=None, name=None):
if isawaitable(result):
result = await result
except Exception as e:
if not catch_errors:
raise
msg = error_message(e)
return msg

Expand Down

0 comments on commit 4c8ebfd

Please sign in to comment.