Skip to content

Commit

Permalink
fix[lang]: fix uses analysis for nonreentrant functions (vyperlang#…
Browse files Browse the repository at this point in the history
…3927)

`uses` analysis ignores nonreentrant functions, even though those that
(implicitly) use state.

this commit adds checks both for internally (called) and external
(exported) modules

misc/refactor:
- factor out `uses_state()` util
- rename `validate_functions` to more accurate `analyze_functions`
- improve locality of exceptions thrown in check_module_uses
  • Loading branch information
charles-cooper authored and electriclilies committed Apr 27, 2024
1 parent 7db4f87 commit bb1e03b
Show file tree
Hide file tree
Showing 8 changed files with 190 additions and 41 deletions.
78 changes: 78 additions & 0 deletions tests/functional/codegen/modules/test_nonreentrant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
def test_export_nonreentrant(make_input_bundle, get_contract, tx_failed):
lib1 = """
interface Foo:
def foo() -> uint256: nonpayable
implements: Foo
@external
@nonreentrant
def foo() -> uint256:
return 5
"""
main = """
import lib1
initializes: lib1
exports: lib1.foo
@external
@nonreentrant
def re_enter():
extcall lib1.Foo(self).foo() # should always throw
@external
def __default__():
# sanity: make sure we don't revert due to bad selector
pass
"""

input_bundle = make_input_bundle({"lib1.vy": lib1})

c = get_contract(main, input_bundle=input_bundle)
assert c.foo() == 5
with tx_failed():
c.re_enter()


def test_internal_nonreentrant(make_input_bundle, get_contract, tx_failed):
lib1 = """
interface Foo:
def foo() -> uint256: nonpayable
implements: Foo
@external
def foo() -> uint256:
return self._safe_fn()
@internal
@nonreentrant
def _safe_fn() -> uint256:
return 10
"""
main = """
import lib1
initializes: lib1
exports: lib1.foo
@external
@nonreentrant
def re_enter():
extcall lib1.Foo(self).foo() # should always throw
@external
def __default__():
# sanity: make sure we don't revert due to bad selector
pass
"""

input_bundle = make_input_bundle({"lib1.vy": lib1})

c = get_contract(main, input_bundle=input_bundle)
assert c.foo() == 10
with tx_failed():
c.re_enter()
3 changes: 3 additions & 0 deletions tests/functional/syntax/modules/helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
NONREENTRANT_NOTE = (
"\n note that use of the `@nonreentrant` decorator is also considered state access"
)
6 changes: 4 additions & 2 deletions tests/functional/syntax/modules/test_exports.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from vyper.compiler import compile_code
from vyper.exceptions import ImmutableViolation, NamespaceCollision, StructureException

from .helpers import NONREENTRANT_NOTE


def test_exports_no_uses(make_input_bundle):
lib1 = """
Expand All @@ -21,7 +23,7 @@ def get_counter() -> uint256:
with pytest.raises(ImmutableViolation) as e:
compile_code(main, input_bundle=input_bundle)

assert e.value._message == "Cannot access `lib1` state!"
assert e.value._message == "Cannot access `lib1` state!" + NONREENTRANT_NOTE

expected_hint = "add `uses: lib1` or `initializes: lib1` as a "
expected_hint += "top-level statement to your contract"
Expand All @@ -40,7 +42,7 @@ def test_exports_no_uses_variable(make_input_bundle):
with pytest.raises(ImmutableViolation) as e:
compile_code(main, input_bundle=input_bundle)

assert e.value._message == "Cannot access `lib1` state!"
assert e.value._message == "Cannot access `lib1` state!" + NONREENTRANT_NOTE

expected_hint = "add `uses: lib1` or `initializes: lib1` as a "
expected_hint += "top-level statement to your contract"
Expand Down
77 changes: 65 additions & 12 deletions tests/functional/syntax/modules/test_initializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
UndeclaredDefinition,
)

from .helpers import NONREENTRANT_NOTE


def test_initialize_uses(make_input_bundle):
lib1 = """
Expand Down Expand Up @@ -413,7 +415,7 @@ def foo():
with pytest.raises(ImmutableViolation) as e:
compile_code(main, input_bundle=input_bundle)

assert e.value._message == "Cannot access `lib1` state!"
assert e.value._message == "Cannot access `lib1` state!" + NONREENTRANT_NOTE

expected_hint = "add `uses: lib1` or `initializes: lib1` as a "
expected_hint += "top-level statement to your contract"
Expand Down Expand Up @@ -450,7 +452,7 @@ def __init__():
with pytest.raises(ImmutableViolation) as e:
compile_code(main, input_bundle=input_bundle)

assert e.value._message == "Cannot access `lib1` state!"
assert e.value._message == "Cannot access `lib1` state!" + NONREENTRANT_NOTE

expected_hint = "add `uses: lib1` or `initializes: lib1` as a "
expected_hint += "top-level statement to your contract"
Expand Down Expand Up @@ -491,7 +493,7 @@ def __init__():
with pytest.raises(ImmutableViolation) as e:
compile_code(main, input_bundle=input_bundle)

assert e.value._message == "Cannot access `lib1` state!"
assert e.value._message == "Cannot access `lib1` state!" + NONREENTRANT_NOTE

expected_hint = "add `uses: lib1` or `initializes: lib1` as a "
expected_hint += "top-level statement to your contract"
Expand Down Expand Up @@ -536,7 +538,7 @@ def __init__():
with pytest.raises(ImmutableViolation) as e:
compile_code(main, input_bundle=input_bundle)

assert e.value._message == "Cannot access `lib1` state!"
assert e.value._message == "Cannot access `lib1` state!" + NONREENTRANT_NOTE

expected_hint = "add `uses: lib1` or `initializes: lib1` as a "
expected_hint += "top-level statement to your contract"
Expand Down Expand Up @@ -571,7 +573,7 @@ def __init__():
with pytest.raises(ImmutableViolation) as e:
compile_code(main, input_bundle=input_bundle)

assert e.value._message == "Cannot access `lib1` state!"
assert e.value._message == "Cannot access `lib1` state!" + NONREENTRANT_NOTE

expected_hint = "add `uses: lib1` or `initializes: lib1` as a "
expected_hint += "top-level statement to your contract"
Expand Down Expand Up @@ -612,7 +614,7 @@ def __init__():
with pytest.raises(ImmutableViolation) as e:
compile_code(main, input_bundle=input_bundle)

assert e.value._message == "Cannot access `lib1` state!"
assert e.value._message == "Cannot access `lib1` state!" + NONREENTRANT_NOTE

expected_hint = "add `uses: lib1` or `initializes: lib1` as a "
expected_hint += "top-level statement to your contract"
Expand Down Expand Up @@ -656,7 +658,7 @@ def __init__():
with pytest.raises(ImmutableViolation) as e:
compile_code(main, input_bundle=input_bundle)

assert e.value._message == "Cannot access `lib1` state!"
assert e.value._message == "Cannot access `lib1` state!" + NONREENTRANT_NOTE

expected_hint = "add `uses: lib1` or `initializes: lib1` as a "
expected_hint += "top-level statement to your contract"
Expand Down Expand Up @@ -695,7 +697,7 @@ def foo():
with pytest.raises(ImmutableViolation) as e:
compile_code(main, input_bundle=input_bundle)

assert e.value._message == "Cannot access `lib1` state!"
assert e.value._message == "Cannot access `lib1` state!" + NONREENTRANT_NOTE

expected_hint = "add `uses: lib1` or `initializes: lib1` as a "
expected_hint += "top-level statement to your contract"
Expand Down Expand Up @@ -734,7 +736,7 @@ def foo(new_value: uint256):
with pytest.raises(ImmutableViolation) as e:
compile_code(main, input_bundle=input_bundle)

assert e.value._message == "Cannot access `lib2` state!"
assert e.value._message == "Cannot access `lib2` state!" + NONREENTRANT_NOTE

expected_hint = "add `uses: lib2` or `initializes: lib2` as a "
expected_hint += "top-level statement to your contract"
Expand Down Expand Up @@ -776,7 +778,7 @@ def foo(new_value: uint256):
with pytest.raises(ImmutableViolation) as e:
compile_code(main, input_bundle=input_bundle)

assert e.value._message == "Cannot access `lib2` state!"
assert e.value._message == "Cannot access `lib2` state!" + NONREENTRANT_NOTE

expected_hint = "add `uses: lib2` or `initializes: lib2` as a "
expected_hint += "top-level statement to your contract"
Expand Down Expand Up @@ -819,7 +821,7 @@ def foo(new_value: uint256):
with pytest.raises(ImmutableViolation) as e:
compile_code(main, input_bundle=input_bundle)

assert e.value._message == "Cannot access `lib2` state!"
assert e.value._message == "Cannot access `lib2` state!" + NONREENTRANT_NOTE

expected_hint = "add `uses: lib2` or `initializes: lib2` as a "
expected_hint += "top-level statement to your contract"
Expand Down Expand Up @@ -853,7 +855,7 @@ def foo(new_value: uint256):
with pytest.raises(ImmutableViolation) as e:
compile_code(main, input_bundle=input_bundle)

assert e.value._message == "Cannot access `lib1` state!"
assert e.value._message == "Cannot access `lib1` state!" + NONREENTRANT_NOTE

expected_hint = "add `uses: lib1` or `initializes: lib1` as a "
expected_hint += "top-level statement to your contract"
Expand Down Expand Up @@ -1296,3 +1298,54 @@ def foo():
compile_code(main, input_bundle=input_bundle)
assert e.value._message == "`lib2` uses `lib1`, but it is not initialized with `lib1`"
assert e.value._hint == "try importing lib1 first"


def test_nonreentrant_exports(make_input_bundle):
lib1 = """
# lib1.vy
@external
@nonreentrant
def bar():
pass
"""
main = """
import lib1
exports: lib1.bar # line 4
@external
def foo():
pass
"""
input_bundle = make_input_bundle({"lib1.vy": lib1})
with pytest.raises(ImmutableViolation) as e:
compile_code(main, input_bundle=input_bundle)
assert e.value._message == "Cannot access `lib1` state!" + NONREENTRANT_NOTE
hint = "add `uses: lib1` or `initializes: lib1` as a top-level statement to your contract"
assert e.value._hint == hint
assert e.value.annotations[0].lineno == 4


def test_internal_nonreentrant_import(make_input_bundle):
lib1 = """
# lib1.vy
@internal
@nonreentrant
def bar():
pass
"""
main = """
import lib1
@external
def foo():
lib1.bar() # line 6
"""
input_bundle = make_input_bundle({"lib1.vy": lib1})
with pytest.raises(ImmutableViolation) as e:
compile_code(main, input_bundle=input_bundle)
assert e.value._message == "Cannot access `lib1` state!" + NONREENTRANT_NOTE

hint = "add `uses: lib1` or `initializes: lib1` as a top-level statement to your contract"
assert e.value._hint == hint
assert e.value.annotations[0].lineno == 6
37 changes: 21 additions & 16 deletions vyper/semantics/analysis/local.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# CMC 2024-02-03 TODO: split me into function.py and expr.py
# CMC 2024-02-03 TODO: rename me to function.py

import contextlib
from typing import Optional
Expand Down Expand Up @@ -35,6 +35,7 @@
get_exact_type_from_node,
get_expr_info,
get_possible_types_from_node,
uses_state,
validate_expected_type,
)
from vyper.semantics.data_locations import DataLocation
Expand Down Expand Up @@ -64,30 +65,30 @@
from vyper.semantics.types.utils import type_from_annotation


def validate_functions(vy_module: vy_ast.Module) -> None:
def analyze_functions(vy_module: vy_ast.Module) -> None:
"""Analyzes a vyper ast and validates the function bodies"""
err_list = ExceptionList()

for node in vy_module.get_children(vy_ast.FunctionDef):
_validate_function_r(vy_module, node, err_list)
_analyze_function_r(vy_module, node, err_list)

for node in vy_module.get_children(vy_ast.VariableDecl):
if not node.is_public:
continue
_validate_function_r(vy_module, node._expanded_getter, err_list)
_analyze_function_r(vy_module, node._expanded_getter, err_list)

err_list.raise_if_not_empty()


def _validate_function_r(
def _analyze_function_r(
vy_module: vy_ast.Module, node: vy_ast.FunctionDef, err_list: ExceptionList
):
func_t = node._metadata["func_type"]

for call_t in func_t.called_functions:
if isinstance(call_t, ContractFunctionT):
assert isinstance(call_t.ast_def, vy_ast.FunctionDef) # help mypy
_validate_function_r(vy_module, call_t.ast_def, err_list)
_analyze_function_r(vy_module, call_t.ast_def, err_list)

namespace = get_namespace()

Expand Down Expand Up @@ -267,7 +268,14 @@ def check_module_uses(node: vy_ast.ExprNode) -> Optional[ModuleInfo]:

for module_info in module_infos:
if module_info.ownership < ModuleOwnership.USES:
msg = f"Cannot access `{module_info.alias}` state!"
msg = f"Cannot access `{module_info.alias}` state!\n note that"
# CMC 2024-04-12 add UX note about nonreentrant. might be nice
# in the future to be more specific about exactly which state is
# used, although that requires threading a bit more context into
# this function.
msg += " use of the `@nonreentrant` decorator is also considered"
msg += " state access"

hint = f"add `uses: {module_info.alias}` or "
hint += f"`initializes: {module_info.alias}` as "
hint += "a top-level statement to your contract"
Expand Down Expand Up @@ -443,10 +451,7 @@ def _handle_modification(self, target: vy_ast.ExprNode):

info._writes.add(var_access)

def _handle_module_access(self, var_access: VarAccess, target: vy_ast.ExprNode):
if not var_access.variable.is_state_variable():
return

def _handle_module_access(self, target: vy_ast.ExprNode):
root_module_info = check_module_uses(target)

if root_module_info is not None:
Expand Down Expand Up @@ -682,9 +687,9 @@ def visit(self, node, typ):
msg += f" `{var.decl_node.target.id}`"
raise ImmutableViolation(msg, var.decl_node, node)

variable_accesses = info._writes | info._reads
for s in variable_accesses:
self.function_analyzer._handle_module_access(s, node)
var_accesses = info._writes | info._reads
if uses_state(var_accesses):
self.function_analyzer._handle_module_access(node)

self.func.mark_variable_writes(info._writes)
self.func.mark_variable_reads(info._reads)
Expand Down Expand Up @@ -787,8 +792,8 @@ def visit_Call(self, node: vy_ast.Call, typ: VyperType) -> None:
if self.function_analyzer:
self._check_call_mutability(func_type.mutability)

for s in func_type.get_variable_accesses():
self.function_analyzer._handle_module_access(s, node.func)
if func_type.uses_state():
self.function_analyzer._handle_module_access(node.func)

if func_type.is_deploy and not self.func.is_deploy:
raise CallViolation(
Expand Down
Loading

0 comments on commit bb1e03b

Please sign in to comment.