Skip to content

Commit

Permalink
refactor: make assert_tx_failed a contextmanager (#3706)
Browse files Browse the repository at this point in the history
rename `assert_tx_failed` to `tx_failed` and change it into a context
manager which has a similar API to `pytest.raises()`.

---------

Co-authored-by: Charles Cooper <[email protected]>
  • Loading branch information
DanielSchiavini and charles-cooper authored Dec 23, 2023
1 parent 88c09a2 commit 2e41873
Show file tree
Hide file tree
Showing 63 changed files with 1,051 additions and 825 deletions.
4 changes: 2 additions & 2 deletions docs/testing-contracts-ethtester.rst
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,9 @@ To test events and failed transactions we expand our simple storage contract to

Next, we take a look at the two fixtures that will allow us to read the event logs and to check for failed transactions.

.. literalinclude:: ../tests/base_conftest.py
.. literalinclude:: ../tests/conftest.py
:language: python
:pyobject: assert_tx_failed
:pyobject: tx_failed

The fixture to assert failed transactions defaults to check for a ``TransactionFailed`` exception, but can be used to check for different exceptions too, as shown below. Also note that the chain gets reverted to the state before the failed transaction.

Expand Down
28 changes: 6 additions & 22 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
import logging
from contextlib import contextmanager
from functools import wraps

import hypothesis
Expand Down Expand Up @@ -411,23 +412,6 @@ def assert_compile_failed(function_to_test, exception=Exception):
return assert_compile_failed


# TODO this should not be a fixture
@pytest.fixture
def search_for_sublist():
def search_for_sublist(ir, sublist):
_list = ir.to_list() if hasattr(ir, "to_list") else ir
if _list == sublist:
return True
if isinstance(_list, list):
for i in _list:
ret = search_for_sublist(i, sublist)
if ret is True:
return ret
return False

return search_for_sublist


@pytest.fixture
def create2_address_of(keccak):
def _f(_addr, _salt, _initcode):
Expand Down Expand Up @@ -484,16 +468,16 @@ def get_logs(tx_hash, c, event_name):
return get_logs


# TODO replace me with function like `with anchor_state()`
@pytest.fixture(scope="module")
def assert_tx_failed(tester):
def assert_tx_failed(function_to_test, exception=TransactionFailed, exc_text=None):
def tx_failed(tester):
@contextmanager
def fn(exception=TransactionFailed, exc_text=None):
snapshot_id = tester.take_snapshot()
with pytest.raises(exception) as excinfo:
function_to_test()
yield excinfo
tester.revert_to_snapshot(snapshot_id)
if exc_text:
# TODO test equality
assert exc_text in str(excinfo.value), (exc_text, excinfo.value)

return assert_tx_failed
return fn
25 changes: 15 additions & 10 deletions tests/functional/builtins/codegen/test_abi_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ def abi_decode(x: Bytes[32]) -> uint256:
b"\x01" * 96, # Length of byte array is beyond size bound of output type
],
)
def test_clamper(get_contract, assert_tx_failed, input_):
def test_clamper(get_contract, tx_failed, input_):
contract = """
@external
def abi_decode(x: Bytes[96]) -> (uint256, uint256):
Expand All @@ -341,10 +341,11 @@ def abi_decode(x: Bytes[96]) -> (uint256, uint256):
return a, b
"""
c = get_contract(contract)
assert_tx_failed(lambda: c.abi_decode(input_))
with tx_failed():
c.abi_decode(input_)


def test_clamper_nested_uint8(get_contract, assert_tx_failed):
def test_clamper_nested_uint8(get_contract, tx_failed):
# check that _abi_decode clamps on word-types even when it is in a nested expression
# decode -> validate uint8 -> revert if input >= 256 -> cast back to uint256
contract = """
Expand All @@ -355,10 +356,11 @@ def abi_decode(x: uint256) -> uint256:
"""
c = get_contract(contract)
assert c.abi_decode(255) == 255
assert_tx_failed(lambda: c.abi_decode(256))
with tx_failed():
c.abi_decode(256)


def test_clamper_nested_bytes(get_contract, assert_tx_failed):
def test_clamper_nested_bytes(get_contract, tx_failed):
# check that _abi_decode clamps dynamic even when it is in a nested expression
# decode -> validate Bytes[20] -> revert if len(input) > 20 -> convert back to -> add 1
contract = """
Expand All @@ -369,7 +371,8 @@ def abi_decode(x: Bytes[96]) -> Bytes[21]:
"""
c = get_contract(contract)
assert c.abi_decode(abi.encode("(bytes)", (b"bc",))) == b"abc"
assert_tx_failed(lambda: c.abi_decode(abi.encode("(bytes)", (b"a" * 22,))))
with tx_failed():
c.abi_decode(abi.encode("(bytes)", (b"a" * 22,)))


@pytest.mark.parametrize(
Expand All @@ -381,7 +384,7 @@ def abi_decode(x: Bytes[96]) -> Bytes[21]:
("Bytes[5]", b"\x01" * 192),
],
)
def test_clamper_dynamic(get_contract, assert_tx_failed, output_typ, input_):
def test_clamper_dynamic(get_contract, tx_failed, output_typ, input_):
contract = f"""
@external
def abi_decode(x: Bytes[192]) -> {output_typ}:
Expand All @@ -390,7 +393,8 @@ def abi_decode(x: Bytes[192]) -> {output_typ}:
return a
"""
c = get_contract(contract)
assert_tx_failed(lambda: c.abi_decode(input_))
with tx_failed():
c.abi_decode(input_)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -422,7 +426,7 @@ def abi_decode(x: Bytes[160]) -> uint256:
("Bytes[5]", "address", b"\x01" * 128),
],
)
def test_clamper_dynamic_tuple(get_contract, assert_tx_failed, output_typ1, output_typ2, input_):
def test_clamper_dynamic_tuple(get_contract, tx_failed, output_typ1, output_typ2, input_):
contract = f"""
@external
def abi_decode(x: Bytes[224]) -> ({output_typ1}, {output_typ2}):
Expand All @@ -432,7 +436,8 @@ def abi_decode(x: Bytes[224]) -> ({output_typ1}, {output_typ2}):
return a, b
"""
c = get_contract(contract)
assert_tx_failed(lambda: c.abi_decode(input_))
with tx_failed():
c.abi_decode(input_)


FAIL_LIST = [
Expand Down
5 changes: 3 additions & 2 deletions tests/functional/builtins/codegen/test_addmod.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
def test_uint256_addmod(assert_tx_failed, get_contract_with_gas_estimation):
def test_uint256_addmod(tx_failed, get_contract_with_gas_estimation):
uint256_code = """
@external
def _uint256_addmod(x: uint256, y: uint256, z: uint256) -> uint256:
Expand All @@ -11,7 +11,8 @@ def _uint256_addmod(x: uint256, y: uint256, z: uint256) -> uint256:
assert c._uint256_addmod(32, 2, 32) == 2
assert c._uint256_addmod((2**256) - 1, 0, 2) == 1
assert c._uint256_addmod(2**255, 2**255, 6) == 4
assert_tx_failed(lambda: c._uint256_addmod(1, 2, 0))
with tx_failed():
c._uint256_addmod(1, 2, 0)


def test_uint256_addmod_ext_call(
Expand Down
16 changes: 9 additions & 7 deletions tests/functional/builtins/codegen/test_as_wei_value.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@


@pytest.mark.parametrize("denom,multiplier", wei_denoms.items())
def test_wei_uint256(get_contract, assert_tx_failed, denom, multiplier):
def test_wei_uint256(get_contract, tx_failed, denom, multiplier):
code = f"""
@external
def foo(a: uint256) -> uint256:
Expand All @@ -36,11 +36,12 @@ def foo(a: uint256) -> uint256:
assert c.foo(value) == value * (10**multiplier)

value = (2**256 - 1) // (10 ** (multiplier - 1))
assert_tx_failed(lambda: c.foo(value))
with tx_failed():
c.foo(value)


@pytest.mark.parametrize("denom,multiplier", wei_denoms.items())
def test_wei_int128(get_contract, assert_tx_failed, denom, multiplier):
def test_wei_int128(get_contract, tx_failed, denom, multiplier):
code = f"""
@external
def foo(a: int128) -> uint256:
Expand All @@ -54,7 +55,7 @@ def foo(a: int128) -> uint256:


@pytest.mark.parametrize("denom,multiplier", wei_denoms.items())
def test_wei_decimal(get_contract, assert_tx_failed, denom, multiplier):
def test_wei_decimal(get_contract, tx_failed, denom, multiplier):
code = f"""
@external
def foo(a: decimal) -> uint256:
Expand All @@ -69,20 +70,21 @@ def foo(a: decimal) -> uint256:

@pytest.mark.parametrize("value", (-1, -(2**127)))
@pytest.mark.parametrize("data_type", ["decimal", "int128"])
def test_negative_value_reverts(get_contract, assert_tx_failed, value, data_type):
def test_negative_value_reverts(get_contract, tx_failed, value, data_type):
code = f"""
@external
def foo(a: {data_type}) -> uint256:
return as_wei_value(a, "ether")
"""

c = get_contract(code)
assert_tx_failed(lambda: c.foo(value))
with tx_failed():
c.foo(value)


@pytest.mark.parametrize("denom,multiplier", wei_denoms.items())
@pytest.mark.parametrize("data_type", ["decimal", "int128", "uint256"])
def test_zero_value(get_contract, assert_tx_failed, denom, multiplier, data_type):
def test_zero_value(get_contract, tx_failed, denom, multiplier, data_type):
code = f"""
@external
def foo(a: {data_type}) -> uint256:
Expand Down
13 changes: 8 additions & 5 deletions tests/functional/builtins/codegen/test_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,7 +511,7 @@ def bar(a: uint256) -> Roles:
@pytest.mark.parametrize("typ", ["uint8", "int128", "int256", "uint256"])
@pytest.mark.parametrize("val", [1, 2, 3, 4, 2**128, 2**256 - 1, 2**256 - 2])
def test_flag_conversion_2(
get_contract_with_gas_estimation, assert_compile_failed, assert_tx_failed, val, typ
get_contract_with_gas_estimation, assert_compile_failed, tx_failed, val, typ
):
contract = f"""
flag Status:
Expand All @@ -529,7 +529,8 @@ def foo(a: {typ}) -> Status:
if lo <= val <= hi:
assert c.foo(val) == val
else:
assert_tx_failed(lambda: c.foo(val))
with tx_failed():
c.foo(val)
else:
assert_compile_failed(lambda: get_contract_with_gas_estimation(contract), TypeMismatch)

Expand Down Expand Up @@ -608,7 +609,7 @@ def foo() -> {t_bytes}:
@pytest.mark.parametrize("i_typ,o_typ,val", generate_reverting_cases())
@pytest.mark.fuzzing
def test_conversion_failures(
get_contract_with_gas_estimation, assert_compile_failed, assert_tx_failed, i_typ, o_typ, val
get_contract_with_gas_estimation, assert_compile_failed, tx_failed, i_typ, o_typ, val
):
"""
Test multiple contracts and check for a specific exception.
Expand Down Expand Up @@ -650,7 +651,8 @@ def foo():
"""

c2 = get_contract_with_gas_estimation(contract_2)
assert_tx_failed(lambda: c2.foo())
with tx_failed():
c2.foo()

contract_3 = f"""
@external
Expand All @@ -659,4 +661,5 @@ def foo(bar: {i_typ}) -> {o_typ}:
"""

c3 = get_contract_with_gas_estimation(contract_3)
assert_tx_failed(lambda: c3.foo(val))
with tx_failed():
c3.foo(val)
Loading

0 comments on commit 2e41873

Please sign in to comment.