From 61b71d442438819fa3eb43a4d77094d814265528 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Janek=20Nouvertn=C3=A9?= Date: Sun, 1 Oct 2023 18:47:17 +0200 Subject: [PATCH] fix: #1301 - Apply compression before caching (#2393) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix: #1301 - Apply compression before caching Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> * fix typing Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> --------- Signed-off-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com> --- litestar/_asgi/routing_trie/mapping.py | 6 ++ litestar/constants.py | 1 + litestar/middleware/compression.py | 11 ++- litestar/middleware/response_cache.py | 48 +++++++++++++ litestar/routes/http.py | 48 ++++--------- tests/e2e/test_response_caching.py | 70 ++++++++++++++++++- .../test_compression_middleware.py | 23 ++++++ 7 files changed, 171 insertions(+), 36 deletions(-) create mode 100644 litestar/middleware/response_cache.py diff --git a/litestar/_asgi/routing_trie/mapping.py b/litestar/_asgi/routing_trie/mapping.py index ffa58dda80..d0f0fe5b7f 100644 --- a/litestar/_asgi/routing_trie/mapping.py +++ b/litestar/_asgi/routing_trie/mapping.py @@ -186,6 +186,8 @@ def build_route_middleware_stack( from litestar.middleware.allowed_hosts import AllowedHostsMiddleware from litestar.middleware.compression import CompressionMiddleware from litestar.middleware.csrf import CSRFMiddleware + from litestar.middleware.response_cache import ResponseCacheMiddleware + from litestar.routes import HTTPRoute # we wrap the route.handle method in the ExceptionHandlerMiddleware asgi_handler = wrap_in_exception_handler( @@ -197,6 +199,10 @@ def build_route_middleware_stack( if app.compression_config: asgi_handler = CompressionMiddleware(app=asgi_handler, config=app.compression_config) + + if isinstance(route, HTTPRoute) and any(r.cache for r in route.route_handlers): + asgi_handler = ResponseCacheMiddleware(app=asgi_handler, config=app.response_cache_config) + if app.allowed_hosts: asgi_handler = AllowedHostsMiddleware(app=asgi_handler, config=app.allowed_hosts) diff --git a/litestar/constants.py b/litestar/constants.py index 9db278b4a2..b3bb2c5b2c 100644 --- a/litestar/constants.py +++ b/litestar/constants.py @@ -20,6 +20,7 @@ SCOPE_STATE_DEPENDENCY_CACHE: Final = "dependency_cache" SCOPE_STATE_NAMESPACE: Final = "__litestar__" SCOPE_STATE_RESPONSE_COMPRESSED: Final = "response_compressed" +SCOPE_STATE_IS_CACHED: Final = "is_cached" SKIP_VALIDATION_NAMES: Final = {"request", "socket", "scope", "receive", "send"} UNDEFINED_SENTINELS: Final = {Signature.empty, Empty, Ellipsis, MISSING, UnsetType} WEBSOCKET_CLOSE: Final = "websocket.close" diff --git a/litestar/middleware/compression.py b/litestar/middleware/compression.py index 4648087010..e6443f05b2 100644 --- a/litestar/middleware/compression.py +++ b/litestar/middleware/compression.py @@ -4,12 +4,12 @@ from io import BytesIO from typing import TYPE_CHECKING, Any, Literal, Optional -from litestar.constants import SCOPE_STATE_RESPONSE_COMPRESSED +from litestar.constants import SCOPE_STATE_IS_CACHED, SCOPE_STATE_RESPONSE_COMPRESSED from litestar.datastructures import Headers, MutableScopeHeaders from litestar.enums import CompressionEncoding, ScopeType from litestar.exceptions import MissingDependencyException from litestar.middleware.base import AbstractMiddleware -from litestar.utils import Ref, set_litestar_scope_state +from litestar.utils import Ref, get_litestar_scope_state, set_litestar_scope_state __all__ = ("CompressionFacade", "CompressionMiddleware") @@ -176,6 +176,8 @@ def create_compression_send_wrapper( initial_message = Ref[Optional["HTTPResponseStartEvent"]](None) started = Ref[bool](False) + _own_encoding = compression_encoding.encode("latin-1") + async def send_wrapper(message: Message) -> None: """Handle and compresses the HTTP Message with brotli. @@ -187,6 +189,11 @@ async def send_wrapper(message: Message) -> None: initial_message.value = message return + if initial_message.value and get_litestar_scope_state(scope, SCOPE_STATE_IS_CACHED): + await send(initial_message.value) + await send(message) + return + if initial_message.value and message["type"] == "http.response.body": body = message["body"] more_body = message.get("more_body") diff --git a/litestar/middleware/response_cache.py b/litestar/middleware/response_cache.py new file mode 100644 index 0000000000..905a90f040 --- /dev/null +++ b/litestar/middleware/response_cache.py @@ -0,0 +1,48 @@ +from __future__ import annotations + +from msgspec.msgpack import encode as encode_msgpack + +from litestar.enums import ScopeType +from litestar.utils import get_litestar_scope_state + +from .base import AbstractMiddleware + +__all__ = ["ResponseCacheMiddleware"] + +from typing import TYPE_CHECKING, cast + +from litestar import Request +from litestar.constants import SCOPE_STATE_IS_CACHED + +if TYPE_CHECKING: + from litestar.config.response_cache import ResponseCacheConfig + from litestar.handlers import HTTPRouteHandler + from litestar.types import ASGIApp, Message, Receive, Scope, Send + + +class ResponseCacheMiddleware(AbstractMiddleware): + def __init__(self, app: ASGIApp, config: ResponseCacheConfig) -> None: + self.config = config + super().__init__(app=app, scopes={ScopeType.HTTP}) + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + route_handler = cast("HTTPRouteHandler", scope["route_handler"]) + store = self.config.get_store_from_app(scope["app"]) + + expires_in: int | None = None + if route_handler.cache is True: + expires_in = self.config.default_expiration + elif route_handler.cache is not False and isinstance(route_handler.cache, int): + expires_in = route_handler.cache + + messages = [] + + async def wrapped_send(message: Message) -> None: + if not get_litestar_scope_state(scope, SCOPE_STATE_IS_CACHED): + messages.append(message) + if message["type"] == "http.response.body" and not message["more_body"]: + key = (route_handler.cache_key_builder or self.config.key_builder)(Request(scope)) + await store.set(key, encode_msgpack(messages), expires_in=expires_in) + await send(message) + + await self.app(scope, receive, wrapped_send) diff --git a/litestar/routes/http.py b/litestar/routes/http.py index 1df06aa1f4..2582439f8c 100644 --- a/litestar/routes/http.py +++ b/litestar/routes/http.py @@ -1,10 +1,11 @@ from __future__ import annotations -import pickle from itertools import chain from typing import TYPE_CHECKING, Any, cast -from litestar.constants import DEFAULT_ALLOWED_CORS_HEADERS +from msgspec.msgpack import decode as _decode_msgpack_plain + +from litestar.constants import DEFAULT_ALLOWED_CORS_HEADERS, SCOPE_STATE_IS_CACHED from litestar.datastructures.headers import Headers from litestar.datastructures.upload_file import UploadFile from litestar.enums import HttpMethod, MediaType, ScopeType @@ -13,6 +14,7 @@ from litestar.response import Response from litestar.routes.base import BaseRoute from litestar.status_codes import HTTP_204_NO_CONTENT, HTTP_400_BAD_REQUEST +from litestar.utils import set_litestar_scope_state if TYPE_CHECKING: from litestar._kwargs import KwargsModel @@ -128,19 +130,10 @@ async def _get_response_for_request( ): return response - response = await self._call_handler_function( + return await self._call_handler_function( scope=scope, request=request, parameter_model=parameter_model, route_handler=route_handler ) - if route_handler.cache: - await self._set_cached_response( - response=response, - request=request, - route_handler=route_handler, - ) - - return response - async def _call_handler_function( self, scope: Scope, request: Request, parameter_model: KwargsModel, route_handler: HTTPRouteHandler ) -> ASGIApp: @@ -225,30 +218,19 @@ async def _get_cached_response(request: Request, route_handler: HTTPRouteHandler cache_key = (route_handler.cache_key_builder or cache_config.key_builder)(request) store = cache_config.get_store_from_app(request.app) - cached_response = await store.get(key=cache_key) - - if cached_response: - return cast("ASGIApp", pickle.loads(cached_response)) # noqa: S301 + if not (cached_response_data := await store.get(key=cache_key)): + return None - return None + # we use the regular msgspec.msgpack.decode here since we don't need any of + # the added decoders + messages = _decode_msgpack_plain(cached_response_data) - @staticmethod - async def _set_cached_response( - response: Response | ASGIApp, request: Request, route_handler: HTTPRouteHandler - ) -> None: - """Pickles and caches a response object.""" - cache_config = request.app.response_cache_config - cache_key = (route_handler.cache_key_builder or cache_config.key_builder)(request) - - expires_in: int | None = None - if route_handler.cache is True: - expires_in = cache_config.default_expiration - elif route_handler.cache is not False and isinstance(route_handler.cache, int): - expires_in = route_handler.cache - - store = cache_config.get_store_from_app(request.app) + async def cached_response(scope: Scope, receive: Receive, send: Send) -> None: + set_litestar_scope_state(scope, SCOPE_STATE_IS_CACHED, True) + for message in messages: + await send(message) - await store.set(key=cache_key, value=pickle.dumps(response, pickle.HIGHEST_PROTOCOL), expires_in=expires_in) + return cached_response def create_options_handler(self, path: str) -> HTTPRouteHandler: """Args: diff --git a/tests/e2e/test_response_caching.py b/tests/e2e/test_response_caching.py index a8bf7fedc7..8a33e6ca37 100644 --- a/tests/e2e/test_response_caching.py +++ b/tests/e2e/test_response_caching.py @@ -1,13 +1,18 @@ +import gzip import random from datetime import timedelta -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Optional, Type, Union from unittest.mock import MagicMock from uuid import uuid4 +import msgspec import pytest from litestar import Litestar, Request, get +from litestar.config.compression import CompressionConfig from litestar.config.response_cache import CACHE_FOREVER, ResponseCacheConfig +from litestar.enums import CompressionEncoding +from litestar.middleware.response_cache import ResponseCacheMiddleware from litestar.stores.base import Store from litestar.stores.memory import MemoryStore from litestar.testing import TestClient, create_test_client @@ -180,3 +185,66 @@ def handler() -> str: assert response_two.text == mock.return_value assert mock.call_count == 1 + + +def test_does_not_apply_to_non_cached_routes(mock: MagicMock) -> None: + @get("/") + def handler() -> str: + return mock() # type: ignore[no-any-return] + + with create_test_client([handler]) as client: + first_response = client.get("/") + second_response = client.get("/") + + assert first_response.status_code == 200 + assert second_response.status_code == 200 + assert mock.call_count == 2 + + +@pytest.mark.parametrize( + "cache,expect_applied", + [ + (True, True), + (False, False), + (1, True), + (CACHE_FOREVER, True), + ], +) +def test_middleware_not_applied_to_non_cached_routes( + cache: Union[bool, int, Type[CACHE_FOREVER]], expect_applied: bool +) -> None: + @get(path="/", cache=cache) + def handler() -> None: + ... + + client = create_test_client(route_handlers=[handler]) + unpacked_middleware = [] + cur = client.app.asgi_router.root_route_map_node.children["/"].asgi_handlers["GET"][0] + while hasattr(cur, "app"): + unpacked_middleware.append(cur) + cur = cur.app + unpacked_middleware.append(cur) + + assert len([m for m in unpacked_middleware if isinstance(m, ResponseCacheMiddleware)]) == int(expect_applied) + + +async def test_compression_applies_before_cache() -> None: + return_value = "_litestar_" * 4000 + mock = MagicMock(return_value=return_value) + + @get(path="/", cache=True) + def handler_fn() -> str: + return mock() # type: ignore[no-any-return] + + app = Litestar( + route_handlers=[handler_fn], + compression_config=CompressionConfig(backend="gzip"), + ) + + with TestClient(app) as client: + client.get("/", headers={"Accept-Encoding": str(CompressionEncoding.GZIP.value)}) + + stored_value = await app.response_cache_config.get_store_from_app(app).get("/") + assert stored_value + stored_messages = msgspec.msgpack.decode(stored_value) + assert gzip.decompress(stored_messages[1]["body"]).decode() == return_value diff --git a/tests/unit/test_middleware/test_compression_middleware.py b/tests/unit/test_middleware/test_compression_middleware.py index 2694e3ab6b..e1c3ec08bb 100644 --- a/tests/unit/test_middleware/test_compression_middleware.py +++ b/tests/unit/test_middleware/test_compression_middleware.py @@ -193,3 +193,26 @@ async def fake_send(message: Message) -> None: # second body message with more_body=True will be empty if zlib buffers output and is not flushed await wrapped_send(HTTPResponseBodyEvent(type="http.response.body", body=b"abc", more_body=True)) assert mock.mock_calls[-1].args[0]["body"] + + +@pytest.mark.parametrize( + "backend, compression_encoding", (("brotli", CompressionEncoding.BROTLI), ("gzip", CompressionEncoding.GZIP)) +) +def test_dont_recompress_cached(backend: Literal["gzip", "brotli"], compression_encoding: CompressionEncoding) -> None: + mock = MagicMock(return_value="_litestar_" * 4000) + + @get(path="/", media_type=MediaType.TEXT, cache=True) + def handler_fn() -> str: + return mock() # type: ignore[no-any-return] + + with create_test_client( + route_handlers=[handler_fn], compression_config=CompressionConfig(backend=backend) + ) as client: + client.get("/", headers={"Accept-Encoding": str(compression_encoding.value)}) + response = client.get("/", headers={"Accept-Encoding": str(compression_encoding.value)}) + + assert mock.call_count == 1 + assert response.status_code == HTTP_200_OK + assert response.text == "_litestar_" * 4000 + assert response.headers["Content-Encoding"] == compression_encoding + assert int(response.headers["Content-Length"]) < 40000