Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature: Add the ability to deny specific plugins from loading #1737

Merged
17 changes: 17 additions & 0 deletions aries_cloudagent/config/argparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,6 +534,20 @@ def add_arguments(self, parser: ArgumentParser):
),
)

parser.add_argument(
"--block-plugin",
dest="blocked_plugins",
type=str,
action="append",
required=False,
metavar="<module>",
env_var="ACAPY_BLOCKED_PLUGIN",
help=(
"Block <module> plugin module from loading. Multiple "
"instances of this parameter can be specified."
),
)

parser.add_argument(
"--plugin-config",
dest="plugin_config",
Expand Down Expand Up @@ -611,6 +625,9 @@ def get_settings(self, args: Namespace) -> dict:
if args.external_plugins:
settings["external_plugins"] = args.external_plugins

if args.blocked_plugins:
settings["blocked_plugins"] = args.blocked_plugins

if args.plugin_config:
with open(args.plugin_config, "r") as stream:
settings["plugin_config"] = yaml.safe_load(stream)
Expand Down
4 changes: 3 additions & 1 deletion aries_cloudagent/config/default_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,9 @@ async def bind_providers(self, context: InjectionContext):
async def load_plugins(self, context: InjectionContext):
"""Set up plugin registry and load plugins."""

plugin_registry = PluginRegistry()
plugin_registry = PluginRegistry(
blocklist=self.settings.get("blocked_plugins", [])
)
context.injector.bind_instance(PluginRegistry, plugin_registry)

# Register standard protocol plugins
Expand Down
8 changes: 6 additions & 2 deletions aries_cloudagent/core/plugin_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import logging
from collections import OrderedDict
from types import ModuleType
from typing import Sequence
from typing import Sequence, Iterable

from ..config.injection_context import InjectionContext
from ..core.event_bus import EventBus
Expand All @@ -19,9 +19,10 @@
class PluginRegistry:
"""Plugin registry for indexing application plugins."""

def __init__(self):
def __init__(self, blocklist: Iterable[str] = []):
"""Initialize a `PluginRegistry` instance."""
self._plugins = OrderedDict()
self._blocklist = set(blocklist)

@property
def plugin_names(self) -> Sequence[str]:
Expand Down Expand Up @@ -119,6 +120,9 @@ def register_plugin(self, module_name: str) -> ModuleType:
"""Register a plugin module."""
if module_name in self._plugins:
mod = self._plugins[module_name]
elif module_name in self._blocklist:
LOGGER.debug(f"Blocked {module_name} from loading due to blocklist")
return None
else:
try:
mod = ClassLoader.load_module(module_name)
Expand Down
20 changes: 19 additions & 1 deletion aries_cloudagent/core/tests/test_plugin_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@

class TestPluginRegistry(AsyncTestCase):
def setUp(self):
self.registry = PluginRegistry()
self.blocked_module = "blocked_module"
self.registry = PluginRegistry(blocklist=[self.blocked_module])

self.context = InjectionContext(enforce_typing=False)
self.proto_registry = async_mock.MagicMock(
Expand Down Expand Up @@ -478,6 +479,23 @@ class MODULE:
]
assert self.registry.register_plugin("dummy") == obj

async def test_unregister_plugin_has_setup(self):
class MODULE:
setup = "present"

obj = MODULE()
with async_mock.patch.object(
ClassLoader, "load_module", async_mock.MagicMock()
) as load_module:
load_module.side_effect = [
obj, # module
None, # routes
None, # message types
None, # definition without versions attr
]
assert self.registry.register_plugin(self.blocked_module) == None
assert self.blocked_module not in self.registry._plugins.keys()

async def test_register_definitions_malformed(self):
class MODULE:
no_setup = "no setup attr"
Expand Down