Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support TypeAliasType #16926

Merged
merged 18 commits into from
Mar 11, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 108 additions & 15 deletions mypy/semanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@

from contextlib import contextmanager
from typing import Any, Callable, Collection, Final, Iterable, Iterator, List, TypeVar, cast
from typing_extensions import TypeAlias as _TypeAlias
from typing_extensions import TypeAlias as _TypeAlias, TypeGuard

from mypy import errorcodes as codes, message_registry
from mypy.constant_fold import constant_fold_expr
Expand Down Expand Up @@ -2007,34 +2007,36 @@ def analyze_class_typevar_declaration(self, base: Type) -> tuple[TypeVarLikeList

def analyze_unbound_tvar(self, t: Type) -> tuple[str, TypeVarLikeExpr] | None:
if isinstance(t, UnpackType) and isinstance(t.type, UnboundType):
return self.analyze_unbound_tvar_impl(t.type, allow_tvt=True)
return self.analyze_unbound_tvar_impl(t.type, is_unpacked=True)
if isinstance(t, UnboundType):
sym = self.lookup_qualified(t.name, t)
if sym and sym.fullname in ("typing.Unpack", "typing_extensions.Unpack"):
inner_t = t.args[0]
if isinstance(inner_t, UnboundType):
return self.analyze_unbound_tvar_impl(inner_t, allow_tvt=True)
return self.analyze_unbound_tvar_impl(inner_t, is_unpacked=True)
return None
return self.analyze_unbound_tvar_impl(t)
return None

def analyze_unbound_tvar_impl(
self, t: UnboundType, allow_tvt: bool = False
self, t: UnboundType, is_unpacked: bool = False, is_typealias_param: bool = False
) -> tuple[str, TypeVarLikeExpr] | None:
if is_unpacked and is_typealias_param:
return None # This should be unreachable
JelleZijlstra marked this conversation as resolved.
Show resolved Hide resolved
sym = self.lookup_qualified(t.name, t)
if sym and isinstance(sym.node, PlaceholderNode):
self.record_incomplete_ref()
if not allow_tvt and sym and isinstance(sym.node, ParamSpecExpr):
if not is_unpacked and sym and isinstance(sym.node, ParamSpecExpr):
if sym.fullname and not self.tvar_scope.allow_binding(sym.fullname):
# It's bound by our type variable scope
return None
return t.name, sym.node
if allow_tvt and sym and isinstance(sym.node, TypeVarTupleExpr):
if (is_unpacked or is_typealias_param) and sym and isinstance(sym.node, TypeVarTupleExpr):
if sym.fullname and not self.tvar_scope.allow_binding(sym.fullname):
# It's bound by our type variable scope
return None
return t.name, sym.node
if sym is None or not isinstance(sym.node, TypeVarExpr) or allow_tvt:
if sym is None or not isinstance(sym.node, TypeVarExpr) or is_unpacked:
return None
elif sym.fullname and not self.tvar_scope.allow_binding(sym.fullname):
# It's bound by our type variable scope
Expand Down Expand Up @@ -3490,7 +3492,11 @@ def analyze_simple_literal_type(self, rvalue: Expression, is_final: bool) -> Typ
return typ

def analyze_alias(
self, name: str, rvalue: Expression, allow_placeholder: bool = False
self,
name: str,
rvalue: Expression,
allow_placeholder: bool = False,
declared_type_vars: TypeVarLikeList | None = None,
) -> tuple[Type | None, list[TypeVarLikeType], set[str], list[str], bool]:
"""Check if 'rvalue' is a valid type allowed for aliasing (e.g. not a type variable).

Expand All @@ -3515,8 +3521,9 @@ def analyze_alias(
found_type_vars = self.find_type_var_likes(typ)
tvar_defs: list[TypeVarLikeType] = []
namespace = self.qualified_name(name)
alias_type_vars = found_type_vars if declared_type_vars is None else declared_type_vars
with self.tvar_scope_frame(self.tvar_scope.class_frame(namespace)):
for name, tvar_expr in found_type_vars:
for name, tvar_expr in alias_type_vars:
tvar_def = self.tvar_scope.bind_new(name, tvar_expr)
tvar_defs.append(tvar_def)

Expand All @@ -3543,7 +3550,7 @@ def analyze_alias(
variadic = True
new_tvar_defs.append(td)

qualified_tvars = [node.fullname for _name, node in found_type_vars]
qualified_tvars = [node.fullname for _name, node in alias_type_vars]
empty_tuple_index = typ.empty_tuple_index if isinstance(typ, UnboundType) else False
return analyzed, new_tvar_defs, depends_on, qualified_tvars, empty_tuple_index

Expand Down Expand Up @@ -3576,7 +3583,17 @@ def check_and_set_up_type_alias(self, s: AssignmentStmt) -> bool:
# unless using PEP 613 `cls: TypeAlias = A`
return False

if isinstance(s.rvalue, CallExpr) and s.rvalue.analyzed:
# It can be `A = TypeAliasType('A', ...)` call, in this case,
# we just take the second argument and analyze it:
type_params: TypeVarLikeList | None
if self.check_type_alias_type_call(s.rvalue, name=lvalue.name):
rvalue = s.rvalue.args[1]
type_params = self.analyze_type_alias_type_params(s.rvalue)
else:
rvalue = s.rvalue
type_params = None

if isinstance(rvalue, CallExpr) and rvalue.analyzed:
return False

existing = self.current_symbol_table().get(lvalue.name)
Expand All @@ -3602,7 +3619,7 @@ def check_and_set_up_type_alias(self, s: AssignmentStmt) -> bool:
return False

non_global_scope = self.type or self.is_func_scope()
if not pep_613 and isinstance(s.rvalue, RefExpr) and non_global_scope:
if not pep_613 and isinstance(rvalue, RefExpr) and non_global_scope:
# Fourth rule (special case): Non-subscripted right hand side creates a variable
# at class and function scopes. For example:
#
Expand All @@ -3614,7 +3631,6 @@ def check_and_set_up_type_alias(self, s: AssignmentStmt) -> bool:
# without this rule, this typical use case will require a lot of explicit
# annotations (see the second rule).
return False
rvalue = s.rvalue
if not pep_613 and not self.can_be_type_alias(rvalue):
return False

Expand All @@ -3632,7 +3648,7 @@ def check_and_set_up_type_alias(self, s: AssignmentStmt) -> bool:
else:
tag = self.track_incomplete_refs()
res, alias_tvars, depends_on, qualified_tvars, empty_tuple_index = self.analyze_alias(
lvalue.name, rvalue, allow_placeholder=True
lvalue.name, rvalue, allow_placeholder=True, declared_type_vars=type_params
)
if not res:
return False
Expand Down Expand Up @@ -3662,13 +3678,15 @@ def check_and_set_up_type_alias(self, s: AssignmentStmt) -> bool:
# so we need to replace it with non-explicit Anys.
res = make_any_non_explicit(res)
# Note: with the new (lazy) type alias representation we only need to set no_args to True
# if the expected number of arguments is non-zero, so that aliases like A = List work.
# if the expected number of arguments is non-zero, so that aliases like `A = List` work
# but not aliases like `A = TypeAliasType("A", List)` as these need explicit type params.
# However, eagerly expanding aliases like Text = str is a nice performance optimization.
no_args = (
isinstance(res, ProperType)
and isinstance(res, Instance)
and not res.args
and not empty_tuple_index
and type_params is None
)
if isinstance(res, ProperType) and isinstance(res, Instance):
if not validate_instance(res, self.fail, empty_tuple_index):
Expand Down Expand Up @@ -3735,6 +3753,75 @@ def check_and_set_up_type_alias(self, s: AssignmentStmt) -> bool:
self.note("Use variable annotation syntax to define protocol members", s)
return True

def check_type_alias_type_call(self, rvalue: Expression, *, name: str) -> TypeGuard[CallExpr]:
if not isinstance(rvalue, CallExpr):
return False

names = ["typing_extensions.TypeAliasType"]
if self.options.python_version >= (3, 12):
names.append("typing.TypeAliasType")
if not refers_to_fullname(rvalue.callee, tuple(names)):
return False

return self.check_typevarlike_name(rvalue, name, rvalue)

def analyze_type_alias_type_params(self, rvalue: CallExpr) -> TypeVarLikeList:
if "type_params" in rvalue.arg_names:
type_params_arg = rvalue.args[rvalue.arg_names.index("type_params")]
if not isinstance(type_params_arg, TupleExpr):
self.fail(
"Tuple literal expected as the type_params argument to TypeAliasType",
type_params_arg,
)
return []
type_params = type_params_arg.items
else:
type_params = []

declared_tvars: TypeVarLikeList = []
have_type_var_tuple = False
for tp_expr in type_params:
if isinstance(tp_expr, StarExpr):
tp_expr.valid = False
self.analyze_type_expr(tp_expr)
try:
base = self.expr_to_unanalyzed_type(tp_expr)
except TypeTranslationError:
continue
if not isinstance(base, UnboundType):
continue

tag = self.track_incomplete_refs()
tvar = self.analyze_unbound_tvar_impl(base, is_typealias_param=True)
if tvar:
if isinstance(tvar[1], TypeVarTupleExpr):
if have_type_var_tuple:
self.fail(
"Can only use one type var tuple in type_params argument to TypeAliasType",
JelleZijlstra marked this conversation as resolved.
Show resolved Hide resolved
base,
code=codes.TYPE_VAR,
)
have_type_var_tuple = True
continue
have_type_var_tuple = True
elif not self.found_incomplete_ref(tag):
self.fail(
"Free type variable expected in type_params argument to TypeAliasType",
base,
code=codes.TYPE_VAR,
)
continue
if tvar in declared_tvars:
self.fail(
"Duplicate type variables in type_params argument to TypeAliasType",
base,
code=codes.TYPE_VAR,
)
continue
if tvar:
declared_tvars.append(tvar)
return declared_tvars

def disable_invalid_recursive_aliases(
self, s: AssignmentStmt, current_node: TypeAlias
) -> None:
Expand Down Expand Up @@ -5151,6 +5238,12 @@ def visit_call_expr(self, expr: CallExpr) -> None:
expr.analyzed = OpExpr("divmod", expr.args[0], expr.args[1])
expr.analyzed.line = expr.line
expr.analyzed.accept(self)
elif refers_to_fullname(
expr.callee, ("typing.TypeAliasType", "typing_extensions.TypeAliasType")
):
with self.allow_unbound_tvars_set():
for a in expr.args:
a.accept(self)
else:
# Normal call expression.
for a in expr.args:
Expand Down
54 changes: 27 additions & 27 deletions test-data/unit/check-generics.test
Original file line number Diff line number Diff line change
Expand Up @@ -3124,8 +3124,8 @@ def dec(f: Callable[P, T]) -> Callable[P, List[T]]: ...
def id(x: U) -> U: ...
def either(x: U, y: U) -> U: ...
def pair(x: U, y: V) -> Tuple[U, V]: ...
reveal_type(dec(id)) # N: Revealed type is "def [T] (x: T`2) -> builtins.list[T`2]"
reveal_type(dec(either)) # N: Revealed type is "def [T] (x: T`4, y: T`4) -> builtins.list[T`4]"
reveal_type(dec(id)) # N: Revealed type is "def [T] (x: T`3) -> builtins.list[T`3]"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you know why this changed?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have no idea to be honest. They changed in the original PR by sobolevn and I didn't know why. I'll do some more digging.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. Not too important if it's hard to figure out, I assume these numbers can change based on all kinds of factors.

reveal_type(dec(either)) # N: Revealed type is "def [T] (x: T`5, y: T`5) -> builtins.list[T`5]"
reveal_type(dec(pair)) # N: Revealed type is "def [U, V] (x: U`-1, y: V`-2) -> builtins.list[Tuple[U`-1, V`-2]]"
[builtins fixtures/list.pyi]

Expand All @@ -3142,8 +3142,8 @@ V = TypeVar('V')
def dec(f: Callable[P, List[T]]) -> Callable[P, T]: ...
def id(x: U) -> U: ...
def either(x: U, y: U) -> U: ...
reveal_type(dec(id)) # N: Revealed type is "def [T] (x: builtins.list[T`2]) -> T`2"
reveal_type(dec(either)) # N: Revealed type is "def [T] (x: builtins.list[T`4], y: builtins.list[T`4]) -> T`4"
reveal_type(dec(id)) # N: Revealed type is "def [T] (x: builtins.list[T`3]) -> T`3"
reveal_type(dec(either)) # N: Revealed type is "def [T] (x: builtins.list[T`5], y: builtins.list[T`5]) -> T`5"
[builtins fixtures/list.pyi]

[case testInferenceAgainstGenericParamSpecPopOff]
Expand All @@ -3161,9 +3161,9 @@ def dec(f: Callable[Concatenate[T, P], S]) -> Callable[P, Callable[[T], S]]: ...
def id(x: U) -> U: ...
def either(x: U, y: U) -> U: ...
def pair(x: U, y: V) -> Tuple[U, V]: ...
reveal_type(dec(id)) # N: Revealed type is "def () -> def [T] (T`1) -> T`1"
reveal_type(dec(either)) # N: Revealed type is "def [T] (y: T`4) -> def (T`4) -> T`4"
reveal_type(dec(pair)) # N: Revealed type is "def [V] (y: V`-2) -> def [T] (T`7) -> Tuple[T`7, V`-2]"
reveal_type(dec(id)) # N: Revealed type is "def () -> def [T] (T`2) -> T`2"
reveal_type(dec(either)) # N: Revealed type is "def [T] (y: T`5) -> def (T`5) -> T`5"
reveal_type(dec(pair)) # N: Revealed type is "def [V] (y: V`-2) -> def [T] (T`8) -> Tuple[T`8, V`-2]"
reveal_type(dec(dec)) # N: Revealed type is "def () -> def [T, P, S] (def (T`-1, *P.args, **P.kwargs) -> S`-3) -> def (*P.args, **P.kwargs) -> def (T`-1) -> S`-3"
[builtins fixtures/list.pyi]

Expand All @@ -3182,11 +3182,11 @@ def dec(f: Callable[P, Callable[[T], S]]) -> Callable[Concatenate[T, P], S]: ...
def id() -> Callable[[U], U]: ...
def either(x: U) -> Callable[[U], U]: ...
def pair(x: U) -> Callable[[V], Tuple[V, U]]: ...
reveal_type(dec(id)) # N: Revealed type is "def [T] (T`2) -> T`2"
reveal_type(dec(either)) # N: Revealed type is "def [T] (T`5, x: T`5) -> T`5"
reveal_type(dec(pair)) # N: Revealed type is "def [T, U] (T`8, x: U`-1) -> Tuple[T`8, U`-1]"
reveal_type(dec(id)) # N: Revealed type is "def [T] (T`3) -> T`3"
reveal_type(dec(either)) # N: Revealed type is "def [T] (T`6, x: T`6) -> T`6"
reveal_type(dec(pair)) # N: Revealed type is "def [T, U] (T`9, x: U`-1) -> Tuple[T`9, U`-1]"
# This is counter-intuitive but looks correct, dec matches itself only if P can be empty
reveal_type(dec(dec)) # N: Revealed type is "def [T, S] (T`11, f: def () -> def (T`11) -> S`12) -> S`12"
reveal_type(dec(dec)) # N: Revealed type is "def [T, S] (T`12, f: def () -> def (T`12) -> S`13) -> S`13"
[builtins fixtures/list.pyi]

[case testInferenceAgainstGenericParamSpecVsParamSpec]
Expand All @@ -3203,7 +3203,7 @@ class Bar(Generic[P, T]): ...

def dec(f: Callable[P, T]) -> Callable[P, List[T]]: ...
def f(*args: Q.args, **kwargs: Q.kwargs) -> Foo[Q]: ...
reveal_type(dec(f)) # N: Revealed type is "def [P] (*P.args, **P.kwargs) -> builtins.list[__main__.Foo[P`1]]"
reveal_type(dec(f)) # N: Revealed type is "def [P] (*P.args, **P.kwargs) -> builtins.list[__main__.Foo[P`2]]"
g: Callable[Concatenate[int, Q], Foo[Q]]
reveal_type(dec(g)) # N: Revealed type is "def [Q] (builtins.int, *Q.args, **Q.kwargs) -> builtins.list[__main__.Foo[Q`-1]]"
h: Callable[Concatenate[T, Q], Bar[Q, T]]
Expand Down Expand Up @@ -3264,8 +3264,8 @@ def transform(

def dec(f: Callable[W, U]) -> Callable[W, U]: ...
def dec2(f: Callable[Concatenate[str, W], U]) -> Callable[Concatenate[bytes, W], U]: ...
reveal_type(transform(dec)) # N: Revealed type is "def [P, T] (def (builtins.int, *P.args, **P.kwargs) -> T`2) -> def (builtins.int, *P.args, **P.kwargs) -> T`2"
reveal_type(transform(dec2)) # N: Revealed type is "def [W, T] (def (builtins.int, builtins.str, *W.args, **W.kwargs) -> T`6) -> def (builtins.int, builtins.bytes, *W.args, **W.kwargs) -> T`6"
reveal_type(transform(dec)) # N: Revealed type is "def [P, T] (def (builtins.int, *P.args, **P.kwargs) -> T`3) -> def (builtins.int, *P.args, **P.kwargs) -> T`3"
reveal_type(transform(dec2)) # N: Revealed type is "def [W, T] (def (builtins.int, builtins.str, *W.args, **W.kwargs) -> T`7) -> def (builtins.int, builtins.bytes, *W.args, **W.kwargs) -> T`7"
[builtins fixtures/tuple.pyi]

[case testNoAccidentalVariableClashInNestedGeneric]
Expand Down Expand Up @@ -3319,8 +3319,8 @@ def id(x: U) -> U: ...
def either(x: U, y: U) -> U: ...
def pair(x: U, y: V) -> Tuple[U, V]: ...

reveal_type(dec(id)) # N: Revealed type is "def [T] (T`2) -> builtins.list[T`2]"
reveal_type(dec(either)) # N: Revealed type is "def [T] (T`4, T`4) -> builtins.list[T`4]"
reveal_type(dec(id)) # N: Revealed type is "def [T] (T`3) -> builtins.list[T`3]"
reveal_type(dec(either)) # N: Revealed type is "def [T] (T`5, T`5) -> builtins.list[T`5]"
reveal_type(dec(pair)) # N: Revealed type is "def [U, V] (U`-1, V`-2) -> builtins.list[Tuple[U`-1, V`-2]]"
[builtins fixtures/tuple.pyi]

Expand All @@ -3338,8 +3338,8 @@ V = TypeVar("V")
def id(x: U) -> U: ...
def either(x: U, y: U) -> U: ...

reveal_type(dec(id)) # N: Revealed type is "def [T] (builtins.list[T`2]) -> T`2"
reveal_type(dec(either)) # N: Revealed type is "def [T] (builtins.list[T`4], builtins.list[T`4]) -> T`4"
reveal_type(dec(id)) # N: Revealed type is "def [T] (builtins.list[T`3]) -> T`3"
reveal_type(dec(either)) # N: Revealed type is "def [T] (builtins.list[T`5], builtins.list[T`5]) -> T`5"
[builtins fixtures/tuple.pyi]

[case testInferenceAgainstGenericVariadicPopOff]
Expand All @@ -3358,9 +3358,9 @@ def id(x: U) -> U: ...
def either(x: U, y: U) -> U: ...
def pair(x: U, y: V) -> Tuple[U, V]: ...

reveal_type(dec(id)) # N: Revealed type is "def () -> def [T] (T`1) -> T`1"
reveal_type(dec(either)) # N: Revealed type is "def [T] (T`4) -> def (T`4) -> T`4"
reveal_type(dec(pair)) # N: Revealed type is "def [V] (V`-2) -> def [T] (T`7) -> Tuple[T`7, V`-2]"
reveal_type(dec(id)) # N: Revealed type is "def () -> def [T] (T`2) -> T`2"
reveal_type(dec(either)) # N: Revealed type is "def [T] (T`5) -> def (T`5) -> T`5"
reveal_type(dec(pair)) # N: Revealed type is "def [V] (V`-2) -> def [T] (T`8) -> Tuple[T`8, V`-2]"
reveal_type(dec(dec)) # N: Revealed type is "def () -> def [T, Ts, S] (def (T`-1, *Unpack[Ts`-2]) -> S`-3) -> def (*Unpack[Ts`-2]) -> def (T`-1) -> S`-3"
[builtins fixtures/list.pyi]

Expand All @@ -3380,11 +3380,11 @@ def id() -> Callable[[U], U]: ...
def either(x: U) -> Callable[[U], U]: ...
def pair(x: U) -> Callable[[V], Tuple[V, U]]: ...

reveal_type(dec(id)) # N: Revealed type is "def [T] (T`2) -> T`2"
reveal_type(dec(either)) # N: Revealed type is "def [T] (T`5, T`5) -> T`5"
reveal_type(dec(pair)) # N: Revealed type is "def [T, U] (T`8, U`-1) -> Tuple[T`8, U`-1]"
reveal_type(dec(id)) # N: Revealed type is "def [T] (T`3) -> T`3"
reveal_type(dec(either)) # N: Revealed type is "def [T] (T`6, T`6) -> T`6"
reveal_type(dec(pair)) # N: Revealed type is "def [T, U] (T`9, U`-1) -> Tuple[T`9, U`-1]"
# This is counter-intuitive but looks correct, dec matches itself only if Ts is empty
reveal_type(dec(dec)) # N: Revealed type is "def [T, S] (T`11, def () -> def (T`11) -> S`12) -> S`12"
reveal_type(dec(dec)) # N: Revealed type is "def [T, S] (T`12, def () -> def (T`12) -> S`13) -> S`13"
[builtins fixtures/list.pyi]

[case testInferenceAgainstGenericVariadicVsVariadic]
Expand All @@ -3402,9 +3402,9 @@ class Bar(Generic[Unpack[Ts], T]): ...

def dec(f: Callable[[Unpack[Ts]], T]) -> Callable[[Unpack[Ts]], List[T]]: ...
def f(*args: Unpack[Us]) -> Foo[Unpack[Us]]: ...
reveal_type(dec(f)) # N: Revealed type is "def [Ts] (*Unpack[Ts`1]) -> builtins.list[__main__.Foo[Unpack[Ts`1]]]"
reveal_type(dec(f)) # N: Revealed type is "def [Ts] (*Unpack[Ts`2]) -> builtins.list[__main__.Foo[Unpack[Ts`2]]]"
g: Callable[[Unpack[Us]], Foo[Unpack[Us]]]
reveal_type(dec(g)) # N: Revealed type is "def [Ts] (*Unpack[Ts`3]) -> builtins.list[__main__.Foo[Unpack[Ts`3]]]"
reveal_type(dec(g)) # N: Revealed type is "def [Ts] (*Unpack[Ts`4]) -> builtins.list[__main__.Foo[Unpack[Ts`4]]]"
[builtins fixtures/list.pyi]

[case testInferenceAgainstGenericVariadicVsVariadicConcatenate]
Expand Down
Loading
Loading