Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix[lang]: fix uses analysis for nonreentrant functions #3927

Merged
merged 12 commits into from
Apr 12, 2024
78 changes: 78 additions & 0 deletions tests/functional/codegen/modules/test_nonreentrant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
def test_export_nonreentrant(make_input_bundle, get_contract, tx_failed):
lib1 = """
interface Foo:
def foo() -> uint256: nonpayable

implements: Foo

@external
@nonreentrant
def foo() -> uint256:
return 5
"""
main = """
import lib1

initializes: lib1

exports: lib1.foo

@external
@nonreentrant
def re_enter():
extcall lib1.Foo(self).foo() # should always throw

@external
def __default__():
# sanity: make sure we don't revert due to bad selector
pass
"""

input_bundle = make_input_bundle({"lib1.vy": lib1})

c = get_contract(main, input_bundle=input_bundle)
assert c.foo() == 5
with tx_failed():
c.re_enter()


def test_internal_nonreentrant(make_input_bundle, get_contract, tx_failed):
lib1 = """
interface Foo:
def foo() -> uint256: nonpayable

implements: Foo

@external
def foo() -> uint256:
return self._safe_fn()

@internal
@nonreentrant
def _safe_fn() -> uint256:
return 10
"""
main = """
import lib1

initializes: lib1

exports: lib1.foo

@external
@nonreentrant
def re_enter():
extcall lib1.Foo(self).foo() # should always throw

@external
def __default__():
# sanity: make sure we don't revert due to bad selector
pass
"""

input_bundle = make_input_bundle({"lib1.vy": lib1})

c = get_contract(main, input_bundle=input_bundle)
assert c.foo() == 10
with tx_failed():
c.re_enter()
6 changes: 4 additions & 2 deletions tests/functional/syntax/modules/test_exports.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from vyper.compiler import compile_code
from vyper.exceptions import ImmutableViolation, NamespaceCollision, StructureException

from .helpers import NONREENTRANT_NOTE


def test_exports_no_uses(make_input_bundle):
lib1 = """
Expand All @@ -21,7 +23,7 @@ def get_counter() -> uint256:
with pytest.raises(ImmutableViolation) as e:
compile_code(main, input_bundle=input_bundle)

assert e.value._message == "Cannot access `lib1` state!"
assert e.value._message == "Cannot access `lib1` state!" + NONREENTRANT_NOTE

expected_hint = "add `uses: lib1` or `initializes: lib1` as a "
expected_hint += "top-level statement to your contract"
Expand All @@ -40,7 +42,7 @@ def test_exports_no_uses_variable(make_input_bundle):
with pytest.raises(ImmutableViolation) as e:
compile_code(main, input_bundle=input_bundle)

assert e.value._message == "Cannot access `lib1` state!"
assert e.value._message == "Cannot access `lib1` state!" + NONREENTRANT_NOTE

expected_hint = "add `uses: lib1` or `initializes: lib1` as a "
expected_hint += "top-level statement to your contract"
Expand Down
77 changes: 65 additions & 12 deletions tests/functional/syntax/modules/test_initializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
UndeclaredDefinition,
)

from .helpers import NONREENTRANT_NOTE


def test_initialize_uses(make_input_bundle):
lib1 = """
Expand Down Expand Up @@ -413,7 +415,7 @@ def foo():
with pytest.raises(ImmutableViolation) as e:
compile_code(main, input_bundle=input_bundle)

assert e.value._message == "Cannot access `lib1` state!"
assert e.value._message == "Cannot access `lib1` state!" + NONREENTRANT_NOTE

expected_hint = "add `uses: lib1` or `initializes: lib1` as a "
expected_hint += "top-level statement to your contract"
Expand Down Expand Up @@ -450,7 +452,7 @@ def __init__():
with pytest.raises(ImmutableViolation) as e:
compile_code(main, input_bundle=input_bundle)

assert e.value._message == "Cannot access `lib1` state!"
assert e.value._message == "Cannot access `lib1` state!" + NONREENTRANT_NOTE

expected_hint = "add `uses: lib1` or `initializes: lib1` as a "
expected_hint += "top-level statement to your contract"
Expand Down Expand Up @@ -491,7 +493,7 @@ def __init__():
with pytest.raises(ImmutableViolation) as e:
compile_code(main, input_bundle=input_bundle)

assert e.value._message == "Cannot access `lib1` state!"
assert e.value._message == "Cannot access `lib1` state!" + NONREENTRANT_NOTE

expected_hint = "add `uses: lib1` or `initializes: lib1` as a "
expected_hint += "top-level statement to your contract"
Expand Down Expand Up @@ -536,7 +538,7 @@ def __init__():
with pytest.raises(ImmutableViolation) as e:
compile_code(main, input_bundle=input_bundle)

assert e.value._message == "Cannot access `lib1` state!"
assert e.value._message == "Cannot access `lib1` state!" + NONREENTRANT_NOTE

expected_hint = "add `uses: lib1` or `initializes: lib1` as a "
expected_hint += "top-level statement to your contract"
Expand Down Expand Up @@ -571,7 +573,7 @@ def __init__():
with pytest.raises(ImmutableViolation) as e:
compile_code(main, input_bundle=input_bundle)

assert e.value._message == "Cannot access `lib1` state!"
assert e.value._message == "Cannot access `lib1` state!" + NONREENTRANT_NOTE

expected_hint = "add `uses: lib1` or `initializes: lib1` as a "
expected_hint += "top-level statement to your contract"
Expand Down Expand Up @@ -612,7 +614,7 @@ def __init__():
with pytest.raises(ImmutableViolation) as e:
compile_code(main, input_bundle=input_bundle)

assert e.value._message == "Cannot access `lib1` state!"
assert e.value._message == "Cannot access `lib1` state!" + NONREENTRANT_NOTE

expected_hint = "add `uses: lib1` or `initializes: lib1` as a "
expected_hint += "top-level statement to your contract"
Expand Down Expand Up @@ -656,7 +658,7 @@ def __init__():
with pytest.raises(ImmutableViolation) as e:
compile_code(main, input_bundle=input_bundle)

assert e.value._message == "Cannot access `lib1` state!"
assert e.value._message == "Cannot access `lib1` state!" + NONREENTRANT_NOTE

expected_hint = "add `uses: lib1` or `initializes: lib1` as a "
expected_hint += "top-level statement to your contract"
Expand Down Expand Up @@ -695,7 +697,7 @@ def foo():
with pytest.raises(ImmutableViolation) as e:
compile_code(main, input_bundle=input_bundle)

assert e.value._message == "Cannot access `lib1` state!"
assert e.value._message == "Cannot access `lib1` state!" + NONREENTRANT_NOTE

expected_hint = "add `uses: lib1` or `initializes: lib1` as a "
expected_hint += "top-level statement to your contract"
Expand Down Expand Up @@ -734,7 +736,7 @@ def foo(new_value: uint256):
with pytest.raises(ImmutableViolation) as e:
compile_code(main, input_bundle=input_bundle)

assert e.value._message == "Cannot access `lib2` state!"
assert e.value._message == "Cannot access `lib2` state!" + NONREENTRANT_NOTE

expected_hint = "add `uses: lib2` or `initializes: lib2` as a "
expected_hint += "top-level statement to your contract"
Expand Down Expand Up @@ -776,7 +778,7 @@ def foo(new_value: uint256):
with pytest.raises(ImmutableViolation) as e:
compile_code(main, input_bundle=input_bundle)

assert e.value._message == "Cannot access `lib2` state!"
assert e.value._message == "Cannot access `lib2` state!" + NONREENTRANT_NOTE

expected_hint = "add `uses: lib2` or `initializes: lib2` as a "
expected_hint += "top-level statement to your contract"
Expand Down Expand Up @@ -819,7 +821,7 @@ def foo(new_value: uint256):
with pytest.raises(ImmutableViolation) as e:
compile_code(main, input_bundle=input_bundle)

assert e.value._message == "Cannot access `lib2` state!"
assert e.value._message == "Cannot access `lib2` state!" + NONREENTRANT_NOTE

expected_hint = "add `uses: lib2` or `initializes: lib2` as a "
expected_hint += "top-level statement to your contract"
Expand Down Expand Up @@ -853,7 +855,7 @@ def foo(new_value: uint256):
with pytest.raises(ImmutableViolation) as e:
compile_code(main, input_bundle=input_bundle)

assert e.value._message == "Cannot access `lib1` state!"
assert e.value._message == "Cannot access `lib1` state!" + NONREENTRANT_NOTE

expected_hint = "add `uses: lib1` or `initializes: lib1` as a "
expected_hint += "top-level statement to your contract"
Expand Down Expand Up @@ -1296,3 +1298,54 @@ def foo():
compile_code(main, input_bundle=input_bundle)
assert e.value._message == "`lib2` uses `lib1`, but it is not initialized with `lib1`"
assert e.value._hint == "try importing lib1 first"


def test_nonreentrant_exports(make_input_bundle):
lib1 = """
# lib1.vy
@external
@nonreentrant
def bar():
pass
"""
main = """
import lib1

exports: lib1.bar # line 4

@external
def foo():
pass
"""
input_bundle = make_input_bundle({"lib1.vy": lib1})
with pytest.raises(ImmutableViolation) as e:
compile_code(main, input_bundle=input_bundle)
assert e.value._message == "Cannot access `lib1` state!" + NONREENTRANT_NOTE
hint = "add `uses: lib1` or `initializes: lib1` as a top-level statement to your contract"
assert e.value._hint == hint
assert e.value.annotations[0].lineno == 4


def test_internal_nonreentrant_import(make_input_bundle):
lib1 = """
# lib1.vy
@internal
@nonreentrant
def bar():
pass
"""
main = """
import lib1

@external
def foo():
lib1.bar() # line 6
"""
input_bundle = make_input_bundle({"lib1.vy": lib1})
with pytest.raises(ImmutableViolation) as e:
compile_code(main, input_bundle=input_bundle)
assert e.value._message == "Cannot access `lib1` state!" + NONREENTRANT_NOTE

hint = "add `uses: lib1` or `initializes: lib1` as a top-level statement to your contract"
assert e.value._hint == hint
assert e.value.annotations[0].lineno == 6
37 changes: 21 additions & 16 deletions vyper/semantics/analysis/local.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# CMC 2024-02-03 TODO: split me into function.py and expr.py
# CMC 2024-02-03 TODO: rename me to function.py

import contextlib
from typing import Optional
Expand Down Expand Up @@ -35,6 +35,7 @@
get_exact_type_from_node,
get_expr_info,
get_possible_types_from_node,
uses_state,
validate_expected_type,
)
from vyper.semantics.data_locations import DataLocation
Expand Down Expand Up @@ -64,30 +65,30 @@
from vyper.semantics.types.utils import type_from_annotation


def validate_functions(vy_module: vy_ast.Module) -> None:
def analyze_functions(vy_module: vy_ast.Module) -> None:
"""Analyzes a vyper ast and validates the function bodies"""
err_list = ExceptionList()

for node in vy_module.get_children(vy_ast.FunctionDef):
_validate_function_r(vy_module, node, err_list)
_analyze_function_r(vy_module, node, err_list)

for node in vy_module.get_children(vy_ast.VariableDecl):
if not node.is_public:
continue
_validate_function_r(vy_module, node._expanded_getter, err_list)
_analyze_function_r(vy_module, node._expanded_getter, err_list)

err_list.raise_if_not_empty()


def _validate_function_r(
def _analyze_function_r(
vy_module: vy_ast.Module, node: vy_ast.FunctionDef, err_list: ExceptionList
):
func_t = node._metadata["func_type"]

for call_t in func_t.called_functions:
if isinstance(call_t, ContractFunctionT):
assert isinstance(call_t.ast_def, vy_ast.FunctionDef) # help mypy
_validate_function_r(vy_module, call_t.ast_def, err_list)
_analyze_function_r(vy_module, call_t.ast_def, err_list)

namespace = get_namespace()

Expand Down Expand Up @@ -267,7 +268,14 @@ def check_module_uses(node: vy_ast.ExprNode) -> Optional[ModuleInfo]:

for module_info in module_infos:
if module_info.ownership < ModuleOwnership.USES:
msg = f"Cannot access `{module_info.alias}` state!"
msg = f"Cannot access `{module_info.alias}` state!\n note that"
# CMC 2024-04-12 add UX note about nonreentrant. might be nice
# in the future to be more specific about exactly which state is
# used, although that requires threading a bit more context into
# this function.
msg += " use of the `@nonreentrant` decorator is also considered"
msg += " state access"

hint = f"add `uses: {module_info.alias}` or "
hint += f"`initializes: {module_info.alias}` as "
hint += "a top-level statement to your contract"
Expand Down Expand Up @@ -443,10 +451,7 @@ def _handle_modification(self, target: vy_ast.ExprNode):

info._writes.add(var_access)

def _handle_module_access(self, var_access: VarAccess, target: vy_ast.ExprNode):
if not var_access.variable.is_state_variable():
return

def _handle_module_access(self, target: vy_ast.ExprNode):
root_module_info = check_module_uses(target)

if root_module_info is not None:
Expand Down Expand Up @@ -682,9 +687,9 @@ def visit(self, node, typ):
msg += f" `{var.decl_node.target.id}`"
raise ImmutableViolation(msg, var.decl_node, node)

variable_accesses = info._writes | info._reads
for s in variable_accesses:
self.function_analyzer._handle_module_access(s, node)
var_accesses = info._writes | info._reads
if uses_state(var_accesses):
self.function_analyzer._handle_module_access(node)

self.func.mark_variable_writes(info._writes)
self.func.mark_variable_reads(info._reads)
Expand Down Expand Up @@ -787,8 +792,8 @@ def visit_Call(self, node: vy_ast.Call, typ: VyperType) -> None:
if self.function_analyzer:
self._check_call_mutability(func_type.mutability)

for s in func_type.get_variable_accesses():
self.function_analyzer._handle_module_access(s, node.func)
if func_type.uses_state():
self.function_analyzer._handle_module_access(node.func)

if func_type.is_deploy and not self.func.is_deploy:
raise CallViolation(
Expand Down
Loading
Loading