From 73d1956f3ca4a3694c2745d5b293dbc47805d37a Mon Sep 17 00:00:00 2001 From: Paul Robinson Date: Fri, 7 Oct 2022 14:59:37 -0600 Subject: [PATCH] Asyncify filter middleware (#2663) * asyncify filter middleware and tests --- .../core/middleware/test_filter_middleware.py | 182 ++++++++++- web3/eth.py | 8 +- web3/middleware/__init__.py | 3 + web3/middleware/filter.py | 286 +++++++++++++++++- 4 files changed, 466 insertions(+), 13 deletions(-) diff --git a/tests/core/middleware/test_filter_middleware.py b/tests/core/middleware/test_filter_middleware.py index afa486bb82..d1fd94f7a0 100644 --- a/tests/core/middleware/test_filter_middleware.py +++ b/tests/core/middleware/test_filter_middleware.py @@ -3,19 +3,29 @@ from hexbytes import ( HexBytes, ) +import pytest_asyncio from web3 import Web3 from web3.datastructures import ( AttributeDict, ) +from web3.eth import ( + AsyncEth, +) from web3.middleware import ( + async_construct_result_generator_middleware, + async_local_filter_middleware, construct_result_generator_middleware, local_filter_middleware, ) from web3.middleware.filter import ( + async_iter_latest_block_ranges, block_ranges, iter_latest_block_ranges, ) +from web3.providers.async_base import ( + AsyncBaseProvider, +) from web3.providers.base import ( BaseProvider, ) @@ -184,6 +194,15 @@ def test_block_ranges(start, stop, expected): (None, None), ], ), + ( + 10, + 10, + [10, 10], + [ + (10, 10), + (None, None), + ], + ), ], ) def test_iter_latest_block_ranges( @@ -206,18 +225,171 @@ def test_local_filter_middleware(w3, iter_block_number): block_filter = w3.eth.filter("latest") block_filter.get_new_entries() iter_block_number.send(1) + assert w3.eth.get_filter_changes(block_filter.filter_id) == [HexBytes(BLOCK_HASH)] log_filter = w3.eth.filter(filter_params={"fromBlock": "latest"}) + iter_block_number.send(2) + log_changes = w3.eth.get_filter_changes(log_filter.filter_id) + assert log_changes == FILTER_LOG + assert w3.eth.get_filter_logs(log_filter.filter_id) == FILTER_LOG - assert w3.eth.get_filter_changes(block_filter.filter_id) == [HexBytes(BLOCK_HASH)] + log_filter_from_hex_string = w3.eth.filter( + filter_params={"fromBlock": "0x0", "toBlock": "0x2"} + ) + log_filter_from_int = w3.eth.filter(filter_params={"fromBlock": 1, "toBlock": 3}) + + filter_ids = ( + block_filter.filter_id, + log_filter.filter_id, + log_filter_from_hex_string.filter_id, + log_filter_from_int.filter_id, + ) + + # Test that all ids are str types + assert all(isinstance(_filter_id, (str,)) for _filter_id in filter_ids) + + # Test that all ids are unique + assert len(filter_ids) == len(set(filter_ids)) + + +# --- async --- # + + +class AsyncDummyProvider(AsyncBaseProvider): + async def make_request(self, method, params): + raise NotImplementedError(f"Cannot make request for {method}:{params}") + + +@pytest_asyncio.fixture(scope="function") +async def async_result_generator_middleware(iter_block_number): + return await async_construct_result_generator_middleware( + { + "eth_getLogs": lambda *_: FILTER_LOG, + "eth_getBlockByNumber": lambda *_: {"hash": BLOCK_HASH}, + "net_version": lambda *_: 1, + "eth_blockNumber": lambda *_: next(iter_block_number), + } + ) + +@pytest.fixture(scope="function") +def async_w3_base(): + return Web3( + provider=AsyncDummyProvider(), modules={"eth": (AsyncEth)}, middlewares=[] + ) + + +@pytest.fixture(scope="function") +def async_w3(async_w3_base, async_result_generator_middleware): + async_w3_base.middleware_onion.add(async_result_generator_middleware) + async_w3_base.middleware_onion.add(async_local_filter_middleware) + return async_w3_base + + +@pytest.mark.parametrize( + "from_block,to_block,current_block,expected", + [ + ( + 0, + 10, + [10], + [ + (0, 10), + ], + ), + ( + 0, + 55, + [0, 19, 55], + [ + (0, 0), + (1, 19), + (20, 55), + ], + ), + ( + 0, + None, + [10], + [ + (0, 10), + ], + ), + ( + 0, + 10, + [12], + [ + (None, None), + ], + ), + ( + 12, + 10, + [12], + [ + (None, None), + ], + ), + ( + 12, + 10, + [None], + [ + (None, None), + ], + ), + ( + 10, + 10, + [10, 10], + [ + (10, 10), + (None, None), + ], + ), + ], +) +@pytest.mark.asyncio +async def test_async_iter_latest_block_ranges( + async_w3, iter_block_number, from_block, to_block, current_block, expected +): + latest_block_ranges = async_iter_latest_block_ranges(async_w3, from_block, to_block) + for index, block in enumerate(current_block): + iter_block_number.send(block) + expected_tuple = expected[index] + actual_tuple = await latest_block_ranges.__anext__() + assert actual_tuple == expected_tuple + + +@pytest.mark.asyncio +async def test_async_local_filter_middleware(async_w3, iter_block_number): + block_filter = await async_w3.eth.filter("latest") + await block_filter.get_new_entries() + iter_block_number.send(1) + block_changes = await async_w3.eth.get_filter_changes(block_filter.filter_id) + assert block_changes == [HexBytes(BLOCK_HASH)] + + log_filter = await async_w3.eth.filter(filter_params={"fromBlock": "latest"}) iter_block_number.send(2) - results = w3.eth.get_filter_changes(log_filter.filter_id) - assert results == FILTER_LOG + log_changes = await async_w3.eth.get_filter_changes(log_filter.filter_id) + assert log_changes == FILTER_LOG + logs = await async_w3.eth.get_filter_logs(log_filter.filter_id) + assert logs == FILTER_LOG - assert w3.eth.get_filter_logs(log_filter.filter_id) == FILTER_LOG + log_filter_from_hex_string = await async_w3.eth.filter( + filter_params={"fromBlock": "0x0", "toBlock": "0x2"} + ) + log_filter_from_int = await async_w3.eth.filter( + filter_params={"fromBlock": 1, "toBlock": 3} + ) - filter_ids = (block_filter.filter_id, log_filter.filter_id) + filter_ids = ( + block_filter.filter_id, + log_filter.filter_id, + log_filter_from_hex_string.filter_id, + log_filter_from_int.filter_id, + ) # Test that all ids are str types assert all(isinstance(_filter_id, (str,)) for _filter_id in filter_ids) diff --git a/web3/eth.py b/web3/eth.py index 92910493e8..7c4ab1526b 100644 --- a/web3/eth.py +++ b/web3/eth.py @@ -634,7 +634,7 @@ async def get_storage_at( ) -> HexBytes: return await self._get_storage_at(account, position, block_identifier) - _filter: Method[Callable[..., Awaitable[Any]]] = Method( + filter: Method[Callable[..., Awaitable[Any]]] = Method( method_choice_depends_on_args=select_filter_method( if_new_block_filter=RPC.eth_newBlockFilter, if_new_pending_transaction_filter=RPC.eth_newPendingTransactionFilter, @@ -643,12 +643,6 @@ async def get_storage_at( mungers=[BaseEth.filter_munger], ) - async def filter( - self, - filter_type: Union[str, FilterParams, HexStr], - ) -> HexStr: - return await self._filter(filter_type) - _get_filter_changes: Method[ Callable[[HexStr], Awaitable[List[LogReceipt]]] ] = Method(RPC.eth_getFilterChanges, mungers=[default_root_munger]) diff --git a/web3/middleware/__init__.py b/web3/middleware/__init__.py index 7c0ddb8c1f..af99aed36f 100644 --- a/web3/middleware/__init__.py +++ b/web3/middleware/__init__.py @@ -39,9 +39,12 @@ http_retry_request_middleware, ) from .filter import ( # noqa: F401 + async_local_filter_middleware, local_filter_middleware, ) from .fixture import ( # noqa: F401 + async_construct_error_generator_middleware, + async_construct_result_generator_middleware, construct_error_generator_middleware, construct_fixture_middleware, construct_result_generator_middleware, diff --git a/web3/middleware/filter.py b/web3/middleware/filter.py index f19bfd73da..9666e1510d 100644 --- a/web3/middleware/filter.py +++ b/web3/middleware/filter.py @@ -3,8 +3,11 @@ from typing import ( TYPE_CHECKING, Any, + AsyncIterable, + AsyncIterator, Callable, Dict, + Generator, Iterable, Iterator, List, @@ -19,6 +22,7 @@ BlockNumber, ChecksumAddress, Hash32, + HexStr, ) from eth_utils import ( apply_key_map, @@ -40,6 +44,7 @@ RPC, ) from web3.types import ( # noqa: F401 + Coroutine, FilterParams, LatestBlockParam, LogReceipt, @@ -251,7 +256,7 @@ def to_block(self) -> BlockNumber: to_block = self.w3.eth.block_number elif self._to_block == "latest": to_block = self.w3.eth.block_number - elif is_hex(self._to_block): + elif is_string(self._to_block) and is_hex(self._to_block): to_block = BlockNumber(hex_to_integer(self._to_block)) # type: ignore else: to_block = self._to_block @@ -367,6 +372,7 @@ def middleware(method: RPCEndpoint, params: Any) -> RPCResponse: _filter = filters[filter_id] if method == RPC.eth_getFilterChanges: return {"result": next(_filter.filter_changes)} + elif method == RPC.eth_getFilterLogs: # type ignored b/c logic prevents RequestBlocks which # doesn't implement get_logs @@ -377,3 +383,281 @@ def middleware(method: RPCEndpoint, params: Any) -> RPCResponse: return make_request(method, params) return middleware + + +# --- async --- # + + +async def async_iter_latest_block( + w3: "Web3", to_block: Optional[Union[BlockNumber, LatestBlockParam]] = None +) -> AsyncIterable[BlockNumber]: + """Returns a generator that dispenses the latest block, if + any new blocks have been mined since last iteration. + + If there are no new blocks or the latest block is greater than + the ``to_block`` None is returned. + + >>> new_blocks = iter_latest_block(w3, 0, 10) + >>> next(new_blocks) # Latest block = 0 + 0 + >>> next(new_blocks) # No new blocks + >>> next(new_blocks) # Latest block = 1 + 1 + >>> next(new_blocks) # Latest block = 10 + 10 + >>> next(new_blocks) # latest block > to block + """ + _last = None + + is_bounded_range = to_block is not None and to_block != "latest" + + while True: + latest_block = await w3.eth.block_number # type: ignore + # type ignored b/c is_bounded_range prevents unsupported comparison + if is_bounded_range and latest_block > to_block: + yield None + # No new blocks since last iteration. + if _last is not None and _last == latest_block: + yield None + else: + yield latest_block + _last = latest_block + + +async def async_iter_latest_block_ranges( + w3: "Web3", + from_block: BlockNumber, + to_block: Optional[Union[BlockNumber, LatestBlockParam]] = None, +) -> AsyncIterable[Tuple[Optional[BlockNumber], Optional[BlockNumber]]]: + """Returns an iterator unloading ranges of available blocks + + starting from `fromBlock` to the latest mined block, + until reaching toBlock. e.g.: + + + >>> blocks_to_filter = iter_latest_block_ranges(w3, 0, 50) + >>> next(blocks_to_filter) # latest block number = 11 + (0, 11) + >>> next(blocks_to_filter) # latest block number = 45 + (12, 45) + >>> next(blocks_to_filter) # latest block number = 50 + (46, 50) + """ + latest_block_iterator = async_iter_latest_block(w3, to_block) + async for latest_block in latest_block_iterator: + if latest_block is None: + yield (None, None) + elif from_block > latest_block: + yield (None, None) + else: + yield (from_block, latest_block) + from_block = BlockNumber(latest_block + 1) + + +async def async_get_logs_multipart( + w3: "Web3", + startBlock: BlockNumber, + stopBlock: BlockNumber, + address: Union[Address, ChecksumAddress, List[Union[Address, ChecksumAddress]]], + topics: List[Optional[Union[_Hash32, List[_Hash32]]]], + max_blocks: int, +) -> AsyncIterable[List[LogReceipt]]: + """Used to break up requests to ``eth_getLogs`` + + The getLog request is partitioned into multiple calls of the max number of blocks + ``max_blocks``. + """ + _block_ranges = block_ranges(startBlock, stopBlock, max_blocks) + for from_block, to_block in _block_ranges: + params = { + "fromBlock": from_block, + "toBlock": to_block, + "address": address, + "topics": topics, + } + params_with_none_dropped = cast( + FilterParams, drop_items_with_none_value(params) + ) + next_logs = await w3.eth.get_logs(params_with_none_dropped) # type: ignore + yield next_logs + + +class AsyncRequestLogs: + _from_block: BlockNumber + + def __init__( + self, + w3: "Web3", + from_block: Optional[Union[BlockNumber, LatestBlockParam]] = None, + to_block: Optional[Union[BlockNumber, LatestBlockParam]] = None, + address: Optional[ + Union[Address, ChecksumAddress, List[Union[Address, ChecksumAddress]]] + ] = None, + topics: Optional[List[Optional[Union[_Hash32, List[_Hash32]]]]] = None, + ) -> None: + self.address = address + self.topics = topics + self.w3 = w3 + self._from_block_arg = from_block + self._to_block = to_block + self.filter_changes = self._get_filter_changes() + + def __await__(self) -> Generator[Any, None, "AsyncRequestLogs"]: + async def closure() -> "AsyncRequestLogs": + if self._from_block_arg is None or self._from_block_arg == "latest": + self.block_number = await self.w3.eth.block_number # type: ignore + self._from_block = BlockNumber(self.block_number + 1) + elif is_string(self._from_block_arg) and is_hex(self._from_block_arg): + self._from_block = BlockNumber( + hex_to_integer(cast(HexStr, self._from_block_arg)) + ) + else: + self._from_block = self._from_block_arg + + return self + + return closure().__await__() + + @property + async def from_block(self) -> BlockNumber: + return self._from_block + + @property + async def to_block(self) -> BlockNumber: + if self._to_block is None or self._to_block == "latest": + to_block = await self.w3.eth.block_number # type: ignore + elif is_string(self._to_block) and is_hex(self._to_block): + to_block = BlockNumber(hex_to_integer(cast(HexStr, self._to_block))) + else: + to_block = self._to_block + + return to_block + + async def _get_filter_changes(self) -> AsyncIterator[List[LogReceipt]]: + self_from_block = await self.from_block + self_to_block = await self.to_block + async for start, stop in async_iter_latest_block_ranges( + self.w3, self_from_block, self_to_block + ): + if None in (start, stop): + yield [] + else: + yield [ + item + async for sublist in async_get_logs_multipart( + self.w3, + start, + stop, + self.address, + self.topics, + max_blocks=MAX_BLOCK_REQUEST, + ) + for item in sublist + ] + + async def get_logs(self) -> List[LogReceipt]: + self_from_block = await self.from_block + self_to_block = await self.to_block + return [ + item + async for sublist in async_get_logs_multipart( + self.w3, + self_from_block, + self_to_block, + self.address, + self.topics, + max_blocks=MAX_BLOCK_REQUEST, + ) + for item in sublist + ] + + +class AsyncRequestBlocks: + def __init__(self, w3: "Web3") -> None: + self.w3 = w3 + + def __await__(self) -> Generator[Any, None, "AsyncRequestBlocks"]: + async def closure() -> "AsyncRequestBlocks": + self.block_number = await self.w3.eth.block_number # type: ignore + self.start_block = BlockNumber(self.block_number + 1) + return self + + return closure().__await__() + + @property + def filter_changes(self) -> AsyncIterator[List[Hash32]]: + return self.get_filter_changes() + + async def get_filter_changes(self) -> AsyncIterator[List[Hash32]]: + block_range_iter = async_iter_latest_block_ranges( + self.w3, self.start_block, None + ) + async for block_range in block_range_iter: + hash = await async_block_hashes_in_range(self.w3, block_range) + yield hash + + +async def async_block_hashes_in_range( + w3: "Web3", block_range: Tuple[BlockNumber, BlockNumber] +) -> List[Union[None, Hash32]]: + from_block, to_block = block_range + if from_block is None or to_block is None: + return [] + + block_hashes = [] + for block_number in range(from_block, to_block + 1): + w3_get_block = await w3.eth.get_block(BlockNumber(block_number)) # type: ignore + block_hashes.append(getattr(w3_get_block, "hash", None)) + + return block_hashes + + +async def async_local_filter_middleware( + make_request: Callable[[RPCEndpoint, Any], Any], w3: "Web3" +) -> Callable[[RPCEndpoint, Any], Coroutine[Any, Any, RPCResponse]]: + filters = {} + filter_id_counter = map(to_hex, itertools.count()) + + async def middleware(method: RPCEndpoint, params: Any) -> RPCResponse: + if method in NEW_FILTER_METHODS: + + filter_id = next(filter_id_counter) + + _filter: Union[AsyncRequestLogs, AsyncRequestBlocks] + if method == RPC.eth_newFilter: + _filter = await AsyncRequestLogs( + w3, **apply_key_map(FILTER_PARAMS_KEY_MAP, params[0]) + ) + + elif method == RPC.eth_newBlockFilter: + _filter = await AsyncRequestBlocks(w3) + + else: + raise NotImplementedError(method) + + filters[filter_id] = _filter + return {"result": filter_id} + + elif method in FILTER_CHANGES_METHODS: + + filter_id = params[0] + # Pass through to filters not created by middleware + if filter_id not in filters: + return await make_request(method, params) + _filter = filters[filter_id] + + if method == RPC.eth_getFilterChanges: + changes = await _filter.filter_changes.__anext__() + return {"result": changes} + + elif method == RPC.eth_getFilterLogs: + # type ignored b/c logic prevents RequestBlocks which + # doesn't implement get_logs + logs = await _filter.get_logs() # type: ignore + return {"result": logs} + else: + raise NotImplementedError(method) + else: + return await make_request(method, params) + + return middleware