From 54425703371c1553d00676e35d8e26709368e853 Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Thu, 26 Dec 2024 10:08:57 +0000 Subject: [PATCH 01/14] fix race condition in queue shutdown --- starlette/testclient.py | 28 ++++++++++++++++++++++------ 1 file changed, 22 insertions(+), 6 deletions(-) diff --git a/starlette/testclient.py b/starlette/testclient.py index 2c096aa22..11a8940b0 100644 --- a/starlette/testclient.py +++ b/starlette/testclient.py @@ -1,6 +1,7 @@ from __future__ import annotations import contextlib +import enum import inspect import io import json @@ -85,6 +86,14 @@ class WebSocketDenialResponse( # type: ignore[misc] """ +class _Eof(enum.Enum): + EOF = enum.auto() + + +EOF: typing.Final = _Eof.EOF +Eof = typing.Literal[_Eof.EOF] + + class WebSocketTestSession: def __init__( self, @@ -97,7 +106,7 @@ def __init__( self.accepted_subprotocol = None self.portal_factory = portal_factory self._receive_queue: queue.Queue[Message] = queue.Queue() - self._send_queue: queue.Queue[Message | BaseException] = queue.Queue() + self._send_queue: queue.Queue[Message | Eof | BaseException] = queue.Queue() self.extra_headers = None def __enter__(self) -> WebSocketTestSession: @@ -129,8 +138,11 @@ def __exit__(self, *args: typing.Any) -> None: finally: self.portal.start_task_soon(self._notify_close) self.exit_stack.close() - while not self._send_queue.empty(): + + while True: message = self._send_queue.get() + if message is EOF: + break if isinstance(message, BaseException): raise message @@ -150,10 +162,13 @@ async def run_app(tg: anyio.abc.TaskGroup) -> None: finally: tg.cancel_scope.cancel() - async with anyio.create_task_group() as tg: - tg.start_soon(run_app, tg) - await self.should_close.wait() - tg.cancel_scope.cancel() + try: + async with anyio.create_task_group() as tg: + tg.start_soon(run_app, tg) + await self.should_close.wait() + tg.cancel_scope.cancel() + finally: + self._send_queue.put(EOF) async def _asgi_receive(self) -> Message: while self._receive_queue.empty(): @@ -202,6 +217,7 @@ def close(self, code: int = 1000, reason: str | None = None) -> None: def receive(self) -> Message: message = self._send_queue.get() + assert message is not EOF if isinstance(message, BaseException): raise message return message From 3978a1728882df86cb4bf93db30ef76fc22afefb Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Thu, 26 Dec 2024 10:13:48 +0000 Subject: [PATCH 02/14] Update starlette/testclient.py --- starlette/testclient.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/starlette/testclient.py b/starlette/testclient.py index 11a8940b0..2f540e99b 100644 --- a/starlette/testclient.py +++ b/starlette/testclient.py @@ -168,7 +168,7 @@ async def run_app(tg: anyio.abc.TaskGroup) -> None: await self.should_close.wait() tg.cancel_scope.cancel() finally: - self._send_queue.put(EOF) + self._send_queue.put(EOF) # TODO: use self._send_queue.shutdown() on 3.13+ async def _asgi_receive(self) -> Message: while self._receive_queue.empty(): From 8eec43514753e6d30f061af1625335d2128dc056 Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Thu, 26 Dec 2024 10:25:43 +0000 Subject: [PATCH 03/14] Update testclient.py --- starlette/testclient.py | 1 - 1 file changed, 1 deletion(-) diff --git a/starlette/testclient.py b/starlette/testclient.py index 2f540e99b..ac9b46f81 100644 --- a/starlette/testclient.py +++ b/starlette/testclient.py @@ -17,7 +17,6 @@ import anyio import anyio.abc import anyio.from_thread -from anyio.abc import ObjectReceiveStream, ObjectSendStream from anyio.streams.stapled import StapledObjectStream from starlette._utils import is_async_callable From 44c1bc2bf786b66b5a42ea4b780fd40cd9608c32 Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Thu, 26 Dec 2024 11:06:31 +0000 Subject: [PATCH 04/14] refactor exit stack for test client --- starlette/testclient.py | 41 +++++++++++++++++++---------------------- 1 file changed, 19 insertions(+), 22 deletions(-) diff --git a/starlette/testclient.py b/starlette/testclient.py index ac9b46f81..69a930bad 100644 --- a/starlette/testclient.py +++ b/starlette/testclient.py @@ -19,7 +19,7 @@ import anyio.from_thread from anyio.streams.stapled import StapledObjectStream -from starlette._utils import is_async_callable +from starlette._utils import collapse_excgroups, is_async_callable from starlette.types import ASGIApp, Message, Receive, Scope, Send from starlette.websockets import WebSocketDisconnect @@ -109,20 +109,20 @@ def __init__( self.extra_headers = None def __enter__(self) -> WebSocketTestSession: - self.exit_stack = contextlib.ExitStack() - self.portal = self.exit_stack.enter_context(self.portal_factory()) + with contextlib.ExitStack() as stack: + self.portal = portal = stack.enter_context(self.portal_factory()) - try: - _: Future[None] = self.portal.start_task_soon(self._run) + fut: Future[None] = self.portal.start_task_soon(self._run) self.send({"type": "websocket.connect"}) message = self.receive() self._raise_on_close(message) - except Exception: - self.exit_stack.close() - raise - self.accepted_subprotocol = message.get("subprotocol", None) - self.extra_headers = message.get("headers", None) - return self + self.accepted_subprotocol = message.get("subprotocol", None) + self.extra_headers = message.get("headers", None) + stack.callback(fut.result) + stack.callback(portal.call, self._notify_close) + stack.callback(self.close, 1000) + self.exit_stack = stack.pop_all() + return self @cached_property def should_close(self) -> anyio.Event: @@ -132,18 +132,14 @@ async def _notify_close(self) -> None: self.should_close.set() def __exit__(self, *args: typing.Any) -> None: - try: - self.close(1000) - finally: - self.portal.start_task_soon(self._notify_close) - self.exit_stack.close() + self.exit_stack.close() while True: message = self._send_queue.get() if message is EOF: break if isinstance(message, BaseException): - raise message + raise message # pragma: no cover (defensive, should be impossible) async def _run(self) -> None: """ @@ -154,7 +150,7 @@ async def run_app(tg: anyio.abc.TaskGroup) -> None: try: await self.app(self.scope, self._asgi_receive, self._asgi_send) except anyio.get_cancelled_exc_class(): - ... + raise except BaseException as exc: self._send_queue.put(exc) raise @@ -162,10 +158,11 @@ async def run_app(tg: anyio.abc.TaskGroup) -> None: tg.cancel_scope.cancel() try: - async with anyio.create_task_group() as tg: - tg.start_soon(run_app, tg) - await self.should_close.wait() - tg.cancel_scope.cancel() + with collapse_excgroups(): + async with anyio.create_task_group() as tg: + tg.start_soon(run_app, tg) + await self.should_close.wait() + tg.cancel_scope.cancel() finally: self._send_queue.put(EOF) # TODO: use self._send_queue.shutdown() on 3.13+ From c210c27c7fa9debe6125e29367d3aab9f22d1a1f Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Thu, 26 Dec 2024 11:12:11 +0000 Subject: [PATCH 05/14] run tests a bunch of times --- .github/workflows/test-suite.yml | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/.github/workflows/test-suite.yml b/.github/workflows/test-suite.yml index fac946039..08f38b0ee 100644 --- a/.github/workflows/test-suite.yml +++ b/.github/workflows/test-suite.yml @@ -30,5 +30,23 @@ jobs: run: "scripts/build" - name: "Run tests" run: "scripts/test" + - name: "Run tests" + run: "scripts/test" + - name: "Run tests" + run: "scripts/test" + - name: "Run tests" + run: "scripts/test" + - name: "Run tests" + run: "scripts/test" + - name: "Run tests" + run: "scripts/test" + - name: "Run tests" + run: "scripts/test" + - name: "Run tests" + run: "scripts/test" + - name: "Run tests" + run: "scripts/test" + - name: "Run tests" + run: "scripts/test" - name: "Enforce coverage" run: "scripts/coverage" From 3fde98b8002e23422beb2c18410489e9f120ec2e Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Fri, 27 Dec 2024 09:59:11 +0000 Subject: [PATCH 06/14] Revert "run tests a bunch of times" This reverts commit c210c27c7fa9debe6125e29367d3aab9f22d1a1f. --- .github/workflows/test-suite.yml | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/.github/workflows/test-suite.yml b/.github/workflows/test-suite.yml index 08f38b0ee..fac946039 100644 --- a/.github/workflows/test-suite.yml +++ b/.github/workflows/test-suite.yml @@ -30,23 +30,5 @@ jobs: run: "scripts/build" - name: "Run tests" run: "scripts/test" - - name: "Run tests" - run: "scripts/test" - - name: "Run tests" - run: "scripts/test" - - name: "Run tests" - run: "scripts/test" - - name: "Run tests" - run: "scripts/test" - - name: "Run tests" - run: "scripts/test" - - name: "Run tests" - run: "scripts/test" - - name: "Run tests" - run: "scripts/test" - - name: "Run tests" - run: "scripts/test" - - name: "Run tests" - run: "scripts/test" - name: "Enforce coverage" run: "scripts/coverage" From 4d2ddbcc71b7d461e92502004034a1f42ab87e85 Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Fri, 27 Dec 2024 10:57:57 +0000 Subject: [PATCH 07/14] Update starlette/testclient.py --- starlette/testclient.py | 1 + 1 file changed, 1 insertion(+) diff --git a/starlette/testclient.py b/starlette/testclient.py index 69a930bad..020f0a0d5 100644 --- a/starlette/testclient.py +++ b/starlette/testclient.py @@ -17,6 +17,7 @@ import anyio import anyio.abc import anyio.from_thread +from anyio.abc import ObjectReceiveStream, ObjectSendStream from anyio.streams.stapled import StapledObjectStream from starlette._utils import collapse_excgroups, is_async_callable From 57701e7fe2a15e29c479a94df429fee444bcc29c Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Fri, 27 Dec 2024 11:04:11 +0000 Subject: [PATCH 08/14] avoid a task group --- starlette/testclient.py | 36 +++++++++++------------------------- 1 file changed, 11 insertions(+), 25 deletions(-) diff --git a/starlette/testclient.py b/starlette/testclient.py index 020f0a0d5..c87a4ae36 100644 --- a/starlette/testclient.py +++ b/starlette/testclient.py @@ -10,7 +10,6 @@ import sys import typing from concurrent.futures import Future -from functools import cached_property from types import GeneratorType from urllib.parse import unquote, urljoin @@ -20,7 +19,7 @@ from anyio.abc import ObjectReceiveStream, ObjectSendStream from anyio.streams.stapled import StapledObjectStream -from starlette._utils import collapse_excgroups, is_async_callable +from starlette._utils import is_async_callable from starlette.types import ASGIApp, Message, Receive, Scope, Send from starlette.websockets import WebSocketDisconnect @@ -108,30 +107,24 @@ def __init__( self._receive_queue: queue.Queue[Message] = queue.Queue() self._send_queue: queue.Queue[Message | Eof | BaseException] = queue.Queue() self.extra_headers = None + self.should_close: anyio.Event def __enter__(self) -> WebSocketTestSession: with contextlib.ExitStack() as stack: self.portal = portal = stack.enter_context(self.portal_factory()) - fut: Future[None] = self.portal.start_task_soon(self._run) + fut, cs = self.portal.start_task(self._run) self.send({"type": "websocket.connect"}) message = self.receive() self._raise_on_close(message) self.accepted_subprotocol = message.get("subprotocol", None) self.extra_headers = message.get("headers", None) stack.callback(fut.result) - stack.callback(portal.call, self._notify_close) + stack.callback(portal.call, cs.cancel) stack.callback(self.close, 1000) self.exit_stack = stack.pop_all() return self - @cached_property - def should_close(self) -> anyio.Event: - return anyio.Event() - - async def _notify_close(self) -> None: - self.should_close.set() - def __exit__(self, *args: typing.Any) -> None: self.exit_stack.close() @@ -142,28 +135,21 @@ def __exit__(self, *args: typing.Any) -> None: if isinstance(message, BaseException): raise message # pragma: no cover (defensive, should be impossible) - async def _run(self) -> None: + async def _run(self, *, task_status: anyio.abc.TaskStatus[anyio.CancelScope]) -> None: """ The sub-thread in which the websocket session runs. """ - - async def run_app(tg: anyio.abc.TaskGroup) -> None: + try: try: - await self.app(self.scope, self._asgi_receive, self._asgi_send) - except anyio.get_cancelled_exc_class(): - raise + self.should_close = anyio.Event() + with anyio.CancelScope() as cs: + task_status.started(cs) + await self.app(self.scope, self._asgi_receive, self._asgi_send) except BaseException as exc: self._send_queue.put(exc) raise finally: - tg.cancel_scope.cancel() - - try: - with collapse_excgroups(): - async with anyio.create_task_group() as tg: - tg.start_soon(run_app, tg) - await self.should_close.wait() - tg.cancel_scope.cancel() + self.should_close.set() finally: self._send_queue.put(EOF) # TODO: use self._send_queue.shutdown() on 3.13+ From afc65e7fb4d259f07beb1863502be4bd41fab6fd Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Sat, 28 Dec 2024 07:57:35 +0000 Subject: [PATCH 09/14] move task cancelling/future handling up --- starlette/testclient.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/starlette/testclient.py b/starlette/testclient.py index 8d6f83a29..f6efda0d0 100644 --- a/starlette/testclient.py +++ b/starlette/testclient.py @@ -114,13 +114,14 @@ def __enter__(self) -> WebSocketTestSession: self.portal = portal = stack.enter_context(self.portal_factory()) fut, cs = self.portal.start_task(self._run) + stack.callback(fut.result) + stack.callback(portal.call, cs.cancel) + self.send({"type": "websocket.connect"}) message = self.receive() self._raise_on_close(message) self.accepted_subprotocol = message.get("subprotocol", None) self.extra_headers = message.get("headers", None) - stack.callback(fut.result) - stack.callback(portal.call, cs.cancel) stack.callback(self.close, 1000) self.exit_stack = stack.pop_all() return self From 6e60cda50491f90601bb01f7793d68b21ddbd299 Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Sat, 28 Dec 2024 08:08:07 +0000 Subject: [PATCH 10/14] add workaround for py3.8 bug --- starlette/testclient.py | 54 +++++++++++++++++++++++++++++------------ 1 file changed, 38 insertions(+), 16 deletions(-) diff --git a/starlette/testclient.py b/starlette/testclient.py index f6efda0d0..185cfaf60 100644 --- a/starlette/testclient.py +++ b/starlette/testclient.py @@ -93,6 +93,12 @@ class _Eof(enum.Enum): Eof = typing.Literal[_Eof.EOF] +class _HadException(Exception): + def __init__(self, wrapped: BaseException, /, *args: object): + super().__init__(wrapped, *args) + self.wrapped: typing.Final = wrapped + + class WebSocketTestSession: def __init__( self, @@ -110,24 +116,40 @@ def __init__( self.should_close: anyio.Event def __enter__(self) -> WebSocketTestSession: - with contextlib.ExitStack() as stack: - self.portal = portal = stack.enter_context(self.portal_factory()) - - fut, cs = self.portal.start_task(self._run) - stack.callback(fut.result) - stack.callback(portal.call, cs.cancel) - - self.send({"type": "websocket.connect"}) - message = self.receive() - self._raise_on_close(message) - self.accepted_subprotocol = message.get("subprotocol", None) - self.extra_headers = message.get("headers", None) - stack.callback(self.close, 1000) - self.exit_stack = stack.pop_all() - return self + try: + with contextlib.ExitStack() as stack: + self.portal = portal = stack.enter_context(self.portal_factory()) + + fut, cs = self.portal.start_task(self._run) + + @stack.callback + def handle_task() -> None: + portal.call(cs.cancel) + e = fut.exception() + if e is None: + return + # work-around for https://github.com/python/cpython/issues/69968 + try: + raise _HadException(e) + finally: + del e + + self.send({"type": "websocket.connect"}) + message = self.receive() + self._raise_on_close(message) + self.accepted_subprotocol = message.get("subprotocol", None) + self.extra_headers = message.get("headers", None) + stack.callback(self.close, 1000) + self.exit_stack = stack.pop_all() + return self + except _HadException as e: + raise e.wrapped def __exit__(self, *args: typing.Any) -> None: - self.exit_stack.close() + try: + self.exit_stack.close() + except _HadException as e: + raise e.wrapped while True: message = self._send_queue.get() From 3c8071f4c903dbb62158cb6cfa57623154c4f4ee Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Sat, 28 Dec 2024 08:26:10 +0000 Subject: [PATCH 11/14] nicer work-around --- starlette/testclient.py | 62 +++++++++++++++++------------------------ 1 file changed, 26 insertions(+), 36 deletions(-) diff --git a/starlette/testclient.py b/starlette/testclient.py index 185cfaf60..1e3614378 100644 --- a/starlette/testclient.py +++ b/starlette/testclient.py @@ -92,11 +92,20 @@ class _Eof(enum.Enum): EOF: typing.Final = _Eof.EOF Eof = typing.Literal[_Eof.EOF] +_T_co = typing.TypeVar("_T_co", covariant=True) -class _HadException(Exception): - def __init__(self, wrapped: BaseException, /, *args: object): - super().__init__(wrapped, *args) - self.wrapped: typing.Final = wrapped + +class StartableAsyncFn(typing.Generic[_T_co], typing.Protocol): + async def __call__(self, /, *, task_status: anyio.abc.TaskStatus[_T_co]) -> None: ... + + +@contextlib.contextmanager +def _handle_task(portal: anyio.abc.BlockingPortal, async_fn: StartableAsyncFn[_T_co]) -> typing.Generator[_T_co]: + fut, result = portal.start_task(async_fn) + try: + yield result + finally: + fut.result() class WebSocketTestSession: @@ -116,40 +125,21 @@ def __init__( self.should_close: anyio.Event def __enter__(self) -> WebSocketTestSession: - try: - with contextlib.ExitStack() as stack: - self.portal = portal = stack.enter_context(self.portal_factory()) - - fut, cs = self.portal.start_task(self._run) - - @stack.callback - def handle_task() -> None: - portal.call(cs.cancel) - e = fut.exception() - if e is None: - return - # work-around for https://github.com/python/cpython/issues/69968 - try: - raise _HadException(e) - finally: - del e - - self.send({"type": "websocket.connect"}) - message = self.receive() - self._raise_on_close(message) - self.accepted_subprotocol = message.get("subprotocol", None) - self.extra_headers = message.get("headers", None) - stack.callback(self.close, 1000) - self.exit_stack = stack.pop_all() - return self - except _HadException as e: - raise e.wrapped + with contextlib.ExitStack() as stack: + self.portal = portal = stack.enter_context(self.portal_factory()) + cs = stack.enter_context(_handle_task(portal, self._run)) + stack.callback(portal.call, cs.cancel) + self.send({"type": "websocket.connect"}) + message = self.receive() + self._raise_on_close(message) + self.accepted_subprotocol = message.get("subprotocol", None) + self.extra_headers = message.get("headers", None) + stack.callback(self.close, 1000) + self.exit_stack = stack.pop_all() + return self def __exit__(self, *args: typing.Any) -> None: - try: - self.exit_stack.close() - except _HadException as e: - raise e.wrapped + self.exit_stack.close() while True: message = self._send_queue.get() From 2ef7a53fe2db600baa730de0e80beb8023392ba6 Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Sat, 28 Dec 2024 08:43:32 +0000 Subject: [PATCH 12/14] remove WebSocketTestSession should_close --- starlette/testclient.py | 17 ++++++----------- tests/test_testclient.py | 18 ++++++++++++------ 2 files changed, 18 insertions(+), 17 deletions(-) diff --git a/starlette/testclient.py b/starlette/testclient.py index 1e3614378..371d426f8 100644 --- a/starlette/testclient.py +++ b/starlette/testclient.py @@ -122,7 +122,6 @@ def __init__( self._receive_queue: queue.Queue[Message] = queue.Queue() self._send_queue: queue.Queue[Message | Eof | BaseException] = queue.Queue() self.extra_headers = None - self.should_close: anyio.Event def __enter__(self) -> WebSocketTestSession: with contextlib.ExitStack() as stack: @@ -153,16 +152,12 @@ async def _run(self, *, task_status: anyio.abc.TaskStatus[anyio.CancelScope]) -> The sub-thread in which the websocket session runs. """ try: - try: - self.should_close = anyio.Event() - with anyio.CancelScope() as cs: - task_status.started(cs) - await self.app(self.scope, self._asgi_receive, self._asgi_send) - except BaseException as exc: - self._send_queue.put(exc) - raise - finally: - self.should_close.set() + with anyio.CancelScope() as cs: + task_status.started(cs) + await self.app(self.scope, self._asgi_receive, self._asgi_send) + except BaseException as exc: + self._send_queue.put(exc) + raise finally: self._send_queue.put(EOF) # TODO: use self._send_queue.shutdown() on 3.13+ diff --git a/tests/test_testclient.py b/tests/test_testclient.py index 279b81d91..589b88315 100644 --- a/tests/test_testclient.py +++ b/tests/test_testclient.py @@ -254,19 +254,25 @@ async def asgi(receive: Receive, send: Send) -> None: def test_websocket_not_block_on_close(test_client_factory: TestClientFactory) -> None: + cancelled = False + def app(scope: Scope) -> ASGIInstance: async def asgi(receive: Receive, send: Send) -> None: - websocket = WebSocket(scope, receive=receive, send=send) - await websocket.accept() - while True: - await anyio.sleep(0.1) + nonlocal cancelled + try: + websocket = WebSocket(scope, receive=receive, send=send) + await websocket.accept() + await anyio.sleep_forever() + except anyio.get_cancelled_exc_class(): + cancelled = True + raise return asgi client = test_client_factory(app) # type: ignore - with client.websocket_connect("/") as websocket: + with client.websocket_connect("/"): ... - assert websocket.should_close.is_set() + assert cancelled def test_client(test_client_factory: TestClientFactory) -> None: From 509b9c0e31d388a27c5829023b550617547b0029 Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Sat, 28 Dec 2024 08:45:07 +0000 Subject: [PATCH 13/14] Update starlette/testclient.py --- starlette/testclient.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/starlette/testclient.py b/starlette/testclient.py index 371d426f8..5daa15ff4 100644 --- a/starlette/testclient.py +++ b/starlette/testclient.py @@ -105,6 +105,8 @@ def _handle_task(portal: anyio.abc.BlockingPortal, async_fn: StartableAsyncFn[_T try: yield result finally: + # can't raise an exception from stack.callback on Python 3.8 + # due to https://github.com/python/cpython/issues/69968 fut.result() From 82dad82e0eecf6349cb4bc71cda63e33d92f4428 Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Sun, 29 Dec 2024 09:32:11 +0000 Subject: [PATCH 14/14] remove 3.8 work-around --- starlette/testclient.py | 20 ++------------------ 1 file changed, 2 insertions(+), 18 deletions(-) diff --git a/starlette/testclient.py b/starlette/testclient.py index 5daa15ff4..a14f646d4 100644 --- a/starlette/testclient.py +++ b/starlette/testclient.py @@ -92,23 +92,6 @@ class _Eof(enum.Enum): EOF: typing.Final = _Eof.EOF Eof = typing.Literal[_Eof.EOF] -_T_co = typing.TypeVar("_T_co", covariant=True) - - -class StartableAsyncFn(typing.Generic[_T_co], typing.Protocol): - async def __call__(self, /, *, task_status: anyio.abc.TaskStatus[_T_co]) -> None: ... - - -@contextlib.contextmanager -def _handle_task(portal: anyio.abc.BlockingPortal, async_fn: StartableAsyncFn[_T_co]) -> typing.Generator[_T_co]: - fut, result = portal.start_task(async_fn) - try: - yield result - finally: - # can't raise an exception from stack.callback on Python 3.8 - # due to https://github.com/python/cpython/issues/69968 - fut.result() - class WebSocketTestSession: def __init__( @@ -128,7 +111,8 @@ def __init__( def __enter__(self) -> WebSocketTestSession: with contextlib.ExitStack() as stack: self.portal = portal = stack.enter_context(self.portal_factory()) - cs = stack.enter_context(_handle_task(portal, self._run)) + fut, cs = portal.start_task(self._run) + stack.callback(fut.result) stack.callback(portal.call, cs.cancel) self.send({"type": "websocket.connect"}) message = self.receive()