From b3018854c031854424bc9d54f20f9e68e3daf3f8 Mon Sep 17 00:00:00 2001 From: Ricardo Busquet Date: Wed, 31 Jan 2024 12:12:30 -0500 Subject: [PATCH 1/5] support annotated types --- injector/__init__.py | 15 +++++- injector_test.py | 117 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 130 insertions(+), 2 deletions(-) diff --git a/injector/__init__.py b/injector/__init__.py index 4136f8f..e570452 100644 --- a/injector/__init__.py +++ b/injector/__init__.py @@ -705,6 +705,8 @@ def _punch_through_alias(type_: Any) -> type: and type(type_).__name__ == 'NewType' ): return type_.__supertype__ + elif isinstance(type_, _AnnotatedAlias) and getattr(type_, '__metadata__', None) is not None: + return type_.__origin__ else: return type_ @@ -1237,8 +1239,17 @@ def _is_new_union_type(instance: Any) -> bool: for k, v in list(bindings.items()): if _is_specialization(v, Annotated): - v, metadata = v.__origin__, v.__metadata__ - bindings[k] = v + origin, metadata = v.__origin__, v.__metadata__ + + if _inject_marker in metadata or _noinject_marker in metadata: + new_metadata = tuple(m for m in metadata if m not in [_inject_marker, _noinject_marker]) + if len(new_metadata) == 0: + new_type = origin + else: + new_type = _AnnotatedAlias(origin, new_metadata) + bindings[k] = new_type + else: + bindings[k] = v else: metadata = tuple() diff --git a/injector_test.py b/injector_test.py index 10087f2..59c677e 100644 --- a/injector_test.py +++ b/injector_test.py @@ -11,6 +11,7 @@ """Functional tests for the "Injector" dependency injection framework.""" from contextlib import contextmanager +from dataclasses import dataclass from typing import Any, NewType, Optional, Union import abc import sys @@ -18,6 +19,11 @@ import traceback import warnings +if sys.version_info >= (3, 9): + from typing import Annotated +else: + from typing_extensions import Annotated + from typing import Dict, List, NewType import pytest @@ -1682,3 +1688,114 @@ def function1(a: int | str) -> None: pass assert get_bindings(function1) == {'a': Union[int, str]} + + +# test for https://github.com/python-injector/injector/issues/217 +def test_annotated_instance_integration_works(): + UserID = Annotated[int, "user_id"] + + def configure(binder): + binder.bind(UserID, to=123) + + injector = Injector([configure]) + assert injector.get(UserID) == 123 + + +def test_annotated_class_integration_works(): + class Shape(abc.ABC): + pass + + class Circle(Shape): + pass + + first = Annotated[Shape, "first"] + + def configure(binder): + binder.bind(first, to=Circle) + + injector = Injector([configure]) + assert isinstance(injector.get(first), Circle) + + +def test_annotated_meta_separate_bindings(): + first = Annotated[int, "first"] + second = Annotated[int, "second"] + + def configure(binder): + binder.bind(first, to=123) + binder.bind(second, to=456) + + injector = Injector([configure]) + assert injector.get(first) == 123 + assert injector.get(second) == 456 + assert injector.get(first) != injector.get(second) + + +def test_annotated_origin_separate_bindings(): + UserID = Annotated[int, "user_id"] + + def configure(binder): + binder.bind(UserID, to=123) + binder.bind(int, to=456) + + injector = Injector([configure]) + assert injector.get(UserID) == 123 + assert injector.get(int) == 456 + assert injector.get(UserID) != injector.get(int) + + +def test_annotated_non_comparable_types(): + foo = Annotated[int, float("nan")] + bar = Annotated[int, object()] + + def configure(binder): + binder.bind(foo, to=123) + binder.bind(bar, to=456) + + injector = Injector([configure]) + assert injector.get(foo) == 123 + assert injector.get(bar) == 456 + + +def test_annotated_injection(): + UserID = Annotated[int, "user_id"] + + def configure(binder): + binder.bind(UserID, to=123) + + @inject + @dataclass + class Data: + def __init__(self, user_id: UserID) -> None: + self.user_id = user_id + + injector = Injector([configure]) + assert injector.get(Data).user_id == 123 + + +def test_annotated_call_with_injection(): + UserID = Annotated[int, "user_id"] + + def configure(binder): + binder.bind(UserID, to=123) + + @inject + def fun(user_id: UserID) -> int: + return user_id + + injector = Injector([configure]) + assert injector.call_with_injection(fun) == 123 + +def test_with_injector_wrapper(): + UserID = Annotated[int, "user_id"] + + def configure(binder): + binder.bind(UserID, to=123) + + @dataclass + class Data: + user_id: Inject[UserID] + + injector = Injector([configure]) + + assert injector.get(Data).user_id == 123 \ No newline at end of file From 68d4e5242878aaa7609a7d715bb76cb43a45af9a Mon Sep 17 00:00:00 2001 From: Ricardo Busquet Date: Wed, 31 Jan 2024 12:46:23 -0500 Subject: [PATCH 2/5] new line --- injector_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/injector_test.py b/injector_test.py index 59c677e..d477112 100644 --- a/injector_test.py +++ b/injector_test.py @@ -1798,4 +1798,4 @@ class Data: injector = Injector([configure]) - assert injector.get(Data).user_id == 123 \ No newline at end of file + assert injector.get(Data).user_id == 123 From 0607e70e8922b0de782e3f632482d42c68fffd76 Mon Sep 17 00:00:00 2001 From: Ricardo Busquet Date: Wed, 31 Jan 2024 12:52:26 -0500 Subject: [PATCH 3/5] start actions From 21a2197fd33eca824366db1e5e5acea3bbb82192 Mon Sep 17 00:00:00 2001 From: Ricardo Busquet Date: Wed, 31 Jan 2024 12:54:03 -0500 Subject: [PATCH 4/5] black --- injector_test.py | 1 + 1 file changed, 1 insertion(+) diff --git a/injector_test.py b/injector_test.py index d477112..8a93872 100644 --- a/injector_test.py +++ b/injector_test.py @@ -1786,6 +1786,7 @@ def fun(user_id: UserID) -> int: injector = Injector([configure]) assert injector.call_with_injection(fun) == 123 + def test_with_injector_wrapper(): UserID = Annotated[int, "user_id"] From 2cd11918b108b8c4239d6392b584b1d59fc2fafc Mon Sep 17 00:00:00 2001 From: Ricardo Busquet Date: Wed, 31 Jan 2024 14:27:04 -0500 Subject: [PATCH 5/5] more tests --- injector_test.py | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/injector_test.py b/injector_test.py index 8a93872..323a250 100644 --- a/injector_test.py +++ b/injector_test.py @@ -1800,3 +1800,35 @@ class Data: injector = Injector([configure]) assert injector.get(Data).user_id == 123 + + +def test_annotated_in_injection_configuration(): + UserID = Annotated[int, "user_id"] + + def configure(binder): + binder.bind(UserID, to=123) + + @inject + @dataclass + class Data: + user_id: Annotated[int, "user_id"] + + injector = Injector([configure]) + + assert injector.get(Data).user_id == 123 + + +def test_bind_to_annotated_directly(): + def configure(binder): + # works at runtime, mypy complains: + # Argument 1 to "bind" of "Binder" has incompatible type "object"; expected "type[str]" + binder.bind(Annotated[int, "user_id"], to=123) + + @inject + @dataclass + class Data: + user_id: Annotated[int, "user_id"] + + injector = Injector([configure]) + + assert injector.get(Data).user_id == 123