Skip to content

Commit

Permalink
Add support for Annotated types (#219)
Browse files Browse the repository at this point in the history
This is a relatively simple and straightforward implementation with minimal changes and seems to work fine.

Resolves #217
  • Loading branch information
ljnsn authored Jan 31, 2024
1 parent 44ac4e0 commit 78bb621
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 0 deletions.
2 changes: 2 additions & 0 deletions injector/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -705,6 +705,8 @@ def _punch_through_alias(type_: Any) -> type:
and type(type_).__name__ == 'NewType'
):
return type_.__supertype__
elif isinstance(type_, _AnnotatedAlias) and getattr(type_, '__metadata__', None) is not None:
return type_.__origin__
else:
return type_

Expand Down
72 changes: 72 additions & 0 deletions injector_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@
import traceback
import warnings

if sys.version_info >= (3, 9):
from typing import Annotated
else:
from typing_extensions import Annotated

from typing import Dict, List, NewType

import pytest
Expand Down Expand Up @@ -1682,3 +1687,70 @@ def function1(a: int | str) -> None:
pass

assert get_bindings(function1) == {'a': Union[int, str]}


# test for https://github.com/python-injector/injector/issues/217
def test_annotated_instance_integration_works():
UserID = Annotated[int, "user_id"]

def configure(binder):
binder.bind(UserID, to=123)

injector = Injector([configure])
assert injector.get(UserID) == 123


def test_annotated_class_integration_works():
class Shape(abc.ABC):
pass

class Circle(Shape):
pass

first = Annotated[Shape, "first"]

def configure(binder):
binder.bind(first, to=Circle)

injector = Injector([configure])
assert isinstance(injector.get(first), Circle)


def test_annotated_meta_separate_bindings():
first = Annotated[int, "first"]
second = Annotated[int, "second"]

def configure(binder):
binder.bind(first, to=123)
binder.bind(second, to=456)

injector = Injector([configure])
assert injector.get(first) == 123
assert injector.get(second) == 456
assert injector.get(first) != injector.get(second)


def test_annotated_origin_separate_bindings():
UserID = Annotated[int, "user_id"]

def configure(binder):
binder.bind(UserID, to=123)
binder.bind(int, to=456)

injector = Injector([configure])
assert injector.get(UserID) == 123
assert injector.get(int) == 456
assert injector.get(UserID) != injector.get(int)


def test_annotated_non_comparable_types():
foo = Annotated[int, float("nan")]
bar = Annotated[int, object()]

def configure(binder):
binder.bind(foo, to=123)
binder.bind(bar, to=456)

injector = Injector([configure])
assert injector.get(foo) == 123
assert injector.get(bar) == 456

0 comments on commit 78bb621

Please sign in to comment.