Skip to content

Commit

Permalink
rewrite function termination detection
Browse files Browse the repository at this point in the history
fix dead code detection
fix detection of returns after valid branching returns
add tests
  • Loading branch information
charles-cooper committed Jan 14, 2024
1 parent 25bd14a commit 7b2c74b
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 42 deletions.
1 change: 0 additions & 1 deletion tests/functional/codegen/features/test_conditionals.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ def foo(i: bool) -> int128:
else:
assert 2 != 0
return 7
return 11
"""

c = get_contract_with_gas_estimation(conditional_return_code)
Expand Down
43 changes: 33 additions & 10 deletions tests/functional/syntax/test_unbalanced_return.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
"""
@external
def foo() -> int128:
pass
pass # missing return
""",
FunctionDeclarationException,
),
Expand All @@ -18,6 +18,7 @@ def foo() -> int128:
def foo() -> int128:
if False:
return 123
# missing return
""",
FunctionDeclarationException,
),
Expand All @@ -27,19 +28,19 @@ def foo() -> int128:
def test() -> int128:
if 1 == 1 :
return 1
if True:
if True: # unreachable
return 0
else:
assert msg.sender != msg.sender
""",
FunctionDeclarationException,
StructureException,
),
(
"""
@internal
def valid_address(sender: address) -> bool:
selfdestruct(sender)
return True
return True # unreachable
""",
StructureException,
),
Expand All @@ -48,7 +49,7 @@ def valid_address(sender: address) -> bool:
@internal
def valid_address(sender: address) -> bool:
selfdestruct(sender)
a: address = sender
a: address = sender # unreachable
""",
StructureException,
),
Expand All @@ -58,7 +59,7 @@ def valid_address(sender: address) -> bool:
def valid_address(sender: address) -> bool:
if sender == empty(address):
selfdestruct(sender)
_sender: address = sender
_sender: address = sender # unreachable
else:
return False
""",
Expand All @@ -69,7 +70,7 @@ def valid_address(sender: address) -> bool:
@internal
def foo() -> bool:
raw_revert(b"vyper")
return True
return True # unreachable
""",
StructureException,
),
Expand All @@ -78,7 +79,7 @@ def foo() -> bool:
@internal
def foo() -> bool:
raw_revert(b"vyper")
x: uint256 = 3
x: uint256 = 3 # unreachable
""",
StructureException,
),
Expand All @@ -88,12 +89,35 @@ def foo() -> bool:
def foo(x: uint256) -> bool:
if x == 2:
raw_revert(b"vyper")
a: uint256 = 3
a: uint256 = 3 # unreachable
else:
return False
""",
StructureException,
),
(
"""
@internal
def foo():
return
return # unreachable
""",
StructureException,
),
(
"""
@internal
def foo() -> uint256:
if block.number % 2 == 0:
return 5
elif block.number % 3 == 0:
return 6
else:
return 10
return 0 # unreachable
""",
StructureException,
),
]


Expand Down Expand Up @@ -154,7 +178,6 @@ def test() -> int128:
else:
x = keccak256(x)
return 1
return 1
""",
"""
@external
Expand Down
57 changes: 26 additions & 31 deletions vyper/semantics/analysis/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,35 +69,28 @@ def validate_functions(vy_module: vy_ast.Module) -> None:
err_list.raise_if_not_empty()


def check_for_terminus(node_list: list) -> bool:
terminus_nodes = []

# Check for invalid code after returns
last_node_pos = len(node_list) - 1
for idx, n in enumerate(node_list):
if n.is_terminus:
terminus_nodes.append(n)
if idx < last_node_pos:
# is not last statement in body.
raise StructureException(
"Exit statement with succeeding code (that will not execute).",
node_list[idx + 1],
)
# finds the terminus node for a list of nodes.
# raises an exception if any nodes are unreachable
def find_terminating_node(node_list: list) -> Optional[vy_ast.VyperNode]:
ret = None

if len(terminus_nodes) > 1:
raise StructureException(
"Too many exit statements (return, raise or selfdestruct).", terminus_nodes[-1]
)
elif len(terminus_nodes) == 1:
return True
for node in node_list:
if ret is not None:
raise StructureException("Unreachable code!", node)
if node.is_terminus:
ret = node

for node in [i for i in node_list if isinstance(i, vy_ast.If)][::-1]:
if not node.orelse or not check_for_terminus(node.orelse):
continue
if not check_for_terminus(node.body):
continue
return True
return False
if isinstance(node, vy_ast.If):
body_terminates = find_terminating_node(node.body)

else_terminates = None
if node.orelse is not None:
else_terminates = find_terminating_node(node.orelse)

if body_terminates is not None and else_terminates is not None:
ret = else_terminates

return ret


def _check_iterator_modification(
Expand Down Expand Up @@ -213,11 +206,13 @@ def analyze(self):
self.visit(node)

if self.func.return_type:
if not check_for_terminus(self.fn_node.body):
if not find_terminating_node(self.fn_node.body):
raise FunctionDeclarationException(
f"Missing or unmatched return statements in function '{self.fn_node.name}'",
self.fn_node,
f"Missing return statement in function '{self.fn_node.name}'", self.fn_node
)
else:
# call find_terminator for its unreachable code detection side effect
find_terminating_node(self.fn_node.body)

# visit default args
assert self.func.n_keyword_args == len(self.fn_node.args.defaults)
Expand Down Expand Up @@ -519,7 +514,7 @@ def visit_Return(self, node):
raise FunctionDeclarationException("Return statement is missing a value", node)
return
elif self.func.return_type is None:
raise FunctionDeclarationException("Function does not return any values", node)
raise FunctionDeclarationException("Function should not return any values", node)

if isinstance(values, vy_ast.Tuple):
values = values.elements
Expand Down

0 comments on commit 7b2c74b

Please sign in to comment.