Skip to content

Commit

Permalink
Add basic support for recursive TypeVar defaults (PEP 696) (#16878)
Browse files Browse the repository at this point in the history
Ref: #14851
  • Loading branch information
cdce8p authored Feb 16, 2024
1 parent 5ffa6dd commit 2e5174c
Show file tree
Hide file tree
Showing 6 changed files with 133 additions and 4 deletions.
13 changes: 12 additions & 1 deletion mypy/applytype.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,18 @@ def apply_generic_arguments(
# TODO: move apply_poly() logic from checkexpr.py here when new inference
# becomes universally used (i.e. in all passes + in unification).
# With this new logic we can actually *add* some new free variables.
remaining_tvars = [tv for tv in tvars if tv.id not in id_to_type]
remaining_tvars: list[TypeVarLikeType] = []
for tv in tvars:
if tv.id in id_to_type:
continue
if not tv.has_default():
remaining_tvars.append(tv)
continue
# TypeVarLike isn't in id_to_type mapping.
# Only expand the TypeVar default here.
typ = expand_type(tv, id_to_type)
assert isinstance(typ, TypeVarLikeType)
remaining_tvars.append(typ)

return callable.copy_modified(
ret_type=expand_type(callable.ret_type, id_to_type),
Expand Down
9 changes: 9 additions & 0 deletions mypy/expandtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ class ExpandTypeVisitor(TrivialSyntheticTypeTranslator):

def __init__(self, variables: Mapping[TypeVarId, Type]) -> None:
self.variables = variables
self.recursive_tvar_guard: dict[TypeVarId, Type | None] = {}

def visit_unbound_type(self, t: UnboundType) -> Type:
return t
Expand Down Expand Up @@ -226,6 +227,14 @@ def visit_type_var(self, t: TypeVarType) -> Type:
# TODO: do we really need to do this?
# If I try to remove this special-casing ~40 tests fail on reveal_type().
return repl.copy_modified(last_known_value=None)
if isinstance(repl, TypeVarType) and repl.has_default():
if (tvar_id := repl.id) in self.recursive_tvar_guard:
return self.recursive_tvar_guard[tvar_id] or repl
self.recursive_tvar_guard[tvar_id] = None
repl = repl.accept(self)
if isinstance(repl, TypeVarType):
repl.default = repl.default.accept(self)
self.recursive_tvar_guard[tvar_id] = repl
return repl

def visit_param_spec(self, t: ParamSpecType) -> Type:
Expand Down
9 changes: 9 additions & 0 deletions mypy/semanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -1954,6 +1954,15 @@ class Foo(Bar, Generic[T]): ...
del base_type_exprs[i]
tvar_defs: list[TypeVarLikeType] = []
for name, tvar_expr in declared_tvars:
tvar_expr_default = tvar_expr.default
if isinstance(tvar_expr_default, UnboundType):
# TODO: - detect out of order and self-referencing TypeVars
# - nested default types, e.g. list[T1]
n = self.lookup_qualified(
tvar_expr_default.name, tvar_expr_default, suppress_errors=True
)
if n is not None and (default := self.tvar_scope.get_binding(n)) is not None:
tvar_expr.default = default
tvar_def = self.tvar_scope.bind_new(name, tvar_expr)
tvar_defs.append(tvar_def)
return base_type_exprs, tvar_defs, is_protocol
Expand Down
22 changes: 22 additions & 0 deletions mypy/tvar_scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,26 @@
TypeVarTupleType,
TypeVarType,
)
from mypy.typetraverser import TypeTraverserVisitor


class TypeVarLikeNamespaceSetter(TypeTraverserVisitor):
"""Set namespace for all TypeVarLikeTypes types."""

def __init__(self, namespace: str) -> None:
self.namespace = namespace

def visit_type_var(self, t: TypeVarType) -> None:
t.id.namespace = self.namespace
super().visit_type_var(t)

def visit_param_spec(self, t: ParamSpecType) -> None:
t.id.namespace = self.namespace
return super().visit_param_spec(t)

def visit_type_var_tuple(self, t: TypeVarTupleType) -> None:
t.id.namespace = self.namespace
super().visit_type_var_tuple(t)


class TypeVarLikeScope:
Expand Down Expand Up @@ -88,6 +108,8 @@ def bind_new(self, name: str, tvar_expr: TypeVarLikeExpr) -> TypeVarLikeType:
i = self.func_id
# TODO: Consider also using namespaces for functions
namespace = ""
tvar_expr.default.accept(TypeVarLikeNamespaceSetter(namespace))

if isinstance(tvar_expr, TypeVarExpr):
tvar_def: TypeVarLikeType = TypeVarType(
name=name,
Expand Down
6 changes: 3 additions & 3 deletions mypy/typetraverser.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,16 +61,16 @@ def visit_type_var(self, t: TypeVarType) -> None:
# Note that type variable values and upper bound aren't treated as
# components, since they are components of the type variable
# definition. We want to traverse everything just once.
pass
t.default.accept(self)

def visit_param_spec(self, t: ParamSpecType) -> None:
pass
t.default.accept(self)

def visit_parameters(self, t: Parameters) -> None:
self.traverse_types(t.arg_types)

def visit_type_var_tuple(self, t: TypeVarTupleType) -> None:
pass
t.default.accept(self)

def visit_literal_type(self, t: LiteralType) -> None:
t.fallback.accept(self)
Expand Down
78 changes: 78 additions & 0 deletions test-data/unit/check-typevar-defaults.test
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,84 @@ def func_c4(
reveal_type(m) # N: Revealed type is "__main__.ClassC4[builtins.int, builtins.float]"
[builtins fixtures/tuple.pyi]

[case testTypeVarDefaultsClassRecursive1]
# flags: --disallow-any-generics
from typing import Generic, TypeVar

T1 = TypeVar("T1", default=str)
T2 = TypeVar("T2", default=T1)
T3 = TypeVar("T3", default=T2)

class ClassD1(Generic[T1, T2]): ...

def func_d1(
a: ClassD1,
b: ClassD1[int],
c: ClassD1[int, float]
) -> None:
reveal_type(a) # N: Revealed type is "__main__.ClassD1[builtins.str, builtins.str]"
reveal_type(b) # N: Revealed type is "__main__.ClassD1[builtins.int, builtins.int]"
reveal_type(c) # N: Revealed type is "__main__.ClassD1[builtins.int, builtins.float]"

k = ClassD1()
reveal_type(k) # N: Revealed type is "__main__.ClassD1[builtins.str, builtins.str]"
l = ClassD1[int]()
reveal_type(l) # N: Revealed type is "__main__.ClassD1[builtins.int, builtins.int]"
m = ClassD1[int, float]()
reveal_type(m) # N: Revealed type is "__main__.ClassD1[builtins.int, builtins.float]"

class ClassD2(Generic[T1, T2, T3]): ...

def func_d2(
a: ClassD2,
b: ClassD2[int],
c: ClassD2[int, float],
d: ClassD2[int, float, str],
) -> None:
reveal_type(a) # N: Revealed type is "__main__.ClassD2[builtins.str, builtins.str, builtins.str]"
reveal_type(b) # N: Revealed type is "__main__.ClassD2[builtins.int, builtins.int, builtins.int]"
reveal_type(c) # N: Revealed type is "__main__.ClassD2[builtins.int, builtins.float, builtins.float]"
reveal_type(d) # N: Revealed type is "__main__.ClassD2[builtins.int, builtins.float, builtins.str]"

k = ClassD2()
reveal_type(k) # N: Revealed type is "__main__.ClassD2[builtins.str, builtins.str, builtins.str]"
l = ClassD2[int]()
reveal_type(l) # N: Revealed type is "__main__.ClassD2[builtins.int, builtins.int, builtins.int]"
m = ClassD2[int, float]()
reveal_type(m) # N: Revealed type is "__main__.ClassD2[builtins.int, builtins.float, builtins.float]"
n = ClassD2[int, float, str]()
reveal_type(n) # N: Revealed type is "__main__.ClassD2[builtins.int, builtins.float, builtins.str]"

[case testTypeVarDefaultsClassRecursiveMultipleFiles]
# flags: --disallow-any-generics
from typing import Generic, TypeVar
from file2 import T as T2

T = TypeVar('T', default=T2)

class ClassG1(Generic[T2, T]):
pass

def func(
a: ClassG1,
b: ClassG1[str],
c: ClassG1[str, float],
) -> None:
reveal_type(a) # N: Revealed type is "__main__.ClassG1[builtins.int, builtins.int]"
reveal_type(b) # N: Revealed type is "__main__.ClassG1[builtins.str, builtins.str]"
reveal_type(c) # N: Revealed type is "__main__.ClassG1[builtins.str, builtins.float]"

k = ClassG1()
reveal_type(k) # N: Revealed type is "__main__.ClassG1[builtins.int, builtins.int]"
l = ClassG1[str]()
reveal_type(l) # N: Revealed type is "__main__.ClassG1[builtins.str, builtins.str]"
m = ClassG1[str, float]()
reveal_type(m) # N: Revealed type is "__main__.ClassG1[builtins.str, builtins.float]"

[file file2.py]
from typing import TypeVar
T = TypeVar('T', default=int)

[case testTypeVarDefaultsTypeAlias1]
# flags: --disallow-any-generics
from typing import Any, Dict, List, Tuple, TypeVar, Union
Expand Down

0 comments on commit 2e5174c

Please sign in to comment.