diff --git a/docs/providers.rst b/docs/providers.rst index 2b0b7a2069..d18e3ad827 100644 --- a/docs/providers.rst +++ b/docs/providers.rst @@ -483,5 +483,6 @@ Supported Middleware - :meth:`Gas Price Strategy ` - :meth:`Buffered Gas Estimate Middleware ` - :meth:`Stalecheck Middleware ` -- :meth:`Validation middleware ` +- :meth:`Validation Middleware ` - :ref:`Geth POA Middleware ` +- :meth:`Simple Cache Middleware ` diff --git a/newsfragments/2579.feature.rst b/newsfragments/2579.feature.rst new file mode 100644 index 0000000000..dcb55738d5 --- /dev/null +++ b/newsfragments/2579.feature.rst @@ -0,0 +1 @@ +Async support upport for caching certain methods via ``async_simple_cache_middleware`` as well as constructing custom async caching middleware via ``async_construct_simple_cache_middleware``. diff --git a/tests/core/middleware/test_simple_cache_middleware.py b/tests/core/middleware/test_simple_cache_middleware.py index 56061d83cd..0479bc5993 100644 --- a/tests/core/middleware/test_simple_cache_middleware.py +++ b/tests/core/middleware/test_simple_cache_middleware.py @@ -11,9 +11,22 @@ construct_result_generator_middleware, construct_simple_cache_middleware, ) +from web3.middleware.async_middleware.async_cache import ( + async_construct_simple_cache_middleware, +) +from web3.middleware.fixture import ( + async_construct_error_generator_middleware, + async_construct_result_generator_middleware, +) from web3.providers.base import ( BaseProvider, ) +from web3.providers.eth_tester import ( + AsyncEthereumTesterProvider, +) +from web3.types import ( + RPCEndpoint, +) @pytest.fixture @@ -133,3 +146,155 @@ def test_simple_cache_middleware_does_not_cache_endpoints_not_in_whitelist(w3): result_b = w3.manager.request_blocking("not_whitelisted", []) assert result_a != result_b + + +# -- async -- # + + +async def _async_simple_cache_middleware_for_testing(make_request, async_w3): + middleware = await async_construct_simple_cache_middleware( + cache_class=dict, + rpc_whitelist={RPCEndpoint("fake_endpoint")}, + ) + return await middleware(make_request, async_w3) + + +@pytest.fixture +def async_w3(): + return Web3( + provider=AsyncEthereumTesterProvider(), + middlewares=[ + (_async_simple_cache_middleware_for_testing, "simple_cache"), + ], + ) + + +@pytest.mark.asyncio +async def test_async_simple_cache_middleware_pulls_from_cache(async_w3): + # remove the pre-loaded simple cache middleware to replace with test-specific: + async_w3.middleware_onion.remove("simple_cache") + + def cache_class(): + return { + generate_cache_key(("fake_endpoint", [1])): {"result": "value-a"}, + } + + async def _properly_awaited_middleware(make_request, _async_w3): + middleware = await async_construct_simple_cache_middleware( + cache_class=cache_class, + rpc_whitelist={RPCEndpoint("fake_endpoint")}, + ) + return await middleware(make_request, _async_w3) + + async_w3.middleware_onion.inject( + _properly_awaited_middleware, + "for_this_test_only", + layer=0, + ) + + _result = await async_w3.manager.coro_request("fake_endpoint", [1]) + assert _result == "value-a" + + # cleanup + async_w3.middleware_onion.remove("for_this_test_only") + # add back the pre-loaded simple cache middleware: + async_w3.middleware_onion.add( + _async_simple_cache_middleware_for_testing, "simple_cache" + ) + + +@pytest.mark.asyncio +async def test_async_simple_cache_middleware_populates_cache(async_w3): + async_w3.middleware_onion.inject( + await async_construct_result_generator_middleware( + { + RPCEndpoint("fake_endpoint"): lambda *_: str(uuid.uuid4()), + } + ), + "result_generator", + layer=0, + ) + + result = await async_w3.manager.coro_request("fake_endpoint", []) + + _empty_params = await async_w3.manager.coro_request("fake_endpoint", []) + _non_empty_params = await async_w3.manager.coro_request("fake_endpoint", [1]) + + assert _empty_params == result + assert _non_empty_params != result + + # cleanup + async_w3.middleware_onion.remove("result_generator") + + +@pytest.mark.asyncio +async def test_async_simple_cache_middleware_does_not_cache_none_responses(async_w3): + counter = itertools.count() + + def result_cb(method, params): + next(counter) + return None + + async_w3.middleware_onion.inject( + await async_construct_result_generator_middleware( + { + RPCEndpoint("fake_endpoint"): result_cb, + }, + ), + "result_generator", + layer=0, + ) + + await async_w3.manager.coro_request("fake_endpoint", []) + await async_w3.manager.coro_request("fake_endpoint", []) + + assert next(counter) == 2 + + # cleanup + async_w3.middleware_onion.remove("result_generator") + + +@pytest.mark.asyncio +async def test_async_simple_cache_middleware_does_not_cache_error_responses(async_w3): + async_w3.middleware_onion.inject( + await async_construct_error_generator_middleware( + { + RPCEndpoint("fake_endpoint"): lambda *_: f"msg-{uuid.uuid4()}", + } + ), + "error_generator", + layer=0, + ) + + with pytest.raises(ValueError) as err_a: + await async_w3.manager.coro_request("fake_endpoint", []) + with pytest.raises(ValueError) as err_b: + await async_w3.manager.coro_request("fake_endpoint", []) + + assert str(err_a) != str(err_b) + + # cleanup + async_w3.middleware_onion.remove("error_generator") + + +@pytest.mark.asyncio +async def test_async_simple_cache_middleware_does_not_cache_non_whitelist_endpoints( + async_w3, +): + async_w3.middleware_onion.inject( + await async_construct_result_generator_middleware( + { + RPCEndpoint("not_whitelisted"): lambda *_: str(uuid.uuid4()), + } + ), + "result_generator", + layer=0, + ) + + result_a = await async_w3.manager.coro_request("not_whitelisted", []) + result_b = await async_w3.manager.coro_request("not_whitelisted", []) + + assert result_a != result_b + + # cleanup + async_w3.middleware_onion.remove("result_generator") diff --git a/web3/middleware/__init__.py b/web3/middleware/__init__.py index 0c1fbbc04d..49b72346d7 100644 --- a/web3/middleware/__init__.py +++ b/web3/middleware/__init__.py @@ -14,6 +14,9 @@ from .abi import ( # noqa: F401 abi_middleware, ) +from .async_middleware.async_cache import ( # noqa: F401 + _async_simple_cache_middleware as async_simple_cache_middleware, +) from .attrdict import ( # noqa: F401 attrdict_middleware, ) diff --git a/web3/middleware/async_middleware/async_cache.py b/web3/middleware/async_middleware/async_cache.py new file mode 100644 index 0000000000..efd0b8b2f5 --- /dev/null +++ b/web3/middleware/async_middleware/async_cache.py @@ -0,0 +1,104 @@ +import functools +import threading +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Collection, + Dict, + Tuple, + Type, + cast, +) + +import lru + +from web3._utils.caching import ( + generate_cache_key, +) +from web3.types import ( + AsyncMiddleware, + Middleware, + RPCEndpoint, + RPCResponse, +) + +if TYPE_CHECKING: + from web3 import Web3 # noqa: F401 + + +SIMPLE_CACHE_RPC_WHITELIST = cast( + Tuple[RPCEndpoint], + ( + "web3_clientVersion", + "eth_getBlockTransactionCountByHash", + "eth_getUncleCountByBlockHash", + "eth_getBlockByHash", + "eth_getTransactionByHash", + "eth_getTransactionByBlockHashAndIndex", + "eth_getRawTransactionByHash", + "eth_getUncleByBlockHashAndIndex", + "eth_chainId", + ), +) + + +def _should_cache_response( + _method: RPCEndpoint, _params: Any, response: RPCResponse +) -> bool: + return ( + "error" not in response + and "result" in response + and response["result"] is not None + ) + + +async def async_construct_simple_cache_middleware( + cache_class: Type[Dict[Any, Any]], + rpc_whitelist: Collection[RPCEndpoint] = SIMPLE_CACHE_RPC_WHITELIST, + should_cache_fn: Callable[ + [RPCEndpoint, Any, RPCResponse], bool + ] = _should_cache_response, +) -> Middleware: + """ + Constructs a middleware which caches responses based on the request + ``method`` and ``params`` + + :param cache_class: Any dictionary-like object + :param rpc_whitelist: A set of RPC methods which may have their responses cached. + :param should_cache_fn: A callable which accepts ``method`` ``params`` and + ``response`` and returns a boolean as to whether the response should be + cached. + """ + + async def async_simple_cache_middleware( + make_request: Callable[[RPCEndpoint, Any], Any], async_w3: "Web3" + ) -> AsyncMiddleware: + cache = cache_class() + lock = threading.Lock() + + async def middleware(method: RPCEndpoint, params: Any) -> RPCResponse: + if method in rpc_whitelist: + with lock: + cache_key = generate_cache_key((method, params)) + if cache_key not in cache: + response = await make_request(method, params) + if should_cache_fn(method, params, response): + cache[cache_key] = response + return response + return cache[cache_key] + else: + return await make_request(method, params) + + return middleware + + return async_simple_cache_middleware + + +async def _async_simple_cache_middleware( + make_request: Callable[[RPCEndpoint, Any], Any], async_w3: "Web3" +): + middleware = await async_construct_simple_cache_middleware( + cache_class=cast(Type[Dict[Any, Any]], functools.partial(lru.LRU, 256)), + ) + return await middleware(make_request, async_w3)