Skip to content

Commit

Permalink
Async eth.get_balance, eth.get_code, eth.get_transaction_count (#2056)
Browse files Browse the repository at this point in the history
  • Loading branch information
kclowes authored Jul 21, 2021
1 parent 2d36855 commit c2048f1
Show file tree
Hide file tree
Showing 4 changed files with 117 additions and 31 deletions.
1 change: 1 addition & 0 deletions newsfragments/2056.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add async ``eth.get_balance``, ``eth.get_code``, ``eth.get_transaction_count`` methods.
48 changes: 48 additions & 0 deletions web3/_utils/module_testing/eth_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,54 @@ async def test_eth_get_raw_transaction_raises_error(
with pytest.raises(TransactionNotFound, match=f"Transaction with hash: '{UNKNOWN_HASH}'"):
web3.eth.get_raw_transaction(UNKNOWN_HASH)

@pytest.mark.asyncio
async def test_eth_get_balance(self, async_w3: "Web3") -> None:
coinbase = await async_w3.eth.coinbase # type: ignore

with pytest.raises(InvalidAddress):
await async_w3.eth.get_balance( # type: ignore
ChecksumAddress(HexAddress(HexStr(coinbase.lower())))
)

balance = await async_w3.eth.get_balance(coinbase) # type: ignore

assert is_integer(balance)
assert balance >= 0

@pytest.mark.asyncio
async def test_eth_get_code(
self, async_w3: "Web3", math_contract_address: ChecksumAddress
) -> None:
code = await async_w3.eth.get_code(math_contract_address) # type: ignore
assert isinstance(code, HexBytes)
assert len(code) > 0

@pytest.mark.asyncio
async def test_eth_get_code_invalid_address(
self, async_w3: "Web3", math_contract: "Contract"
) -> None:
with pytest.raises(InvalidAddress):
await async_w3.eth.get_code( # type: ignore
ChecksumAddress(HexAddress(HexStr(math_contract.address.lower())))
)

@pytest.mark.asyncio
async def test_eth_get_code_with_block_identifier(
self, async_w3: "Web3", emitter_contract: "Contract"
) -> None:
block_id = await async_w3.eth.block_number # type: ignore
code = await async_w3.eth.get_code(emitter_contract.address, block_id) # type: ignore
assert isinstance(code, HexBytes)
assert len(code) > 0

@pytest.mark.asyncio
async def test_eth_get_transaction_count(
self, async_w3: "Web3", unlocked_account_dual_type: ChecksumAddress
) -> None:
transaction_count = await async_w3.eth.get_transaction_count(unlocked_account_dual_type) # type: ignore # noqa E501
assert is_integer(transaction_count)
assert transaction_count >= 0


class EthModuleTest:
def test_eth_protocol_version(self, web3: "Web3") -> None:
Expand Down
96 changes: 66 additions & 30 deletions web3/eth.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import (
Any,
Awaitable,
Callable,
List,
NoReturn,
Expand Down Expand Up @@ -105,13 +106,23 @@

class BaseEth(Module):
_default_account: Union[ChecksumAddress, Empty] = empty
_default_block: BlockIdentifier = "latest"
gasPriceStrategy = None

_gas_price: Method[Callable[[], Wei]] = Method(
RPC.eth_gasPrice,
mungers=None,
)

""" property default_block """
@property
def default_block(self) -> BlockIdentifier:
return self._default_block

@default_block.setter
def default_block(self, value: BlockIdentifier) -> None:
self._default_block = value

@property
def default_account(self) -> Union[ChecksumAddress, Empty]:
return self._default_account
Expand Down Expand Up @@ -193,6 +204,15 @@ def get_block_munger(
mungers=None,
)

def block_id_munger(
self,
account: Union[Address, ChecksumAddress, ENS],
block_identifier: Optional[BlockIdentifier] = None
) -> Tuple[Union[Address, ChecksumAddress, ENS], BlockIdentifier]:
if block_identifier is None:
block_identifier = self.default_block
return (account, block_identifier)


class AsyncEth(BaseEth):
is_async = True
Expand Down Expand Up @@ -243,10 +263,45 @@ async def coinbase(self) -> ChecksumAddress:
# types ignored b/c mypy conflict with BlockingEth properties
return await self.get_coinbase() # type: ignore

_get_balance: Method[Callable[..., Awaitable[Wei]]] = Method(
RPC.eth_getBalance,
mungers=[BaseEth.block_id_munger],
)

async def get_balance(
self,
account: Union[Address, ChecksumAddress, ENS],
block_identifier: Optional[BlockIdentifier] = None
) -> Wei:
return await self._get_balance(account, block_identifier)

_get_code: Method[Callable[..., Awaitable[HexBytes]]] = Method(
RPC.eth_getCode,
mungers=[BaseEth.block_id_munger]
)

async def get_code(
self,
account: Union[Address, ChecksumAddress, ENS],
block_identifier: Optional[BlockIdentifier] = None
) -> HexBytes:
return await self._get_code(account, block_identifier)

_get_transaction_count: Method[Callable[..., Awaitable[Nonce]]] = Method(
RPC.eth_getTransactionCount,
mungers=[BaseEth.block_id_munger],
)

async def get_transaction_count(
self,
account: Union[Address, ChecksumAddress, ENS],
block_identifier: Optional[BlockIdentifier] = None
) -> Nonce:
return await self._get_transaction_count(account, block_identifier)


class Eth(BaseEth, Module):
account = Account()
_default_block: BlockIdentifier = "latest"
defaultContractFactory: Type[Union[Contract, ConciseContract, ContractCaller]] = Contract # noqa: E704,E501
iban = Iban

Expand Down Expand Up @@ -383,15 +438,10 @@ def defaultAccount(self, account: Union[ChecksumAddress, Empty]) -> None:
)
self._default_account = account

""" property default_block """

@property
def default_block(self) -> BlockIdentifier:
return self._default_block

@default_block.setter
def default_block(self, value: BlockIdentifier) -> None:
self._default_block = value
get_balance: Method[Callable[..., Wei]] = Method(
RPC.eth_getBalance,
mungers=[BaseEth.block_id_munger],
)

@property
def defaultBlock(self) -> BlockIdentifier:
Expand All @@ -409,20 +459,6 @@ def defaultBlock(self, value: BlockIdentifier) -> None:
)
self._default_block = value

def block_id_munger(
self,
account: Union[Address, ChecksumAddress, ENS],
block_identifier: Optional[BlockIdentifier] = None
) -> Tuple[Union[Address, ChecksumAddress, ENS], BlockIdentifier]:
if block_identifier is None:
block_identifier = self.default_block
return (account, block_identifier)

get_balance: Method[Callable[..., Wei]] = Method(
RPC.eth_getBalance,
mungers=[block_id_munger],
)

def get_storage_at_munger(
self,
account: Union[Address, ChecksumAddress, ENS],
Expand Down Expand Up @@ -458,16 +494,16 @@ def get_proof_munger(
mungers=[get_proof_munger],
)

get_code: Method[Callable[..., HexBytes]] = Method(
RPC.eth_getCode,
mungers=[block_id_munger]
)

def get_block(
self, block_identifier: BlockIdentifier, full_transactions: bool = False
) -> BlockData:
return self._get_block(block_identifier, full_transactions)

get_code: Method[Callable[..., HexBytes]] = Method(
RPC.eth_getCode,
mungers=[BaseEth.block_id_munger]
)

"""
`eth_getBlockTransactionCountByHash`
`eth_getBlockTransactionCountByNumber`
Expand Down Expand Up @@ -557,7 +593,7 @@ def wait_for_transaction_receipt(

get_transaction_count: Method[Callable[..., Nonce]] = Method(
RPC.eth_getTransactionCount,
mungers=[block_id_munger],
mungers=[BaseEth.block_id_munger],
)

@deprecated_for("replace_transaction")
Expand Down
3 changes: 2 additions & 1 deletion web3/method.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from web3 import Web3 # noqa: F401
from web3.module import Module # noqa: F401


Munger = Callable[..., Any]


Expand Down Expand Up @@ -154,7 +155,7 @@ def method_selector_fn(self) -> Callable[..., Union[RPCEndpoint, Callable[..., R
def input_munger(
self, module: "Module", args: Any, kwargs: Any
) -> List[Any]:
# This function takes the "root_munger" - the first munger in
# This function takes the "root_munger" - (the first munger in
# the list of mungers) and then pipes the return value of the
# previous munger as an argument to the next munger to return
# an array of arguments that have been formatted.
Expand Down

0 comments on commit c2048f1

Please sign in to comment.