From dac88f346cada58ae022599fffbbd961643b5d5b Mon Sep 17 00:00:00 2001 From: Nikita Sobolev Date: Fri, 14 Jun 2024 17:35:19 +0300 Subject: [PATCH] Support `enum.member` for python3.11+ (#17382) There are no tests for `@enum.member` used as a decorator, because I can only decorate classes and functions, which are not supported right now: https://mypy-play.net/?mypy=latest&python=3.12&gist=449ee8c12eba9f807cfc7832f1ea2c49 ```python import enum class A(enum.Enum): class x: ... reveal_type(A.x) # Revealed type is "def () -> __main__.A.x" ``` This issue is separate and rather complex, so I would prefer to solve it independently. Refs https://github.com/python/mypy/pull/17376 --------- Co-authored-by: Alex Waygood --- mypy/plugins/default.py | 4 +++- mypy/plugins/enums.py | 18 ++++++++++++++++++ test-data/unit/check-enum.test | 19 +++++++++++++++++++ test-data/unit/lib-stub/enum.pyi | 5 +++++ 4 files changed, 45 insertions(+), 1 deletion(-) diff --git a/mypy/plugins/default.py b/mypy/plugins/default.py index 3ad301a15f6c..5139b9b82289 100644 --- a/mypy/plugins/default.py +++ b/mypy/plugins/default.py @@ -41,7 +41,7 @@ class DefaultPlugin(Plugin): """Type checker plugin that is enabled by default.""" def get_function_hook(self, fullname: str) -> Callable[[FunctionContext], Type] | None: - from mypy.plugins import ctypes, singledispatch + from mypy.plugins import ctypes, enums, singledispatch if fullname == "_ctypes.Array": return ctypes.array_constructor_callback @@ -51,6 +51,8 @@ def get_function_hook(self, fullname: str) -> Callable[[FunctionContext], Type] import mypy.plugins.functools return mypy.plugins.functools.partial_new_callback + elif fullname == "enum.member": + return enums.enum_member_callback return None diff --git a/mypy/plugins/enums.py b/mypy/plugins/enums.py index 167b330f9b09..816241fa6e9a 100644 --- a/mypy/plugins/enums.py +++ b/mypy/plugins/enums.py @@ -87,6 +87,8 @@ def _infer_value_type_with_auto_fallback( return None proper_type = get_proper_type(fixup_partial_type(proper_type)) if not (isinstance(proper_type, Instance) and proper_type.type.fullname == "enum.auto"): + if is_named_instance(proper_type, "enum.member") and proper_type.args: + return proper_type.args[0] return proper_type assert isinstance(ctx.type, Instance), "An incorrect ctx.type was passed." info = ctx.type.type @@ -126,6 +128,22 @@ def _implements_new(info: TypeInfo) -> bool: return type_with_new.fullname not in ("enum.Enum", "enum.IntEnum", "enum.StrEnum") +def enum_member_callback(ctx: mypy.plugin.FunctionContext) -> Type: + """By default `member(1)` will be infered as `member[int]`, + we want to improve the inference to be `Literal[1]` here.""" + if ctx.arg_types or ctx.arg_types[0]: + arg = get_proper_type(ctx.arg_types[0][0]) + proper_return = get_proper_type(ctx.default_return_type) + if ( + isinstance(arg, Instance) + and arg.last_known_value + and isinstance(proper_return, Instance) + and len(proper_return.args) == 1 + ): + return proper_return.copy_modified(args=[arg]) + return ctx.default_return_type + + def enum_value_callback(ctx: mypy.plugin.AttributeContext) -> Type: """This plugin refines the 'value' attribute in enums to refer to the original underlying value. For example, suppose we have the diff --git a/test-data/unit/check-enum.test b/test-data/unit/check-enum.test index 183901416604..d53935085325 100644 --- a/test-data/unit/check-enum.test +++ b/test-data/unit/check-enum.test @@ -2166,3 +2166,22 @@ class Other(Enum): reveal_type(Other.a) # N: Revealed type is "Literal[__main__.Other.a]?" reveal_type(Other.Support.b) # N: Revealed type is "builtins.int" [builtins fixtures/dict.pyi] + + +[case testEnumMemberSupport] +# flags: --python-version 3.11 +# This was added in 3.11 +from enum import Enum, member + +class A(Enum): + x = member(1) + y = 2 + +reveal_type(A.x) # N: Revealed type is "Literal[__main__.A.x]?" +reveal_type(A.x.value) # N: Revealed type is "Literal[1]?" +reveal_type(A.y) # N: Revealed type is "Literal[__main__.A.y]?" +reveal_type(A.y.value) # N: Revealed type is "Literal[2]?" + +def some_a(a: A): + reveal_type(a.value) # N: Revealed type is "Union[Literal[1]?, Literal[2]?]" +[builtins fixtures/dict.pyi] diff --git a/test-data/unit/lib-stub/enum.pyi b/test-data/unit/lib-stub/enum.pyi index 32dd7c38d251..0e0b8e025d9f 100644 --- a/test-data/unit/lib-stub/enum.pyi +++ b/test-data/unit/lib-stub/enum.pyi @@ -53,3 +53,8 @@ class StrEnum(str, Enum): class nonmember(Generic[_T]): value: _T def __init__(self, value: _T) -> None: ... + +# It is python-3.11+ only: +class member(Generic[_T]): + value: _T + def __init__(self, value: _T) -> None: ...