Skip to content

Commit

Permalink
Fix generic inheritance for attrs init methods (#9383)
Browse files Browse the repository at this point in the history
Fixes #5744

Updates the attrs plugin. Instead of directly copying attribute type along the MRO, this first resolves typevar in the context of the subtype.
  • Loading branch information
natemcmaster authored Nov 28, 2020
1 parent 98beb8e commit fc212ef
Show file tree
Hide file tree
Showing 2 changed files with 137 additions and 12 deletions.
46 changes: 34 additions & 12 deletions mypy/plugins/attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,16 @@
SymbolTableNode, MDEF, JsonDict, OverloadedFuncDef, ARG_NAMED_OPT, ARG_NAMED,
TypeVarExpr, PlaceholderNode
)
from mypy.plugin import SemanticAnalyzerPluginInterface
from mypy.plugins.common import (
_get_argument, _get_bool_argument, _get_decorator_bool_argument, add_method
_get_argument, _get_bool_argument, _get_decorator_bool_argument, add_method,
deserialize_and_fixup_type
)
from mypy.types import (
Type, AnyType, TypeOfAny, CallableType, NoneType, TypeVarDef, TypeVarType,
Overloaded, UnionType, FunctionLike, get_proper_type
)
from mypy.typeops import make_simplified_union
from mypy.typeops import make_simplified_union, map_type_from_supertype
from mypy.typevars import fill_typevars
from mypy.util import unmangle
from mypy.server.trigger import make_wildcard_trigger
Expand Down Expand Up @@ -70,19 +72,22 @@ class Attribute:

def __init__(self, name: str, info: TypeInfo,
has_default: bool, init: bool, kw_only: bool, converter: Converter,
context: Context) -> None:
context: Context,
init_type: Optional[Type]) -> None:
self.name = name
self.info = info
self.has_default = has_default
self.init = init
self.kw_only = kw_only
self.converter = converter
self.context = context
self.init_type = init_type

def argument(self, ctx: 'mypy.plugin.ClassDefContext') -> Argument:
"""Return this attribute as an argument to __init__."""
assert self.init
init_type = self.info[self.name].type

init_type = self.init_type or self.info[self.name].type

if self.converter.name:
# When a converter is set the init_type is overridden by the first argument
Expand Down Expand Up @@ -168,20 +173,33 @@ def serialize(self) -> JsonDict:
'converter_is_attr_converters_optional': self.converter.is_attr_converters_optional,
'context_line': self.context.line,
'context_column': self.context.column,
'init_type': self.init_type.serialize() if self.init_type else None,
}

@classmethod
def deserialize(cls, info: TypeInfo, data: JsonDict) -> 'Attribute':
def deserialize(cls, info: TypeInfo,
data: JsonDict,
api: SemanticAnalyzerPluginInterface) -> 'Attribute':
"""Return the Attribute that was serialized."""
return Attribute(
data['name'],
raw_init_type = data['init_type']
init_type = deserialize_and_fixup_type(raw_init_type, api) if raw_init_type else None

return Attribute(data['name'],
info,
data['has_default'],
data['init'],
data['kw_only'],
Converter(data['converter_name'], data['converter_is_attr_converters_optional']),
Context(line=data['context_line'], column=data['context_column'])
)
Context(line=data['context_line'], column=data['context_column']),
init_type)

def expand_typevar_from_subtype(self, sub_type: TypeInfo) -> None:
"""Expands type vars in the context of a subtype when an attribute is inherited
from a generic super type."""
if not isinstance(self.init_type, TypeVarType):
return

self.init_type = map_type_from_supertype(self.init_type, sub_type, self.info)


def _determine_eq_order(ctx: 'mypy.plugin.ClassDefContext') -> bool:
Expand Down Expand Up @@ -363,7 +381,8 @@ def _analyze_class(ctx: 'mypy.plugin.ClassDefContext',
# Only add an attribute if it hasn't been defined before. This
# allows for overwriting attribute definitions by subclassing.
if data['name'] not in taken_attr_names:
a = Attribute.deserialize(super_info, data)
a = Attribute.deserialize(super_info, data, ctx.api)
a.expand_typevar_from_subtype(ctx.cls.info)
super_attrs.append(a)
taken_attr_names.add(a.name)
attributes = super_attrs + list(own_attrs.values())
Expand Down Expand Up @@ -491,7 +510,9 @@ def _attribute_from_auto_attrib(ctx: 'mypy.plugin.ClassDefContext',
name = unmangle(lhs.name)
# `x: int` (without equal sign) assigns rvalue to TempNode(AnyType())
has_rhs = not isinstance(rvalue, TempNode)
return Attribute(name, ctx.cls.info, has_rhs, True, kw_only, Converter(), stmt)
sym = ctx.cls.info.names.get(name)
init_type = sym.type if sym else None
return Attribute(name, ctx.cls.info, has_rhs, True, kw_only, Converter(), stmt, init_type)


def _attribute_from_attrib_maker(ctx: 'mypy.plugin.ClassDefContext',
Expand Down Expand Up @@ -557,7 +578,8 @@ def _attribute_from_attrib_maker(ctx: 'mypy.plugin.ClassDefContext',
converter_info = _parse_converter(ctx, converter)

name = unmangle(lhs.name)
return Attribute(name, ctx.cls.info, attr_has_default, init, kw_only, converter_info, stmt)
return Attribute(name, ctx.cls.info, attr_has_default, init,
kw_only, converter_info, stmt, init_type)


def _parse_converter(ctx: 'mypy.plugin.ClassDefContext',
Expand Down
103 changes: 103 additions & 0 deletions test-data/unit/check-attr.test
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,109 @@ A([1], '2') # E: Cannot infer type argument 1 of "A"

[builtins fixtures/list.pyi]


[case testAttrsUntypedGenericInheritance]
from typing import Generic, TypeVar
import attr

T = TypeVar("T")

@attr.s(auto_attribs=True)
class Base(Generic[T]):
attr: T

@attr.s(auto_attribs=True)
class Sub(Base):
pass

sub = Sub(attr=1)
reveal_type(sub) # N: Revealed type is '__main__.Sub'
reveal_type(sub.attr) # N: Revealed type is 'Any'

[builtins fixtures/bool.pyi]


[case testAttrsGenericInheritance]
from typing import Generic, TypeVar
import attr

S = TypeVar("S")
T = TypeVar("T")

@attr.s(auto_attribs=True)
class Base(Generic[T]):
attr: T

@attr.s(auto_attribs=True)
class Sub(Base[S]):
pass

sub_int = Sub[int](attr=1)
reveal_type(sub_int) # N: Revealed type is '__main__.Sub[builtins.int*]'
reveal_type(sub_int.attr) # N: Revealed type is 'builtins.int*'

sub_str = Sub[str](attr='ok')
reveal_type(sub_str) # N: Revealed type is '__main__.Sub[builtins.str*]'
reveal_type(sub_str.attr) # N: Revealed type is 'builtins.str*'

[builtins fixtures/bool.pyi]


[case testAttrsGenericInheritance]
from typing import Generic, TypeVar
import attr

T1 = TypeVar("T1")
T2 = TypeVar("T2")
T3 = TypeVar("T3")

@attr.s(auto_attribs=True)
class Base(Generic[T1, T2, T3]):
one: T1
two: T2
three: T3

@attr.s(auto_attribs=True)
class Sub(Base[int, str, float]):
pass

sub = Sub(one=1, two='ok', three=3.14)
reveal_type(sub) # N: Revealed type is '__main__.Sub'
reveal_type(sub.one) # N: Revealed type is 'builtins.int*'
reveal_type(sub.two) # N: Revealed type is 'builtins.str*'
reveal_type(sub.three) # N: Revealed type is 'builtins.float*'

[builtins fixtures/bool.pyi]


[case testAttrsMultiGenericInheritance]
from typing import Generic, TypeVar
import attr

T = TypeVar("T")

@attr.s(auto_attribs=True, eq=False)
class Base(Generic[T]):
base_attr: T

S = TypeVar("S")

@attr.s(auto_attribs=True, eq=False)
class Middle(Base[int], Generic[S]):
middle_attr: S

@attr.s(auto_attribs=True, eq=False)
class Sub(Middle[str]):
pass

sub = Sub(base_attr=1, middle_attr='ok')
reveal_type(sub) # N: Revealed type is '__main__.Sub'
reveal_type(sub.base_attr) # N: Revealed type is 'builtins.int*'
reveal_type(sub.middle_attr) # N: Revealed type is 'builtins.str*'

[builtins fixtures/bool.pyi]


[case testAttrsGenericClassmethod]
from typing import TypeVar, Generic, Optional
import attr
Expand Down

0 comments on commit fc212ef

Please sign in to comment.