Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: guard against mutating code in non-mutable functions #3555

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 100 additions & 5 deletions tests/parser/exceptions/test_constancy_exception.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,16 +76,111 @@ def foo():
for i in range(x):
pass""",
"""
f:int128
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test now throws for its other issue of call violation, which is checked first in visit_Expr. Hence, I removed it as there are already tests for call violation in this file.

from vyper.interfaces import ERC20

token: ERC20

@external
def a (x:int128):
self.f = 100
@view
def topup(amount: uint256):
assert self.token.transferFrom(msg.sender, self, amount)
""",
"""
from vyper.interfaces import ERC20

token: ERC20

@external
@view
def topup(amount: uint256):
x: bool = self.token.transferFrom(msg.sender, self, amount)
""",
"""
from vyper.interfaces import ERC20

token: ERC20

@external
def b():
self.a(10)""",
@view
def topup(amount: uint256):
x: bool = False
x = self.token.transferFrom(msg.sender, self, amount)
""",
"""
from vyper.interfaces import ERC20

token: ERC20

@external
@view
def topup(amount: uint256) -> bool:
return self.token.transferFrom(msg.sender, self, amount)
""",
"""
a: DynArray[uint256, 3]

@external
@view
def foo():
assert self.a.pop() > 123, "vyper"
""",
"""
a: DynArray[uint256, 3]

@external
@view
def foo():
x: uint256 = self.a.pop()
""",
"""
a: DynArray[uint256, 3]

@external
@view
def foo():
x: uint256 = 0
x = self.a.pop()
""",
"""
a: DynArray[uint256, 3]

@external
@view
def foo() -> uint256:
return self.a.pop()
""",
"""
@external
@view
def foo(x: address):
assert convert(
raw_call(
x,
b'',
max_outsize=32,
revert_on_failure=False
), uint256
) > 123, "vyper"
""",
"""
@external
@view
def foo(a: address):
x: address = create_minimal_proxy_to(a)
""",
"""
@external
@view
def foo(a: address):
x: address = empty(address)
x = create_copy_of(a)
""",
"""
@external
@view
def foo(a: address) -> address:
return create_from_blueprint(a)
""",
],
)
def test_statefulness_violations(bad_code):
Expand Down
24 changes: 24 additions & 0 deletions tests/parser/features/decorators/test_view.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import pytest

from vyper.exceptions import FunctionDeclarationException


Expand Down Expand Up @@ -28,3 +30,25 @@ def foo() -> num:
assert_compile_failed(
lambda: get_contract_with_gas_estimation_for_constants(code), FunctionDeclarationException
)


good_code = [
"""
@external
@view
def foo(x: address):
assert convert(
raw_call(
x,
b'',
max_outsize=32,
is_static_call=True,
), uint256
) > 123, "vyper"
"""
]


@pytest.mark.parametrize("code", good_code)
def test_view_call_compiles(get_contract, code):
get_contract(code)
Original file line number Diff line number Diff line change
Expand Up @@ -881,7 +881,7 @@ def set_lucky(arg1: address, arg2: int128):
print("Successfully executed an external contract call state change")


def test_constant_external_contract_call_cannot_change_state(
def test_constant_external_contract_call_cannot_change_state1(
assert_compile_failed, get_contract_with_gas_estimation
):
c = """
Expand All @@ -892,6 +892,18 @@ def set_lucky(_lucky: int128) -> int128: nonpayable
@view
def set_lucky_expr(arg1: address, arg2: int128):
Foo(arg1).set_lucky(arg2)
"""
assert_compile_failed(lambda: get_contract_with_gas_estimation(c), StateAccessViolation)

print("Successfully blocked an external contract call from a constant function")


def test_constant_external_contract_call_cannot_change_state2(
assert_compile_failed, get_contract_with_gas_estimation
):
c = """
interface Foo:
def set_lucky(_lucky: int128) -> int128: nonpayable

@external
@view
Expand Down
42 changes: 40 additions & 2 deletions vyper/semantics/analysis/annotation.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from vyper import ast as vy_ast
from vyper.exceptions import StructureException, TypeCheckFailure
from vyper.exceptions import StateAccessViolation, StructureException, TypeCheckFailure
from vyper.semantics.analysis.utils import (
get_common_types,
get_exact_type_from_node,
get_expr_info,
get_possible_types_from_node,
)
from vyper.semantics.types import TYPE_T, BoolT, EnumT, EventT, SArrayT, StructT, is_type_t
from vyper.semantics.types.function import ContractFunctionT, MemberFunctionT
from vyper.semantics.types.function import ContractFunctionT, MemberFunctionT, StateMutability

Check notice

Code scanning / CodeQL

Cyclic import Note

Import of module
vyper.semantics.types.function
begins an import cycle.


class _AnnotationVisitorBase:
Expand Down Expand Up @@ -136,6 +137,23 @@
self.visit(node.func)

if isinstance(call_type, ContractFunctionT):
if (
call_type.mutability > StateMutability.VIEW
and self.func.mutability <= StateMutability.VIEW
):
raise StateAccessViolation(
f"Cannot call a mutating function from a {self.func.mutability.value} function",
node,
)

if (
self.func.mutability == StateMutability.PURE
and call_type.mutability != StateMutability.PURE
):
raise StateAccessViolation(
"Cannot call non-pure function from a pure function", node
)

# function calls
if call_type.is_internal:
self.func.called_functions.add(call_type)
Expand All @@ -157,10 +175,30 @@
):
self.visit(value, arg_type)
elif isinstance(call_type, MemberFunctionT):
if call_type.is_modifying:
# it's a dotted function call like dynarray.pop()
expr_info = get_expr_info(node.func.value)
expr_info.validate_modification(node, self.func.mutability)

assert len(node.args) == len(call_type.arg_types)
for arg, arg_type in zip(node.args, call_type.arg_types):
self.visit(arg, arg_type)
else:
# note that mutability for`raw_call` is handled in its `build_IR` function
mutable_builtins = (
"create_minimal_proxy_to",
"create_copy_of",
"create_from_blueprint",
)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is an overlap here with #3546 (comment).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another option here could be to check constancy of the context in _CreateBase.build_IR(), similar to how it is being checked for raw_call. However, this would not resolve #3546.

if (
self.func.mutability <= StateMutability.VIEW
and node.get("func.id") in mutable_builtins
):
raise StateAccessViolation(
f"Cannot call a mutating builtin from a {self.func.mutability.value} function",
node,
)

# builtin functions
arg_types = call_type.infer_arg_types(node)
for arg, arg_type in zip(node.args, arg_types):
Expand Down
23 changes: 0 additions & 23 deletions vyper/semantics/analysis/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,29 +518,6 @@ def visit_Expr(self, node):
if is_type_t(fn_type, StructT):
raise StructureException("Struct creation without assignment is disallowed", node)

if isinstance(fn_type, ContractFunctionT):
if (
fn_type.mutability > StateMutability.VIEW
and self.func.mutability <= StateMutability.VIEW
):
raise StateAccessViolation(
f"Cannot call a mutating function from a {self.func.mutability.value} function",
node,
)

if (
self.func.mutability == StateMutability.PURE
and fn_type.mutability != StateMutability.PURE
):
raise StateAccessViolation(
"Cannot call non-pure function from a pure function", node
)

if isinstance(fn_type, MemberFunctionT) and fn_type.is_modifying:
# it's a dotted function call like dynarray.pop()
expr_info = get_expr_info(node.value.func.value)
expr_info.validate_modification(node, self.func.mutability)

# NOTE: fetch_call_return validates call args.
return_value = fn_type.fetch_call_return(node.value)
if (
Expand Down
Loading