From 9ce56e7d8b0196a5d51d706a8d2376b98d3e8ad7 Mon Sep 17 00:00:00 2001 From: tserg <8017125+tserg@users.noreply.github.com> Date: Fri, 3 Nov 2023 00:33:20 +0800 Subject: [PATCH] chore: fix test for `slice` (#3633) fix some test cases for `slice` and simplify the test logic --------- Co-authored-by: Charles Cooper --- tests/parser/functions/test_slice.py | 88 +++++++++++++++++----------- 1 file changed, 55 insertions(+), 33 deletions(-) diff --git a/tests/parser/functions/test_slice.py b/tests/parser/functions/test_slice.py index 3090dafda0..53e092019f 100644 --- a/tests/parser/functions/test_slice.py +++ b/tests/parser/functions/test_slice.py @@ -32,8 +32,8 @@ def slice_tower_test(inp1: Bytes[50]) -> Bytes[50]: _bytes_1024 = st.binary(min_size=0, max_size=1024) -@pytest.mark.parametrize("literal_start", (True, False)) -@pytest.mark.parametrize("literal_length", (True, False)) +@pytest.mark.parametrize("use_literal_start", (True, False)) +@pytest.mark.parametrize("use_literal_length", (True, False)) @pytest.mark.parametrize("opt_level", list(OptimizationLevel)) @given(start=_draw_1024, length=_draw_1024, length_bound=_draw_1024_1, bytesdata=_bytes_1024) @settings(max_examples=100) @@ -45,13 +45,13 @@ def test_slice_immutable( opt_level, bytesdata, start, - literal_start, + use_literal_start, length, - literal_length, + use_literal_length, length_bound, ): - _start = start if literal_start else "start" - _length = length if literal_length else "length" + _start = start if use_literal_start else "start" + _length = length if use_literal_length else "length" code = f""" IMMUTABLE_BYTES: immutable(Bytes[{length_bound}]) @@ -71,10 +71,10 @@ def _get_contract(): return get_contract(code, bytesdata, start, length, override_opt_level=opt_level) if ( - (start + length > length_bound and literal_start and literal_length) - or (literal_length and length > length_bound) - or (literal_start and start > length_bound) - or (literal_length and length < 1) + (start + length > length_bound and use_literal_start and use_literal_length) + or (use_literal_length and length > length_bound) + or (use_literal_start and start > length_bound) + or (use_literal_length and length == 0) ): assert_compile_failed(lambda: _get_contract(), ArgumentException) elif start + length > len(bytesdata) or (len(bytesdata) > length_bound): @@ -86,13 +86,13 @@ def _get_contract(): @pytest.mark.parametrize("location", ("storage", "calldata", "memory", "literal", "code")) -@pytest.mark.parametrize("literal_start", (True, False)) -@pytest.mark.parametrize("literal_length", (True, False)) +@pytest.mark.parametrize("use_literal_start", (True, False)) +@pytest.mark.parametrize("use_literal_length", (True, False)) @pytest.mark.parametrize("opt_level", list(OptimizationLevel)) @given(start=_draw_1024, length=_draw_1024, length_bound=_draw_1024_1, bytesdata=_bytes_1024) @settings(max_examples=100) @pytest.mark.fuzzing -def test_slice_bytes( +def test_slice_bytes_fuzz( get_contract, assert_compile_failed, assert_tx_failed, @@ -100,18 +100,28 @@ def test_slice_bytes( location, bytesdata, start, - literal_start, + use_literal_start, length, - literal_length, + use_literal_length, length_bound, ): + preamble = "" if location == "memory": spliced_code = f"foo: Bytes[{length_bound}] = inp" foo = "foo" elif location == "storage": + preamble = f""" +foo: Bytes[{length_bound}] + """ spliced_code = "self.foo = inp" foo = "self.foo" elif location == "code": + preamble = f""" +IMMUTABLE_BYTES: immutable(Bytes[{length_bound}]) +@external +def __init__(foo: Bytes[{length_bound}]): + IMMUTABLE_BYTES = foo + """ spliced_code = "" foo = "IMMUTABLE_BYTES" elif location == "literal": @@ -123,15 +133,11 @@ def test_slice_bytes( else: raise Exception("unreachable") - _start = start if literal_start else "start" - _length = length if literal_length else "length" + _start = start if use_literal_start else "start" + _length = length if use_literal_length else "length" code = f""" -foo: Bytes[{length_bound}] -IMMUTABLE_BYTES: immutable(Bytes[{length_bound}]) -@external -def __init__(foo: Bytes[{length_bound}]): - IMMUTABLE_BYTES = foo +{preamble} @external def do_slice(inp: Bytes[{length_bound}], start: uint256, length: uint256) -> Bytes[{length_bound}]: @@ -142,24 +148,40 @@ def do_slice(inp: Bytes[{length_bound}], start: uint256, length: uint256) -> Byt def _get_contract(): return get_contract(code, bytesdata, override_opt_level=opt_level) - data_length = len(bytesdata) if location == "literal" else length_bound - if ( - (start + length > data_length and literal_start and literal_length) - or (literal_length and length > data_length) - or (location == "literal" and len(bytesdata) > length_bound) - or (literal_start and start > data_length) - or (literal_length and length < 1) - ): + # length bound is the container size; input_bound is the bound on the input + # (which can be different, if the input is a literal) + input_bound = length_bound + slice_output_too_large = False + + if location == "literal": + input_bound = len(bytesdata) + + # ex.: + # @external + # def do_slice(inp: Bytes[1], start: uint256, length: uint256) -> Bytes[1]: + # return slice(b'\x00\x00', 0, length) + output_length = length if use_literal_length else input_bound + slice_output_too_large = output_length > length_bound + + end = start + length + + compile_time_oob = ( + (use_literal_length and (length > input_bound or length == 0)) + or (use_literal_start and start > input_bound) + or (use_literal_start and use_literal_length and start + length > input_bound) + ) + + if compile_time_oob or slice_output_too_large: assert_compile_failed(lambda: _get_contract(), (ArgumentException, TypeMismatch)) - elif len(bytesdata) > data_length: + elif location == "code" and len(bytesdata) > length_bound: # deploy fail assert_tx_failed(lambda: _get_contract()) - elif start + length > len(bytesdata): + elif end > len(bytesdata) or len(bytesdata) > length_bound: c = _get_contract() assert_tx_failed(lambda: c.do_slice(bytesdata, start, length)) else: c = _get_contract() - assert c.do_slice(bytesdata, start, length) == bytesdata[start : start + length], code + assert c.do_slice(bytesdata, start, length) == bytesdata[start:end], code def test_slice_private(get_contract):