From fc212ef893a0c3e6d500be89aa4d9302c1219595 Mon Sep 17 00:00:00 2001 From: Nate McMaster Date: Fri, 27 Nov 2020 17:25:00 -0800 Subject: [PATCH] Fix generic inheritance for attrs init methods (#9383) Fixes #5744 Updates the attrs plugin. Instead of directly copying attribute type along the MRO, this first resolves typevar in the context of the subtype. --- mypy/plugins/attrs.py | 46 +++++++++++---- test-data/unit/check-attr.test | 103 +++++++++++++++++++++++++++++++++ 2 files changed, 137 insertions(+), 12 deletions(-) diff --git a/mypy/plugins/attrs.py b/mypy/plugins/attrs.py index f8ca2161a7e9..5fd2dde01a03 100644 --- a/mypy/plugins/attrs.py +++ b/mypy/plugins/attrs.py @@ -15,14 +15,16 @@ SymbolTableNode, MDEF, JsonDict, OverloadedFuncDef, ARG_NAMED_OPT, ARG_NAMED, TypeVarExpr, PlaceholderNode ) +from mypy.plugin import SemanticAnalyzerPluginInterface from mypy.plugins.common import ( - _get_argument, _get_bool_argument, _get_decorator_bool_argument, add_method + _get_argument, _get_bool_argument, _get_decorator_bool_argument, add_method, + deserialize_and_fixup_type ) from mypy.types import ( Type, AnyType, TypeOfAny, CallableType, NoneType, TypeVarDef, TypeVarType, Overloaded, UnionType, FunctionLike, get_proper_type ) -from mypy.typeops import make_simplified_union +from mypy.typeops import make_simplified_union, map_type_from_supertype from mypy.typevars import fill_typevars from mypy.util import unmangle from mypy.server.trigger import make_wildcard_trigger @@ -70,7 +72,8 @@ class Attribute: def __init__(self, name: str, info: TypeInfo, has_default: bool, init: bool, kw_only: bool, converter: Converter, - context: Context) -> None: + context: Context, + init_type: Optional[Type]) -> None: self.name = name self.info = info self.has_default = has_default @@ -78,11 +81,13 @@ def __init__(self, name: str, info: TypeInfo, self.kw_only = kw_only self.converter = converter self.context = context + self.init_type = init_type def argument(self, ctx: 'mypy.plugin.ClassDefContext') -> Argument: """Return this attribute as an argument to __init__.""" assert self.init - init_type = self.info[self.name].type + + init_type = self.init_type or self.info[self.name].type if self.converter.name: # When a converter is set the init_type is overridden by the first argument @@ -168,20 +173,33 @@ def serialize(self) -> JsonDict: 'converter_is_attr_converters_optional': self.converter.is_attr_converters_optional, 'context_line': self.context.line, 'context_column': self.context.column, + 'init_type': self.init_type.serialize() if self.init_type else None, } @classmethod - def deserialize(cls, info: TypeInfo, data: JsonDict) -> 'Attribute': + def deserialize(cls, info: TypeInfo, + data: JsonDict, + api: SemanticAnalyzerPluginInterface) -> 'Attribute': """Return the Attribute that was serialized.""" - return Attribute( - data['name'], + raw_init_type = data['init_type'] + init_type = deserialize_and_fixup_type(raw_init_type, api) if raw_init_type else None + + return Attribute(data['name'], info, data['has_default'], data['init'], data['kw_only'], Converter(data['converter_name'], data['converter_is_attr_converters_optional']), - Context(line=data['context_line'], column=data['context_column']) - ) + Context(line=data['context_line'], column=data['context_column']), + init_type) + + def expand_typevar_from_subtype(self, sub_type: TypeInfo) -> None: + """Expands type vars in the context of a subtype when an attribute is inherited + from a generic super type.""" + if not isinstance(self.init_type, TypeVarType): + return + + self.init_type = map_type_from_supertype(self.init_type, sub_type, self.info) def _determine_eq_order(ctx: 'mypy.plugin.ClassDefContext') -> bool: @@ -363,7 +381,8 @@ def _analyze_class(ctx: 'mypy.plugin.ClassDefContext', # Only add an attribute if it hasn't been defined before. This # allows for overwriting attribute definitions by subclassing. if data['name'] not in taken_attr_names: - a = Attribute.deserialize(super_info, data) + a = Attribute.deserialize(super_info, data, ctx.api) + a.expand_typevar_from_subtype(ctx.cls.info) super_attrs.append(a) taken_attr_names.add(a.name) attributes = super_attrs + list(own_attrs.values()) @@ -491,7 +510,9 @@ def _attribute_from_auto_attrib(ctx: 'mypy.plugin.ClassDefContext', name = unmangle(lhs.name) # `x: int` (without equal sign) assigns rvalue to TempNode(AnyType()) has_rhs = not isinstance(rvalue, TempNode) - return Attribute(name, ctx.cls.info, has_rhs, True, kw_only, Converter(), stmt) + sym = ctx.cls.info.names.get(name) + init_type = sym.type if sym else None + return Attribute(name, ctx.cls.info, has_rhs, True, kw_only, Converter(), stmt, init_type) def _attribute_from_attrib_maker(ctx: 'mypy.plugin.ClassDefContext', @@ -557,7 +578,8 @@ def _attribute_from_attrib_maker(ctx: 'mypy.plugin.ClassDefContext', converter_info = _parse_converter(ctx, converter) name = unmangle(lhs.name) - return Attribute(name, ctx.cls.info, attr_has_default, init, kw_only, converter_info, stmt) + return Attribute(name, ctx.cls.info, attr_has_default, init, + kw_only, converter_info, stmt, init_type) def _parse_converter(ctx: 'mypy.plugin.ClassDefContext', diff --git a/test-data/unit/check-attr.test b/test-data/unit/check-attr.test index e83f80c85948..5a97f6c5dd38 100644 --- a/test-data/unit/check-attr.test +++ b/test-data/unit/check-attr.test @@ -454,6 +454,109 @@ A([1], '2') # E: Cannot infer type argument 1 of "A" [builtins fixtures/list.pyi] + +[case testAttrsUntypedGenericInheritance] +from typing import Generic, TypeVar +import attr + +T = TypeVar("T") + +@attr.s(auto_attribs=True) +class Base(Generic[T]): + attr: T + +@attr.s(auto_attribs=True) +class Sub(Base): + pass + +sub = Sub(attr=1) +reveal_type(sub) # N: Revealed type is '__main__.Sub' +reveal_type(sub.attr) # N: Revealed type is 'Any' + +[builtins fixtures/bool.pyi] + + +[case testAttrsGenericInheritance] +from typing import Generic, TypeVar +import attr + +S = TypeVar("S") +T = TypeVar("T") + +@attr.s(auto_attribs=True) +class Base(Generic[T]): + attr: T + +@attr.s(auto_attribs=True) +class Sub(Base[S]): + pass + +sub_int = Sub[int](attr=1) +reveal_type(sub_int) # N: Revealed type is '__main__.Sub[builtins.int*]' +reveal_type(sub_int.attr) # N: Revealed type is 'builtins.int*' + +sub_str = Sub[str](attr='ok') +reveal_type(sub_str) # N: Revealed type is '__main__.Sub[builtins.str*]' +reveal_type(sub_str.attr) # N: Revealed type is 'builtins.str*' + +[builtins fixtures/bool.pyi] + + +[case testAttrsGenericInheritance] +from typing import Generic, TypeVar +import attr + +T1 = TypeVar("T1") +T2 = TypeVar("T2") +T3 = TypeVar("T3") + +@attr.s(auto_attribs=True) +class Base(Generic[T1, T2, T3]): + one: T1 + two: T2 + three: T3 + +@attr.s(auto_attribs=True) +class Sub(Base[int, str, float]): + pass + +sub = Sub(one=1, two='ok', three=3.14) +reveal_type(sub) # N: Revealed type is '__main__.Sub' +reveal_type(sub.one) # N: Revealed type is 'builtins.int*' +reveal_type(sub.two) # N: Revealed type is 'builtins.str*' +reveal_type(sub.three) # N: Revealed type is 'builtins.float*' + +[builtins fixtures/bool.pyi] + + +[case testAttrsMultiGenericInheritance] +from typing import Generic, TypeVar +import attr + +T = TypeVar("T") + +@attr.s(auto_attribs=True, eq=False) +class Base(Generic[T]): + base_attr: T + +S = TypeVar("S") + +@attr.s(auto_attribs=True, eq=False) +class Middle(Base[int], Generic[S]): + middle_attr: S + +@attr.s(auto_attribs=True, eq=False) +class Sub(Middle[str]): + pass + +sub = Sub(base_attr=1, middle_attr='ok') +reveal_type(sub) # N: Revealed type is '__main__.Sub' +reveal_type(sub.base_attr) # N: Revealed type is 'builtins.int*' +reveal_type(sub.middle_attr) # N: Revealed type is 'builtins.str*' + +[builtins fixtures/bool.pyi] + + [case testAttrsGenericClassmethod] from typing import TypeVar, Generic, Optional import attr