Skip to content

Commit

Permalink
Enable generic NamedTuples (#13396)
Browse files Browse the repository at this point in the history
Fixes #685

This builds on top of some infra I added for recursive types (Ref #13297). Implementation is based on the idea in #13297 (comment). Generally it works well, but there are actually some problems for named tuples that are recursive. Special-casing them in `maptype.py` is a bit ugly, but I think this is best we can get at the moment.
  • Loading branch information
ilevkivskyi authored Aug 15, 2022
1 parent fd7040e commit 8deeaf3
Show file tree
Hide file tree
Showing 17 changed files with 399 additions and 38 deletions.
3 changes: 3 additions & 0 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -3667,6 +3667,9 @@ def visit_type_application(self, tapp: TypeApplication) -> Type:
if isinstance(item, Instance):
tp = type_object_type(item.type, self.named_type)
return self.apply_type_arguments_to_callable(tp, item.args, tapp)
elif isinstance(item, TupleType) and item.partial_fallback.type.is_named_tuple:
tp = type_object_type(item.partial_fallback.type, self.named_type)
return self.apply_type_arguments_to_callable(tp, item.partial_fallback.args, tapp)
else:
self.chk.fail(message_registry.ONLY_CLASS_APPLICATION, tapp)
return AnyType(TypeOfAny.from_error)
Expand Down
8 changes: 8 additions & 0 deletions mypy/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -882,6 +882,14 @@ def visit_tuple_type(self, template: TupleType) -> List[Constraint]:
]

if isinstance(actual, TupleType) and len(actual.items) == len(template.items):
if (
actual.partial_fallback.type.is_named_tuple
and template.partial_fallback.type.is_named_tuple
):
# For named tuples using just the fallbacks usually gives better results.
return infer_constraints(
template.partial_fallback, actual.partial_fallback, self.direction
)
res: List[Constraint] = []
for i in range(len(template.items)):
res.extend(infer_constraints(template.items[i], actual.items[i], self.direction))
Expand Down
6 changes: 5 additions & 1 deletion mypy/expandtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,11 @@ def expand_types_with_unpack(
def visit_tuple_type(self, t: TupleType) -> Type:
items = self.expand_types_with_unpack(t.items)
if isinstance(items, list):
return t.copy_modified(items=items)
fallback = t.partial_fallback.accept(self)
fallback = get_proper_type(fallback)
if not isinstance(fallback, Instance):
fallback = t.partial_fallback
return t.copy_modified(items=items, fallback=fallback)
else:
return items

Expand Down
27 changes: 26 additions & 1 deletion mypy/maptype.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,20 @@

from typing import Dict, List

import mypy.typeops
from mypy.expandtype import expand_type
from mypy.nodes import TypeInfo
from mypy.types import AnyType, Instance, ProperType, Type, TypeOfAny, TypeVarId
from mypy.types import (
AnyType,
Instance,
ProperType,
TupleType,
Type,
TypeOfAny,
TypeVarId,
get_proper_type,
has_type_vars,
)


def map_instance_to_supertype(instance: Instance, superclass: TypeInfo) -> Instance:
Expand All @@ -18,6 +29,20 @@ def map_instance_to_supertype(instance: Instance, superclass: TypeInfo) -> Insta
# Fast path: `instance` already belongs to `superclass`.
return instance

if superclass.fullname == "builtins.tuple" and instance.type.tuple_type:
if has_type_vars(instance.type.tuple_type):
# We special case mapping generic tuple types to tuple base, because for
# such tuples fallback can't be calculated before applying type arguments.
alias = instance.type.special_alias
assert alias is not None
if not alias._is_recursive:
# Unfortunately we can't support this for generic recursive tuples.
# If we skip this special casing we will fall back to tuple[Any, ...].
env = instance_to_type_environment(instance)
tuple_type = get_proper_type(expand_type(instance.type.tuple_type, env))
if isinstance(tuple_type, TupleType):
return mypy.typeops.tuple_fallback(tuple_type)

if not superclass.type_vars:
# Fast path: `superclass` has no type variables to map to.
return Instance(superclass, [])
Expand Down
2 changes: 1 addition & 1 deletion mypy/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3294,7 +3294,7 @@ def from_tuple_type(cls, info: TypeInfo) -> TypeAlias:
"""Generate an alias to the tuple type described by a given TypeInfo."""
assert info.tuple_type
return TypeAlias(
info.tuple_type.copy_modified(fallback=mypy.types.Instance(info, [])),
info.tuple_type.copy_modified(fallback=mypy.types.Instance(info, info.defn.type_vars)),
info.fullname,
info.line,
info.column,
Expand Down
80 changes: 65 additions & 15 deletions mypy/semanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -1392,16 +1392,12 @@ def analyze_class(self, defn: ClassDef) -> None:
if self.analyze_typeddict_classdef(defn):
return

if self.analyze_namedtuple_classdef(defn):
if self.analyze_namedtuple_classdef(defn, tvar_defs):
return

# Create TypeInfo for class now that base classes and the MRO can be calculated.
self.prepare_class_def(defn)

defn.type_vars = tvar_defs
defn.info.type_vars = []
# we want to make sure any additional logic in add_type_vars gets run
defn.info.add_type_vars()
self.setup_type_vars(defn, tvar_defs)
if base_error:
defn.info.fallback_to_any = True

Expand All @@ -1414,6 +1410,19 @@ def analyze_class(self, defn: ClassDef) -> None:
self.analyze_class_decorator(defn, decorator)
self.analyze_class_body_common(defn)

def setup_type_vars(self, defn: ClassDef, tvar_defs: List[TypeVarLikeType]) -> None:
defn.type_vars = tvar_defs
defn.info.type_vars = []
# we want to make sure any additional logic in add_type_vars gets run
defn.info.add_type_vars()

def setup_alias_type_vars(self, defn: ClassDef) -> None:
assert defn.info.special_alias is not None
defn.info.special_alias.alias_tvars = list(defn.info.type_vars)
target = defn.info.special_alias.target
assert isinstance(target, ProperType) and isinstance(target, TupleType)
target.partial_fallback.args = tuple(defn.type_vars)

def is_core_builtin_class(self, defn: ClassDef) -> bool:
return self.cur_mod_id == "builtins" and defn.name in CORE_BUILTIN_CLASSES

Expand Down Expand Up @@ -1446,7 +1455,9 @@ def analyze_typeddict_classdef(self, defn: ClassDef) -> bool:
return True
return False

def analyze_namedtuple_classdef(self, defn: ClassDef) -> bool:
def analyze_namedtuple_classdef(
self, defn: ClassDef, tvar_defs: List[TypeVarLikeType]
) -> bool:
"""Check if this class can define a named tuple."""
if (
defn.info
Expand All @@ -1465,7 +1476,9 @@ def analyze_namedtuple_classdef(self, defn: ClassDef) -> bool:
if info is None:
self.mark_incomplete(defn.name, defn)
else:
self.prepare_class_def(defn, info)
self.prepare_class_def(defn, info, custom_names=True)
self.setup_type_vars(defn, tvar_defs)
self.setup_alias_type_vars(defn)
with self.scope.class_scope(defn.info):
with self.named_tuple_analyzer.save_namedtuple_body(info):
self.analyze_class_body_common(defn)
Expand Down Expand Up @@ -1690,7 +1703,31 @@ def get_all_bases_tvars(
tvars.extend(base_tvars)
return remove_dups(tvars)

def prepare_class_def(self, defn: ClassDef, info: Optional[TypeInfo] = None) -> None:
def get_and_bind_all_tvars(self, type_exprs: List[Expression]) -> List[TypeVarLikeType]:
"""Return all type variable references in item type expressions.
This is a helper for generic TypedDicts and NamedTuples. Essentially it is
a simplified version of the logic we use for ClassDef bases. We duplicate
some amount of code, because it is hard to refactor common pieces.
"""
tvars = []
for base_expr in type_exprs:
try:
base = self.expr_to_unanalyzed_type(base_expr)
except TypeTranslationError:
# This error will be caught later.
continue
base_tvars = base.accept(TypeVarLikeQuery(self.lookup_qualified, self.tvar_scope))
tvars.extend(base_tvars)
tvars = remove_dups(tvars) # Variables are defined in order of textual appearance.
tvar_defs = []
for name, tvar_expr in tvars:
tvar_def = self.tvar_scope.bind_new(name, tvar_expr)
tvar_defs.append(tvar_def)
return tvar_defs

def prepare_class_def(
self, defn: ClassDef, info: Optional[TypeInfo] = None, custom_names: bool = False
) -> None:
"""Prepare for the analysis of a class definition.
Create an empty TypeInfo and store it in a symbol table, or if the 'info'
Expand All @@ -1702,10 +1739,13 @@ def prepare_class_def(self, defn: ClassDef, info: Optional[TypeInfo] = None) ->
info = info or self.make_empty_type_info(defn)
defn.info = info
info.defn = defn
if not self.is_func_scope():
info._fullname = self.qualified_name(defn.name)
else:
info._fullname = info.name
if not custom_names:
# Some special classes (in particular NamedTuples) use custom fullname logic.
# Don't override it here (also see comment below, this needs cleanup).
if not self.is_func_scope():
info._fullname = self.qualified_name(defn.name)
else:
info._fullname = info.name
local_name = defn.name
if "@" in local_name:
local_name = local_name.split("@")[0]
Expand Down Expand Up @@ -1866,6 +1906,7 @@ def configure_tuple_base_class(self, defn: ClassDef, base: TupleType) -> Instanc
if info.special_alias and has_placeholder(info.special_alias.target):
self.defer(force_progress=True)
info.update_tuple_type(base)
self.setup_alias_type_vars(defn)

if base.partial_fallback.type.fullname == "builtins.tuple" and not has_placeholder(base):
# Fallback can only be safely calculated after semantic analysis, since base
Expand Down Expand Up @@ -2658,7 +2699,7 @@ def analyze_namedtuple_assign(self, s: AssignmentStmt) -> bool:
return False
lvalue = s.lvalues[0]
name = lvalue.name
internal_name, info = self.named_tuple_analyzer.check_namedtuple(
internal_name, info, tvar_defs = self.named_tuple_analyzer.check_namedtuple(
s.rvalue, name, self.is_func_scope()
)
if internal_name is None:
Expand All @@ -2678,6 +2719,9 @@ def analyze_namedtuple_assign(self, s: AssignmentStmt) -> bool:
# Yes, it's a valid namedtuple, but defer if it is not ready.
if not info:
self.mark_incomplete(name, lvalue, becomes_typeinfo=True)
else:
self.setup_type_vars(info.defn, tvar_defs)
self.setup_alias_type_vars(info.defn)
return True

def analyze_typeddict_assign(self, s: AssignmentStmt) -> bool:
Expand Down Expand Up @@ -5864,10 +5908,16 @@ def expr_to_analyzed_type(
self, expr: Expression, report_invalid_types: bool = True, allow_placeholder: bool = False
) -> Optional[Type]:
if isinstance(expr, CallExpr):
# This is a legacy syntax intended mostly for Python 2, we keep it for
# backwards compatibility, but new features like generic named tuples
# and recursive named tuples will be not supported.
expr.accept(self)
internal_name, info = self.named_tuple_analyzer.check_namedtuple(
internal_name, info, tvar_defs = self.named_tuple_analyzer.check_namedtuple(
expr, None, self.is_func_scope()
)
if tvar_defs:
self.fail("Generic named tuples are not supported for legacy class syntax", expr)
self.note("Use either Python 3 class syntax, or the assignment syntax", expr)
if internal_name is None:
# Some form of namedtuple is the only valid type that looks like a call
# expression. This isn't a valid type.
Expand Down
42 changes: 28 additions & 14 deletions mypy/semanal_namedtuple.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,10 @@
Type,
TypeOfAny,
TypeType,
TypeVarLikeType,
TypeVarType,
UnboundType,
has_type_vars,
)
from mypy.util import get_unique_redefinition_name

Expand Down Expand Up @@ -118,7 +120,6 @@ def analyze_namedtuple_classdef(
info = self.build_namedtuple_typeinfo(
defn.name, items, types, default_items, defn.line, existing_info
)
defn.info = info
defn.analyzed = NamedTupleExpr(info, is_typed=True)
defn.analyzed.line = defn.line
defn.analyzed.column = defn.column
Expand Down Expand Up @@ -201,7 +202,7 @@ def check_namedtuple_classdef(

def check_namedtuple(
self, node: Expression, var_name: Optional[str], is_func_scope: bool
) -> Tuple[Optional[str], Optional[TypeInfo]]:
) -> Tuple[Optional[str], Optional[TypeInfo], List[TypeVarLikeType]]:
"""Check if a call defines a namedtuple.
The optional var_name argument is the name of the variable to
Expand All @@ -216,21 +217,21 @@ def check_namedtuple(
report errors but return (some) TypeInfo.
"""
if not isinstance(node, CallExpr):
return None, None
return None, None, []
call = node
callee = call.callee
if not isinstance(callee, RefExpr):
return None, None
return None, None, []
fullname = callee.fullname
if fullname == "collections.namedtuple":
is_typed = False
elif fullname in TYPED_NAMEDTUPLE_NAMES:
is_typed = True
else:
return None, None
return None, None, []
result = self.parse_namedtuple_args(call, fullname)
if result:
items, types, defaults, typename, ok = result
items, types, defaults, typename, tvar_defs, ok = result
else:
# Error. Construct dummy return value.
if var_name:
Expand All @@ -244,10 +245,10 @@ def check_namedtuple(
if name != var_name or is_func_scope:
# NOTE: we skip local namespaces since they are not serialized.
self.api.add_symbol_skip_local(name, info)
return var_name, info
return var_name, info, []
if not ok:
# This is a valid named tuple but some types are not ready.
return typename, None
return typename, None, []

# We use the variable name as the class name if it exists. If
# it doesn't, we use the name passed as an argument. We prefer
Expand Down Expand Up @@ -306,7 +307,7 @@ def check_namedtuple(
if name != var_name or is_func_scope:
# NOTE: we skip local namespaces since they are not serialized.
self.api.add_symbol_skip_local(name, info)
return typename, info
return typename, info, tvar_defs

def store_namedtuple_info(
self, info: TypeInfo, name: str, call: CallExpr, is_typed: bool
Expand All @@ -317,7 +318,9 @@ def store_namedtuple_info(

def parse_namedtuple_args(
self, call: CallExpr, fullname: str
) -> Optional[Tuple[List[str], List[Type], List[Expression], str, bool]]:
) -> Optional[
Tuple[List[str], List[Type], List[Expression], str, List[TypeVarLikeType], bool]
]:
"""Parse a namedtuple() call into data needed to construct a type.
Returns a 5-tuple:
Expand Down Expand Up @@ -363,6 +366,7 @@ def parse_namedtuple_args(
return None
typename = cast(StrExpr, call.args[0]).value
types: List[Type] = []
tvar_defs = []
if not isinstance(args[1], (ListExpr, TupleExpr)):
if fullname == "collections.namedtuple" and isinstance(args[1], StrExpr):
str_expr = args[1]
Expand All @@ -384,14 +388,20 @@ def parse_namedtuple_args(
return None
items = [cast(StrExpr, item).value for item in listexpr.items]
else:
type_exprs = [
t.items[1]
for t in listexpr.items
if isinstance(t, TupleExpr) and len(t.items) == 2
]
tvar_defs = self.api.get_and_bind_all_tvars(type_exprs)
# The fields argument contains (name, type) tuples.
result = self.parse_namedtuple_fields_with_types(listexpr.items, call)
if result is None:
# One of the types is not ready, defer.
return None
items, types, _, ok = result
if not ok:
return [], [], [], typename, False
return [], [], [], typename, [], False
if not types:
types = [AnyType(TypeOfAny.unannotated) for _ in items]
underscore = [item for item in items if item.startswith("_")]
Expand All @@ -404,7 +414,7 @@ def parse_namedtuple_args(
if len(defaults) > len(items):
self.fail(f'Too many defaults given in call to "{type_name}()"', call)
defaults = defaults[: len(items)]
return items, types, defaults, typename, True
return items, types, defaults, typename, tvar_defs, True

def parse_namedtuple_fields_with_types(
self, nodes: List[Expression], context: Context
Expand Down Expand Up @@ -490,7 +500,7 @@ def build_namedtuple_typeinfo(
# We can't calculate the complete fallback type until after semantic
# analysis, since otherwise base classes might be incomplete. Postpone a
# callback function that patches the fallback.
if not has_placeholder(tuple_base):
if not has_placeholder(tuple_base) and not has_type_vars(tuple_base):
self.api.schedule_patch(
PRIORITY_FALLBACKS, lambda: calculate_tuple_fallback(tuple_base)
)
Expand Down Expand Up @@ -525,7 +535,11 @@ def add_field(

assert info.tuple_type is not None # Set by update_tuple_type() above.
tvd = TypeVarType(
SELF_TVAR_NAME, info.fullname + "." + SELF_TVAR_NAME, -1, [], info.tuple_type
SELF_TVAR_NAME,
info.fullname + "." + SELF_TVAR_NAME,
self.api.tvar_scope.new_unique_func_id(),
[],
info.tuple_type,
)
selftype = tvd

Expand Down
Loading

0 comments on commit 8deeaf3

Please sign in to comment.