Skip to content

Commit

Permalink
Merge pull request #1376 from dbluhm/fix/injector-typing
Browse files Browse the repository at this point in the history
feat: add inject_or
  • Loading branch information
andrewwhitehead authored Aug 31, 2021
2 parents 102d9f1 + 28ae318 commit 040b2f7
Show file tree
Hide file tree
Showing 48 changed files with 280 additions and 150 deletions.
39 changes: 33 additions & 6 deletions aries_cloudagent/admin/request_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,6 @@ def inject(
self,
base_cls: Type[InjectType],
settings: Mapping[str, object] = None,
*,
required: bool = True
) -> Optional[InjectType]:
"""
Get the provided instance of a given class identifier.
Expand All @@ -75,7 +73,27 @@ def inject(
An instance of the base class, or None
"""
return self._context.inject(base_cls, settings, required=required)
return self._context.inject(base_cls, settings)

def inject_or(
self,
base_cls: Type[InjectType],
settings: Mapping[str, object] = None,
default: Optional[InjectType] = None,
) -> Optional[InjectType]:
"""
Get the provided instance of a given class identifier or default if not found.
Args:
base_cls: The base class to retrieve an instance of
settings: An optional dict providing configuration to the provider
default: default return value if no instance is found
Returns:
An instance of the base class, or None
"""
return self._context.inject_or(base_cls, settings, default)

def update_settings(self, settings: Mapping[str, object]):
"""Update the current scope with additional settings."""
Expand All @@ -96,17 +114,26 @@ def test_context(
def _test_session(self) -> ProfileSession:
session = self.profile.session(self._context)

def _inject(base_cls, required=True):
def _inject(base_cls):
if session._active and base_cls in self.session_inject:
ret = self.session_inject[base_cls]
if ret is None and required:
if ret is None:
raise InjectionError(
"No instance provided for class: {}".format(base_cls.__name__)
)
return ret
return session._context.injector.inject(base_cls, required=required)
return session._context.injector.inject(base_cls)

def _inject_or(base_cls, default=None):
if session._active and base_cls in self.session_inject:
ret = self.session_inject[base_cls]
if ret is None:
ret = default
return ret
return session._context.injector.inject_or(base_cls, default)

setattr(session, "inject", _inject)
setattr(session, "inject_or", _inject_or)
return session

def __repr__(self) -> str:
Expand Down
16 changes: 8 additions & 8 deletions aries_cloudagent/admin/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ def __init__(
self.webhook_router = webhook_router
self.websocket_queues = {}
self.site = None
self.multitenant_manager = context.inject(MultitenantManager, required=False)
self.multitenant_manager = context.inject_or(MultitenantManager)

self.server_paths = []

Expand Down Expand Up @@ -293,7 +293,7 @@ async def check_token(request: web.Request, handler):

middlewares.append(check_token)

collector = self.context.inject(Collector, required=False)
collector = self.context.inject_or(Collector)

if self.multitenant_manager:

Expand Down Expand Up @@ -393,7 +393,7 @@ async def setup_context(request: web.Request, handler):
self.server_paths = [route.path for route in server_routes]
app.add_routes(server_routes)

plugin_registry = self.context.inject(PluginRegistry, required=False)
plugin_registry = self.context.inject_or(PluginRegistry)
if plugin_registry:
await plugin_registry.register_admin_routes(app)

Expand Down Expand Up @@ -445,11 +445,11 @@ def sort_dict(raw: dict) -> dict:
runner = web.AppRunner(self.app)
await runner.setup()

plugin_registry = self.context.inject(PluginRegistry, required=False)
plugin_registry = self.context.inject_or(PluginRegistry)
if plugin_registry:
plugin_registry.post_process_routes(self.app)

event_bus = self.context.inject(EventBus, required=False)
event_bus = self.context.inject_or(EventBus)
if event_bus:
event_bus.subscribe(EVENT_PATTERN_WEBHOOK, self._on_webhook_event)
event_bus.subscribe(EVENT_PATTERN_RECORD, self._on_record_event)
Expand Down Expand Up @@ -548,7 +548,7 @@ async def plugins_handler(self, request: web.BaseRequest):
The module list response
"""
registry = self.context.inject(PluginRegistry, required=False)
registry = self.context.inject_or(PluginRegistry)
plugins = registry and sorted(registry.plugin_names) or []
return web.json_response({"result": plugins})

Expand Down Expand Up @@ -602,7 +602,7 @@ async def status_handler(self, request: web.BaseRequest):
"""
status = {"version": __version__}
status["label"] = self.context.settings.get("default_label")
collector = self.context.inject(Collector, required=False)
collector = self.context.inject_or(Collector)
if collector:
status["timing"] = collector.results
if self.conductor_stats:
Expand All @@ -622,7 +622,7 @@ async def status_reset_handler(self, request: web.BaseRequest):
The web response
"""
collector = self.context.inject(Collector, required=False)
collector = self.context.inject_or(Collector)
if collector:
collector.reset()
return web.json_response({})
Expand Down
2 changes: 1 addition & 1 deletion aries_cloudagent/admin/tests/test_request_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,4 @@ async def test_session_inject_x(self):
test_ctx = test_module.AdminRequestContext.test_context({Collector: None})
async with test_ctx.session() as test_sesn:
with self.assertRaises(test_module.InjectionError):
test_sesn.inject(Collector, required=True)
test_sesn.inject(Collector)
2 changes: 1 addition & 1 deletion aries_cloudagent/askar/profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def init_ledger_pool(self):
if read_only:
LOGGER.error("Note: setting ledger to read-only mode")
genesis_transactions = self.settings.get("ledger.genesis_transactions")
cache = self.context.injector.inject(BaseCache, required=False)
cache = self.context.injector.inject_or(BaseCache)
self.ledger_pool = IndyVdrLedgerPool(
pool_name,
keepalive=keepalive,
Expand Down
24 changes: 21 additions & 3 deletions aries_cloudagent/config/injection_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,6 @@ def inject(
self,
base_cls: Type[InjectType],
settings: Mapping[str, object] = None,
*,
required: bool = True
) -> Optional[InjectType]:
"""
Get the provided instance of a given class identifier.
Expand All @@ -123,7 +121,27 @@ def inject(
An instance of the base class, or None
"""
return self.injector.inject(base_cls, settings, required=required)
return self.injector.inject(base_cls, settings)

def inject_or(
self,
base_cls: Type[InjectType],
settings: Mapping[str, object] = None,
default: Optional[InjectType] = None,
) -> Optional[InjectType]:
"""
Get the provided instance of a given class identifier or default if not found.
Args:
base_cls: The base class to retrieve an instance of
settings: An optional dict providing configuration to the provider
default: default return value if no instance is found
Returns:
An instance of the base class, or None
"""
return self.injector.inject_or(base_cls, settings, default)

def copy(self) -> "InjectionContext":
"""Produce a copy of the injector instance."""
Expand Down
44 changes: 32 additions & 12 deletions aries_cloudagent/config/injector.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,19 +51,19 @@ def get_provider(self, base_cls: Type[InjectType]):
"""Find the provider associated with a class binding."""
return self._providers.get(base_cls)

def inject(
def inject_or(
self,
base_cls: Type[InjectType],
settings: Mapping[str, object] = None,
*,
required: bool = True,
default: Optional[InjectType] = None,
) -> Optional[InjectType]:
"""
Get the provided instance of a given class identifier.
Get the provided instance of a given class identifier or default if not found.
Args:
cls: The base class to retrieve an instance of
params: An optional dict providing configuration to the provider
base_cls: The base class to retrieve an instance of
settings: An optional dict providing configuration to the provider
default: default return value if no instance is found
Returns:
An instance of the base class, or None
Expand All @@ -80,17 +80,37 @@ def inject(
result = provider.provide(ext_settings, self)
else:
result = None
if result is None:
if required:
raise InjectionError(
"No instance provided for class: {}".format(base_cls.__name__)
)
elif not isinstance(result, base_cls) and self.enforce_typing:

if result and not isinstance(result, base_cls) and self.enforce_typing:
raise InjectionError(
"Provided instance does not implement the base class: {}".format(
base_cls.__name__
)
)

return result if result is not None else default

def inject(
self,
base_cls: Type[InjectType],
settings: Mapping[str, object] = None,
) -> InjectType:
"""
Get the provided instance of a given class identifier.
Args:
cls: The base class to retrieve an instance of
params: An optional dict providing configuration to the provider
Returns:
An instance of the base class, or None
"""
result = self.inject_or(base_cls, settings)
if result is None:
raise InjectionError(
"No instance provided for class: {}".format(base_cls.__name__)
)
return result

def copy(self) -> BaseInjector:
Expand Down
2 changes: 1 addition & 1 deletion aries_cloudagent/config/ledger.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ async def ledger_config(

session = await profile.session()

ledger = session.inject(BaseLedger, required=False)
ledger = session.inject_or(BaseLedger)
if not ledger:
LOGGER.info("Ledger instance not provided")
return False
Expand Down
2 changes: 1 addition & 1 deletion aries_cloudagent/config/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def provide(self, config: BaseSettings, injector: BaseInjector):
"""Provide the object instance given a config and injector."""
instance = self._provider.provide(config, injector)
if self._methods:
collector: Collector = injector.inject(Collector, required=False)
collector: Collector = injector.inject_or(Collector)
if collector:
collector.wrap(
instance, self._methods, ignore_missing=self._ignore_missing
Expand Down
10 changes: 5 additions & 5 deletions aries_cloudagent/config/tests/test_injection_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def test_settings_scope(self):

async def test_inject_simple(self):
"""Test a basic injection."""
assert self.test_instance.inject(str, required=False) is None
assert self.test_instance.inject_or(str) is None
with self.assertRaises(InjectionError):
self.test_instance.inject(str)
self.test_instance.injector.bind_instance(str, self.test_value)
Expand All @@ -65,10 +65,10 @@ async def test_inject_simple(self):
async def test_inject_scope(self):
"""Test a scoped injection."""
context = self.test_instance.start_scope(self.test_scope)
assert context.inject(str, required=False) is None
assert context.inject_or(str) is None
context.injector.bind_instance(str, self.test_value)
assert context.inject(str) is self.test_value
assert self.test_instance.inject(str, required=False) is None
assert self.test_instance.inject_or(str) is None
root = context.injector_for_scope(context.ROOT_SCOPE)
assert root.inject(str, required=False) is None
assert self.test_instance.inject(str, required=False) is None
assert root.inject_or(str) is None
assert self.test_instance.inject_or(str) is None
10 changes: 8 additions & 2 deletions aries_cloudagent/config/tests/test_injector.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def test_settings_init(self):

def test_inject_simple(self):
"""Test a basic injection."""
assert self.test_instance.inject(str, required=False) is None
assert self.test_instance.inject_or(str) is None
with self.assertRaises(InjectionError):
self.test_instance.inject(str)
with self.assertRaises(ValueError):
Expand Down Expand Up @@ -75,7 +75,7 @@ def test_bad_provider(self):
self.test_instance.bind_provider(str, MockProvider(None))
with self.assertRaises(InjectionError):
self.test_instance.inject(str)
self.test_instance.inject(str, required=False)
self.test_instance.inject_or(str)
self.test_instance.bind_provider(str, MockProvider(1))
self.test_instance.clear_binding(str)
assert self.test_instance.get_provider(str) is None
Expand Down Expand Up @@ -139,3 +139,9 @@ def test_inject_cached(self):
i1 = self.test_instance.inject(MockInstance)
i2 = self.test_instance.inject(MockInstance)
assert i1 is i2

def test_falsey_still_returns(self):
"""Test the injector still returns falsey values."""
self.test_instance.bind_instance(dict, dict())
assert self.test_instance.inject_or(dict) is not None
assert self.test_instance.inject(dict) is not None
4 changes: 2 additions & 2 deletions aries_cloudagent/config/tests/test_wallet.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ async def setUp(self):
transaction=async_mock.CoroutineMock(return_value=self.session),
)

def _inject(cls, required=True):
return self.injector.inject(cls, required=required)
def _inject(cls):
return self.injector.inject(cls)

self.session.inject = _inject
self.context = InjectionContext()
Expand Down
6 changes: 3 additions & 3 deletions aries_cloudagent/core/conductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ async def setup(self):
self.dispatcher = Dispatcher(self.root_profile)
await self.dispatcher.setup()

wire_format = context.inject(BaseWireFormat, required=False)
wire_format = context.inject_or(BaseWireFormat)
if wire_format and hasattr(wire_format, "task_queue"):
wire_format.task_queue = self.dispatcher.task_queue

Expand Down Expand Up @@ -159,7 +159,7 @@ async def setup(self):
raise

# Fetch stats collector, if any
collector = context.inject(Collector, required=False)
collector = context.inject_or(Collector)
if collector:
# add stats to our own methods
collector.wrap(
Expand Down Expand Up @@ -358,7 +358,7 @@ async def stop(self, timeout=1.0):
shutdown.run(self.outbound_transport_manager.stop())

# close multitenant profiles
multitenant_mgr = self.context.inject(MultitenantManager, required=False)
multitenant_mgr = self.context.inject_or(MultitenantManager)
if multitenant_mgr:
for profile in multitenant_mgr._instances.values():
shutdown.run(profile.close())
Expand Down
2 changes: 1 addition & 1 deletion aries_cloudagent/core/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def __init__(self, profile: Profile):

async def setup(self):
"""Perform async instance setup."""
self.collector = self.profile.inject(Collector, required=False)
self.collector = self.profile.inject_or(Collector)
max_active = int(os.getenv("DISPATCHER_MAX_ACTIVE", 50))
self.task_queue = TaskQueue(
max_active=max_active, timed=bool(self.collector), trace_fn=self.log_task
Expand Down
Loading

0 comments on commit 040b2f7

Please sign in to comment.