Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Propagate type narrowing to nested functions #15133

Merged
merged 30 commits into from
May 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
e4ecb91
Add test case
JukkaL Apr 24, 2023
05e3958
WIP debugging stuff
JukkaL Apr 24, 2023
f3c87ca
Propagate narrowed types to nested functions in some cases
JukkaL Apr 24, 2023
01e044d
Add test cases
JukkaL Apr 24, 2023
eecc877
Support for and with statements
JukkaL Apr 24, 2023
030457e
Test more
JukkaL Apr 24, 2023
7968859
Fix nested functions
JukkaL Apr 25, 2023
7603f24
Add unit tests
JukkaL Apr 25, 2023
a81856a
More testing
JukkaL Apr 25, 2023
c503b3c
Check if a match statement assigns to a variable
JukkaL Apr 25, 2023
a601098
Support walrus expression
JukkaL Apr 25, 2023
fe05131
Fix self check and isort
JukkaL Apr 25, 2023
9d0b478
Add docstrings
JukkaL Apr 25, 2023
639a9a1
Minor tweak to test
JukkaL Apr 25, 2023
6dbf9f1
Refactoring and comment updates
JukkaL Apr 25, 2023
d463240
More refactoring
JukkaL Apr 25, 2023
77b67df
Test narrowing multiple variables
JukkaL Apr 25, 2023
339ae64
Add another test case
JukkaL Apr 25, 2023
56c3f3a
Don't leak frames + add assert to find frame leaks
JukkaL Apr 25, 2023
24faaa5
Test method
JukkaL Apr 25, 2023
ecc2d34
Fix mypyc build
JukkaL Apr 25, 2023
67098aa
Update mypy/literals.py
JukkaL Apr 25, 2023
25cafbd
Fix tests when using compiled mypy
JukkaL Apr 25, 2023
40f3b53
Fix unused import
JukkaL Apr 25, 2023
4bfb5f1
Actually fix compiled mypy
JukkaL Apr 25, 2023
0c4dce4
Add more tests for reading narrowed variable after nested function
JukkaL Apr 26, 2023
4fe08c2
Fix dealing with member and index expressions after nested function
JukkaL Apr 26, 2023
dc6516e
Merge branch 'master' into nested-func-optional
JukkaL May 2, 2023
c323e06
Fix test case
JukkaL May 2, 2023
c350905
More tests
JukkaL May 3, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion mypy/binder.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ def __init__(self, id: int, conditional_frame: bool = False) -> None:
# need this field.
self.suppress_unreachable_warnings = False

def __repr__(self) -> str:
return f"Frame({self.id}, {self.types}, {self.unreachable}, {self.conditional_frame})"


Assigns = DefaultDict[Expression, List[Tuple[Type, Optional[Type]]]]

Expand All @@ -63,7 +66,7 @@ class ConditionalTypeBinder:

```
class A:
a = None # type: Union[int, str]
a: Union[int, str] = None
x = A()
lst = [x]
reveal_type(x.a) # Union[int, str]
Expand Down Expand Up @@ -446,6 +449,7 @@ def top_frame_context(self) -> Iterator[Frame]:
assert len(self.frames) == 1
yield self.push_frame()
self.pop_frame(True, 0)
assert len(self.frames) == 1


def get_declaration(expr: BindableExpression) -> Type | None:
Expand Down
117 changes: 114 additions & 3 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

import mypy.checkexpr
from mypy import errorcodes as codes, message_registry, nodes, operators
from mypy.binder import ConditionalTypeBinder, get_declaration
from mypy.binder import ConditionalTypeBinder, Frame, get_declaration
from mypy.checkmember import (
MemberContext,
analyze_decorator_or_funcbase_access,
Expand All @@ -41,7 +41,7 @@
from mypy.errors import Errors, ErrorWatcher, report_internal_error
from mypy.expandtype import expand_self_type, expand_type, expand_type_by_instance
from mypy.join import join_types
from mypy.literals import Key, literal, literal_hash
from mypy.literals import Key, extract_var_from_literal_hash, literal, literal_hash
from mypy.maptype import map_instance_to_supertype
from mypy.meet import is_overlapping_erased_types, is_overlapping_types
from mypy.message_registry import ErrorMessage
Expand Down Expand Up @@ -134,6 +134,7 @@
is_final_node,
)
from mypy.options import Options
from mypy.patterns import AsPattern, StarredPattern
from mypy.plugin import CheckerPluginInterface, Plugin
from mypy.scope import Scope
from mypy.semanal import is_trivial_body, refers_to_fullname, set_callable_name
Expand All @@ -151,7 +152,7 @@
restrict_subtype_away,
unify_generic_callable,
)
from mypy.traverser import all_return_statements, has_return_statement
from mypy.traverser import TraverserVisitor, all_return_statements, has_return_statement
from mypy.treetransform import TransformVisitor
from mypy.typeanal import check_for_explicit_any, has_any_from_unimported_type, make_optional_type
from mypy.typeops import (
Expand Down Expand Up @@ -1207,6 +1208,20 @@ def check_func_def(

# Type check body in a new scope.
with self.binder.top_frame_context():
# Copy some type narrowings from an outer function when it seems safe enough
# (i.e. we can't find an assignment that might change the type of the
# variable afterwards).
new_frame: Frame | None = None
for frame in old_binder.frames:
for key, narrowed_type in frame.types.items():
key_var = extract_var_from_literal_hash(key)
if key_var is not None and not self.is_var_redefined_in_outer_context(
key_var, defn.line
):
# It seems safe to propagate the type narrowing to a nested scope.
if new_frame is None:
new_frame = self.binder.push_frame()
new_frame.types[key] = narrowed_type
with self.scope.push_function(defn):
# We suppress reachability warnings when we use TypeVars with value
# restrictions: we only want to report a warning if a certain statement is
Expand All @@ -1218,6 +1233,8 @@ def check_func_def(
self.binder.suppress_unreachable_warnings()
self.accept(item.body)
unreachable = self.binder.is_unreachable()
if new_frame is not None:
self.binder.pop_frame(True, 0)

if not unreachable:
if defn.is_generator or is_named_instance(
Expand Down Expand Up @@ -1310,6 +1327,23 @@ def check_func_def(

self.binder = old_binder

def is_var_redefined_in_outer_context(self, v: Var, after_line: int) -> bool:
"""Can the variable be assigned to at module top level or outer function?
Note that this doesn't do a full CFG analysis but uses a line number based
heuristic that isn't correct in some (rare) cases.
"""
outers = self.tscope.outer_functions()
if not outers:
# Top-level function -- outer context is top level, and we can't reason about
# globals
return True
for outer in outers:
if isinstance(outer, FuncDef):
if find_last_var_assignment_line(outer.body, v) >= after_line:
return True
return False

def check_unbound_return_typevar(self, typ: CallableType) -> None:
"""Fails when the return typevar is not defined in arguments."""
if isinstance(typ.ret_type, TypeVarType) and typ.ret_type in typ.variables:
Expand Down Expand Up @@ -7629,3 +7663,80 @@ def collapse_walrus(e: Expression) -> Expression:
if isinstance(e, AssignmentExpr):
return e.target
return e


def find_last_var_assignment_line(n: Node, v: Var) -> int:
"""Find the highest line number of a potential assignment to variable within node.
This supports local and global variables.
Return -1 if no assignment was found.
"""
visitor = VarAssignVisitor(v)
n.accept(visitor)
return visitor.last_line


class VarAssignVisitor(TraverserVisitor):
def __init__(self, v: Var) -> None:
self.last_line = -1
self.lvalue = False
self.var_node = v

def visit_assignment_stmt(self, s: AssignmentStmt) -> None:
self.lvalue = True
for lv in s.lvalues:
lv.accept(self)
self.lvalue = False

def visit_name_expr(self, e: NameExpr) -> None:
if self.lvalue and e.node is self.var_node:
self.last_line = max(self.last_line, e.line)

def visit_member_expr(self, e: MemberExpr) -> None:
old_lvalue = self.lvalue
self.lvalue = False
super().visit_member_expr(e)
self.lvalue = old_lvalue

def visit_index_expr(self, e: IndexExpr) -> None:
old_lvalue = self.lvalue
self.lvalue = False
super().visit_index_expr(e)
self.lvalue = old_lvalue

def visit_with_stmt(self, s: WithStmt) -> None:
self.lvalue = True
for lv in s.target:
if lv is not None:
lv.accept(self)
self.lvalue = False
s.body.accept(self)

def visit_for_stmt(self, s: ForStmt) -> None:
self.lvalue = True
s.index.accept(self)
self.lvalue = False
s.body.accept(self)
if s.else_body:
s.else_body.accept(self)

def visit_assignment_expr(self, e: AssignmentExpr) -> None:
self.lvalue = True
e.target.accept(self)
self.lvalue = False
e.value.accept(self)

def visit_as_pattern(self, p: AsPattern) -> None:
if p.pattern is not None:
p.pattern.accept(self)
if p.name is not None:
self.lvalue = True
p.name.accept(self)
self.lvalue = False

def visit_starred_pattern(self, p: StarredPattern) -> None:
if p.capture is not None:
self.lvalue = True
p.capture.accept(self)
self.lvalue = False
3 changes: 2 additions & 1 deletion mypy/fastparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -1766,7 +1766,8 @@ def visit_MatchStar(self, n: MatchStar) -> StarredPattern:
if n.name is None:
node = StarredPattern(None)
else:
node = StarredPattern(NameExpr(n.name))
name = self.set_line(NameExpr(n.name), n)
node = StarredPattern(name)

return self.set_line(node, n)

Expand Down
10 changes: 10 additions & 0 deletions mypy/literals.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,16 @@ def literal_hash(e: Expression) -> Key | None:
return e.accept(_hasher)


def extract_var_from_literal_hash(key: Key) -> Var | None:
"""If key refers to a Var node, return it.
Return None otherwise.
"""
if len(key) == 2 and key[0] == "Var" and isinstance(key[1], Var):
return key[1]
return None


class _Hasher(ExpressionVisitor[Optional[Key]]):
def visit_int_expr(self, e: IntExpr) -> Key:
return ("Literal", e.value)
Expand Down
6 changes: 6 additions & 0 deletions mypy/scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def __init__(self) -> None:
self.module: str | None = None
self.classes: list[TypeInfo] = []
self.function: FuncBase | None = None
self.functions: list[FuncBase] = []
# Number of nested scopes ignored (that don't get their own separate targets)
self.ignored = 0

Expand Down Expand Up @@ -65,19 +66,24 @@ def module_scope(self, prefix: str) -> Iterator[None]:

@contextmanager
def function_scope(self, fdef: FuncBase) -> Iterator[None]:
self.functions.append(fdef)
if not self.function:
self.function = fdef
else:
# Nested functions are part of the topmost function target.
self.ignored += 1
yield
self.functions.pop()
if self.ignored:
# Leave a scope that's included in the enclosing target.
self.ignored -= 1
else:
assert self.function
self.function = None

def outer_functions(self) -> list[FuncBase]:
return self.functions[:-1]

def enter_class(self, info: TypeInfo) -> None:
"""Enter a class target scope."""
if not self.function:
Expand Down
Loading