diff --git a/conftest.py b/conftest.py index 9c276371f1..652c23def1 100644 --- a/conftest.py +++ b/conftest.py @@ -1,5 +1,6 @@ import asyncio import pytest +import pytest_asyncio import time import warnings @@ -14,27 +15,11 @@ EthereumTesterProvider, ) - -class PollDelayCounter: - def __init__(self, initial_delay=0, max_delay=1, initial_step=0.01): - self.initial_delay = initial_delay - self.initial_step = initial_step - self.max_delay = max_delay - self.current_delay = initial_delay - - def __call__(self): - delay = self.current_delay - - if self.current_delay == 0: - self.current_delay += self.initial_step - else: - self.current_delay *= 2 - self.current_delay = min(self.current_delay, self.max_delay) - - return delay - - def reset(self): - self.current_delay = self.initial_delay +from tests.utils import ( + PollDelayCounter, + _async_wait_for_block_fixture_logic, + _async_wait_for_transaction_fixture_logic, +) @pytest.fixture() @@ -46,25 +31,12 @@ def is_testrpc_provider(provider): return isinstance(provider, EthereumTesterProvider) -def is_async_testrpc_provider(provider): - return isinstance(provider, AsyncEthereumTesterProvider) - - @pytest.fixture() def skip_if_testrpc(): - def _skip_if_testrpc(w3): if is_testrpc_provider(w3.provider): pytest.skip() - return _skip_if_testrpc - - -@pytest.fixture() -def async_skip_if_testrpc(): - def _skip_if_testrpc(async_w3): - if is_async_testrpc_provider(async_w3.provider): - pytest.skip() return _skip_if_testrpc @@ -76,6 +48,7 @@ def _wait_for_miner_start(w3, timeout=60): while not w3.eth.mining or not w3.eth.hashrate: time.sleep(poll_delay_counter()) timeout.check() + return _wait_for_miner_start @@ -89,6 +62,7 @@ def _wait_for_block(w3, block_number=1, timeout=None): while w3.eth.block_number < block_number: w3.manager.request_blocking("evm_mine", []) timeout.sleep(poll_delay_counter()) + return _wait_for_block @@ -105,6 +79,7 @@ def _wait_for_transaction(w3, txn_hash, timeout=120): timeout.check() return txn_receipt + return _wait_for_transaction @@ -123,4 +98,30 @@ def w3_strict_abi(): @pytest.fixture(autouse=True) def print_warnings(): - warnings.simplefilter('always') + warnings.simplefilter("always") + + +# --- async --- # + + +def is_async_testrpc_provider(provider): + return isinstance(provider, AsyncEthereumTesterProvider) + + +@pytest.fixture() +def async_skip_if_testrpc(): + def _skip_if_testrpc(async_w3): + if is_async_testrpc_provider(async_w3.provider): + pytest.skip() + + return _skip_if_testrpc + + +@pytest_asyncio.fixture() +async def async_wait_for_block(): + return _async_wait_for_block_fixture_logic + + +@pytest_asyncio.fixture() +async def async_wait_for_transaction(): + return _async_wait_for_transaction_fixture_logic diff --git a/newsfragments/2744.feature.rst b/newsfragments/2744.feature.rst new file mode 100644 index 0000000000..0221ab2143 --- /dev/null +++ b/newsfragments/2744.feature.rst @@ -0,0 +1 @@ +Added async functionality to filter diff --git a/tests/core/contracts/conftest.py b/tests/core/contracts/conftest.py index 59d13b7ed9..be175f3f1a 100644 --- a/tests/core/contracts/conftest.py +++ b/tests/core/contracts/conftest.py @@ -7,11 +7,13 @@ ) import pytest_asyncio -from _utils import ( +from tests.core.contracts.utils import ( async_deploy, - async_partial, deploy, ) +from tests.utils import ( + async_partial, +) from web3._utils.module_testing.emitter_contract import ( CONTRACT_EMITTER_ABI, CONTRACT_EMITTER_CODE, diff --git a/tests/core/contracts/test_contract_call_interface.py b/tests/core/contracts/test_contract_call_interface.py index 7d1338e944..24e1d84e52 100644 --- a/tests/core/contracts/test_contract_call_interface.py +++ b/tests/core/contracts/test_contract_call_interface.py @@ -16,7 +16,7 @@ ) import pytest_asyncio -from _utils import ( +from tests.core.contracts.utils import ( async_deploy, deploy, ) diff --git a/tests/core/contracts/test_contract_caller_interface.py b/tests/core/contracts/test_contract_caller_interface.py index 46c05b6e21..ebda46b5a3 100644 --- a/tests/core/contracts/test_contract_caller_interface.py +++ b/tests/core/contracts/test_contract_caller_interface.py @@ -2,7 +2,7 @@ import pytest_asyncio -from _utils import ( +from tests.core.contracts.utils import ( async_deploy, deploy, ) diff --git a/tests/core/contracts/_utils.py b/tests/core/contracts/utils.py similarity index 79% rename from tests/core/contracts/_utils.py rename to tests/core/contracts/utils.py index 12d1e5564a..75ba472fe2 100644 --- a/tests/core/contracts/_utils.py +++ b/tests/core/contracts/utils.py @@ -1,5 +1,3 @@ -import asyncio - from eth_utils.toolz import ( identity, ) @@ -27,13 +25,3 @@ async def async_deploy(async_web3, Contract, apply_func=identity, args=None): assert contract.address == address assert len(await async_web3.eth.get_code(contract.address)) > 0 return contract - - -def async_partial(f, *args, **kwargs): - async def f2(*args2, **kwargs2): - result = f(*args, *args2, **kwargs, **kwargs2) - if asyncio.iscoroutinefunction(f): - result = await result - return result - - return f2 diff --git a/tests/core/eth-module/test_eth_filter.py b/tests/core/eth-module/test_eth_filter.py new file mode 100644 index 0000000000..e33bd6af2d --- /dev/null +++ b/tests/core/eth-module/test_eth_filter.py @@ -0,0 +1,49 @@ +import pytest + +import pytest_asyncio + +from web3 import Web3 +from web3._utils.filters import ( + AsyncBlockFilter, + AsyncLogFilter, + AsyncTransactionFilter, + BlockFilter, + LogFilter, + TransactionFilter, +) +from web3.eth import ( + AsyncEth, +) +from web3.providers.eth_tester.main import ( + AsyncEthereumTesterProvider, +) + + +def test_Eth_filter_creates_correct_filter_type(w3): + filter1 = w3.eth.filter("latest") + assert isinstance(filter1, BlockFilter) + filter2 = w3.eth.filter("pending") + assert isinstance(filter2, TransactionFilter) + filter3 = w3.eth.filter({}) + assert isinstance(filter3, LogFilter) + + +# --- async --- # + + +@pytest_asyncio.fixture() +async def async_w3(): + provider = AsyncEthereumTesterProvider() + w3 = Web3(provider, modules={"eth": [AsyncEth]}, middlewares=[]) + return w3 + + +@pytest.mark.asyncio +async def test_AsyncEth_filter_creates_correct_filter_type(async_w3): + + filter1 = await async_w3.eth.filter("latest") + assert isinstance(filter1, AsyncBlockFilter) + filter2 = await async_w3.eth.filter("pending") + assert isinstance(filter2, AsyncTransactionFilter) + filter3 = await async_w3.eth.filter({}) + assert isinstance(filter3, AsyncLogFilter) diff --git a/tests/core/filtering/conftest.py b/tests/core/filtering/conftest.py index d3ddd6f153..58457e856c 100644 --- a/tests/core/filtering/conftest.py +++ b/tests/core/filtering/conftest.py @@ -6,19 +6,22 @@ encode_hex, event_signature_to_log_topic, ) +import pytest_asyncio -from web3 import Web3 +from tests.core.filtering.utils import ( + _async_emitter_fixture_logic, + _async_w3_fixture_logic, + _emitter_fixture_logic, + _w3_fixture_logic, +) +from tests.utils import ( + async_partial, +) from web3._utils.module_testing.emitter_contract import ( CONTRACT_EMITTER_ABI, CONTRACT_EMITTER_CODE, CONTRACT_EMITTER_RUNTIME, ) -from web3.middleware import ( - local_filter_middleware, -) -from web3.providers.eth_tester import ( - EthereumTesterProvider, -) @pytest.fixture( @@ -27,17 +30,7 @@ ids=["local_filter_middleware", "node_based_filter"], ) def w3(request): - use_filter_middleware = request.param - provider = EthereumTesterProvider() - w3 = Web3(provider) - if use_filter_middleware: - w3.middleware_onion.add(local_filter_middleware) - return w3 - - -@pytest.fixture(autouse=True) -def wait_for_mining_start(w3, wait_for_block): - wait_for_block(w3) + return _w3_fixture_logic(request) @pytest.fixture() @@ -71,16 +64,9 @@ def Emitter(w3, EMITTER): @pytest.fixture() def emitter(w3, Emitter, wait_for_transaction, wait_for_block, address_conversion_func): - wait_for_block(w3) - deploy_txn_hash = Emitter.constructor().transact({"gas": 10000000}) - deploy_receipt = wait_for_transaction(w3, deploy_txn_hash) - contract_address = address_conversion_func(deploy_receipt["contractAddress"]) - - bytecode = w3.eth.get_code(contract_address) - assert bytecode == Emitter.bytecode_runtime - _emitter = Emitter(address=contract_address) - assert _emitter.address == contract_address - return _emitter + return _emitter_fixture_logic( + w3, Emitter, wait_for_transaction, wait_for_block, address_conversion_func + ) class LogFunctions: @@ -163,7 +149,7 @@ def emitter_log_topics(): return LogTopics -def return_filter(contract=None, args=[]): +def return_filter(contract, args): event_name = args[0] kwargs = apply_key_map({"filter": "argument_filters"}, args[1]) if "fromBlock" not in kwargs: @@ -174,3 +160,50 @@ def return_filter(contract=None, args=[]): @pytest.fixture(scope="module") def create_filter(request): return functools.partial(return_filter) + + +# --- async --- # + + +@pytest.fixture( + scope="function", + params=[True, False], + ids=["async_local_filter_middleware", "node_based_filter"], +) +def async_w3(request): + return _async_w3_fixture_logic(request) + + +@pytest.fixture() +def AsyncEmitter(async_w3, EMITTER): + return async_w3.eth.contract(**EMITTER) + + +@pytest_asyncio.fixture() +async def async_emitter( + async_w3, + AsyncEmitter, + async_wait_for_transaction, + async_wait_for_block, + address_conversion_func, +): + return await _async_emitter_fixture_logic( + async_w3, + AsyncEmitter, + async_wait_for_transaction, + async_wait_for_block, + address_conversion_func, + ) + + +async def async_return_filter(contract, args): + event_name = args[0] + kwargs = apply_key_map({"filter": "argument_filters"}, args[1]) + if "fromBlock" not in kwargs: + kwargs["fromBlock"] = "latest" + return await contract.events[event_name].create_filter(**kwargs) + + +@pytest_asyncio.fixture(scope="module") +async def async_create_filter(request): + return async_partial(async_return_filter) diff --git a/tests/core/filtering/test_basic_filter_tests.py b/tests/core/filtering/test_basic_filter_tests.py index 43a2413f19..cb47271de3 100644 --- a/tests/core/filtering/test_basic_filter_tests.py +++ b/tests/core/filtering/test_basic_filter_tests.py @@ -1,3 +1,6 @@ +import pytest + + def test_filtering_sequential_blocks_with_bounded_range( w3, emitter, Emitter, wait_for_transaction ): @@ -30,3 +33,46 @@ def test_requesting_results_with_no_new_blocks(w3, emitter): builder = emitter.events.LogNoArguments.build_filter() filter_ = builder.deploy(w3) assert len(filter_.get_new_entries()) == 0 + + +# --- async --- # + + +@pytest.mark.asyncio +async def test_async_filtering_sequential_blocks_with_bounded_range( + async_w3, async_emitter +): + builder = async_emitter.events.LogNoArguments.build_filter() + builder.fromBlock = "latest" + initial_block_number = await async_w3.eth.block_number + builder.toBlock = initial_block_number + 100 + filter_ = await builder.deploy(async_w3) + for i in range(100): + await async_emitter.functions.logNoArgs(which=1).transact() + eth_block_number = await async_w3.eth.block_number + assert eth_block_number == initial_block_number + 100 + new_entries = await filter_.get_new_entries() + assert len(new_entries) == 100 + + +@pytest.mark.asyncio +async def test_async_filtering_starting_block_range(async_w3, async_emitter): + for i in range(10): + await async_emitter.functions.logNoArgs(which=1).transact() + builder = async_emitter.events.LogNoArguments.build_filter() + filter_ = await builder.deploy(async_w3) + initial_block_number = await async_w3.eth.block_number + for i in range(10): + await async_emitter.functions.logNoArgs(which=1).transact() + eth_block_number = await async_w3.eth.block_number + assert eth_block_number == initial_block_number + 10 + new_entries = await filter_.get_new_entries() + assert len(new_entries) == 10 + + +@pytest.mark.asyncio +async def test_async_requesting_results_with_no_new_blocks(async_w3, async_emitter): + builder = async_emitter.events.LogNoArguments.build_filter() + filter_ = await builder.deploy(async_w3) + new_entries = await filter_.get_new_entries() + assert len(new_entries) == 0 diff --git a/tests/core/filtering/test_contract_data_filters.py b/tests/core/filtering/test_contract_data_filters.py index eeeeb3c51e..4ac35346a9 100644 --- a/tests/core/filtering/test_contract_data_filters.py +++ b/tests/core/filtering/test_contract_data_filters.py @@ -1,3 +1,4 @@ +import asyncio import pytest from hypothesis import ( @@ -5,33 +6,23 @@ settings, strategies as st, ) +import pytest_asyncio -from web3 import Web3 +from tests.core.filtering.utils import ( + _async_emitter_fixture_logic, + _async_w3_fixture_logic, + _emitter_fixture_logic, + _w3_fixture_logic, +) +from tests.utils import ( + _async_wait_for_block_fixture_logic, + _async_wait_for_transaction_fixture_logic, +) from web3._utils.module_testing.emitter_contract import ( CONTRACT_EMITTER_ABI, CONTRACT_EMITTER_CODE, CONTRACT_EMITTER_RUNTIME, ) -from web3.middleware import ( - local_filter_middleware, -) -from web3.providers.eth_tester import ( - EthereumTesterProvider, -) - - -@pytest.fixture( - scope="module", - params=[True, False], - ids=["local_filter_middleware", "node_based_filter"], -) -def w3(request): - use_filter_middleware = request.param - provider = EthereumTesterProvider() - w3 = Web3(provider) - if use_filter_middleware: - w3.middleware_onion.add(local_filter_middleware) - return w3 @pytest.fixture(scope="module") @@ -58,25 +49,6 @@ def EMITTER(EMITTER_CODE, EMITTER_RUNTIME, EMITTER_ABI): } -@pytest.fixture(scope="module") -def Emitter(w3, EMITTER): - return w3.eth.contract(**EMITTER) - - -@pytest.fixture(scope="module") -def emitter(w3, Emitter, wait_for_transaction, wait_for_block, address_conversion_func): - wait_for_block(w3) - deploy_txn_hash = Emitter.constructor().transact({"gas": 10000000}) - deploy_receipt = wait_for_transaction(w3, deploy_txn_hash) - contract_address = address_conversion_func(deploy_receipt["contractAddress"]) - - bytecode = w3.eth.get_code(contract_address) - assert bytecode == Emitter.bytecode_runtime - _emitter = Emitter(address=contract_address) - assert _emitter.address == contract_address - return _emitter - - def not_empty_string(x): return x != "" @@ -130,6 +102,30 @@ def array_values(draw): return (matching, non_matching) +# --- sync --- # + + +@pytest.fixture( + scope="module", + params=[True, False], + ids=["local_filter_middleware", "node_based_filter"], +) +def w3(request): + return _w3_fixture_logic(request) + + +@pytest.fixture(scope="module") +def Emitter(w3, EMITTER): + return w3.eth.contract(**EMITTER) + + +@pytest.fixture(scope="module") +def emitter(w3, Emitter, wait_for_transaction, wait_for_block, address_conversion_func): + return _emitter_fixture_logic( + w3, Emitter, wait_for_transaction, wait_for_block, address_conversion_func + ) + + @pytest.mark.parametrize("api_style", ("v4", "build_filter")) @given(vals=dynamic_values()) @settings(max_examples=5, deadline=None) @@ -284,3 +280,218 @@ def test_data_filters_with_list_arguments( else: with pytest.raises(TypeError): create_filter(emitter, ["LogListArgs", {"filter": {"arg1": matching}}]) + + +# --- async --- # + + +@pytest_asyncio.fixture(scope="module") +async def async_wait_for_block(): + return _async_wait_for_block_fixture_logic + + +@pytest_asyncio.fixture(scope="module") +async def async_wait_for_transaction(): + return _async_wait_for_transaction_fixture_logic + + +@pytest.fixture(scope="module") +def event_loop(): + loop = asyncio.get_event_loop_policy().new_event_loop() + yield loop + loop.close() + + +@pytest.fixture( + scope="module", + params=[True, False], + ids=["local_filter_middleware", "node_based_filter"], +) +def async_w3(request): + return _async_w3_fixture_logic(request) + + +@pytest.fixture(scope="module") +def AsyncEmitter(async_w3, EMITTER): + return async_w3.eth.contract(**EMITTER) + + +@pytest_asyncio.fixture(scope="module") +async def async_emitter( + async_w3, + AsyncEmitter, + async_wait_for_transaction, + async_wait_for_block, + address_conversion_func, +): + return await _async_emitter_fixture_logic( + async_w3, + AsyncEmitter, + async_wait_for_transaction, + async_wait_for_block, + address_conversion_func, + ) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("api_style", ("v4", "build_filter")) +@given(vals=dynamic_values()) +@settings(max_examples=5, deadline=None) +async def test_async_data_filters_with_dynamic_arguments( + async_w3, + async_wait_for_transaction, + async_create_filter, + async_emitter, + api_style, + vals, +): + if api_style == "build_filter": + filter_builder = async_emitter.events.LogDynamicArgs.build_filter() + filter_builder.args["arg1"].match_single(vals["matching"]) + event_filter = await filter_builder.deploy(async_w3) + else: + event_filter = await async_create_filter( + async_emitter, ["LogDynamicArgs", {"filter": {"arg1": vals["matching"]}}] + ) + + txn_hashes = [ + await async_emitter.functions.logDynamicArgs( + arg0=vals["matching"], arg1=vals["matching"] + ).transact( + {"maxFeePerGas": 10**9, "maxPriorityFeePerGas": 10**9, "gas": 400000} + ), + await async_emitter.functions.logDynamicArgs( + arg0=vals["non_matching"][0], arg1=vals["non_matching"][0] + ).transact( + {"maxFeePerGas": 10**9, "maxPriorityFeePerGas": 10**9, "gas": 400000} + ), + ] + + for txn_hash in txn_hashes: + await async_wait_for_transaction(async_w3, txn_hash) + + log_entries = await event_filter.get_new_entries() + assert len(log_entries) == 1 + assert log_entries[0]["transactionHash"] == txn_hashes[0] + + +@pytest.mark.asyncio +@pytest.mark.parametrize("api_style", ("v4", "build_filter")) +@given(vals=fixed_values()) +@settings(max_examples=5, deadline=None) +async def test_async_data_filters_with_fixed_arguments( + async_w3, + async_emitter, + async_wait_for_transaction, + async_create_filter, + api_style, + vals, +): + if api_style == "build_filter": + filter_builder = async_emitter.events.LogQuadrupleArg.build_filter() + filter_builder.args["arg0"].match_single(vals["matching"][0]) + filter_builder.args["arg1"].match_single(vals["matching"][1]) + filter_builder.args["arg2"].match_single(vals["matching"][2]) + filter_builder.args["arg3"].match_single(vals["matching"][3]) + event_filter = await filter_builder.deploy(async_w3) + else: + event_filter = await async_create_filter( + async_emitter, + [ + "LogQuadrupleArg", + { + "filter": { + "arg0": vals["matching"][0], + "arg1": vals["matching"][1], + "arg2": vals["matching"][2], + "arg3": vals["matching"][3], + } + }, + ], + ) + + txn_hashes = [] + txn_hashes.append( + await async_emitter.functions.logQuadruple( + which=5, + arg0=vals["matching"][0], + arg1=vals["matching"][1], + arg2=vals["matching"][2], + arg3=vals["matching"][3], + ).transact( + {"maxFeePerGas": 10**9, "maxPriorityFeePerGas": 10**9, "gas": 100000} + ) + ) + txn_hashes.append( + await async_emitter.functions.logQuadruple( + which=5, + arg0=vals["non_matching"][0], + arg1=vals["non_matching"][1], + arg2=vals["non_matching"][2], + arg3=vals["non_matching"][3], + ).transact( + {"maxFeePerGas": 10**9, "maxPriorityFeePerGas": 10**9, "gas": 100000} + ) + ) + + for txn_hash in txn_hashes: + await async_wait_for_transaction(async_w3, txn_hash) + + log_entries = await event_filter.get_new_entries() + assert len(log_entries) == 1 + assert log_entries[0]["transactionHash"] == txn_hashes[0] + + +@pytest.mark.asyncio +@pytest.mark.parametrize("api_style", ("v4", "build_filter")) +@given(vals=array_values()) +@settings(max_examples=5, deadline=None) +async def test_async_data_filters_with_list_arguments( + async_w3, + async_emitter, + async_wait_for_transaction, + async_create_filter, + api_style, + vals, +): + matching, non_matching = vals + + if api_style == "build_filter": + filter_builder = async_emitter.events.LogListArgs.build_filter() + filter_builder.args["arg1"].match_single(matching) + event_filter = await filter_builder.deploy(async_w3) + + txn_hashes = [] + txn_hashes.append( + await async_emitter.functions.logListArgs( + arg0=matching, arg1=matching + ).transact({"maxFeePerGas": 10**9, "maxPriorityFeePerGas": 10**9}) + ) + txn_hashes.append( + await async_emitter.functions.logListArgs( + arg0=non_matching, arg1=non_matching + ).transact({"maxFeePerGas": 10**9, "maxPriorityFeePerGas": 10**9}) + ) + txn_hashes.append( + await async_emitter.functions.logListArgs( + arg0=non_matching, arg1=matching + ).transact({"maxFeePerGas": 10**9, "maxPriorityFeePerGas": 10**9}) + ) + txn_hashes.append( + await async_emitter.functions.logListArgs( + arg0=matching, arg1=non_matching + ).transact({"maxFeePerGas": 10**9, "maxPriorityFeePerGas": 10**9}) + ) + + for txn_hash in txn_hashes: + await async_wait_for_transaction(async_w3, txn_hash) + + log_entries = await event_filter.get_new_entries() + assert len(log_entries) == 2 + assert log_entries[0]["transactionHash"] == txn_hashes[0] + assert log_entries[1]["transactionHash"] == txn_hashes[2] + else: + with pytest.raises(TypeError): + await async_create_filter( + async_emitter, ["LogListArgs", {"filter": {"arg1": matching}}] + ) diff --git a/tests/core/filtering/test_contract_get_logs.py b/tests/core/filtering/test_contract_get_logs.py index b7a0a0d518..4759ef97aa 100644 --- a/tests/core/filtering/test_contract_get_logs.py +++ b/tests/core/filtering/test_contract_get_logs.py @@ -1,3 +1,6 @@ +import pytest + + def test_contract_get_available_events( emitter, ): @@ -87,3 +90,114 @@ def test_contract_get_logs_argument_filter( argument_filters={"arg0": 1}, ) assert len(partial_logs) == 4 + + +# --- async --- # + + +def test_async_contract_get_available_events( + async_emitter, +): + """We can iterate over available contract events""" + contract = async_emitter + events = list(contract.events) + assert len(events) == 19 + + +@pytest.mark.asyncio +async def test_async_contract_get_logs_all( + async_w3, + async_emitter, + async_wait_for_transaction, + emitter_event_ids, +): + contract = async_emitter + event_id = emitter_event_ids.LogNoArguments + + txn_hash = await contract.functions.logNoArgs(event_id).transact() + await async_wait_for_transaction(async_w3, txn_hash) + + contract_logs = await contract.events.LogNoArguments.get_logs() + log_entries = list(contract_logs) + assert len(log_entries) == 1 + assert log_entries[0]["transactionHash"] == txn_hash + + +@pytest.mark.asyncio +async def test_async_contract_get_logs_range( + async_w3, + async_emitter, + async_wait_for_transaction, + emitter_event_ids, +): + contract = async_emitter + event_id = emitter_event_ids.LogNoArguments + + eth_block_number = await async_w3.eth.block_number + assert eth_block_number == 2 + txn_hash = await contract.functions.logNoArgs(event_id).transact() + # Mined as block 3 + await async_wait_for_transaction(async_w3, txn_hash) + eth_block_number = await async_w3.eth.block_number + assert eth_block_number == 3 + + contract_logs = await contract.events.LogNoArguments.get_logs() + log_entries = list(contract_logs) + assert len(log_entries) == 1 + + contract_logs = await contract.events.LogNoArguments.get_logs( + from_block=2, to_block=3 + ) + log_entries = list(contract_logs) + assert len(log_entries) == 1 + + contract_logs = await contract.events.LogNoArguments.get_logs( + from_block=1, to_block=2 + ) + log_entries = list(contract_logs) + assert len(log_entries) == 0 + + +@pytest.mark.asyncio +async def test_async_contract_get_logs_argument_filter( + async_w3, async_emitter, async_wait_for_transaction, emitter_event_ids +): + + contract = async_emitter + + txn_hashes = [] + event_id = emitter_event_ids.LogTripleWithIndex + # 1 = arg0 + # 4 = arg1 + # 1 = arg2 + txn_hashes.append( + await async_emitter.functions.logTriple(event_id, 1, 4, 1).transact() + ) + txn_hashes.append( + await async_emitter.functions.logTriple(event_id, 1, 1, 2).transact() + ) + txn_hashes.append( + await async_emitter.functions.logTriple(event_id, 1, 2, 2).transact() + ) + txn_hashes.append( + await async_emitter.functions.logTriple(event_id, 1, 3, 1).transact() + ) + for txn_hash in txn_hashes: + await async_wait_for_transaction(async_w3, txn_hash) + + all_logs = await contract.events.LogTripleWithIndex.get_logs(from_block=1) + assert len(all_logs) == 4 + + # Filter all entries where arg1 in (1, 2) + partial_logs = await contract.events.LogTripleWithIndex.get_logs( + from_block=1, + argument_filters={"arg1": [1, 2]}, + ) + assert len(partial_logs) == 2 + + # Filter all entries where arg0 == 1 + partial_logs = await contract.events.LogTripleWithIndex.get_logs( + from_block=1, + argument_filters={"arg0": 1}, + ) + assert len(partial_logs) == 4 diff --git a/tests/core/filtering/test_contract_on_event_filtering.py b/tests/core/filtering/test_contract_on_event_filtering.py index 8eb69a70aa..a5907eca31 100644 --- a/tests/core/filtering/test_contract_on_event_filtering.py +++ b/tests/core/filtering/test_contract_on_event_filtering.py @@ -1,3 +1,4 @@ +import asyncio import pytest from eth_utils import ( @@ -212,3 +213,262 @@ def test_on_sync_filter_with_topic_filter_options_on_old_apis( old_logs = post_event_filter.get_all_entries() assert len(old_logs) == 4 + + +# --- async --- # + + +@pytest.fixture(scope="module") +def event_loop(): + loop = asyncio.get_event_loop_policy().new_event_loop() + yield loop + loop.close() + + +@pytest.mark.asyncio +@pytest.mark.parametrize("call_as_instance", (True, False)) +async def test_async_create_filter_address_parameter( + async_emitter, AsyncEmitter, call_as_instance +): + if call_as_instance: + event_filter = await async_emitter.events.LogNoArguments.create_filter( + fromBlock="latest" + ) + else: + event_filter = await AsyncEmitter.events.LogNoArguments.create_filter( + fromBlock="latest" + ) + + if call_as_instance: + # Assert this is a single string value, and not a list of addresses + assert is_address(event_filter.filter_params["address"]) + else: + # Undeployed contract shouldnt have address... + assert "address" not in event_filter.filter_params + + +@pytest.mark.asyncio +@pytest.mark.parametrize("call_as_instance", (True, False)) +@pytest.mark.parametrize("api_style", ("v4", "build_filter")) +async def test_on_async_filter_using_get_entries_interface( + async_w3, + async_emitter, + AsyncEmitter, + async_wait_for_transaction, + emitter_event_ids, + call_as_instance, + api_style, + async_create_filter, +): + + if call_as_instance: + contract = async_emitter + else: + contract = AsyncEmitter + + if api_style == "build_filter": + event_filter = await contract.events.LogNoArguments.build_filter().deploy( + async_w3 + ) + else: + event_filter = await async_create_filter(async_emitter, ["LogNoArguments", {}]) + + txn_hash = await async_emitter.functions.logNoArgs( + emitter_event_ids.LogNoArguments + ).transact() + await async_wait_for_transaction(async_w3, txn_hash) + + log_entries = await event_filter.get_new_entries() + assert len(log_entries) == 1 + assert log_entries[0]["transactionHash"] == txn_hash + + # a second call is empty because all events have been retrieved + new_entries = await event_filter.get_new_entries() + assert len(new_entries) == 0 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("call_as_instance", (True, False)) +@pytest.mark.parametrize("api_style", ("v4", "build_filter")) +async def test_on_async_filter_with_event_name_and_single_argument( + async_w3, + async_emitter, + AsyncEmitter, + async_wait_for_transaction, + emitter_event_ids, + call_as_instance, + api_style, + async_create_filter, +): + + if call_as_instance: + contract = async_emitter + else: + contract = AsyncEmitter + + if api_style == "build_filter": + builder = contract.events.LogTripleWithIndex.build_filter() + builder.args["arg1"].match_single(2) + event_filter = await builder.deploy(async_w3) + else: + event_filter = await async_create_filter( + contract, + [ + "LogTripleWithIndex", + { + "filter": { + "arg1": 2, + } + }, + ], + ) + + txn_hashes = [] + event_id = emitter_event_ids.LogTripleWithIndex + txn_hashes.append( + await async_emitter.functions.logTriple(event_id, 2, 1, 3).transact() + ) + txn_hashes.append( + await async_emitter.functions.logTriple(event_id, 1, 2, 3).transact() + ) + txn_hashes.append( + await async_emitter.functions.logTriple(event_id, 12345, 2, 54321).transact() + ) + for txn_hash in txn_hashes: + await async_wait_for_transaction(async_w3, txn_hash) + + seen_logs = await event_filter.get_new_entries() + assert len(seen_logs) == 2 + assert {log["transactionHash"] for log in seen_logs} == set(txn_hashes[1:]) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("call_as_instance", (True, False)) +@pytest.mark.parametrize("api_style", ("v4", "build_filter")) +async def test_on_async_filter_with_event_name_and_non_indexed_argument( + async_w3, + async_emitter, + AsyncEmitter, + async_wait_for_transaction, + emitter_event_ids, + call_as_instance, + api_style, + async_create_filter, +): + + if call_as_instance: + contract = async_emitter + else: + contract = AsyncEmitter + + if api_style == "build_filter": + builder = contract.events.LogTripleWithIndex.build_filter() + builder.args["arg0"].match_single(1) + builder.args["arg1"].match_single(2) + event_filter = await builder.deploy(async_w3) + else: + event_filter = await async_create_filter( + contract, + [ + "LogTripleWithIndex", + { + "filter": { + "arg0": 1, + "arg1": 2, + } + }, + ], + ) + + txn_hashes = [] + event_id = emitter_event_ids.LogTripleWithIndex + txn_hashes.append( + await async_emitter.functions.logTriple(event_id, 2, 1, 3).transact() + ) + txn_hashes.append( + await async_emitter.functions.logTriple(event_id, 1, 2, 3).transact() + ) + txn_hashes.append( + await async_emitter.functions.logTriple(event_id, 12345, 2, 54321).transact() + ) + for txn_hash in txn_hashes: + await async_wait_for_transaction(async_w3, txn_hash) + + seen_logs = await event_filter.get_new_entries() + assert len(seen_logs) == 1 + assert seen_logs[0]["transactionHash"] == txn_hashes[1] + + post_event_filter = await contract.events.LogTripleWithIndex.create_filter( + argument_filters={"arg0": 1, "arg1": 2}, + fromBlock=0, + ) + + old_logs = await post_event_filter.get_all_entries() + assert len(old_logs) == 1 + assert old_logs[0]["transactionHash"] == txn_hashes[1] + + +@pytest.mark.asyncio +async def test_async_filter_with_contract_address( + async_w3, async_emitter, emitter_event_ids, async_wait_for_transaction +): + event_filter = await async_w3.eth.filter( + filter_params={"address": async_emitter.address} + ) + txn_hash = await async_emitter.functions.logNoArgs( + emitter_event_ids.LogNoArguments + ).transact() + await async_wait_for_transaction(async_w3, txn_hash) + seen_logs = await event_filter.get_new_entries() + assert len(seen_logs) == 1 + assert seen_logs[0]["transactionHash"] == txn_hash + + +@pytest.mark.asyncio +@pytest.mark.parametrize("call_as_instance", (True, False)) +async def test_on_async_filter_with_topic_filter_options_on_old_apis( + async_w3, + async_emitter, + AsyncEmitter, + async_wait_for_transaction, + emitter_event_ids, + call_as_instance, + async_create_filter, +): + + if call_as_instance: + contract = async_emitter + else: + contract = AsyncEmitter + + event_filter = await async_create_filter( + contract, ["LogTripleWithIndex", {"filter": {"arg1": [1, 2], "arg2": [1, 2]}}] + ) + + txn_hashes = [] + event_id = emitter_event_ids.LogTripleWithIndex + txn_hashes.append( + await async_emitter.functions.logTriple(event_id, 1, 1, 1).transact() + ) + txn_hashes.append( + await async_emitter.functions.logTriple(event_id, 1, 1, 2).transact() + ) + txn_hashes.append( + await async_emitter.functions.logTriple(event_id, 1, 2, 2).transact() + ) + txn_hashes.append( + await async_emitter.functions.logTriple(event_id, 1, 2, 1).transact() + ) + for txn_hash in txn_hashes: + await async_wait_for_transaction(async_w3, txn_hash) + + seen_logs = await event_filter.get_new_entries() + assert len(seen_logs) == 4 + + post_event_filter = await contract.events.LogTripleWithIndex.create_filter( + argument_filters={"arg1": [1, 2], "arg2": [1, 2]}, + fromBlock=0, + ) + + old_logs = await post_event_filter.get_all_entries() + assert len(old_logs) == 4 diff --git a/tests/core/filtering/test_contract_past_event_filtering.py b/tests/core/filtering/test_contract_past_event_filtering.py index 3fc66e2aa7..40e8f13b4a 100644 --- a/tests/core/filtering/test_contract_past_event_filtering.py +++ b/tests/core/filtering/test_contract_past_event_filtering.py @@ -1,3 +1,4 @@ +import asyncio import pytest from eth_utils import ( @@ -85,3 +86,101 @@ def test_get_all_entries_returned_block_data( assert event_data["transactionIndex"] == txn_receipt["transactionIndex"] assert is_same_address(event_data["address"], emitter.address) assert event_data["event"] == "LogNoArguments" + + +# --- async --- # + + +@pytest.fixture(scope="module") +def event_loop(): + loop = asyncio.get_event_loop_policy().new_event_loop() + yield loop + loop.close() + + +@pytest.mark.asyncio +@pytest.mark.parametrize("call_as_instance", (True, False)) +@pytest.mark.parametrize("api_style", ("v4", "build_filter")) +async def test_on_async_filter_using_get_all_entries_interface( + async_w3, + async_emitter, + AsyncEmitter, + async_wait_for_transaction, + emitter_event_ids, + call_as_instance, + api_style, + async_create_filter, +): + if call_as_instance: + contract = async_emitter + else: + contract = AsyncEmitter + + if api_style == "build_filter": + builder = contract.events.LogNoArguments.build_filter() + builder.fromBlock = "latest" + event_filter = await builder.deploy(async_w3) + else: + event_filter = await async_create_filter( + contract, ["LogNoArguments", {"fromBlock": "latest"}] + ) + + txn_hash = await async_emitter.functions.logNoArgs( + emitter_event_ids.LogNoArguments + ).transact() + await async_wait_for_transaction(async_w3, txn_hash) + + log_entries = await event_filter.get_all_entries() + + assert len(log_entries) == 1 + assert log_entries[0]["transactionHash"] == txn_hash + + # a second call still retrieves all results + log_entries_2 = await event_filter.get_all_entries() + + assert len(log_entries_2) == 1 + assert log_entries_2[0]["transactionHash"] == txn_hash + + +@pytest.mark.asyncio +@pytest.mark.parametrize("call_as_instance", (True, False)) +@pytest.mark.parametrize("api_style", ("v4", "build_filter")) +async def test_async_get_all_entries_returned_block_data( + async_w3, + async_emitter, + AsyncEmitter, + async_wait_for_transaction, + emitter_event_ids, + call_as_instance, + api_style, + async_create_filter, +): + txn_hash = await async_emitter.functions.logNoArgs( + emitter_event_ids.LogNoArguments + ).transact() + txn_receipt = await async_wait_for_transaction(async_w3, txn_hash) + + if call_as_instance: + contract = async_emitter + else: + contract = AsyncEmitter + + if api_style == "build_filter": + builder = contract.events.LogNoArguments.build_filter() + builder.fromBlock = txn_receipt["blockNumber"] + event_filter = await builder.deploy(async_w3) + else: + event_filter = await async_create_filter( + contract, ["LogNoArguments", {"fromBlock": txn_receipt["blockNumber"]}] + ) + + log_entries = await event_filter.get_all_entries() + + assert len(log_entries) == 1 + event_data = log_entries[0] + assert event_data["args"] == {} + assert event_data["blockHash"] == txn_receipt["blockHash"] + assert event_data["blockNumber"] == txn_receipt["blockNumber"] + assert event_data["transactionIndex"] == txn_receipt["transactionIndex"] + assert is_same_address(event_data["address"], async_emitter.address) + assert event_data["event"] == "LogNoArguments" diff --git a/tests/core/filtering/test_contract_topic_filters.py b/tests/core/filtering/test_contract_topic_filters.py index 0ea0d74a96..9712978828 100644 --- a/tests/core/filtering/test_contract_topic_filters.py +++ b/tests/core/filtering/test_contract_topic_filters.py @@ -1,3 +1,4 @@ +import asyncio import pytest from hypothesis import ( @@ -5,38 +6,23 @@ settings, strategies as st, ) +import pytest_asyncio -from web3 import Web3 +from tests.core.filtering.utils import ( + _async_emitter_fixture_logic, + _async_w3_fixture_logic, + _emitter_fixture_logic, + _w3_fixture_logic, +) +from tests.utils import ( + _async_wait_for_block_fixture_logic, + _async_wait_for_transaction_fixture_logic, +) from web3._utils.module_testing.emitter_contract import ( CONTRACT_EMITTER_ABI, CONTRACT_EMITTER_CODE, CONTRACT_EMITTER_RUNTIME, ) -from web3.middleware import ( - local_filter_middleware, -) -from web3.providers.eth_tester import ( - EthereumTesterProvider, -) - - -@pytest.fixture( - scope="module", - params=[True, False], - ids=["local_filter_middleware", "node_based_filter"], -) -def w3(request): - use_filter_middleware = request.param - provider = EthereumTesterProvider() - w3 = Web3(provider) - if use_filter_middleware: - w3.middleware_onion.add(local_filter_middleware) - return w3 - - -@pytest.fixture(autouse=True) -def wait_for_mining_start(w3, wait_for_block): - wait_for_block(w3) @pytest.fixture(scope="module") @@ -63,25 +49,6 @@ def EMITTER(EMITTER_CODE, EMITTER_RUNTIME, EMITTER_ABI): } -@pytest.fixture(scope="module") -def Emitter(w3, EMITTER): - return w3.eth.contract(**EMITTER) - - -@pytest.fixture(scope="module") -def emitter(w3, Emitter, wait_for_transaction, wait_for_block, address_conversion_func): - wait_for_block(w3) - deploy_txn_hash = Emitter.constructor().transact({"gas": 10000000}) - deploy_receipt = wait_for_transaction(w3, deploy_txn_hash) - contract_address = address_conversion_func(deploy_receipt["contractAddress"]) - - bytecode = w3.eth.get_code(contract_address) - assert bytecode == Emitter.bytecode_runtime - _emitter = Emitter(address=contract_address) - assert _emitter.address == contract_address - return _emitter - - def not_empty_string(x): return x != "" @@ -135,6 +102,30 @@ def array_values(draw): return (matching, non_matching) +# --- sync --- # + + +@pytest.fixture( + scope="module", + params=[True, False], + ids=["local_filter_middleware", "node_based_filter"], +) +def w3(request): + return _w3_fixture_logic(request) + + +@pytest.fixture(scope="module") +def Emitter(w3, EMITTER): + return w3.eth.contract(**EMITTER) + + +@pytest.fixture(scope="module") +def emitter(w3, Emitter, wait_for_transaction, wait_for_block, address_conversion_func): + return _emitter_fixture_logic( + w3, Emitter, wait_for_transaction, wait_for_block, address_conversion_func + ) + + @pytest.mark.parametrize("api_style", ("v4", "build_filter")) @given(vals=dynamic_values()) @settings(max_examples=5, deadline=None) @@ -171,16 +162,13 @@ def test_topic_filters_with_dynamic_arguments( assert log_entries[0]["transactionHash"] == txn_hashes[0] -@pytest.mark.parametrize("call_as_instance", (True, False)) @pytest.mark.parametrize("api_style", ("v4", "build_filter")) @given(vals=fixed_values()) @settings(max_examples=5, deadline=None) def test_topic_filters_with_fixed_arguments( w3, emitter, - Emitter, wait_for_transaction, - call_as_instance, create_filter, api_style, vals, @@ -240,12 +228,11 @@ def test_topic_filters_with_fixed_arguments( assert log_entries[0]["transactionHash"] == txn_hashes[0] -@pytest.mark.parametrize("call_as_instance", (True, False)) @pytest.mark.parametrize("api_style", ("v4", "build_filter")) @given(vals=array_values()) @settings(max_examples=5, deadline=None) def test_topic_filters_with_list_arguments( - w3, emitter, wait_for_transaction, call_as_instance, create_filter, api_style, vals + w3, emitter, wait_for_transaction, create_filter, api_style, vals ): matching, non_matching = vals @@ -274,3 +261,207 @@ def test_topic_filters_with_list_arguments( else: with pytest.raises(TypeError): create_filter(emitter, ["LogListArgs", {"filter": {"arg0": matching}}]) + + +# --- async --- # + + +@pytest_asyncio.fixture(scope="module") +async def async_wait_for_block(): + return _async_wait_for_block_fixture_logic + + +@pytest_asyncio.fixture(scope="module") +async def async_wait_for_transaction(): + return _async_wait_for_transaction_fixture_logic + + +@pytest.fixture(scope="module") +def event_loop(): + loop = asyncio.get_event_loop_policy().new_event_loop() + yield loop + loop.close() + + +@pytest.fixture( + scope="module", + params=[True, False], + ids=["local_filter_middleware", "node_based_filter"], +) +def async_w3(request): + return _async_w3_fixture_logic(request) + + +@pytest_asyncio.fixture(scope="module") +def AsyncEmitter(async_w3, EMITTER): + return async_w3.eth.contract(**EMITTER) + + +@pytest_asyncio.fixture(scope="module") +async def async_emitter( + async_w3, + AsyncEmitter, + async_wait_for_transaction, + async_wait_for_block, + address_conversion_func, +): + return await _async_emitter_fixture_logic( + async_w3, + AsyncEmitter, + async_wait_for_transaction, + async_wait_for_block, + address_conversion_func, + ) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("api_style", ("v4", "build_filter")) +@given(vals=dynamic_values()) +@settings(max_examples=5, deadline=None) +async def test_async_topic_filters_with_dynamic_arguments( + async_w3, + async_emitter, + async_wait_for_transaction, + async_create_filter, + api_style, + vals, +): + if api_style == "build_filter": + + filter_builder = async_emitter.events.LogDynamicArgs.build_filter() + filter_builder.args["arg0"].match_single(vals["matching"]) + event_filter = await filter_builder.deploy(async_w3) + else: + event_filter = await async_create_filter( + async_emitter, ["LogDynamicArgs", {"filter": {"arg0": vals["matching"]}}] + ) + + txn_hashes = [ + await async_emitter.functions.logDynamicArgs( + arg0=vals["matching"], arg1=vals["matching"] + ).transact( + {"maxFeePerGas": 10**9, "maxPriorityFeePerGas": 10**9, "gas": 60000} + ), + await async_emitter.functions.logDynamicArgs( + arg0=vals["non_matching"][0], arg1=vals["non_matching"][0] + ).transact( + {"maxFeePerGas": 10**9, "maxPriorityFeePerGas": 10**9, "gas": 60000} + ), + ] + + for txn_hash in txn_hashes: + await async_wait_for_transaction(async_w3, txn_hash) + + log_entries = await event_filter.get_new_entries() + assert len(log_entries) == 1 + assert log_entries[0]["transactionHash"] == txn_hashes[0] + + +@pytest.mark.asyncio +@pytest.mark.parametrize("api_style", ("v4", "build_filter")) +@given(vals=fixed_values()) +@settings(max_examples=5, deadline=None) +async def test_async_topic_filters_with_fixed_arguments( + async_w3, + async_emitter, + async_wait_for_transaction, + async_create_filter, + api_style, + vals, +): + if api_style == "build_filter": + filter_builder = async_emitter.events.LogQuadrupleWithIndex.build_filter() + filter_builder.args["arg0"].match_single(vals["matching"][0]) + filter_builder.args["arg1"].match_single(vals["matching"][1]) + filter_builder.args["arg2"].match_single(vals["matching"][2]) + filter_builder.args["arg3"].match_single(vals["matching"][3]) + event_filter = await filter_builder.deploy(async_w3) + else: + event_filter = await async_create_filter( + async_emitter, + [ + "LogQuadrupleWithIndex", + { + "filter": { + "arg0": vals["matching"][0], + "arg1": vals["matching"][1], + "arg2": vals["matching"][2], + "arg3": vals["matching"][3], + } + }, + ], + ) + + txn_hashes = [] + txn_hashes.append( + await async_emitter.functions.logQuadruple( + which=11, + arg0=vals["matching"][0], + arg1=vals["matching"][1], + arg2=vals["matching"][2], + arg3=vals["matching"][3], + ).transact( + {"maxFeePerGas": 10**9, "maxPriorityFeePerGas": 10**9, "gas": 60000} + ) + ) + txn_hashes.append( + await async_emitter.functions.logQuadruple( + which=11, + arg0=vals["non_matching"][0], + arg1=vals["non_matching"][1], + arg2=vals["non_matching"][2], + arg3=vals["non_matching"][3], + ).transact( + {"maxFeePerGas": 10**9, "maxPriorityFeePerGas": 10**9, "gas": 60000} + ) + ) + + for txn_hash in txn_hashes: + await async_wait_for_transaction(async_w3, txn_hash) + + log_entries = await event_filter.get_new_entries() + assert len(log_entries) == 1 + assert log_entries[0]["transactionHash"] == txn_hashes[0] + + +@pytest.mark.asyncio +@pytest.mark.parametrize("api_style", ("v4", "build_filter")) +@given(vals=array_values()) +@settings(max_examples=5, deadline=None) +async def test_async_topic_filters_with_list_arguments( + async_w3, + async_emitter, + async_wait_for_transaction, + async_create_filter, + api_style, + vals, +): + matching, non_matching = vals + + if api_style == "build_filter": + filter_builder = async_emitter.events.LogListArgs.build_filter() + filter_builder.args["arg0"].match_single(matching) + event_filter = await filter_builder.deploy(async_w3) + txn_hashes = [] + txn_hashes.append( + await async_emitter.functions.logListArgs( + arg0=matching, arg1=matching + ).transact({"maxFeePerGas": 10**9, "maxPriorityFeePerGas": 10**9}) + ) + txn_hashes.append( + await async_emitter.functions.logListArgs( + arg0=non_matching, arg1=non_matching + ).transact({"maxFeePerGas": 10**9, "maxPriorityFeePerGas": 10**9}) + ) + + for txn_hash in txn_hashes: + await async_wait_for_transaction(async_w3, txn_hash) + + log_entries = await event_filter.get_new_entries() + assert len(log_entries) == 1 + assert log_entries[0]["transactionHash"] == txn_hashes[0] + else: + with pytest.raises(TypeError): + await async_create_filter( + async_emitter, ["LogListArgs", {"filter": {"arg0": matching}}] + ) diff --git a/tests/core/filtering/test_existing_filter_instance.py b/tests/core/filtering/test_existing_filter_instance.py index d97523e463..8ae2cdd2ee 100644 --- a/tests/core/filtering/test_existing_filter_instance.py +++ b/tests/core/filtering/test_existing_filter_instance.py @@ -1,18 +1,14 @@ import pytest +import pytest_asyncio + from web3._utils.threads import ( Timeout, ) -from web3.providers.eth_tester import ( - EthereumTesterProvider, -) @pytest.fixture() def filter_id(w3): - if not isinstance(w3.provider, EthereumTesterProvider): - w3.provider = EthereumTesterProvider() - block_filter = w3.eth.filter("latest") return block_filter.filter_id @@ -43,3 +39,46 @@ def test_instantiate_existing_filter(w3, sleep_interval, wait_for_block, filter_ w3.eth.get_block(n + 1).hash for n in range(current_block, current_block + 3) ] assert found_block_hashes == expected_block_hashes + + +# --- async --- # + + +@pytest_asyncio.fixture() +async def async_filter_id(async_w3): + block_filter = await async_w3.eth.filter("latest") + return block_filter.filter_id + + +@pytest.mark.asyncio +async def test_async_instantiate_existing_filter( + async_w3, sleep_interval, async_wait_for_block, async_filter_id +): + with pytest.raises(TypeError): + await async_w3.eth.filter("latest", async_filter_id) + with pytest.raises(TypeError): + await async_w3.eth.filter("latest", filter_id=async_filter_id) + with pytest.raises(TypeError): + await async_w3.eth.filter(filter_params="latest", filter_id=async_filter_id) + + block_filter = await async_w3.eth.filter(filter_id=async_filter_id) + + current_block = await async_w3.eth.block_number + + await async_wait_for_block(async_w3, current_block + 3) + + found_block_hashes = [] + with Timeout(5) as timeout: + while len(found_block_hashes) < 3: + new_entries = await block_filter.get_new_entries() + found_block_hashes.extend(new_entries) + await timeout.async_sleep(sleep_interval()) + + assert len(found_block_hashes) == 3 + + expected_block_hashes = [] + for n in range(current_block, current_block + 3): + next_block = await async_w3.eth.get_block(n + 1) + expected_block_hashes.append(next_block.hash) + + assert found_block_hashes == expected_block_hashes diff --git a/tests/core/filtering/test_filter_against_latest_blocks.py b/tests/core/filtering/test_filter_against_latest_blocks.py index b12e76b580..856154b528 100644 --- a/tests/core/filtering/test_filter_against_latest_blocks.py +++ b/tests/core/filtering/test_filter_against_latest_blocks.py @@ -1,19 +1,13 @@ +import pytest + from web3._utils.threads import ( Timeout, ) -from web3.providers.eth_tester import ( - EthereumTesterProvider, -) def test_sync_filter_against_latest_blocks(w3, sleep_interval, wait_for_block): - if not isinstance(w3.provider, EthereumTesterProvider): - w3.provider = EthereumTesterProvider() - txn_filter = w3.eth.filter("latest") - current_block = w3.eth.block_number - wait_for_block(w3, current_block + 3) found_block_hashes = [] @@ -28,3 +22,31 @@ def test_sync_filter_against_latest_blocks(w3, sleep_interval, wait_for_block): w3.eth.get_block(n + 1).hash for n in range(current_block, current_block + 3) ] assert found_block_hashes == expected_block_hashes + + +# --- async --- # + + +@pytest.mark.asyncio +async def test_async_filter_against_latest_blocks( + async_w3, sleep_interval, async_wait_for_block +): + txn_filter = await async_w3.eth.filter("latest") + current_block = await async_w3.eth.block_number + await async_wait_for_block(async_w3, current_block + 3) + + found_block_hashes = [] + with Timeout(5) as timeout: + while len(found_block_hashes) < 3: + new_entries = await txn_filter.get_new_entries() + found_block_hashes.extend(new_entries) + await timeout.async_sleep(sleep_interval()) + + assert len(found_block_hashes) == 3 + + expected_block_hashes = [] + for n in range(current_block, current_block + 3): + block = await async_w3.eth.get_block(n + 1) + expected_block_hashes.append(block.hash) + + assert found_block_hashes == expected_block_hashes diff --git a/tests/core/filtering/test_filter_against_pending_transactions.py b/tests/core/filtering/test_filter_against_pending_transactions.py index c23da77412..82f5dc1e4a 100644 --- a/tests/core/filtering/test_filter_against_pending_transactions.py +++ b/tests/core/filtering/test_filter_against_pending_transactions.py @@ -1,25 +1,11 @@ import pytest -import random -from flaky import ( - flaky, -) -from web3._utils.threads import ( - Timeout, -) - - -@pytest.mark.skip(reason="fixture 'w3_empty' not found") -@flaky(max_runs=3) def test_sync_filter_against_pending_transactions( - w3_empty, wait_for_transaction, skip_if_testrpc + w3, + wait_for_transaction, ): - w3 = w3_empty - skip_if_testrpc(w3) - txn_filter = w3.eth.filter("pending") - txn_1_hash = w3.eth.send_transaction( { "from": w3.eth.coinbase, @@ -38,51 +24,37 @@ def test_sync_filter_against_pending_transactions( wait_for_transaction(w3, txn_1_hash) wait_for_transaction(w3, txn_2_hash) - with Timeout(5) as timeout: - while not txn_filter.get_new_entries(): - timeout.sleep(random.random()) - seen_txns = txn_filter.get_new_entries() assert txn_1_hash in seen_txns assert txn_2_hash in seen_txns -@pytest.mark.skip(reason="fixture 'w3_empty' not found") -@flaky(max_runs=3) -def test_async_filter_against_pending_transactions( - w3_empty, wait_for_transaction, skip_if_testrpc +@pytest.mark.asyncio +async def test_async_filter_against_pending_transactions( + async_w3, async_wait_for_transaction ): - w3 = w3_empty - skip_if_testrpc(w3) - - seen_txns = [] - txn_filter = w3.eth.filter("pending") - txn_filter.watch(seen_txns.append) - - txn_1_hash = w3.eth.send_transaction( + txn_filter = await async_w3.eth.filter("pending") + async_w3_eth_coinbase = await async_w3.eth.coinbase + txn_1_hash = await async_w3.eth.send_transaction( { - "from": w3.eth.coinbase, + "from": async_w3_eth_coinbase, "to": "0xd3CdA913deB6f67967B99D67aCDFa1712C293601", "value": 12345, } ) - txn_2_hash = w3.eth.send_transaction( + txn_2_hash = await async_w3.eth.send_transaction( { - "from": w3.eth.coinbase, + "from": async_w3_eth_coinbase, "to": "0xd3CdA913deB6f67967B99D67aCDFa1712C293601", "value": 54321, } ) - wait_for_transaction(w3, txn_1_hash) - wait_for_transaction(w3, txn_2_hash) - - with Timeout(5) as timeout: - while not seen_txns: - timeout.sleep(random.random()) + await async_wait_for_transaction(async_w3, txn_1_hash) + await async_wait_for_transaction(async_w3, txn_2_hash) - txn_filter.stop_watching(30) + seen_txns = await txn_filter.get_new_entries() assert txn_1_hash in seen_txns assert txn_2_hash in seen_txns diff --git a/tests/core/filtering/test_filter_against_transaction_logs.py b/tests/core/filtering/test_filter_against_transaction_logs.py index 1a769c12ba..49172f0891 100644 --- a/tests/core/filtering/test_filter_against_transaction_logs.py +++ b/tests/core/filtering/test_filter_against_transaction_logs.py @@ -1,24 +1,10 @@ import pytest -import random -from flaky import ( - flaky, -) -from web3._utils.threads import ( - Timeout, -) - - -@pytest.mark.skip(reason="fixture 'w3_empty' not found") -@flaky(max_runs=3) def test_sync_filter_against_log_events( - w3_empty, emitter, wait_for_transaction, emitter_log_topics, emitter_event_ids + w3, emitter, wait_for_transaction, emitter_event_ids ): - w3 = w3_empty - txn_filter = w3.eth.filter({}) - txn_hashes = set() txn_hashes.add( emitter.functions.logNoArgs(emitter_event_ids.LogNoArguments).transact() @@ -27,39 +13,26 @@ def test_sync_filter_against_log_events( for txn_hash in txn_hashes: wait_for_transaction(w3, txn_hash) - with Timeout(5) as timeout: - while not txn_filter.get_new_entries(): - timeout.sleep(random.random()) - seen_logs = txn_filter.get_new_entries() assert txn_hashes == {log["transactionHash"] for log in seen_logs} -@pytest.mark.skip(reason="fixture 'w3_empty' not found") -@flaky(max_runs=3) -def test_async_filter_against_log_events( - w3_empty, emitter, wait_for_transaction, emitter_log_topics, emitter_event_ids +@pytest.mark.asyncio +async def test_async_filter_against_log_events( + async_w3, async_emitter, async_wait_for_transaction, emitter_event_ids ): - w3 = w3_empty - - seen_logs = [] - txn_filter = w3.eth.filter({}) - txn_filter.watch(seen_logs.append) - + txn_filter = await async_w3.eth.filter({}) txn_hashes = set() - txn_hashes.add( - emitter.functions.logNoArgs(emitter_event_ids.LogNoArguments).transact() + await async_emitter.functions.logNoArgs( + emitter_event_ids.LogNoArguments + ).transact() ) for txn_hash in txn_hashes: - wait_for_transaction(w3, txn_hash) - - with Timeout(5) as timeout: - while not seen_logs: - timeout.sleep(random.random()) + await async_wait_for_transaction(async_w3, txn_hash) - txn_filter.stop_watching(30) + seen_logs = await txn_filter.get_new_entries() assert txn_hashes == {log["transactionHash"] for log in seen_logs} diff --git a/tests/core/filtering/test_filters_against_many_blocks.py b/tests/core/filtering/test_filters_against_many_blocks.py index 0acfa45494..9a8633d230 100644 --- a/tests/core/filtering/test_filters_against_many_blocks.py +++ b/tests/core/filtering/test_filters_against_many_blocks.py @@ -68,10 +68,8 @@ def test_event_filter_new_events( assert len(event_filter.get_new_entries()) == expected_match_counter -@pytest.mark.xfail(reason="Suspected eth-tester bug") def test_block_filter(w3): block_filter = w3.eth.filter("latest") - while w3.eth.block_number < 50: pad_with_transactions(w3) @@ -91,17 +89,11 @@ def test_transaction_filter_with_mining(w3): assert len(transaction_filter.get_new_entries()) == transaction_counter -@pytest.mark.xfail(reason="Suspected eth-tester bug") def test_transaction_filter_without_mining(w3): - - w3.providers[0].ethereum_tester.auto_mine_transactions = False transaction_filter = w3.eth.filter("pending") - transaction_counter = 0 - - transact_once = single_transaction(w3) while transaction_counter < 100: - next(transact_once) + single_transaction(w3) transaction_counter += 1 assert len(transaction_filter.get_new_entries()) == transaction_counter @@ -153,3 +145,170 @@ def gen_non_matching_transact(): pad_with_transactions(w3) assert len(event_filter.get_new_entries()) == expected_match_counter + + +# --- async --- # + + +async def async_deploy_contracts(async_w3, contract, async_wait_for_transaction): + txs = [] + for i in range(25): + tx_hash = await contract.constructor().transact() + await async_wait_for_transaction(async_w3, tx_hash) + tx = await async_w3.eth.get_transaction_receipt(tx_hash) + + txs.append(tx["contractAddress"]) + + return tuple(txs) + + +async def async_pad_with_transactions(async_w3): + accounts = await async_w3.eth.accounts + for tx_count in range(random.randint(0, 10)): + _from = accounts[random.randint(0, len(accounts) - 1)] + _to = accounts[random.randint(0, len(accounts) - 1)] + value = 50 + tx_count + await async_w3.eth.send_transaction({"from": _from, "to": _to, "value": value}) + + +async def async_single_transaction(async_w3): + accounts = await async_w3.eth.accounts + _from = accounts[random.randint(0, len(accounts) - 1)] + _to = accounts[random.randint(0, len(accounts) - 1)] + value = 50 + tx_hash = await async_w3.eth.send_transaction( + {"from": _from, "to": _to, "value": value} + ) + return tx_hash + + +@pytest.mark.asyncio +@pytest.mark.parametrize("api_style", ("v4", "build_filter")) +async def test_async_event_filter_new_events( + async_w3, + async_emitter, + async_wait_for_transaction, + api_style, +): + + matching_transact = async_emitter.functions.logNoArgs(which=1).transact + non_matching_transact = async_emitter.functions.logNoArgs(which=0).transact + + if api_style == "build_filter": + builder = async_emitter.events.LogNoArguments.build_filter() + builder.fromBlock = "latest" + event_filter = await builder.deploy(async_w3) + else: + event_filter = await async_emitter.events.LogNoArguments().create_filter( + fromBlock="latest" + ) + + expected_match_counter = 0 + + eth_block_number = await async_w3.eth.block_number + while eth_block_number < 50: + is_match = bool(random.randint(0, 1)) + if is_match: + expected_match_counter += 1 + awaited_matching_transact = await matching_transact() + await async_wait_for_transaction(async_w3, awaited_matching_transact) + await async_pad_with_transactions(async_w3) + continue + await non_matching_transact() + await async_pad_with_transactions(async_w3) + eth_block_number = await async_w3.eth.block_number + + new_entries = await event_filter.get_new_entries() + assert len(new_entries) == expected_match_counter + + +@pytest.mark.asyncio +async def test_async_block_filter(async_w3): + block_filter = await async_w3.eth.filter("latest") + + eth_block_number = await async_w3.eth.block_number + while eth_block_number < 50: + await async_pad_with_transactions(async_w3) + eth_block_number = await async_w3.eth.block_number + + new_entries = await block_filter.get_new_entries() + eth_block_number = await async_w3.eth.block_number + assert len(new_entries) == eth_block_number + + +@pytest.mark.asyncio +async def async_test_transaction_filter_with_mining(async_w3): + transaction_filter = await async_w3.eth.filter("pending") + transaction_counter = 0 + while transaction_counter < 100: + await async_single_transaction(async_w3) + transaction_counter += 1 + + new_entries = await transaction_filter.get_new_entries() + assert len(new_entries) == transaction_counter + + +@pytest.mark.asyncio +async def test_async_transaction_filter_without_mining(async_w3): + transaction_filter = await async_w3.eth.filter("pending") + transaction_counter = 0 + while transaction_counter < 100: + await async_single_transaction(async_w3) + transaction_counter += 1 + + new_entries = await transaction_filter.get_new_entries() + assert len(new_entries) == transaction_counter + + +@pytest.mark.asyncio +@pytest.mark.parametrize("api_style", ("v4", "build_filter")) +async def test_async_event_filter_new_events_many_deployed_contracts( + async_w3, + async_emitter, + AsyncEmitter, + async_wait_for_transaction, + api_style, +): + + matching_transact = async_emitter.functions.logNoArgs(which=1).transact + + deployed_contract_addresses = await async_deploy_contracts( + async_w3, AsyncEmitter, async_wait_for_transaction + ) + + async def gen_non_matching_transact(): + while True: + contract_address = deployed_contract_addresses[ + random.randint(0, len(deployed_contract_addresses) - 1) + ] + yield async_w3.eth.contract( + address=contract_address, abi=AsyncEmitter.abi + ).functions.logNoArgs(which=1).transact + + non_matching_transact = gen_non_matching_transact() + + if api_style == "build_filter": + builder = async_emitter.events.LogNoArguments.build_filter() + builder.fromBlock = "latest" + event_filter = await builder.deploy(async_w3) + else: + event_filter = await async_emitter.events.LogNoArguments().create_filter( + fromBlock="latest" + ) + + expected_match_counter = 0 + + eth_block_number = await async_w3.eth.block_number + while eth_block_number < 50: + is_match = bool(random.randint(0, 1)) + if is_match: + expected_match_counter += 1 + await matching_transact() + await async_pad_with_transactions(async_w3) + continue + await non_matching_transact.__anext__() + await async_pad_with_transactions(async_w3) + eth_block_number = await async_w3.eth.block_number + + new_entries = await event_filter.get_new_entries() + assert len(new_entries) == expected_match_counter diff --git a/tests/core/filtering/utils.py b/tests/core/filtering/utils.py new file mode 100644 index 0000000000..5c38af984d --- /dev/null +++ b/tests/core/filtering/utils.py @@ -0,0 +1,68 @@ +from web3 import Web3 +from web3.eth import ( + AsyncEth, +) +from web3.middleware import ( + async_local_filter_middleware, + local_filter_middleware, +) +from web3.providers.eth_tester import ( + AsyncEthereumTesterProvider, + EthereumTesterProvider, +) + + +def _w3_fixture_logic(request): + use_filter_middleware = request.param + provider = EthereumTesterProvider() + w3 = Web3(provider) + if use_filter_middleware: + w3.middleware_onion.add(local_filter_middleware) + return w3 + + +def _emitter_fixture_logic( + w3, Emitter, wait_for_transaction, wait_for_block, address_conversion_func +): + wait_for_block(w3) + deploy_txn_hash = Emitter.constructor().transact({"gas": 10000000}) + deploy_receipt = wait_for_transaction(w3, deploy_txn_hash) + contract_address = address_conversion_func(deploy_receipt["contractAddress"]) + + bytecode = w3.eth.get_code(contract_address) + assert bytecode == Emitter.bytecode_runtime + _emitter = Emitter(address=contract_address) + assert _emitter.address == contract_address + return _emitter + + +# --- async --- # + + +def _async_w3_fixture_logic(request): + use_filter_middleware = request.param + provider = AsyncEthereumTesterProvider() + async_w3 = Web3(provider, modules={"eth": [AsyncEth]}, middlewares=[]) + + if use_filter_middleware: + async_w3.middleware_onion.add(async_local_filter_middleware) + return async_w3 + + +async def _async_emitter_fixture_logic( + async_w3, + AsyncEmitter, + async_wait_for_transaction, + async_wait_for_block, + address_conversion_func, +): + await async_wait_for_block(async_w3) + deploy_txn_hash = await AsyncEmitter.constructor().transact({"gas": 10000000}) + deploy_receipt = await async_wait_for_transaction(async_w3, deploy_txn_hash) + contract_address = address_conversion_func(deploy_receipt["contractAddress"]) + + bytecode = await async_w3.eth.get_code(contract_address) + assert bytecode == AsyncEmitter.bytecode_runtime + _emitter = AsyncEmitter(address=contract_address) + assert _emitter.address == contract_address + return _emitter diff --git a/tests/core/middleware/test_filter_middleware.py b/tests/core/middleware/test_filter_middleware.py index afa486bb82..d1fd94f7a0 100644 --- a/tests/core/middleware/test_filter_middleware.py +++ b/tests/core/middleware/test_filter_middleware.py @@ -3,19 +3,29 @@ from hexbytes import ( HexBytes, ) +import pytest_asyncio from web3 import Web3 from web3.datastructures import ( AttributeDict, ) +from web3.eth import ( + AsyncEth, +) from web3.middleware import ( + async_construct_result_generator_middleware, + async_local_filter_middleware, construct_result_generator_middleware, local_filter_middleware, ) from web3.middleware.filter import ( + async_iter_latest_block_ranges, block_ranges, iter_latest_block_ranges, ) +from web3.providers.async_base import ( + AsyncBaseProvider, +) from web3.providers.base import ( BaseProvider, ) @@ -184,6 +194,15 @@ def test_block_ranges(start, stop, expected): (None, None), ], ), + ( + 10, + 10, + [10, 10], + [ + (10, 10), + (None, None), + ], + ), ], ) def test_iter_latest_block_ranges( @@ -206,18 +225,171 @@ def test_local_filter_middleware(w3, iter_block_number): block_filter = w3.eth.filter("latest") block_filter.get_new_entries() iter_block_number.send(1) + assert w3.eth.get_filter_changes(block_filter.filter_id) == [HexBytes(BLOCK_HASH)] log_filter = w3.eth.filter(filter_params={"fromBlock": "latest"}) + iter_block_number.send(2) + log_changes = w3.eth.get_filter_changes(log_filter.filter_id) + assert log_changes == FILTER_LOG + assert w3.eth.get_filter_logs(log_filter.filter_id) == FILTER_LOG - assert w3.eth.get_filter_changes(block_filter.filter_id) == [HexBytes(BLOCK_HASH)] + log_filter_from_hex_string = w3.eth.filter( + filter_params={"fromBlock": "0x0", "toBlock": "0x2"} + ) + log_filter_from_int = w3.eth.filter(filter_params={"fromBlock": 1, "toBlock": 3}) + + filter_ids = ( + block_filter.filter_id, + log_filter.filter_id, + log_filter_from_hex_string.filter_id, + log_filter_from_int.filter_id, + ) + + # Test that all ids are str types + assert all(isinstance(_filter_id, (str,)) for _filter_id in filter_ids) + + # Test that all ids are unique + assert len(filter_ids) == len(set(filter_ids)) + + +# --- async --- # + + +class AsyncDummyProvider(AsyncBaseProvider): + async def make_request(self, method, params): + raise NotImplementedError(f"Cannot make request for {method}:{params}") + + +@pytest_asyncio.fixture(scope="function") +async def async_result_generator_middleware(iter_block_number): + return await async_construct_result_generator_middleware( + { + "eth_getLogs": lambda *_: FILTER_LOG, + "eth_getBlockByNumber": lambda *_: {"hash": BLOCK_HASH}, + "net_version": lambda *_: 1, + "eth_blockNumber": lambda *_: next(iter_block_number), + } + ) + +@pytest.fixture(scope="function") +def async_w3_base(): + return Web3( + provider=AsyncDummyProvider(), modules={"eth": (AsyncEth)}, middlewares=[] + ) + + +@pytest.fixture(scope="function") +def async_w3(async_w3_base, async_result_generator_middleware): + async_w3_base.middleware_onion.add(async_result_generator_middleware) + async_w3_base.middleware_onion.add(async_local_filter_middleware) + return async_w3_base + + +@pytest.mark.parametrize( + "from_block,to_block,current_block,expected", + [ + ( + 0, + 10, + [10], + [ + (0, 10), + ], + ), + ( + 0, + 55, + [0, 19, 55], + [ + (0, 0), + (1, 19), + (20, 55), + ], + ), + ( + 0, + None, + [10], + [ + (0, 10), + ], + ), + ( + 0, + 10, + [12], + [ + (None, None), + ], + ), + ( + 12, + 10, + [12], + [ + (None, None), + ], + ), + ( + 12, + 10, + [None], + [ + (None, None), + ], + ), + ( + 10, + 10, + [10, 10], + [ + (10, 10), + (None, None), + ], + ), + ], +) +@pytest.mark.asyncio +async def test_async_iter_latest_block_ranges( + async_w3, iter_block_number, from_block, to_block, current_block, expected +): + latest_block_ranges = async_iter_latest_block_ranges(async_w3, from_block, to_block) + for index, block in enumerate(current_block): + iter_block_number.send(block) + expected_tuple = expected[index] + actual_tuple = await latest_block_ranges.__anext__() + assert actual_tuple == expected_tuple + + +@pytest.mark.asyncio +async def test_async_local_filter_middleware(async_w3, iter_block_number): + block_filter = await async_w3.eth.filter("latest") + await block_filter.get_new_entries() + iter_block_number.send(1) + block_changes = await async_w3.eth.get_filter_changes(block_filter.filter_id) + assert block_changes == [HexBytes(BLOCK_HASH)] + + log_filter = await async_w3.eth.filter(filter_params={"fromBlock": "latest"}) iter_block_number.send(2) - results = w3.eth.get_filter_changes(log_filter.filter_id) - assert results == FILTER_LOG + log_changes = await async_w3.eth.get_filter_changes(log_filter.filter_id) + assert log_changes == FILTER_LOG + logs = await async_w3.eth.get_filter_logs(log_filter.filter_id) + assert logs == FILTER_LOG - assert w3.eth.get_filter_logs(log_filter.filter_id) == FILTER_LOG + log_filter_from_hex_string = await async_w3.eth.filter( + filter_params={"fromBlock": "0x0", "toBlock": "0x2"} + ) + log_filter_from_int = await async_w3.eth.filter( + filter_params={"fromBlock": 1, "toBlock": 3} + ) - filter_ids = (block_filter.filter_id, log_filter.filter_id) + filter_ids = ( + block_filter.filter_id, + log_filter.filter_id, + log_filter_from_hex_string.filter_id, + log_filter_from_int.filter_id, + ) # Test that all ids are str types assert all(isinstance(_filter_id, (str,)) for _filter_id in filter_ids) diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 9d682122da..cddb026238 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -95,7 +95,7 @@ def async_offchain_lookup_contract_factory(async_w3): @pytest.fixture(scope="module") -def event_loop(request): +def event_loop(): loop = asyncio.get_event_loop_policy().new_event_loop() yield loop loop.close() diff --git a/tests/utils.py b/tests/utils.py index 43a9a1aa2e..7485ca8a30 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -4,6 +4,32 @@ import websockets +from web3._utils.threads import ( + Timeout, +) + + +class PollDelayCounter: + def __init__(self, initial_delay=0, max_delay=1, initial_step=0.01): + self.initial_delay = initial_delay + self.initial_step = initial_step + self.max_delay = max_delay + self.current_delay = initial_delay + + def __call__(self): + delay = self.current_delay + + if self.current_delay == 0: + self.current_delay += self.initial_step + else: + self.current_delay *= 2 + self.current_delay = min(self.current_delay, self.max_delay) + + return delay + + def reset(self): + self.current_delay = self.initial_delay + def get_open_port(): sock = socket.socket() @@ -23,3 +49,39 @@ async def wait_for_ws(endpoint_uri, timeout=10): await asyncio.sleep(0.01) else: break + + +async def _async_wait_for_block_fixture_logic(async_w3, block_number=1, timeout=None): + if not timeout: + current_block_number = await async_w3.eth.block_number # type:ignore + timeout = (block_number - current_block_number) * 3 + poll_delay_counter = PollDelayCounter() + with Timeout(timeout) as timeout: + eth_block_number = await async_w3.eth.block_number + while eth_block_number < block_number: + await async_w3.manager.coro_request("evm_mine", []) + await timeout.async_sleep(poll_delay_counter()) + eth_block_number = await async_w3.eth.block_number + + +async def _async_wait_for_transaction_fixture_logic(async_w3, txn_hash, timeout=120): + poll_delay_counter = PollDelayCounter() + with Timeout(timeout) as timeout: + while True: + txn_receipt = await async_w3.eth.get_transaction_receipt(txn_hash) + if txn_receipt is not None: + break + asyncio.sleep(poll_delay_counter()) + timeout.check() + + return txn_receipt + + +def async_partial(f, *args, **kwargs): + async def f2(*args2, **kwargs2): + result = f(*args, *args2, **kwargs, **kwargs2) + if asyncio.iscoroutinefunction(f): + result = await result + return result + + return f2 diff --git a/web3/_utils/events.py b/web3/_utils/events.py index d3405af39f..d4f210a80b 100644 --- a/web3/_utils/events.py +++ b/web3/_utils/events.py @@ -88,6 +88,7 @@ if TYPE_CHECKING: from web3 import Web3 # noqa: F401 from web3._utils.filters import ( # noqa: F401 + AsyncLogFilter, LogFilter, ) @@ -310,7 +311,7 @@ def is_indexed(arg: Any) -> bool: is_not_indexed = complement(is_indexed) -class EventFilterBuilder: +class BaseEventFilterBuilder: formatter = None _fromBlock = None _toBlock = None @@ -410,6 +411,8 @@ def filter_params(self) -> FilterParams: } return valfilter(lambda x: x is not None, params) + +class EventFilterBuilder(BaseEventFilterBuilder): def deploy(self, w3: "Web3") -> "LogFilter": if not isinstance(w3, web3.Web3): raise ValueError(f"Invalid web3 argument: got: {w3!r}") @@ -427,6 +430,24 @@ def deploy(self, w3: "Web3") -> "LogFilter": return log_filter +class AsyncEventFilterBuilder(BaseEventFilterBuilder): + async def deploy(self, async_w3: "Web3") -> "AsyncLogFilter": + if not isinstance(async_w3, web3.Web3): + raise ValueError(f"Invalid web3 argument: got: {async_w3!r}") + + for arg in AttributeDict.values(self.args): + arg._immutable = True + self._immutable = True + + log_filter = await async_w3.eth.filter(self.filter_params) # type: ignore + log_filter.filter_params = self.filter_params + log_filter.set_data_filters(self.data_argument_values) + log_filter.builder = self + if self.formatter is not None: + log_filter.log_entry_formatter = self.formatter + return log_filter + + def initialize_event_topics(event_abi: ABIEvent) -> Union[bytes, List[Any]]: if event_abi["anonymous"] is False: # https://github.com/python/mypy/issues/4976 diff --git a/web3/_utils/filters.py b/web3/_utils/filters.py index 25e224d15c..72cc16003a 100644 --- a/web3/_utils/filters.py +++ b/web3/_utils/filters.py @@ -41,6 +41,7 @@ ) from web3._utils.events import ( + AsyncEventFilterBuilder, EventFilterBuilder, construct_event_data_set, construct_event_topic_set, @@ -62,6 +63,7 @@ if TYPE_CHECKING: from web3 import Web3 # noqa: F401 from web3.eth import Eth # noqa: F401 + from web3.eth import AsyncEth # noqa: F401 def construct_event_filter_params( @@ -126,14 +128,13 @@ def construct_event_filter_params( return data_filters_set, filter_params -class Filter: +class BaseFilter: callbacks: List[Callable[..., Any]] = None stopped = False poll_interval = None filter_id = None - def __init__(self, filter_id: HexStr, eth_module: "Eth") -> None: - self.eth_module = eth_module + def __init__(self, filter_id: HexStr) -> None: self.filter_id = filter_id self.callbacks = [] super().__init__() @@ -159,6 +160,23 @@ def _filter_valid_entries( ) -> Iterator[LogReceipt]: return filter(self.is_valid_entry, entries) + def _format_log_entries( + self, log_entries: Optional[Iterator[LogReceipt]] = None + ) -> List[LogReceipt]: + if log_entries is None: + return [] + + formatted_log_entries = [ + self.format_entry(log_entry) for log_entry in log_entries + ] + return formatted_log_entries + + +class Filter(BaseFilter): + def __init__(self, filter_id: HexStr, eth_module: "Eth") -> None: + self.eth_module = eth_module + super(Filter, self).__init__(filter_id) + def get_new_entries(self) -> List[LogReceipt]: log_entries = self._filter_valid_entries( self.eth_module.get_filter_changes(self.filter_id) @@ -171,26 +189,39 @@ def get_all_entries(self) -> List[LogReceipt]: ) return self._format_log_entries(log_entries) - def _format_log_entries( - self, log_entries: Optional[Iterator[LogReceipt]] = None - ) -> List[LogReceipt]: - if log_entries is None: - return [] - formatted_log_entries = [ - self.format_entry(log_entry) for log_entry in log_entries - ] - return formatted_log_entries +class AsyncFilter(BaseFilter): + def __init__(self, filter_id: HexStr, eth_module: "AsyncEth") -> None: + self.eth_module = eth_module + super(AsyncFilter, self).__init__(filter_id) + + async def get_new_entries(self) -> List[LogReceipt]: + filter_changes = await self.eth_module.get_filter_changes(self.filter_id) + log_entries = self._filter_valid_entries(filter_changes) + return self._format_log_entries(log_entries) + + async def get_all_entries(self) -> List[LogReceipt]: + filter_logs = await self.eth_module.get_filter_logs(self.filter_id) + log_entries = self._filter_valid_entries(filter_logs) + return self._format_log_entries(log_entries) class BlockFilter(Filter): pass +class AsyncBlockFilter(AsyncFilter): + pass + + class TransactionFilter(Filter): pass +class AsyncTransactionFilter(AsyncFilter): + pass + + class LogFilter(Filter): data_filter_set = None data_filter_set_regex = None @@ -233,6 +264,48 @@ def is_valid_entry(self, entry: LogReceipt) -> bool: return bool(self.data_filter_set_function(entry["data"])) +class AsyncLogFilter(AsyncFilter): + data_filter_set = None + data_filter_set_regex = None + data_filter_set_function = None + log_entry_formatter = None + filter_params: FilterParams = None + builder: AsyncEventFilterBuilder = None + + def __init__(self, *args: Any, **kwargs: Any) -> None: + self.log_entry_formatter = kwargs.pop( + "log_entry_formatter", + self.log_entry_formatter, + ) + if "data_filter_set" in kwargs: + self.set_data_filters(kwargs.pop("data_filter_set")) + super().__init__(*args, **kwargs) + + def format_entry(self, entry: LogReceipt) -> LogReceipt: + if self.log_entry_formatter: + return self.log_entry_formatter(entry) + return entry + + def set_data_filters( + self, data_filter_set: Collection[Tuple[TypeStr, Any]] + ) -> None: + """Sets the data filters (non indexed argument filters) + + Expects a set of tuples with the type and value, e.g.: + (('uint256', [12345, 54321]), ('string', ('a-single-string',))) + """ + self.data_filter_set = data_filter_set + if any(data_filter_set): + self.data_filter_set_function = match_fn( + self.eth_module.codec, data_filter_set + ) + + def is_valid_entry(self, entry: LogReceipt) -> bool: + if not self.data_filter_set: + return True + return bool(self.data_filter_set_function(entry["data"])) + + def decode_utf8_bytes(value: bytes) -> str: return value.decode("utf-8") diff --git a/web3/_utils/method_formatters.py b/web3/_utils/method_formatters.py index 9c48a8a416..12c0bf4b6c 100644 --- a/web3/_utils/method_formatters.py +++ b/web3/_utils/method_formatters.py @@ -10,6 +10,7 @@ NoReturn, Tuple, Union, + cast, ) from eth_abi import ( @@ -60,6 +61,9 @@ to_hex, ) from web3._utils.filters import ( + AsyncBlockFilter, + AsyncLogFilter, + AsyncTransactionFilter, BlockFilter, LogFilter, TransactionFilter, @@ -106,6 +110,7 @@ from web3 import Web3 # noqa: F401 from web3.module import Module # noqa: F401 from web3.eth import Eth # noqa: F401 + from web3.eth import AsyncEth # noqa: F401 def bytes_to_ascii(value: bytes) -> str: @@ -751,16 +756,34 @@ def raise_transaction_not_found_with_index( def filter_wrapper( - module: "Eth", + module: Union["AsyncEth", "Eth"], method: RPCEndpoint, filter_id: HexStr, -) -> Union[BlockFilter, TransactionFilter, LogFilter]: +) -> Union[ + AsyncBlockFilter, + AsyncTransactionFilter, + AsyncLogFilter, + BlockFilter, + TransactionFilter, + LogFilter, +]: if method == RPC.eth_newBlockFilter: - return BlockFilter(filter_id, eth_module=module) + if module.is_async: + return AsyncBlockFilter(filter_id, eth_module=cast("AsyncEth", module)) + else: + return BlockFilter(filter_id, eth_module=cast("Eth", module)) elif method == RPC.eth_newPendingTransactionFilter: - return TransactionFilter(filter_id, eth_module=module) + if module.is_async: + return AsyncTransactionFilter( + filter_id, eth_module=cast("AsyncEth", module) + ) + else: + return TransactionFilter(filter_id, eth_module=cast("Eth", module)) elif method == RPC.eth_newFilter: - return LogFilter(filter_id, eth_module=module) + if module.is_async: + return AsyncLogFilter(filter_id, eth_module=cast("AsyncEth", module)) + else: + return LogFilter(filter_id, eth_module=cast("Eth", module)) else: raise NotImplementedError( "Filter wrapper needs to be used with either " @@ -794,7 +817,6 @@ def get_result_formatters( formatters_requiring_module = combine_formatters( (FILTER_RESULT_FORMATTERS,), method_name ) - partial_formatters = apply_module_to_formatters( formatters_requiring_module, module, method_name ) diff --git a/web3/_utils/module_testing/eth_module.py b/web3/_utils/module_testing/eth_module.py index 3db2b5decc..717e9bcf7b 100644 --- a/web3/_utils/module_testing/eth_module.py +++ b/web3/_utils/module_testing/eth_module.py @@ -1465,6 +1465,64 @@ async def test_eth_getBlockTransactionCountByNumber_block_with_txn( assert is_integer(transaction_count) assert transaction_count >= 1 + @pytest.mark.asyncio + async def test_async_eth_new_filter(self, async_w3: "Web3") -> None: + filter = await async_w3.eth.filter({}) # type: ignore + + changes = await async_w3.eth.get_filter_changes( + filter.filter_id + ) # type: ignore + assert is_list_like(changes) + assert not changes + + logs = await async_w3.eth.get_filter_logs(filter.filter_id) # type: ignore + assert is_list_like(logs) + assert not logs + + result = await async_w3.eth.uninstall_filter(filter.filter_id) # type: ignore + assert result is True + + @pytest.mark.asyncio + async def test_async_eth_new_block_filter(self, async_w3: "Web3") -> None: + filter = await async_w3.eth.filter("latest") # type: ignore + assert is_string(filter.filter_id) + + changes = await async_w3.eth.get_filter_changes( + filter.filter_id + ) # type: ignore + assert is_list_like(changes) + assert not changes + + result = await async_w3.eth.uninstall_filter(filter.filter_id) # type: ignore + assert result is True + + @pytest.mark.asyncio + async def test_async_eth_new_pending_transaction_filter( + self, async_w3: "Web3" + ) -> None: + filter = await async_w3.eth.filter("pending") # type: ignore + assert is_string(filter.filter_id) + + changes = await async_w3.eth.get_filter_changes( + filter.filter_id + ) # type: ignore + assert is_list_like(changes) + assert not changes + + result = await async_w3.eth.uninstall_filter(filter.filter_id) # type: ignore + assert result is True + + @pytest.mark.asyncio + async def test_async_eth_uninstall_filter(self, async_w3: "Web3") -> None: + filter = await async_w3.eth.filter({}) # type: ignore + assert is_string(filter.filter_id) + + success = await async_w3.eth.uninstall_filter(filter.filter_id) # type: ignore + assert success is True + + failure = await async_w3.eth.uninstall_filter(filter.filter_id) # type: ignore + assert failure is False + class EthModuleTest: def test_eth_syncing(self, w3: "Web3") -> None: @@ -3171,7 +3229,7 @@ def test_eth_getUncleByBlockNumberAndIndex(self, w3: "Web3") -> None: # TODO: how do we make uncles.... pass - def test_eth_newFilter(self, w3: "Web3") -> None: + def test_eth_new_filter(self, w3: "Web3") -> None: filter = w3.eth.filter({}) changes = w3.eth.get_filter_changes(filter.filter_id) @@ -3185,7 +3243,7 @@ def test_eth_newFilter(self, w3: "Web3") -> None: result = w3.eth.uninstall_filter(filter.filter_id) assert result is True - def test_eth_newBlockFilter(self, w3: "Web3") -> None: + def test_eth_new_block_filter(self, w3: "Web3") -> None: filter = w3.eth.filter("latest") assert is_string(filter.filter_id) @@ -3193,15 +3251,10 @@ def test_eth_newBlockFilter(self, w3: "Web3") -> None: assert is_list_like(changes) assert not changes - # TODO: figure out why this fails in go-ethereum - # logs = w3.eth.get_filter_logs(filter.filter_id) - # assert is_list_like(logs) - # assert not logs - result = w3.eth.uninstall_filter(filter.filter_id) assert result is True - def test_eth_newPendingTransactionFilter(self, w3: "Web3") -> None: + def test_eth_new_pending_transaction_filter(self, w3: "Web3") -> None: filter = w3.eth.filter("pending") assert is_string(filter.filter_id) @@ -3209,11 +3262,6 @@ def test_eth_newPendingTransactionFilter(self, w3: "Web3") -> None: assert is_list_like(changes) assert not changes - # TODO: figure out why this fails in go-ethereum - # logs = w3.eth.get_filter_logs(filter.filter_id) - # assert is_list_like(logs) - # assert not logs - result = w3.eth.uninstall_filter(filter.filter_id) assert result is True diff --git a/web3/_utils/threads.py b/web3/_utils/threads.py index c69485604d..991c7cf254 100644 --- a/web3/_utils/threads.py +++ b/web3/_utils/threads.py @@ -1,6 +1,7 @@ """ A minimal implementation of the various gevent APIs used within this codebase. """ +import asyncio import threading import time from types import ( @@ -97,6 +98,10 @@ def sleep(self, seconds: float) -> None: time.sleep(seconds) self.check() + async def async_sleep(self, seconds: float) -> None: + await asyncio.sleep(seconds) + self.check() + class ThreadWithReturn(threading.Thread, Generic[TReturn]): def __init__( diff --git a/web3/contract.py b/web3/contract.py index b7a3029835..59c41ed965 100644 --- a/web3/contract.py +++ b/web3/contract.py @@ -85,11 +85,13 @@ to_hex, ) from web3._utils.events import ( + AsyncEventFilterBuilder, EventFilterBuilder, get_event_data, is_dynamic_sized_type, ) from web3._utils.filters import ( + AsyncLogFilter, LogFilter, construct_event_filter_params, ) @@ -1474,19 +1476,64 @@ def process_log(self, log: HexStr) -> EventData: return get_event_data(self.w3.codec, self.abi, log) @combomethod - def create_filter( + def _get_event_filter_params( + self, + abi: ABIEvent, + argument_filters: Optional[Dict[str, Any]] = None, + fromBlock: Optional[BlockIdentifier] = None, + toBlock: Optional[BlockIdentifier] = None, + blockHash: Optional[HexBytes] = None, + ) -> FilterParams: + + if not self.address: + raise TypeError( + "This method can be only called on " + "an instantiated contract with an address" + ) + + if argument_filters is None: + argument_filters = dict() + + _filters = dict(**argument_filters) + + blkhash_set = blockHash is not None + blknum_set = fromBlock is not None or toBlock is not None + if blkhash_set and blknum_set: + raise ValidationError( + "blockHash cannot be set at the same time as fromBlock or toBlock" + ) + + # Construct JSON-RPC raw filter presentation based on human readable + # Python descriptions. Namely, convert event names to their keccak signatures + data_filter_set, event_filter_params = construct_event_filter_params( + abi, + self.w3.codec, + contract_address=self.address, + argument_filters=_filters, + fromBlock=fromBlock, + toBlock=toBlock, + address=self.address, + ) + + if blockHash is not None: + event_filter_params["blockHash"] = blockHash + + return event_filter_params + + @classmethod + def factory(cls, class_name: str, **kwargs: Any) -> PropertyCheckingFactory: + return PropertyCheckingFactory(class_name, (cls,), kwargs) + + @combomethod + def _set_up_filter_builder( self, - *, # PEP 3102 argument_filters: Optional[Dict[str, Any]] = None, fromBlock: Optional[BlockIdentifier] = None, toBlock: BlockIdentifier = "latest", address: Optional[ChecksumAddress] = None, topics: Optional[Sequence[Any]] = None, - ) -> LogFilter: - """ - Create filter object that tracks logs emitted by this contract event. - :param filter_params: other parameters to limit the events - """ + filter_builder: Union[EventFilterBuilder, AsyncEventFilterBuilder] = None, + ) -> None: if fromBlock is None: raise TypeError( "Missing mandatory keyword argument to create_filter: fromBlock" @@ -1512,7 +1559,6 @@ def create_filter( topics=topics, ) - filter_builder = EventFilterBuilder(event_abi, self.w3.codec) filter_builder.address = cast( ChecksumAddress, event_filter_params.get("address") ) @@ -1536,73 +1582,6 @@ def create_filter( for arg, value in match_single_vals.items(): filter_builder.args[arg].match_single(value) - log_filter = filter_builder.deploy(self.w3) - log_filter.log_entry_formatter = get_event_data( - self.w3.codec, self._get_event_abi() - ) - log_filter.builder = filter_builder - - return log_filter - - @combomethod - def build_filter(self) -> EventFilterBuilder: - builder = EventFilterBuilder( - self._get_event_abi(), - self.w3.codec, - formatter=get_event_data(self.w3.codec, self._get_event_abi()), - ) - builder.address = self.address - return builder - - @combomethod - def _get_event_filter_params( - self, - abi: ABIEvent, - argument_filters: Optional[Dict[str, Any]] = None, - fromBlock: Optional[BlockIdentifier] = None, - toBlock: Optional[BlockIdentifier] = None, - blockHash: Optional[HexBytes] = None, - ) -> FilterParams: - - if not self.address: - raise TypeError( - "This method can be only called on " - "an instated contract with an address" - ) - - if argument_filters is None: - argument_filters = dict() - - _filters = dict(**argument_filters) - - blkhash_set = blockHash is not None - blknum_set = fromBlock is not None or toBlock is not None - if blkhash_set and blknum_set: - raise ValidationError( - "blockHash cannot be set at the same time as fromBlock or toBlock" - ) - - # Construct JSON-RPC raw filter presentation based on human readable - # Python descriptions. Namely, convert event names to their keccak signatures - data_filter_set, event_filter_params = construct_event_filter_params( - abi, - self.w3.codec, - contract_address=self.address, - argument_filters=_filters, - fromBlock=fromBlock, - toBlock=toBlock, - address=self.address, - ) - - if blockHash is not None: - event_filter_params["blockHash"] = blockHash - - return event_filter_params - - @classmethod - def factory(cls, class_name: str, **kwargs: Any) -> PropertyCheckingFactory: - return PropertyCheckingFactory(class_name, (cls,), kwargs) - class ContractEvent(BaseContractEvent): @combomethod @@ -1678,6 +1657,46 @@ def get_logs( # Convert raw binary data to Python proxy objects as described by ABI return tuple(get_event_data(self.w3.codec, abi, entry) for entry in logs) + @combomethod + def create_filter( + self, + *, # PEP 3102 + argument_filters: Optional[Dict[str, Any]] = None, + fromBlock: Optional[BlockIdentifier] = None, + toBlock: BlockIdentifier = "latest", + address: Optional[ChecksumAddress] = None, + topics: Optional[Sequence[Any]] = None, + ) -> LogFilter: + """ + Create filter object that tracks logs emitted by this contract event. + """ + filter_builder = EventFilterBuilder(self._get_event_abi(), self.w3.codec) + self._set_up_filter_builder( + argument_filters, + fromBlock, + toBlock, + address, + topics, + filter_builder, + ) + log_filter = filter_builder.deploy(self.w3) + log_filter.log_entry_formatter = get_event_data( + self.w3.codec, self._get_event_abi() + ) + log_filter.builder = filter_builder + + return log_filter + + @combomethod + def build_filter(self) -> EventFilterBuilder: + builder = EventFilterBuilder( + self._get_event_abi(), + self.w3.codec, + formatter=get_event_data(self.w3.codec, self._get_event_abi()), + ) + builder.address = self.address + return builder + class AsyncContractEvent(BaseContractEvent): @combomethod @@ -1755,6 +1774,46 @@ async def get_logs( get_event_data(self.w3.codec, abi, entry) for entry in logs # type: ignore ) + @combomethod + async def create_filter( + self, + *, # PEP 3102 + argument_filters: Optional[Dict[str, Any]] = None, + fromBlock: Optional[BlockIdentifier] = None, + toBlock: BlockIdentifier = "latest", + address: Optional[ChecksumAddress] = None, + topics: Optional[Sequence[Any]] = None, + ) -> AsyncLogFilter: + """ + Create filter object that tracks logs emitted by this contract event. + """ + filter_builder = AsyncEventFilterBuilder(self._get_event_abi(), self.w3.codec) + self._set_up_filter_builder( + argument_filters, + fromBlock, + toBlock, + address, + topics, + filter_builder, + ) + log_filter = await filter_builder.deploy(self.w3) + log_filter.log_entry_formatter = get_event_data( + self.w3.codec, self._get_event_abi() + ) + log_filter.builder = filter_builder + + return log_filter + + @combomethod + def build_filter(self) -> AsyncEventFilterBuilder: + builder = AsyncEventFilterBuilder( + self._get_event_abi(), + self.w3.codec, + formatter=get_event_data(self.w3.codec, self._get_event_abi()), + ) + builder.address = self.address + return builder + class BaseContractCaller: """ diff --git a/web3/eth.py b/web3/eth.py index b95f712a8c..f4333ee5b4 100644 --- a/web3/eth.py +++ b/web3/eth.py @@ -51,6 +51,8 @@ fee_history_priority_fee, ) from web3._utils.filters import ( + AsyncFilter, + Filter, select_filter_method, ) from web3._utils.rpc_abi import ( @@ -359,6 +361,35 @@ def set_contract_factory( ) -> None: self.defaultContractFactory = contractFactory + def filter_munger( + self, + filter_params: Optional[Union[str, FilterParams]] = None, + filter_id: Optional[HexStr] = None, + ) -> Union[List[FilterParams], List[HexStr], List[str]]: + if filter_id and filter_params: + raise TypeError( + "Ambiguous invocation: provide either a `filter_params` or a " + "`filter_id` argument. Both were supplied." + ) + if isinstance(filter_params, dict): + return [filter_params] + elif is_string(filter_params): + if filter_params in ["latest", "pending"]: + return [filter_params] + else: + raise ValueError( + "The filter API only accepts the values of `pending` or " + "`latest` for string based filters" + ) + elif filter_id and not filter_params: + return [filter_id] + else: + raise TypeError( + "Must provide either filter_params as a string or " + "a valid filter object, or a filter_id as a string " + "or hex." + ) + class AsyncEth(BaseEth): is_async = True @@ -605,6 +636,39 @@ async def get_storage_at( ) -> HexBytes: return await self._get_storage_at(account, position, block_identifier) + filter: Method[ + Callable[[Optional[Union[str, FilterParams, HexStr]]], Awaitable[AsyncFilter]] + ] = Method( + method_choice_depends_on_args=select_filter_method( + if_new_block_filter=RPC.eth_newBlockFilter, + if_new_pending_transaction_filter=RPC.eth_newPendingTransactionFilter, + if_new_filter=RPC.eth_newFilter, + ), + mungers=[BaseEth.filter_munger], + ) + + _get_filter_changes: Method[ + Callable[[HexStr], Awaitable[List[LogReceipt]]] + ] = Method(RPC.eth_getFilterChanges, mungers=[default_root_munger]) + + async def get_filter_changes(self, filter_id: HexStr) -> List[LogReceipt]: + return await self._get_filter_changes(filter_id) + + _get_filter_logs: Method[Callable[[HexStr], Awaitable[List[LogReceipt]]]] = Method( + RPC.eth_getFilterLogs, mungers=[default_root_munger] + ) + + async def get_filter_logs(self, filter_id: HexStr) -> List[LogReceipt]: + return await self._get_filter_logs(filter_id) + + _uninstall_filter: Method[Callable[[HexStr], Awaitable[bool]]] = Method( + RPC.eth_uninstallFilter, + mungers=[default_root_munger], + ) + + async def uninstall_filter(self, filter_id: HexStr) -> bool: + return await self._uninstall_filter(filter_id) + class Eth(BaseEth): defaultContractFactory: Type[Union[Contract, ContractCaller]] = Contract @@ -897,42 +961,15 @@ def fee_history( ) -> FeeHistory: return self._fee_history(block_count, newest_block, reward_percentiles) - def filter_munger( - self, - filter_params: Optional[Union[str, FilterParams]] = None, - filter_id: Optional[HexStr] = None, - ) -> Union[List[FilterParams], List[HexStr], List[str]]: - if filter_id and filter_params: - raise TypeError( - "Ambiguous invocation: provide either a `filter_params` or a " - "`filter_id` argument. Both were supplied." - ) - if isinstance(filter_params, dict): - return [filter_params] - elif is_string(filter_params): - if filter_params in {"latest", "pending"}: - return [filter_params] - else: - raise ValueError( - "The filter API only accepts the values of `pending` or " - "`latest` for string based filters" - ) - elif filter_id and not filter_params: - return [filter_id] - else: - raise TypeError( - "Must provide either filter_params as a string or " - "a valid filter object, or a filter_id as a string " - "or hex." - ) - - filter: Method[Callable[..., Any]] = Method( + filter: Method[ + Callable[[Optional[Union[str, FilterParams, HexStr]]], Filter] + ] = Method( method_choice_depends_on_args=select_filter_method( if_new_block_filter=RPC.eth_newBlockFilter, if_new_pending_transaction_filter=RPC.eth_newPendingTransactionFilter, if_new_filter=RPC.eth_newFilter, ), - mungers=[filter_munger], + mungers=[BaseEth.filter_munger], ) get_filter_changes: Method[Callable[[HexStr], List[LogReceipt]]] = Method( @@ -943,6 +980,11 @@ def filter_munger( RPC.eth_getFilterLogs, mungers=[default_root_munger] ) + uninstall_filter: Method[Callable[[HexStr], bool]] = Method( + RPC.eth_uninstallFilter, + mungers=[default_root_munger], + ) + get_logs: Method[Callable[[FilterParams], List[LogReceipt]]] = Method( RPC.eth_getLogs, mungers=[default_root_munger] ) @@ -957,11 +999,6 @@ def filter_munger( mungers=[default_root_munger], ) - uninstall_filter: Method[Callable[[HexStr], bool]] = Method( - RPC.eth_uninstallFilter, - mungers=[default_root_munger], - ) - get_work: Method[Callable[[], List[HexBytes]]] = Method( RPC.eth_getWork, is_property=True, diff --git a/web3/middleware/__init__.py b/web3/middleware/__init__.py index 7c0ddb8c1f..af99aed36f 100644 --- a/web3/middleware/__init__.py +++ b/web3/middleware/__init__.py @@ -39,9 +39,12 @@ http_retry_request_middleware, ) from .filter import ( # noqa: F401 + async_local_filter_middleware, local_filter_middleware, ) from .fixture import ( # noqa: F401 + async_construct_error_generator_middleware, + async_construct_result_generator_middleware, construct_error_generator_middleware, construct_fixture_middleware, construct_result_generator_middleware, diff --git a/web3/middleware/filter.py b/web3/middleware/filter.py index f19bfd73da..75de96d8c4 100644 --- a/web3/middleware/filter.py +++ b/web3/middleware/filter.py @@ -3,8 +3,11 @@ from typing import ( TYPE_CHECKING, Any, + AsyncIterable, + AsyncIterator, Callable, Dict, + Generator, Iterable, Iterator, List, @@ -19,6 +22,7 @@ BlockNumber, ChecksumAddress, Hash32, + HexStr, ) from eth_utils import ( apply_key_map, @@ -40,6 +44,7 @@ RPC, ) from web3.types import ( # noqa: F401 + Coroutine, FilterParams, LatestBlockParam, LogReceipt, @@ -193,8 +198,8 @@ def drop_items_with_none_value(params: Dict[str, Any]) -> Dict[str, Any]: def get_logs_multipart( w3: "Web3", - startBlock: BlockNumber, - stopBlock: BlockNumber, + start_block: BlockNumber, + stop_block: BlockNumber, address: Union[Address, ChecksumAddress, List[Union[Address, ChecksumAddress]]], topics: List[Optional[Union[_Hash32, List[_Hash32]]]], max_blocks: int, @@ -204,7 +209,7 @@ def get_logs_multipart( The getLog request is partitioned into multiple calls of the max number of blocks ``max_blocks``. """ - _block_ranges = block_ranges(startBlock, stopBlock, max_blocks) + _block_ranges = block_ranges(start_block, stop_block, max_blocks) for from_block, to_block in _block_ranges: params = { "fromBlock": from_block, @@ -251,7 +256,7 @@ def to_block(self) -> BlockNumber: to_block = self.w3.eth.block_number elif self._to_block == "latest": to_block = self.w3.eth.block_number - elif is_hex(self._to_block): + elif is_string(self._to_block) and is_hex(self._to_block): to_block = BlockNumber(hex_to_integer(self._to_block)) # type: ignore else: to_block = self._to_block @@ -367,6 +372,7 @@ def middleware(method: RPCEndpoint, params: Any) -> RPCResponse: _filter = filters[filter_id] if method == RPC.eth_getFilterChanges: return {"result": next(_filter.filter_changes)} + elif method == RPC.eth_getFilterLogs: # type ignored b/c logic prevents RequestBlocks which # doesn't implement get_logs @@ -377,3 +383,281 @@ def middleware(method: RPCEndpoint, params: Any) -> RPCResponse: return make_request(method, params) return middleware + + +# --- async --- # + + +async def async_iter_latest_block( + w3: "Web3", to_block: Optional[Union[BlockNumber, LatestBlockParam]] = None +) -> AsyncIterable[BlockNumber]: + """Returns a generator that dispenses the latest block, if + any new blocks have been mined since last iteration. + + If there are no new blocks or the latest block is greater than + the ``to_block`` None is returned. + + >>> new_blocks = iter_latest_block(w3, 0, 10) + >>> next(new_blocks) # Latest block = 0 + 0 + >>> next(new_blocks) # No new blocks + >>> next(new_blocks) # Latest block = 1 + 1 + >>> next(new_blocks) # Latest block = 10 + 10 + >>> next(new_blocks) # latest block > to block + """ + _last = None + + is_bounded_range = to_block is not None and to_block != "latest" + + while True: + latest_block = await w3.eth.block_number # type: ignore + # type ignored b/c is_bounded_range prevents unsupported comparison + if is_bounded_range and latest_block > to_block: + yield None + # No new blocks since last iteration. + if _last is not None and _last == latest_block: + yield None + else: + yield latest_block + _last = latest_block + + +async def async_iter_latest_block_ranges( + w3: "Web3", + from_block: BlockNumber, + to_block: Optional[Union[BlockNumber, LatestBlockParam]] = None, +) -> AsyncIterable[Tuple[Optional[BlockNumber], Optional[BlockNumber]]]: + """Returns an iterator unloading ranges of available blocks + + starting from `from_block` to the latest mined block, + until reaching to_block. e.g.: + + + >>> blocks_to_filter = iter_latest_block_ranges(w3, 0, 50) + >>> next(blocks_to_filter) # latest block number = 11 + (0, 11) + >>> next(blocks_to_filter) # latest block number = 45 + (12, 45) + >>> next(blocks_to_filter) # latest block number = 50 + (46, 50) + """ + latest_block_iterator = async_iter_latest_block(w3, to_block) + async for latest_block in latest_block_iterator: + if latest_block is None: + yield (None, None) + elif from_block > latest_block: + yield (None, None) + else: + yield (from_block, latest_block) + from_block = BlockNumber(latest_block + 1) + + +async def async_get_logs_multipart( + w3: "Web3", + start_block: BlockNumber, + stop_block: BlockNumber, + address: Union[Address, ChecksumAddress, List[Union[Address, ChecksumAddress]]], + topics: List[Optional[Union[_Hash32, List[_Hash32]]]], + max_blocks: int, +) -> AsyncIterable[List[LogReceipt]]: + """Used to break up requests to ``eth_getLogs`` + + The getLog request is partitioned into multiple calls of the max number of blocks + ``max_blocks``. + """ + _block_ranges = block_ranges(start_block, stop_block, max_blocks) + for from_block, to_block in _block_ranges: + params = { + "fromBlock": from_block, + "toBlock": to_block, + "address": address, + "topics": topics, + } + params_with_none_dropped = cast( + FilterParams, drop_items_with_none_value(params) + ) + next_logs = await w3.eth.get_logs(params_with_none_dropped) # type: ignore + yield next_logs + + +class AsyncRequestLogs: + _from_block: BlockNumber + + def __init__( + self, + w3: "Web3", + from_block: Optional[Union[BlockNumber, LatestBlockParam]] = None, + to_block: Optional[Union[BlockNumber, LatestBlockParam]] = None, + address: Optional[ + Union[Address, ChecksumAddress, List[Union[Address, ChecksumAddress]]] + ] = None, + topics: Optional[List[Optional[Union[_Hash32, List[_Hash32]]]]] = None, + ) -> None: + self.address = address + self.topics = topics + self.w3 = w3 + self._from_block_arg = from_block + self._to_block = to_block + self.filter_changes = self._get_filter_changes() + + def __await__(self) -> Generator[Any, None, "AsyncRequestLogs"]: + async def closure() -> "AsyncRequestLogs": + if self._from_block_arg is None or self._from_block_arg == "latest": + self.block_number = await self.w3.eth.block_number # type: ignore + self._from_block = BlockNumber(self.block_number + 1) + elif is_string(self._from_block_arg) and is_hex(self._from_block_arg): + self._from_block = BlockNumber( + hex_to_integer(cast(HexStr, self._from_block_arg)) + ) + else: + self._from_block = self._from_block_arg + + return self + + return closure().__await__() + + @property + async def from_block(self) -> BlockNumber: + return self._from_block + + @property + async def to_block(self) -> BlockNumber: + if self._to_block is None or self._to_block == "latest": + to_block = await self.w3.eth.block_number # type: ignore + elif is_string(self._to_block) and is_hex(self._to_block): + to_block = BlockNumber(hex_to_integer(cast(HexStr, self._to_block))) + else: + to_block = self._to_block + + return to_block + + async def _get_filter_changes(self) -> AsyncIterator[List[LogReceipt]]: + self_from_block = await self.from_block + self_to_block = await self.to_block + async for start, stop in async_iter_latest_block_ranges( + self.w3, self_from_block, self_to_block + ): + if None in (start, stop): + yield [] + else: + yield [ + item + async for sublist in async_get_logs_multipart( + self.w3, + start, + stop, + self.address, + self.topics, + max_blocks=MAX_BLOCK_REQUEST, + ) + for item in sublist + ] + + async def get_logs(self) -> List[LogReceipt]: + self_from_block = await self.from_block + self_to_block = await self.to_block + return [ + item + async for sublist in async_get_logs_multipart( + self.w3, + self_from_block, + self_to_block, + self.address, + self.topics, + max_blocks=MAX_BLOCK_REQUEST, + ) + for item in sublist + ] + + +class AsyncRequestBlocks: + def __init__(self, w3: "Web3") -> None: + self.w3 = w3 + + def __await__(self) -> Generator[Any, None, "AsyncRequestBlocks"]: + async def closure() -> "AsyncRequestBlocks": + self.block_number = await self.w3.eth.block_number # type: ignore + self.start_block = BlockNumber(self.block_number + 1) + return self + + return closure().__await__() + + @property + def filter_changes(self) -> AsyncIterator[List[Hash32]]: + return self.get_filter_changes() + + async def get_filter_changes(self) -> AsyncIterator[List[Hash32]]: + block_range_iter = async_iter_latest_block_ranges( + self.w3, self.start_block, None + ) + async for block_range in block_range_iter: + hash = await async_block_hashes_in_range(self.w3, block_range) + yield hash + + +async def async_block_hashes_in_range( + w3: "Web3", block_range: Tuple[BlockNumber, BlockNumber] +) -> List[Union[None, Hash32]]: + from_block, to_block = block_range + if from_block is None or to_block is None: + return [] + + block_hashes = [] + for block_number in range(from_block, to_block + 1): + w3_get_block = await w3.eth.get_block(BlockNumber(block_number)) # type: ignore + block_hashes.append(getattr(w3_get_block, "hash", None)) + + return block_hashes + + +async def async_local_filter_middleware( + make_request: Callable[[RPCEndpoint, Any], Any], w3: "Web3" +) -> Callable[[RPCEndpoint, Any], Coroutine[Any, Any, RPCResponse]]: + filters = {} + filter_id_counter = map(to_hex, itertools.count()) + + async def middleware(method: RPCEndpoint, params: Any) -> RPCResponse: + if method in NEW_FILTER_METHODS: + + filter_id = next(filter_id_counter) + + _filter: Union[AsyncRequestLogs, AsyncRequestBlocks] + if method == RPC.eth_newFilter: + _filter = await AsyncRequestLogs( + w3, **apply_key_map(FILTER_PARAMS_KEY_MAP, params[0]) + ) + + elif method == RPC.eth_newBlockFilter: + _filter = await AsyncRequestBlocks(w3) + + else: + raise NotImplementedError(method) + + filters[filter_id] = _filter + return {"result": filter_id} + + elif method in FILTER_CHANGES_METHODS: + + filter_id = params[0] + # Pass through to filters not created by middleware + if filter_id not in filters: + return await make_request(method, params) + _filter = filters[filter_id] + + if method == RPC.eth_getFilterChanges: + changes = await _filter.filter_changes.__anext__() + return {"result": changes} + + elif method == RPC.eth_getFilterLogs: + # type ignored b/c logic prevents RequestBlocks which + # doesn't implement get_logs + logs = await _filter.get_logs() # type: ignore + return {"result": logs} + else: + raise NotImplementedError(method) + else: + return await make_request(method, params) + + return middleware diff --git a/web3/module.py b/web3/module.py index cc5ac7cb8d..b41168d9eb 100644 --- a/web3/module.py +++ b/web3/module.py @@ -17,6 +17,7 @@ ) from web3._utils.filters import ( + AsyncLogFilter, LogFilter, _UseExistingFilter, ) @@ -72,11 +73,15 @@ def caller(*args: Any, **kwargs: Any) -> Union[TReturn, LogFilter]: @curry def retrieve_async_method_call_fn( w3: "Web3", module: "Module", method: Method[Callable[..., Any]] -) -> Callable[..., Coroutine[Any, Any, RPCResponse]]: - async def caller(*args: Any, **kwargs: Any) -> RPCResponse: - (method_str, params), response_formatters = method.process_params( - module, *args, **kwargs - ) +) -> Callable[..., Coroutine[Any, Any, Union[RPCResponse, AsyncLogFilter]]]: + async def caller(*args: Any, **kwargs: Any) -> Union[RPCResponse, AsyncLogFilter]: + try: + (method_str, params), response_formatters = method.process_params( + module, *args, **kwargs + ) + + except _UseExistingFilter as err: + return AsyncLogFilter(eth_module=module, filter_id=err.filter_id) ( result_formatters, error_formatters,