Skip to content

Commit

Permalink
Add support for async_simple_cache_middleware
Browse files Browse the repository at this point in the history
  • Loading branch information
fselmo committed Jul 18, 2022
1 parent 3c00810 commit aae75ab
Show file tree
Hide file tree
Showing 5 changed files with 275 additions and 1 deletion.
3 changes: 2 additions & 1 deletion docs/providers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -483,5 +483,6 @@ Supported Middleware
- :meth:`Gas Price Strategy <web3.middleware.gas_price_strategy_middleware>`
- :meth:`Buffered Gas Estimate Middleware <web3.middleware.buffered_gas_estimate_middleware>`
- :meth:`Stalecheck Middleware <web3.middleware.make_stalecheck_middleware>`
- :meth:`Validation middleware <web3.middleware.validation>`
- :meth:`Validation Middleware <web3.middleware.validation>`
- :ref:`Geth POA Middleware <geth-poa>`
- :meth:`Simple Cache Middleware <web3.middleware.simple_cache_middleware>`
1 change: 1 addition & 0 deletions newsfragments/2579.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Async support upport for caching certain methods via ``async_simple_cache_middleware`` as well as constructing custom async caching middleware via ``async_construct_simple_cache_middleware``.
165 changes: 165 additions & 0 deletions tests/core/middleware/test_simple_cache_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,22 @@
construct_result_generator_middleware,
construct_simple_cache_middleware,
)
from web3.middleware.async_middleware.async_cache import (
async_construct_simple_cache_middleware,
)
from web3.middleware.fixture import (
async_construct_error_generator_middleware,
async_construct_result_generator_middleware,
)
from web3.providers.base import (
BaseProvider,
)
from web3.providers.eth_tester import (
AsyncEthereumTesterProvider,
)
from web3.types import (
RPCEndpoint,
)


@pytest.fixture
Expand Down Expand Up @@ -133,3 +146,155 @@ def test_simple_cache_middleware_does_not_cache_endpoints_not_in_whitelist(w3):
result_b = w3.manager.request_blocking("not_whitelisted", [])

assert result_a != result_b


# -- async -- #


async def _async_simple_cache_middleware_for_testing(make_request, async_w3):
middleware = await async_construct_simple_cache_middleware(
cache_class=dict,
rpc_whitelist={RPCEndpoint("fake_endpoint")},
)
return await middleware(make_request, async_w3)


@pytest.fixture
def async_w3():
return Web3(
provider=AsyncEthereumTesterProvider(),
middlewares=[
(_async_simple_cache_middleware_for_testing, "simple_cache"),
],
)


@pytest.mark.asyncio
async def test_async_simple_cache_middleware_pulls_from_cache(async_w3):
# remove the pre-loaded simple cache middleware to replace with test-specific:
async_w3.middleware_onion.remove("simple_cache")

def cache_class():
return {
generate_cache_key(("fake_endpoint", [1])): {"result": "value-a"},
}

async def _properly_awaited_middleware(make_request, _async_w3):
middleware = await async_construct_simple_cache_middleware(
cache_class=cache_class,
rpc_whitelist={RPCEndpoint("fake_endpoint")},
)
return await middleware(make_request, _async_w3)

async_w3.middleware_onion.inject(
_properly_awaited_middleware,
"for_this_test_only",
layer=0,
)

_result = await async_w3.manager.coro_request("fake_endpoint", [1])
assert _result == "value-a"

# cleanup
async_w3.middleware_onion.remove("for_this_test_only")
# add back the pre-loaded simple cache middleware:
async_w3.middleware_onion.add(
_async_simple_cache_middleware_for_testing, "simple_cache"
)


@pytest.mark.asyncio
async def test_async_simple_cache_middleware_populates_cache(async_w3):
async_w3.middleware_onion.inject(
await async_construct_result_generator_middleware(
{
RPCEndpoint("fake_endpoint"): lambda *_: str(uuid.uuid4()),
}
),
"result_generator",
layer=0,
)

result = await async_w3.manager.coro_request("fake_endpoint", [])

_empty_params = await async_w3.manager.coro_request("fake_endpoint", [])
_non_empty_params = await async_w3.manager.coro_request("fake_endpoint", [1])

assert _empty_params == result
assert _non_empty_params != result

# cleanup
async_w3.middleware_onion.remove("result_generator")


@pytest.mark.asyncio
async def test_async_simple_cache_middleware_does_not_cache_none_responses(async_w3):
counter = itertools.count()

def result_cb(method, params):
next(counter)
return None

async_w3.middleware_onion.inject(
await async_construct_result_generator_middleware(
{
RPCEndpoint("fake_endpoint"): result_cb,
},
),
"result_generator",
layer=0,
)

await async_w3.manager.coro_request("fake_endpoint", [])
await async_w3.manager.coro_request("fake_endpoint", [])

assert next(counter) == 2

# cleanup
async_w3.middleware_onion.remove("result_generator")


@pytest.mark.asyncio
async def test_async_simple_cache_middleware_does_not_cache_error_responses(async_w3):
async_w3.middleware_onion.inject(
await async_construct_error_generator_middleware(
{
RPCEndpoint("fake_endpoint"): lambda *_: f"msg-{uuid.uuid4()}",
}
),
"error_generator",
layer=0,
)

with pytest.raises(ValueError) as err_a:
await async_w3.manager.coro_request("fake_endpoint", [])
with pytest.raises(ValueError) as err_b:
await async_w3.manager.coro_request("fake_endpoint", [])

assert str(err_a) != str(err_b)

# cleanup
async_w3.middleware_onion.remove("error_generator")


@pytest.mark.asyncio
async def test_async_simple_cache_middleware_does_not_cache_non_whitelist_endpoints(
async_w3,
):
async_w3.middleware_onion.inject(
await async_construct_result_generator_middleware(
{
RPCEndpoint("not_whitelisted"): lambda *_: str(uuid.uuid4()),
}
),
"result_generator",
layer=0,
)

result_a = await async_w3.manager.coro_request("not_whitelisted", [])
result_b = await async_w3.manager.coro_request("not_whitelisted", [])

assert result_a != result_b

# cleanup
async_w3.middleware_onion.remove("result_generator")
3 changes: 3 additions & 0 deletions web3/middleware/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
from .abi import ( # noqa: F401
abi_middleware,
)
from .async_middleware.async_cache import ( # noqa: F401
_async_simple_cache_middleware as async_simple_cache_middleware,
)
from .attrdict import ( # noqa: F401
attrdict_middleware,
)
Expand Down
104 changes: 104 additions & 0 deletions web3/middleware/async_middleware/async_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import functools
import threading
from typing import (
TYPE_CHECKING,
Any,
Callable,
Collection,
Dict,
Tuple,
Type,
cast,
)

import lru

from web3._utils.caching import (
generate_cache_key,
)
from web3.types import (
AsyncMiddleware,
Middleware,
RPCEndpoint,
RPCResponse,
)

if TYPE_CHECKING:
from web3 import Web3 # noqa: F401


SIMPLE_CACHE_RPC_WHITELIST = cast(
Tuple[RPCEndpoint],
(
"web3_clientVersion",
"eth_getBlockTransactionCountByHash",
"eth_getUncleCountByBlockHash",
"eth_getBlockByHash",
"eth_getTransactionByHash",
"eth_getTransactionByBlockHashAndIndex",
"eth_getRawTransactionByHash",
"eth_getUncleByBlockHashAndIndex",
"eth_chainId",
),
)


def _should_cache_response(
_method: RPCEndpoint, _params: Any, response: RPCResponse
) -> bool:
return (
"error" not in response
and "result" in response
and response["result"] is not None
)


async def async_construct_simple_cache_middleware(
cache_class: Type[Dict[Any, Any]],
rpc_whitelist: Collection[RPCEndpoint] = SIMPLE_CACHE_RPC_WHITELIST,
should_cache_fn: Callable[
[RPCEndpoint, Any, RPCResponse], bool
] = _should_cache_response,
) -> Middleware:
"""
Constructs a middleware which caches responses based on the request
``method`` and ``params``
:param cache_class: Any dictionary-like object
:param rpc_whitelist: A set of RPC methods which may have their responses cached.
:param should_cache_fn: A callable which accepts ``method`` ``params`` and
``response`` and returns a boolean as to whether the response should be
cached.
"""

async def async_simple_cache_middleware(
make_request: Callable[[RPCEndpoint, Any], Any], async_w3: "Web3"
) -> AsyncMiddleware:
cache = cache_class()
lock = threading.Lock()

async def middleware(method: RPCEndpoint, params: Any) -> RPCResponse:
if method in rpc_whitelist:
with lock:
cache_key = generate_cache_key((method, params))
if cache_key not in cache:
response = await make_request(method, params)
if should_cache_fn(method, params, response):
cache[cache_key] = response
return response
return cache[cache_key]
else:
return await make_request(method, params)

return middleware

return async_simple_cache_middleware


async def _async_simple_cache_middleware(
make_request: Callable[[RPCEndpoint, Any], Any], async_w3: "Web3"
):
middleware = await async_construct_simple_cache_middleware(
cache_class=cast(Type[Dict[Any, Any]], functools.partial(lru.LRU, 256)),
)
return await middleware(make_request, async_w3)

0 comments on commit aae75ab

Please sign in to comment.