diff --git a/tests/functional/codegen/modules/test_nonreentrant.py b/tests/functional/codegen/modules/test_nonreentrant.py new file mode 100644 index 0000000000..69b17cbfa2 --- /dev/null +++ b/tests/functional/codegen/modules/test_nonreentrant.py @@ -0,0 +1,78 @@ +def test_export_nonreentrant(make_input_bundle, get_contract, tx_failed): + lib1 = """ +interface Foo: + def foo() -> uint256: nonpayable + +implements: Foo + +@external +@nonreentrant +def foo() -> uint256: + return 5 + """ + main = """ +import lib1 + +initializes: lib1 + +exports: lib1.foo + +@external +@nonreentrant +def re_enter(): + extcall lib1.Foo(self).foo() # should always throw + +@external +def __default__(): + # sanity: make sure we don't revert due to bad selector + pass + """ + + input_bundle = make_input_bundle({"lib1.vy": lib1}) + + c = get_contract(main, input_bundle=input_bundle) + assert c.foo() == 5 + with tx_failed(): + c.re_enter() + + +def test_internal_nonreentrant(make_input_bundle, get_contract, tx_failed): + lib1 = """ +interface Foo: + def foo() -> uint256: nonpayable + +implements: Foo + +@external +def foo() -> uint256: + return self._safe_fn() + +@internal +@nonreentrant +def _safe_fn() -> uint256: + return 10 + """ + main = """ +import lib1 + +initializes: lib1 + +exports: lib1.foo + +@external +@nonreentrant +def re_enter(): + extcall lib1.Foo(self).foo() # should always throw + +@external +def __default__(): + # sanity: make sure we don't revert due to bad selector + pass + """ + + input_bundle = make_input_bundle({"lib1.vy": lib1}) + + c = get_contract(main, input_bundle=input_bundle) + assert c.foo() == 10 + with tx_failed(): + c.re_enter() diff --git a/tests/functional/syntax/modules/helpers.py b/tests/functional/syntax/modules/helpers.py new file mode 100644 index 0000000000..2a54073afb --- /dev/null +++ b/tests/functional/syntax/modules/helpers.py @@ -0,0 +1,3 @@ +NONREENTRANT_NOTE = ( + "\n note that use of the `@nonreentrant` decorator is also considered state access" +) diff --git a/tests/functional/syntax/modules/test_exports.py b/tests/functional/syntax/modules/test_exports.py index 24a233da9d..1edb99bc7f 100644 --- a/tests/functional/syntax/modules/test_exports.py +++ b/tests/functional/syntax/modules/test_exports.py @@ -3,6 +3,8 @@ from vyper.compiler import compile_code from vyper.exceptions import ImmutableViolation, NamespaceCollision, StructureException +from .helpers import NONREENTRANT_NOTE + def test_exports_no_uses(make_input_bundle): lib1 = """ @@ -21,7 +23,7 @@ def get_counter() -> uint256: with pytest.raises(ImmutableViolation) as e: compile_code(main, input_bundle=input_bundle) - assert e.value._message == "Cannot access `lib1` state!" + assert e.value._message == "Cannot access `lib1` state!" + NONREENTRANT_NOTE expected_hint = "add `uses: lib1` or `initializes: lib1` as a " expected_hint += "top-level statement to your contract" @@ -40,7 +42,7 @@ def test_exports_no_uses_variable(make_input_bundle): with pytest.raises(ImmutableViolation) as e: compile_code(main, input_bundle=input_bundle) - assert e.value._message == "Cannot access `lib1` state!" + assert e.value._message == "Cannot access `lib1` state!" + NONREENTRANT_NOTE expected_hint = "add `uses: lib1` or `initializes: lib1` as a " expected_hint += "top-level statement to your contract" diff --git a/tests/functional/syntax/modules/test_initializers.py b/tests/functional/syntax/modules/test_initializers.py index f6697afea1..2193050a5f 100644 --- a/tests/functional/syntax/modules/test_initializers.py +++ b/tests/functional/syntax/modules/test_initializers.py @@ -17,6 +17,8 @@ UndeclaredDefinition, ) +from .helpers import NONREENTRANT_NOTE + def test_initialize_uses(make_input_bundle): lib1 = """ @@ -413,7 +415,7 @@ def foo(): with pytest.raises(ImmutableViolation) as e: compile_code(main, input_bundle=input_bundle) - assert e.value._message == "Cannot access `lib1` state!" + assert e.value._message == "Cannot access `lib1` state!" + NONREENTRANT_NOTE expected_hint = "add `uses: lib1` or `initializes: lib1` as a " expected_hint += "top-level statement to your contract" @@ -450,7 +452,7 @@ def __init__(): with pytest.raises(ImmutableViolation) as e: compile_code(main, input_bundle=input_bundle) - assert e.value._message == "Cannot access `lib1` state!" + assert e.value._message == "Cannot access `lib1` state!" + NONREENTRANT_NOTE expected_hint = "add `uses: lib1` or `initializes: lib1` as a " expected_hint += "top-level statement to your contract" @@ -491,7 +493,7 @@ def __init__(): with pytest.raises(ImmutableViolation) as e: compile_code(main, input_bundle=input_bundle) - assert e.value._message == "Cannot access `lib1` state!" + assert e.value._message == "Cannot access `lib1` state!" + NONREENTRANT_NOTE expected_hint = "add `uses: lib1` or `initializes: lib1` as a " expected_hint += "top-level statement to your contract" @@ -536,7 +538,7 @@ def __init__(): with pytest.raises(ImmutableViolation) as e: compile_code(main, input_bundle=input_bundle) - assert e.value._message == "Cannot access `lib1` state!" + assert e.value._message == "Cannot access `lib1` state!" + NONREENTRANT_NOTE expected_hint = "add `uses: lib1` or `initializes: lib1` as a " expected_hint += "top-level statement to your contract" @@ -571,7 +573,7 @@ def __init__(): with pytest.raises(ImmutableViolation) as e: compile_code(main, input_bundle=input_bundle) - assert e.value._message == "Cannot access `lib1` state!" + assert e.value._message == "Cannot access `lib1` state!" + NONREENTRANT_NOTE expected_hint = "add `uses: lib1` or `initializes: lib1` as a " expected_hint += "top-level statement to your contract" @@ -612,7 +614,7 @@ def __init__(): with pytest.raises(ImmutableViolation) as e: compile_code(main, input_bundle=input_bundle) - assert e.value._message == "Cannot access `lib1` state!" + assert e.value._message == "Cannot access `lib1` state!" + NONREENTRANT_NOTE expected_hint = "add `uses: lib1` or `initializes: lib1` as a " expected_hint += "top-level statement to your contract" @@ -656,7 +658,7 @@ def __init__(): with pytest.raises(ImmutableViolation) as e: compile_code(main, input_bundle=input_bundle) - assert e.value._message == "Cannot access `lib1` state!" + assert e.value._message == "Cannot access `lib1` state!" + NONREENTRANT_NOTE expected_hint = "add `uses: lib1` or `initializes: lib1` as a " expected_hint += "top-level statement to your contract" @@ -695,7 +697,7 @@ def foo(): with pytest.raises(ImmutableViolation) as e: compile_code(main, input_bundle=input_bundle) - assert e.value._message == "Cannot access `lib1` state!" + assert e.value._message == "Cannot access `lib1` state!" + NONREENTRANT_NOTE expected_hint = "add `uses: lib1` or `initializes: lib1` as a " expected_hint += "top-level statement to your contract" @@ -734,7 +736,7 @@ def foo(new_value: uint256): with pytest.raises(ImmutableViolation) as e: compile_code(main, input_bundle=input_bundle) - assert e.value._message == "Cannot access `lib2` state!" + assert e.value._message == "Cannot access `lib2` state!" + NONREENTRANT_NOTE expected_hint = "add `uses: lib2` or `initializes: lib2` as a " expected_hint += "top-level statement to your contract" @@ -776,7 +778,7 @@ def foo(new_value: uint256): with pytest.raises(ImmutableViolation) as e: compile_code(main, input_bundle=input_bundle) - assert e.value._message == "Cannot access `lib2` state!" + assert e.value._message == "Cannot access `lib2` state!" + NONREENTRANT_NOTE expected_hint = "add `uses: lib2` or `initializes: lib2` as a " expected_hint += "top-level statement to your contract" @@ -819,7 +821,7 @@ def foo(new_value: uint256): with pytest.raises(ImmutableViolation) as e: compile_code(main, input_bundle=input_bundle) - assert e.value._message == "Cannot access `lib2` state!" + assert e.value._message == "Cannot access `lib2` state!" + NONREENTRANT_NOTE expected_hint = "add `uses: lib2` or `initializes: lib2` as a " expected_hint += "top-level statement to your contract" @@ -853,7 +855,7 @@ def foo(new_value: uint256): with pytest.raises(ImmutableViolation) as e: compile_code(main, input_bundle=input_bundle) - assert e.value._message == "Cannot access `lib1` state!" + assert e.value._message == "Cannot access `lib1` state!" + NONREENTRANT_NOTE expected_hint = "add `uses: lib1` or `initializes: lib1` as a " expected_hint += "top-level statement to your contract" @@ -1296,3 +1298,54 @@ def foo(): compile_code(main, input_bundle=input_bundle) assert e.value._message == "`lib2` uses `lib1`, but it is not initialized with `lib1`" assert e.value._hint == "try importing lib1 first" + + +def test_nonreentrant_exports(make_input_bundle): + lib1 = """ +# lib1.vy +@external +@nonreentrant +def bar(): + pass + """ + main = """ +import lib1 + +exports: lib1.bar # line 4 + +@external +def foo(): + pass + """ + input_bundle = make_input_bundle({"lib1.vy": lib1}) + with pytest.raises(ImmutableViolation) as e: + compile_code(main, input_bundle=input_bundle) + assert e.value._message == "Cannot access `lib1` state!" + NONREENTRANT_NOTE + hint = "add `uses: lib1` or `initializes: lib1` as a top-level statement to your contract" + assert e.value._hint == hint + assert e.value.annotations[0].lineno == 4 + + +def test_internal_nonreentrant_import(make_input_bundle): + lib1 = """ +# lib1.vy +@internal +@nonreentrant +def bar(): + pass + """ + main = """ +import lib1 + +@external +def foo(): + lib1.bar() # line 6 + """ + input_bundle = make_input_bundle({"lib1.vy": lib1}) + with pytest.raises(ImmutableViolation) as e: + compile_code(main, input_bundle=input_bundle) + assert e.value._message == "Cannot access `lib1` state!" + NONREENTRANT_NOTE + + hint = "add `uses: lib1` or `initializes: lib1` as a top-level statement to your contract" + assert e.value._hint == hint + assert e.value.annotations[0].lineno == 6 diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index 37ba371dd8..5b20ef773a 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -1,4 +1,4 @@ -# CMC 2024-02-03 TODO: split me into function.py and expr.py +# CMC 2024-02-03 TODO: rename me to function.py import contextlib from typing import Optional @@ -35,6 +35,7 @@ get_exact_type_from_node, get_expr_info, get_possible_types_from_node, + uses_state, validate_expected_type, ) from vyper.semantics.data_locations import DataLocation @@ -64,22 +65,22 @@ from vyper.semantics.types.utils import type_from_annotation -def validate_functions(vy_module: vy_ast.Module) -> None: +def analyze_functions(vy_module: vy_ast.Module) -> None: """Analyzes a vyper ast and validates the function bodies""" err_list = ExceptionList() for node in vy_module.get_children(vy_ast.FunctionDef): - _validate_function_r(vy_module, node, err_list) + _analyze_function_r(vy_module, node, err_list) for node in vy_module.get_children(vy_ast.VariableDecl): if not node.is_public: continue - _validate_function_r(vy_module, node._expanded_getter, err_list) + _analyze_function_r(vy_module, node._expanded_getter, err_list) err_list.raise_if_not_empty() -def _validate_function_r( +def _analyze_function_r( vy_module: vy_ast.Module, node: vy_ast.FunctionDef, err_list: ExceptionList ): func_t = node._metadata["func_type"] @@ -87,7 +88,7 @@ def _validate_function_r( for call_t in func_t.called_functions: if isinstance(call_t, ContractFunctionT): assert isinstance(call_t.ast_def, vy_ast.FunctionDef) # help mypy - _validate_function_r(vy_module, call_t.ast_def, err_list) + _analyze_function_r(vy_module, call_t.ast_def, err_list) namespace = get_namespace() @@ -267,7 +268,14 @@ def check_module_uses(node: vy_ast.ExprNode) -> Optional[ModuleInfo]: for module_info in module_infos: if module_info.ownership < ModuleOwnership.USES: - msg = f"Cannot access `{module_info.alias}` state!" + msg = f"Cannot access `{module_info.alias}` state!\n note that" + # CMC 2024-04-12 add UX note about nonreentrant. might be nice + # in the future to be more specific about exactly which state is + # used, although that requires threading a bit more context into + # this function. + msg += " use of the `@nonreentrant` decorator is also considered" + msg += " state access" + hint = f"add `uses: {module_info.alias}` or " hint += f"`initializes: {module_info.alias}` as " hint += "a top-level statement to your contract" @@ -443,10 +451,7 @@ def _handle_modification(self, target: vy_ast.ExprNode): info._writes.add(var_access) - def _handle_module_access(self, var_access: VarAccess, target: vy_ast.ExprNode): - if not var_access.variable.is_state_variable(): - return - + def _handle_module_access(self, target: vy_ast.ExprNode): root_module_info = check_module_uses(target) if root_module_info is not None: @@ -682,9 +687,9 @@ def visit(self, node, typ): msg += f" `{var.decl_node.target.id}`" raise ImmutableViolation(msg, var.decl_node, node) - variable_accesses = info._writes | info._reads - for s in variable_accesses: - self.function_analyzer._handle_module_access(s, node) + var_accesses = info._writes | info._reads + if uses_state(var_accesses): + self.function_analyzer._handle_module_access(node) self.func.mark_variable_writes(info._writes) self.func.mark_variable_reads(info._reads) @@ -787,8 +792,8 @@ def visit_Call(self, node: vy_ast.Call, typ: VyperType) -> None: if self.function_analyzer: self._check_call_mutability(func_type.mutability) - for s in func_type.get_variable_accesses(): - self.function_analyzer._handle_module_access(s, node.func) + if func_type.uses_state(): + self.function_analyzer._handle_module_access(node.func) if func_type.is_deploy and not self.func.is_deploy: raise CallViolation( diff --git a/vyper/semantics/analysis/module.py b/vyper/semantics/analysis/module.py index 5cb6ae4f5c..9d3f9ae1ff 100644 --- a/vyper/semantics/analysis/module.py +++ b/vyper/semantics/analysis/module.py @@ -43,7 +43,7 @@ from vyper.semantics.analysis.constant_folding import constant_fold from vyper.semantics.analysis.getters import generate_public_variable_getters from vyper.semantics.analysis.import_graph import ImportGraph -from vyper.semantics.analysis.local import ExprVisitor, check_module_uses, validate_functions +from vyper.semantics.analysis.local import ExprVisitor, analyze_functions, check_module_uses from vyper.semantics.analysis.utils import ( check_modifiability, get_exact_type_from_node, @@ -102,7 +102,7 @@ def _analyze_module_r( # if this is an interface, the function is already validated # in `ContractFunction.from_vyi()` if not is_interface: - validate_functions(module_ast) + analyze_functions(module_ast) analyzer.validate_initialized_modules() analyzer.validate_used_modules() @@ -557,14 +557,13 @@ def visit_ExportsDecl(self, node): with tag_exceptions(item): # tag with specific item self._self_t.typ.add_member(func_t.name, func_t) - funcs.append(func_t) + funcs.append(func_t) - # check module uses - var_accesses = func_t.get_variable_accesses() - if any(s.variable.is_state_variable() for s in var_accesses): - module_info = check_module_uses(item) - assert module_info is not None # guaranteed by above checks - used_modules.add(module_info) + # check module uses + if func_t.uses_state(): + module_info = check_module_uses(item) + assert module_info is not None # guaranteed by above checks + used_modules.add(module_info) node._metadata["exports_info"] = ExportsInfo(funcs, used_modules) diff --git a/vyper/semantics/analysis/utils.py b/vyper/semantics/analysis/utils.py index 4b751e7406..b4b31ca358 100644 --- a/vyper/semantics/analysis/utils.py +++ b/vyper/semantics/analysis/utils.py @@ -1,5 +1,5 @@ import itertools -from typing import Callable, List +from typing import Callable, Iterable, List from vyper import ast as vy_ast from vyper.exceptions import ( @@ -17,7 +17,7 @@ ZeroDivisionException, ) from vyper.semantics import types -from vyper.semantics.analysis.base import ExprInfo, Modifiability, ModuleInfo, VarInfo +from vyper.semantics.analysis.base import ExprInfo, Modifiability, ModuleInfo, VarAccess, VarInfo from vyper.semantics.analysis.levenshtein_utils import get_levenshtein_error_suggestions from vyper.semantics.namespace import get_namespace from vyper.semantics.types.base import TYPE_T, VyperType @@ -48,6 +48,10 @@ def _validate_op(node, types_list, validation_fn_name): raise err_list[0] +def uses_state(var_accesses: Iterable[VarAccess]) -> bool: + return any(s.variable.is_state_variable() for s in var_accesses) + + class _ExprAnalyser: """ Node type-checker class. diff --git a/vyper/semantics/types/function.py b/vyper/semantics/types/function.py index fbeb3e37cd..86fd90f0f9 100644 --- a/vyper/semantics/types/function.py +++ b/vyper/semantics/types/function.py @@ -27,6 +27,7 @@ from vyper.semantics.analysis.utils import ( check_modifiability, get_exact_type_from_node, + uses_state, validate_expected_type, ) from vyper.semantics.data_locations import DataLocation @@ -163,7 +164,11 @@ def get_variable_writes(self): def get_variable_accesses(self): return self._variable_reads | self._variable_writes + def uses_state(self): + return self.nonreentrant or uses_state(self.get_variable_accesses()) + def get_used_modules(self): + # _used_modules is populated during analysis return self._used_modules def mark_used_module(self, module_info):