Skip to content

Commit

Permalink
sync with starlette StreamingResponse
Browse files Browse the repository at this point in the history
- adjust typing to latest starlette standard
- fix formatting
- fix mypy errors
  • Loading branch information
sysid committed Nov 19, 2023
1 parent 159bbec commit 415d647
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 14 deletions.
3 changes: 2 additions & 1 deletion examples/load_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@
# curl http://localhost:8000/stream | pv --line-mode --average-rate > /dev/null
################################################################################

import uvicorn
import json

import uvicorn
from fastapi import FastAPI, Request
from sse_starlette.sse import EventSourceResponse

Expand Down
39 changes: 26 additions & 13 deletions sse_starlette/sse.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,22 @@
import re
from datetime import datetime
from functools import partial
from typing import Any, AsyncIterable, Callable, Coroutine, Dict, Optional, Union
from typing import (
Any,
AsyncIterable,
Awaitable,
Callable,
Coroutine,
Iterator,
Mapping,
Optional,
Union,
)

import anyio
from starlette.background import BackgroundTask
from starlette.concurrency import iterate_in_threadpool
from starlette.responses import Response
from starlette.responses import AsyncContentStream, ContentStream, Response
from starlette.types import Receive, Scope, Send

_log = logging.getLogger(__name__)
Expand Down Expand Up @@ -137,14 +147,16 @@ class EventSourceResponse(Response):
implementation based on Starlette StreamingResponse
"""

body_iterator: AsyncContentStream

DEFAULT_PING_INTERVAL = 15

# noinspection PyMissingConstructor
def __init__(
self,
content: Any,
content: ContentStream,
status_code: int = 200,
headers: Optional[Dict] = None,
headers: Optional[Mapping[str, str]] = None,
media_type: str = "text/event-stream",
background: Optional[BackgroundTask] = None,
ping: Optional[int] = None,
Expand All @@ -156,20 +168,21 @@ def __init__(
) -> None:
if sep is not None and sep not in ["\r\n", "\r", "\n"]:
raise ValueError(f"sep must be one of: \\r\\n, \\r, \\n, got: {sep}")
self.sep = sep
self.DEFAULT_SEPARATOR = "\r\n"
self.sep = sep if sep is not None else self.DEFAULT_SEPARATOR

self.ping_message_factory = ping_message_factory

if isinstance(content, AsyncIterable):
self.body_iterator = (
content
) # type: AsyncIterable[Union[Any,dict,ServerSentEvent]]
self.body_iterator = content
else:
self.body_iterator = iterate_in_threadpool(content) # type: ignore
self.body_iterator = iterate_in_threadpool(content)
self.status_code = status_code
self.media_type = self.media_type if media_type is None else media_type
self.background = background # type: ignore # follows https://github.com/encode/starlette/blob/master/starlette/responses.py
self.background = background
self.data_sender_callable = data_sender_callable

_headers = {}
_headers: dict[str, str] = {}
if headers is not None: # pragma: no cover
_headers.update(headers)

Expand Down Expand Up @@ -215,7 +228,7 @@ async def listen_for_exit_signal() -> None:
# Await the event
await AppStatus.should_exit_event.wait()

async def stream_response(self, send) -> None:
async def stream_response(self, send: Send) -> None:
await send(
{
"type": "http.response.start",
Expand All @@ -235,7 +248,7 @@ async def stream_response(self, send) -> None:
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
async with anyio.create_task_group() as task_group:
# https://trio.readthedocs.io/en/latest/reference-core.html#custom-supervisors
async def wrap(func: Callable[[], Coroutine[None, None, None]]) -> None:
async def wrap(func: Callable[[], Awaitable[None]]) -> None:
await func()
# noinspection PyAsyncCall
task_group.cancel_scope.cancel()
Expand Down

0 comments on commit 415d647

Please sign in to comment.