From 63b8d1da768727bb5f3e96d0d532d4d9f22e5441 Mon Sep 17 00:00:00 2001 From: cyberthirst Date: Thu, 4 Apr 2024 23:29:07 +0200 Subject: [PATCH] fix[codegen]: fix transient codegen for `slice` and `extract32` (#3874) this commit fixes transient storage codegen for `slice()` and `extract32()` builtins. previously, some codegen routines were hardcoded to check for storage; this commit changes them to check if they are `word_addressable` instead. this commit also refactors the `extract32()` code generation logic to be simpler and use more recent APIs. the new implementation relies on the optimizer to do some optimizations which in this routine were previously hand-rolled. --------- Co-authored-by: Charles Cooper --- .../builtins/codegen/test_extract32.py | 26 +++- .../functional/builtins/codegen/test_slice.py | 34 ++++- vyper/builtins/functions.py | 134 ++++++------------ vyper/evm/address_space.py | 4 + 4 files changed, 97 insertions(+), 101 deletions(-) diff --git a/tests/functional/builtins/codegen/test_extract32.py b/tests/functional/builtins/codegen/test_extract32.py index a95b57b5ab..96280ce862 100644 --- a/tests/functional/builtins/codegen/test_extract32.py +++ b/tests/functional/builtins/codegen/test_extract32.py @@ -1,6 +1,22 @@ -def test_extract32_extraction(tx_failed, get_contract_with_gas_estimation): - extract32_code = """ -y: Bytes[100] +import pytest + +from vyper.evm.opcodes import version_check + + +@pytest.mark.parametrize("location", ["storage", "transient"]) +def test_extract32_extraction(tx_failed, get_contract_with_gas_estimation, location): + if location == "transient" and not version_check(begin="cancun"): + pytest.skip( + "Skipping test as storage_location is 'transient' and EVM version is pre-Cancun" + ) + if location == "storage": + decl = "y: Bytes[100]" + elif location == "transient": + decl = "y: transient(Bytes[100])" + else: + raise Exception("unreachable") + extract32_code = f""" +{decl} @external def extrakt32(inp: Bytes[100], index: uint256) -> bytes32: return extract32(inp, index) @@ -43,8 +59,6 @@ def extrakt32_storage(index: uint256, inp: Bytes[100]) -> bytes32: with tx_failed(): c.extrakt32(S, i) - print("Passed bytes32 extraction test") - def test_extract32_code(tx_failed, get_contract_with_gas_estimation): extract32_code = """ @@ -84,5 +98,3 @@ def foq(inp: Bytes[32]) -> address: with tx_failed(): c.foq(b"crow" * 8) - - print("Passed extract32 test") diff --git a/tests/functional/builtins/codegen/test_slice.py b/tests/functional/builtins/codegen/test_slice.py index 0ff6ee1d06..03dc7cc56d 100644 --- a/tests/functional/builtins/codegen/test_slice.py +++ b/tests/functional/builtins/codegen/test_slice.py @@ -4,6 +4,7 @@ from vyper.compiler import compile_code from vyper.compiler.settings import OptimizationLevel, Settings +from vyper.evm.opcodes import version_check from vyper.exceptions import ArgumentException, TypeMismatch _fun_bytes32_bounds = [(0, 32), (3, 29), (27, 5), (0, 5), (5, 3), (30, 2)] @@ -93,7 +94,9 @@ def _get_contract(): assert c.do_splice() == bytesdata[start : start + length] -@pytest.mark.parametrize("location", ("storage", "calldata", "memory", "literal", "code")) +@pytest.mark.parametrize( + "location", ["storage", "transient", "calldata", "memory", "literal", "code"] +) @pytest.mark.parametrize("use_literal_start", (True, False)) @pytest.mark.parametrize("use_literal_length", (True, False)) @pytest.mark.parametrize("opt_level", list(OptimizationLevel)) @@ -112,6 +115,10 @@ def test_slice_bytes_fuzz( use_literal_length, length_bound, ): + if location == "transient" and not version_check(begin="cancun"): + pytest.skip( + "Skipping test as storage_location is 'transient' and EVM version is pre-Cancun" + ) preamble = "" if location == "memory": spliced_code = f"foo: Bytes[{length_bound}] = inp" @@ -119,6 +126,12 @@ def test_slice_bytes_fuzz( elif location == "storage": preamble = f""" foo: Bytes[{length_bound}] + """ + spliced_code = "self.foo = inp" + foo = "self.foo" + elif location == "transient": + preamble = f""" +foo: transient(Bytes[{length_bound}]) """ spliced_code = "self.foo = inp" foo = "self.foo" @@ -194,10 +207,23 @@ def _get_contract(): assert c.do_slice(bytesdata, start, length) == bytesdata[start:end], code -def test_slice_private(get_contract): +@pytest.mark.parametrize("location", ["storage", "transient"]) +def test_slice_private(get_contract, location): + if location == "transient" and not version_check(begin="cancun"): + pytest.skip( + "Skipping test as storage_location is 'transient' and EVM version is pre-Cancun" + ) + # test there are no buffer overruns in the slice function - code = """ -bytez: public(String[12]) + if location == "storage": + decl = "bytez: public(String[12])" + elif location == "transient": + decl = "bytez: public(transient(String[12]))" + else: + raise Exception("unreachable") + + code = f""" +{decl} @internal def _slice(start: uint256, length: uint256): diff --git a/vyper/builtins/functions.py b/vyper/builtins/functions.py index f29fd0ef61..05d6dcb8b3 100644 --- a/vyper/builtins/functions.py +++ b/vyper/builtins/functions.py @@ -8,6 +8,7 @@ from vyper.codegen.abi_encoder import abi_encode from vyper.codegen.context import Context, VariableRecord from vyper.codegen.core import ( + LOAD, STORE, IRnode, add_ofst, @@ -36,7 +37,7 @@ from vyper.codegen.expr import Expr from vyper.codegen.ir_node import Encoding, scope_multi from vyper.codegen.keccak256_helper import keccak256_helper -from vyper.evm.address_space import MEMORY, STORAGE +from vyper.evm.address_space import MEMORY from vyper.exceptions import ( ArgumentException, CompilerPanic, @@ -378,7 +379,7 @@ def build_IR(self, expr, args, kwargs, context): # add 32 bytes to the buffer size bc word access might # be unaligned (see below) - if src.location == STORAGE: + if src.location.word_addressable: buflen += 32 # Get returntype string or bytes @@ -405,8 +406,8 @@ def build_IR(self, expr, args, kwargs, context): src_data = bytes_data_ptr(src) # general case. byte-for-byte copy - if src.location == STORAGE: - # because slice uses byte-addressing but storage + if src.location.word_addressable: + # because slice uses byte-addressing but storage/tstorage # is word-aligned, this algorithm starts at some number # of bytes before the data section starts, and might copy # an extra word. the pseudocode is: @@ -838,19 +839,6 @@ class ECMul(_ECArith): _precompile = 0x7 -def _generic_element_getter(op): - def f(index): - return IRnode.from_list( - [op, ["add", "_sub", ["add", 32, ["mul", 32, index]]]], typ=INT128_T - ) - - return f - - -def _storage_element_getter(index): - return IRnode.from_list(["sload", ["add", "_sub", ["add", 1, index]]], typ=INT128_T) - - class Extract32(BuiltinFunctionT): _id = "extract32" _inputs = [("b", BytesT.any()), ("start", IntegerT.unsigneds())] @@ -882,81 +870,47 @@ def infer_kwarg_types(self, node): @process_inputs def build_IR(self, expr, args, kwargs, context): - sub, index = args + bytez, index = args ret_type = kwargs["output_type"] - # Get length and specific element - if sub.location == STORAGE: - lengetter = IRnode.from_list(["sload", "_sub"], typ=INT128_T) - elementgetter = _storage_element_getter - - else: - op = sub.location.load_op - lengetter = IRnode.from_list([op, "_sub"], typ=INT128_T) - elementgetter = _generic_element_getter(op) - - # TODO rewrite all this with cache_when_complex and bitshifts - - # Special case: index known to be a multiple of 32 - if isinstance(index.value, int) and not index.value % 32: - o = IRnode.from_list( - [ - "with", - "_sub", - sub, - elementgetter( - ["div", clamp2(0, index, ["sub", lengetter, 32], signed=True), 32] - ), - ], - typ=ret_type, - annotation="extracting 32 bytes", - ) - # General case - else: - o = IRnode.from_list( - [ - "with", - "_sub", - sub, - [ - "with", - "_len", - lengetter, - [ - "with", - "_index", - clamp2(0, index, ["sub", "_len", 32], signed=True), - [ - "with", - "_mi32", - ["mod", "_index", 32], - [ - "with", - "_di32", - ["div", "_index", 32], - [ - "if", - "_mi32", - [ - "add", - ["mul", elementgetter("_di32"), ["exp", 256, "_mi32"]], - [ - "div", - elementgetter(["add", "_di32", 1]), - ["exp", 256, ["sub", 32, "_mi32"]], - ], - ], - elementgetter("_di32"), - ], - ], - ], - ], - ], - ], - typ=ret_type, - annotation="extract32", - ) - return IRnode.from_list(clamp_basetype(o), typ=ret_type) + def finalize(ret): + annotation = "extract32" + ret = IRnode.from_list(ret, typ=ret_type, annotation=annotation) + return clamp_basetype(ret) + + with bytez.cache_when_complex("_sub") as (b1, bytez): + # merge + length = get_bytearray_length(bytez) + index = clamp2(0, index, ["sub", length, 32], signed=True) + with index.cache_when_complex("_index") as (b2, index): + assert not index.typ.is_signed + + # "easy" case, byte- addressed locations: + if bytez.location.word_scale == 32: + word = LOAD(add_ofst(bytes_data_ptr(bytez), index)) + return finalize(b1.resolve(b2.resolve(word))) + + # storage and transient storage, word-addressed + assert bytez.location.word_scale == 1 + + slot = IRnode.from_list(["div", index, 32]) + # byte offset within the slot + byte_ofst = IRnode.from_list(["mod", index, 32]) + + with byte_ofst.cache_when_complex("byte_ofst") as ( + b3, + byte_ofst, + ), slot.cache_when_complex("slot") as (b4, slot): + # perform two loads and merge + w1 = LOAD(add_ofst(bytes_data_ptr(bytez), slot)) + w2 = LOAD(add_ofst(bytes_data_ptr(bytez), ["add", slot, 1])) + + left_bytes = shl(["mul", 8, byte_ofst], w1) + right_bytes = shr(["mul", 8, ["sub", 32, byte_ofst]], w2) + merged = ["or", left_bytes, right_bytes] + + ret = ["if", byte_ofst, merged, left_bytes] + return finalize(b1.resolve(b2.resolve(b3.resolve(b4.resolve(ret))))) class AsWeiValue(BuiltinFunctionT): diff --git a/vyper/evm/address_space.py b/vyper/evm/address_space.py index fcbd4bcf63..08bef88e58 100644 --- a/vyper/evm/address_space.py +++ b/vyper/evm/address_space.py @@ -28,6 +28,10 @@ class AddrSpace: # TODO maybe make positional instead of defaulting to None store_op: Optional[str] = None + @property + def word_addressable(self) -> bool: + return self.word_scale == 1 + # alternative: # class Memory(AddrSpace):