Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add inject_or #1376

Merged
39 changes: 33 additions & 6 deletions aries_cloudagent/admin/request_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,6 @@ def inject(
self,
base_cls: Type[InjectType],
settings: Mapping[str, object] = None,
*,
required: bool = True
) -> Optional[InjectType]:
"""
Get the provided instance of a given class identifier.
Expand All @@ -75,7 +73,27 @@ def inject(
An instance of the base class, or None

"""
return self._context.inject(base_cls, settings, required=required)
return self._context.inject(base_cls, settings)

def inject_or(
self,
base_cls: Type[InjectType],
settings: Mapping[str, object] = None,
default: Optional[InjectType] = None,
) -> Optional[InjectType]:
"""
Get the provided instance of a given class identifier or default if not found.

Args:
base_cls: The base class to retrieve an instance of
settings: An optional dict providing configuration to the provider
default: default return value if no instance is found

Returns:
An instance of the base class, or None

"""
return self._context.inject_or(base_cls, settings, default)

def update_settings(self, settings: Mapping[str, object]):
"""Update the current scope with additional settings."""
Expand All @@ -96,17 +114,26 @@ def test_context(
def _test_session(self) -> ProfileSession:
session = self.profile.session(self._context)

def _inject(base_cls, required=True):
def _inject(base_cls):
if session._active and base_cls in self.session_inject:
ret = self.session_inject[base_cls]
if ret is None and required:
if ret is None:
raise InjectionError(
"No instance provided for class: {}".format(base_cls.__name__)
)
return ret
return session._context.injector.inject(base_cls, required=required)
return session._context.injector.inject(base_cls)

def _inject_or(base_cls, default=None):
if session._active and base_cls in self.session_inject:
ret = self.session_inject[base_cls]
if ret is None:
ret = default
return ret
return session._context.injector.inject_or(base_cls, default)

setattr(session, "inject", _inject)
setattr(session, "inject_or", _inject_or)
return session

def __repr__(self) -> str:
Expand Down
16 changes: 8 additions & 8 deletions aries_cloudagent/admin/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ def __init__(
self.webhook_router = webhook_router
self.websocket_queues = {}
self.site = None
self.multitenant_manager = context.inject(MultitenantManager, required=False)
self.multitenant_manager = context.inject_or(MultitenantManager)

self.server_paths = []

Expand Down Expand Up @@ -293,7 +293,7 @@ async def check_token(request: web.Request, handler):

middlewares.append(check_token)

collector = self.context.inject(Collector, required=False)
collector = self.context.inject_or(Collector)

if self.multitenant_manager:

Expand Down Expand Up @@ -393,7 +393,7 @@ async def setup_context(request: web.Request, handler):
self.server_paths = [route.path for route in server_routes]
app.add_routes(server_routes)

plugin_registry = self.context.inject(PluginRegistry, required=False)
plugin_registry = self.context.inject_or(PluginRegistry)
if plugin_registry:
await plugin_registry.register_admin_routes(app)

Expand Down Expand Up @@ -445,11 +445,11 @@ def sort_dict(raw: dict) -> dict:
runner = web.AppRunner(self.app)
await runner.setup()

plugin_registry = self.context.inject(PluginRegistry, required=False)
plugin_registry = self.context.inject_or(PluginRegistry)
if plugin_registry:
plugin_registry.post_process_routes(self.app)

event_bus = self.context.inject(EventBus, required=False)
event_bus = self.context.inject_or(EventBus)
if event_bus:
event_bus.subscribe(EVENT_PATTERN_WEBHOOK, self._on_webhook_event)
event_bus.subscribe(EVENT_PATTERN_RECORD, self._on_record_event)
Expand Down Expand Up @@ -548,7 +548,7 @@ async def plugins_handler(self, request: web.BaseRequest):
The module list response

"""
registry = self.context.inject(PluginRegistry, required=False)
registry = self.context.inject_or(PluginRegistry)
plugins = registry and sorted(registry.plugin_names) or []
return web.json_response({"result": plugins})

Expand Down Expand Up @@ -602,7 +602,7 @@ async def status_handler(self, request: web.BaseRequest):
"""
status = {"version": __version__}
status["label"] = self.context.settings.get("default_label")
collector = self.context.inject(Collector, required=False)
collector = self.context.inject_or(Collector)
if collector:
status["timing"] = collector.results
if self.conductor_stats:
Expand All @@ -622,7 +622,7 @@ async def status_reset_handler(self, request: web.BaseRequest):
The web response

"""
collector = self.context.inject(Collector, required=False)
collector = self.context.inject_or(Collector)
if collector:
collector.reset()
return web.json_response({})
Expand Down
2 changes: 1 addition & 1 deletion aries_cloudagent/admin/tests/test_request_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,4 @@ async def test_session_inject_x(self):
test_ctx = test_module.AdminRequestContext.test_context({Collector: None})
async with test_ctx.session() as test_sesn:
with self.assertRaises(test_module.InjectionError):
test_sesn.inject(Collector, required=True)
test_sesn.inject(Collector)
2 changes: 1 addition & 1 deletion aries_cloudagent/askar/profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def init_ledger_pool(self):
if read_only:
LOGGER.error("Note: setting ledger to read-only mode")
genesis_transactions = self.settings.get("ledger.genesis_transactions")
cache = self.context.injector.inject(BaseCache, required=False)
cache = self.context.injector.inject_or(BaseCache)
self.ledger_pool = IndyVdrLedgerPool(
pool_name,
keepalive=keepalive,
Expand Down
24 changes: 21 additions & 3 deletions aries_cloudagent/config/injection_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,6 @@ def inject(
self,
base_cls: Type[InjectType],
settings: Mapping[str, object] = None,
*,
required: bool = True
) -> Optional[InjectType]:
"""
Get the provided instance of a given class identifier.
Expand All @@ -123,7 +121,27 @@ def inject(
An instance of the base class, or None

"""
return self.injector.inject(base_cls, settings, required=required)
return self.injector.inject(base_cls, settings)

def inject_or(
self,
base_cls: Type[InjectType],
settings: Mapping[str, object] = None,
default: Optional[InjectType] = None,
) -> Optional[InjectType]:
"""
Get the provided instance of a given class identifier or default if not found.

Args:
base_cls: The base class to retrieve an instance of
settings: An optional dict providing configuration to the provider
default: default return value if no instance is found

Returns:
An instance of the base class, or None

"""
return self.injector.inject_or(base_cls, settings, default)

def copy(self) -> "InjectionContext":
"""Produce a copy of the injector instance."""
Expand Down
44 changes: 32 additions & 12 deletions aries_cloudagent/config/injector.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,19 +51,19 @@ def get_provider(self, base_cls: Type[InjectType]):
"""Find the provider associated with a class binding."""
return self._providers.get(base_cls)

def inject(
def inject_or(
self,
base_cls: Type[InjectType],
settings: Mapping[str, object] = None,
*,
required: bool = True,
default: Optional[InjectType] = None,
) -> Optional[InjectType]:
"""
Get the provided instance of a given class identifier.
Get the provided instance of a given class identifier or default if not found.

Args:
cls: The base class to retrieve an instance of
params: An optional dict providing configuration to the provider
base_cls: The base class to retrieve an instance of
settings: An optional dict providing configuration to the provider
default: default return value if no instance is found

Returns:
An instance of the base class, or None
Expand All @@ -80,17 +80,37 @@ def inject(
result = provider.provide(ext_settings, self)
else:
result = None
if result is None:
if required:
raise InjectionError(
"No instance provided for class: {}".format(base_cls.__name__)
)
elif not isinstance(result, base_cls) and self.enforce_typing:

if result and not isinstance(result, base_cls) and self.enforce_typing:
raise InjectionError(
"Provided instance does not implement the base class: {}".format(
base_cls.__name__
)
)

return result if result is not None else default

def inject(
self,
base_cls: Type[InjectType],
settings: Mapping[str, object] = None,
) -> InjectType:
"""
Get the provided instance of a given class identifier.

Args:
cls: The base class to retrieve an instance of
params: An optional dict providing configuration to the provider

Returns:
An instance of the base class, or None

"""
result = self.inject_or(base_cls, settings)
if result is None:
raise InjectionError(
"No instance provided for class: {}".format(base_cls.__name__)
)
return result

def copy(self) -> BaseInjector:
Expand Down
2 changes: 1 addition & 1 deletion aries_cloudagent/config/ledger.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ async def ledger_config(

session = await profile.session()

ledger = session.inject(BaseLedger, required=False)
ledger = session.inject_or(BaseLedger)
if not ledger:
LOGGER.info("Ledger instance not provided")
return False
Expand Down
2 changes: 1 addition & 1 deletion aries_cloudagent/config/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def provide(self, config: BaseSettings, injector: BaseInjector):
"""Provide the object instance given a config and injector."""
instance = self._provider.provide(config, injector)
if self._methods:
collector: Collector = injector.inject(Collector, required=False)
collector: Collector = injector.inject_or(Collector)
if collector:
collector.wrap(
instance, self._methods, ignore_missing=self._ignore_missing
Expand Down
10 changes: 5 additions & 5 deletions aries_cloudagent/config/tests/test_injection_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def test_settings_scope(self):

async def test_inject_simple(self):
"""Test a basic injection."""
assert self.test_instance.inject(str, required=False) is None
assert self.test_instance.inject_or(str) is None
with self.assertRaises(InjectionError):
self.test_instance.inject(str)
self.test_instance.injector.bind_instance(str, self.test_value)
Expand All @@ -65,10 +65,10 @@ async def test_inject_simple(self):
async def test_inject_scope(self):
"""Test a scoped injection."""
context = self.test_instance.start_scope(self.test_scope)
assert context.inject(str, required=False) is None
assert context.inject_or(str) is None
context.injector.bind_instance(str, self.test_value)
assert context.inject(str) is self.test_value
assert self.test_instance.inject(str, required=False) is None
assert self.test_instance.inject_or(str) is None
root = context.injector_for_scope(context.ROOT_SCOPE)
assert root.inject(str, required=False) is None
assert self.test_instance.inject(str, required=False) is None
assert root.inject_or(str) is None
assert self.test_instance.inject_or(str) is None
10 changes: 8 additions & 2 deletions aries_cloudagent/config/tests/test_injector.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def test_settings_init(self):

def test_inject_simple(self):
"""Test a basic injection."""
assert self.test_instance.inject(str, required=False) is None
assert self.test_instance.inject_or(str) is None
with self.assertRaises(InjectionError):
self.test_instance.inject(str)
with self.assertRaises(ValueError):
Expand Down Expand Up @@ -75,7 +75,7 @@ def test_bad_provider(self):
self.test_instance.bind_provider(str, MockProvider(None))
with self.assertRaises(InjectionError):
self.test_instance.inject(str)
self.test_instance.inject(str, required=False)
self.test_instance.inject_or(str)
self.test_instance.bind_provider(str, MockProvider(1))
self.test_instance.clear_binding(str)
assert self.test_instance.get_provider(str) is None
Expand Down Expand Up @@ -139,3 +139,9 @@ def test_inject_cached(self):
i1 = self.test_instance.inject(MockInstance)
i2 = self.test_instance.inject(MockInstance)
assert i1 is i2

def test_falsey_still_returns(self):
"""Test the injector still returns falsey values."""
self.test_instance.bind_instance(dict, dict())
assert self.test_instance.inject_or(dict) is not None
assert self.test_instance.inject(dict) is not None
4 changes: 2 additions & 2 deletions aries_cloudagent/config/tests/test_wallet.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ async def setUp(self):
transaction=async_mock.CoroutineMock(return_value=self.session),
)

def _inject(cls, required=True):
return self.injector.inject(cls, required=required)
def _inject(cls):
return self.injector.inject(cls)

self.session.inject = _inject
self.context = InjectionContext()
Expand Down
6 changes: 3 additions & 3 deletions aries_cloudagent/core/conductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ async def setup(self):
self.dispatcher = Dispatcher(self.root_profile)
await self.dispatcher.setup()

wire_format = context.inject(BaseWireFormat, required=False)
wire_format = context.inject_or(BaseWireFormat)
if wire_format and hasattr(wire_format, "task_queue"):
wire_format.task_queue = self.dispatcher.task_queue

Expand Down Expand Up @@ -159,7 +159,7 @@ async def setup(self):
raise

# Fetch stats collector, if any
collector = context.inject(Collector, required=False)
collector = context.inject_or(Collector)
if collector:
# add stats to our own methods
collector.wrap(
Expand Down Expand Up @@ -358,7 +358,7 @@ async def stop(self, timeout=1.0):
shutdown.run(self.outbound_transport_manager.stop())

# close multitenant profiles
multitenant_mgr = self.context.inject(MultitenantManager, required=False)
multitenant_mgr = self.context.inject_or(MultitenantManager)
if multitenant_mgr:
for profile in multitenant_mgr._instances.values():
shutdown.run(profile.close())
Expand Down
2 changes: 1 addition & 1 deletion aries_cloudagent/core/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def __init__(self, profile: Profile):

async def setup(self):
"""Perform async instance setup."""
self.collector = self.profile.inject(Collector, required=False)
self.collector = self.profile.inject_or(Collector)
max_active = int(os.getenv("DISPATCHER_MAX_ACTIVE", 50))
self.task_queue = TaskQueue(
max_active=max_active, timed=bool(self.collector), trace_fn=self.log_task
Expand Down
Loading