Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Raise plugin exceptions on Worker.start() #4298

Merged
merged 17 commits into from
Jan 18, 2022
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions distributed/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6443,6 +6443,16 @@ async def test_get_task_metadata_multiple(c, s, a, b):
assert metadata2[f2.key] == s.tasks.get(f2.key).metadata


@gen_cluster(client=True)
async def test_register_worker_plugin_exception(c, s, a, b):
class MyPlugin:
def setup(self, worker=None):
raise ValueError("Setup failed")

with pytest.raises(ValueError, match="Setup failed"):
await c.register_worker_plugin(MyPlugin())


@gen_cluster(client=True)
async def test_log_event(c, s, a, b):

Expand Down
59 changes: 59 additions & 0 deletions distributed/tests/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from distributed.diagnostics import nvml
from distributed.diagnostics.plugin import PipInstall
from distributed.metrics import time
from distributed.protocol import pickle
from distributed.scheduler import Scheduler
from distributed.utils import TimeoutError
from distributed.utils_test import (
Expand Down Expand Up @@ -378,6 +379,64 @@ def __str__(self):
assert "Bar" in str(e.__cause__)


@pytest.mark.asyncio
async def test_plugin_exception(cleanup):
class MyPlugin:
def setup(self, worker=None):
raise ValueError("Setup failed")

async with Scheduler(port=0) as s:
with pytest.raises(ValueError, match="Setup failed"):
async with Worker(
s.address,
plugins={
MyPlugin(),
},
) as w:
pass


@pytest.mark.asyncio
async def test_plugin_multiple_exceptions(cleanup):
class MyPlugin1:
def setup(self, worker=None):
raise ValueError("MyPlugin1 Error")

class MyPlugin2:
def setup(self, worker=None):
raise RuntimeError("MyPlugin2 Error")

async with Scheduler(port=0) as s:
# There's no guarantee on the order of which exception is raised first
with pytest.raises((ValueError, RuntimeError), match="MyPlugin.* Error"):
with captured_logger("distributed.worker") as logger:
async with Worker(
s.address,
plugins={
MyPlugin1(),
MyPlugin2(),
},
) as w:
pass

text = logger.getvalue()
assert "MyPlugin1 Error" in text
assert "MyPlugin2 Error" in text


@pytest.mark.asyncio
async def test_plugin_internal_exception(cleanup):
async with Scheduler(port=0) as s:
with pytest.raises(UnicodeDecodeError, match="codec can't decode"):
async with Worker(
s.address,
plugins={
b"corrupting pickle" + pickle.dumps(lambda: None, protocol=4),
},
) as w:
pass


@gen_cluster(client=True)
async def test_gather(c, s, a, b):
x, y = await c.scatter(["x", "y"], workers=[b.address])
Expand Down
16 changes: 15 additions & 1 deletion distributed/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1494,9 +1494,23 @@ async def start(self):

setproctitle("dask-worker [%s]" % self.address)

await asyncio.gather(
plugins_msgs = await asyncio.gather(
*(self.plugin_add(plugin=plugin) for plugin in self._pending_plugins)
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's also possible for plugin_add to raise other exceptions (e.g. unpickling errors). What would happen in these cases? I assume it is failing hard but I would like to confirm. Do we have tests for this?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added test for that in d323913 .

plugins_exceptions = []
for msg in plugins_msgs:
if msg["status"] != "OK":
exc = pickle.loads(msg["exception"].data)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The pickle-unpickle here feels a little cumbersome to me. I'm not sure I love this idea, but what if you add a catch_errors=True to plugin_add, and in start just set catch_errors=False?

diff --git a/distributed/worker.py b/distributed/worker.py
index 6e2fc56b..f4e3ed03 100644
--- a/distributed/worker.py
+++ b/distributed/worker.py
@@ -1494,22 +1494,12 @@ class Worker(ServerNode):
 
         setproctitle("dask-worker [%s]" % self.address)
 
-        plugins_msgs = await asyncio.gather(
-            *(self.plugin_add(plugin=plugin) for plugin in self._pending_plugins)
+        await asyncio.gather(
+            *(
+                self.plugin_add(plugin=plugin, catch_errors=False)
+                for plugin in self._pending_plugins
+            )
         )
-        plugins_exceptions = []
-        for msg in plugins_msgs:
-            if msg["status"] != "OK":
-                exc = pickle.loads(msg["exception"].data)
-                plugins_exceptions.append(pickle.loads(msg["exception"].data))
-        if len(plugins_exceptions) >= 1:
-            if len(plugins_exceptions) > 1:
-                logger.error(
-                    "Multiple plugin exceptions raised. All exceptions will be logged, the first is raised."
-                )
-                for exc in plugins_exceptions:
-                    logger.error(repr(exc))
-            raise plugins_exceptions[0]
 
         self._pending_plugins = ()
 
@@ -3182,7 +3172,7 @@ class Worker(ServerNode):
     def run_coroutine(self, comm, function, args=(), kwargs=None, wait=True):
         return run(self, comm, function=function, args=args, kwargs=kwargs, wait=wait)
 
-    async def plugin_add(self, comm=None, plugin=None, name=None):
+    async def plugin_add(self, comm=None, plugin=None, name=None, catch_errors=True):
         with log_errors(pdb=False):
             if isinstance(plugin, bytes):
                 plugin = pickle.loads(plugin)
@@ -3204,6 +3194,8 @@ class Worker(ServerNode):
                     if isawaitable(result):
                         result = await result
                 except Exception as e:
+                    if not catch_errors:
+                        raise
                     msg = error_message(e)
                     return msg

If you wanted to maintain the behavior you have here where multiple plugin exceptions can be handled and logged, then you could pass gather(..., return_exceptions=True) and handle them similarly as here, just without the unpickling. But I'm not sure if this functionality is actually necessary.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The pickle-unpickle here feels a little cumbersome to me. I'm not sure I love this idea, but what if you add a catch_errors=True to plugin_add, and in start just set catch_errors=False?

Thanks for the suggestion, that indeed looks better.

If you wanted to maintain the behavior you have here where multiple plugin exceptions can be handled and logged, then you could pass gather(..., return_exceptions=True) and handle them similarly as here, just without the unpickling. But I'm not sure if this functionality is actually necessary.

I'm ok with having this or not, it was suggested by @fjetter in #4298 (comment) , for now I'm keeping this but we can remove it if people prefer.

Relevant changes applied in 560bb55 .

plugins_exceptions.append(pickle.loads(msg["exception"].data))
if len(plugins_exceptions) >= 1:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think plugins_exceptions is not just the exceptions; it contains all the values returned by plugin_add, including OK messages. I think you'd want plugins_exceptions = [x for x in plugins_results if isinstance(x, Exception)] first.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so, that was the case with plugins_msgs (the previous pickled version), but now adding a line just above this print(f"plugins_exceptions: {plugins_exceptions}"), I see the following for distributed/tests/test_worker.py::test_plugin_exception:

plugins_exceptions: [ValueError('Setup failed')]

and for distributed/tests/test_worker.py::test_plugin_internal_exception I see:

plugins_exceptions: [UnicodeDecodeError('utf-8', b'orrupting pickle\x80\x04\x95e\x02\x00\x00\x00\x00\x00\x00\x8c\x17cloudpickle.cloudpickle\x94\x8c\r_builtin_type\x94\x93\x94\x8c', 16, 17, 'invalid start byte')]

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Try having the plugin not raise an exception and see what happens.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, you're absolutely right, thanks for pointing it out. This is fixed now by dc3aaeb. Could you take another look when you have a chance?

if len(plugins_exceptions) > 1:
logger.error(
"Multiple plugin exceptions raised. All exceptions will be logged, the first is raised."
)
for exc in plugins_exceptions:
logger.error(repr(exc))
raise plugins_exceptions[0]

self._pending_plugins = ()

await self._register_with_scheduler()
Expand Down