From b892b22bd0aff22c516b7ede7c9766ecbdd5d655 Mon Sep 17 00:00:00 2001 From: pacrob Date: Mon, 8 Aug 2022 15:52:40 -0600 Subject: [PATCH] fix bug in how eth_tester middleware filled default fields --- newsfragments/2600.bugfix.rst | 1 + .../middleware/test_eth_tester_middleware.py | 162 ++++++++++++++++++ web3/providers/eth_tester/middleware.py | 38 +++- 3 files changed, 192 insertions(+), 9 deletions(-) create mode 100644 newsfragments/2600.bugfix.rst diff --git a/newsfragments/2600.bugfix.rst b/newsfragments/2600.bugfix.rst new file mode 100644 index 0000000000..06589098e1 --- /dev/null +++ b/newsfragments/2600.bugfix.rst @@ -0,0 +1 @@ +fixed bug in how async_eth_tester_middleware fills default fields diff --git a/tests/core/middleware/test_eth_tester_middleware.py b/tests/core/middleware/test_eth_tester_middleware.py index 228f7cd80d..c82ba91796 100644 --- a/tests/core/middleware/test_eth_tester_middleware.py +++ b/tests/core/middleware/test_eth_tester_middleware.py @@ -1,5 +1,13 @@ import pytest +from unittest.mock import ( + AsyncMock, + Mock, +) +from web3.providers.eth_tester.middleware import ( + async_default_transaction_fields_middleware, + default_transaction_fields_middleware, +) from web3.types import ( BlockData, ) @@ -20,3 +28,157 @@ def test_get_block_formatters(w3): keys_diff = all_block_keys.difference(latest_block_keys) assert len(keys_diff) == 1 assert keys_diff.pop() == "mixHash" # mixHash is not implemented in eth-tester + + +sample_address_1 = "0x0000000000000000000000000000000000000001" +sample_address_2 = "0x0000000000000000000000000000000000000002" + + +@pytest.mark.parametrize( + "w3_accounts, w3_coinbase, method, from_field_added, from_field_value", + ( + (sample_address_1, sample_address_2, "eth_call", True, sample_address_2), + ( + sample_address_1, + sample_address_2, + "eth_estimateGas", + True, + sample_address_2, + ), + ( + sample_address_1, + sample_address_2, + "eth_sendTransaction", + True, + sample_address_2, + ), + (sample_address_1, sample_address_2, "eth_gasPrice", False, None), + (sample_address_1, sample_address_2, "eth_blockNumber", False, None), + (sample_address_1, sample_address_2, "meow", False, None), + (sample_address_1, None, "eth_call", True, sample_address_1), + (sample_address_1, None, "eth_estimateGas", True, sample_address_1), + (sample_address_1, None, "eth_sendTransaction", True, sample_address_1), + (sample_address_1, None, "eth_gasPrice", False, None), + (sample_address_1, None, "eth_blockNumber", False, None), + (sample_address_1, None, "meow", False, None), + (None, sample_address_2, "eth_call", True, sample_address_2), + (None, sample_address_2, "eth_estimateGas", True, sample_address_2), + (None, sample_address_2, "eth_sendTransaction", True, sample_address_2), + (None, sample_address_2, "eth_gasPrice", False, sample_address_2), + (None, sample_address_2, "eth_blockNumber", False, sample_address_2), + (None, sample_address_2, "meow", False, sample_address_2), + (None, None, "eth_call", True, None), + (None, None, "eth_estimateGas", True, None), + (None, None, "eth_sendTransaction", True, None), + (None, None, "eth_gasPrice", False, None), + (None, None, "eth_blockNumber", False, None), + (None, None, "meow", False, None), + ), +) +def test_default_transaction_fields_middleware( + w3_accounts, w3_coinbase, method, from_field_added, from_field_value +): + w3_accounts = [w3_accounts] + + def mock_request(_method, params): + return params + + mock_w3 = Mock() + mock_w3.eth.accounts = w3_accounts + mock_w3.eth.coinbase = w3_coinbase + + middleware = default_transaction_fields_middleware(mock_request, mock_w3) + base_params = {"chainId": 5} + filled_transaction = middleware(method, [base_params]) + + filled_params = filled_transaction[0] + + assert ("from" in filled_params.keys()) == from_field_added + if "from" in filled_params.keys(): + assert filled_params["from"] == from_field_value + + filled_transaction[0].pop("from", None) + assert filled_transaction[0] == base_params + + +@pytest.mark.parametrize( + "w3_accounts, w3_coinbase, method, from_field_added, from_field_value", + ( + (sample_address_1, sample_address_2, "eth_call", True, sample_address_2), + ( + sample_address_1, + sample_address_2, + "eth_estimateGas", + True, + sample_address_2, + ), + ( + sample_address_1, + sample_address_2, + "eth_sendTransaction", + True, + sample_address_2, + ), + (sample_address_1, sample_address_2, "eth_gasPrice", False, None), + (sample_address_1, sample_address_2, "eth_blockNumber", False, None), + (sample_address_1, sample_address_2, "meow", False, None), + (sample_address_1, None, "eth_call", True, sample_address_1), + (sample_address_1, None, "eth_estimateGas", True, sample_address_1), + (sample_address_1, None, "eth_sendTransaction", True, sample_address_1), + (sample_address_1, None, "eth_gasPrice", False, None), + (sample_address_1, None, "eth_blockNumber", False, None), + (sample_address_1, None, "meow", False, None), + (None, sample_address_2, "eth_call", True, sample_address_2), + (None, sample_address_2, "eth_estimateGas", True, sample_address_2), + (None, sample_address_2, "eth_sendTransaction", True, sample_address_2), + (None, sample_address_2, "eth_gasPrice", False, sample_address_2), + (None, sample_address_2, "eth_blockNumber", False, sample_address_2), + (None, sample_address_2, "meow", False, sample_address_2), + (None, None, "eth_call", True, None), + (None, None, "eth_estimateGas", True, None), + (None, None, "eth_sendTransaction", True, None), + (None, None, "eth_gasPrice", False, None), + (None, None, "eth_blockNumber", False, None), + (None, None, "meow", False, None), + ), +) +@pytest.mark.asyncio +async def test_async_default_transaction_fields_middleware( + w3_accounts, + w3_coinbase, + method, + from_field_added, + from_field_value, +): + w3_accounts = [w3_accounts] + + async def mock_request(_method, params): + return params + + async def mock_async_accounts(): + return w3_accounts + + async def mock_async_coinbase(): + return w3_coinbase + + mock_w3 = AsyncMock() + mock_w3.eth.accounts = mock_async_accounts() + mock_w3.eth.coinbase = mock_async_coinbase() + + middleware = await async_default_transaction_fields_middleware( + mock_request, mock_w3 + ) + base_params = {"chainId": 5} + filled_transaction = await middleware(method, [base_params]) + + filled_params = filled_transaction[0] + assert ("from" in filled_params.keys()) == from_field_added + if "from" in filled_params.keys(): + assert filled_params["from"] == from_field_value + + filled_transaction[0].pop("from", None) + assert filled_transaction[0] == base_params + + # clean up + mock_w3.eth.accounts.close() + mock_w3.eth.coinbase.close() diff --git a/web3/providers/eth_tester/middleware.py b/web3/providers/eth_tester/middleware.py index 3689711c6e..c51eace53c 100644 --- a/web3/providers/eth_tester/middleware.py +++ b/web3/providers/eth_tester/middleware.py @@ -315,15 +315,21 @@ async def async_ethereum_tester_middleware( # type: ignore def guess_from(w3: "Web3", _: TxParams) -> ChecksumAddress: - coinbase = w3.eth.coinbase + if w3.eth.coinbase: + return w3.eth.coinbase + elif w3.eth.accounts and len(w3.eth.accounts) > 0: + return w3.eth.accounts[0] + + return None + + +async def async_guess_from(w3: "Web3", _: TxParams) -> ChecksumAddress: + coinbase = await w3.eth.coinbase # type: ignore + accounts = await w3.eth.accounts # type: ignore if coinbase is not None: return coinbase - - try: - return w3.eth.accounts[0] - except KeyError: - # no accounts available to pre-fill, carry on - pass + elif accounts is not None and len(accounts) > 0: + return accounts[0] return None @@ -340,6 +346,18 @@ def fill_default( return assoc(transaction, field, guess_val) +@curry +async def async_fill_default( + field: str, guess_func: Callable[..., Any], w3: "Web3", transaction: TxParams +) -> TxParams: + # type ignored b/c TxParams keys must be string literal types + if field in transaction and transaction[field] is not None: # type: ignore + return transaction + else: + guess_val = await guess_func(w3, transaction) + return assoc(transaction, field, guess_val) + + def default_transaction_fields_middleware( make_request: Callable[[RPCEndpoint, Any], Any], w3: "Web3" ) -> Callable[[RPCEndpoint, Any], RPCResponse]: @@ -363,7 +381,7 @@ def middleware(method: RPCEndpoint, params: Any) -> RPCResponse: async def async_default_transaction_fields_middleware( - make_request: Callable[[RPCEndpoint, Any], Any], web3: "Web3" + make_request: Callable[[RPCEndpoint, Any], Any], w3: "Web3" ) -> Callable[[RPCEndpoint, Any], RPCResponse]: async def middleware(method: RPCEndpoint, params: Any) -> RPCResponse: if method in ( @@ -371,7 +389,9 @@ async def middleware(method: RPCEndpoint, params: Any) -> RPCResponse: "eth_estimateGas", "eth_sendTransaction", ): - filled_transaction = fill_default("from", guess_from, web3, params[0]) + filled_transaction = await async_fill_default( + "from", async_guess_from, w3, params[0] + ) return await make_request(method, [filled_transaction] + list(params)[1:]) else: return await make_request(method, params)