Skip to content

Commit

Permalink
attrs.evolve: support generics and unions (#15050)
Browse files Browse the repository at this point in the history
Fixes `attrs.evolve` signature generation to support the `inst`
parameter being
- a generic attrs class
- a union of attrs classes
- a mix of the two

In the case of unions, we "meet" the fields of the potential attrs
classes, so that the resulting signature is the lower bound.

Fixes #15088.
  • Loading branch information
ikonst authored Apr 21, 2023
1 parent 0845818 commit 2a4c473
Show file tree
Hide file tree
Showing 3 changed files with 171 additions and 27 deletions.
115 changes: 90 additions & 25 deletions mypy/plugins/attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,18 @@

from __future__ import annotations

from typing import Iterable, List, cast
from collections import defaultdict
from functools import reduce
from typing import Iterable, List, Mapping, cast
from typing_extensions import Final, Literal

import mypy.plugin # To avoid circular imports.
from mypy.applytype import apply_generic_arguments
from mypy.checker import TypeChecker
from mypy.errorcodes import LITERAL_REQ
from mypy.expandtype import expand_type
from mypy.expandtype import expand_type, expand_type_by_instance
from mypy.exprtotype import TypeTranslationError, expr_to_unanalyzed_type
from mypy.meet import meet_types
from mypy.messages import format_type_bare
from mypy.nodes import (
ARG_NAMED,
Expand Down Expand Up @@ -67,6 +70,7 @@
Type,
TypeOfAny,
TypeVarType,
UninhabitedType,
UnionType,
get_proper_type,
)
Expand Down Expand Up @@ -942,12 +946,82 @@ def _get_attrs_init_type(typ: Instance) -> CallableType | None:
return init_method.type


def _get_attrs_cls_and_init(typ: ProperType) -> tuple[Instance | None, CallableType | None]:
if isinstance(typ, TypeVarType):
typ = get_proper_type(typ.upper_bound)
if not isinstance(typ, Instance):
return None, None
return typ, _get_attrs_init_type(typ)
def _fail_not_attrs_class(ctx: mypy.plugin.FunctionSigContext, t: Type, parent_t: Type) -> None:
t_name = format_type_bare(t, ctx.api.options)
if parent_t is t:
msg = (
f'Argument 1 to "evolve" has a variable type "{t_name}" not bound to an attrs class'
if isinstance(t, TypeVarType)
else f'Argument 1 to "evolve" has incompatible type "{t_name}"; expected an attrs class'
)
else:
pt_name = format_type_bare(parent_t, ctx.api.options)
msg = (
f'Argument 1 to "evolve" has type "{pt_name}" whose item "{t_name}" is not bound to an attrs class'
if isinstance(t, TypeVarType)
else f'Argument 1 to "evolve" has incompatible type "{pt_name}" whose item "{t_name}" is not an attrs class'
)

ctx.api.fail(msg, ctx.context)


def _get_expanded_attr_types(
ctx: mypy.plugin.FunctionSigContext,
typ: ProperType,
display_typ: ProperType,
parent_typ: ProperType,
) -> list[Mapping[str, Type]] | None:
"""
For a given type, determine what attrs classes it can be: for each class, return the field types.
For generic classes, the field types are expanded.
If the type contains Any or a non-attrs type, returns None; in the latter case, also reports an error.
"""
if isinstance(typ, AnyType):
return None
elif isinstance(typ, UnionType):
ret: list[Mapping[str, Type]] | None = []
for item in typ.relevant_items():
item = get_proper_type(item)
item_types = _get_expanded_attr_types(ctx, item, item, parent_typ)
if ret is not None and item_types is not None:
ret += item_types
else:
ret = None # but keep iterating to emit all errors
return ret
elif isinstance(typ, TypeVarType):
return _get_expanded_attr_types(
ctx, get_proper_type(typ.upper_bound), display_typ, parent_typ
)
elif isinstance(typ, Instance):
init_func = _get_attrs_init_type(typ)
if init_func is None:
_fail_not_attrs_class(ctx, display_typ, parent_typ)
return None
init_func = expand_type_by_instance(init_func, typ)
# [1:] to skip the self argument of AttrClass.__init__
field_names = cast(List[str], init_func.arg_names[1:])
field_types = init_func.arg_types[1:]
return [dict(zip(field_names, field_types))]
else:
_fail_not_attrs_class(ctx, display_typ, parent_typ)
return None


def _meet_fields(types: list[Mapping[str, Type]]) -> Mapping[str, Type]:
"""
"Meets" the fields of a list of attrs classes, i.e. for each field, its new type will be the lower bound.
"""
field_to_types = defaultdict(list)
for fields in types:
for name, typ in fields.items():
field_to_types[name].append(typ)

return {
name: get_proper_type(reduce(meet_types, f_types))
if len(f_types) == len(types)
else UninhabitedType()
for name, f_types in field_to_types.items()
}


def evolve_function_sig_callback(ctx: mypy.plugin.FunctionSigContext) -> CallableType:
Expand All @@ -971,27 +1045,18 @@ def evolve_function_sig_callback(ctx: mypy.plugin.FunctionSigContext) -> Callabl
# </hack>

inst_type = get_proper_type(inst_type)
if isinstance(inst_type, AnyType):
return ctx.default_signature # evolve(Any, ....) -> Any
inst_type_str = format_type_bare(inst_type, ctx.api.options)

attrs_type, attrs_init_type = _get_attrs_cls_and_init(inst_type)
if attrs_type is None or attrs_init_type is None:
ctx.api.fail(
f'Argument 1 to "evolve" has a variable type "{inst_type_str}" not bound to an attrs class'
if isinstance(inst_type, TypeVarType)
else f'Argument 1 to "evolve" has incompatible type "{inst_type_str}"; expected an attrs class',
ctx.context,
)
attr_types = _get_expanded_attr_types(ctx, inst_type, inst_type, inst_type)
if attr_types is None:
return ctx.default_signature
fields = _meet_fields(attr_types)

# AttrClass.__init__ has the following signature (or similar, if having kw-only & defaults):
# def __init__(self, attr1: Type1, attr2: Type2) -> None:
# We want to generate a signature for evolve that looks like this:
# def evolve(inst: AttrClass, *, attr1: Type1 = ..., attr2: Type2 = ...) -> AttrClass:
return attrs_init_type.copy_modified(
arg_names=["inst"] + attrs_init_type.arg_names[1:],
arg_kinds=[ARG_POS] + [ARG_NAMED_OPT for _ in attrs_init_type.arg_kinds[1:]],
return CallableType(
arg_names=["inst", *fields.keys()],
arg_kinds=[ARG_POS] + [ARG_NAMED_OPT] * len(fields),
arg_types=[inst_type, *fields.values()],
ret_type=inst_type,
fallback=ctx.default_signature.fallback,
name=f"{ctx.default_signature.name} of {inst_type_str}",
)
81 changes: 80 additions & 1 deletion test-data/unit/check-attr.test
Original file line number Diff line number Diff line change
Expand Up @@ -1970,6 +1970,81 @@ reveal_type(ret) # N: Revealed type is "Any"

[typing fixtures/typing-medium.pyi]

[case testEvolveGeneric]
import attrs
from typing import Generic, TypeVar

T = TypeVar('T')

@attrs.define
class A(Generic[T]):
x: T


a = A(x=42)
reveal_type(a) # N: Revealed type is "__main__.A[builtins.int]"
a2 = attrs.evolve(a, x=42)
reveal_type(a2) # N: Revealed type is "__main__.A[builtins.int]"
a2 = attrs.evolve(a, x='42') # E: Argument "x" to "evolve" of "A[int]" has incompatible type "str"; expected "int"
reveal_type(a2) # N: Revealed type is "__main__.A[builtins.int]"

[builtins fixtures/attr.pyi]

[case testEvolveUnion]
# flags: --python-version 3.10
from typing import Generic, TypeVar
import attrs

T = TypeVar('T')


@attrs.define
class A(Generic[T]):
x: T # exercises meet(T=int, int) = int
y: bool # exercises meet(bool, int) = bool
z: str # exercises meet(str, bytes) = <nothing>
w: dict # exercises meet(dict, <nothing>) = <nothing>


@attrs.define
class B:
x: int
y: bool
z: bytes


a_or_b: A[int] | B
a2 = attrs.evolve(a_or_b, x=42, y=True)
a2 = attrs.evolve(a_or_b, x=42, y=True, z='42') # E: Argument "z" to "evolve" of "Union[A[int], B]" has incompatible type "str"; expected <nothing>
a2 = attrs.evolve(a_or_b, x=42, y=True, w={}) # E: Argument "w" to "evolve" of "Union[A[int], B]" has incompatible type "Dict[<nothing>, <nothing>]"; expected <nothing>

[builtins fixtures/attr.pyi]

[case testEvolveUnionOfTypeVar]
# flags: --python-version 3.10
import attrs
from typing import TypeVar

@attrs.define
class A:
x: int
y: int
z: str
w: dict


class B:
pass

TA = TypeVar('TA', bound=A)
TB = TypeVar('TB', bound=B)

def f(b_or_t: TA | TB | int) -> None:
a2 = attrs.evolve(b_or_t) # E: Argument 1 to "evolve" has type "Union[TA, TB, int]" whose item "TB" is not bound to an attrs class # E: Argument 1 to "evolve" has incompatible type "Union[TA, TB, int]" whose item "int" is not an attrs class


[builtins fixtures/attr.pyi]

[case testEvolveTypeVarBound]
import attrs
from typing import TypeVar
Expand Down Expand Up @@ -1997,11 +2072,12 @@ f(B(x=42))

[case testEvolveTypeVarBoundNonAttrs]
import attrs
from typing import TypeVar
from typing import Union, TypeVar

TInt = TypeVar('TInt', bound=int)
TAny = TypeVar('TAny')
TNone = TypeVar('TNone', bound=None)
TUnion = TypeVar('TUnion', bound=Union[str, int])

def f(t: TInt) -> None:
_ = attrs.evolve(t, x=42) # E: Argument 1 to "evolve" has a variable type "TInt" not bound to an attrs class
Expand All @@ -2012,6 +2088,9 @@ def g(t: TAny) -> None:
def h(t: TNone) -> None:
_ = attrs.evolve(t, x=42) # E: Argument 1 to "evolve" has a variable type "TNone" not bound to an attrs class

def x(t: TUnion) -> None:
_ = attrs.evolve(t, x=42) # E: Argument 1 to "evolve" has incompatible type "TUnion" whose item "str" is not an attrs class # E: Argument 1 to "evolve" has incompatible type "TUnion" whose item "int" is not an attrs class

[builtins fixtures/attr.pyi]

[case testEvolveTypeVarConstrained]
Expand Down
2 changes: 1 addition & 1 deletion test-data/unit/fixtures/attr.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@ class object:
class type: pass
class bytes: pass
class function: pass
class bool: pass
class float: pass
class int:
@overload
def __init__(self, x: Union[str, bytes, int] = ...) -> None: ...
@overload
def __init__(self, x: Union[str, bytes], base: int) -> None: ...
class bool(int): pass
class complex:
@overload
def __init__(self, real: float = ..., im: float = ...) -> None: ...
Expand Down

0 comments on commit 2a4c473

Please sign in to comment.