diff --git a/src/inject/__init__.py b/src/inject/__init__.py index ebdac22..b6e735a 100644 --- a/src/inject/__init__.py +++ b/src/inject/__init__.py @@ -73,6 +73,8 @@ def my_config(binder): inject.configure(my_config) """ +import contextlib + from inject._version import __version__ import inspect @@ -335,12 +337,28 @@ def __call__(self, func: Callable[..., Union[Awaitable[T], T]]) -> Callable[..., async def async_injection_wrapper(*args: Any, **kwargs: Any) -> T: provided_params = frozenset( arg_names[:len(args)]) | frozenset(kwargs.keys()) + ctx_managers = {} + async_ctx_managers = {} for param, cls in params_to_provide.items(): if param not in provided_params: - kwargs[param] = instance(cls) + inst = instance(cls) + if isinstance(inst, contextlib.AbstractContextManager): + ctx_managers[param] = inst + elif isinstance(inst, contextlib.AbstractAsyncContextManager): + async_ctx_managers[param] = inst + else: + kwargs[param] = inst async_func = cast(Callable[..., Awaitable[T]], func) try: - return await async_func(*args, **kwargs) + with contextlib.ExitStack() as sync_stack: + ctx_kwargs = {param: sync_stack.enter_context(ctx_manager) for param, ctx_manager in + ctx_managers.items()} + kwargs.update(ctx_kwargs) + async with contextlib.AsyncExitStack() as async_stack: + asynx_ctx_kwargs = {param: await async_stack.enter_async_context(ctx_manager) for param, ctx_manager in + async_ctx_managers.items()} + kwargs.update(asynx_ctx_kwargs) + return await async_func(*args, **kwargs) except TypeError as previous_error: raise ConstructorTypeError(func, previous_error) @@ -350,12 +368,20 @@ async def async_injection_wrapper(*args: Any, **kwargs: Any) -> T: def injection_wrapper(*args: Any, **kwargs: Any) -> T: provided_params = frozenset( arg_names[:len(args)]) | frozenset(kwargs.keys()) + ctx_managers = {} for param, cls in params_to_provide.items(): if param not in provided_params: - kwargs[param] = instance(cls) + inst = instance(cls) + if isinstance(inst, contextlib.AbstractContextManager): + ctx_managers[param] = inst + else: + kwargs[param] = inst sync_func = cast(Callable[..., T], func) try: - return sync_func(*args, **kwargs) + with contextlib.ExitStack() as stack: + ctx_kwargs = {param: stack.enter_context(ctx_manager) for param, ctx_manager in ctx_managers.items()} + kwargs.update(ctx_kwargs) + return sync_func(*args, **kwargs) except TypeError as previous_error: raise ConstructorTypeError(func, previous_error) return injection_wrapper diff --git a/test/test_context_manager.py b/test/test_context_manager.py new file mode 100644 index 0000000..652a8ad --- /dev/null +++ b/test/test_context_manager.py @@ -0,0 +1,107 @@ +import contextlib + +import inject +from test import BaseTestInject + + +class Destroyable: + def __init__(self): + self.started = True + + def destroy(self): + self.started = False + + +class MockFile(Destroyable): + ... + + +class MockConnection(Destroyable): + ... + + +class MockFoo(Destroyable): + ... + + +@contextlib.contextmanager +def get_file_sync(): + obj = MockFile() + yield obj + obj.destroy() + + +@contextlib.contextmanager +def get_conn_sync(): + obj = MockConnection() + yield obj + obj.destroy() + + +@contextlib.contextmanager +def get_foo_sync(): + obj = MockFoo() + yield obj + obj.destroy() + + +@contextlib.asynccontextmanager +async def get_file_async(): + obj = MockFile() + yield obj + obj.destroy() + + +@contextlib.asynccontextmanager +async def get_conn_async(): + obj = MockConnection() + yield obj + obj.destroy() + + +class TestContextManagerFunctional(BaseTestInject): + + def test_provider_as_context_manager_sync(self): + def config(binder): + binder.bind_to_provider(MockFile, get_file_sync) + binder.bind(int, 100) + binder.bind_to_provider(str, lambda: "Hello") + binder.bind_to_provider(MockConnection, get_conn_sync) + + inject.configure(config) + + @inject.autoparams() + def mock_func(conn: MockConnection, name: str, f: MockFile, number: int): + assert f.started + assert conn.started + assert name == "Hello" + assert number == 100 + return f, conn + + f_, conn_ = mock_func() + assert not f_.started + assert not conn_.started + + def test_provider_as_context_manager_async(self): + def config(binder): + binder.bind_to_provider(MockFile, get_file_async) + binder.bind(int, 100) + binder.bind_to_provider(str, lambda: "Hello") + binder.bind_to_provider(MockConnection, get_conn_async) + binder.bind_to_provider(MockFoo, get_foo_sync) + + inject.configure(config) + + @inject.autoparams() + async def mock_func(conn: MockConnection, name: str, f: MockFile, number: int, foo: MockFoo): + assert f.started + assert conn.started + assert foo.started + assert name == "Hello" + assert number == 100 + return f, conn, foo + + f_, conn_, foo_ = self.run_async(mock_func()) + assert not f_.started + assert not conn_.started + assert not foo_.started \ No newline at end of file