Skip to content

Commit

Permalink
Some final touches for variadic types support (#16334)
Browse files Browse the repository at this point in the history
I decided to go again over various parts of variadic types
implementation to double-check nothing is missing, checked interaction
with various "advanced" features (dataclasses, protocols, self-types,
match statement, etc.), added some more tests (including incremental),
and `grep`ed for potentially unhandled cases (and did found few
crashes). This mostly touches only variadic types but one thing goes
beyond, the fix for self-types upper bound, I think it is correct and
should be safe.

If there are no objections, next PR will flip the switch.

---------

Co-authored-by: Shantanu <[email protected]>
  • Loading branch information
ilevkivskyi and hauntsaninja authored Oct 28, 2023
1 parent 6c7faf3 commit f33c9a3
Show file tree
Hide file tree
Showing 19 changed files with 726 additions and 72 deletions.
7 changes: 7 additions & 0 deletions mypy/applytype.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Callable, Sequence

import mypy.subtypes
from mypy.erasetype import erase_typevars
from mypy.expandtype import expand_type
from mypy.nodes import Context
from mypy.types import (
Expand Down Expand Up @@ -62,6 +63,11 @@ def get_target_type(
report_incompatible_typevar_value(callable, type, tvar.name, context)
else:
upper_bound = tvar.upper_bound
if tvar.name == "Self":
# Internally constructed Self-types contain class type variables in upper bound,
# so we need to erase them to avoid false positives. This is safe because we do
# not support type variables in upper bounds of user defined types.
upper_bound = erase_typevars(upper_bound)
if not mypy.subtypes.is_subtype(type, upper_bound):
if skip_unsatisfied:
return None
Expand Down Expand Up @@ -121,6 +127,7 @@ def apply_generic_arguments(
# Apply arguments to argument types.
var_arg = callable.var_arg()
if var_arg is not None and isinstance(var_arg.typ, UnpackType):
# Same as for ParamSpec, callable with variadic types needs to be expanded as a whole.
callable = expand_type(callable, id_to_type)
assert isinstance(callable, CallableType)
return callable.copy_modified(variables=[tv for tv in tvars if tv.id not in id_to_type])
Expand Down
39 changes: 20 additions & 19 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1852,7 +1852,6 @@ def expand_typevars(
if defn.info:
# Class type variables
tvars += defn.info.defn.type_vars or []
# TODO(PEP612): audit for paramspec
for tvar in tvars:
if isinstance(tvar, TypeVarType) and tvar.values:
subst.append([(tvar.id, value) for value in tvar.values])
Expand Down Expand Up @@ -2538,6 +2537,9 @@ def check_protocol_variance(self, defn: ClassDef) -> None:
object_type = Instance(info.mro[-1], [])
tvars = info.defn.type_vars
for i, tvar in enumerate(tvars):
if not isinstance(tvar, TypeVarType):
# Variance of TypeVarTuple and ParamSpec is underspecified by PEPs.
continue
up_args: list[Type] = [
object_type if i == j else AnyType(TypeOfAny.special_form)
for j, _ in enumerate(tvars)
Expand All @@ -2554,7 +2556,7 @@ def check_protocol_variance(self, defn: ClassDef) -> None:
expected = CONTRAVARIANT
else:
expected = INVARIANT
if isinstance(tvar, TypeVarType) and expected != tvar.variance:
if expected != tvar.variance:
self.msg.bad_proto_variance(tvar.variance, tvar.name, expected, defn)

def check_multiple_inheritance(self, typ: TypeInfo) -> None:
Expand Down Expand Up @@ -6695,19 +6697,6 @@ def check_possible_missing_await(
return
self.msg.possible_missing_await(context, code)

def contains_none(self, t: Type) -> bool:
t = get_proper_type(t)
return (
isinstance(t, NoneType)
or (isinstance(t, UnionType) and any(self.contains_none(ut) for ut in t.items))
or (isinstance(t, TupleType) and any(self.contains_none(tt) for tt in t.items))
or (
isinstance(t, Instance)
and bool(t.args)
and any(self.contains_none(it) for it in t.args)
)
)

def named_type(self, name: str) -> Instance:
"""Return an instance type with given name and implicit Any type args.
Expand Down Expand Up @@ -7471,10 +7460,22 @@ def builtin_item_type(tp: Type) -> Type | None:
return None
if not isinstance(get_proper_type(tp.args[0]), AnyType):
return tp.args[0]
elif isinstance(tp, TupleType) and all(
not isinstance(it, AnyType) for it in get_proper_types(tp.items)
):
return make_simplified_union(tp.items) # this type is not externally visible
elif isinstance(tp, TupleType):
normalized_items = []
for it in tp.items:
# This use case is probably rare, but not handling unpacks here can cause crashes.
if isinstance(it, UnpackType):
unpacked = get_proper_type(it.type)
if isinstance(unpacked, TypeVarTupleType):
unpacked = get_proper_type(unpacked.upper_bound)
assert (
isinstance(unpacked, Instance) and unpacked.type.fullname == "builtins.tuple"
)
normalized_items.append(unpacked.args[0])
else:
normalized_items.append(it)
if all(not isinstance(it, AnyType) for it in get_proper_types(normalized_items)):
return make_simplified_union(normalized_items) # this type is not externally visible
elif isinstance(tp, TypedDictType):
# TypedDict always has non-optional string keys. Find the key type from the Mapping
# base class.
Expand Down
7 changes: 4 additions & 3 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,7 @@ def analyze_ref_expr(self, e: RefExpr, lvalue: bool = False) -> Type:
result = self.alias_type_in_runtime_context(
node, ctx=e, alias_definition=e.is_alias_rvalue or lvalue
)
elif isinstance(node, (TypeVarExpr, ParamSpecExpr)):
elif isinstance(node, (TypeVarExpr, ParamSpecExpr, TypeVarTupleExpr)):
result = self.object_type()
else:
if isinstance(node, PlaceholderNode):
Expand Down Expand Up @@ -3316,6 +3316,7 @@ def infer_literal_expr_type(self, value: LiteralValue, fallback_name: str) -> Ty

def concat_tuples(self, left: TupleType, right: TupleType) -> TupleType:
"""Concatenate two fixed length tuples."""
assert not (find_unpack_in_list(left.items) and find_unpack_in_list(right.items))
return TupleType(
items=left.items + right.items, fallback=self.named_type("builtins.tuple")
)
Expand Down Expand Up @@ -6507,8 +6508,8 @@ def merge_typevars_in_callables_by_name(
for tv in target.variables:
name = tv.fullname
if name not in unique_typevars:
# TODO(PEP612): fix for ParamSpecType
if isinstance(tv, ParamSpecType):
# TODO: support ParamSpecType and TypeVarTuple.
if isinstance(tv, (ParamSpecType, TypeVarTupleType)):
continue
assert isinstance(tv, TypeVarType)
unique_typevars[name] = tv
Expand Down
111 changes: 93 additions & 18 deletions mypy/checkpattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,13 @@
Type,
TypedDictType,
TypeOfAny,
TypeVarTupleType,
UninhabitedType,
UnionType,
UnpackType,
find_unpack_in_list,
get_proper_type,
split_with_prefix_and_suffix,
)
from mypy.typevars import fill_typevars
from mypy.visitor import PatternVisitor
Expand Down Expand Up @@ -239,13 +243,29 @@ def visit_sequence_pattern(self, o: SequencePattern) -> PatternType:
#
# get inner types of original type
#
unpack_index = None
if isinstance(current_type, TupleType):
inner_types = current_type.items
size_diff = len(inner_types) - required_patterns
if size_diff < 0:
return self.early_non_match()
elif size_diff > 0 and star_position is None:
return self.early_non_match()
unpack_index = find_unpack_in_list(inner_types)
if unpack_index is None:
size_diff = len(inner_types) - required_patterns
if size_diff < 0:
return self.early_non_match()
elif size_diff > 0 and star_position is None:
return self.early_non_match()
else:
normalized_inner_types = []
for it in inner_types:
# Unfortunately, it is not possible to "split" the TypeVarTuple
# into individual items, so we just use its upper bound for the whole
# analysis instead.
if isinstance(it, UnpackType) and isinstance(it.type, TypeVarTupleType):
it = UnpackType(it.type.upper_bound)
normalized_inner_types.append(it)
inner_types = normalized_inner_types
current_type = current_type.copy_modified(items=normalized_inner_types)
if len(inner_types) - 1 > required_patterns and star_position is None:
return self.early_non_match()
else:
inner_type = self.get_sequence_type(current_type, o)
if inner_type is None:
Expand All @@ -270,18 +290,18 @@ def visit_sequence_pattern(self, o: SequencePattern) -> PatternType:
self.update_type_map(captures, type_map)

new_inner_types = self.expand_starred_pattern_types(
contracted_new_inner_types, star_position, len(inner_types)
contracted_new_inner_types, star_position, len(inner_types), unpack_index is not None
)
rest_inner_types = self.expand_starred_pattern_types(
contracted_rest_inner_types, star_position, len(inner_types)
contracted_rest_inner_types, star_position, len(inner_types), unpack_index is not None
)

#
# Calculate new type
#
new_type: Type
rest_type: Type = current_type
if isinstance(current_type, TupleType):
if isinstance(current_type, TupleType) and unpack_index is None:
narrowed_inner_types = []
inner_rest_types = []
for inner_type, new_inner_type in zip(inner_types, new_inner_types):
Expand All @@ -301,6 +321,14 @@ def visit_sequence_pattern(self, o: SequencePattern) -> PatternType:
if all(is_uninhabited(typ) for typ in inner_rest_types):
# All subpatterns always match, so we can apply negative narrowing
rest_type = TupleType(rest_inner_types, current_type.partial_fallback)
elif isinstance(current_type, TupleType):
# For variadic tuples it is too tricky to match individual items like for fixed
# tuples, so we instead try to narrow the entire type.
# TODO: use more precise narrowing when possible (e.g. for identical shapes).
new_tuple_type = TupleType(new_inner_types, current_type.partial_fallback)
new_type, rest_type = self.chk.conditional_types_with_intersection(
new_tuple_type, [get_type_range(current_type)], o, default=new_tuple_type
)
else:
new_inner_type = UninhabitedType()
for typ in new_inner_types:
Expand Down Expand Up @@ -345,17 +373,45 @@ def contract_starred_pattern_types(
If star_pos in None the types are returned unchanged.
"""
if star_pos is None:
return types
new_types = types[:star_pos]
star_length = len(types) - num_patterns
new_types.append(make_simplified_union(types[star_pos : star_pos + star_length]))
new_types += types[star_pos + star_length :]

return new_types
unpack_index = find_unpack_in_list(types)
if unpack_index is not None:
# Variadic tuples require "re-shaping" to match the requested pattern.
unpack = types[unpack_index]
assert isinstance(unpack, UnpackType)
unpacked = get_proper_type(unpack.type)
# This should be guaranteed by the normalization in the caller.
assert isinstance(unpacked, Instance) and unpacked.type.fullname == "builtins.tuple"
if star_pos is None:
missing = num_patterns - len(types) + 1
new_types = types[:unpack_index]
new_types += [unpacked.args[0]] * missing
new_types += types[unpack_index + 1 :]
return new_types
prefix, middle, suffix = split_with_prefix_and_suffix(
tuple([UnpackType(unpacked) if isinstance(t, UnpackType) else t for t in types]),
star_pos,
num_patterns - star_pos,
)
new_middle = []
for m in middle:
# The existing code expects the star item type, rather than the type of
# the whole tuple "slice".
if isinstance(m, UnpackType):
new_middle.append(unpacked.args[0])
else:
new_middle.append(m)
return list(prefix) + [make_simplified_union(new_middle)] + list(suffix)
else:
if star_pos is None:
return types
new_types = types[:star_pos]
star_length = len(types) - num_patterns
new_types.append(make_simplified_union(types[star_pos : star_pos + star_length]))
new_types += types[star_pos + star_length :]
return new_types

def expand_starred_pattern_types(
self, types: list[Type], star_pos: int | None, num_types: int
self, types: list[Type], star_pos: int | None, num_types: int, original_unpack: bool
) -> list[Type]:
"""Undoes the contraction done by contract_starred_pattern_types.
Expand All @@ -364,6 +420,17 @@ def expand_starred_pattern_types(
"""
if star_pos is None:
return types
if original_unpack:
# In the case where original tuple type has an unpack item, it is not practical
# to coerce pattern type back to the original shape (and may not even be possible),
# so we only restore the type of the star item.
res = []
for i, t in enumerate(types):
if i != star_pos:
res.append(t)
else:
res.append(UnpackType(self.chk.named_generic_type("builtins.tuple", [t])))
return res
new_types = types[:star_pos]
star_length = num_types - len(types) + 1
new_types += [types[star_pos]] * star_length
Expand Down Expand Up @@ -459,7 +526,15 @@ def visit_class_pattern(self, o: ClassPattern) -> PatternType:
return self.early_non_match()
if isinstance(type_info, TypeInfo):
any_type = AnyType(TypeOfAny.implementation_artifact)
typ: Type = Instance(type_info, [any_type] * len(type_info.defn.type_vars))
args: list[Type] = []
for tv in type_info.defn.type_vars:
if isinstance(tv, TypeVarTupleType):
args.append(
UnpackType(self.chk.named_generic_type("builtins.tuple", [any_type]))
)
else:
args.append(any_type)
typ: Type = Instance(type_info, args)
elif isinstance(type_info, TypeAlias):
typ = type_info.target
elif (
Expand Down
19 changes: 9 additions & 10 deletions mypy/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
Instance,
LiteralType,
NoneType,
NormalizedCallableType,
Overloaded,
Parameters,
ParamSpecType,
Expand Down Expand Up @@ -1388,7 +1389,7 @@ def find_matching_overload_items(
return res


def get_tuple_fallback_from_unpack(unpack: UnpackType) -> TypeInfo | None:
def get_tuple_fallback_from_unpack(unpack: UnpackType) -> TypeInfo:
"""Get builtins.tuple type from available types to construct homogeneous tuples."""
tp = get_proper_type(unpack.type)
if isinstance(tp, Instance) and tp.type.fullname == "builtins.tuple":
Expand All @@ -1399,10 +1400,10 @@ def get_tuple_fallback_from_unpack(unpack: UnpackType) -> TypeInfo | None:
for base in tp.partial_fallback.type.mro:
if base.fullname == "builtins.tuple":
return base
return None
assert False, "Invalid unpack type"


def repack_callable_args(callable: CallableType, tuple_type: TypeInfo | None) -> list[Type]:
def repack_callable_args(callable: CallableType, tuple_type: TypeInfo) -> list[Type]:
"""Present callable with star unpack in a normalized form.
Since positional arguments cannot follow star argument, they are packed in a suffix,
Expand All @@ -1417,12 +1418,8 @@ def repack_callable_args(callable: CallableType, tuple_type: TypeInfo | None) ->
star_type = callable.arg_types[star_index]
suffix_types = []
if not isinstance(star_type, UnpackType):
if tuple_type is not None:
# Re-normalize *args: X -> *args: *tuple[X, ...]
star_type = UnpackType(Instance(tuple_type, [star_type]))
else:
# This is unfortunate, something like tuple[Any, ...] would be better.
star_type = UnpackType(AnyType(TypeOfAny.from_error))
# Re-normalize *args: X -> *args: *tuple[X, ...]
star_type = UnpackType(Instance(tuple_type, [star_type]))
else:
tp = get_proper_type(star_type.type)
if isinstance(tp, TupleType):
Expand Down Expand Up @@ -1544,7 +1541,9 @@ def infer_directed_arg_constraints(left: Type, right: Type, direction: int) -> l


def infer_callable_arguments_constraints(
template: CallableType | Parameters, actual: CallableType | Parameters, direction: int
template: NormalizedCallableType | Parameters,
actual: NormalizedCallableType | Parameters,
direction: int,
) -> list[Constraint]:
"""Infer constraints between argument types of two callables.
Expand Down
4 changes: 3 additions & 1 deletion mypy/erasetype.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,9 @@ def visit_parameters(self, t: Parameters) -> ProperType:
raise RuntimeError("Parameters should have been bound to a class")

def visit_type_var_tuple(self, t: TypeVarTupleType) -> ProperType:
return AnyType(TypeOfAny.special_form)
# Likely, we can never get here because of aggressive erasure of types that
# can contain this, but better still return a valid replacement.
return t.tuple_fallback.copy_modified(args=[AnyType(TypeOfAny.special_form)])

def visit_unpack_type(self, t: UnpackType) -> ProperType:
return AnyType(TypeOfAny.special_form)
Expand Down
Loading

0 comments on commit f33c9a3

Please sign in to comment.