Skip to content

Commit

Permalink
Async HTTP Provider
Browse files Browse the repository at this point in the history
  • Loading branch information
kclowes committed May 7, 2021
1 parent cfd97bd commit 9cee725
Show file tree
Hide file tree
Showing 12 changed files with 219 additions and 6 deletions.
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
url='https://github.com/ethereum/web3.py',
include_package_data=True,
install_requires=[
"aiohttp>=3.7.4.post0,<4",
"eth-abi>=2.0.0b6,<3.0.0",
"eth-account>=0.5.3,<0.6.0",
"eth-hash[pycryptodome]>=0.2.0,<1.0.0",
Expand Down
5 changes: 5 additions & 0 deletions tests/integration/go_ethereum/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pytest

from web3._utils.module_testing import ( # noqa: F401
AsyncEthModuleTest,
EthModuleTest,
GoEthereumAdminModuleTest,
GoEthereumPersonalModuleTest,
Expand Down Expand Up @@ -80,3 +81,7 @@ class GoEthereumAdminModuleTest(GoEthereumAdminModuleTest):

class GoEthereumPersonalModuleTest(GoEthereumPersonalModuleTest):
pass


class GoEthereumAsyncEthModuleTest(AsyncEthModuleTest):
pass
24 changes: 24 additions & 0 deletions tests/integration/go_ethereum/test_goethereum_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,24 @@
get_open_port,
)
from web3 import Web3
from web3.eth import (
AsyncEth,
)
from web3.providers.rpc import (
AsyncHTTPProvider,
)

from .common import (
GoEthereumAdminModuleTest,
GoEthereumAsyncEthModuleTest,
GoEthereumEthModuleTest,
GoEthereumNetModuleTest,
GoEthereumPersonalModuleTest,
GoEthereumTest,
GoEthereumVersionModuleTest,
)
from .utils import (
wait_for_aiohttp,
wait_for_http,
)

Expand Down Expand Up @@ -63,6 +71,18 @@ def web3(geth_process, endpoint_uri):
return _web3


@pytest.fixture(scope="module")
async def async_w3_http(geth_process, endpoint_uri):
await wait_for_aiohttp(endpoint_uri)
_web3 = Web3(
AsyncHTTPProvider(endpoint_uri),
middlewares=[],
modules={
'async_eth': (AsyncEth,),
})
return _web3


class TestGoEthereumTest(GoEthereumTest):
pass

Expand Down Expand Up @@ -97,3 +117,7 @@ class TestGoEthereumNetModuleTest(GoEthereumNetModuleTest):

class TestGoEthereumPersonalModuleTest(GoEthereumPersonalModuleTest):
pass


class TestGoEthereumAsyncEthModuleTest(GoEthereumAsyncEthModuleTest):
pass
13 changes: 13 additions & 0 deletions tests/integration/go_ethereum/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import socket
import time

import aiohttp
import requests


Expand Down Expand Up @@ -29,6 +30,18 @@ def wait_for_http(endpoint_uri, timeout=60):
break


async def wait_for_aiohttp(endpoint_uri, timeout=60):
start = time.time()
while time.time() < start + timeout:
try:
async with aiohttp.ClientSession() as session:
await session.get(endpoint_uri)
except aiohttp.client_exceptions.ClientConnectorError:
time.sleep(0.01)
else:
break


def wait_for_popen(proc, timeout):
start = time.time()
while time.time() < start + timeout:
Expand Down
2 changes: 2 additions & 0 deletions web3/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
)
from web3.providers.rpc import ( # noqa: E402
HTTPProvider,
AsyncHTTPProvider,
)
from web3.providers.websocket import ( # noqa: E402
WebsocketProvider,
Expand Down Expand Up @@ -45,4 +46,5 @@
"TestRPCProvider",
"EthereumTesterProvider",
"Account",
"AsyncHTTPProvider",
]
1 change: 1 addition & 0 deletions web3/_utils/module_testing/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .eth_module import ( # noqa: F401
AsyncEthModuleTest,
EthModuleTest,
)
from .go_ethereum_admin_module import ( # noqa: F401
Expand Down
12 changes: 12 additions & 0 deletions web3/_utils/module_testing/eth_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,18 @@
from web3.contract import Contract # noqa: F401


class AsyncEthModuleTest:
@pytest.mark.asyncio
async def test_eth_gas_price(self, async_w3_http: "Web3") -> None:
gas_price = await async_w3_http.async_eth.gas_price
assert gas_price > 0

@pytest.mark.asyncio
async def test_isConnected(self, async_w3_http: "Web3") -> None:
is_connected = await async_w3_http.isConnected() # type: ignore
assert is_connected is True


class EthModuleTest:
def test_eth_protocol_version(self, web3: "Web3") -> None:
with pytest.warns(DeprecationWarning,
Expand Down
16 changes: 16 additions & 0 deletions web3/_utils/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
Any,
)

from aiohttp import (
ClientSession,
ClientTimeout,
)
from eth_typing import (
URI,
)
Expand Down Expand Up @@ -40,3 +44,15 @@ def make_post_request(endpoint_uri: URI, data: bytes, *args: Any, **kwargs: Any)
response.raise_for_status()

return response.content


async def async_make_post_request(
endpoint_uri: URI, data: bytes, *args: Any, **kwargs: Any
) -> bytes:
kwargs.setdefault('timeout', ClientTimeout(10))
async with ClientSession(raise_for_status=True) as session:
async with session.post(endpoint_uri,
data=data,
*args,
**kwargs) as response:
return await response.read()
23 changes: 17 additions & 6 deletions web3/eth.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,23 @@
)


class Eth(Module):
class BaseEth(Module):
_gas_price: Method[Callable[[], Wei]] = Method(
RPC.eth_gasPrice,
mungers=None,
)


class AsyncEth(BaseEth):
is_async = True

@property
async def gas_price(self) -> Wei:
# types ignored b/c mypy conflict with BlockingEth properties
return await self._gas_price() # type: ignore


class Eth(BaseEth, Module):
account = Account()
_default_account: Union[ChecksumAddress, Empty] = empty
_default_block: BlockIdentifier = "latest"
Expand Down Expand Up @@ -175,11 +191,6 @@ def mining(self) -> bool:
def hashrate(self) -> int:
return self.get_hashrate()

_gas_price: Method[Callable[[], Wei]] = Method(
RPC.eth_gasPrice,
mungers=None,
)

@property
def gas_price(self) -> Wei:
return self._gas_price()
Expand Down
2 changes: 2 additions & 0 deletions web3/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
abi_ens_resolver,
)
from web3.eth import (
AsyncEth,
Eth,
)
from web3.geth import (
Expand Down Expand Up @@ -157,6 +158,7 @@ class Web3:
parity: Parity
geth: Geth
net: Net
async_eth: AsyncEth

def __init__(
self,
Expand Down
71 changes: 71 additions & 0 deletions web3/providers/async_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import itertools
from typing import (
TYPE_CHECKING,
Any,
Callable,
Coroutine,
cast,
)

from eth_utils import (
to_bytes,
to_text,
)

from web3._utils.encoding import (
FriendlyJsonSerde,
)
from web3.types import (
MiddlewareOnion,
RPCEndpoint,
RPCResponse,
)

if TYPE_CHECKING:
from web3 import Web3 # noqa: F401


class AsyncBaseProvider:
def request_func(
self, web3: "Web3", outer_middlewares: MiddlewareOnion
) -> Callable[[RPCEndpoint, Any], Coroutine[Any, Any, RPCResponse]]:
# Placeholder - manager calls self.provider.request_func
# Eventually this will handle caching and return make_request
# along with all the middleware
return self.make_request

async def make_request(self, method: RPCEndpoint, params: Any) -> RPCResponse:
raise NotImplementedError("Providers must implement this method")

async def isConnected(self) -> bool:
raise NotImplementedError("Providers must implement this method")


class AsyncJSONBaseProvider(AsyncBaseProvider):
def __init__(self) -> None:
self.request_counter = itertools.count()

async def encode_rpc_request(self, method: RPCEndpoint, params: Any) -> bytes:
rpc_dict = {
"jsonrpc": "2.0",
"method": method,
"params": params or [],
"id": next(self.request_counter),
}
encoded = FriendlyJsonSerde().json_encode(rpc_dict)
return to_bytes(text=encoded)

async def decode_rpc_response(self, raw_response: bytes) -> RPCResponse:
text_response = to_text(raw_response)
return cast(RPCResponse, FriendlyJsonSerde().json_decode(text_response))

async def isConnected(self) -> bool:
try:
response = await self.make_request(RPCEndpoint('web3_clientVersion'), [])
except IOError:
return False

assert response['jsonrpc'] == '2.0'
assert 'error' not in response

return True
55 changes: 55 additions & 0 deletions web3/providers/rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
construct_user_agent,
)
from web3._utils.request import (
async_make_post_request,
cache_session,
make_post_request,
)
Expand All @@ -35,6 +36,9 @@
RPCResponse,
)

from .async_base import (
AsyncJSONBaseProvider,
)
from .base import (
JSONBaseProvider,
)
Expand Down Expand Up @@ -99,3 +103,54 @@ def make_request(self, method: RPCEndpoint, params: Any) -> RPCResponse:
"Method: %s, Response: %s",
self.endpoint_uri, method, response)
return response


class AsyncHTTPProvider(AsyncJSONBaseProvider):
logger = logging.getLogger("web3.providers.HTTPProvider")
endpoint_uri = None
_request_kwargs = None

def __init__(
self, endpoint_uri: Optional[Union[URI, str]] = None,
request_kwargs: Optional[Any] = None,
session: Optional[Any] = None
) -> None:
if endpoint_uri is None:
self.endpoint_uri = get_default_endpoint()
else:
self.endpoint_uri = URI(endpoint_uri)

self._request_kwargs = request_kwargs or {}

super().__init__()

def __str__(self) -> str:
return "RPC connection {0}".format(self.endpoint_uri)

@to_dict
def get_request_kwargs(self) -> Iterable[Tuple[str, Any]]:
if 'headers' not in self._request_kwargs:
yield 'headers', self.get_request_headers()
for key, value in self._request_kwargs.items():
yield key, value

def get_request_headers(self) -> Dict[str, str]:
return {
'Content-Type': 'application/json',
'User-Agent': construct_user_agent(str(type(self))),
}

async def make_request(self, method: RPCEndpoint, params: Any) -> RPCResponse:
self.logger.debug("Making request HTTP. URI: %s, Method: %s",
self.endpoint_uri, method)
request_data = await self.encode_rpc_request(method, params)
raw_response = await async_make_post_request(
self.endpoint_uri,
request_data,
**self.get_request_kwargs()
)
response = await self.decode_rpc_response(raw_response)
self.logger.debug("Getting response HTTP. URI: %s, "
"Method: %s, Response: %s",
self.endpoint_uri, method, response)
return response

0 comments on commit 9cee725

Please sign in to comment.