Skip to content

Commit

Permalink
Revert "Upgraded to AnyIO 4.0 (#2211)"
Browse files Browse the repository at this point in the history
This reverts commit 1a71441.
  • Loading branch information
Kludex committed Aug 18, 2023
1 parent a8b8856 commit e9d4fc4
Show file tree
Hide file tree
Showing 6 changed files with 12 additions and 54 deletions.
3 changes: 1 addition & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@ types-contextvars==2.4.7.2
types-PyYAML==6.0.12.10
types-dataclasses==0.6.6
pytest==7.4.0
trio==0.22.1
anyio@git+https://github.com/agronholm/anyio.git
trio==0.21.0

# Documentation
mkdocs==1.4.3
Expand Down
28 changes: 4 additions & 24 deletions starlette/middleware/base.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,19 @@
import sys
import typing
from contextlib import contextmanager

import anyio
from anyio.abc import ObjectReceiveStream, ObjectSendStream

from starlette.background import BackgroundTask
from starlette.requests import ClientDisconnect, Request
from starlette.responses import ContentStream, Response, StreamingResponse
from starlette.types import ASGIApp, Message, Receive, Scope, Send

if sys.version_info < (3, 11): # pragma: no cover
from exceptiongroup import BaseExceptionGroup

RequestResponseEndpoint = typing.Callable[[Request], typing.Awaitable[Response]]
DispatchFunction = typing.Callable[
[Request, RequestResponseEndpoint], typing.Awaitable[Response]
]
T = typing.TypeVar("T")


@contextmanager
def _convert_excgroups() -> typing.Generator[None, None, None]:
try:
yield
except BaseException as exc:
while isinstance(exc, BaseExceptionGroup) and len(exc.exceptions) == 1:
exc = exc.exceptions[0]

raise exc


class _CachedRequest(Request):
"""
If the user calls Request.body() from their dispatch function
Expand Down Expand Up @@ -124,8 +107,6 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:

async def call_next(request: Request) -> Response:
app_exc: typing.Optional[Exception] = None
send_stream: ObjectSendStream[typing.MutableMapping[str, typing.Any]]
recv_stream: ObjectReceiveStream[typing.MutableMapping[str, typing.Any]]
send_stream, recv_stream = anyio.create_memory_object_stream()

async def receive_or_disconnect() -> Message:
Expand Down Expand Up @@ -201,11 +182,10 @@ async def body_stream() -> typing.AsyncGenerator[bytes, None]:
response.raw_headers = message["headers"]
return response

with _convert_excgroups():
async with anyio.create_task_group() as task_group:
response = await self.dispatch_func(request, call_next)
await response(scope, wrapped_receive, send)
response_sent.set()
async with anyio.create_task_group() as task_group:
response = await self.dispatch_func(request, call_next)
await response(scope, wrapped_receive, send)
response_sent.set()

async def dispatch(
self, request: Request, call_next: RequestResponseEndpoint
Expand Down
4 changes: 0 additions & 4 deletions starlette/middleware/wsgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import warnings

import anyio
from anyio.abc import ObjectReceiveStream, ObjectSendStream

from starlette.types import Receive, Scope, Send

Expand Down Expand Up @@ -73,9 +72,6 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:


class WSGIResponder:
stream_send: ObjectSendStream[typing.MutableMapping[str, typing.Any]]
stream_receive: ObjectReceiveStream[typing.MutableMapping[str, typing.Any]]

def __init__(self, app: typing.Callable[..., typing.Any], scope: Scope) -> None:
self.app = app
self.scope = scope
Expand Down
19 changes: 6 additions & 13 deletions starlette/testclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@

import anyio
import anyio.from_thread
from anyio.abc import ObjectReceiveStream, ObjectSendStream
from anyio.streams.stapled import StapledObjectStream

from starlette._utils import is_async_callable
Expand Down Expand Up @@ -738,18 +737,12 @@ def __enter__(self) -> "TestClient":
def reset_portal() -> None:
self.portal = None

send1: ObjectSendStream[
typing.Optional[typing.MutableMapping[str, typing.Any]]
]
receive1: ObjectReceiveStream[
typing.Optional[typing.MutableMapping[str, typing.Any]]
]
send2: ObjectSendStream[typing.MutableMapping[str, typing.Any]]
receive2: ObjectReceiveStream[typing.MutableMapping[str, typing.Any]]
send1, receive1 = anyio.create_memory_object_stream(math.inf)
send2, receive2 = anyio.create_memory_object_stream(math.inf)
self.stream_send = StapledObjectStream(send1, receive1)
self.stream_receive = StapledObjectStream(send2, receive2)
self.stream_send = StapledObjectStream(
*anyio.create_memory_object_stream(math.inf)
)
self.stream_receive = StapledObjectStream(
*anyio.create_memory_object_stream(math.inf)
)
self.task = portal.start_task_soon(self.lifespan)
portal.call(self.wait_startup)

Expand Down
8 changes: 1 addition & 7 deletions tests/middleware/test_wsgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,6 @@

from starlette.middleware.wsgi import WSGIMiddleware, build_environ

if sys.version_info < (3, 11): # pragma: no cover
from exceptiongroup import ExceptionGroup


def hello_world(environ, start_response):
status = "200 OK"
Expand Down Expand Up @@ -69,12 +66,9 @@ def test_wsgi_exception(test_client_factory):
# The HTTP protocol implementations would catch this error and return 500.
app = WSGIMiddleware(raise_exception)
client = test_client_factory(app)
with pytest.raises(ExceptionGroup) as exc:
with pytest.raises(RuntimeError):
client.get("/")

assert len(exc.value.exceptions) == 1
assert isinstance(exc.value.exceptions[0], RuntimeError)


def test_wsgi_exc_info(test_client_factory):
# Note that we're testing the WSGI app directly here.
Expand Down
4 changes: 0 additions & 4 deletions tests/test_websockets.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
import sys
from typing import Any, MutableMapping

import anyio
import pytest
from anyio.abc import ObjectReceiveStream, ObjectSendStream

from starlette import status
from starlette.types import Receive, Scope, Send
Expand Down Expand Up @@ -180,8 +178,6 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None:


def test_websocket_concurrency_pattern(test_client_factory):
stream_send: ObjectSendStream[MutableMapping[str, Any]]
stream_receive: ObjectReceiveStream[MutableMapping[str, Any]]
stream_send, stream_receive = anyio.create_memory_object_stream()

async def reader(websocket):
Expand Down

0 comments on commit e9d4fc4

Please sign in to comment.