Skip to content

Commit

Permalink
[partially defined] implement support for try statements (#14114)
Browse files Browse the repository at this point in the history
This adds support for try/except/finally/else check.

The implementation ended up pretty complicated because it had to handle
jumps different for finally. It took me a few iterations to get to this
solution and that's the cleanest one I could come up with.

Closes #13928.
  • Loading branch information
ilinum authored Dec 16, 2022
1 parent df6e828 commit 96ac8b3
Show file tree
Hide file tree
Showing 2 changed files with 295 additions and 2 deletions.
113 changes: 111 additions & 2 deletions mypy/partially_defined.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
RefExpr,
ReturnStmt,
StarExpr,
TryStmt,
TupleExpr,
WhileStmt,
WithStmt,
Expand Down Expand Up @@ -66,6 +67,13 @@ def __init__(
self.must_be_defined = set(must_be_defined)
self.skipped = skipped

def copy(self) -> BranchState:
return BranchState(
must_be_defined=set(self.must_be_defined),
may_be_defined=set(self.may_be_defined),
skipped=self.skipped,
)


class BranchStatement:
def __init__(self, initial_state: BranchState) -> None:
Expand All @@ -77,6 +85,11 @@ def __init__(self, initial_state: BranchState) -> None:
)
]

def copy(self) -> BranchStatement:
result = BranchStatement(self.initial_state)
result.branches = [b.copy() for b in self.branches]
return result

def next_branch(self) -> None:
self.branches.append(
BranchState(
Expand All @@ -90,6 +103,11 @@ def record_definition(self, name: str) -> None:
self.branches[-1].must_be_defined.add(name)
self.branches[-1].may_be_defined.discard(name)

def delete_var(self, name: str) -> None:
assert len(self.branches) > 0
self.branches[-1].must_be_defined.discard(name)
self.branches[-1].may_be_defined.discard(name)

def record_nested_branch(self, state: BranchState) -> None:
assert len(self.branches) > 0
current_branch = self.branches[-1]
Expand Down Expand Up @@ -151,6 +169,11 @@ def __init__(self, stmts: list[BranchStatement]) -> None:
self.branch_stmts: list[BranchStatement] = stmts
self.undefined_refs: dict[str, set[NameExpr]] = {}

def copy(self) -> Scope:
result = Scope([s.copy() for s in self.branch_stmts])
result.undefined_refs = self.undefined_refs.copy()
return result

def record_undefined_ref(self, o: NameExpr) -> None:
if o.name not in self.undefined_refs:
self.undefined_refs[o.name] = set()
Expand All @@ -166,6 +189,15 @@ class DefinedVariableTracker:
def __init__(self) -> None:
# There's always at least one scope. Within each scope, there's at least one "global" BranchingStatement.
self.scopes: list[Scope] = [Scope([BranchStatement(BranchState())])]
# disable_branch_skip is used to disable skipping a branch due to a return/raise/etc. This is useful
# in things like try/except/finally statements.
self.disable_branch_skip = False

def copy(self) -> DefinedVariableTracker:
result = DefinedVariableTracker()
result.scopes = [s.copy() for s in self.scopes]
result.disable_branch_skip = self.disable_branch_skip
return result

def _scope(self) -> Scope:
assert len(self.scopes) > 0
Expand Down Expand Up @@ -195,14 +227,19 @@ def end_branch_statement(self) -> None:

def skip_branch(self) -> None:
# Only skip branch if we're outside of "root" branch statement.
if len(self._scope().branch_stmts) > 1:
if len(self._scope().branch_stmts) > 1 and not self.disable_branch_skip:
self._scope().branch_stmts[-1].skip_branch()

def record_definition(self, name: str) -> None:
assert len(self.scopes) > 0
assert len(self.scopes[-1].branch_stmts) > 0
self._scope().branch_stmts[-1].record_definition(name)

def delete_var(self, name: str) -> None:
assert len(self.scopes) > 0
assert len(self.scopes[-1].branch_stmts) > 0
self._scope().branch_stmts[-1].delete_var(name)

def record_undefined_ref(self, o: NameExpr) -> None:
"""Records an undefined reference. These can later be retrieved via `pop_undefined_ref`."""
assert len(self.scopes) > 0
Expand Down Expand Up @@ -268,6 +305,7 @@ def __init__(
self.type_map = type_map
self.options = options
self.loops: list[Loop] = []
self.try_depth = 0
self.tracker = DefinedVariableTracker()
for name in implicit_module_attrs:
self.tracker.record_definition(name)
Expand Down Expand Up @@ -432,6 +470,75 @@ def visit_expression_stmt(self, o: ExpressionStmt) -> None:
self.tracker.skip_branch()
super().visit_expression_stmt(o)

def visit_try_stmt(self, o: TryStmt) -> None:
"""
Note that finding undefined vars in `finally` requires different handling from
the rest of the code. In particular, we want to disallow skipping branches due to jump
statements in except/else clauses for finally but not for other cases. Imagine a case like:
def f() -> int:
try:
x = 1
except:
# This jump statement needs to be handled differently depending on whether or
# not we're trying to process `finally` or not.
return 0
finally:
# `x` may be undefined here.
pass
# `x` is always defined here.
return x
"""
self.try_depth += 1
if o.finally_body is not None:
# In order to find undefined vars in `finally`, we need to
# process try/except with branch skipping disabled. However, for the rest of the code
# after finally, we need to process try/except with branch skipping enabled.
# Therefore, we need to process try/finally twice.
# Because processing is not idempotent, we should make a copy of the tracker.
old_tracker = self.tracker.copy()
self.tracker.disable_branch_skip = True
self.process_try_stmt(o)
self.tracker = old_tracker
self.process_try_stmt(o)
self.try_depth -= 1

def process_try_stmt(self, o: TryStmt) -> None:
"""
Processes try statement decomposing it into the following:
if ...:
body
else_body
elif ...:
except 1
elif ...:
except 2
else:
except n
finally
"""
self.tracker.start_branch_statement()
o.body.accept(self)
if o.else_body is not None:
o.else_body.accept(self)
if len(o.handlers) > 0:
assert len(o.handlers) == len(o.vars) == len(o.types)
for i in range(len(o.handlers)):
self.tracker.next_branch()
exc_type = o.types[i]
if exc_type is not None:
exc_type.accept(self)
var = o.vars[i]
if var is not None:
self.process_definition(var.name)
var.accept(self)
o.handlers[i].accept(self)
if var is not None:
self.tracker.delete_var(var.name)
self.tracker.end_branch_statement()

if o.finally_body is not None:
o.finally_body.accept(self)

def visit_while_stmt(self, o: WhileStmt) -> None:
o.expr.accept(self)
self.tracker.start_branch_statement()
Expand Down Expand Up @@ -478,7 +585,9 @@ def visit_name_expr(self, o: NameExpr) -> None:
self.tracker.record_definition(o.name)
elif self.tracker.is_defined_in_different_branch(o.name):
# A variable is defined in one branch but used in a different branch.
if self.loops:
if self.loops or self.try_depth > 0:
# If we're in a loop or in a try, we can't be sure that this variable
# is undefined. Report it as "may be undefined".
self.variable_may_be_undefined(o.name, o)
else:
self.var_used_before_def(o.name, o)
Expand Down
184 changes: 184 additions & 0 deletions test-data/unit/check-possibly-undefined.test
Original file line number Diff line number Diff line change
Expand Up @@ -525,6 +525,190 @@ def f3() -> None:
y = x
z = x # E: Name "x" may be undefined

[case testTryBasic]
# flags: --enable-error-code possibly-undefined --enable-error-code used-before-def
def f1() -> int:
try:
x = 1
except:
pass
return x # E: Name "x" may be undefined

def f2() -> int:
try:
pass
except:
x = 1
return x # E: Name "x" may be undefined

def f3() -> int:
try:
x = 1
except:
y = x # E: Name "x" may be undefined
return x # E: Name "x" may be undefined

def f4() -> int:
try:
x = 1
except:
return 0
return x

def f5() -> int:
try:
x = 1
except:
raise
return x

def f6() -> None:
try:
pass
except BaseException as exc:
x = exc # No error.
exc = BaseException()
# This case is covered by the other check, not by possibly undefined check.
y = exc # E: Trying to read deleted variable "exc"

def f7() -> int:
try:
if int():
x = 1
assert False
except:
pass
return x # E: Name "x" may be undefined
[builtins fixtures/exception.pyi]

[case testTryMultiExcept]
# flags: --enable-error-code possibly-undefined
def f1() -> int:
try:
x = 1
except BaseException:
x = 2
except:
x = 3
return x

def f2() -> int:
try:
x = 1
except BaseException:
pass
except:
x = 3
return x # E: Name "x" may be undefined
[builtins fixtures/exception.pyi]

[case testTryFinally]
# flags: --enable-error-code possibly-undefined --enable-error-code used-before-def
def f1() -> int:
try:
x = 1
finally:
x = 2
return x

def f2() -> int:
try:
pass
except:
pass
finally:
x = 2
return x

def f3() -> int:
try:
x = 1
except:
pass
finally:
y = x # E: Name "x" may be undefined
return x

def f4() -> int:
try:
x = 0
except BaseException:
raise
finally:
y = x # E: Name "x" may be undefined
return y

def f5() -> int:
try:
if int():
x = 1
else:
return 0
finally:
pass
return x # No error.

def f6() -> int:
try:
if int():
x = 1
else:
return 0
finally:
a = x # E: Name "x" may be undefined
return a
[builtins fixtures/exception.pyi]

[case testTryElse]
# flags: --enable-error-code possibly-undefined
def f1() -> int:
try:
return 0
except BaseException:
x = 1
else:
x = 2
finally:
y = x
return y

def f2() -> int:
try:
pass
except:
x = 1
else:
x = 2
return x

def f3() -> int:
try:
pass
except:
x = 1
else:
pass
return x # E: Name "x" may be undefined

def f4() -> int:
try:
x = 1
except:
x = 2
else:
pass
return x

def f5() -> int:
try:
pass
except:
x = 1
else:
return 1
return x
[builtins fixtures/exception.pyi]

[case testNoReturn]
# flags: --enable-error-code possibly-undefined

Expand Down

0 comments on commit 96ac8b3

Please sign in to comment.