Skip to content

Commit

Permalink
fix bug in how eth_tester middleware filled default fields
Browse files Browse the repository at this point in the history
  • Loading branch information
pacrob committed Aug 10, 2022
1 parent 41f88cb commit b892b22
Show file tree
Hide file tree
Showing 3 changed files with 192 additions and 9 deletions.
1 change: 1 addition & 0 deletions newsfragments/2600.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
fixed bug in how async_eth_tester_middleware fills default fields
162 changes: 162 additions & 0 deletions tests/core/middleware/test_eth_tester_middleware.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
import pytest
from unittest.mock import (
AsyncMock,
Mock,
)

from web3.providers.eth_tester.middleware import (
async_default_transaction_fields_middleware,
default_transaction_fields_middleware,
)
from web3.types import (
BlockData,
)
Expand All @@ -20,3 +28,157 @@ def test_get_block_formatters(w3):
keys_diff = all_block_keys.difference(latest_block_keys)
assert len(keys_diff) == 1
assert keys_diff.pop() == "mixHash" # mixHash is not implemented in eth-tester


sample_address_1 = "0x0000000000000000000000000000000000000001"
sample_address_2 = "0x0000000000000000000000000000000000000002"


@pytest.mark.parametrize(
"w3_accounts, w3_coinbase, method, from_field_added, from_field_value",
(
(sample_address_1, sample_address_2, "eth_call", True, sample_address_2),
(
sample_address_1,
sample_address_2,
"eth_estimateGas",
True,
sample_address_2,
),
(
sample_address_1,
sample_address_2,
"eth_sendTransaction",
True,
sample_address_2,
),
(sample_address_1, sample_address_2, "eth_gasPrice", False, None),
(sample_address_1, sample_address_2, "eth_blockNumber", False, None),
(sample_address_1, sample_address_2, "meow", False, None),
(sample_address_1, None, "eth_call", True, sample_address_1),
(sample_address_1, None, "eth_estimateGas", True, sample_address_1),
(sample_address_1, None, "eth_sendTransaction", True, sample_address_1),
(sample_address_1, None, "eth_gasPrice", False, None),
(sample_address_1, None, "eth_blockNumber", False, None),
(sample_address_1, None, "meow", False, None),
(None, sample_address_2, "eth_call", True, sample_address_2),
(None, sample_address_2, "eth_estimateGas", True, sample_address_2),
(None, sample_address_2, "eth_sendTransaction", True, sample_address_2),
(None, sample_address_2, "eth_gasPrice", False, sample_address_2),
(None, sample_address_2, "eth_blockNumber", False, sample_address_2),
(None, sample_address_2, "meow", False, sample_address_2),
(None, None, "eth_call", True, None),
(None, None, "eth_estimateGas", True, None),
(None, None, "eth_sendTransaction", True, None),
(None, None, "eth_gasPrice", False, None),
(None, None, "eth_blockNumber", False, None),
(None, None, "meow", False, None),
),
)
def test_default_transaction_fields_middleware(
w3_accounts, w3_coinbase, method, from_field_added, from_field_value
):
w3_accounts = [w3_accounts]

def mock_request(_method, params):
return params

mock_w3 = Mock()
mock_w3.eth.accounts = w3_accounts
mock_w3.eth.coinbase = w3_coinbase

middleware = default_transaction_fields_middleware(mock_request, mock_w3)
base_params = {"chainId": 5}
filled_transaction = middleware(method, [base_params])

filled_params = filled_transaction[0]

assert ("from" in filled_params.keys()) == from_field_added
if "from" in filled_params.keys():
assert filled_params["from"] == from_field_value

filled_transaction[0].pop("from", None)
assert filled_transaction[0] == base_params


@pytest.mark.parametrize(
"w3_accounts, w3_coinbase, method, from_field_added, from_field_value",
(
(sample_address_1, sample_address_2, "eth_call", True, sample_address_2),
(
sample_address_1,
sample_address_2,
"eth_estimateGas",
True,
sample_address_2,
),
(
sample_address_1,
sample_address_2,
"eth_sendTransaction",
True,
sample_address_2,
),
(sample_address_1, sample_address_2, "eth_gasPrice", False, None),
(sample_address_1, sample_address_2, "eth_blockNumber", False, None),
(sample_address_1, sample_address_2, "meow", False, None),
(sample_address_1, None, "eth_call", True, sample_address_1),
(sample_address_1, None, "eth_estimateGas", True, sample_address_1),
(sample_address_1, None, "eth_sendTransaction", True, sample_address_1),
(sample_address_1, None, "eth_gasPrice", False, None),
(sample_address_1, None, "eth_blockNumber", False, None),
(sample_address_1, None, "meow", False, None),
(None, sample_address_2, "eth_call", True, sample_address_2),
(None, sample_address_2, "eth_estimateGas", True, sample_address_2),
(None, sample_address_2, "eth_sendTransaction", True, sample_address_2),
(None, sample_address_2, "eth_gasPrice", False, sample_address_2),
(None, sample_address_2, "eth_blockNumber", False, sample_address_2),
(None, sample_address_2, "meow", False, sample_address_2),
(None, None, "eth_call", True, None),
(None, None, "eth_estimateGas", True, None),
(None, None, "eth_sendTransaction", True, None),
(None, None, "eth_gasPrice", False, None),
(None, None, "eth_blockNumber", False, None),
(None, None, "meow", False, None),
),
)
@pytest.mark.asyncio
async def test_async_default_transaction_fields_middleware(
w3_accounts,
w3_coinbase,
method,
from_field_added,
from_field_value,
):
w3_accounts = [w3_accounts]

async def mock_request(_method, params):
return params

async def mock_async_accounts():
return w3_accounts

async def mock_async_coinbase():
return w3_coinbase

mock_w3 = AsyncMock()
mock_w3.eth.accounts = mock_async_accounts()
mock_w3.eth.coinbase = mock_async_coinbase()

middleware = await async_default_transaction_fields_middleware(
mock_request, mock_w3
)
base_params = {"chainId": 5}
filled_transaction = await middleware(method, [base_params])

filled_params = filled_transaction[0]
assert ("from" in filled_params.keys()) == from_field_added
if "from" in filled_params.keys():
assert filled_params["from"] == from_field_value

filled_transaction[0].pop("from", None)
assert filled_transaction[0] == base_params

# clean up
mock_w3.eth.accounts.close()
mock_w3.eth.coinbase.close()
38 changes: 29 additions & 9 deletions web3/providers/eth_tester/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,15 +315,21 @@ async def async_ethereum_tester_middleware( # type: ignore


def guess_from(w3: "Web3", _: TxParams) -> ChecksumAddress:
coinbase = w3.eth.coinbase
if w3.eth.coinbase:
return w3.eth.coinbase
elif w3.eth.accounts and len(w3.eth.accounts) > 0:
return w3.eth.accounts[0]

return None


async def async_guess_from(w3: "Web3", _: TxParams) -> ChecksumAddress:
coinbase = await w3.eth.coinbase # type: ignore
accounts = await w3.eth.accounts # type: ignore
if coinbase is not None:
return coinbase

try:
return w3.eth.accounts[0]
except KeyError:
# no accounts available to pre-fill, carry on
pass
elif accounts is not None and len(accounts) > 0:
return accounts[0]

return None

Expand All @@ -340,6 +346,18 @@ def fill_default(
return assoc(transaction, field, guess_val)


@curry
async def async_fill_default(
field: str, guess_func: Callable[..., Any], w3: "Web3", transaction: TxParams
) -> TxParams:
# type ignored b/c TxParams keys must be string literal types
if field in transaction and transaction[field] is not None: # type: ignore
return transaction
else:
guess_val = await guess_func(w3, transaction)
return assoc(transaction, field, guess_val)


def default_transaction_fields_middleware(
make_request: Callable[[RPCEndpoint, Any], Any], w3: "Web3"
) -> Callable[[RPCEndpoint, Any], RPCResponse]:
Expand All @@ -363,15 +381,17 @@ def middleware(method: RPCEndpoint, params: Any) -> RPCResponse:


async def async_default_transaction_fields_middleware(
make_request: Callable[[RPCEndpoint, Any], Any], web3: "Web3"
make_request: Callable[[RPCEndpoint, Any], Any], w3: "Web3"
) -> Callable[[RPCEndpoint, Any], RPCResponse]:
async def middleware(method: RPCEndpoint, params: Any) -> RPCResponse:
if method in (
"eth_call",
"eth_estimateGas",
"eth_sendTransaction",
):
filled_transaction = fill_default("from", guess_from, web3, params[0])
filled_transaction = await async_fill_default(
"from", async_guess_from, w3, params[0]
)
return await make_request(method, [filled_transaction] + list(params)[1:])
else:
return await make_request(method, params)
Expand Down

0 comments on commit b892b22

Please sign in to comment.