Skip to content

Commit

Permalink
chore: fix test for slice (#3633)
Browse files Browse the repository at this point in the history
fix some test cases for `slice` and simplify the test logic

---------

Co-authored-by: Charles Cooper <[email protected]>
  • Loading branch information
tserg and charles-cooper authored Nov 2, 2023
1 parent 52dc413 commit 9ce56e7
Showing 1 changed file with 55 additions and 33 deletions.
88 changes: 55 additions & 33 deletions tests/parser/functions/test_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ def slice_tower_test(inp1: Bytes[50]) -> Bytes[50]:
_bytes_1024 = st.binary(min_size=0, max_size=1024)


@pytest.mark.parametrize("literal_start", (True, False))
@pytest.mark.parametrize("literal_length", (True, False))
@pytest.mark.parametrize("use_literal_start", (True, False))
@pytest.mark.parametrize("use_literal_length", (True, False))
@pytest.mark.parametrize("opt_level", list(OptimizationLevel))
@given(start=_draw_1024, length=_draw_1024, length_bound=_draw_1024_1, bytesdata=_bytes_1024)
@settings(max_examples=100)
Expand All @@ -45,13 +45,13 @@ def test_slice_immutable(
opt_level,
bytesdata,
start,
literal_start,
use_literal_start,
length,
literal_length,
use_literal_length,
length_bound,
):
_start = start if literal_start else "start"
_length = length if literal_length else "length"
_start = start if use_literal_start else "start"
_length = length if use_literal_length else "length"

code = f"""
IMMUTABLE_BYTES: immutable(Bytes[{length_bound}])
Expand All @@ -71,10 +71,10 @@ def _get_contract():
return get_contract(code, bytesdata, start, length, override_opt_level=opt_level)

if (
(start + length > length_bound and literal_start and literal_length)
or (literal_length and length > length_bound)
or (literal_start and start > length_bound)
or (literal_length and length < 1)
(start + length > length_bound and use_literal_start and use_literal_length)
or (use_literal_length and length > length_bound)
or (use_literal_start and start > length_bound)
or (use_literal_length and length == 0)
):
assert_compile_failed(lambda: _get_contract(), ArgumentException)
elif start + length > len(bytesdata) or (len(bytesdata) > length_bound):
Expand All @@ -86,32 +86,42 @@ def _get_contract():


@pytest.mark.parametrize("location", ("storage", "calldata", "memory", "literal", "code"))
@pytest.mark.parametrize("literal_start", (True, False))
@pytest.mark.parametrize("literal_length", (True, False))
@pytest.mark.parametrize("use_literal_start", (True, False))
@pytest.mark.parametrize("use_literal_length", (True, False))
@pytest.mark.parametrize("opt_level", list(OptimizationLevel))
@given(start=_draw_1024, length=_draw_1024, length_bound=_draw_1024_1, bytesdata=_bytes_1024)
@settings(max_examples=100)
@pytest.mark.fuzzing
def test_slice_bytes(
def test_slice_bytes_fuzz(
get_contract,
assert_compile_failed,
assert_tx_failed,
opt_level,
location,
bytesdata,
start,
literal_start,
use_literal_start,
length,
literal_length,
use_literal_length,
length_bound,
):
preamble = ""
if location == "memory":
spliced_code = f"foo: Bytes[{length_bound}] = inp"
foo = "foo"
elif location == "storage":
preamble = f"""
foo: Bytes[{length_bound}]
"""
spliced_code = "self.foo = inp"
foo = "self.foo"
elif location == "code":
preamble = f"""
IMMUTABLE_BYTES: immutable(Bytes[{length_bound}])
@external
def __init__(foo: Bytes[{length_bound}]):
IMMUTABLE_BYTES = foo
"""
spliced_code = ""
foo = "IMMUTABLE_BYTES"
elif location == "literal":
Expand All @@ -123,15 +133,11 @@ def test_slice_bytes(
else:
raise Exception("unreachable")

_start = start if literal_start else "start"
_length = length if literal_length else "length"
_start = start if use_literal_start else "start"
_length = length if use_literal_length else "length"

code = f"""
foo: Bytes[{length_bound}]
IMMUTABLE_BYTES: immutable(Bytes[{length_bound}])
@external
def __init__(foo: Bytes[{length_bound}]):
IMMUTABLE_BYTES = foo
{preamble}
@external
def do_slice(inp: Bytes[{length_bound}], start: uint256, length: uint256) -> Bytes[{length_bound}]:
Expand All @@ -142,24 +148,40 @@ def do_slice(inp: Bytes[{length_bound}], start: uint256, length: uint256) -> Byt
def _get_contract():
return get_contract(code, bytesdata, override_opt_level=opt_level)

data_length = len(bytesdata) if location == "literal" else length_bound
if (
(start + length > data_length and literal_start and literal_length)
or (literal_length and length > data_length)
or (location == "literal" and len(bytesdata) > length_bound)
or (literal_start and start > data_length)
or (literal_length and length < 1)
):
# length bound is the container size; input_bound is the bound on the input
# (which can be different, if the input is a literal)
input_bound = length_bound
slice_output_too_large = False

if location == "literal":
input_bound = len(bytesdata)

# ex.:
# @external
# def do_slice(inp: Bytes[1], start: uint256, length: uint256) -> Bytes[1]:
# return slice(b'\x00\x00', 0, length)
output_length = length if use_literal_length else input_bound
slice_output_too_large = output_length > length_bound

end = start + length

compile_time_oob = (
(use_literal_length and (length > input_bound or length == 0))
or (use_literal_start and start > input_bound)
or (use_literal_start and use_literal_length and start + length > input_bound)
)

if compile_time_oob or slice_output_too_large:
assert_compile_failed(lambda: _get_contract(), (ArgumentException, TypeMismatch))
elif len(bytesdata) > data_length:
elif location == "code" and len(bytesdata) > length_bound:
# deploy fail
assert_tx_failed(lambda: _get_contract())
elif start + length > len(bytesdata):
elif end > len(bytesdata) or len(bytesdata) > length_bound:
c = _get_contract()
assert_tx_failed(lambda: c.do_slice(bytesdata, start, length))
else:
c = _get_contract()
assert c.do_slice(bytesdata, start, length) == bytesdata[start : start + length], code
assert c.do_slice(bytesdata, start, length) == bytesdata[start:end], code


def test_slice_private(get_contract):
Expand Down

0 comments on commit 9ce56e7

Please sign in to comment.