From e1abe7ed1b3d862e5f2e9390bd09c256e232c9f7 Mon Sep 17 00:00:00 2001 From: Paul Robinson Date: Mon, 21 Nov 2022 13:35:18 -0700 Subject: [PATCH] async contract filter methods & start async filter testing (#2715) # async contract filter methods * set up async filter testing fixtures * async a few filter tests --- conftest.py | 71 +++--- tests/core/contracts/conftest.py | 6 +- .../contracts/test_contract_call_interface.py | 2 +- .../test_contract_caller_interface.py | 2 +- tests/core/contracts/{_utils.py => utils.py} | 12 - tests/core/filtering/conftest.py | 81 ++++--- .../core/filtering/test_basic_filter_tests.py | 50 +++++ ...t_contract_create_filter_topic_merging.py} | 0 ...t_getLogs.py => test_contract_get_logs.py} | 0 .../test_filter_against_latest_blocks.py | 41 +++- .../test_filters_against_many_blocks.py | 181 ++++++++++++++- tests/core/filtering/utils.py | 76 +++++++ tests/utils.py | 62 +++++ web3/_utils/threads.py | 5 + web3/contract.py | 211 +++++++++++------- 15 files changed, 629 insertions(+), 171 deletions(-) rename tests/core/contracts/{_utils.py => utils.py} (79%) rename tests/core/filtering/{test_contract_createFilter_topic_merging.py => test_contract_create_filter_topic_merging.py} (100%) rename tests/core/filtering/{test_contract_getLogs.py => test_contract_get_logs.py} (100%) create mode 100644 tests/core/filtering/utils.py diff --git a/conftest.py b/conftest.py index 9c276371f1..652c23def1 100644 --- a/conftest.py +++ b/conftest.py @@ -1,5 +1,6 @@ import asyncio import pytest +import pytest_asyncio import time import warnings @@ -14,27 +15,11 @@ EthereumTesterProvider, ) - -class PollDelayCounter: - def __init__(self, initial_delay=0, max_delay=1, initial_step=0.01): - self.initial_delay = initial_delay - self.initial_step = initial_step - self.max_delay = max_delay - self.current_delay = initial_delay - - def __call__(self): - delay = self.current_delay - - if self.current_delay == 0: - self.current_delay += self.initial_step - else: - self.current_delay *= 2 - self.current_delay = min(self.current_delay, self.max_delay) - - return delay - - def reset(self): - self.current_delay = self.initial_delay +from tests.utils import ( + PollDelayCounter, + _async_wait_for_block_fixture_logic, + _async_wait_for_transaction_fixture_logic, +) @pytest.fixture() @@ -46,25 +31,12 @@ def is_testrpc_provider(provider): return isinstance(provider, EthereumTesterProvider) -def is_async_testrpc_provider(provider): - return isinstance(provider, AsyncEthereumTesterProvider) - - @pytest.fixture() def skip_if_testrpc(): - def _skip_if_testrpc(w3): if is_testrpc_provider(w3.provider): pytest.skip() - return _skip_if_testrpc - - -@pytest.fixture() -def async_skip_if_testrpc(): - def _skip_if_testrpc(async_w3): - if is_async_testrpc_provider(async_w3.provider): - pytest.skip() return _skip_if_testrpc @@ -76,6 +48,7 @@ def _wait_for_miner_start(w3, timeout=60): while not w3.eth.mining or not w3.eth.hashrate: time.sleep(poll_delay_counter()) timeout.check() + return _wait_for_miner_start @@ -89,6 +62,7 @@ def _wait_for_block(w3, block_number=1, timeout=None): while w3.eth.block_number < block_number: w3.manager.request_blocking("evm_mine", []) timeout.sleep(poll_delay_counter()) + return _wait_for_block @@ -105,6 +79,7 @@ def _wait_for_transaction(w3, txn_hash, timeout=120): timeout.check() return txn_receipt + return _wait_for_transaction @@ -123,4 +98,30 @@ def w3_strict_abi(): @pytest.fixture(autouse=True) def print_warnings(): - warnings.simplefilter('always') + warnings.simplefilter("always") + + +# --- async --- # + + +def is_async_testrpc_provider(provider): + return isinstance(provider, AsyncEthereumTesterProvider) + + +@pytest.fixture() +def async_skip_if_testrpc(): + def _skip_if_testrpc(async_w3): + if is_async_testrpc_provider(async_w3.provider): + pytest.skip() + + return _skip_if_testrpc + + +@pytest_asyncio.fixture() +async def async_wait_for_block(): + return _async_wait_for_block_fixture_logic + + +@pytest_asyncio.fixture() +async def async_wait_for_transaction(): + return _async_wait_for_transaction_fixture_logic diff --git a/tests/core/contracts/conftest.py b/tests/core/contracts/conftest.py index 59d13b7ed9..be175f3f1a 100644 --- a/tests/core/contracts/conftest.py +++ b/tests/core/contracts/conftest.py @@ -7,11 +7,13 @@ ) import pytest_asyncio -from _utils import ( +from tests.core.contracts.utils import ( async_deploy, - async_partial, deploy, ) +from tests.utils import ( + async_partial, +) from web3._utils.module_testing.emitter_contract import ( CONTRACT_EMITTER_ABI, CONTRACT_EMITTER_CODE, diff --git a/tests/core/contracts/test_contract_call_interface.py b/tests/core/contracts/test_contract_call_interface.py index 7d1338e944..24e1d84e52 100644 --- a/tests/core/contracts/test_contract_call_interface.py +++ b/tests/core/contracts/test_contract_call_interface.py @@ -16,7 +16,7 @@ ) import pytest_asyncio -from _utils import ( +from tests.core.contracts.utils import ( async_deploy, deploy, ) diff --git a/tests/core/contracts/test_contract_caller_interface.py b/tests/core/contracts/test_contract_caller_interface.py index 46c05b6e21..ebda46b5a3 100644 --- a/tests/core/contracts/test_contract_caller_interface.py +++ b/tests/core/contracts/test_contract_caller_interface.py @@ -2,7 +2,7 @@ import pytest_asyncio -from _utils import ( +from tests.core.contracts.utils import ( async_deploy, deploy, ) diff --git a/tests/core/contracts/_utils.py b/tests/core/contracts/utils.py similarity index 79% rename from tests/core/contracts/_utils.py rename to tests/core/contracts/utils.py index 12d1e5564a..75ba472fe2 100644 --- a/tests/core/contracts/_utils.py +++ b/tests/core/contracts/utils.py @@ -1,5 +1,3 @@ -import asyncio - from eth_utils.toolz import ( identity, ) @@ -27,13 +25,3 @@ async def async_deploy(async_web3, Contract, apply_func=identity, args=None): assert contract.address == address assert len(await async_web3.eth.get_code(contract.address)) > 0 return contract - - -def async_partial(f, *args, **kwargs): - async def f2(*args2, **kwargs2): - result = f(*args, *args2, **kwargs, **kwargs2) - if asyncio.iscoroutinefunction(f): - result = await result - return result - - return f2 diff --git a/tests/core/filtering/conftest.py b/tests/core/filtering/conftest.py index d3ddd6f153..56463280ca 100644 --- a/tests/core/filtering/conftest.py +++ b/tests/core/filtering/conftest.py @@ -6,19 +6,22 @@ encode_hex, event_signature_to_log_topic, ) +import pytest_asyncio -from web3 import Web3 +from tests.core.filtering.utils import ( + _async_emitter_fixture_logic, + _async_w3_fixture_logic, + _emitter_fixture_logic, + _w3_fixture_logic, +) +from tests.utils import ( + async_partial, +) from web3._utils.module_testing.emitter_contract import ( CONTRACT_EMITTER_ABI, CONTRACT_EMITTER_CODE, CONTRACT_EMITTER_RUNTIME, ) -from web3.middleware import ( - local_filter_middleware, -) -from web3.providers.eth_tester import ( - EthereumTesterProvider, -) @pytest.fixture( @@ -27,17 +30,7 @@ ids=["local_filter_middleware", "node_based_filter"], ) def w3(request): - use_filter_middleware = request.param - provider = EthereumTesterProvider() - w3 = Web3(provider) - if use_filter_middleware: - w3.middleware_onion.add(local_filter_middleware) - return w3 - - -@pytest.fixture(autouse=True) -def wait_for_mining_start(w3, wait_for_block): - wait_for_block(w3) + return _w3_fixture_logic(request) @pytest.fixture() @@ -71,16 +64,9 @@ def Emitter(w3, EMITTER): @pytest.fixture() def emitter(w3, Emitter, wait_for_transaction, wait_for_block, address_conversion_func): - wait_for_block(w3) - deploy_txn_hash = Emitter.constructor().transact({"gas": 10000000}) - deploy_receipt = wait_for_transaction(w3, deploy_txn_hash) - contract_address = address_conversion_func(deploy_receipt["contractAddress"]) - - bytecode = w3.eth.get_code(contract_address) - assert bytecode == Emitter.bytecode_runtime - _emitter = Emitter(address=contract_address) - assert _emitter.address == contract_address - return _emitter + return _emitter_fixture_logic( + w3, Emitter, wait_for_transaction, wait_for_block, address_conversion_func + ) class LogFunctions: @@ -174,3 +160,42 @@ def return_filter(contract=None, args=[]): @pytest.fixture(scope="module") def create_filter(request): return functools.partial(return_filter) + + +# --- async --- # + + +@pytest.fixture( + scope="function", + params=[True, False], + ids=["async_local_filter_middleware", "node_based_filter"], +) +def async_w3(request): + return _async_w3_fixture_logic(request) + + +@pytest.fixture() +def AsyncEmitter(async_w3, EMITTER): + return async_w3.eth.contract(**EMITTER) + + +@pytest_asyncio.fixture() +async def async_emitter( + async_w3, + AsyncEmitter, + async_wait_for_transaction, + async_wait_for_block, + address_conversion_func, +): + return await _async_emitter_fixture_logic( + async_w3, + AsyncEmitter, + async_wait_for_transaction, + async_wait_for_block, + address_conversion_func, + ) + + +@pytest.fixture(scope="module") +def async_create_filter(request): + return async_partial(return_filter) diff --git a/tests/core/filtering/test_basic_filter_tests.py b/tests/core/filtering/test_basic_filter_tests.py index 43a2413f19..957e0f08e3 100644 --- a/tests/core/filtering/test_basic_filter_tests.py +++ b/tests/core/filtering/test_basic_filter_tests.py @@ -1,3 +1,10 @@ +import pytest + +from tests.core.filtering.utils import ( + async_range, +) + + def test_filtering_sequential_blocks_with_bounded_range( w3, emitter, Emitter, wait_for_transaction ): @@ -30,3 +37,46 @@ def test_requesting_results_with_no_new_blocks(w3, emitter): builder = emitter.events.LogNoArguments.build_filter() filter_ = builder.deploy(w3) assert len(filter_.get_new_entries()) == 0 + + +# --- async --- # + + +@pytest.mark.asyncio +async def test_async_filtering_sequential_blocks_with_bounded_range( + async_w3, async_emitter +): + builder = async_emitter.events.LogNoArguments.build_filter() + builder.fromBlock = "latest" + initial_block_number = await async_w3.eth.block_number + builder.toBlock = initial_block_number + 100 + filter_ = await builder.deploy(async_w3) + async for i in async_range(100): + await async_emitter.functions.logNoArgs(which=1).transact() + eth_block_number = await async_w3.eth.block_number + assert eth_block_number == initial_block_number + 100 + new_entries = await filter_.get_new_entries() + assert len(new_entries) == 100 + + +@pytest.mark.asyncio +async def test_async_filtering_starting_block_range(async_w3, async_emitter): + async for i in async_range(10): + await async_emitter.functions.logNoArgs(which=1).transact() + builder = async_emitter.events.LogNoArguments.build_filter() + filter_ = await builder.deploy(async_w3) + initial_block_number = await async_w3.eth.block_number + async for i in async_range(10): + await async_emitter.functions.logNoArgs(which=1).transact() + eth_block_number = await async_w3.eth.block_number + assert eth_block_number == initial_block_number + 10 + new_entries = await filter_.get_new_entries() + assert len(new_entries) == 10 + + +@pytest.mark.asyncio +async def test_async_requesting_results_with_no_new_blocks(async_w3, async_emitter): + builder = async_emitter.events.LogNoArguments.build_filter() + filter_ = await builder.deploy(async_w3) + new_entries = await filter_.get_new_entries() + assert len(new_entries) == 0 diff --git a/tests/core/filtering/test_contract_createFilter_topic_merging.py b/tests/core/filtering/test_contract_create_filter_topic_merging.py similarity index 100% rename from tests/core/filtering/test_contract_createFilter_topic_merging.py rename to tests/core/filtering/test_contract_create_filter_topic_merging.py diff --git a/tests/core/filtering/test_contract_getLogs.py b/tests/core/filtering/test_contract_get_logs.py similarity index 100% rename from tests/core/filtering/test_contract_getLogs.py rename to tests/core/filtering/test_contract_get_logs.py diff --git a/tests/core/filtering/test_filter_against_latest_blocks.py b/tests/core/filtering/test_filter_against_latest_blocks.py index b12e76b580..246f80804a 100644 --- a/tests/core/filtering/test_filter_against_latest_blocks.py +++ b/tests/core/filtering/test_filter_against_latest_blocks.py @@ -1,19 +1,16 @@ +import pytest + +from tests.core.filtering.utils import ( + async_range, +) from web3._utils.threads import ( Timeout, ) -from web3.providers.eth_tester import ( - EthereumTesterProvider, -) def test_sync_filter_against_latest_blocks(w3, sleep_interval, wait_for_block): - if not isinstance(w3.provider, EthereumTesterProvider): - w3.provider = EthereumTesterProvider() - txn_filter = w3.eth.filter("latest") - current_block = w3.eth.block_number - wait_for_block(w3, current_block + 3) found_block_hashes = [] @@ -28,3 +25,31 @@ def test_sync_filter_against_latest_blocks(w3, sleep_interval, wait_for_block): w3.eth.get_block(n + 1).hash for n in range(current_block, current_block + 3) ] assert found_block_hashes == expected_block_hashes + + +# --- async --- # + + +@pytest.mark.asyncio +async def test_async_filter_against_latest_blocks( + async_w3, sleep_interval, async_wait_for_block +): + txn_filter = await async_w3.eth.filter("latest") + current_block = await async_w3.eth.block_number + await async_wait_for_block(async_w3, current_block + 3) + + found_block_hashes = [] + with Timeout(5) as timeout: + while len(found_block_hashes) < 3: + new_entries = await txn_filter.get_new_entries() + found_block_hashes.extend(new_entries) + await timeout.async_sleep(sleep_interval()) + + assert len(found_block_hashes) == 3 + + expected_block_hashes = [] + async for n in async_range(current_block, current_block + 3): + block = await async_w3.eth.get_block(n + 1) + expected_block_hashes.append(block.hash) + + assert found_block_hashes == expected_block_hashes diff --git a/tests/core/filtering/test_filters_against_many_blocks.py b/tests/core/filtering/test_filters_against_many_blocks.py index 0acfa45494..968cd1666b 100644 --- a/tests/core/filtering/test_filters_against_many_blocks.py +++ b/tests/core/filtering/test_filters_against_many_blocks.py @@ -5,6 +5,10 @@ to_tuple, ) +from tests.core.filtering.utils import ( + async_range, +) + @to_tuple def deploy_contracts(w3, contract, wait_for_transaction): @@ -68,10 +72,8 @@ def test_event_filter_new_events( assert len(event_filter.get_new_entries()) == expected_match_counter -@pytest.mark.xfail(reason="Suspected eth-tester bug") def test_block_filter(w3): block_filter = w3.eth.filter("latest") - while w3.eth.block_number < 50: pad_with_transactions(w3) @@ -91,17 +93,11 @@ def test_transaction_filter_with_mining(w3): assert len(transaction_filter.get_new_entries()) == transaction_counter -@pytest.mark.xfail(reason="Suspected eth-tester bug") def test_transaction_filter_without_mining(w3): - - w3.providers[0].ethereum_tester.auto_mine_transactions = False transaction_filter = w3.eth.filter("pending") - transaction_counter = 0 - - transact_once = single_transaction(w3) while transaction_counter < 100: - next(transact_once) + single_transaction(w3) transaction_counter += 1 assert len(transaction_filter.get_new_entries()) == transaction_counter @@ -153,3 +149,170 @@ def gen_non_matching_transact(): pad_with_transactions(w3) assert len(event_filter.get_new_entries()) == expected_match_counter + + +# --- async --- # + + +async def async_deploy_contracts(async_w3, contract, async_wait_for_transaction): + txs = [] + async for i in async_range(25): + tx_hash = await contract.constructor().transact() + await async_wait_for_transaction(async_w3, tx_hash) + tx = await async_w3.eth.get_transaction_receipt(tx_hash) + + txs.append(tx["contractAddress"]) + + return tuple(txs) + + +async def async_pad_with_transactions(async_w3): + accounts = await async_w3.eth.accounts + async for tx_count in async_range(random.randint(0, 10)): + _from = accounts[random.randint(0, len(accounts) - 1)] + _to = accounts[random.randint(0, len(accounts) - 1)] + value = 50 + tx_count + await async_w3.eth.send_transaction({"from": _from, "to": _to, "value": value}) + + +async def async_single_transaction(async_w3): + accounts = await async_w3.eth.accounts + _from = accounts[random.randint(0, len(accounts) - 1)] + _to = accounts[random.randint(0, len(accounts) - 1)] + value = 50 + tx_hash = await async_w3.eth.send_transaction( + {"from": _from, "to": _to, "value": value} + ) + return tx_hash + + +@pytest.mark.asyncio +@pytest.mark.parametrize("api_style", ("v4", "build_filter")) +async def test_async_event_filter_new_events( + async_w3, + async_emitter, + async_wait_for_transaction, + api_style, +): + + matching_transact = async_emitter.functions.logNoArgs(which=1).transact + non_matching_transact = async_emitter.functions.logNoArgs(which=0).transact + + if api_style == "build_filter": + builder = async_emitter.events.LogNoArguments.build_filter() + builder.fromBlock = "latest" + event_filter = await builder.deploy(async_w3) + else: + event_filter = await async_emitter.events.LogNoArguments().create_filter( + fromBlock="latest" + ) + + expected_match_counter = 0 + + eth_block_number = await async_w3.eth.block_number + while eth_block_number < 50: + is_match = bool(random.randint(0, 1)) + if is_match: + expected_match_counter += 1 + awaited_matching_transact = await matching_transact() + await async_wait_for_transaction(async_w3, awaited_matching_transact) + await async_pad_with_transactions(async_w3) + continue + await non_matching_transact() + await async_pad_with_transactions(async_w3) + eth_block_number = await async_w3.eth.block_number + + new_entries = await event_filter.get_new_entries() + assert len(new_entries) == expected_match_counter + + +@pytest.mark.asyncio +async def test_async_block_filter(async_w3): + block_filter = await async_w3.eth.filter("latest") + + eth_block_number = await async_w3.eth.block_number + while eth_block_number < 50: + await async_pad_with_transactions(async_w3) + eth_block_number = await async_w3.eth.block_number + + new_entries = await block_filter.get_new_entries() + eth_block_number = await async_w3.eth.block_number + assert len(new_entries) == eth_block_number + + +@pytest.mark.asyncio +async def async_test_transaction_filter_with_mining(async_w3): + transaction_filter = await async_w3.eth.filter("pending") + transaction_counter = 0 + while transaction_counter < 100: + await async_single_transaction(async_w3) + transaction_counter += 1 + + new_entries = await transaction_filter.get_new_entries() + assert len(new_entries) == transaction_counter + + +@pytest.mark.asyncio +async def test_async_transaction_filter_without_mining(async_w3): + transaction_filter = await async_w3.eth.filter("pending") + transaction_counter = 0 + while transaction_counter < 100: + await async_single_transaction(async_w3) + transaction_counter += 1 + + new_entries = await transaction_filter.get_new_entries() + assert len(new_entries) == transaction_counter + + +@pytest.mark.asyncio +@pytest.mark.parametrize("api_style", ("v4", "build_filter")) +async def test_async_event_filter_new_events_many_deployed_contracts( + async_w3, + async_emitter, + AsyncEmitter, + async_wait_for_transaction, + api_style, +): + + matching_transact = async_emitter.functions.logNoArgs(which=1).transact + + deployed_contract_addresses = await async_deploy_contracts( + async_w3, AsyncEmitter, async_wait_for_transaction + ) + + async def gen_non_matching_transact(): + while True: + contract_address = deployed_contract_addresses[ + random.randint(0, len(deployed_contract_addresses) - 1) + ] + yield async_w3.eth.contract( + address=contract_address, abi=AsyncEmitter.abi + ).functions.logNoArgs(which=1).transact + + non_matching_transact = gen_non_matching_transact() + + if api_style == "build_filter": + builder = async_emitter.events.LogNoArguments.build_filter() + builder.fromBlock = "latest" + event_filter = await builder.deploy(async_w3) + else: + event_filter = await async_emitter.events.LogNoArguments().create_filter( + fromBlock="latest" + ) + + expected_match_counter = 0 + + eth_block_number = await async_w3.eth.block_number + while eth_block_number < 50: + is_match = bool(random.randint(0, 1)) + if is_match: + expected_match_counter += 1 + await matching_transact() + await async_pad_with_transactions(async_w3) + continue + await non_matching_transact.__anext__() + await async_pad_with_transactions(async_w3) + eth_block_number = await async_w3.eth.block_number + + new_entries = await event_filter.get_new_entries() + assert len(new_entries) == expected_match_counter diff --git a/tests/core/filtering/utils.py b/tests/core/filtering/utils.py new file mode 100644 index 0000000000..7382cc2d2b --- /dev/null +++ b/tests/core/filtering/utils.py @@ -0,0 +1,76 @@ +import asyncio + +from web3 import Web3 +from web3.eth import ( + AsyncEth, +) +from web3.middleware import ( + async_local_filter_middleware, + local_filter_middleware, +) +from web3.providers.eth_tester import ( + AsyncEthereumTesterProvider, + EthereumTesterProvider, +) + + +def _w3_fixture_logic(request): + use_filter_middleware = request.param + provider = EthereumTesterProvider() + w3 = Web3(provider) + if use_filter_middleware: + w3.middleware_onion.add(local_filter_middleware) + return w3 + + +def _emitter_fixture_logic( + w3, Emitter, wait_for_transaction, wait_for_block, address_conversion_func +): + wait_for_block(w3) + deploy_txn_hash = Emitter.constructor().transact({"gas": 10000000}) + deploy_receipt = wait_for_transaction(w3, deploy_txn_hash) + contract_address = address_conversion_func(deploy_receipt["contractAddress"]) + + bytecode = w3.eth.get_code(contract_address) + assert bytecode == Emitter.bytecode_runtime + _emitter = Emitter(address=contract_address) + assert _emitter.address == contract_address + return _emitter + + +# --- async --- # + + +def _async_w3_fixture_logic(request): + use_filter_middleware = request.param + provider = AsyncEthereumTesterProvider() + async_w3 = Web3(provider, modules={"eth": [AsyncEth]}, middlewares=[]) + + if use_filter_middleware: + async_w3.middleware_onion.add(async_local_filter_middleware) + return async_w3 + + +async def _async_emitter_fixture_logic( + async_w3, + AsyncEmitter, + async_wait_for_transaction, + async_wait_for_block, + address_conversion_func, +): + await async_wait_for_block(async_w3) + deploy_txn_hash = await AsyncEmitter.constructor().transact({"gas": 10000000}) + deploy_receipt = await async_wait_for_transaction(async_w3, deploy_txn_hash) + contract_address = address_conversion_func(deploy_receipt["contractAddress"]) + + bytecode = await async_w3.eth.get_code(contract_address) + assert bytecode == AsyncEmitter.bytecode_runtime + _emitter = AsyncEmitter(address=contract_address) + assert _emitter.address == contract_address + return _emitter + + +async def async_range(*args): + for i in range(*args): + yield (i) + await asyncio.sleep(0.0) diff --git a/tests/utils.py b/tests/utils.py index 43a9a1aa2e..7485ca8a30 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -4,6 +4,32 @@ import websockets +from web3._utils.threads import ( + Timeout, +) + + +class PollDelayCounter: + def __init__(self, initial_delay=0, max_delay=1, initial_step=0.01): + self.initial_delay = initial_delay + self.initial_step = initial_step + self.max_delay = max_delay + self.current_delay = initial_delay + + def __call__(self): + delay = self.current_delay + + if self.current_delay == 0: + self.current_delay += self.initial_step + else: + self.current_delay *= 2 + self.current_delay = min(self.current_delay, self.max_delay) + + return delay + + def reset(self): + self.current_delay = self.initial_delay + def get_open_port(): sock = socket.socket() @@ -23,3 +49,39 @@ async def wait_for_ws(endpoint_uri, timeout=10): await asyncio.sleep(0.01) else: break + + +async def _async_wait_for_block_fixture_logic(async_w3, block_number=1, timeout=None): + if not timeout: + current_block_number = await async_w3.eth.block_number # type:ignore + timeout = (block_number - current_block_number) * 3 + poll_delay_counter = PollDelayCounter() + with Timeout(timeout) as timeout: + eth_block_number = await async_w3.eth.block_number + while eth_block_number < block_number: + await async_w3.manager.coro_request("evm_mine", []) + await timeout.async_sleep(poll_delay_counter()) + eth_block_number = await async_w3.eth.block_number + + +async def _async_wait_for_transaction_fixture_logic(async_w3, txn_hash, timeout=120): + poll_delay_counter = PollDelayCounter() + with Timeout(timeout) as timeout: + while True: + txn_receipt = await async_w3.eth.get_transaction_receipt(txn_hash) + if txn_receipt is not None: + break + asyncio.sleep(poll_delay_counter()) + timeout.check() + + return txn_receipt + + +def async_partial(f, *args, **kwargs): + async def f2(*args2, **kwargs2): + result = f(*args, *args2, **kwargs, **kwargs2) + if asyncio.iscoroutinefunction(f): + result = await result + return result + + return f2 diff --git a/web3/_utils/threads.py b/web3/_utils/threads.py index c69485604d..991c7cf254 100644 --- a/web3/_utils/threads.py +++ b/web3/_utils/threads.py @@ -1,6 +1,7 @@ """ A minimal implementation of the various gevent APIs used within this codebase. """ +import asyncio import threading import time from types import ( @@ -97,6 +98,10 @@ def sleep(self, seconds: float) -> None: time.sleep(seconds) self.check() + async def async_sleep(self, seconds: float) -> None: + await asyncio.sleep(seconds) + self.check() + class ThreadWithReturn(threading.Thread, Generic[TReturn]): def __init__( diff --git a/web3/contract.py b/web3/contract.py index d938888f0b..63bec34b91 100644 --- a/web3/contract.py +++ b/web3/contract.py @@ -85,11 +85,13 @@ to_hex, ) from web3._utils.events import ( + AsyncEventFilterBuilder, EventFilterBuilder, get_event_data, is_dynamic_sized_type, ) from web3._utils.filters import ( + AsyncLogFilter, LogFilter, construct_event_filter_params, ) @@ -1474,19 +1476,64 @@ def processLog(self, log: HexStr) -> EventData: return get_event_data(self.w3.codec, self.abi, log) @combomethod - def create_filter( + def _get_event_filter_params( + self, + abi: ABIEvent, + argument_filters: Optional[Dict[str, Any]] = None, + fromBlock: Optional[BlockIdentifier] = None, + toBlock: Optional[BlockIdentifier] = None, + blockHash: Optional[HexBytes] = None, + ) -> FilterParams: + + if not self.address: + raise TypeError( + "This method can be only called on " + "an instated contract with an address" + ) + + if argument_filters is None: + argument_filters = dict() + + _filters = dict(**argument_filters) + + blkhash_set = blockHash is not None + blknum_set = fromBlock is not None or toBlock is not None + if blkhash_set and blknum_set: + raise ValidationError( + "blockHash cannot be set at the same time as fromBlock or toBlock" + ) + + # Construct JSON-RPC raw filter presentation based on human readable + # Python descriptions. Namely, convert event names to their keccak signatures + data_filter_set, event_filter_params = construct_event_filter_params( + abi, + self.w3.codec, + contract_address=self.address, + argument_filters=_filters, + fromBlock=fromBlock, + toBlock=toBlock, + address=self.address, + ) + + if blockHash is not None: + event_filter_params["blockHash"] = blockHash + + return event_filter_params + + @classmethod + def factory(cls, class_name: str, **kwargs: Any) -> PropertyCheckingFactory: + return PropertyCheckingFactory(class_name, (cls,), kwargs) + + @combomethod + def _set_up_filter_builder( self, - *, # PEP 3102 argument_filters: Optional[Dict[str, Any]] = None, fromBlock: Optional[BlockIdentifier] = None, toBlock: BlockIdentifier = "latest", address: Optional[ChecksumAddress] = None, topics: Optional[Sequence[Any]] = None, - ) -> LogFilter: - """ - Create filter object that tracks logs emitted by this contract event. - :param filter_params: other parameters to limit the events - """ + filter_builder: Union[EventFilterBuilder, AsyncEventFilterBuilder] = None, + ) -> None: if fromBlock is None: raise TypeError( "Missing mandatory keyword argument to create_filter: fromBlock" @@ -1512,7 +1559,6 @@ def create_filter( topics=topics, ) - filter_builder = EventFilterBuilder(event_abi, self.w3.codec) filter_builder.address = cast( ChecksumAddress, event_filter_params.get("address") ) @@ -1536,73 +1582,6 @@ def create_filter( for arg, value in match_single_vals.items(): filter_builder.args[arg].match_single(value) - log_filter = filter_builder.deploy(self.w3) - log_filter.log_entry_formatter = get_event_data( - self.w3.codec, self._get_event_abi() - ) - log_filter.builder = filter_builder - - return log_filter - - @combomethod - def build_filter(self) -> EventFilterBuilder: - builder = EventFilterBuilder( - self._get_event_abi(), - self.w3.codec, - formatter=get_event_data(self.w3.codec, self._get_event_abi()), - ) - builder.address = self.address - return builder - - @combomethod - def _get_event_filter_params( - self, - abi: ABIEvent, - argument_filters: Optional[Dict[str, Any]] = None, - fromBlock: Optional[BlockIdentifier] = None, - toBlock: Optional[BlockIdentifier] = None, - blockHash: Optional[HexBytes] = None, - ) -> FilterParams: - - if not self.address: - raise TypeError( - "This method can be only called on " - "an instated contract with an address" - ) - - if argument_filters is None: - argument_filters = dict() - - _filters = dict(**argument_filters) - - blkhash_set = blockHash is not None - blknum_set = fromBlock is not None or toBlock is not None - if blkhash_set and blknum_set: - raise ValidationError( - "blockHash cannot be set at the same time as fromBlock or toBlock" - ) - - # Construct JSON-RPC raw filter presentation based on human readable - # Python descriptions. Namely, convert event names to their keccak signatures - data_filter_set, event_filter_params = construct_event_filter_params( - abi, - self.w3.codec, - contract_address=self.address, - argument_filters=_filters, - fromBlock=fromBlock, - toBlock=toBlock, - address=self.address, - ) - - if blockHash is not None: - event_filter_params["blockHash"] = blockHash - - return event_filter_params - - @classmethod - def factory(cls, class_name: str, **kwargs: Any) -> PropertyCheckingFactory: - return PropertyCheckingFactory(class_name, (cls,), kwargs) - class ContractEvent(BaseContractEvent): @combomethod @@ -1678,6 +1657,47 @@ def getLogs( # Convert raw binary data to Python proxy objects as described by ABI return tuple(get_event_data(self.w3.codec, abi, entry) for entry in logs) + @combomethod + def create_filter( + self, + *, # PEP 3102 + argument_filters: Optional[Dict[str, Any]] = None, + fromBlock: Optional[BlockIdentifier] = None, + toBlock: BlockIdentifier = "latest", + address: Optional[ChecksumAddress] = None, + topics: Optional[Sequence[Any]] = None, + ) -> LogFilter: + """ + Create filter object that tracks logs emitted by this contract event. + # optional --- update the params descriptions here or remove this line + """ + filter_builder = EventFilterBuilder(self._get_event_abi(), self.w3.codec) + self._set_up_filter_builder( + argument_filters, + fromBlock, + toBlock, + address, + topics, + filter_builder, + ) + log_filter = filter_builder.deploy(self.w3) + log_filter.log_entry_formatter = get_event_data( + self.w3.codec, self._get_event_abi() + ) + log_filter.builder = filter_builder + + return log_filter + + @combomethod + def build_filter(self) -> EventFilterBuilder: + builder = EventFilterBuilder( + self._get_event_abi(), + self.w3.codec, + formatter=get_event_data(self.w3.codec, self._get_event_abi()), + ) + builder.address = self.address + return builder + class AsyncContractEvent(BaseContractEvent): @combomethod @@ -1755,6 +1775,47 @@ async def getLogs( get_event_data(self.w3.codec, abi, entry) for entry in logs # type: ignore ) + @combomethod + async def create_filter( + self, + *, # PEP 3102 + argument_filters: Optional[Dict[str, Any]] = None, + fromBlock: Optional[BlockIdentifier] = None, + toBlock: BlockIdentifier = "latest", + address: Optional[ChecksumAddress] = None, + topics: Optional[Sequence[Any]] = None, + ) -> AsyncLogFilter: + """ + Create filter object that tracks logs emitted by this contract event. + # optional --- update the params descriptions here or remove this line + """ + filter_builder = AsyncEventFilterBuilder(self._get_event_abi(), self.w3.codec) + self._set_up_filter_builder( + argument_filters, + fromBlock, + toBlock, + address, + topics, + filter_builder, + ) + log_filter = await filter_builder.deploy(self.w3) + log_filter.log_entry_formatter = get_event_data( + self.w3.codec, self._get_event_abi() + ) + log_filter.builder = filter_builder + + return log_filter + + @combomethod + def build_filter(self) -> AsyncEventFilterBuilder: + builder = AsyncEventFilterBuilder( + self._get_event_abi(), + self.w3.codec, + formatter=get_event_data(self.w3.codec, self._get_event_abi()), + ) + builder.address = self.address + return builder + class BaseContractCaller: """