From 7d161d7b117a0a5e62a673a4576a8bf3e2bf5f1e Mon Sep 17 00:00:00 2001 From: Yukio Fukuzawa Date: Tue, 14 Nov 2023 14:50:39 +1300 Subject: [PATCH 1/3] Allow provider to be a context manager (sync/async) --- src/inject/__init__.py | 34 +++++++++-- test/test_context_manager.py | 107 +++++++++++++++++++++++++++++++++++ 2 files changed, 137 insertions(+), 4 deletions(-) create mode 100644 test/test_context_manager.py diff --git a/src/inject/__init__.py b/src/inject/__init__.py index ebdac22..b6e735a 100644 --- a/src/inject/__init__.py +++ b/src/inject/__init__.py @@ -73,6 +73,8 @@ def my_config(binder): inject.configure(my_config) """ +import contextlib + from inject._version import __version__ import inspect @@ -335,12 +337,28 @@ def __call__(self, func: Callable[..., Union[Awaitable[T], T]]) -> Callable[..., async def async_injection_wrapper(*args: Any, **kwargs: Any) -> T: provided_params = frozenset( arg_names[:len(args)]) | frozenset(kwargs.keys()) + ctx_managers = {} + async_ctx_managers = {} for param, cls in params_to_provide.items(): if param not in provided_params: - kwargs[param] = instance(cls) + inst = instance(cls) + if isinstance(inst, contextlib.AbstractContextManager): + ctx_managers[param] = inst + elif isinstance(inst, contextlib.AbstractAsyncContextManager): + async_ctx_managers[param] = inst + else: + kwargs[param] = inst async_func = cast(Callable[..., Awaitable[T]], func) try: - return await async_func(*args, **kwargs) + with contextlib.ExitStack() as sync_stack: + ctx_kwargs = {param: sync_stack.enter_context(ctx_manager) for param, ctx_manager in + ctx_managers.items()} + kwargs.update(ctx_kwargs) + async with contextlib.AsyncExitStack() as async_stack: + asynx_ctx_kwargs = {param: await async_stack.enter_async_context(ctx_manager) for param, ctx_manager in + async_ctx_managers.items()} + kwargs.update(asynx_ctx_kwargs) + return await async_func(*args, **kwargs) except TypeError as previous_error: raise ConstructorTypeError(func, previous_error) @@ -350,12 +368,20 @@ async def async_injection_wrapper(*args: Any, **kwargs: Any) -> T: def injection_wrapper(*args: Any, **kwargs: Any) -> T: provided_params = frozenset( arg_names[:len(args)]) | frozenset(kwargs.keys()) + ctx_managers = {} for param, cls in params_to_provide.items(): if param not in provided_params: - kwargs[param] = instance(cls) + inst = instance(cls) + if isinstance(inst, contextlib.AbstractContextManager): + ctx_managers[param] = inst + else: + kwargs[param] = inst sync_func = cast(Callable[..., T], func) try: - return sync_func(*args, **kwargs) + with contextlib.ExitStack() as stack: + ctx_kwargs = {param: stack.enter_context(ctx_manager) for param, ctx_manager in ctx_managers.items()} + kwargs.update(ctx_kwargs) + return sync_func(*args, **kwargs) except TypeError as previous_error: raise ConstructorTypeError(func, previous_error) return injection_wrapper diff --git a/test/test_context_manager.py b/test/test_context_manager.py new file mode 100644 index 0000000..652a8ad --- /dev/null +++ b/test/test_context_manager.py @@ -0,0 +1,107 @@ +import contextlib + +import inject +from test import BaseTestInject + + +class Destroyable: + def __init__(self): + self.started = True + + def destroy(self): + self.started = False + + +class MockFile(Destroyable): + ... + + +class MockConnection(Destroyable): + ... + + +class MockFoo(Destroyable): + ... + + +@contextlib.contextmanager +def get_file_sync(): + obj = MockFile() + yield obj + obj.destroy() + + +@contextlib.contextmanager +def get_conn_sync(): + obj = MockConnection() + yield obj + obj.destroy() + + +@contextlib.contextmanager +def get_foo_sync(): + obj = MockFoo() + yield obj + obj.destroy() + + +@contextlib.asynccontextmanager +async def get_file_async(): + obj = MockFile() + yield obj + obj.destroy() + + +@contextlib.asynccontextmanager +async def get_conn_async(): + obj = MockConnection() + yield obj + obj.destroy() + + +class TestContextManagerFunctional(BaseTestInject): + + def test_provider_as_context_manager_sync(self): + def config(binder): + binder.bind_to_provider(MockFile, get_file_sync) + binder.bind(int, 100) + binder.bind_to_provider(str, lambda: "Hello") + binder.bind_to_provider(MockConnection, get_conn_sync) + + inject.configure(config) + + @inject.autoparams() + def mock_func(conn: MockConnection, name: str, f: MockFile, number: int): + assert f.started + assert conn.started + assert name == "Hello" + assert number == 100 + return f, conn + + f_, conn_ = mock_func() + assert not f_.started + assert not conn_.started + + def test_provider_as_context_manager_async(self): + def config(binder): + binder.bind_to_provider(MockFile, get_file_async) + binder.bind(int, 100) + binder.bind_to_provider(str, lambda: "Hello") + binder.bind_to_provider(MockConnection, get_conn_async) + binder.bind_to_provider(MockFoo, get_foo_sync) + + inject.configure(config) + + @inject.autoparams() + async def mock_func(conn: MockConnection, name: str, f: MockFile, number: int, foo: MockFoo): + assert f.started + assert conn.started + assert foo.started + assert name == "Hello" + assert number == 100 + return f, conn, foo + + f_, conn_, foo_ = self.run_async(mock_func()) + assert not f_.started + assert not conn_.started + assert not foo_.started \ No newline at end of file From f98ac51eed3c80bf690e594b282d05d719b1adc0 Mon Sep 17 00:00:00 2001 From: Yukio Fukuzawa Date: Tue, 14 Nov 2023 17:08:20 +1300 Subject: [PATCH 2/3] Narrow down context manager check to only include generator context managers --- src/inject/__init__.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/inject/__init__.py b/src/inject/__init__.py index b6e735a..42142a3 100644 --- a/src/inject/__init__.py +++ b/src/inject/__init__.py @@ -342,9 +342,9 @@ async def async_injection_wrapper(*args: Any, **kwargs: Any) -> T: for param, cls in params_to_provide.items(): if param not in provided_params: inst = instance(cls) - if isinstance(inst, contextlib.AbstractContextManager): + if isinstance(inst, contextlib._GeneratorContextManager): ctx_managers[param] = inst - elif isinstance(inst, contextlib.AbstractAsyncContextManager): + elif isinstance(inst, contextlib._AsyncGeneratorContextManager): async_ctx_managers[param] = inst else: kwargs[param] = inst @@ -372,7 +372,7 @@ def injection_wrapper(*args: Any, **kwargs: Any) -> T: for param, cls in params_to_provide.items(): if param not in provided_params: inst = instance(cls) - if isinstance(inst, contextlib.AbstractContextManager): + if isinstance(inst, contextlib._GeneratorContextManager): ctx_managers[param] = inst else: kwargs[param] = inst From f5272ca7cf88b244042aa8af52f680b422461b69 Mon Sep 17 00:00:00 2001 From: Yukio Fukuzawa Date: Wed, 22 Nov 2023 10:00:54 +1300 Subject: [PATCH 3/3] Restructure the code, add comments --- src/inject/__init__.py | 64 +++++++++++++++++++++++++----------------- 1 file changed, 39 insertions(+), 25 deletions(-) diff --git a/src/inject/__init__.py b/src/inject/__init__.py index 42142a3..4c48e4c 100644 --- a/src/inject/__init__.py +++ b/src/inject/__init__.py @@ -158,7 +158,10 @@ def bind_to_constructor(self, cls: Binding, constructor: Constructor) -> 'Binder return self def bind_to_provider(self, cls: Binding, provider: Provider) -> 'Binder': - """Bind a class to a callable instance provider executed for each injection.""" + """ + Bind a class to a callable instance provider executed for each injection. + A provider can be a normal function or a context manager. Both sync and async are supported. + """ self._check_class(cls) if provider is None: raise InjectorException('Provider cannot be None, key=%s' % cls) @@ -325,6 +328,35 @@ class _ParametersInjection(Generic[T]): def __init__(self, **kwargs: Any) -> None: self._params = kwargs + @staticmethod + def _aggregate_sync_stack( + sync_stack: contextlib.ExitStack, + provided_params: frozenset[str], + kwargs: dict[str, Any] + ) -> None: + """Extracts context managers, aggregate them in an ExitStack and swap out the param value with results of + running __enter__(). The result is equivalent to using `with` multiple times """ + executed_kwargs = { + param: sync_stack.enter_context(inst) + for param, inst in kwargs.items() + if param not in provided_params and isinstance(inst, contextlib._GeneratorContextManager) + } + kwargs.update(executed_kwargs) + + @staticmethod + async def _aggregate_async_stack( + async_stack: contextlib.AsyncExitStack, + provided_params: frozenset[str], + kwargs: dict[str, Any] + ) -> None: + """Similar to _aggregate_sync_stack, but for async context managers""" + executed_kwargs = { + param: await async_stack.enter_async_context(inst) + for param, inst in kwargs.items() + if param not in provided_params and isinstance(inst, contextlib._AsyncGeneratorContextManager) + } + kwargs.update(executed_kwargs) + def __call__(self, func: Callable[..., Union[Awaitable[T], T]]) -> Callable[..., Union[Awaitable[T], T]]: if sys.version_info.major == 2: arg_names = inspect.getargspec(func).args @@ -337,27 +369,15 @@ def __call__(self, func: Callable[..., Union[Awaitable[T], T]]) -> Callable[..., async def async_injection_wrapper(*args: Any, **kwargs: Any) -> T: provided_params = frozenset( arg_names[:len(args)]) | frozenset(kwargs.keys()) - ctx_managers = {} - async_ctx_managers = {} for param, cls in params_to_provide.items(): if param not in provided_params: - inst = instance(cls) - if isinstance(inst, contextlib._GeneratorContextManager): - ctx_managers[param] = inst - elif isinstance(inst, contextlib._AsyncGeneratorContextManager): - async_ctx_managers[param] = inst - else: - kwargs[param] = inst + kwargs[param] = instance(cls) async_func = cast(Callable[..., Awaitable[T]], func) try: with contextlib.ExitStack() as sync_stack: - ctx_kwargs = {param: sync_stack.enter_context(ctx_manager) for param, ctx_manager in - ctx_managers.items()} - kwargs.update(ctx_kwargs) async with contextlib.AsyncExitStack() as async_stack: - asynx_ctx_kwargs = {param: await async_stack.enter_async_context(ctx_manager) for param, ctx_manager in - async_ctx_managers.items()} - kwargs.update(asynx_ctx_kwargs) + self._aggregate_sync_stack(sync_stack, provided_params, kwargs) + await self._aggregate_async_stack(async_stack, provided_params, kwargs) return await async_func(*args, **kwargs) except TypeError as previous_error: raise ConstructorTypeError(func, previous_error) @@ -368,19 +388,13 @@ async def async_injection_wrapper(*args: Any, **kwargs: Any) -> T: def injection_wrapper(*args: Any, **kwargs: Any) -> T: provided_params = frozenset( arg_names[:len(args)]) | frozenset(kwargs.keys()) - ctx_managers = {} for param, cls in params_to_provide.items(): if param not in provided_params: - inst = instance(cls) - if isinstance(inst, contextlib._GeneratorContextManager): - ctx_managers[param] = inst - else: - kwargs[param] = inst + kwargs[param] = instance(cls) sync_func = cast(Callable[..., T], func) try: - with contextlib.ExitStack() as stack: - ctx_kwargs = {param: stack.enter_context(ctx_manager) for param, ctx_manager in ctx_managers.items()} - kwargs.update(ctx_kwargs) + with contextlib.ExitStack() as sync_stack: + self._aggregate_sync_stack(sync_stack, provided_params, kwargs) return sync_func(*args, **kwargs) except TypeError as previous_error: raise ConstructorTypeError(func, previous_error)