From 27b6f4c0e9ae613e4145797ad54f8036c6fb301c Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Sun, 29 Dec 2024 11:00:04 +0000 Subject: [PATCH] collect errors more reliably from websocket test client (#2814) --- starlette/testclient.py | 79 ++++++++++++++++++---------------------- tests/test_testclient.py | 18 ++++++--- 2 files changed, 48 insertions(+), 49 deletions(-) diff --git a/starlette/testclient.py b/starlette/testclient.py index 8e908d36f..a14f646d4 100644 --- a/starlette/testclient.py +++ b/starlette/testclient.py @@ -1,6 +1,7 @@ from __future__ import annotations import contextlib +import enum import inspect import io import json @@ -9,7 +10,6 @@ import sys import typing from concurrent.futures import Future -from functools import cached_property from types import GeneratorType from urllib.parse import unquote, urljoin @@ -85,6 +85,14 @@ class WebSocketDenialResponse( # type: ignore[misc] """ +class _Eof(enum.Enum): + EOF = enum.auto() + + +EOF: typing.Final = _Eof.EOF +Eof = typing.Literal[_Eof.EOF] + + class WebSocketTestSession: def __init__( self, @@ -97,63 +105,47 @@ def __init__( self.accepted_subprotocol = None self.portal_factory = portal_factory self._receive_queue: queue.Queue[Message] = queue.Queue() - self._send_queue: queue.Queue[Message | BaseException] = queue.Queue() + self._send_queue: queue.Queue[Message | Eof | BaseException] = queue.Queue() self.extra_headers = None def __enter__(self) -> WebSocketTestSession: - self.exit_stack = contextlib.ExitStack() - self.portal = self.exit_stack.enter_context(self.portal_factory()) - - try: - _: Future[None] = self.portal.start_task_soon(self._run) + with contextlib.ExitStack() as stack: + self.portal = portal = stack.enter_context(self.portal_factory()) + fut, cs = portal.start_task(self._run) + stack.callback(fut.result) + stack.callback(portal.call, cs.cancel) self.send({"type": "websocket.connect"}) message = self.receive() self._raise_on_close(message) - except Exception: - self.exit_stack.close() - raise - self.accepted_subprotocol = message.get("subprotocol", None) - self.extra_headers = message.get("headers", None) - return self - - @cached_property - def should_close(self) -> anyio.Event: - return anyio.Event() - - async def _notify_close(self) -> None: - self.should_close.set() + self.accepted_subprotocol = message.get("subprotocol", None) + self.extra_headers = message.get("headers", None) + stack.callback(self.close, 1000) + self.exit_stack = stack.pop_all() + return self def __exit__(self, *args: typing.Any) -> None: - try: - self.close(1000) - finally: - self.portal.start_task_soon(self._notify_close) - self.exit_stack.close() - while not self._send_queue.empty(): + self.exit_stack.close() + + while True: message = self._send_queue.get() + if message is EOF: + break if isinstance(message, BaseException): - raise message + raise message # pragma: no cover (defensive, should be impossible) - async def _run(self) -> None: + async def _run(self, *, task_status: anyio.abc.TaskStatus[anyio.CancelScope]) -> None: """ The sub-thread in which the websocket session runs. """ - - async def run_app(tg: anyio.abc.TaskGroup) -> None: - try: + try: + with anyio.CancelScope() as cs: + task_status.started(cs) await self.app(self.scope, self._asgi_receive, self._asgi_send) - except anyio.get_cancelled_exc_class(): - ... - except BaseException as exc: - self._send_queue.put(exc) - raise - finally: - tg.cancel_scope.cancel() - - async with anyio.create_task_group() as tg: - tg.start_soon(run_app, tg) - await self.should_close.wait() - tg.cancel_scope.cancel() + except BaseException as exc: + self._send_queue.put(exc) + raise + finally: + self._send_queue.put(EOF) # TODO: use self._send_queue.shutdown() on 3.13+ async def _asgi_receive(self) -> Message: while self._receive_queue.empty(): @@ -202,6 +194,7 @@ def close(self, code: int = 1000, reason: str | None = None) -> None: def receive(self) -> Message: message = self._send_queue.get() + assert message is not EOF if isinstance(message, BaseException): raise message return message diff --git a/tests/test_testclient.py b/tests/test_testclient.py index 58ab6f6f2..478dbca46 100644 --- a/tests/test_testclient.py +++ b/tests/test_testclient.py @@ -255,19 +255,25 @@ async def asgi(receive: Receive, send: Send) -> None: def test_websocket_not_block_on_close(test_client_factory: TestClientFactory) -> None: + cancelled = False + def app(scope: Scope) -> ASGIInstance: async def asgi(receive: Receive, send: Send) -> None: - websocket = WebSocket(scope, receive=receive, send=send) - await websocket.accept() - while True: - await anyio.sleep(0.1) + nonlocal cancelled + try: + websocket = WebSocket(scope, receive=receive, send=send) + await websocket.accept() + await anyio.sleep_forever() + except anyio.get_cancelled_exc_class(): + cancelled = True + raise return asgi client = test_client_factory(app) # type: ignore - with client.websocket_connect("/") as websocket: + with client.websocket_connect("/"): ... - assert websocket.should_close.is_set() + assert cancelled def test_client(test_client_factory: TestClientFactory) -> None: