Skip to content

Commit

Permalink
Remove some not needed (and dangerous) uses of get_proper_type (#15198)
Browse files Browse the repository at this point in the history
This is in preparation for support of variadic type aliases. This PR
should be a no-op from user perspective. Note I also added a test to
prohibit new usages of `get_proper_type()` in `expand_type()`, since
this can easily create unwanted recursion.
  • Loading branch information
ilevkivskyi authored May 7, 2023
1 parent ba8ae29 commit 3a1dc4c
Show file tree
Hide file tree
Showing 8 changed files with 36 additions and 23 deletions.
1 change: 1 addition & 0 deletions misc/proper_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def is_special_target(right: ProperType) -> bool:
"mypy.types.UnpackType",
"mypy.types.TypeVarTupleType",
"mypy.types.ParamSpecType",
"mypy.types.Parameters",
"mypy.types.RawExpressionType",
"mypy.types.EllipsisType",
"mypy.types.StarType",
Expand Down
4 changes: 2 additions & 2 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2057,7 +2057,7 @@ def check_argument_types(
if (
isinstance(first_actual_arg_type, TupleType)
and len(first_actual_arg_type.items) == 1
and isinstance(get_proper_type(first_actual_arg_type.items[0]), UnpackType)
and isinstance(first_actual_arg_type.items[0], UnpackType)
):
# TODO: use walrus operator
actual_types = [first_actual_arg_type.items[0]] + [
Expand All @@ -2084,7 +2084,7 @@ def check_argument_types(
callee_arg_types = unpacked_type.items
callee_arg_kinds = [ARG_POS] * len(actuals)
else:
inner_unpack = get_proper_type(unpacked_type.items[inner_unpack_index])
inner_unpack = unpacked_type.items[inner_unpack_index]
assert isinstance(inner_unpack, UnpackType)
inner_unpacked_type = get_proper_type(inner_unpack.type)
# We assume heterogenous tuples are desugared earlier
Expand Down
2 changes: 1 addition & 1 deletion mypy/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def infer_constraints_for_callable(
# not to hold we can always handle the prefixes too.
inner_unpack = unpacked_type.items[0]
assert isinstance(inner_unpack, UnpackType)
inner_unpacked_type = get_proper_type(inner_unpack.type)
inner_unpacked_type = inner_unpack.type
assert isinstance(inner_unpacked_type, TypeVarTupleType)
suffix_len = len(unpacked_type.items) - 1
constraints.append(
Expand Down
19 changes: 8 additions & 11 deletions mypy/expandtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,10 +271,11 @@ def interpolate_args_for_unpack(
star_index = t.arg_kinds.index(ARG_STAR)

# We have something like Unpack[Tuple[X1, X2, Unpack[Ts], Y1, Y2]]
if isinstance(get_proper_type(var_arg.type), TupleType):
expanded_tuple = get_proper_type(var_arg.type.accept(self))
var_arg_type = get_proper_type(var_arg.type)
if isinstance(var_arg_type, TupleType):
expanded_tuple = var_arg_type.accept(self)
# TODO: handle the case that expanded_tuple is a variable length tuple.
assert isinstance(expanded_tuple, TupleType)
assert isinstance(expanded_tuple, ProperType) and isinstance(expanded_tuple, TupleType)
expanded_items = expanded_tuple.items
else:
expanded_items_res = self.expand_unpack(var_arg)
Expand Down Expand Up @@ -320,11 +321,11 @@ def interpolate_args_for_unpack(
# homogenous tuple, then only the prefix can be represented as
# positional arguments, and we pass Tuple[Unpack[Ts-1], Y1, Y2]
# as the star arg, for example.
expanded_unpack = get_proper_type(expanded_items[expanded_unpack_index])
expanded_unpack = expanded_items[expanded_unpack_index]
assert isinstance(expanded_unpack, UnpackType)

# Extract the typevartuple so we can get a tuple fallback from it.
expanded_unpacked_tvt = get_proper_type(expanded_unpack.type)
expanded_unpacked_tvt = expanded_unpack.type
assert isinstance(expanded_unpacked_tvt, TypeVarTupleType)

prefix_len = expanded_unpack_index
Expand Down Expand Up @@ -450,18 +451,14 @@ def visit_tuple_type(self, t: TupleType) -> Type:
items = self.expand_types_with_unpack(t.items)
if isinstance(items, list):
fallback = t.partial_fallback.accept(self)
fallback = get_proper_type(fallback)
if not isinstance(fallback, Instance):
fallback = t.partial_fallback
assert isinstance(fallback, ProperType) and isinstance(fallback, Instance)
return t.copy_modified(items=items, fallback=fallback)
else:
return items

def visit_typeddict_type(self, t: TypedDictType) -> Type:
fallback = t.fallback.accept(self)
fallback = get_proper_type(fallback)
if not isinstance(fallback, Instance):
fallback = t.fallback
assert isinstance(fallback, ProperType) and isinstance(fallback, Instance)
return t.copy_modified(item_types=self.expand_types(t.items.values()), fallback=fallback)

def visit_literal_type(self, t: LiteralType) -> Type:
Expand Down
2 changes: 1 addition & 1 deletion mypy/semanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -936,7 +936,7 @@ def analyze_func_def(self, defn: FuncDef) -> None:
def remove_unpack_kwargs(self, defn: FuncDef, typ: CallableType) -> CallableType:
if not typ.arg_kinds or typ.arg_kinds[-1] is not ArgKind.ARG_STAR2:
return typ
last_type = get_proper_type(typ.arg_types[-1])
last_type = typ.arg_types[-1]
if not isinstance(last_type, UnpackType):
return typ
last_type = get_proper_type(last_type.type)
Expand Down
17 changes: 17 additions & 0 deletions mypy/test/testtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@

from __future__ import annotations

import re
from unittest import TestCase, skipUnless

import mypy.expandtype
from mypy.erasetype import erase_type, remove_instance_last_known_values
from mypy.expandtype import expand_type
from mypy.indirection import TypeIndirectionVisitor
Expand Down Expand Up @@ -1435,3 +1439,16 @@ def make_call(*items: tuple[str, str | None]) -> CallExpr:
else:
arg_kinds.append(ARG_POS)
return CallExpr(NameExpr("f"), args, arg_kinds, arg_names)


class TestExpandTypeLimitGetProperType(TestCase):
# WARNING: do not increase this number unless absolutely necessary,
# and you understand what you are doing.
ALLOWED_GET_PROPER_TYPES = 7

@skipUnless(mypy.expandtype.__file__.endswith(".py"), "Skip for compiled mypy")
def test_count_get_proper_type(self) -> None:
with open(mypy.expandtype.__file__) as f:
code = f.read()
get_proper_type_count = len(re.findall("get_proper_type", code))
assert get_proper_type_count == self.ALLOWED_GET_PROPER_TYPES
8 changes: 3 additions & 5 deletions mypy/typevartuples.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@
def find_unpack_in_list(items: Sequence[Type]) -> int | None:
unpack_index: int | None = None
for i, item in enumerate(items):
proper_item = get_proper_type(item)
if isinstance(proper_item, UnpackType):
if isinstance(item, UnpackType):
# We cannot fail here, so we must check this in an earlier
# semanal phase.
# Funky code here avoids mypyc narrowing the type of unpack_index.
Expand Down Expand Up @@ -181,9 +180,8 @@ def fully_split_with_mapped_and_template(
def extract_unpack(types: Sequence[Type]) -> ProperType | None:
"""Given a list of types, extracts either a single type from an unpack, or returns None."""
if len(types) == 1:
proper_type = get_proper_type(types[0])
if isinstance(proper_type, UnpackType):
return get_proper_type(proper_type.type)
if isinstance(types[0], UnpackType):
return get_proper_type(types[0].type)
return None


Expand Down
6 changes: 3 additions & 3 deletions test-data/unit/check-typevar-tuple.test
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ class Array(Generic[Unpack[Shape]]):

def get_shape(self) -> Tuple[Unpack[Shape]]:
return self._shape

def __abs__(self) -> Array[Unpack[Shape]]: ...

def __add__(self, other: Array[Unpack[Shape]]) -> Array[Unpack[Shape]]: ...
Expand Down Expand Up @@ -237,7 +237,7 @@ class Array(Generic[DType, Unpack[Shape]]):

def get_shape(self) -> Tuple[Unpack[Shape]]:
return self._shape

def __abs__(self) -> Array[DType, Unpack[Shape]]: ...

def __add__(self, other: Array[DType, Unpack[Shape]]) -> Array[DType, Unpack[Shape]]: ...
Expand Down Expand Up @@ -443,7 +443,7 @@ def foo(*args: Unpack[Tuple[int, ...]]) -> None:

foo(0, 1, 2)
# TODO: this should say 'expected "int"' rather than the unpack
foo(0, 1, "bar") # E: Argument 3 to "foo" has incompatible type "str"; expected "Unpack[Tuple[int, ...]]"
foo(0, 1, "bar") # E: Argument 3 to "foo" has incompatible type "str"; expected "Unpack[Tuple[int, ...]]"


def foo2(*args: Unpack[Tuple[str, Unpack[Tuple[int, ...]], bool, bool]]) -> None:
Expand Down

0 comments on commit 3a1dc4c

Please sign in to comment.