From 2e5174c82317068645ec888c36902bd19ffcd49d Mon Sep 17 00:00:00 2001 From: Marc Mueller <30130371+cdce8p@users.noreply.github.com> Date: Fri, 16 Feb 2024 15:01:15 +0100 Subject: [PATCH] Add basic support for recursive TypeVar defaults (PEP 696) (#16878) Ref: https://github.com/python/mypy/issues/14851 --- mypy/applytype.py | 13 +++- mypy/expandtype.py | 9 +++ mypy/semanal.py | 9 +++ mypy/tvar_scope.py | 22 ++++++ mypy/typetraverser.py | 6 +- test-data/unit/check-typevar-defaults.test | 78 ++++++++++++++++++++++ 6 files changed, 133 insertions(+), 4 deletions(-) diff --git a/mypy/applytype.py b/mypy/applytype.py index b00372855d9c..e14906fa2772 100644 --- a/mypy/applytype.py +++ b/mypy/applytype.py @@ -147,7 +147,18 @@ def apply_generic_arguments( # TODO: move apply_poly() logic from checkexpr.py here when new inference # becomes universally used (i.e. in all passes + in unification). # With this new logic we can actually *add* some new free variables. - remaining_tvars = [tv for tv in tvars if tv.id not in id_to_type] + remaining_tvars: list[TypeVarLikeType] = [] + for tv in tvars: + if tv.id in id_to_type: + continue + if not tv.has_default(): + remaining_tvars.append(tv) + continue + # TypeVarLike isn't in id_to_type mapping. + # Only expand the TypeVar default here. + typ = expand_type(tv, id_to_type) + assert isinstance(typ, TypeVarLikeType) + remaining_tvars.append(typ) return callable.copy_modified( ret_type=expand_type(callable.ret_type, id_to_type), diff --git a/mypy/expandtype.py b/mypy/expandtype.py index d2d294fb77f3..3bf45854b2a0 100644 --- a/mypy/expandtype.py +++ b/mypy/expandtype.py @@ -179,6 +179,7 @@ class ExpandTypeVisitor(TrivialSyntheticTypeTranslator): def __init__(self, variables: Mapping[TypeVarId, Type]) -> None: self.variables = variables + self.recursive_tvar_guard: dict[TypeVarId, Type | None] = {} def visit_unbound_type(self, t: UnboundType) -> Type: return t @@ -226,6 +227,14 @@ def visit_type_var(self, t: TypeVarType) -> Type: # TODO: do we really need to do this? # If I try to remove this special-casing ~40 tests fail on reveal_type(). return repl.copy_modified(last_known_value=None) + if isinstance(repl, TypeVarType) and repl.has_default(): + if (tvar_id := repl.id) in self.recursive_tvar_guard: + return self.recursive_tvar_guard[tvar_id] or repl + self.recursive_tvar_guard[tvar_id] = None + repl = repl.accept(self) + if isinstance(repl, TypeVarType): + repl.default = repl.default.accept(self) + self.recursive_tvar_guard[tvar_id] = repl return repl def visit_param_spec(self, t: ParamSpecType) -> Type: diff --git a/mypy/semanal.py b/mypy/semanal.py index 4bf9f0c3eabb..aeceb644fe52 100644 --- a/mypy/semanal.py +++ b/mypy/semanal.py @@ -1954,6 +1954,15 @@ class Foo(Bar, Generic[T]): ... del base_type_exprs[i] tvar_defs: list[TypeVarLikeType] = [] for name, tvar_expr in declared_tvars: + tvar_expr_default = tvar_expr.default + if isinstance(tvar_expr_default, UnboundType): + # TODO: - detect out of order and self-referencing TypeVars + # - nested default types, e.g. list[T1] + n = self.lookup_qualified( + tvar_expr_default.name, tvar_expr_default, suppress_errors=True + ) + if n is not None and (default := self.tvar_scope.get_binding(n)) is not None: + tvar_expr.default = default tvar_def = self.tvar_scope.bind_new(name, tvar_expr) tvar_defs.append(tvar_def) return base_type_exprs, tvar_defs, is_protocol diff --git a/mypy/tvar_scope.py b/mypy/tvar_scope.py index c7a653a1552d..4dc663df0399 100644 --- a/mypy/tvar_scope.py +++ b/mypy/tvar_scope.py @@ -15,6 +15,26 @@ TypeVarTupleType, TypeVarType, ) +from mypy.typetraverser import TypeTraverserVisitor + + +class TypeVarLikeNamespaceSetter(TypeTraverserVisitor): + """Set namespace for all TypeVarLikeTypes types.""" + + def __init__(self, namespace: str) -> None: + self.namespace = namespace + + def visit_type_var(self, t: TypeVarType) -> None: + t.id.namespace = self.namespace + super().visit_type_var(t) + + def visit_param_spec(self, t: ParamSpecType) -> None: + t.id.namespace = self.namespace + return super().visit_param_spec(t) + + def visit_type_var_tuple(self, t: TypeVarTupleType) -> None: + t.id.namespace = self.namespace + super().visit_type_var_tuple(t) class TypeVarLikeScope: @@ -88,6 +108,8 @@ def bind_new(self, name: str, tvar_expr: TypeVarLikeExpr) -> TypeVarLikeType: i = self.func_id # TODO: Consider also using namespaces for functions namespace = "" + tvar_expr.default.accept(TypeVarLikeNamespaceSetter(namespace)) + if isinstance(tvar_expr, TypeVarExpr): tvar_def: TypeVarLikeType = TypeVarType( name=name, diff --git a/mypy/typetraverser.py b/mypy/typetraverser.py index d9ab54871f4a..1ff5f6685eb8 100644 --- a/mypy/typetraverser.py +++ b/mypy/typetraverser.py @@ -61,16 +61,16 @@ def visit_type_var(self, t: TypeVarType) -> None: # Note that type variable values and upper bound aren't treated as # components, since they are components of the type variable # definition. We want to traverse everything just once. - pass + t.default.accept(self) def visit_param_spec(self, t: ParamSpecType) -> None: - pass + t.default.accept(self) def visit_parameters(self, t: Parameters) -> None: self.traverse_types(t.arg_types) def visit_type_var_tuple(self, t: TypeVarTupleType) -> None: - pass + t.default.accept(self) def visit_literal_type(self, t: LiteralType) -> None: t.fallback.accept(self) diff --git a/test-data/unit/check-typevar-defaults.test b/test-data/unit/check-typevar-defaults.test index 544bc59494b3..1a08823cb692 100644 --- a/test-data/unit/check-typevar-defaults.test +++ b/test-data/unit/check-typevar-defaults.test @@ -349,6 +349,84 @@ def func_c4( reveal_type(m) # N: Revealed type is "__main__.ClassC4[builtins.int, builtins.float]" [builtins fixtures/tuple.pyi] +[case testTypeVarDefaultsClassRecursive1] +# flags: --disallow-any-generics +from typing import Generic, TypeVar + +T1 = TypeVar("T1", default=str) +T2 = TypeVar("T2", default=T1) +T3 = TypeVar("T3", default=T2) + +class ClassD1(Generic[T1, T2]): ... + +def func_d1( + a: ClassD1, + b: ClassD1[int], + c: ClassD1[int, float] +) -> None: + reveal_type(a) # N: Revealed type is "__main__.ClassD1[builtins.str, builtins.str]" + reveal_type(b) # N: Revealed type is "__main__.ClassD1[builtins.int, builtins.int]" + reveal_type(c) # N: Revealed type is "__main__.ClassD1[builtins.int, builtins.float]" + + k = ClassD1() + reveal_type(k) # N: Revealed type is "__main__.ClassD1[builtins.str, builtins.str]" + l = ClassD1[int]() + reveal_type(l) # N: Revealed type is "__main__.ClassD1[builtins.int, builtins.int]" + m = ClassD1[int, float]() + reveal_type(m) # N: Revealed type is "__main__.ClassD1[builtins.int, builtins.float]" + +class ClassD2(Generic[T1, T2, T3]): ... + +def func_d2( + a: ClassD2, + b: ClassD2[int], + c: ClassD2[int, float], + d: ClassD2[int, float, str], +) -> None: + reveal_type(a) # N: Revealed type is "__main__.ClassD2[builtins.str, builtins.str, builtins.str]" + reveal_type(b) # N: Revealed type is "__main__.ClassD2[builtins.int, builtins.int, builtins.int]" + reveal_type(c) # N: Revealed type is "__main__.ClassD2[builtins.int, builtins.float, builtins.float]" + reveal_type(d) # N: Revealed type is "__main__.ClassD2[builtins.int, builtins.float, builtins.str]" + + k = ClassD2() + reveal_type(k) # N: Revealed type is "__main__.ClassD2[builtins.str, builtins.str, builtins.str]" + l = ClassD2[int]() + reveal_type(l) # N: Revealed type is "__main__.ClassD2[builtins.int, builtins.int, builtins.int]" + m = ClassD2[int, float]() + reveal_type(m) # N: Revealed type is "__main__.ClassD2[builtins.int, builtins.float, builtins.float]" + n = ClassD2[int, float, str]() + reveal_type(n) # N: Revealed type is "__main__.ClassD2[builtins.int, builtins.float, builtins.str]" + +[case testTypeVarDefaultsClassRecursiveMultipleFiles] +# flags: --disallow-any-generics +from typing import Generic, TypeVar +from file2 import T as T2 + +T = TypeVar('T', default=T2) + +class ClassG1(Generic[T2, T]): + pass + +def func( + a: ClassG1, + b: ClassG1[str], + c: ClassG1[str, float], +) -> None: + reveal_type(a) # N: Revealed type is "__main__.ClassG1[builtins.int, builtins.int]" + reveal_type(b) # N: Revealed type is "__main__.ClassG1[builtins.str, builtins.str]" + reveal_type(c) # N: Revealed type is "__main__.ClassG1[builtins.str, builtins.float]" + + k = ClassG1() + reveal_type(k) # N: Revealed type is "__main__.ClassG1[builtins.int, builtins.int]" + l = ClassG1[str]() + reveal_type(l) # N: Revealed type is "__main__.ClassG1[builtins.str, builtins.str]" + m = ClassG1[str, float]() + reveal_type(m) # N: Revealed type is "__main__.ClassG1[builtins.str, builtins.float]" + +[file file2.py] +from typing import TypeVar +T = TypeVar('T', default=int) + [case testTypeVarDefaultsTypeAlias1] # flags: --disallow-any-generics from typing import Any, Dict, List, Tuple, TypeVar, Union