Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix[lang]: fix uses analysis for nonreentrant functions #3927

Merged
merged 12 commits into from
Apr 12, 2024
78 changes: 78 additions & 0 deletions tests/functional/codegen/modules/test_nonreentrant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
def test_export_nonreentrant(make_input_bundle, get_contract, tx_failed):
lib1 = """
interface Foo:
def foo() -> uint256: nonpayable

implements: Foo

@external
@nonreentrant
def foo() -> uint256:
return 5
"""
main = """
import lib1

initializes: lib1

exports: lib1.foo

@external
@nonreentrant
def re_enter():
extcall lib1.Foo(self).foo() # should always throw

@external
def __default__():
# sanity: make sure we don't revert due to bad selector
pass
"""

input_bundle = make_input_bundle({"lib1.vy": lib1})

c = get_contract(main, input_bundle=input_bundle)
assert c.foo() == 5
with tx_failed():
c.re_enter()


def test_internal_nonreentrant(make_input_bundle, get_contract, tx_failed):
lib1 = """
interface Foo:
def foo() -> uint256: nonpayable

implements: Foo

@external
def foo() -> uint256:
return self._safe_fn()

@internal
@nonreentrant
def _safe_fn() -> uint256:
return 10
"""
main = """
import lib1

initializes: lib1

exports: lib1.foo

@external
@nonreentrant
def re_enter():
extcall lib1.Foo(self).foo() # should always throw

@external
def __default__():
# sanity: make sure we don't revert due to bad selector
pass
"""

input_bundle = make_input_bundle({"lib1.vy": lib1})

c = get_contract(main, input_bundle=input_bundle)
assert c.foo() == 10
with tx_failed():
c.re_enter()
50 changes: 50 additions & 0 deletions tests/functional/syntax/modules/test_initializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1292,3 +1292,53 @@ def foo():
compile_code(main, input_bundle=input_bundle)
assert e.value._message == "`lib2` uses `lib1`, but it is not initialized with `lib1`"
assert e.value._hint == "try importing lib1 first"


def test_nonreentrant_exports(make_input_bundle):
lib1 = """
# lib1.vy
@external
@nonreentrant
def bar():
pass
"""
main = """
import lib1

exports: lib1.bar # line 4

@external
def foo():
pass
"""
input_bundle = make_input_bundle({"lib1.vy": lib1})
with pytest.raises(ImmutableViolation) as e:
compile_code(main, input_bundle=input_bundle)
assert e.value._message == "Cannot access `lib1` state!"
hint = "add `uses: lib1` or `initializes: lib1` as a top-level statement to your contract"
assert e.value._hint == hint
assert e.value.annotations[0].lineno == 4


def test_internal_nonreentrant_import(make_input_bundle):
lib1 = """
# lib1.vy
@internal
@nonreentrant
def bar():
pass
"""
main = """
import lib1

@external
def foo():
lib1.bar() # line 6
"""
input_bundle = make_input_bundle({"lib1.vy": lib1})
with pytest.raises(ImmutableViolation) as e:
compile_code(main, input_bundle=input_bundle)
assert e.value._message == "Cannot access `lib1` state!"
hint = "add `uses: lib1` or `initializes: lib1` as a top-level statement to your contract"
assert e.value._hint == hint
assert e.value.annotations[0].lineno == 6
28 changes: 13 additions & 15 deletions vyper/semantics/analysis/local.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# CMC 2024-02-03 TODO: split me into function.py and expr.py
# CMC 2024-02-03 TODO: rename me to function.py

import contextlib
from typing import Optional
Expand Down Expand Up @@ -33,6 +33,7 @@
get_exact_type_from_node,
get_expr_info,
get_possible_types_from_node,
uses_state,
validate_expected_type,
)
from vyper.semantics.data_locations import DataLocation
Expand Down Expand Up @@ -64,30 +65,30 @@
from vyper.semantics.types.utils import type_from_annotation


def validate_functions(vy_module: vy_ast.Module) -> None:
def analyze_functions(vy_module: vy_ast.Module) -> None:
"""Analyzes a vyper ast and validates the function bodies"""
err_list = ExceptionList()

for node in vy_module.get_children(vy_ast.FunctionDef):
_validate_function_r(vy_module, node, err_list)
_analyze_function_r(vy_module, node, err_list)

for node in vy_module.get_children(vy_ast.VariableDecl):
if not node.is_public:
continue
_validate_function_r(vy_module, node._expanded_getter, err_list)
_analyze_function_r(vy_module, node._expanded_getter, err_list)

err_list.raise_if_not_empty()


def _validate_function_r(
def _analyze_function_r(
vy_module: vy_ast.Module, node: vy_ast.FunctionDef, err_list: ExceptionList
):
func_t = node._metadata["func_type"]

for call_t in func_t.called_functions:
if isinstance(call_t, ContractFunctionT):
assert isinstance(call_t.ast_def, vy_ast.FunctionDef) # help mypy
_validate_function_r(vy_module, call_t.ast_def, err_list)
_analyze_function_r(vy_module, call_t.ast_def, err_list)

namespace = get_namespace()

Expand Down Expand Up @@ -443,10 +444,7 @@ def _handle_modification(self, target: vy_ast.ExprNode):

info._writes.add(var_access)

def _handle_module_access(self, var_access: VarAccess, target: vy_ast.ExprNode):
if not var_access.variable.is_state_variable():
return

def _handle_module_access(self, target: vy_ast.ExprNode):
root_module_info = check_module_uses(target)

if root_module_info is not None:
Expand Down Expand Up @@ -684,9 +682,9 @@ def visit(self, node, typ):
msg += f" `{var.decl_node.target.id}`"
raise ImmutableViolation(msg, var.decl_node, node)

variable_accesses = info._writes | info._reads
for s in variable_accesses:
self.function_analyzer._handle_module_access(s, node)
var_accesses = info._writes | info._reads
if uses_state(var_accesses):
self.function_analyzer._handle_module_access(node)

self.func.mark_variable_writes(info._writes)
self.func.mark_variable_reads(info._reads)
Expand Down Expand Up @@ -789,8 +787,8 @@ def visit_Call(self, node: vy_ast.Call, typ: VyperType) -> None:
if self.function_analyzer:
self._check_call_mutability(func_type.mutability)

for s in func_type.get_variable_accesses():
self.function_analyzer._handle_module_access(s, node.func)
if func_type.uses_state():
self.function_analyzer._handle_module_access(node.func)

if func_type.is_deploy and not self.func.is_deploy:
raise CallViolation(
Expand Down
7 changes: 3 additions & 4 deletions vyper/semantics/analysis/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
from vyper.semantics.analysis.constant_folding import constant_fold
from vyper.semantics.analysis.getters import generate_public_variable_getters
from vyper.semantics.analysis.import_graph import ImportGraph
from vyper.semantics.analysis.local import ExprVisitor, check_module_uses, validate_functions
from vyper.semantics.analysis.local import ExprVisitor, analyze_functions, check_module_uses
from vyper.semantics.analysis.utils import (
check_modifiability,
get_exact_type_from_node,
Expand Down Expand Up @@ -102,7 +102,7 @@ def _analyze_module_r(
# if this is an interface, the function is already validated
# in `ContractFunction.from_vyi()`
if not is_interface:
validate_functions(module_ast)
analyze_functions(module_ast)
analyzer.validate_initialized_modules()
analyzer.validate_used_modules()

Expand Down Expand Up @@ -560,8 +560,7 @@ def visit_ExportsDecl(self, node):
funcs.append(func_t)

# check module uses
var_accesses = func_t.get_variable_accesses()
if any(s.variable.is_state_variable() for s in var_accesses):
if func_t.uses_state():
module_info = check_module_uses(item)
assert module_info is not None # guaranteed by above checks
used_modules.add(module_info)
Expand Down
8 changes: 6 additions & 2 deletions vyper/semantics/analysis/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import itertools
from typing import Callable, List
from typing import Callable, Iterable, List

from vyper import ast as vy_ast
from vyper.exceptions import (
Expand All @@ -17,7 +17,7 @@
ZeroDivisionException,
)
from vyper.semantics import types
from vyper.semantics.analysis.base import ExprInfo, Modifiability, ModuleInfo, VarInfo
from vyper.semantics.analysis.base import ExprInfo, Modifiability, ModuleInfo, VarAccess, VarInfo
Dismissed Show dismissed Hide dismissed
from vyper.semantics.analysis.levenshtein_utils import get_levenshtein_error_suggestions
from vyper.semantics.namespace import get_namespace
from vyper.semantics.types.base import TYPE_T, VyperType
Expand Down Expand Up @@ -48,6 +48,10 @@ def _validate_op(node, types_list, validation_fn_name):
raise err_list[0]


def uses_state(var_accesses: Iterable[VarAccess]) -> bool:
return any(s.variable.is_state_variable() for s in var_accesses)


class _ExprAnalyser:
"""
Node type-checker class.
Expand Down
5 changes: 5 additions & 0 deletions vyper/semantics/types/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from vyper.semantics.analysis.utils import (
check_modifiability,
get_exact_type_from_node,
uses_state,
validate_expected_type,
)
from vyper.semantics.data_locations import DataLocation
Expand Down Expand Up @@ -163,7 +164,11 @@ def get_variable_writes(self):
def get_variable_accesses(self):
return self._variable_reads | self._variable_writes

def uses_state(self):
return self.nonreentrant or uses_state(self.get_variable_accesses())

def get_used_modules(self):
# _used_modules is populated during analysis
return self._used_modules

def mark_used_module(self, module_info):
Expand Down
Loading