Skip to content

Commit

Permalink
Allow provider to be a context manager (sync/async)
Browse files Browse the repository at this point in the history
  • Loading branch information
fzyukio committed Nov 14, 2023
1 parent 5be9189 commit 7d161d7
Show file tree
Hide file tree
Showing 2 changed files with 137 additions and 4 deletions.
34 changes: 30 additions & 4 deletions src/inject/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ def my_config(binder):
inject.configure(my_config)
"""
import contextlib

from inject._version import __version__

import inspect
Expand Down Expand Up @@ -335,12 +337,28 @@ def __call__(self, func: Callable[..., Union[Awaitable[T], T]]) -> Callable[...,
async def async_injection_wrapper(*args: Any, **kwargs: Any) -> T:
provided_params = frozenset(
arg_names[:len(args)]) | frozenset(kwargs.keys())
ctx_managers = {}
async_ctx_managers = {}
for param, cls in params_to_provide.items():
if param not in provided_params:
kwargs[param] = instance(cls)
inst = instance(cls)
if isinstance(inst, contextlib.AbstractContextManager):
ctx_managers[param] = inst
elif isinstance(inst, contextlib.AbstractAsyncContextManager):
async_ctx_managers[param] = inst
else:
kwargs[param] = inst
async_func = cast(Callable[..., Awaitable[T]], func)
try:
return await async_func(*args, **kwargs)
with contextlib.ExitStack() as sync_stack:
ctx_kwargs = {param: sync_stack.enter_context(ctx_manager) for param, ctx_manager in
ctx_managers.items()}
kwargs.update(ctx_kwargs)
async with contextlib.AsyncExitStack() as async_stack:
asynx_ctx_kwargs = {param: await async_stack.enter_async_context(ctx_manager) for param, ctx_manager in
async_ctx_managers.items()}
kwargs.update(asynx_ctx_kwargs)
return await async_func(*args, **kwargs)
except TypeError as previous_error:
raise ConstructorTypeError(func, previous_error)

Expand All @@ -350,12 +368,20 @@ async def async_injection_wrapper(*args: Any, **kwargs: Any) -> T:
def injection_wrapper(*args: Any, **kwargs: Any) -> T:
provided_params = frozenset(
arg_names[:len(args)]) | frozenset(kwargs.keys())
ctx_managers = {}
for param, cls in params_to_provide.items():
if param not in provided_params:
kwargs[param] = instance(cls)
inst = instance(cls)
if isinstance(inst, contextlib.AbstractContextManager):
ctx_managers[param] = inst
else:
kwargs[param] = inst
sync_func = cast(Callable[..., T], func)
try:
return sync_func(*args, **kwargs)
with contextlib.ExitStack() as stack:
ctx_kwargs = {param: stack.enter_context(ctx_manager) for param, ctx_manager in ctx_managers.items()}
kwargs.update(ctx_kwargs)
return sync_func(*args, **kwargs)
except TypeError as previous_error:
raise ConstructorTypeError(func, previous_error)
return injection_wrapper
Expand Down
107 changes: 107 additions & 0 deletions test/test_context_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import contextlib

import inject
from test import BaseTestInject


class Destroyable:
def __init__(self):
self.started = True

def destroy(self):
self.started = False


class MockFile(Destroyable):
...


class MockConnection(Destroyable):
...


class MockFoo(Destroyable):
...


@contextlib.contextmanager
def get_file_sync():
obj = MockFile()
yield obj
obj.destroy()


@contextlib.contextmanager
def get_conn_sync():
obj = MockConnection()
yield obj
obj.destroy()


@contextlib.contextmanager
def get_foo_sync():
obj = MockFoo()
yield obj
obj.destroy()


@contextlib.asynccontextmanager
async def get_file_async():
obj = MockFile()
yield obj
obj.destroy()


@contextlib.asynccontextmanager
async def get_conn_async():
obj = MockConnection()
yield obj
obj.destroy()


class TestContextManagerFunctional(BaseTestInject):

def test_provider_as_context_manager_sync(self):
def config(binder):
binder.bind_to_provider(MockFile, get_file_sync)
binder.bind(int, 100)
binder.bind_to_provider(str, lambda: "Hello")
binder.bind_to_provider(MockConnection, get_conn_sync)

inject.configure(config)

@inject.autoparams()
def mock_func(conn: MockConnection, name: str, f: MockFile, number: int):
assert f.started
assert conn.started
assert name == "Hello"
assert number == 100
return f, conn

f_, conn_ = mock_func()
assert not f_.started
assert not conn_.started

def test_provider_as_context_manager_async(self):
def config(binder):
binder.bind_to_provider(MockFile, get_file_async)
binder.bind(int, 100)
binder.bind_to_provider(str, lambda: "Hello")
binder.bind_to_provider(MockConnection, get_conn_async)
binder.bind_to_provider(MockFoo, get_foo_sync)

inject.configure(config)

@inject.autoparams()
async def mock_func(conn: MockConnection, name: str, f: MockFile, number: int, foo: MockFoo):
assert f.started
assert conn.started
assert foo.started
assert name == "Hello"
assert number == 100
return f, conn, foo

f_, conn_, foo_ = self.run_async(mock_func())
assert not f_.started
assert not conn_.started
assert not foo_.started

0 comments on commit 7d161d7

Please sign in to comment.