Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mypyc] Don't crash on unreachable statements #16311

Merged
merged 1 commit into from
Oct 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions mypyc/irbuild/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,9 @@ def __init__(
self.runtime_args: list[list[RuntimeArg]] = [[]]
self.function_name_stack: list[str] = []
self.class_ir_stack: list[ClassIR] = []
# Keep track of whether the next statement in a block is reachable
# or not, separately for each block nesting level
self.block_reachable_stack: list[bool] = [True]

self.current_module = current_module
self.mapper = mapper
Expand Down Expand Up @@ -1302,6 +1305,14 @@ def is_native_attr_ref(self, expr: MemberExpr) -> bool:
and not obj_rtype.class_ir.get_method(expr.name)
)

def mark_block_unreachable(self) -> None:
"""Mark statements in the innermost block being processed as unreachable.

This should be called after a statement that unconditionally leaves the
block, such as 'break' or 'return'.
"""
self.block_reachable_stack[-1] = False

# Lacks a good type because there wasn't a reasonable type in 3.5 :(
def catch_errors(self, line: int) -> Any:
return catch_errors(self.module_path, line)
Expand Down
5 changes: 5 additions & 0 deletions mypyc/irbuild/statement.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,13 @@

def transform_block(builder: IRBuilder, block: Block) -> None:
if not block.is_unreachable:
builder.block_reachable_stack.append(True)
for stmt in block.body:
builder.accept(stmt)
if not builder.block_reachable_stack[-1]:
# The rest of the block is unreachable, so skip it
break
builder.block_reachable_stack.pop()
# Raise a RuntimeError if we hit a non-empty unreachable block.
# Don't complain about empty unreachable blocks, since mypy inserts
# those after `if MYPY`.
Expand Down
4 changes: 4 additions & 0 deletions mypyc/irbuild/visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ def visit_expression_stmt(self, stmt: ExpressionStmt) -> None:

def visit_return_stmt(self, stmt: ReturnStmt) -> None:
transform_return_stmt(self.builder, stmt)
self.builder.mark_block_unreachable()

def visit_assignment_stmt(self, stmt: AssignmentStmt) -> None:
transform_assignment_stmt(self.builder, stmt)
Expand All @@ -212,12 +213,15 @@ def visit_for_stmt(self, stmt: ForStmt) -> None:

def visit_break_stmt(self, stmt: BreakStmt) -> None:
transform_break_stmt(self.builder, stmt)
self.builder.mark_block_unreachable()

def visit_continue_stmt(self, stmt: ContinueStmt) -> None:
transform_continue_stmt(self.builder, stmt)
self.builder.mark_block_unreachable()

def visit_raise_stmt(self, stmt: RaiseStmt) -> None:
transform_raise_stmt(self.builder, stmt)
self.builder.mark_block_unreachable()

def visit_try_stmt(self, stmt: TryStmt) -> None:
transform_try_stmt(self.builder, stmt)
Expand Down
137 changes: 136 additions & 1 deletion mypyc/test-data/irbuild-unreachable.test
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Test cases for unreachable expressions
# Test cases for unreachable expressions and statements

[case testUnreachableMemberExpr]
import sys
Expand Down Expand Up @@ -104,3 +104,138 @@ L5:
L6:
y = r11
return 1

[case testUnreachableStatementAfterReturn]
def f(x: bool) -> int:
if x:
return 1
f(False)
return 2
[out]
def f(x):
x :: bool
L0:
if x goto L1 else goto L2 :: bool
L1:
return 2
L2:
return 4

[case testUnreachableStatementAfterContinue]
def c() -> bool:
return False

def f() -> None:
n = True
while n:
if c():
continue
if int():
f()
n = False
[out]
def c():
L0:
return 0
def f():
n, r0 :: bool
L0:
n = 1
L1:
if n goto L2 else goto L5 :: bool
L2:
r0 = c()
if r0 goto L3 else goto L4 :: bool
L3:
goto L1
L4:
n = 0
goto L1
L5:
return 1

[case testUnreachableStatementAfterBreak]
def c() -> bool:
return False

def f() -> None:
n = True
while n:
if c():
break
if int():
f()
n = False
[out]
def c():
L0:
return 0
def f():
n, r0 :: bool
L0:
n = 1
L1:
if n goto L2 else goto L5 :: bool
L2:
r0 = c()
if r0 goto L3 else goto L4 :: bool
L3:
goto L5
L4:
n = 0
goto L1
L5:
return 1

[case testUnreachableStatementAfterRaise]
def f(x: bool) -> int:
if x:
raise ValueError()
print('hello')
return 2
[out]
def f(x):
x :: bool
r0 :: object
r1 :: str
r2, r3 :: object
L0:
if x goto L1 else goto L2 :: bool
L1:
r0 = builtins :: module
r1 = 'ValueError'
r2 = CPyObject_GetAttr(r0, r1)
r3 = PyObject_CallFunctionObjArgs(r2, 0)
CPy_Raise(r3)
unreachable
L2:
return 4

[case testUnreachableStatementAfterAssertFalse]
def f(x: bool) -> int:
if x:
assert False
print('hello')
return 2
[out]
def f(x):
x, r0 :: bool
r1 :: str
r2 :: object
r3 :: str
r4, r5 :: object
L0:
if x goto L1 else goto L4 :: bool
L1:
if 0 goto L3 else goto L2 :: bool
L2:
r0 = raise AssertionError
unreachable
L3:
r1 = 'hello'
r2 = builtins :: module
r3 = 'print'
r4 = CPyObject_GetAttr(r2, r3)
r5 = PyObject_CallFunctionObjArgs(r4, r1, 0)
L4:
return 4