Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support annotated types #1

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions injector/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -705,6 +705,8 @@ def _punch_through_alias(type_: Any) -> type:
and type(type_).__name__ == 'NewType'
):
return type_.__supertype__
elif isinstance(type_, _AnnotatedAlias) and getattr(type_, '__metadata__', None) is not None:
return type_.__origin__
else:
return type_

Expand Down Expand Up @@ -1237,8 +1239,17 @@ def _is_new_union_type(instance: Any) -> bool:

for k, v in list(bindings.items()):
if _is_specialization(v, Annotated):
v, metadata = v.__origin__, v.__metadata__
bindings[k] = v
origin, metadata = v.__origin__, v.__metadata__

if _inject_marker in metadata or _noinject_marker in metadata:
new_metadata = tuple(m for m in metadata if m not in [_inject_marker, _noinject_marker])
if len(new_metadata) == 0:
new_type = origin
else:
new_type = _AnnotatedAlias(origin, new_metadata)
bindings[k] = new_type
else:
bindings[k] = v
else:
metadata = tuple()

Expand Down
150 changes: 150 additions & 0 deletions injector_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,19 @@
"""Functional tests for the "Injector" dependency injection framework."""

from contextlib import contextmanager
from dataclasses import dataclass
from typing import Any, NewType, Optional, Union
import abc
import sys
import threading
import traceback
import warnings

if sys.version_info >= (3, 9):
from typing import Annotated
else:
from typing_extensions import Annotated

from typing import Dict, List, NewType

import pytest
Expand Down Expand Up @@ -1682,3 +1688,147 @@ def function1(a: int | str) -> None:
pass

assert get_bindings(function1) == {'a': Union[int, str]}


# test for https://github.com/python-injector/injector/issues/217
def test_annotated_instance_integration_works():
UserID = Annotated[int, "user_id"]

def configure(binder):
binder.bind(UserID, to=123)

injector = Injector([configure])
assert injector.get(UserID) == 123


def test_annotated_class_integration_works():
class Shape(abc.ABC):
pass

class Circle(Shape):
pass

first = Annotated[Shape, "first"]

def configure(binder):
binder.bind(first, to=Circle)

injector = Injector([configure])
assert isinstance(injector.get(first), Circle)


def test_annotated_meta_separate_bindings():
first = Annotated[int, "first"]
second = Annotated[int, "second"]

def configure(binder):
binder.bind(first, to=123)
binder.bind(second, to=456)

injector = Injector([configure])
assert injector.get(first) == 123
assert injector.get(second) == 456
assert injector.get(first) != injector.get(second)


def test_annotated_origin_separate_bindings():
UserID = Annotated[int, "user_id"]

def configure(binder):
binder.bind(UserID, to=123)
binder.bind(int, to=456)

injector = Injector([configure])
assert injector.get(UserID) == 123
assert injector.get(int) == 456
assert injector.get(UserID) != injector.get(int)


def test_annotated_non_comparable_types():
foo = Annotated[int, float("nan")]
bar = Annotated[int, object()]

def configure(binder):
binder.bind(foo, to=123)
binder.bind(bar, to=456)

injector = Injector([configure])
assert injector.get(foo) == 123
assert injector.get(bar) == 456


def test_annotated_injection():
UserID = Annotated[int, "user_id"]

def configure(binder):
binder.bind(UserID, to=123)

@inject
@dataclass
class Data:
def __init__(self, user_id: UserID) -> None:
self.user_id = user_id

injector = Injector([configure])
assert injector.get(Data).user_id == 123


def test_annotated_call_with_injection():
UserID = Annotated[int, "user_id"]

def configure(binder):
binder.bind(UserID, to=123)

@inject
def fun(user_id: UserID) -> int:
return user_id

injector = Injector([configure])
assert injector.call_with_injection(fun) == 123


def test_with_injector_wrapper():
UserID = Annotated[int, "user_id"]

def configure(binder):
binder.bind(UserID, to=123)

@dataclass
class Data:
user_id: Inject[UserID]

injector = Injector([configure])

assert injector.get(Data).user_id == 123


def test_annotated_in_injection_configuration():
UserID = Annotated[int, "user_id"]

def configure(binder):
binder.bind(UserID, to=123)

@inject
@dataclass
class Data:
user_id: Annotated[int, "user_id"]

injector = Injector([configure])

assert injector.get(Data).user_id == 123


def test_bind_to_annotated_directly():
def configure(binder):
# works at runtime, mypy complains:
# Argument 1 to "bind" of "Binder" has incompatible type "object"; expected "type[str]"
binder.bind(Annotated[int, "user_id"], to=123)

@inject
@dataclass
class Data:
user_id: Annotated[int, "user_id"]

injector = Injector([configure])

assert injector.get(Data).user_id == 123
Loading