diff --git a/newsfragments/2056.feature.rst b/newsfragments/2056.feature.rst new file mode 100644 index 0000000000..24bba7c6d7 --- /dev/null +++ b/newsfragments/2056.feature.rst @@ -0,0 +1 @@ +Add async ``eth.get_balance``, ``eth.get_code``, ``eth.get_transaction_count`` methods. diff --git a/web3/_utils/module_testing/eth_module.py b/web3/_utils/module_testing/eth_module.py index 8cf6a06360..d1eafb651e 100644 --- a/web3/_utils/module_testing/eth_module.py +++ b/web3/_utils/module_testing/eth_module.py @@ -374,6 +374,54 @@ async def test_eth_get_raw_transaction_raises_error( with pytest.raises(TransactionNotFound, match=f"Transaction with hash: '{UNKNOWN_HASH}'"): web3.eth.get_raw_transaction(UNKNOWN_HASH) + @pytest.mark.asyncio + async def test_eth_get_balance(self, async_w3: "Web3") -> None: + coinbase = await async_w3.eth.coinbase # type: ignore + + with pytest.raises(InvalidAddress): + await async_w3.eth.get_balance( # type: ignore + ChecksumAddress(HexAddress(HexStr(coinbase.lower()))) + ) + + balance = await async_w3.eth.get_balance(coinbase) # type: ignore + + assert is_integer(balance) + assert balance >= 0 + + @pytest.mark.asyncio + async def test_eth_get_code( + self, async_w3: "Web3", math_contract_address: ChecksumAddress + ) -> None: + code = await async_w3.eth.get_code(math_contract_address) # type: ignore + assert isinstance(code, HexBytes) + assert len(code) > 0 + + @pytest.mark.asyncio + async def test_eth_get_code_invalid_address( + self, async_w3: "Web3", math_contract: "Contract" + ) -> None: + with pytest.raises(InvalidAddress): + await async_w3.eth.get_code( # type: ignore + ChecksumAddress(HexAddress(HexStr(math_contract.address.lower()))) + ) + + @pytest.mark.asyncio + async def test_eth_get_code_with_block_identifier( + self, async_w3: "Web3", emitter_contract: "Contract" + ) -> None: + block_id = await async_w3.eth.block_number # type: ignore + code = await async_w3.eth.get_code(emitter_contract.address, block_id) # type: ignore + assert isinstance(code, HexBytes) + assert len(code) > 0 + + @pytest.mark.asyncio + async def test_eth_get_transaction_count( + self, async_w3: "Web3", unlocked_account_dual_type: ChecksumAddress + ) -> None: + transaction_count = await async_w3.eth.get_transaction_count(unlocked_account_dual_type) # type: ignore # noqa E501 + assert is_integer(transaction_count) + assert transaction_count >= 0 + class EthModuleTest: def test_eth_protocol_version(self, web3: "Web3") -> None: diff --git a/web3/eth.py b/web3/eth.py index 33132a0a3b..29936213a2 100644 --- a/web3/eth.py +++ b/web3/eth.py @@ -1,5 +1,6 @@ from typing import ( Any, + Awaitable, Callable, List, NoReturn, @@ -105,6 +106,7 @@ class BaseEth(Module): _default_account: Union[ChecksumAddress, Empty] = empty + _default_block: BlockIdentifier = "latest" gasPriceStrategy = None _gas_price: Method[Callable[[], Wei]] = Method( @@ -112,6 +114,15 @@ class BaseEth(Module): mungers=None, ) + """ property default_block """ + @property + def default_block(self) -> BlockIdentifier: + return self._default_block + + @default_block.setter + def default_block(self, value: BlockIdentifier) -> None: + self._default_block = value + @property def default_account(self) -> Union[ChecksumAddress, Empty]: return self._default_account @@ -193,6 +204,15 @@ def get_block_munger( mungers=None, ) + def block_id_munger( + self, + account: Union[Address, ChecksumAddress, ENS], + block_identifier: Optional[BlockIdentifier] = None + ) -> Tuple[Union[Address, ChecksumAddress, ENS], BlockIdentifier]: + if block_identifier is None: + block_identifier = self.default_block + return (account, block_identifier) + class AsyncEth(BaseEth): is_async = True @@ -243,10 +263,45 @@ async def coinbase(self) -> ChecksumAddress: # types ignored b/c mypy conflict with BlockingEth properties return await self.get_coinbase() # type: ignore + _get_balance: Method[Callable[..., Awaitable[Wei]]] = Method( + RPC.eth_getBalance, + mungers=[BaseEth.block_id_munger], + ) + + async def get_balance( + self, + account: Union[Address, ChecksumAddress, ENS], + block_identifier: Optional[BlockIdentifier] = None + ) -> Wei: + return await self._get_balance(account, block_identifier) + + _get_code: Method[Callable[..., Awaitable[HexBytes]]] = Method( + RPC.eth_getCode, + mungers=[BaseEth.block_id_munger] + ) + + async def get_code( + self, + account: Union[Address, ChecksumAddress, ENS], + block_identifier: Optional[BlockIdentifier] = None + ) -> HexBytes: + return await self._get_code(account, block_identifier) + + _get_transaction_count: Method[Callable[..., Awaitable[Nonce]]] = Method( + RPC.eth_getTransactionCount, + mungers=[BaseEth.block_id_munger], + ) + + async def get_transaction_count( + self, + account: Union[Address, ChecksumAddress, ENS], + block_identifier: Optional[BlockIdentifier] = None + ) -> Nonce: + return await self._get_transaction_count(account, block_identifier) + class Eth(BaseEth, Module): account = Account() - _default_block: BlockIdentifier = "latest" defaultContractFactory: Type[Union[Contract, ConciseContract, ContractCaller]] = Contract # noqa: E704,E501 iban = Iban @@ -383,15 +438,10 @@ def defaultAccount(self, account: Union[ChecksumAddress, Empty]) -> None: ) self._default_account = account - """ property default_block """ - - @property - def default_block(self) -> BlockIdentifier: - return self._default_block - - @default_block.setter - def default_block(self, value: BlockIdentifier) -> None: - self._default_block = value + get_balance: Method[Callable[..., Wei]] = Method( + RPC.eth_getBalance, + mungers=[BaseEth.block_id_munger], + ) @property def defaultBlock(self) -> BlockIdentifier: @@ -409,20 +459,6 @@ def defaultBlock(self, value: BlockIdentifier) -> None: ) self._default_block = value - def block_id_munger( - self, - account: Union[Address, ChecksumAddress, ENS], - block_identifier: Optional[BlockIdentifier] = None - ) -> Tuple[Union[Address, ChecksumAddress, ENS], BlockIdentifier]: - if block_identifier is None: - block_identifier = self.default_block - return (account, block_identifier) - - get_balance: Method[Callable[..., Wei]] = Method( - RPC.eth_getBalance, - mungers=[block_id_munger], - ) - def get_storage_at_munger( self, account: Union[Address, ChecksumAddress, ENS], @@ -458,16 +494,16 @@ def get_proof_munger( mungers=[get_proof_munger], ) - get_code: Method[Callable[..., HexBytes]] = Method( - RPC.eth_getCode, - mungers=[block_id_munger] - ) - def get_block( self, block_identifier: BlockIdentifier, full_transactions: bool = False ) -> BlockData: return self._get_block(block_identifier, full_transactions) + get_code: Method[Callable[..., HexBytes]] = Method( + RPC.eth_getCode, + mungers=[BaseEth.block_id_munger] + ) + """ `eth_getBlockTransactionCountByHash` `eth_getBlockTransactionCountByNumber` @@ -557,7 +593,7 @@ def wait_for_transaction_receipt( get_transaction_count: Method[Callable[..., Nonce]] = Method( RPC.eth_getTransactionCount, - mungers=[block_id_munger], + mungers=[BaseEth.block_id_munger], ) @deprecated_for("replace_transaction") diff --git a/web3/method.py b/web3/method.py index 601a2ee31b..78f68a68e3 100644 --- a/web3/method.py +++ b/web3/method.py @@ -40,6 +40,7 @@ from web3 import Web3 # noqa: F401 from web3.module import Module # noqa: F401 + Munger = Callable[..., Any] @@ -154,7 +155,7 @@ def method_selector_fn(self) -> Callable[..., Union[RPCEndpoint, Callable[..., R def input_munger( self, module: "Module", args: Any, kwargs: Any ) -> List[Any]: - # This function takes the "root_munger" - the first munger in + # This function takes the "root_munger" - (the first munger in # the list of mungers) and then pipes the return value of the # previous munger as an argument to the next munger to return # an array of arguments that have been formatted.