diff --git a/newsfragments/1978.feature.rst b/newsfragments/1978.feature.rst new file mode 100644 index 0000000000..8ccdad5364 --- /dev/null +++ b/newsfragments/1978.feature.rst @@ -0,0 +1,3 @@ +Add new AsyncHTTPProvider. No middleware or session caching support yet. + +Also adds async ``w3.eth.gas_price``, and async ``w3.isConnected()`` methods. diff --git a/setup.py b/setup.py index 10c2ba37ee..788acc73d0 100644 --- a/setup.py +++ b/setup.py @@ -72,6 +72,7 @@ url='https://github.com/ethereum/web3.py', include_package_data=True, install_requires=[ + "aiohttp>=3.7.4.post0,<4", "eth-abi>=2.0.0b6,<3.0.0", "eth-account>=0.5.3,<0.6.0", "eth-hash[pycryptodome]>=0.2.0,<1.0.0", diff --git a/tests/integration/go_ethereum/common.py b/tests/integration/go_ethereum/common.py index ff747167dd..ce3ca6063a 100644 --- a/tests/integration/go_ethereum/common.py +++ b/tests/integration/go_ethereum/common.py @@ -4,6 +4,7 @@ import pytest from web3._utils.module_testing import ( # noqa: F401 + AsyncEthModuleTest, EthModuleTest, GoEthereumAdminModuleTest, GoEthereumPersonalModuleTest, @@ -80,3 +81,7 @@ class GoEthereumAdminModuleTest(GoEthereumAdminModuleTest): class GoEthereumPersonalModuleTest(GoEthereumPersonalModuleTest): pass + + +class GoEthereumAsyncEthModuleTest(AsyncEthModuleTest): + pass diff --git a/tests/integration/go_ethereum/test_goethereum_http.py b/tests/integration/go_ethereum/test_goethereum_http.py index fca859ff9a..0e353af3db 100644 --- a/tests/integration/go_ethereum/test_goethereum_http.py +++ b/tests/integration/go_ethereum/test_goethereum_http.py @@ -4,9 +4,16 @@ get_open_port, ) from web3 import Web3 +from web3.eth import ( + AsyncEth, +) +from web3.providers.async_rpc import ( + AsyncHTTPProvider, +) from .common import ( GoEthereumAdminModuleTest, + GoEthereumAsyncEthModuleTest, GoEthereumEthModuleTest, GoEthereumNetModuleTest, GoEthereumPersonalModuleTest, @@ -14,6 +21,7 @@ GoEthereumVersionModuleTest, ) from .utils import ( + wait_for_aiohttp, wait_for_http, ) @@ -63,6 +71,18 @@ def web3(geth_process, endpoint_uri): return _web3 +@pytest.fixture(scope="module") +async def async_w3_http(geth_process, endpoint_uri): + await wait_for_aiohttp(endpoint_uri) + _web3 = Web3( + AsyncHTTPProvider(endpoint_uri), + middlewares=[], + modules={ + 'async_eth': (AsyncEth,), + }) + return _web3 + + class TestGoEthereumTest(GoEthereumTest): pass @@ -97,3 +117,7 @@ class TestGoEthereumNetModuleTest(GoEthereumNetModuleTest): class TestGoEthereumPersonalModuleTest(GoEthereumPersonalModuleTest): pass + + +class TestGoEthereumAsyncEthModuleTest(GoEthereumAsyncEthModuleTest): + pass diff --git a/tests/integration/go_ethereum/utils.py b/tests/integration/go_ethereum/utils.py index b9c1176038..a9db62c41b 100644 --- a/tests/integration/go_ethereum/utils.py +++ b/tests/integration/go_ethereum/utils.py @@ -2,6 +2,7 @@ import socket import time +import aiohttp import requests @@ -29,6 +30,18 @@ def wait_for_http(endpoint_uri, timeout=60): break +async def wait_for_aiohttp(endpoint_uri, timeout=60): + start = time.time() + while time.time() < start + timeout: + try: + async with aiohttp.ClientSession() as session: + await session.get(endpoint_uri) + except aiohttp.client_exceptions.ClientConnectorError: + time.sleep(0.01) + else: + break + + def wait_for_popen(proc, timeout): start = time.time() while time.time() < start + timeout: diff --git a/web3/__init__.py b/web3/__init__.py index c0e0e71e4e..5c8edfacef 100644 --- a/web3/__init__.py +++ b/web3/__init__.py @@ -18,6 +18,9 @@ from web3.providers.rpc import ( # noqa: E402 HTTPProvider, ) +from web3.providers.async_rpc import ( # noqa: E402 + AsyncHTTPProvider, +) from web3.providers.websocket import ( # noqa: E402 WebsocketProvider, ) @@ -45,4 +48,5 @@ "TestRPCProvider", "EthereumTesterProvider", "Account", + "AsyncHTTPProvider", ] diff --git a/web3/_utils/module_testing/__init__.py b/web3/_utils/module_testing/__init__.py index 3abf2c5437..94228bebad 100644 --- a/web3/_utils/module_testing/__init__.py +++ b/web3/_utils/module_testing/__init__.py @@ -1,4 +1,5 @@ from .eth_module import ( # noqa: F401 + AsyncEthModuleTest, EthModuleTest, ) from .go_ethereum_admin_module import ( # noqa: F401 diff --git a/web3/_utils/module_testing/eth_module.py b/web3/_utils/module_testing/eth_module.py index bb25323542..f2851d616d 100644 --- a/web3/_utils/module_testing/eth_module.py +++ b/web3/_utils/module_testing/eth_module.py @@ -58,6 +58,18 @@ from web3.contract import Contract # noqa: F401 +class AsyncEthModuleTest: + @pytest.mark.asyncio + async def test_eth_gas_price(self, async_w3_http: "Web3") -> None: + gas_price = await async_w3_http.async_eth.gas_price + assert gas_price > 0 + + @pytest.mark.asyncio + async def test_isConnected(self, async_w3_http: "Web3") -> None: + is_connected = await async_w3_http.isConnected() # type: ignore + assert is_connected is True + + class EthModuleTest: def test_eth_protocol_version(self, web3: "Web3") -> None: with pytest.warns(DeprecationWarning, diff --git a/web3/_utils/request.py b/web3/_utils/request.py index f55f59718f..0ddfc81354 100644 --- a/web3/_utils/request.py +++ b/web3/_utils/request.py @@ -1,7 +1,12 @@ +import os from typing import ( Any, ) +from aiohttp import ( + ClientSession, + ClientTimeout, +) from eth_typing import ( URI, ) @@ -13,6 +18,10 @@ ) +def get_default_http_endpoint() -> URI: + return URI(os.environ.get('WEB3_HTTP_PROVIDER_URI', 'http://localhost:8545')) + + def cache_session(endpoint_uri: URI, session: requests.Session) -> None: cache_key = generate_cache_key(endpoint_uri) _session_cache[cache_key] = session @@ -40,3 +49,15 @@ def make_post_request(endpoint_uri: URI, data: bytes, *args: Any, **kwargs: Any) response.raise_for_status() return response.content + + +async def async_make_post_request( + endpoint_uri: URI, data: bytes, *args: Any, **kwargs: Any +) -> bytes: + kwargs.setdefault('timeout', ClientTimeout(10)) + async with ClientSession(raise_for_status=True) as session: + async with session.post(endpoint_uri, + data=data, + *args, + **kwargs) as response: + return await response.read() diff --git a/web3/eth.py b/web3/eth.py index af857e18b8..ed30fe2aff 100644 --- a/web3/eth.py +++ b/web3/eth.py @@ -104,7 +104,23 @@ ) -class Eth(Module): +class BaseEth(Module): + _gas_price: Method[Callable[[], Wei]] = Method( + RPC.eth_gasPrice, + mungers=None, + ) + + +class AsyncEth(BaseEth): + is_async = True + + @property + async def gas_price(self) -> Wei: + # types ignored b/c mypy conflict with BlockingEth properties + return await self._gas_price() # type: ignore + + +class Eth(BaseEth, Module): account = Account() _default_account: Union[ChecksumAddress, Empty] = empty _default_block: BlockIdentifier = "latest" @@ -175,11 +191,6 @@ def mining(self) -> bool: def hashrate(self) -> int: return self.get_hashrate() - _gas_price: Method[Callable[[], Wei]] = Method( - RPC.eth_gasPrice, - mungers=None, - ) - @property def gas_price(self) -> Wei: return self._gas_price() diff --git a/web3/main.py b/web3/main.py index 6cbfd6d909..edc3c05b71 100644 --- a/web3/main.py +++ b/web3/main.py @@ -53,6 +53,7 @@ abi_ens_resolver, ) from web3.eth import ( + AsyncEth, Eth, ) from web3.geth import ( @@ -157,6 +158,7 @@ class Web3: parity: Parity geth: Geth net: Net + async_eth: AsyncEth def __init__( self, diff --git a/web3/providers/async_base.py b/web3/providers/async_base.py new file mode 100644 index 0000000000..a585a2039b --- /dev/null +++ b/web3/providers/async_base.py @@ -0,0 +1,71 @@ +import itertools +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Coroutine, + cast, +) + +from eth_utils import ( + to_bytes, + to_text, +) + +from web3._utils.encoding import ( + FriendlyJsonSerde, +) +from web3.types import ( + MiddlewareOnion, + RPCEndpoint, + RPCResponse, +) + +if TYPE_CHECKING: + from web3 import Web3 # noqa: F401 + + +class AsyncBaseProvider: + def request_func( + self, web3: "Web3", outer_middlewares: MiddlewareOnion + ) -> Callable[[RPCEndpoint, Any], Coroutine[Any, Any, RPCResponse]]: + # Placeholder - manager calls self.provider.request_func + # Eventually this will handle caching and return make_request + # along with all the middleware + return self.make_request + + async def make_request(self, method: RPCEndpoint, params: Any) -> RPCResponse: + raise NotImplementedError("Providers must implement this method") + + async def isConnected(self) -> bool: + raise NotImplementedError("Providers must implement this method") + + +class AsyncJSONBaseProvider(AsyncBaseProvider): + def __init__(self) -> None: + self.request_counter = itertools.count() + + async def encode_rpc_request(self, method: RPCEndpoint, params: Any) -> bytes: + rpc_dict = { + "jsonrpc": "2.0", + "method": method, + "params": params or [], + "id": next(self.request_counter), + } + encoded = FriendlyJsonSerde().json_encode(rpc_dict) + return to_bytes(text=encoded) + + async def decode_rpc_response(self, raw_response: bytes) -> RPCResponse: + text_response = to_text(raw_response) + return cast(RPCResponse, FriendlyJsonSerde().json_decode(text_response)) + + async def isConnected(self) -> bool: + try: + response = await self.make_request(RPCEndpoint('web3_clientVersion'), []) + except IOError: + return False + + assert response['jsonrpc'] == '2.0' + assert 'error' not in response + + return True diff --git a/web3/providers/async_rpc.py b/web3/providers/async_rpc.py new file mode 100644 index 0000000000..db24f27a61 --- /dev/null +++ b/web3/providers/async_rpc.py @@ -0,0 +1,83 @@ +import logging +from typing import ( + Any, + Dict, + Iterable, + Optional, + Tuple, + Union, +) + +from eth_typing import ( + URI, +) +from eth_utils import ( + to_dict, +) + +from web3._utils.http import ( + construct_user_agent, +) +from web3._utils.request import ( + async_make_post_request, + get_default_http_endpoint, +) +from web3.types import ( + RPCEndpoint, + RPCResponse, +) + +from .async_base import ( + AsyncJSONBaseProvider, +) + + +class AsyncHTTPProvider(AsyncJSONBaseProvider): + logger = logging.getLogger("web3.providers.HTTPProvider") + endpoint_uri = None + _request_kwargs = None + + def __init__( + self, endpoint_uri: Optional[Union[URI, str]] = None, + request_kwargs: Optional[Any] = None, + session: Optional[Any] = None + ) -> None: + if endpoint_uri is None: + self.endpoint_uri = get_default_http_endpoint() + else: + self.endpoint_uri = URI(endpoint_uri) + + self._request_kwargs = request_kwargs or {} + + super().__init__() + + def __str__(self) -> str: + return "RPC connection {0}".format(self.endpoint_uri) + + @to_dict + def get_request_kwargs(self) -> Iterable[Tuple[str, Any]]: + if 'headers' not in self._request_kwargs: + yield 'headers', self.get_request_headers() + for key, value in self._request_kwargs.items(): + yield key, value + + def get_request_headers(self) -> Dict[str, str]: + return { + 'Content-Type': 'application/json', + 'User-Agent': construct_user_agent(str(type(self))), + } + + async def make_request(self, method: RPCEndpoint, params: Any) -> RPCResponse: + self.logger.debug("Making request HTTP. URI: %s, Method: %s", + self.endpoint_uri, method) + request_data = await self.encode_rpc_request(method, params) + raw_response = await async_make_post_request( + self.endpoint_uri, + request_data, + **self.get_request_kwargs() + ) + response = await self.decode_rpc_response(raw_response) + self.logger.debug("Getting response HTTP. URI: %s, " + "Method: %s, Response: %s", + self.endpoint_uri, method, response) + return response diff --git a/web3/providers/rpc.py b/web3/providers/rpc.py index af6b78fd8f..0be744f71d 100644 --- a/web3/providers/rpc.py +++ b/web3/providers/rpc.py @@ -1,5 +1,4 @@ import logging -import os from typing import ( Any, Dict, @@ -21,6 +20,7 @@ ) from web3._utils.request import ( cache_session, + get_default_http_endpoint, make_post_request, ) from web3.datastructures import ( @@ -40,10 +40,6 @@ ) -def get_default_endpoint() -> URI: - return URI(os.environ.get('WEB3_HTTP_PROVIDER_URI', 'http://localhost:8545')) - - class HTTPProvider(JSONBaseProvider): logger = logging.getLogger("web3.providers.HTTPProvider") endpoint_uri = None @@ -58,7 +54,7 @@ def __init__( session: Optional[Any] = None ) -> None: if endpoint_uri is None: - self.endpoint_uri = get_default_endpoint() + self.endpoint_uri = get_default_http_endpoint() else: self.endpoint_uri = URI(endpoint_uri)