Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/master' into issue-16614
Browse files Browse the repository at this point in the history
  • Loading branch information
hamdanal committed Feb 20, 2024
2 parents 310c75d + 790e8a7 commit 7ffa87f
Show file tree
Hide file tree
Showing 15 changed files with 404 additions and 65 deletions.
19 changes: 16 additions & 3 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -5053,6 +5053,19 @@ def visit_continue_stmt(self, s: ContinueStmt) -> None:
return

def visit_match_stmt(self, s: MatchStmt) -> None:
named_subject: Expression
if isinstance(s.subject, CallExpr):
# Create a dummy subject expression to handle cases where a match statement's subject
# is not a literal value. This lets us correctly narrow types and check exhaustivity
# This is hack!
id = s.subject.callee.fullname if isinstance(s.subject.callee, RefExpr) else ""
name = "dummy-match-" + id
v = Var(name)
named_subject = NameExpr(name)
named_subject.node = v
else:
named_subject = s.subject

with self.binder.frame_context(can_skip=False, fall_through=0):
subject_type = get_proper_type(self.expr_checker.accept(s.subject))

Expand All @@ -5071,7 +5084,7 @@ def visit_match_stmt(self, s: MatchStmt) -> None:
# The second pass narrows down the types and type checks bodies.
for p, g, b in zip(s.patterns, s.guards, s.bodies):
current_subject_type = self.expr_checker.narrow_type_from_binder(
s.subject, subject_type
named_subject, subject_type
)
pattern_type = self.pattern_checker.accept(p, current_subject_type)
with self.binder.frame_context(can_skip=True, fall_through=2):
Expand All @@ -5082,7 +5095,7 @@ def visit_match_stmt(self, s: MatchStmt) -> None:
else_map: TypeMap = {}
else:
pattern_map, else_map = conditional_types_to_typemaps(
s.subject, pattern_type.type, pattern_type.rest_type
named_subject, pattern_type.type, pattern_type.rest_type
)
self.remove_capture_conflicts(pattern_type.captures, inferred_types)
self.push_type_map(pattern_map)
Expand Down Expand Up @@ -5110,7 +5123,7 @@ def visit_match_stmt(self, s: MatchStmt) -> None:
and expr.fullname == case_target.fullname
):
continue
type_map[s.subject] = type_map[expr]
type_map[named_subject] = type_map[expr]

self.push_type_map(guard_map)
self.accept(b)
Expand Down
9 changes: 9 additions & 0 deletions mypy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -2059,6 +2059,15 @@ def impossible_intersection(
template.format(formatted_base_class_list, reason), context, code=codes.UNREACHABLE
)

def tvar_without_default_type(
self, tvar_name: str, last_tvar_name_with_default: str, context: Context
) -> None:
self.fail(
f'"{tvar_name}" cannot appear after "{last_tvar_name_with_default}" '
"in type parameter list because it has no default type",
context,
)

def report_protocol_problems(
self,
subtype: Instance | TupleType | TypedDictType | TypeType | CallableType,
Expand Down
47 changes: 38 additions & 9 deletions mypy/semanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,7 @@
SELF_TYPE_NAMES,
FindTypeVarVisitor,
TypeAnalyser,
TypeVarDefaultTranslator,
TypeVarLikeList,
analyze_type_alias,
check_for_explicit_any,
Expand All @@ -252,6 +253,7 @@
TPDICT_NAMES,
TYPE_ALIAS_NAMES,
TYPE_CHECK_ONLY_NAMES,
TYPE_VAR_LIKE_NAMES,
TYPED_NAMEDTUPLE_NAMES,
AnyType,
CallableType,
Expand Down Expand Up @@ -1953,17 +1955,19 @@ class Foo(Bar, Generic[T]): ...
defn.removed_base_type_exprs.append(defn.base_type_exprs[i])
del base_type_exprs[i]
tvar_defs: list[TypeVarLikeType] = []
last_tvar_name_with_default: str | None = None
for name, tvar_expr in declared_tvars:
tvar_expr_default = tvar_expr.default
if isinstance(tvar_expr_default, UnboundType):
# TODO: - detect out of order and self-referencing TypeVars
# - nested default types, e.g. list[T1]
n = self.lookup_qualified(
tvar_expr_default.name, tvar_expr_default, suppress_errors=True
)
if n is not None and (default := self.tvar_scope.get_binding(n)) is not None:
tvar_expr.default = default
tvar_expr.default = tvar_expr.default.accept(
TypeVarDefaultTranslator(self, tvar_expr.name, context)
)
tvar_def = self.tvar_scope.bind_new(name, tvar_expr)
if last_tvar_name_with_default is not None and not tvar_def.has_default():
self.msg.tvar_without_default_type(
tvar_def.name, last_tvar_name_with_default, context
)
tvar_def.default = AnyType(TypeOfAny.from_error)
elif tvar_def.has_default():
last_tvar_name_with_default = tvar_def.name
tvar_defs.append(tvar_def)
return base_type_exprs, tvar_defs, is_protocol

Expand Down Expand Up @@ -2857,6 +2861,10 @@ def visit_assignment_stmt(self, s: AssignmentStmt) -> None:
with self.allow_unbound_tvars_set():
s.rvalue.accept(self)
self.basic_type_applications = old_basic_type_applications
elif self.can_possibly_be_typevarlike_declaration(s):
# Allow unbound tvars inside TypeVarLike defaults to be evaluated later
with self.allow_unbound_tvars_set():
s.rvalue.accept(self)
else:
s.rvalue.accept(self)

Expand Down Expand Up @@ -3033,6 +3041,16 @@ def can_possibly_be_type_form(self, s: AssignmentStmt) -> bool:
# Something that looks like Foo = Bar[Baz, ...]
return True

def can_possibly_be_typevarlike_declaration(self, s: AssignmentStmt) -> bool:
"""Check if r.h.s. can be a TypeVarLike declaration."""
if len(s.lvalues) != 1 or not isinstance(s.lvalues[0], NameExpr):
return False
if not isinstance(s.rvalue, CallExpr) or not isinstance(s.rvalue.callee, NameExpr):
return False
ref = s.rvalue.callee
ref.accept(self)
return ref.fullname in TYPE_VAR_LIKE_NAMES

def is_type_ref(self, rv: Expression, bare: bool = False) -> bool:
"""Does this expression refer to a type?
Expand Down Expand Up @@ -3522,9 +3540,20 @@ def analyze_alias(
tvar_defs: list[TypeVarLikeType] = []
namespace = self.qualified_name(name)
alias_type_vars = found_type_vars if declared_type_vars is None else declared_type_vars
last_tvar_name_with_default: str | None = None
with self.tvar_scope_frame(self.tvar_scope.class_frame(namespace)):
for name, tvar_expr in alias_type_vars:
tvar_expr.default = tvar_expr.default.accept(
TypeVarDefaultTranslator(self, tvar_expr.name, typ)
)
tvar_def = self.tvar_scope.bind_new(name, tvar_expr)
if last_tvar_name_with_default is not None and not tvar_def.has_default():
self.msg.tvar_without_default_type(
tvar_def.name, last_tvar_name_with_default, typ
)
tvar_def.default = AnyType(TypeOfAny.from_error)
elif tvar_def.has_default():
last_tvar_name_with_default = tvar_def.name
tvar_defs.append(tvar_def)

analyzed, depends_on = analyze_type_alias(
Expand Down
51 changes: 33 additions & 18 deletions mypy/stubgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
import os.path
import sys
import traceback
from typing import Final, Iterable
from typing import Final, Iterable, Iterator

import mypy.build
import mypy.mixedtraverser
Expand Down Expand Up @@ -114,6 +114,7 @@
from mypy.stubdoc import ArgSig, FunctionSig
from mypy.stubgenc import InspectionStubGenerator, generate_stub_for_c_module
from mypy.stubutil import (
TYPING_BUILTIN_REPLACEMENTS,
BaseStubGenerator,
CantImport,
ClassInfo,
Expand Down Expand Up @@ -289,20 +290,19 @@ def visit_call_expr(self, node: CallExpr) -> str:
raise ValueError(f"Unknown argument kind {kind} in call")
return f"{callee}({', '.join(args)})"

def _visit_ref_expr(self, node: NameExpr | MemberExpr) -> str:
fullname = self.stubgen.get_fullname(node)
if fullname in TYPING_BUILTIN_REPLACEMENTS:
return self.stubgen.add_name(TYPING_BUILTIN_REPLACEMENTS[fullname], require=False)
qualname = get_qualified_name(node)
self.stubgen.import_tracker.require_name(qualname)
return qualname

def visit_name_expr(self, node: NameExpr) -> str:
self.stubgen.import_tracker.require_name(node.name)
return node.name
return self._visit_ref_expr(node)

def visit_member_expr(self, o: MemberExpr) -> str:
node: Expression = o
trailer = ""
while isinstance(node, MemberExpr):
trailer = "." + node.name + trailer
node = node.expr
if not isinstance(node, NameExpr):
return ERROR_MARKER
self.stubgen.import_tracker.require_name(node.name)
return node.name + trailer
return self._visit_ref_expr(o)

def visit_str_expr(self, node: StrExpr) -> str:
return repr(node.value)
Expand Down Expand Up @@ -351,11 +351,17 @@ def find_defined_names(file: MypyFile) -> set[str]:
return finder.names


def get_assigned_names(lvalues: Iterable[Expression]) -> Iterator[str]:
for lvalue in lvalues:
if isinstance(lvalue, NameExpr):
yield lvalue.name
elif isinstance(lvalue, TupleExpr):
yield from get_assigned_names(lvalue.items)


class DefinitionFinder(mypy.traverser.TraverserVisitor):
"""Find names of things defined at the top level of a module."""

# TODO: Assignment statements etc.

def __init__(self) -> None:
# Short names of things defined at the top level.
self.names: set[str] = set()
Expand All @@ -368,6 +374,10 @@ def visit_func_def(self, o: FuncDef) -> None:
# Don't recurse, as we only keep track of top-level definitions.
self.names.add(o.name)

def visit_assignment_stmt(self, o: AssignmentStmt) -> None:
for name in get_assigned_names(o.lvalues):
self.names.add(name)


def find_referenced_names(file: MypyFile) -> set[str]:
finder = ReferenceFinder()
Expand Down Expand Up @@ -1023,10 +1033,15 @@ def is_alias_expression(self, expr: Expression, top_level: bool = True) -> bool:
and isinstance(expr.node, (FuncDef, Decorator, MypyFile))
or isinstance(expr.node, TypeInfo)
) and not self.is_private_member(expr.node.fullname)
elif (
isinstance(expr, IndexExpr)
and isinstance(expr.base, NameExpr)
and not self.is_private_name(expr.base.name)
elif isinstance(expr, IndexExpr) and (
(isinstance(expr.base, NameExpr) and not self.is_private_name(expr.base.name))
or ( # Also some known aliases that could be member expression
isinstance(expr.base, MemberExpr)
and not self.is_private_member(get_qualified_name(expr.base))
and self.get_fullname(expr.base).startswith(
("builtins.", "typing.", "typing_extensions.", "collections.abc.")
)
)
):
if isinstance(expr.index, TupleExpr):
indices = expr.index.items
Expand Down
33 changes: 29 additions & 4 deletions mypy/stubutil.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,26 @@
# Modules that may fail when imported, or that may have side effects (fully qualified).
NOT_IMPORTABLE_MODULES = ()

# Typing constructs to be replaced by their builtin equivalents.
TYPING_BUILTIN_REPLACEMENTS: Final = {
# From typing
"typing.Text": "builtins.str",
"typing.Tuple": "builtins.tuple",
"typing.List": "builtins.list",
"typing.Dict": "builtins.dict",
"typing.Set": "builtins.set",
"typing.FrozenSet": "builtins.frozenset",
"typing.Type": "builtins.type",
# From typing_extensions
"typing_extensions.Text": "builtins.str",
"typing_extensions.Tuple": "builtins.tuple",
"typing_extensions.List": "builtins.list",
"typing_extensions.Dict": "builtins.dict",
"typing_extensions.Set": "builtins.set",
"typing_extensions.FrozenSet": "builtins.frozenset",
"typing_extensions.Type": "builtins.type",
}


class CantImport(Exception):
def __init__(self, module: str, message: str) -> None:
Expand Down Expand Up @@ -229,6 +249,8 @@ def visit_unbound_type(self, t: UnboundType) -> str:
return " | ".join([item.accept(self) for item in t.args])
if fullname == "typing.Optional":
return f"{t.args[0].accept(self)} | None"
if fullname in TYPING_BUILTIN_REPLACEMENTS:
s = self.stubgen.add_name(TYPING_BUILTIN_REPLACEMENTS[fullname], require=True)
if self.known_modules is not None and "." in s:
# see if this object is from any of the modules that we're currently processing.
# reverse sort so that subpackages come before parents: e.g. "foo.bar" before "foo".
Expand Down Expand Up @@ -476,7 +498,7 @@ def reexport(self, name: str) -> None:
def import_lines(self) -> list[str]:
"""The list of required import lines (as strings with python code).
In order for a module be included in this output, an indentifier must be both
In order for a module be included in this output, an identifier must be both
'required' via require_name() and 'imported' via add_import_from()
or add_import()
"""
Expand Down Expand Up @@ -585,9 +607,9 @@ def __init__(
# a corresponding import statement.
self.known_imports = {
"_typeshed": ["Incomplete"],
"typing": ["Any", "TypeVar", "NamedTuple"],
"typing": ["Any", "TypeVar", "NamedTuple", "TypedDict"],
"collections.abc": ["Generator"],
"typing_extensions": ["TypedDict", "ParamSpec", "TypeVarTuple"],
"typing_extensions": ["ParamSpec", "TypeVarTuple"],
}

def get_sig_generators(self) -> list[SignatureGenerator]:
Expand All @@ -613,7 +635,10 @@ def add_name(self, fullname: str, require: bool = True) -> str:
"""
module, name = fullname.rsplit(".", 1)
alias = "_" + name if name in self.defined_names else None
self.import_tracker.add_import_from(module, [(name, alias)], require=require)
while alias in self.defined_names:
alias = "_" + alias
if module != "builtins" or alias: # don't import from builtins unless needed
self.import_tracker.add_import_from(module, [(name, alias)], require=require)
return alias or name

def add_import_line(self, line: str) -> None:
Expand Down
36 changes: 35 additions & 1 deletion mypy/typeanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,12 @@
)
from mypy.options import Options
from mypy.plugin import AnalyzeTypeContext, Plugin, TypeAnalyzerPluginInterface
from mypy.semanal_shared import SemanticAnalyzerCoreInterface, paramspec_args, paramspec_kwargs
from mypy.semanal_shared import (
SemanticAnalyzerCoreInterface,
SemanticAnalyzerInterface,
paramspec_args,
paramspec_kwargs,
)
from mypy.state import state
from mypy.tvar_scope import TypeVarLikeScope
from mypy.types import (
Expand Down Expand Up @@ -2520,3 +2525,32 @@ def process_types(self, types: list[Type] | tuple[Type, ...]) -> None:
else:
for t in types:
t.accept(self)


class TypeVarDefaultTranslator(TrivialSyntheticTypeTranslator):
"""Type translate visitor that replaces UnboundTypes with in-scope TypeVars."""

def __init__(
self, api: SemanticAnalyzerInterface, tvar_expr_name: str, context: Context
) -> None:
self.api = api
self.tvar_expr_name = tvar_expr_name
self.context = context

def visit_unbound_type(self, t: UnboundType) -> Type:
sym = self.api.lookup_qualified(t.name, t, suppress_errors=True)
if sym is not None:
if type_var := self.api.tvar_scope.get_binding(sym):
return type_var
if isinstance(sym.node, TypeVarLikeExpr):
self.api.fail(
f'Type parameter "{self.tvar_expr_name}" has a default type '
"that refers to one or more type variables that are out of scope",
self.context,
)
return AnyType(TypeOfAny.from_error)
return super().visit_unbound_type(t)

def visit_type_alias_type(self, t: TypeAliasType) -> Type:
# TypeAliasTypes are analyzed separately already, just return it
return t
9 changes: 9 additions & 0 deletions mypy/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,15 @@
TypeVisitor as TypeVisitor,
)

TYPE_VAR_LIKE_NAMES: Final = (
"typing.TypeVar",
"typing_extensions.TypeVar",
"typing.ParamSpec",
"typing_extensions.ParamSpec",
"typing.TypeVarTuple",
"typing_extensions.TypeVarTuple",
)

TYPED_NAMEDTUPLE_NAMES: Final = ("typing.NamedTuple", "typing_extensions.NamedTuple")

# Supported names of TypedDict type constructors.
Expand Down
Loading

0 comments on commit 7ffa87f

Please sign in to comment.