From 37777b3f52560c6d801e76a2ca58b91f3981f43f Mon Sep 17 00:00:00 2001 From: Matt Gilson Date: Thu, 17 Sep 2020 10:24:08 -0400 Subject: [PATCH] Predict enum value type for unknown member names (#9443) It is very common for enums to have homogenous member-value types. In the case where we do not know what enum member we are dealing with, we should sniff for that case and still collapse to a known type if that assumption holds. Handles auto() too, even if you override _get_next_value_. --- mypy/plugins/enums.py | 90 ++++++++++++++++++++++++++++---- test-data/unit/check-enum.test | 88 +++++++++++++++++++++++++++---- test-data/unit/lib-stub/enum.pyi | 6 ++- 3 files changed, 164 insertions(+), 20 deletions(-) diff --git a/mypy/plugins/enums.py b/mypy/plugins/enums.py index 81aa29afcb11..e246e9de14b6 100644 --- a/mypy/plugins/enums.py +++ b/mypy/plugins/enums.py @@ -10,11 +10,11 @@ we actually bake some of it directly in to the semantic analysis layer (see semanal_enum.py). """ -from typing import Optional +from typing import Iterable, Optional, TypeVar from typing_extensions import Final import mypy.plugin # To avoid circular imports. -from mypy.types import Type, Instance, LiteralType, get_proper_type +from mypy.types import Type, Instance, LiteralType, CallableType, ProperType, get_proper_type # Note: 'enum.EnumMeta' is deliberately excluded from this list. Classes that directly use # enum.EnumMeta do not necessarily automatically have the 'name' and 'value' attributes. @@ -53,6 +53,56 @@ def enum_name_callback(ctx: 'mypy.plugin.AttributeContext') -> Type: return str_type.copy_modified(last_known_value=literal_type) +_T = TypeVar('_T') + + +def _first(it: Iterable[_T]) -> Optional[_T]: + """Return the first value from any iterable. + + Returns ``None`` if the iterable is empty. + """ + for val in it: + return val + return None + + +def _infer_value_type_with_auto_fallback( + ctx: 'mypy.plugin.AttributeContext', + proper_type: Optional[ProperType]) -> Optional[Type]: + """Figure out the type of an enum value accounting for `auto()`. + + This method is a no-op for a `None` proper_type and also in the case where + the type is not "enum.auto" + """ + if proper_type is None: + return None + if not ((isinstance(proper_type, Instance) and + proper_type.type.fullname == 'enum.auto')): + return proper_type + assert isinstance(ctx.type, Instance), 'An incorrect ctx.type was passed.' + info = ctx.type.type + # Find the first _generate_next_value_ on the mro. We need to know + # if it is `Enum` because `Enum` types say that the return-value of + # `_generate_next_value_` is `Any`. In reality the default `auto()` + # returns an `int` (presumably the `Any` in typeshed is to make it + # easier to subclass and change the returned type). + type_with_gnv = _first( + ti for ti in info.mro if ti.names.get('_generate_next_value_')) + if type_with_gnv is None: + return ctx.default_attr_type + + stnode = type_with_gnv.names['_generate_next_value_'] + + # This should be a `CallableType` + node_type = get_proper_type(stnode.type) + if isinstance(node_type, CallableType): + if type_with_gnv.fullname == 'enum.Enum': + int_type = ctx.api.named_generic_type('builtins.int', []) + return int_type + return get_proper_type(node_type.ret_type) + return ctx.default_attr_type + + def enum_value_callback(ctx: 'mypy.plugin.AttributeContext') -> Type: """This plugin refines the 'value' attribute in enums to refer to the original underlying value. For example, suppose we have the @@ -78,6 +128,32 @@ class SomeEnum: """ enum_field_name = _extract_underlying_field_name(ctx.type) if enum_field_name is None: + # We do not know the enum field name (perhaps it was passed to a + # function and we only know that it _is_ a member). All is not lost + # however, if we can prove that the all of the enum members have the + # same value-type, then it doesn't matter which member was passed in. + # The value-type is still known. + if isinstance(ctx.type, Instance): + info = ctx.type.type + stnodes = (info.get(name) for name in info.names) + # Enums _can_ have methods. + # Omit methods for our value inference. + node_types = ( + get_proper_type(n.type) if n else None + for n in stnodes) + proper_types = ( + _infer_value_type_with_auto_fallback(ctx, t) + for t in node_types + if t is None or not isinstance(t, CallableType)) + underlying_type = _first(proper_types) + if underlying_type is None: + return ctx.default_attr_type + all_same_value_type = all( + proper_type is not None and proper_type == underlying_type + for proper_type in proper_types) + if all_same_value_type: + if underlying_type is not None: + return underlying_type return ctx.default_attr_type assert isinstance(ctx.type, Instance) @@ -86,15 +162,9 @@ class SomeEnum: if stnode is None: return ctx.default_attr_type - underlying_type = get_proper_type(stnode.type) + underlying_type = _infer_value_type_with_auto_fallback( + ctx, get_proper_type(stnode.type)) if underlying_type is None: - # TODO: Deduce the inferred type if the user omits adding their own default types. - # TODO: Consider using the return type of `Enum._generate_next_value_` here? - return ctx.default_attr_type - - if isinstance(underlying_type, Instance) and underlying_type.type.fullname == 'enum.auto': - # TODO: Deduce the correct inferred type when the user uses 'enum.auto'. - # We should use the same strategy we end up picking up above. return ctx.default_attr_type return underlying_type diff --git a/test-data/unit/check-enum.test b/test-data/unit/check-enum.test index e66fdfe277a1..37b12a0c32eb 100644 --- a/test-data/unit/check-enum.test +++ b/test-data/unit/check-enum.test @@ -59,6 +59,76 @@ reveal_type(Truth.true.name) # N: Revealed type is 'Literal['true']?' reveal_type(Truth.false.value) # N: Revealed type is 'builtins.bool' [builtins fixtures/bool.pyi] +[case testEnumValueExtended] +from enum import Enum +class Truth(Enum): + true = True + false = False + +def infer_truth(truth: Truth) -> None: + reveal_type(truth.value) # N: Revealed type is 'builtins.bool' +[builtins fixtures/bool.pyi] + +[case testEnumValueAllAuto] +from enum import Enum, auto +class Truth(Enum): + true = auto() + false = auto() + +def infer_truth(truth: Truth) -> None: + reveal_type(truth.value) # N: Revealed type is 'builtins.int' +[builtins fixtures/primitives.pyi] + +[case testEnumValueSomeAuto] +from enum import Enum, auto +class Truth(Enum): + true = 8675309 + false = auto() + +def infer_truth(truth: Truth) -> None: + reveal_type(truth.value) # N: Revealed type is 'builtins.int' +[builtins fixtures/primitives.pyi] + +[case testEnumValueExtraMethods] +from enum import Enum, auto +class Truth(Enum): + true = True + false = False + + def foo(self) -> str: + return 'bar' + +def infer_truth(truth: Truth) -> None: + reveal_type(truth.value) # N: Revealed type is 'builtins.bool' +[builtins fixtures/bool.pyi] + +[case testEnumValueCustomAuto] +from enum import Enum, auto +class AutoName(Enum): + + # In `typeshed`, this is a staticmethod and has more arguments, + # but I have lied a bit to keep the test stubs lean. + def _generate_next_value_(self) -> str: + return "name" + +class Truth(AutoName): + true = auto() + false = auto() + +def infer_truth(truth: Truth) -> None: + reveal_type(truth.value) # N: Revealed type is 'builtins.str' +[builtins fixtures/primitives.pyi] + +[case testEnumValueInhomogenous] +from enum import Enum +class Truth(Enum): + true = 'True' + false = 0 + +def cannot_infer_truth(truth: Truth) -> None: + reveal_type(truth.value) # N: Revealed type is 'Any' +[builtins fixtures/bool.pyi] + [case testEnumUnique] import enum @enum.unique @@ -497,8 +567,8 @@ reveal_type(A1.x.value) # N: Revealed type is 'Any' reveal_type(A1.x._value_) # N: Revealed type is 'Any' is_x(reveal_type(A2.x.name)) # N: Revealed type is 'Literal['x']' is_x(reveal_type(A2.x._name_)) # N: Revealed type is 'Literal['x']' -reveal_type(A2.x.value) # N: Revealed type is 'Any' -reveal_type(A2.x._value_) # N: Revealed type is 'Any' +reveal_type(A2.x.value) # N: Revealed type is 'builtins.int' +reveal_type(A2.x._value_) # N: Revealed type is 'builtins.int' is_x(reveal_type(A3.x.name)) # N: Revealed type is 'Literal['x']' is_x(reveal_type(A3.x._name_)) # N: Revealed type is 'Literal['x']' reveal_type(A3.x.value) # N: Revealed type is 'builtins.int' @@ -519,7 +589,7 @@ reveal_type(B1.x._value_) # N: Revealed type is 'Any' is_x(reveal_type(B2.x.name)) # N: Revealed type is 'Literal['x']' is_x(reveal_type(B2.x._name_)) # N: Revealed type is 'Literal['x']' reveal_type(B2.x.value) # N: Revealed type is 'builtins.int' -reveal_type(B2.x._value_) # N: Revealed type is 'Any' +reveal_type(B2.x._value_) # N: Revealed type is 'builtins.int' is_x(reveal_type(B3.x.name)) # N: Revealed type is 'Literal['x']' is_x(reveal_type(B3.x._name_)) # N: Revealed type is 'Literal['x']' reveal_type(B3.x.value) # N: Revealed type is 'builtins.int' @@ -540,8 +610,8 @@ reveal_type(C1.x.value) # N: Revealed type is 'Any' reveal_type(C1.x._value_) # N: Revealed type is 'Any' is_x(reveal_type(C2.x.name)) # N: Revealed type is 'Literal['x']' is_x(reveal_type(C2.x._name_)) # N: Revealed type is 'Literal['x']' -reveal_type(C2.x.value) # N: Revealed type is 'Any' -reveal_type(C2.x._value_) # N: Revealed type is 'Any' +reveal_type(C2.x.value) # N: Revealed type is 'builtins.int' +reveal_type(C2.x._value_) # N: Revealed type is 'builtins.int' is_x(reveal_type(C3.x.name)) # N: Revealed type is 'Literal['x']' is_x(reveal_type(C3.x._name_)) # N: Revealed type is 'Literal['x']' reveal_type(C3.x.value) # N: Revealed type is 'builtins.int' @@ -559,8 +629,8 @@ reveal_type(D1.x.value) # N: Revealed type is 'Any' reveal_type(D1.x._value_) # N: Revealed type is 'Any' is_x(reveal_type(D2.x.name)) # N: Revealed type is 'Literal['x']' is_x(reveal_type(D2.x._name_)) # N: Revealed type is 'Literal['x']' -reveal_type(D2.x.value) # N: Revealed type is 'Any' -reveal_type(D2.x._value_) # N: Revealed type is 'Any' +reveal_type(D2.x.value) # N: Revealed type is 'builtins.int' +reveal_type(D2.x._value_) # N: Revealed type is 'builtins.int' is_x(reveal_type(D3.x.name)) # N: Revealed type is 'Literal['x']' is_x(reveal_type(D3.x._name_)) # N: Revealed type is 'Literal['x']' reveal_type(D3.x.value) # N: Revealed type is 'builtins.int' @@ -578,8 +648,8 @@ class E3(Parent): is_x(reveal_type(E2.x.name)) # N: Revealed type is 'Literal['x']' is_x(reveal_type(E2.x._name_)) # N: Revealed type is 'Literal['x']' -reveal_type(E2.x.value) # N: Revealed type is 'Any' -reveal_type(E2.x._value_) # N: Revealed type is 'Any' +reveal_type(E2.x.value) # N: Revealed type is 'builtins.int' +reveal_type(E2.x._value_) # N: Revealed type is 'builtins.int' is_x(reveal_type(E3.x.name)) # N: Revealed type is 'Literal['x']' is_x(reveal_type(E3.x._name_)) # N: Revealed type is 'Literal['x']' reveal_type(E3.x.value) # N: Revealed type is 'builtins.int' diff --git a/test-data/unit/lib-stub/enum.pyi b/test-data/unit/lib-stub/enum.pyi index 14908c2d1063..8d0e5fce291a 100644 --- a/test-data/unit/lib-stub/enum.pyi +++ b/test-data/unit/lib-stub/enum.pyi @@ -21,6 +21,10 @@ class Enum(metaclass=EnumMeta): _name_: str _value_: Any + # In reality, _generate_next_value_ is python3.6 only and has a different signature. + # However, this should be quick and doesn't require additional stubs (e.g. `staticmethod`) + def _generate_next_value_(self) -> Any: pass + class IntEnum(int, Enum): value: int @@ -37,4 +41,4 @@ class IntFlag(int, Flag): class auto(IntFlag): - value: Any \ No newline at end of file + value: Any