Skip to content

Commit

Permalink
fix[codegen]: fix transient codegen for slice and extract32 (#3874)
Browse files Browse the repository at this point in the history
this commit fixes transient storage codegen for `slice()` and
`extract32()` builtins. previously, some codegen routines were hardcoded
to check for storage; this commit changes them to check if they are
`word_addressable` instead. this commit also refactors the `extract32()`
code generation logic to be simpler and use more recent APIs. the new
implementation relies on the optimizer to do some optimizations which in
this routine were previously hand-rolled.

---------

Co-authored-by: Charles Cooper <[email protected]>
  • Loading branch information
cyberthirst and charles-cooper authored Apr 4, 2024
1 parent ee11e3d commit 63b8d1d
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 101 deletions.
26 changes: 19 additions & 7 deletions tests/functional/builtins/codegen/test_extract32.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,22 @@
def test_extract32_extraction(tx_failed, get_contract_with_gas_estimation):
extract32_code = """
y: Bytes[100]
import pytest

from vyper.evm.opcodes import version_check


@pytest.mark.parametrize("location", ["storage", "transient"])
def test_extract32_extraction(tx_failed, get_contract_with_gas_estimation, location):
if location == "transient" and not version_check(begin="cancun"):
pytest.skip(
"Skipping test as storage_location is 'transient' and EVM version is pre-Cancun"
)
if location == "storage":
decl = "y: Bytes[100]"
elif location == "transient":
decl = "y: transient(Bytes[100])"
else:
raise Exception("unreachable")
extract32_code = f"""
{decl}
@external
def extrakt32(inp: Bytes[100], index: uint256) -> bytes32:
return extract32(inp, index)
Expand Down Expand Up @@ -43,8 +59,6 @@ def extrakt32_storage(index: uint256, inp: Bytes[100]) -> bytes32:
with tx_failed():
c.extrakt32(S, i)

print("Passed bytes32 extraction test")


def test_extract32_code(tx_failed, get_contract_with_gas_estimation):
extract32_code = """
Expand Down Expand Up @@ -84,5 +98,3 @@ def foq(inp: Bytes[32]) -> address:

with tx_failed():
c.foq(b"crow" * 8)

print("Passed extract32 test")
34 changes: 30 additions & 4 deletions tests/functional/builtins/codegen/test_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from vyper.compiler import compile_code
from vyper.compiler.settings import OptimizationLevel, Settings
from vyper.evm.opcodes import version_check
from vyper.exceptions import ArgumentException, TypeMismatch

_fun_bytes32_bounds = [(0, 32), (3, 29), (27, 5), (0, 5), (5, 3), (30, 2)]
Expand Down Expand Up @@ -93,7 +94,9 @@ def _get_contract():
assert c.do_splice() == bytesdata[start : start + length]


@pytest.mark.parametrize("location", ("storage", "calldata", "memory", "literal", "code"))
@pytest.mark.parametrize(
"location", ["storage", "transient", "calldata", "memory", "literal", "code"]
)
@pytest.mark.parametrize("use_literal_start", (True, False))
@pytest.mark.parametrize("use_literal_length", (True, False))
@pytest.mark.parametrize("opt_level", list(OptimizationLevel))
Expand All @@ -112,13 +115,23 @@ def test_slice_bytes_fuzz(
use_literal_length,
length_bound,
):
if location == "transient" and not version_check(begin="cancun"):
pytest.skip(
"Skipping test as storage_location is 'transient' and EVM version is pre-Cancun"
)
preamble = ""
if location == "memory":
spliced_code = f"foo: Bytes[{length_bound}] = inp"
foo = "foo"
elif location == "storage":
preamble = f"""
foo: Bytes[{length_bound}]
"""
spliced_code = "self.foo = inp"
foo = "self.foo"
elif location == "transient":
preamble = f"""
foo: transient(Bytes[{length_bound}])
"""
spliced_code = "self.foo = inp"
foo = "self.foo"
Expand Down Expand Up @@ -194,10 +207,23 @@ def _get_contract():
assert c.do_slice(bytesdata, start, length) == bytesdata[start:end], code


def test_slice_private(get_contract):
@pytest.mark.parametrize("location", ["storage", "transient"])
def test_slice_private(get_contract, location):
if location == "transient" and not version_check(begin="cancun"):
pytest.skip(
"Skipping test as storage_location is 'transient' and EVM version is pre-Cancun"
)

# test there are no buffer overruns in the slice function
code = """
bytez: public(String[12])
if location == "storage":
decl = "bytez: public(String[12])"
elif location == "transient":
decl = "bytez: public(transient(String[12]))"
else:
raise Exception("unreachable")

code = f"""
{decl}
@internal
def _slice(start: uint256, length: uint256):
Expand Down
134 changes: 44 additions & 90 deletions vyper/builtins/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from vyper.codegen.abi_encoder import abi_encode
from vyper.codegen.context import Context, VariableRecord
from vyper.codegen.core import (
LOAD,
STORE,
IRnode,
add_ofst,
Expand Down Expand Up @@ -36,7 +37,7 @@
from vyper.codegen.expr import Expr
from vyper.codegen.ir_node import Encoding, scope_multi
from vyper.codegen.keccak256_helper import keccak256_helper
from vyper.evm.address_space import MEMORY, STORAGE
from vyper.evm.address_space import MEMORY
from vyper.exceptions import (
ArgumentException,
CompilerPanic,
Expand Down Expand Up @@ -378,7 +379,7 @@ def build_IR(self, expr, args, kwargs, context):

# add 32 bytes to the buffer size bc word access might
# be unaligned (see below)
if src.location == STORAGE:
if src.location.word_addressable:
buflen += 32

# Get returntype string or bytes
Expand All @@ -405,8 +406,8 @@ def build_IR(self, expr, args, kwargs, context):
src_data = bytes_data_ptr(src)

# general case. byte-for-byte copy
if src.location == STORAGE:
# because slice uses byte-addressing but storage
if src.location.word_addressable:
# because slice uses byte-addressing but storage/tstorage
# is word-aligned, this algorithm starts at some number
# of bytes before the data section starts, and might copy
# an extra word. the pseudocode is:
Expand Down Expand Up @@ -838,19 +839,6 @@ class ECMul(_ECArith):
_precompile = 0x7


def _generic_element_getter(op):
def f(index):
return IRnode.from_list(
[op, ["add", "_sub", ["add", 32, ["mul", 32, index]]]], typ=INT128_T
)

return f


def _storage_element_getter(index):
return IRnode.from_list(["sload", ["add", "_sub", ["add", 1, index]]], typ=INT128_T)


class Extract32(BuiltinFunctionT):
_id = "extract32"
_inputs = [("b", BytesT.any()), ("start", IntegerT.unsigneds())]
Expand Down Expand Up @@ -882,81 +870,47 @@ def infer_kwarg_types(self, node):

@process_inputs
def build_IR(self, expr, args, kwargs, context):
sub, index = args
bytez, index = args
ret_type = kwargs["output_type"]

# Get length and specific element
if sub.location == STORAGE:
lengetter = IRnode.from_list(["sload", "_sub"], typ=INT128_T)
elementgetter = _storage_element_getter

else:
op = sub.location.load_op
lengetter = IRnode.from_list([op, "_sub"], typ=INT128_T)
elementgetter = _generic_element_getter(op)

# TODO rewrite all this with cache_when_complex and bitshifts

# Special case: index known to be a multiple of 32
if isinstance(index.value, int) and not index.value % 32:
o = IRnode.from_list(
[
"with",
"_sub",
sub,
elementgetter(
["div", clamp2(0, index, ["sub", lengetter, 32], signed=True), 32]
),
],
typ=ret_type,
annotation="extracting 32 bytes",
)
# General case
else:
o = IRnode.from_list(
[
"with",
"_sub",
sub,
[
"with",
"_len",
lengetter,
[
"with",
"_index",
clamp2(0, index, ["sub", "_len", 32], signed=True),
[
"with",
"_mi32",
["mod", "_index", 32],
[
"with",
"_di32",
["div", "_index", 32],
[
"if",
"_mi32",
[
"add",
["mul", elementgetter("_di32"), ["exp", 256, "_mi32"]],
[
"div",
elementgetter(["add", "_di32", 1]),
["exp", 256, ["sub", 32, "_mi32"]],
],
],
elementgetter("_di32"),
],
],
],
],
],
],
typ=ret_type,
annotation="extract32",
)
return IRnode.from_list(clamp_basetype(o), typ=ret_type)
def finalize(ret):
annotation = "extract32"
ret = IRnode.from_list(ret, typ=ret_type, annotation=annotation)
return clamp_basetype(ret)

with bytez.cache_when_complex("_sub") as (b1, bytez):
# merge
length = get_bytearray_length(bytez)
index = clamp2(0, index, ["sub", length, 32], signed=True)
with index.cache_when_complex("_index") as (b2, index):
assert not index.typ.is_signed

# "easy" case, byte- addressed locations:
if bytez.location.word_scale == 32:
word = LOAD(add_ofst(bytes_data_ptr(bytez), index))
return finalize(b1.resolve(b2.resolve(word)))

# storage and transient storage, word-addressed
assert bytez.location.word_scale == 1

slot = IRnode.from_list(["div", index, 32])
# byte offset within the slot
byte_ofst = IRnode.from_list(["mod", index, 32])

with byte_ofst.cache_when_complex("byte_ofst") as (
b3,
byte_ofst,
), slot.cache_when_complex("slot") as (b4, slot):
# perform two loads and merge
w1 = LOAD(add_ofst(bytes_data_ptr(bytez), slot))
w2 = LOAD(add_ofst(bytes_data_ptr(bytez), ["add", slot, 1]))

left_bytes = shl(["mul", 8, byte_ofst], w1)
right_bytes = shr(["mul", 8, ["sub", 32, byte_ofst]], w2)
merged = ["or", left_bytes, right_bytes]

ret = ["if", byte_ofst, merged, left_bytes]
return finalize(b1.resolve(b2.resolve(b3.resolve(b4.resolve(ret)))))


class AsWeiValue(BuiltinFunctionT):
Expand Down
4 changes: 4 additions & 0 deletions vyper/evm/address_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ class AddrSpace:
# TODO maybe make positional instead of defaulting to None
store_op: Optional[str] = None

@property
def word_addressable(self) -> bool:
return self.word_scale == 1


# alternative:
# class Memory(AddrSpace):
Expand Down

0 comments on commit 63b8d1d

Please sign in to comment.