Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: guard against mutating code in non-mutable functions #3555

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 100 additions & 5 deletions tests/parser/exceptions/test_constancy_exception.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,16 +76,111 @@ def foo():
for i in range(x):
pass""",
"""
f:int128
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test now throws for its other issue of call violation, which is checked first in visit_Expr. Hence, I removed it as there are already tests for call violation in this file.

from vyper.interfaces import ERC20

token: ERC20

@external
def a (x:int128):
self.f = 100
@view
def topup(amount: uint256):
assert self.token.transferFrom(msg.sender, self, amount)
""",
"""
from vyper.interfaces import ERC20

token: ERC20

@external
@view
def topup(amount: uint256):
x: bool = self.token.transferFrom(msg.sender, self, amount)
""",
"""
from vyper.interfaces import ERC20

token: ERC20

@external
def b():
self.a(10)""",
@view
def topup(amount: uint256):
x: bool = False
x = self.token.transferFrom(msg.sender, self, amount)
""",
"""
from vyper.interfaces import ERC20

token: ERC20

@external
@view
def topup(amount: uint256) -> bool:
return self.token.transferFrom(msg.sender, self, amount)
""",
"""
a: DynArray[uint256, 3]

@external
@view
def foo():
assert self.a.pop() > 123, "vyper"
""",
"""
a: DynArray[uint256, 3]

@external
@view
def foo():
x: uint256 = self.a.pop()
""",
"""
a: DynArray[uint256, 3]

@external
@view
def foo():
x: uint256 = 0
x = self.a.pop()
""",
"""
a: DynArray[uint256, 3]

@external
@view
def foo() -> uint256:
return self.a.pop()
""",
"""
@external
@view
def foo(x: address):
assert convert(
raw_call(
x,
b'',
max_outsize=32,
revert_on_failure=False
), uint256
) > 123, "vyper"
""",
"""
@external
@view
def foo(a: address):
x: address = create_minimal_proxy_to(a)
""",
"""
@external
@view
def foo(a: address):
x: address = empty(address)
x = create_copy_of(a)
""",
"""
@external
@view
def foo(a: address) -> address:
return create_from_blueprint(a)
""",
],
)
def test_statefulness_violations(bad_code):
Expand Down
24 changes: 24 additions & 0 deletions tests/parser/features/decorators/test_view.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import pytest

from vyper.exceptions import FunctionDeclarationException


Expand Down Expand Up @@ -28,3 +30,25 @@ def foo() -> num:
assert_compile_failed(
lambda: get_contract_with_gas_estimation_for_constants(code), FunctionDeclarationException
)


good_code = [
"""
@external
@view
def foo(x: address):
assert convert(
raw_call(
x,
b'',
max_outsize=32,
is_static_call=True,
), uint256
) > 123, "vyper"
"""
]


@pytest.mark.parametrize("code", good_code)
def test_view_call_compiles(get_contract, code):
get_contract(code)
Original file line number Diff line number Diff line change
Expand Up @@ -881,7 +881,7 @@ def set_lucky(arg1: address, arg2: int128):
print("Successfully executed an external contract call state change")


def test_constant_external_contract_call_cannot_change_state(
def test_constant_external_contract_call_cannot_change_state1(
assert_compile_failed, get_contract_with_gas_estimation
):
c = """
Expand All @@ -892,6 +892,18 @@ def set_lucky(_lucky: int128) -> int128: nonpayable
@view
def set_lucky_expr(arg1: address, arg2: int128):
Foo(arg1).set_lucky(arg2)
"""
assert_compile_failed(lambda: get_contract_with_gas_estimation(c), StateAccessViolation)

print("Successfully blocked an external contract call from a constant function")


def test_constant_external_contract_call_cannot_change_state2(
assert_compile_failed, get_contract_with_gas_estimation
):
c = """
interface Foo:
def set_lucky(_lucky: int128) -> int128: nonpayable

@external
@view
Expand Down
2 changes: 2 additions & 0 deletions vyper/builtins/_signatures.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from vyper.codegen.expr import Expr
from vyper.codegen.ir_node import IRnode
from vyper.exceptions import CompilerPanic, TypeMismatch
from vyper.semantics.analysis.base import StateMutability

Check notice

Code scanning / CodeQL

Cyclic import Note

Import of module
vyper.semantics.analysis.base
begins an import cycle.
from vyper.semantics.analysis.utils import get_exact_type_from_node, validate_expected_type
from vyper.semantics.types import TYPE_T, KwargSettings, VyperType
from vyper.semantics.types.utils import type_from_annotation
Expand Down Expand Up @@ -77,6 +78,7 @@
class BuiltinFunction:
_has_varargs = False
_kwargs: Dict[str, KwargSettings] = {}
mutability = StateMutability.PURE

# helper function to deal with TYPE_DEFINITIONs
def _validate_single(self, arg, expected_type):
Expand Down
11 changes: 10 additions & 1 deletion vyper/builtins/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
UnfoldableNode,
ZeroDivisionException,
)
from vyper.semantics.analysis.base import VarInfo
from vyper.semantics.analysis.base import StateMutability, VarInfo

Check notice

Code scanning / CodeQL

Cyclic import Note

Import of module
vyper.semantics.analysis.base
begins an import cycle.
from vyper.semantics.analysis.utils import (
get_common_types,
get_exact_type_from_node,
Expand Down Expand Up @@ -1083,12 +1083,15 @@
"revert_on_failure": KwargSettings(BoolT(), True, require_literal=True),
}
_return_type = None
mutability = StateMutability.NONPAYABLE

def fetch_call_return(self, node):
self._validate_arg_types(node)

kwargz = {i.arg: i.value for i in node.keywords}

value = kwargz.get("value")
static_call = kwargz.get("is_static_call")
outsize = kwargz.get("max_outsize")
revert_on_failure = kwargz.get("revert_on_failure")
revert_on_failure = revert_on_failure.value if revert_on_failure is not None else True
Expand All @@ -1101,6 +1104,11 @@
if not isinstance(outsize, vy_ast.Int) or outsize.value < 0:
raise

if static_call:
self.mutability = StateMutability.VIEW
elif value:
self.mutability = StateMutability.PAYABLE

if outsize.value:
return_type = BytesT()
return_type.set_min_length(outsize.value)
Expand Down Expand Up @@ -1724,6 +1732,7 @@
"salt": KwargSettings(BYTES32_T, empty_value),
}
_return_type = AddressT()
mutability = StateMutability.PAYABLE

@process_inputs
def build_IR(self, expr, args, kwargs, context):
Expand Down
33 changes: 31 additions & 2 deletions vyper/semantics/analysis/annotation.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from vyper import ast as vy_ast
from vyper.exceptions import StructureException, TypeCheckFailure
from vyper.exceptions import StateAccessViolation, StructureException, TypeCheckFailure
from vyper.semantics.analysis.utils import (
get_common_types,
get_exact_type_from_node,
get_expr_info,
get_possible_types_from_node,
)
from vyper.semantics.types import TYPE_T, BoolT, EnumT, EventT, SArrayT, StructT, is_type_t
from vyper.semantics.types.function import ContractFunctionT, MemberFunctionT
from vyper.semantics.types.function import ContractFunctionT, MemberFunctionT, StateMutability

Check notice

Code scanning / CodeQL

Cyclic import Note

Import of module
vyper.semantics.types.function
begins an import cycle.


class _AnnotationVisitorBase:
Expand Down Expand Up @@ -135,7 +136,27 @@
node._metadata["type"] = node_type
self.visit(node.func)

def _check_mutability(call_type):
if (
call_type.mutability > StateMutability.VIEW
and self.func.mutability <= StateMutability.VIEW
):
raise StateAccessViolation(
f"Cannot call a mutating function from a {self.func.mutability.value} function",
node,
)

if (
self.func.mutability == StateMutability.PURE
and call_type.mutability != StateMutability.PURE
):
raise StateAccessViolation(
"Cannot call non-pure function from a pure function", node
)

if isinstance(call_type, ContractFunctionT):
_check_mutability(call_type)

# function calls
if call_type.is_internal:
self.func.called_functions.add(call_type)
Expand All @@ -157,10 +178,18 @@
):
self.visit(value, arg_type)
elif isinstance(call_type, MemberFunctionT):
if call_type.is_modifying:
# it's a dotted function call like dynarray.pop()
expr_info = get_expr_info(node.func.value)
expr_info.validate_modification(node, self.func.mutability)

assert len(node.args) == len(call_type.arg_types)
for arg, arg_type in zip(node.args, call_type.arg_types):
self.visit(arg, arg_type)
else:
if hasattr(call_type, "mutability"):
_check_mutability(call_type)

# builtin functions
arg_types = call_type.infer_arg_types(node)
for arg, arg_type in zip(node.args, arg_types):
Expand Down
23 changes: 0 additions & 23 deletions vyper/semantics/analysis/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,29 +518,6 @@ def visit_Expr(self, node):
if is_type_t(fn_type, StructT):
raise StructureException("Struct creation without assignment is disallowed", node)

if isinstance(fn_type, ContractFunctionT):
if (
fn_type.mutability > StateMutability.VIEW
and self.func.mutability <= StateMutability.VIEW
):
raise StateAccessViolation(
f"Cannot call a mutating function from a {self.func.mutability.value} function",
node,
)

if (
self.func.mutability == StateMutability.PURE
and fn_type.mutability != StateMutability.PURE
):
raise StateAccessViolation(
"Cannot call non-pure function from a pure function", node
)

if isinstance(fn_type, MemberFunctionT) and fn_type.is_modifying:
# it's a dotted function call like dynarray.pop()
expr_info = get_expr_info(node.value.func.value)
expr_info.validate_modification(node, self.func.mutability)

# NOTE: fetch_call_return validates call args.
return_value = fn_type.fetch_call_return(node.value)
if (
Expand Down
Loading