From 199f2b65e43e3d3f055756039ef4a9bce7f6f3cf Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Tue, 13 Feb 2024 06:56:41 -0800 Subject: [PATCH] feat[lang]: remove named reentrancy locks (#3769) this commit removes "fine-grained" nonreentrancy locks (i.e., reentrancy locks with names) from vyper. they aren't really used (all known production contracts just use a single global named lock) , and in any case such a use case should better be implemented manually by the user. this simplifies the language and allows moderate simplification to the storage allocator, although some complexity is added because the global restriction has to have special handling (it cannot be handled simply in the recursion into child modules). refactors: - the routine for allocating nonreentrant keys has been refactored into a helper function. --- docs/control-structures.rst | 12 +- .../features/decorators/test_nonreentrant.py | 139 ++++++++++++++---- .../exceptions/test_structure_exception.py | 31 ---- .../test_invalid_function_decorators.py | 15 +- .../cli/storage_layout/test_storage_layout.py | 75 ++++++---- .../test_storage_layout_overrides.py | 34 ++++- tests/unit/semantics/test_storage_slots.py | 11 +- vyper/semantics/analysis/data_positions.py | 83 ++++++----- vyper/semantics/types/function.py | 70 ++++----- 9 files changed, 291 insertions(+), 179 deletions(-) diff --git a/docs/control-structures.rst b/docs/control-structures.rst index a0aa927261..4e18a21bd8 100644 --- a/docs/control-structures.rst +++ b/docs/control-structures.rst @@ -100,22 +100,24 @@ Functions marked with ``@pure`` cannot call non-``pure`` functions. Re-entrancy Locks ----------------- -The ``@nonreentrant()`` decorator places a lock on a function, and all functions with the same ```` value. An attempt by an external contract to call back into any of these functions causes the transaction to revert. +The ``@nonreentrant`` decorator places a global nonreentrancy lock on a function. An attempt by an external contract to call back into any other ``@nonreentrant`` function causes the transaction to revert. .. code-block:: vyper @external - @nonreentrant("lock") + @nonreentrant def make_a_call(_addr: address): # this function is protected from re-entrancy ... -You can put the ``@nonreentrant()`` decorator on a ``__default__`` function but we recommend against it because in most circumstances it will not work in a meaningful way. +You can put the ``@nonreentrant`` decorator on a ``__default__`` function but we recommend against it because in most circumstances it will not work in a meaningful way. Nonreentrancy locks work by setting a specially allocated storage slot to a ```` value on function entrance, and setting it to an ```` value on function exit. On function entrance, if the storage slot is detected to be the ```` value, execution reverts. You cannot put the ``@nonreentrant`` decorator on a ``pure`` function. You can put it on a ``view`` function, but it only checks that the function is not in a callback (the storage slot is not in the ```` state), as ``view`` functions can only read the state, not change it. +You can view where the nonreentrant key is physically laid out in storage by using ``vyper`` with the ``-f layout`` option (e.g., ``vyper -f layout foo.vy``). Unless it is overriden, the compiler will allocate it at slot ``0``. + .. note:: A mutable function can protect a ``view`` function from being called back into (which is useful for instance, if a ``view`` function would return inconsistent state during a mutable function), but a ``view`` function cannot protect itself from being called back into. Note that mutable functions can never be called from a ``view`` function because all external calls out from a ``view`` function are protected by the use of the ``STATICCALL`` opcode. @@ -123,6 +125,8 @@ You cannot put the ``@nonreentrant`` decorator on a ``pure`` function. You can p A nonreentrant lock has an ```` value of 3, and a ```` value of 2. Nonzero values are used to take advantage of net gas metering - as of the Berlin hard fork, the net cost for utilizing a nonreentrant lock is 2300 gas. Prior to v0.3.4, the ```` and ```` values were 0 and 1, respectively. +.. note:: + Prior to 0.4.0, nonreentrancy keys took a "key" argument for fine-grained nonreentrancy control. As of 0.4.0, only a global nonreentrancy lock is available. The ``__default__`` Function ---------------------------- @@ -194,7 +198,7 @@ Decorator Description ``@pure`` Function does not read contract state or environment variables ``@view`` Function does not alter contract state ``@payable`` Function is able to receive Ether -``@nonreentrant()`` Function cannot be called back into during an external call +``@nonreentrant`` Function cannot be called back into during an external call =============================== =========================================================== ``if`` statements diff --git a/tests/functional/codegen/features/decorators/test_nonreentrant.py b/tests/functional/codegen/features/decorators/test_nonreentrant.py index 9329605678..92a21cd302 100644 --- a/tests/functional/codegen/features/decorators/test_nonreentrant.py +++ b/tests/functional/codegen/features/decorators/test_nonreentrant.py @@ -2,30 +2,103 @@ from vyper.exceptions import FunctionDeclarationException - # TODO test functions in this module across all evm versions # once we have cancun support. + + def test_nonreentrant_decorator(get_contract, tx_failed): - calling_contract_code = """ -interface SpecialContract: + malicious_code = """ +interface ProtectedContract: + def protected_function(callback_address: address): nonpayable + +@external +def do_callback(): + ProtectedContract(msg.sender).protected_function(self) + """ + + protected_code = """ +interface Callbackable: + def do_callback(): nonpayable + +@external +@nonreentrant +def protected_function(c: Callbackable): + c.do_callback() + +# add a default function so we know the callback didn't fail for any reason +# besides nonreentrancy +@external +def __default__(): + pass + """ + contract = get_contract(protected_code) + malicious = get_contract(malicious_code) + + with tx_failed(): + contract.protected_function(malicious.address) + + +def test_nonreentrant_view_function(get_contract, tx_failed): + malicious_code = """ +interface ProtectedContract: + def protected_function(): nonpayable + def protected_view_fn() -> uint256: view + +@external +def do_callback() -> uint256: + return ProtectedContract(msg.sender).protected_view_fn() + """ + + protected_code = """ +interface Callbackable: + def do_callback(): nonpayable + +@external +@nonreentrant +def protected_function(c: Callbackable): + c.do_callback() + +@external +@nonreentrant +@view +def protected_view_fn() -> uint256: + return 10 + +# add a default function so we know the callback didn't fail for any reason +# besides nonreentrancy +@external +def __default__(): + pass + """ + contract = get_contract(protected_code) + malicious = get_contract(malicious_code) + + with tx_failed(): + contract.protected_function(malicious.address) + + +def test_multi_function_nonreentrant(get_contract, tx_failed): + malicious_code = """ +interface ProtectedContract: def unprotected_function(val: String[100], do_callback: bool): nonpayable def protected_function(val: String[100], do_callback: bool): nonpayable def special_value() -> String[100]: nonpayable @external def updated(): - SpecialContract(msg.sender).unprotected_function('surprise!', False) + ProtectedContract(msg.sender).unprotected_function('surprise!', False) @external def updated_protected(): # This should fail. - SpecialContract(msg.sender).protected_function('surprise protected!', False) + ProtectedContract(msg.sender).protected_function('surprise protected!', False) """ - reentrant_code = """ + protected_code = """ interface Callback: def updated(): nonpayable def updated_protected(): nonpayable + interface Self: def protected_function(val: String[100], do_callback: bool) -> uint256: nonpayable def protected_function2(val: String[100], do_callback: bool) -> uint256: nonpayable @@ -39,7 +112,7 @@ def set_callback(c: address): self.callback = Callback(c) @external -@nonreentrant('protect_special_value') +@nonreentrant def protected_function(val: String[100], do_callback: bool) -> uint256: self.special_value = val @@ -50,7 +123,7 @@ def protected_function(val: String[100], do_callback: bool) -> uint256: return 2 @external -@nonreentrant('protect_special_value') +@nonreentrant def protected_function2(val: String[100], do_callback: bool) -> uint256: self.special_value = val if do_callback: @@ -60,7 +133,7 @@ def protected_function2(val: String[100], do_callback: bool) -> uint256: return 2 @external -@nonreentrant('protect_special_value') +@nonreentrant def protected_function3(val: String[100], do_callback: bool) -> uint256: self.special_value = val if do_callback: @@ -71,7 +144,8 @@ def protected_function3(val: String[100], do_callback: bool) -> uint256: @external -@nonreentrant('protect_special_value') +@nonreentrant +@view def protected_view_fn() -> String[100]: return self.special_value @@ -81,37 +155,42 @@ def unprotected_function(val: String[100], do_callback: bool): if do_callback: self.callback.updated() - """ - reentrant_contract = get_contract(reentrant_code) - calling_contract = get_contract(calling_contract_code) +# add a default function so we know the callback didn't fail for any reason +# besides nonreentrancy +@external +def __default__(): + pass + """ + contract = get_contract(protected_code) + malicious = get_contract(malicious_code) - reentrant_contract.set_callback(calling_contract.address, transact={}) - assert reentrant_contract.callback() == calling_contract.address + contract.set_callback(malicious.address, transact={}) + assert contract.callback() == malicious.address # Test unprotected function. - reentrant_contract.unprotected_function("some value", True, transact={}) - assert reentrant_contract.special_value() == "surprise!" + contract.unprotected_function("some value", True, transact={}) + assert contract.special_value() == "surprise!" # Test protected function. - reentrant_contract.protected_function("some value", False, transact={}) - assert reentrant_contract.special_value() == "some value" - assert reentrant_contract.protected_view_fn() == "some value" + contract.protected_function("some value", False, transact={}) + assert contract.special_value() == "some value" + assert contract.protected_view_fn() == "some value" with tx_failed(): - reentrant_contract.protected_function("zzz value", True, transact={}) + contract.protected_function("zzz value", True, transact={}) - reentrant_contract.protected_function2("another value", False, transact={}) - assert reentrant_contract.special_value() == "another value" + contract.protected_function2("another value", False, transact={}) + assert contract.special_value() == "another value" with tx_failed(): - reentrant_contract.protected_function2("zzz value", True, transact={}) + contract.protected_function2("zzz value", True, transact={}) - reentrant_contract.protected_function3("another value", False, transact={}) - assert reentrant_contract.special_value() == "another value" + contract.protected_function3("another value", False, transact={}) + assert contract.special_value() == "another value" with tx_failed(): - reentrant_contract.protected_function3("zzz value", True, transact={}) + contract.protected_function3("zzz value", True, transact={}) def test_nonreentrant_decorator_for_default(w3, get_contract, tx_failed): @@ -145,7 +224,7 @@ def set_callback(c: address): @external @payable -@nonreentrant("lock") +@nonreentrant def protected_function(val: String[100], do_callback: bool) -> uint256: self.special_value = val _amount: uint256 = msg.value @@ -169,7 +248,7 @@ def unprotected_function(val: String[100], do_callback: bool): @external @payable -@nonreentrant("lock") +@nonreentrant def __default__(): pass """ @@ -209,7 +288,7 @@ def test_disallow_on_init_function(get_contract): code = """ @external -@nonreentrant("lock") +@nonreentrant def __init__(): foo: uint256 = 0 """ diff --git a/tests/functional/syntax/exceptions/test_structure_exception.py b/tests/functional/syntax/exceptions/test_structure_exception.py index afc7a35012..e530487fea 100644 --- a/tests/functional/syntax/exceptions/test_structure_exception.py +++ b/tests/functional/syntax/exceptions/test_structure_exception.py @@ -44,42 +44,11 @@ def foo() -> int128: return x.codesize() """, """ -@external -@nonreentrant("B") -@nonreentrant("C") -def double_nonreentrant(): - pass - """, - """ struct X: int128[5]: int128[7] """, """ @external -@nonreentrant(" ") -def invalid_nonreentrant_key(): - pass - """, - """ -@external -@nonreentrant("") -def invalid_nonreentrant_key(): - pass - """, - """ -@external -@nonreentrant("123") -def invalid_nonreentrant_key(): - pass - """, - """ -@external -@nonreentrant("!123abcd") -def invalid_nonreentrant_key(): - pass - """, - """ -@external def foo(): true: int128 = 3 """, diff --git a/tests/functional/syntax/signatures/test_invalid_function_decorators.py b/tests/functional/syntax/signatures/test_invalid_function_decorators.py index b3d4219a2d..a7a500efc7 100644 --- a/tests/functional/syntax/signatures/test_invalid_function_decorators.py +++ b/tests/functional/syntax/signatures/test_invalid_function_decorators.py @@ -7,10 +7,23 @@ """ @external @pure -@nonreentrant('lock') +@nonreentrant def nonreentrant_foo() -> uint256: return 1 + """, """ +@external +@nonreentrant +@nonreentrant +def nonreentrant_foo() -> uint256: + return 1 + """, + """ +@external +@nonreentrant("foo") +def nonreentrant_foo() -> uint256: + return 1 + """, ] diff --git a/tests/unit/cli/storage_layout/test_storage_layout.py b/tests/unit/cli/storage_layout/test_storage_layout.py index f0ee25f747..9724dd723c 100644 --- a/tests/unit/cli/storage_layout/test_storage_layout.py +++ b/tests/unit/cli/storage_layout/test_storage_layout.py @@ -6,18 +6,18 @@ def test_storage_layout(): foo: HashMap[address, uint256] @external -@nonreentrant("foo") +@nonreentrant def public_foo1(): pass @external -@nonreentrant("foo") +@nonreentrant def public_foo2(): pass @internal -@nonreentrant("bar") +@nonreentrant def _bar(): pass @@ -28,12 +28,12 @@ def _bar(): bar: uint256 @external -@nonreentrant("bar") +@nonreentrant def public_bar(): pass @external -@nonreentrant("foo") +@nonreentrant def public_foo3(): pass """ @@ -41,12 +41,11 @@ def public_foo3(): out = compile_code(code, output_formats=["layout"]) assert out["layout"]["storage_layout"] == { - "nonreentrant.foo": {"type": "nonreentrant lock", "slot": 0}, - "nonreentrant.bar": {"type": "nonreentrant lock", "slot": 1}, - "foo": {"type": "HashMap[address, uint256]", "slot": 2}, - "arr": {"type": "DynArray[uint256, 3]", "slot": 3}, - "baz": {"type": "Bytes[65]", "slot": 7}, - "bar": {"type": "uint256", "slot": 11}, + "$.nonreentrant_key": {"slot": 0, "type": "nonreentrant lock"}, + "foo": {"slot": 1, "type": "HashMap[address, uint256]"}, + "arr": {"slot": 2, "type": "DynArray[uint256, 3]"}, + "baz": {"slot": 6, "type": "Bytes[65]"}, + "bar": {"slot": 10, "type": "uint256"}, } @@ -64,10 +63,13 @@ def __init__(): expected_layout = { "code_layout": { - "DECIMALS": {"length": 32, "offset": 64, "type": "uint8"}, "SYMBOL": {"length": 64, "offset": 0, "type": "String[32]"}, + "DECIMALS": {"length": 32, "offset": 64, "type": "uint8"}, + }, + "storage_layout": { + "$.nonreentrant_key": {"slot": 0, "type": "nonreentrant lock"}, + "name": {"slot": 1, "type": "String[32]"}, }, - "storage_layout": {"name": {"slot": 0, "type": "String[32]"}}, } out = compile_code(code, output_formats=["layout"]) @@ -107,14 +109,15 @@ def __init__(): "code_layout": { "some_immutable": {"length": 352, "offset": 0, "type": "DynArray[uint256, 10]"}, "a_library": { - "DECIMALS": {"length": 32, "offset": 416, "type": "uint8"}, "SYMBOL": {"length": 64, "offset": 352, "type": "String[32]"}, + "DECIMALS": {"length": 32, "offset": 416, "type": "uint8"}, }, }, "storage_layout": { - "counter": {"slot": 0, "type": "uint256"}, - "counter2": {"slot": 1, "type": "uint256"}, - "a_library": {"supply": {"slot": 2, "type": "uint256"}}, + "$.nonreentrant_key": {"slot": 0, "type": "nonreentrant lock"}, + "counter": {"slot": 1, "type": "uint256"}, + "counter2": {"slot": 2, "type": "uint256"}, + "a_library": {"supply": {"slot": 3, "type": "uint256"}}, }, } @@ -160,9 +163,10 @@ def __init__(): }, }, "storage_layout": { - "counter": {"slot": 0, "type": "uint256"}, - "a_library": {"supply": {"slot": 1, "type": "uint256"}}, - "counter2": {"slot": 2, "type": "uint256"}, + "$.nonreentrant_key": {"slot": 0, "type": "nonreentrant lock"}, + "counter": {"slot": 1, "type": "uint256"}, + "a_library": {"supply": {"slot": 2, "type": "uint256"}}, + "counter2": {"slot": 3, "type": "uint256"}, }, } @@ -171,7 +175,8 @@ def __init__(): def test_storage_layout_module_uses(make_input_bundle): - # test module storage layout, with initializes/uses + # test module storage layout, with initializes/uses and a nonreentrant + # lock lib1 = """ supply: uint256 SYMBOL: immutable(String[32]) @@ -197,6 +202,11 @@ def __init__(s: uint256): @internal def decimals() -> uint8: return lib1.DECIMALS + +@external +@nonreentrant +def foo(): + pass """ code = """ import lib1 as a_library @@ -218,6 +228,11 @@ def __init__(): some_immutable = [1, 2, 3] lib2.__init__(17) + +@external +@nonreentrant +def bar(): + pass """ input_bundle = make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2}) @@ -231,10 +246,11 @@ def __init__(): }, }, "storage_layout": { - "counter": {"slot": 0, "type": "uint256"}, - "lib2": {"storage_variable": {"slot": 1, "type": "uint256"}}, - "counter2": {"slot": 2, "type": "uint256"}, - "a_library": {"supply": {"slot": 3, "type": "uint256"}}, + "$.nonreentrant_key": {"slot": 0, "type": "nonreentrant lock"}, + "counter": {"slot": 1, "type": "uint256"}, + "lib2": {"storage_variable": {"slot": 2, "type": "uint256"}}, + "counter2": {"slot": 3, "type": "uint256"}, + "a_library": {"supply": {"slot": 4, "type": "uint256"}}, }, } @@ -309,12 +325,13 @@ def foo() -> uint256: }, }, "storage_layout": { - "counter": {"slot": 0, "type": "uint256"}, + "$.nonreentrant_key": {"slot": 0, "type": "nonreentrant lock"}, + "counter": {"slot": 1, "type": "uint256"}, "lib2": { - "lib1": {"supply": {"slot": 1, "type": "uint256"}}, - "storage_variable": {"slot": 2, "type": "uint256"}, + "lib1": {"supply": {"slot": 2, "type": "uint256"}}, + "storage_variable": {"slot": 3, "type": "uint256"}, }, - "counter2": {"slot": 3, "type": "uint256"}, + "counter2": {"slot": 4, "type": "uint256"}, }, } diff --git a/tests/unit/cli/storage_layout/test_storage_layout_overrides.py b/tests/unit/cli/storage_layout/test_storage_layout_overrides.py index f4c11b7ae6..707c94c3fc 100644 --- a/tests/unit/cli/storage_layout/test_storage_layout_overrides.py +++ b/tests/unit/cli/storage_layout/test_storage_layout_overrides.py @@ -1,3 +1,5 @@ +import re + import pytest from vyper.compiler import compile_code @@ -28,18 +30,18 @@ def test_storage_layout_for_more_complex(): foo: HashMap[address, uint256] @external -@nonreentrant("foo") +@nonreentrant def public_foo1(): pass @external -@nonreentrant("foo") +@nonreentrant def public_foo2(): pass @internal -@nonreentrant("bar") +@nonreentrant def _bar(): pass @@ -48,19 +50,18 @@ def _bar(): bar: uint256 @external -@nonreentrant("bar") +@nonreentrant def public_bar(): pass @external -@nonreentrant("foo") +@nonreentrant def public_foo3(): pass """ storage_layout_override = { - "nonreentrant.foo": {"type": "nonreentrant lock", "slot": 8}, - "nonreentrant.bar": {"type": "nonreentrant lock", "slot": 7}, + "$.nonreentrant_key": {"type": "nonreentrant lock", "slot": 8}, "foo": {"type": "HashMap[address, uint256]", "slot": 1}, "baz": {"type": "Bytes[65]", "slot": 2}, "bar": {"type": "uint256", "slot": 6}, @@ -110,6 +111,25 @@ def test_overflow(): ) +def test_override_nonreentrant_slot(): + code = """ +@nonreentrant +@external +def foo(): + pass + """ + + storage_layout_override = {"$.nonreentrant_key": {"slot": 2**256, "type": "nonreentrant key"}} + + exception_regex = re.escape( + f"Invalid storage slot for var $.nonreentrant_key, out of bounds: {2**256}" + ) + with pytest.raises(StorageLayoutException, match=exception_regex): + compile_code( + code, output_formats=["layout"], storage_layout_override=storage_layout_override + ) + + def test_incomplete_overrides(): code = """ name: public(String[64]) diff --git a/tests/unit/semantics/test_storage_slots.py b/tests/unit/semantics/test_storage_slots.py index 3620ef64b9..1dc70fd1ba 100644 --- a/tests/unit/semantics/test_storage_slots.py +++ b/tests/unit/semantics/test_storage_slots.py @@ -47,15 +47,9 @@ def __init__(): self.foo[1] = [123, 456, 789] @external -@nonreentrant('lock') +@nonreentrant def with_lock(): pass - - -@external -@nonreentrant('otherlock') -def with_other_lock(): - pass """ @@ -84,7 +78,6 @@ def test_reentrancy_lock(get_contract): # if re-entrancy locks are incorrectly placed within storage, these # calls will either revert or correupt the data that we read later c.with_lock() - c.with_other_lock() assert c.a() == ("ok", [4, 5, 6]) assert [c.b(i) for i in range(2)] == [7, 8] @@ -105,7 +98,7 @@ def test_reentrancy_lock(get_contract): def test_allocator_overflow(get_contract): code = """ -x: uint256 +# --> global nonreentrancy slot allocated here <-- y: uint256[max_value(uint256)] """ with pytest.raises( diff --git a/vyper/semantics/analysis/data_positions.py b/vyper/semantics/analysis/data_positions.py index 604bc6b594..bb4322c7b2 100644 --- a/vyper/semantics/analysis/data_positions.py +++ b/vyper/semantics/analysis/data_positions.py @@ -43,10 +43,15 @@ def __setitem__(self, k, v): super().__setitem__(k, v) +# some name that the user cannot assign to a variable +GLOBAL_NONREENTRANT_KEY = "$.nonreentrant_key" + + class SimpleAllocator: def __init__(self, max_slot: int = 2**256, starting_slot: int = 0): # Allocate storage slots from 0 # note storage is word-addressable, not byte-addressable + self._starting_slot = starting_slot self._slot = starting_slot self._max_slot = max_slot @@ -61,12 +66,19 @@ def allocate_slot(self, n, var_name, node=None): self._slot += n return ret + def allocate_global_nonreentrancy_slot(self): + slot = self.allocate_slot(1, GLOBAL_NONREENTRANT_KEY) + assert slot == self._starting_slot + return slot + class Allocators: storage_allocator: SimpleAllocator transient_storage_allocator: SimpleAllocator immutables_allocator: SimpleAllocator + _global_nonreentrancy_key_slot: int + def __init__(self): self.storage_allocator = SimpleAllocator(max_slot=2**256) self.transient_storage_allocator = SimpleAllocator(max_slot=2**256) @@ -82,6 +94,16 @@ def get_allocator(self, location: DataLocation): raise CompilerPanic("unreachable") # pragma: nocover + def allocate_global_nonreentrancy_slot(self): + location = get_reentrancy_key_location() + + allocator = self.get_allocator(location) + slot = allocator.allocate_global_nonreentrancy_slot() + self._global_nonreentrancy_key_slot = slot + + def get_global_nonreentrant_key_slot(self): + return self._global_nonreentrancy_key_slot + class OverridingStorageAllocator: """ @@ -127,7 +149,6 @@ def set_storage_slots_with_overrides( Returns the layout as a dict of variable name -> variable info (Doesn't handle modules, or transient storage) """ - ret: InsertableOnceDict[str, dict] = InsertableOnceDict() reserved_slots = OverridingStorageAllocator() @@ -136,15 +157,13 @@ def set_storage_slots_with_overrides( type_ = node._metadata["func_type"] # Ignore functions without non-reentrant - if type_.nonreentrant is None: + if not type_.nonreentrant: continue - variable_name = f"nonreentrant.{type_.nonreentrant}" + variable_name = GLOBAL_NONREENTRANT_KEY # re-entrant key was already identified if variable_name in ret: - _slot = ret[variable_name]["slot"] - type_.set_reentrancy_key_position(VarOffset(_slot)) continue # Expect to find this variable within the storage layout override @@ -210,6 +229,20 @@ def get_reentrancy_key_location() -> DataLocation: } +def _allocate_nonreentrant_keys(vyper_module, allocators): + SLOT = allocators.get_global_nonreentrant_key_slot() + + for node in vyper_module.get_children(vy_ast.FunctionDef): + type_ = node._metadata["func_type"] + if not type_.nonreentrant: + continue + + # a nonreentrant key can appear many times in a module but it + # only takes one slot. after the first time we see it, do not + # increment the storage slot. + type_.set_reentrancy_key_position(VarOffset(SLOT)) + + def _allocate_layout_r( vyper_module: vy_ast.Module, allocators: Allocators = None, immutables_only=False ) -> StorageLayout: @@ -217,42 +250,26 @@ def _allocate_layout_r( Parse module-level Vyper AST to calculate the layout of storage variables. Returns the layout as a dict of variable name -> variable info """ + global_ = False if allocators is None: + global_ = True allocators = Allocators() + # always allocate nonreentrancy slot, so that adding or removing + # reentrancy protection from a contract does not change its layout + allocators.allocate_global_nonreentrancy_slot() ret: defaultdict[str, InsertableOnceDict[str, dict]] = defaultdict(InsertableOnceDict) - for node in vyper_module.get_children(vy_ast.FunctionDef): - if immutables_only: - break - - type_ = node._metadata["func_type"] - if type_.nonreentrant is None: - continue - - variable_name = f"nonreentrant.{type_.nonreentrant}" - reentrancy_key_location = get_reentrancy_key_location() - layout_key = _LAYOUT_KEYS[reentrancy_key_location] - - # a nonreentrant key can appear many times in a module but it - # only takes one slot. after the first time we see it, do not - # increment the storage slot. - if variable_name in ret[layout_key]: - _slot = ret[layout_key][variable_name]["slot"] - type_.set_reentrancy_key_position(VarOffset(_slot)) - continue - - # TODO use one byte - or bit - per reentrancy key - # requires either an extra SLOAD or caching the value of the - # location in memory at entrance - allocator = allocators.get_allocator(reentrancy_key_location) - slot = allocator.allocate_slot(1, variable_name, node) - - type_.set_reentrancy_key_position(VarOffset(slot)) + # tag functions with the global nonreentrant key + if not immutables_only: + _allocate_nonreentrant_keys(vyper_module, allocators) + layout_key = _LAYOUT_KEYS[get_reentrancy_key_location()] # TODO this could have better typing but leave it untyped until # we nail down the format better - ret[layout_key][variable_name] = {"type": "nonreentrant lock", "slot": slot} + if global_ and GLOBAL_NONREENTRANT_KEY not in ret[layout_key]: + slot = allocators.get_global_nonreentrant_key_slot() + ret[layout_key][GLOBAL_NONREENTRANT_KEY] = {"type": "nonreentrant lock", "slot": slot} for node in _get_allocatable(vyper_module): if isinstance(node, vy_ast.InitializesDecl): diff --git a/vyper/semantics/types/function.py b/vyper/semantics/types/function.py index 705470a798..43d553288e 100644 --- a/vyper/semantics/types/function.py +++ b/vyper/semantics/types/function.py @@ -5,7 +5,6 @@ from typing import Any, Dict, List, Optional, Tuple from vyper import ast as vy_ast -from vyper.ast.identifiers import validate_identifier from vyper.ast.validation import validate_call_args from vyper.exceptions import ( ArgumentException, @@ -78,8 +77,8 @@ class ContractFunctionT(VyperType): enum indicating the external visibility of a function. state_mutability : StateMutability enum indicating the authority a function has to mutate it's own state. - nonreentrant : Optional[str] - Re-entrancy lock name. + nonreentrant : bool + Whether this function is marked `@nonreentrant` or not """ _is_callable = True @@ -93,7 +92,7 @@ def __init__( function_visibility: FunctionVisibility, state_mutability: StateMutability, from_interface: bool = False, - nonreentrant: Optional[str] = None, + nonreentrant: bool = False, ast_def: Optional[vy_ast.VyperNode] = None, ) -> None: super().__init__() @@ -107,6 +106,9 @@ def __init__( self.nonreentrant = nonreentrant self.from_interface = from_interface + # sanity check, nonreentrant used to be Optional[str] + assert isinstance(self.nonreentrant, bool) + self.ast_def = ast_def self._analysed = False @@ -279,7 +281,7 @@ def from_InterfaceDef(cls, funcdef: vy_ast.FunctionDef) -> "ContractFunctionT": function_visibility, state_mutability, from_interface=True, - nonreentrant=None, + nonreentrant=False, ast_def=funcdef, ) @@ -298,12 +300,10 @@ def from_vyi(cls, funcdef: vy_ast.FunctionDef) -> "ContractFunctionT": ------- ContractFunctionT """ - function_visibility, state_mutability, nonreentrant_key = _parse_decorators(funcdef) + function_visibility, state_mutability, nonreentrant = _parse_decorators(funcdef) - if nonreentrant_key is not None: - raise FunctionDeclarationException( - "nonreentrant key not allowed in interfaces", funcdef - ) + if nonreentrant: + raise FunctionDeclarationException("`@nonreentrant` not allowed in interfaces", funcdef) if funcdef.name == "__init__": raise FunctionDeclarationException("Constructors cannot appear in interfaces", funcdef) @@ -332,7 +332,7 @@ def from_vyi(cls, funcdef: vy_ast.FunctionDef) -> "ContractFunctionT": function_visibility, state_mutability, from_interface=True, - nonreentrant=nonreentrant_key, + nonreentrant=nonreentrant, ast_def=funcdef, ) @@ -350,7 +350,7 @@ def from_FunctionDef(cls, funcdef: vy_ast.FunctionDef) -> "ContractFunctionT": ------- ContractFunctionT """ - function_visibility, state_mutability, nonreentrant_key = _parse_decorators(funcdef) + function_visibility, state_mutability, nonreentrant = _parse_decorators(funcdef) positional_args, keyword_args = _parse_args(funcdef) @@ -403,15 +403,16 @@ def from_FunctionDef(cls, funcdef: vy_ast.FunctionDef) -> "ContractFunctionT": function_visibility, state_mutability, from_interface=False, - nonreentrant=nonreentrant_key, + nonreentrant=nonreentrant, ast_def=funcdef, ) def set_reentrancy_key_position(self, position: VarOffset) -> None: if hasattr(self, "reentrancy_key_position"): raise CompilerPanic("Position was already assigned") - if self.nonreentrant is None: - raise CompilerPanic(f"No reentrant key {self}") + if not self.nonreentrant: + raise CompilerPanic(f"Not nonreentrant {self}", self.ast_def) + self.reentrancy_key_position = position @classmethod @@ -660,32 +661,30 @@ def _parse_return_type(funcdef: vy_ast.FunctionDef) -> Optional[VyperType]: def _parse_decorators( funcdef: vy_ast.FunctionDef, -) -> tuple[Optional[FunctionVisibility], StateMutability, Optional[str]]: +) -> tuple[Optional[FunctionVisibility], StateMutability, bool]: function_visibility = None state_mutability = None - nonreentrant_key = None + nonreentrant_node = None for decorator in funcdef.decorator_list: if isinstance(decorator, vy_ast.Call): - if nonreentrant_key is not None: - raise StructureException( - "nonreentrant decorator is already set with key: " f"{nonreentrant_key}", - funcdef, - ) - - if decorator.get("func.id") != "nonreentrant": - raise StructureException("Decorator is not callable", decorator) - if len(decorator.args) != 1 or not isinstance(decorator.args[0], vy_ast.Str): - raise StructureException( - "@nonreentrant name must be given as a single string literal", decorator - ) + msg = "Decorator is not callable" + hint = None + if decorator.get("func.id") == "nonreentrant": + hint = "use `@nonreentrant` with no arguments. the " + hint += "`@nonreentrant` decorator does not accept any " + hint += "arguments since vyper 0.4.0." + raise StructureException(msg, decorator, hint=hint) + + if decorator.get("id") == "nonreentrant": + if nonreentrant_node is not None: + raise StructureException("nonreentrant decorator is already set", nonreentrant_node) if funcdef.name == "__init__": - msg = "Nonreentrant decorator disallowed on `__init__`" + msg = "`@nonreentrant` decorator disallowed on `__init__`" raise FunctionDeclarationException(msg, decorator) - nonreentrant_key = decorator.args[0].value - validate_identifier(nonreentrant_key, decorator.args[0]) + nonreentrant_node = decorator elif isinstance(decorator, vy_ast.Name): if FunctionVisibility.is_valid_value(decorator.id): @@ -726,12 +725,13 @@ def _parse_decorators( # default to nonpayable state_mutability = StateMutability.NONPAYABLE - if state_mutability == StateMutability.PURE and nonreentrant_key is not None: - raise StructureException("Cannot use reentrancy guard on pure functions", funcdef) + if state_mutability == StateMutability.PURE and nonreentrant_node is not None: + raise StructureException("Cannot use reentrancy guard on pure functions", nonreentrant_node) # assert function_visibility is not None # mypy # assert state_mutability is not None # mypy - return function_visibility, state_mutability, nonreentrant_key + nonreentrant = nonreentrant_node is not None + return function_visibility, state_mutability, nonreentrant def _parse_args(