diff --git a/injector/__init__.py b/injector/__init__.py index 4136f8f..e570452 100644 --- a/injector/__init__.py +++ b/injector/__init__.py @@ -705,6 +705,8 @@ def _punch_through_alias(type_: Any) -> type: and type(type_).__name__ == 'NewType' ): return type_.__supertype__ + elif isinstance(type_, _AnnotatedAlias) and getattr(type_, '__metadata__', None) is not None: + return type_.__origin__ else: return type_ @@ -1237,8 +1239,17 @@ def _is_new_union_type(instance: Any) -> bool: for k, v in list(bindings.items()): if _is_specialization(v, Annotated): - v, metadata = v.__origin__, v.__metadata__ - bindings[k] = v + origin, metadata = v.__origin__, v.__metadata__ + + if _inject_marker in metadata or _noinject_marker in metadata: + new_metadata = tuple(m for m in metadata if m not in [_inject_marker, _noinject_marker]) + if len(new_metadata) == 0: + new_type = origin + else: + new_type = _AnnotatedAlias(origin, new_metadata) + bindings[k] = new_type + else: + bindings[k] = v else: metadata = tuple() diff --git a/injector_test.py b/injector_test.py index 10087f2..323a250 100644 --- a/injector_test.py +++ b/injector_test.py @@ -11,6 +11,7 @@ """Functional tests for the "Injector" dependency injection framework.""" from contextlib import contextmanager +from dataclasses import dataclass from typing import Any, NewType, Optional, Union import abc import sys @@ -18,6 +19,11 @@ import traceback import warnings +if sys.version_info >= (3, 9): + from typing import Annotated +else: + from typing_extensions import Annotated + from typing import Dict, List, NewType import pytest @@ -1682,3 +1688,147 @@ def function1(a: int | str) -> None: pass assert get_bindings(function1) == {'a': Union[int, str]} + + +# test for https://github.com/python-injector/injector/issues/217 +def test_annotated_instance_integration_works(): + UserID = Annotated[int, "user_id"] + + def configure(binder): + binder.bind(UserID, to=123) + + injector = Injector([configure]) + assert injector.get(UserID) == 123 + + +def test_annotated_class_integration_works(): + class Shape(abc.ABC): + pass + + class Circle(Shape): + pass + + first = Annotated[Shape, "first"] + + def configure(binder): + binder.bind(first, to=Circle) + + injector = Injector([configure]) + assert isinstance(injector.get(first), Circle) + + +def test_annotated_meta_separate_bindings(): + first = Annotated[int, "first"] + second = Annotated[int, "second"] + + def configure(binder): + binder.bind(first, to=123) + binder.bind(second, to=456) + + injector = Injector([configure]) + assert injector.get(first) == 123 + assert injector.get(second) == 456 + assert injector.get(first) != injector.get(second) + + +def test_annotated_origin_separate_bindings(): + UserID = Annotated[int, "user_id"] + + def configure(binder): + binder.bind(UserID, to=123) + binder.bind(int, to=456) + + injector = Injector([configure]) + assert injector.get(UserID) == 123 + assert injector.get(int) == 456 + assert injector.get(UserID) != injector.get(int) + + +def test_annotated_non_comparable_types(): + foo = Annotated[int, float("nan")] + bar = Annotated[int, object()] + + def configure(binder): + binder.bind(foo, to=123) + binder.bind(bar, to=456) + + injector = Injector([configure]) + assert injector.get(foo) == 123 + assert injector.get(bar) == 456 + + +def test_annotated_injection(): + UserID = Annotated[int, "user_id"] + + def configure(binder): + binder.bind(UserID, to=123) + + @inject + @dataclass + class Data: + def __init__(self, user_id: UserID) -> None: + self.user_id = user_id + + injector = Injector([configure]) + assert injector.get(Data).user_id == 123 + + +def test_annotated_call_with_injection(): + UserID = Annotated[int, "user_id"] + + def configure(binder): + binder.bind(UserID, to=123) + + @inject + def fun(user_id: UserID) -> int: + return user_id + + injector = Injector([configure]) + assert injector.call_with_injection(fun) == 123 + + +def test_with_injector_wrapper(): + UserID = Annotated[int, "user_id"] + + def configure(binder): + binder.bind(UserID, to=123) + + @dataclass + class Data: + user_id: Inject[UserID] + + injector = Injector([configure]) + + assert injector.get(Data).user_id == 123 + + +def test_annotated_in_injection_configuration(): + UserID = Annotated[int, "user_id"] + + def configure(binder): + binder.bind(UserID, to=123) + + @inject + @dataclass + class Data: + user_id: Annotated[int, "user_id"] + + injector = Injector([configure]) + + assert injector.get(Data).user_id == 123 + + +def test_bind_to_annotated_directly(): + def configure(binder): + # works at runtime, mypy complains: + # Argument 1 to "bind" of "Binder" has incompatible type "object"; expected "type[str]" + binder.bind(Annotated[int, "user_id"], to=123) + + @inject + @dataclass + class Data: + user_id: Annotated[int, "user_id"] + + injector = Injector([configure]) + + assert injector.get(Data).user_id == 123