diff --git a/tests/core/eth-module/test_eth_filter.py b/tests/core/eth-module/test_eth_filter.py new file mode 100644 index 0000000000..e33bd6af2d --- /dev/null +++ b/tests/core/eth-module/test_eth_filter.py @@ -0,0 +1,49 @@ +import pytest + +import pytest_asyncio + +from web3 import Web3 +from web3._utils.filters import ( + AsyncBlockFilter, + AsyncLogFilter, + AsyncTransactionFilter, + BlockFilter, + LogFilter, + TransactionFilter, +) +from web3.eth import ( + AsyncEth, +) +from web3.providers.eth_tester.main import ( + AsyncEthereumTesterProvider, +) + + +def test_Eth_filter_creates_correct_filter_type(w3): + filter1 = w3.eth.filter("latest") + assert isinstance(filter1, BlockFilter) + filter2 = w3.eth.filter("pending") + assert isinstance(filter2, TransactionFilter) + filter3 = w3.eth.filter({}) + assert isinstance(filter3, LogFilter) + + +# --- async --- # + + +@pytest_asyncio.fixture() +async def async_w3(): + provider = AsyncEthereumTesterProvider() + w3 = Web3(provider, modules={"eth": [AsyncEth]}, middlewares=[]) + return w3 + + +@pytest.mark.asyncio +async def test_AsyncEth_filter_creates_correct_filter_type(async_w3): + + filter1 = await async_w3.eth.filter("latest") + assert isinstance(filter1, AsyncBlockFilter) + filter2 = await async_w3.eth.filter("pending") + assert isinstance(filter2, AsyncTransactionFilter) + filter3 = await async_w3.eth.filter({}) + assert isinstance(filter3, AsyncLogFilter) diff --git a/web3/_utils/method_formatters.py b/web3/_utils/method_formatters.py index cb653ffd43..fdb3028b50 100644 --- a/web3/_utils/method_formatters.py +++ b/web3/_utils/method_formatters.py @@ -10,6 +10,7 @@ NoReturn, Tuple, Union, + cast, ) from eth_abi import ( @@ -60,6 +61,9 @@ to_hex, ) from web3._utils.filters import ( + AsyncBlockFilter, + AsyncLogFilter, + AsyncTransactionFilter, BlockFilter, LogFilter, TransactionFilter, @@ -107,6 +111,7 @@ from web3 import Web3 # noqa: F401 from web3.module import Module # noqa: F401 from web3.eth import Eth # noqa: F401 + from web3.eth import AsyncEth # noqa: F401 def bytes_to_ascii(value: bytes) -> str: @@ -751,16 +756,34 @@ def raise_transaction_not_found_with_index( def filter_wrapper( - module: "Eth", + module: Union["AsyncEth", "Eth"], method: RPCEndpoint, filter_id: HexStr, -) -> Union[BlockFilter, TransactionFilter, LogFilter]: +) -> Union[ + AsyncBlockFilter, + AsyncTransactionFilter, + AsyncLogFilter, + BlockFilter, + TransactionFilter, + LogFilter, +]: if method == RPC.eth_newBlockFilter: - return BlockFilter(filter_id, eth_module=module) + if module.is_async: + return AsyncBlockFilter(filter_id, eth_module=cast("AsyncEth", module)) + else: + return BlockFilter(filter_id, eth_module=cast("Eth", module)) elif method == RPC.eth_newPendingTransactionFilter: - return TransactionFilter(filter_id, eth_module=module) + if module.is_async: + return AsyncTransactionFilter( + filter_id, eth_module=cast("AsyncEth", module) + ) + else: + return TransactionFilter(filter_id, eth_module=cast("Eth", module)) elif method == RPC.eth_newFilter: - return LogFilter(filter_id, eth_module=module) + if module.is_async: + return AsyncLogFilter(filter_id, eth_module=cast("AsyncEth", module)) + else: + return LogFilter(filter_id, eth_module=cast("Eth", module)) else: raise NotImplementedError( "Filter wrapper needs to be used with either " @@ -794,7 +817,6 @@ def get_result_formatters( formatters_requiring_module = combine_formatters( (FILTER_RESULT_FORMATTERS,), method_name ) - partial_formatters = apply_module_to_formatters( formatters_requiring_module, module, method_name ) diff --git a/web3/eth.py b/web3/eth.py index 17b013d6e1..861ae3ba12 100644 --- a/web3/eth.py +++ b/web3/eth.py @@ -626,7 +626,7 @@ async def get_storage_at( ) -> HexBytes: return await self._get_storage_at(account, position, block_identifier) - filter: Method[Callable[..., Awaitable[Any]]] = Method( + _filter: Method[Callable[..., Awaitable[Any]]] = Method( method_choice_depends_on_args=select_filter_method( if_new_block_filter=RPC.eth_newBlockFilter, if_new_pending_transaction_filter=RPC.eth_newPendingTransactionFilter, @@ -635,6 +635,12 @@ async def get_storage_at( mungers=[BaseEth.filter_munger], ) + async def filter( + self, + filter_type: Union[str, FilterParams, HexStr], + ) -> HexStr: + return await self._filter(filter_type) + _get_filter_changes: Method[ Callable[[HexStr], Awaitable[List[LogReceipt]]] ] = Method(RPC.eth_getFilterChanges, mungers=[default_root_munger])