Skip to content

Commit

Permalink
allow eth.filter to return appropriate sync/async filter types (#2645)
Browse files Browse the repository at this point in the history
* update method_formatters.py to select the appropriate sync/async filter type
  • Loading branch information
Paul Robinson committed Oct 7, 2022
1 parent 2d84bbc commit 2776068
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 7 deletions.
49 changes: 49 additions & 0 deletions tests/core/eth-module/test_eth_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import pytest

import pytest_asyncio

from web3 import Web3
from web3._utils.filters import (
AsyncBlockFilter,
AsyncLogFilter,
AsyncTransactionFilter,
BlockFilter,
LogFilter,
TransactionFilter,
)
from web3.eth import (
AsyncEth,
)
from web3.providers.eth_tester.main import (
AsyncEthereumTesterProvider,
)


def test_Eth_filter_creates_correct_filter_type(w3):
filter1 = w3.eth.filter("latest")
assert isinstance(filter1, BlockFilter)
filter2 = w3.eth.filter("pending")
assert isinstance(filter2, TransactionFilter)
filter3 = w3.eth.filter({})
assert isinstance(filter3, LogFilter)


# --- async --- #


@pytest_asyncio.fixture()
async def async_w3():
provider = AsyncEthereumTesterProvider()
w3 = Web3(provider, modules={"eth": [AsyncEth]}, middlewares=[])
return w3


@pytest.mark.asyncio
async def test_AsyncEth_filter_creates_correct_filter_type(async_w3):

filter1 = await async_w3.eth.filter("latest")
assert isinstance(filter1, AsyncBlockFilter)
filter2 = await async_w3.eth.filter("pending")
assert isinstance(filter2, AsyncTransactionFilter)
filter3 = await async_w3.eth.filter({})
assert isinstance(filter3, AsyncLogFilter)
34 changes: 28 additions & 6 deletions web3/_utils/method_formatters.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
NoReturn,
Tuple,
Union,
cast,
)

from eth_abi import (
Expand Down Expand Up @@ -60,6 +61,9 @@
to_hex,
)
from web3._utils.filters import (
AsyncBlockFilter,
AsyncLogFilter,
AsyncTransactionFilter,
BlockFilter,
LogFilter,
TransactionFilter,
Expand Down Expand Up @@ -107,6 +111,7 @@
from web3 import Web3 # noqa: F401
from web3.module import Module # noqa: F401
from web3.eth import Eth # noqa: F401
from web3.eth import AsyncEth # noqa: F401


def bytes_to_ascii(value: bytes) -> str:
Expand Down Expand Up @@ -751,16 +756,34 @@ def raise_transaction_not_found_with_index(


def filter_wrapper(
module: "Eth",
module: Union["AsyncEth", "Eth"],
method: RPCEndpoint,
filter_id: HexStr,
) -> Union[BlockFilter, TransactionFilter, LogFilter]:
) -> Union[
AsyncBlockFilter,
AsyncTransactionFilter,
AsyncLogFilter,
BlockFilter,
TransactionFilter,
LogFilter,
]:
if method == RPC.eth_newBlockFilter:
return BlockFilter(filter_id, eth_module=module)
if module.is_async:
return AsyncBlockFilter(filter_id, eth_module=cast("AsyncEth", module))
else:
return BlockFilter(filter_id, eth_module=cast("Eth", module))
elif method == RPC.eth_newPendingTransactionFilter:
return TransactionFilter(filter_id, eth_module=module)
if module.is_async:
return AsyncTransactionFilter(
filter_id, eth_module=cast("AsyncEth", module)
)
else:
return TransactionFilter(filter_id, eth_module=cast("Eth", module))
elif method == RPC.eth_newFilter:
return LogFilter(filter_id, eth_module=module)
if module.is_async:
return AsyncLogFilter(filter_id, eth_module=cast("AsyncEth", module))
else:
return LogFilter(filter_id, eth_module=cast("Eth", module))
else:
raise NotImplementedError(
"Filter wrapper needs to be used with either "
Expand Down Expand Up @@ -794,7 +817,6 @@ def get_result_formatters(
formatters_requiring_module = combine_formatters(
(FILTER_RESULT_FORMATTERS,), method_name
)

partial_formatters = apply_module_to_formatters(
formatters_requiring_module, module, method_name
)
Expand Down
8 changes: 7 additions & 1 deletion web3/eth.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,7 +626,7 @@ async def get_storage_at(
) -> HexBytes:
return await self._get_storage_at(account, position, block_identifier)

filter: Method[Callable[..., Awaitable[Any]]] = Method(
_filter: Method[Callable[..., Awaitable[Any]]] = Method(
method_choice_depends_on_args=select_filter_method(
if_new_block_filter=RPC.eth_newBlockFilter,
if_new_pending_transaction_filter=RPC.eth_newPendingTransactionFilter,
Expand All @@ -635,6 +635,12 @@ async def get_storage_at(
mungers=[BaseEth.filter_munger],
)

async def filter(
self,
filter_type: Union[str, FilterParams, HexStr],
) -> HexStr:
return await self._filter(filter_type)

_get_filter_changes: Method[
Callable[[HexStr], Awaitable[List[LogReceipt]]]
] = Method(RPC.eth_getFilterChanges, mungers=[default_root_munger])
Expand Down

0 comments on commit 2776068

Please sign in to comment.