Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow provider to be a context manager (sync/async) #92

Merged
merged 3 commits into from
Nov 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 43 additions & 3 deletions src/inject/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ def my_config(binder):
inject.configure(my_config)

"""
import contextlib

from inject._version import __version__

import inspect
Expand Down Expand Up @@ -156,7 +158,10 @@ def bind_to_constructor(self, cls: Binding, constructor: Constructor) -> 'Binder
return self

def bind_to_provider(self, cls: Binding, provider: Provider) -> 'Binder':
"""Bind a class to a callable instance provider executed for each injection."""
"""
Bind a class to a callable instance provider executed for each injection.
A provider can be a normal function or a context manager. Both sync and async are supported.
"""
self._check_class(cls)
if provider is None:
raise InjectorException('Provider cannot be None, key=%s' % cls)
Expand Down Expand Up @@ -323,6 +328,35 @@ class _ParametersInjection(Generic[T]):
def __init__(self, **kwargs: Any) -> None:
self._params = kwargs

@staticmethod
def _aggregate_sync_stack(
sync_stack: contextlib.ExitStack,
provided_params: frozenset[str],
kwargs: dict[str, Any]
) -> None:
"""Extracts context managers, aggregate them in an ExitStack and swap out the param value with results of
running __enter__(). The result is equivalent to using `with` multiple times """
executed_kwargs = {
param: sync_stack.enter_context(inst)
for param, inst in kwargs.items()
if param not in provided_params and isinstance(inst, contextlib._GeneratorContextManager)
}
kwargs.update(executed_kwargs)

@staticmethod
async def _aggregate_async_stack(
async_stack: contextlib.AsyncExitStack,
provided_params: frozenset[str],
kwargs: dict[str, Any]
) -> None:
"""Similar to _aggregate_sync_stack, but for async context managers"""
executed_kwargs = {
param: await async_stack.enter_async_context(inst)
for param, inst in kwargs.items()
if param not in provided_params and isinstance(inst, contextlib._AsyncGeneratorContextManager)
}
kwargs.update(executed_kwargs)

def __call__(self, func: Callable[..., Union[Awaitable[T], T]]) -> Callable[..., Union[Awaitable[T], T]]:
if sys.version_info.major == 2:
arg_names = inspect.getargspec(func).args
Expand All @@ -340,7 +374,11 @@ async def async_injection_wrapper(*args: Any, **kwargs: Any) -> T:
kwargs[param] = instance(cls)
async_func = cast(Callable[..., Awaitable[T]], func)
try:
return await async_func(*args, **kwargs)
with contextlib.ExitStack() as sync_stack:
async with contextlib.AsyncExitStack() as async_stack:
self._aggregate_sync_stack(sync_stack, provided_params, kwargs)
await self._aggregate_async_stack(async_stack, provided_params, kwargs)
return await async_func(*args, **kwargs)
except TypeError as previous_error:
raise ConstructorTypeError(func, previous_error)

Expand All @@ -355,7 +393,9 @@ def injection_wrapper(*args: Any, **kwargs: Any) -> T:
kwargs[param] = instance(cls)
sync_func = cast(Callable[..., T], func)
try:
return sync_func(*args, **kwargs)
with contextlib.ExitStack() as sync_stack:
self._aggregate_sync_stack(sync_stack, provided_params, kwargs)
return sync_func(*args, **kwargs)
except TypeError as previous_error:
raise ConstructorTypeError(func, previous_error)
return injection_wrapper
Expand Down
107 changes: 107 additions & 0 deletions test/test_context_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import contextlib

import inject
from test import BaseTestInject


class Destroyable:
def __init__(self):
self.started = True

def destroy(self):
self.started = False


class MockFile(Destroyable):
...


class MockConnection(Destroyable):
...


class MockFoo(Destroyable):
...


@contextlib.contextmanager
def get_file_sync():
obj = MockFile()
yield obj
obj.destroy()


@contextlib.contextmanager
def get_conn_sync():
obj = MockConnection()
yield obj
obj.destroy()


@contextlib.contextmanager
def get_foo_sync():
obj = MockFoo()
yield obj
obj.destroy()


@contextlib.asynccontextmanager
async def get_file_async():
obj = MockFile()
yield obj
obj.destroy()


@contextlib.asynccontextmanager
async def get_conn_async():
obj = MockConnection()
yield obj
obj.destroy()


class TestContextManagerFunctional(BaseTestInject):

def test_provider_as_context_manager_sync(self):
def config(binder):
binder.bind_to_provider(MockFile, get_file_sync)
binder.bind(int, 100)
binder.bind_to_provider(str, lambda: "Hello")
binder.bind_to_provider(MockConnection, get_conn_sync)

inject.configure(config)

@inject.autoparams()
def mock_func(conn: MockConnection, name: str, f: MockFile, number: int):
assert f.started
assert conn.started
assert name == "Hello"
assert number == 100
return f, conn

f_, conn_ = mock_func()
assert not f_.started
assert not conn_.started

def test_provider_as_context_manager_async(self):
def config(binder):
binder.bind_to_provider(MockFile, get_file_async)
binder.bind(int, 100)
binder.bind_to_provider(str, lambda: "Hello")
binder.bind_to_provider(MockConnection, get_conn_async)
binder.bind_to_provider(MockFoo, get_foo_sync)

inject.configure(config)

@inject.autoparams()
async def mock_func(conn: MockConnection, name: str, f: MockFile, number: int, foo: MockFoo):
assert f.started
assert conn.started
assert foo.started
assert name == "Hello"
assert number == 100
return f, conn, foo

f_, conn_, foo_ = self.run_async(mock_func())
assert not f_.started
assert not conn_.started
assert not foo_.started