From 9b42fae8672af8fa202f937bbfcc877e3bb66d15 Mon Sep 17 00:00:00 2001 From: hodgespodge <44984311+hodgespodge@users.noreply.github.com> Date: Sun, 21 Jan 2024 04:20:54 -0500 Subject: [PATCH] Fix annotated #123 (#126) * added failing test case for annotated * Fix issues with Annotated * fix failing CI checks * cleaned up pr --------- Co-authored-by: Samuel Hodges Co-authored-by: Wessel --- plum/alias.py | 2 +- plum/plum_typing.py | 22 ++++++++++++++++++++++ plum/signature.py | 5 +++-- plum/type.py | 12 ++++++++---- tests/advanced/test_annotated.py | 29 +++++++++++++++++++++++++++++ 5 files changed, 63 insertions(+), 7 deletions(-) create mode 100644 plum/plum_typing.py create mode 100644 tests/advanced/test_annotated.py diff --git a/plum/alias.py b/plum/alias.py index 5c34ad7d..28b8652c 100644 --- a/plum/alias.py +++ b/plum/alias.py @@ -28,7 +28,7 @@ import typing from functools import wraps -from typing import get_args +from .plum_typing import get_args __all__ = ["activate_union_aliases", "deactivate_union_aliases", "set_union_alias"] diff --git a/plum/plum_typing.py b/plum/plum_typing.py new file mode 100644 index 00000000..bf666a07 --- /dev/null +++ b/plum/plum_typing.py @@ -0,0 +1,22 @@ +import sys +from typing import Literal + +if sys.version_info < (3, 9): + import typing_extensions + + get_type_hints = typing_extensions.get_type_hints + get_origin = typing_extensions.get_origin + get_args = typing_extensions.get_args + + def is_literal(x): + return x == Literal and not isinstance(x, typing_extensions.Annotated) + +else: + import typing + + get_type_hints = typing.get_type_hints + get_origin = typing.get_origin + get_args = typing.get_args + + def is_literal(x): + return x == Literal \ No newline at end of file diff --git a/plum/signature.py b/plum/signature.py index 5944cff8..de70bc70 100644 --- a/plum/signature.py +++ b/plum/signature.py @@ -11,7 +11,8 @@ from . import _is_bearable from .repr import repr_short, rich_repr from .type import is_faithful, resolve_type_hint -from .util import Comparable, Missing, TypeHint, multihash, wrap_lambda +from .plum_typing import get_type_hints +from .util import Comparable, Missing, TypeHint, multihash, repr_short, wrap_lambda __all__ = ["Signature", "append_default_args"] @@ -285,7 +286,7 @@ def resolve_pep563(f: Callable): beartype_resolve_pep563(f) # This mutates `f`. # Override the `__annotations__` attribute, since `resolve_pep563` modifies # `f` too. - for k, v in typing.get_type_hints(f).items(): + for k, v in get_type_hints(f, include_extras=True).items(): f.__annotations__[k] = v diff --git a/plum/type.py b/plum/type.py index d6ea2148..3ed7b9e6 100644 --- a/plum/type.py +++ b/plum/type.py @@ -2,7 +2,9 @@ import sys import typing import warnings -from typing import Literal, get_args, get_origin +from .plum_typing import get_args, get_origin, is_literal + +from beartype.vale._core._valecore import BeartypeValidator try: # pragma: specific no cover 3.8 3.9 from types import UnionType @@ -128,7 +130,8 @@ def _is_hint(x): return x.__module__ in { "types", # E.g., `tuple[int]` "typing", - "collections.abc", # E.g., `Callable` + "collections.abc", # E.g., `Callable`, + "typing_extensions", } except AttributeError: return False @@ -183,7 +186,7 @@ def resolve_type_hint(x): return y else: # Do not resolve the arguments for `Literal`s. - if origin != Literal: + if not is_literal(origin): args = resolve_type_hint(args) try: return origin[args] @@ -218,7 +221,8 @@ def resolve_type_hint(x): return resolve_type_hint(x.resolve()) else: return x - + elif isinstance(x, BeartypeValidator): + return x else: warnings.warn( f"Could not resolve the type hint of `{x}`. " diff --git a/tests/advanced/test_annotated.py b/tests/advanced/test_annotated.py new file mode 100644 index 00000000..5fac2ecf --- /dev/null +++ b/tests/advanced/test_annotated.py @@ -0,0 +1,29 @@ +import sys + +if sys.version_info < (3, 9): + from typing_extensions import Annotated +else: + from typing import Annotated + +import pytest +from beartype.vale import Is + +from plum import Dispatcher, NotFoundLookupError + + +def test_simple_annotated(): + dispatch = Dispatcher() + + positive_int = Annotated[int, Is[lambda value: value > 0]] + + @dispatch + def f(x: positive_int): + return x + + assert f(1) == 1 + + with pytest.raises(NotFoundLookupError): + f("my string") + + with pytest.raises(NotFoundLookupError): + f(-1)