diff --git a/starlette/responses.py b/starlette/responses.py index bbd205b13a..dae0a9dd17 100644 --- a/starlette/responses.py +++ b/starlette/responses.py @@ -253,7 +253,7 @@ async def stream_response(self, send: Send) -> None: # We got an ASGI message which is not response body (eg: pathsend) should_close_body = False await send(chunk) - break + continue if isinstance(chunk, str): chunk = chunk.encode(self.charset) await send({"type": "http.response.body", "body": chunk, "more_body": True}) diff --git a/tests/middleware/test_base.py b/tests/middleware/test_base.py index 2176404d82..34fc27c79d 100644 --- a/tests/middleware/test_base.py +++ b/tests/middleware/test_base.py @@ -2,6 +2,7 @@ import contextvars from contextlib import AsyncExitStack +from pathlib import Path from typing import ( Any, AsyncGenerator, @@ -18,7 +19,12 @@ from starlette.middleware import Middleware, _MiddlewareClass from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint from starlette.requests import Request -from starlette.responses import PlainTextResponse, Response, StreamingResponse +from starlette.responses import ( + FileResponse, + PlainTextResponse, + Response, + StreamingResponse, +) from starlette.routing import Route, WebSocketRoute from starlette.testclient import TestClient from starlette.types import ASGIApp, Message, Receive, Scope, Send @@ -1035,3 +1041,54 @@ async def endpoint(request: Request) -> Response: resp.raise_for_status() assert bodies == [b"Hello, World!-foo"] + + +@pytest.mark.anyio +async def test_asgi_pathsend_events(tmpdir: Path) -> None: + path = tmpdir / "example.txt" + with path.open("w") as file: + file.write("") + + request_body_sent = False + response_complete = anyio.Event() + events: list[Message] = [] + + async def endpoint_with_pathsend(_: Request) -> FileResponse: + return FileResponse(path) + + async def passthrough( + request: Request, call_next: RequestResponseEndpoint + ) -> Response: + return await call_next(request) + + app = Starlette( + middleware=[Middleware(BaseHTTPMiddleware, dispatch=passthrough)], + routes=[Route("/", endpoint_with_pathsend)], + ) + + scope = { + "type": "http", + "version": "3", + "method": "GET", + "path": "/", + "extensions": {"http.response.pathsend": {}}, + } + + async def receive() -> Message: + nonlocal request_body_sent + if not request_body_sent: + request_body_sent = True + return {"type": "http.request", "body": b"", "more_body": False} + await response_complete.wait() + return {"type": "http.disconnect"} + + async def send(message: Message) -> None: + events.append(message) + if message["type"] == "http.response.pathsend": + response_complete.set() + + await app(scope, receive, send) + + assert len(events) == 2 + assert events[0]["type"] == "http.response.start" + assert events[1]["type"] == "http.response.pathsend" diff --git a/tests/middleware/test_gzip.py b/tests/middleware/test_gzip.py index 5bfecadb72..83c9f19f3e 100644 --- a/tests/middleware/test_gzip.py +++ b/tests/middleware/test_gzip.py @@ -1,13 +1,24 @@ +from __future__ import annotations + +from pathlib import Path from typing import Callable +import anyio +import pytest + from starlette.applications import Starlette from starlette.middleware import Middleware from starlette.middleware.gzip import GZipMiddleware from starlette.requests import Request -from starlette.responses import ContentStream, PlainTextResponse, StreamingResponse +from starlette.responses import ( + ContentStream, + FileResponse, + PlainTextResponse, + StreamingResponse, +) from starlette.routing import Route from starlette.testclient import TestClient -from starlette.types import ASGIApp +from starlette.types import ASGIApp, Message TestClientFactory = Callable[[ASGIApp], TestClient] @@ -111,3 +122,50 @@ async def generator(bytes: bytes, count: int) -> ContentStream: assert response.text == "x" * 4000 assert response.headers["Content-Encoding"] == "text" assert "Content-Length" not in response.headers + + +@pytest.mark.anyio +async def test_gzip_ignored_for_pathsend_responses(tmpdir: Path) -> None: + path = tmpdir / "example.txt" + with path.open("w") as file: + file.write("") + + request_body_sent = False + response_complete = anyio.Event() + events: list[Message] = [] + + async def endpoint_with_pathsend(_: Request) -> FileResponse: + return FileResponse(path) + + app = Starlette( + routes=[Route("/", endpoint=endpoint_with_pathsend)], + middleware=[Middleware(GZipMiddleware)], + ) + + scope = { + "type": "http", + "version": "3", + "method": "GET", + "path": "/", + "headers": [(b"accept-encoding", b"gzip, text")], + "extensions": {"http.response.pathsend": {}}, + } + + async def receive() -> Message: + nonlocal request_body_sent + if not request_body_sent: + request_body_sent = True + return {"type": "http.request", "body": b"", "more_body": False} + await response_complete.wait() + return {"type": "http.disconnect"} + + async def send(message: Message) -> None: + events.append(message) + if message["type"] == "http.response.pathsend": + response_complete.set() + + await app(scope, receive, send) + + assert len(events) == 2 + assert events[0]["type"] == "http.response.start" + assert events[1]["type"] == "http.response.pathsend"