diff --git a/distributed/diagnostics/tests/test_scheduler_plugin.py b/distributed/diagnostics/tests/test_scheduler_plugin.py index 59511ba456a..880c5f65881 100644 --- a/distributed/diagnostics/tests/test_scheduler_plugin.py +++ b/distributed/diagnostics/tests/test_scheduler_plugin.py @@ -198,3 +198,14 @@ def f(): await c.submit(f) assert ("foo", 123) in s._recorded_events + + +@gen_cluster(client=True) +async def test_register_plugin_on_scheduler(c, s, a, b): + class MyPlugin(SchedulerPlugin): + async def start(self, scheduler: Scheduler) -> None: + scheduler._foo = "bar" # type: ignore + + await s.register_scheduler_plugin(MyPlugin()) + + assert s._foo == "bar" diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 84562cee66f..d33c32d0219 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -4865,7 +4865,8 @@ async def register_scheduler_plugin(self, plugin, name=None, idempotent=None): "arbitrary bytestrings using pickle via the " "'distributed.scheduler.pickle' configuration setting." ) - plugin = loads(plugin) + if not isinstance(plugin, SchedulerPlugin): + plugin = loads(plugin) if name is None: name = _get_plugin_name(plugin)