From df2d4fdf48a297948e336bba4849cc8832978558 Mon Sep 17 00:00:00 2001 From: Stainless Bot <107565488+stainless-bot@users.noreply.github.com> Date: Wed, 3 Apr 2024 09:36:00 -0400 Subject: [PATCH] fix(client): correct logic for line decoding in streaming (#433) --- src/anthropic/_streaming.py | 73 ++++++---- tests/test_streaming.py | 268 +++++++++++++++++++++++++++--------- 2 files changed, 253 insertions(+), 88 deletions(-) diff --git a/src/anthropic/_streaming.py b/src/anthropic/_streaming.py index e267d875..73474c33 100644 --- a/src/anthropic/_streaming.py +++ b/src/anthropic/_streaming.py @@ -23,7 +23,7 @@ class Stream(Generic[_T]): response: httpx.Response - _decoder: SSEDecoder | SSEBytesDecoder + _decoder: SSEBytesDecoder def __init__( self, @@ -46,10 +46,7 @@ def __iter__(self) -> Iterator[_T]: yield item def _iter_events(self) -> Iterator[ServerSentEvent]: - if isinstance(self._decoder, SSEBytesDecoder): - yield from self._decoder.iter_bytes(self.response.iter_bytes()) - else: - yield from self._decoder.iter(self.response.iter_lines()) + yield from self._decoder.iter_bytes(self.response.iter_bytes()) def __stream__(self) -> Iterator[_T]: cast_to = cast(Any, self._cast_to) @@ -145,12 +142,8 @@ async def __aiter__(self) -> AsyncIterator[_T]: yield item async def _iter_events(self) -> AsyncIterator[ServerSentEvent]: - if isinstance(self._decoder, SSEBytesDecoder): - async for sse in self._decoder.aiter_bytes(self.response.aiter_bytes()): - yield sse - else: - async for sse in self._decoder.aiter(self.response.aiter_lines()): - yield sse + async for sse in self._decoder.aiter_bytes(self.response.aiter_bytes()): + yield sse async def __stream__(self) -> AsyncIterator[_T]: cast_to = cast(Any, self._cast_to) @@ -271,21 +264,49 @@ def __init__(self) -> None: self._last_event_id = None self._retry = None - def iter(self, iterator: Iterator[str]) -> Iterator[ServerSentEvent]: - """Given an iterator that yields lines, iterate over it & yield every event encountered""" - for line in iterator: - line = line.rstrip("\n") - sse = self.decode(line) - if sse is not None: - yield sse - - async def aiter(self, iterator: AsyncIterator[str]) -> AsyncIterator[ServerSentEvent]: - """Given an async iterator that yields lines, iterate over it & yield every event encountered""" - async for line in iterator: - line = line.rstrip("\n") - sse = self.decode(line) - if sse is not None: - yield sse + def iter_bytes(self, iterator: Iterator[bytes]) -> Iterator[ServerSentEvent]: + """Given an iterator that yields raw binary data, iterate over it & yield every event encountered""" + for chunk in self._iter_chunks(iterator): + # Split before decoding so splitlines() only uses \r and \n + for raw_line in chunk.splitlines(): + line = raw_line.decode("utf-8") + sse = self.decode(line) + if sse: + yield sse + + def _iter_chunks(self, iterator: Iterator[bytes]) -> Iterator[bytes]: + """Given an iterator that yields raw binary data, iterate over it and yield individual SSE chunks""" + data = b"" + for chunk in iterator: + for line in chunk.splitlines(keepends=True): + data += line + if data.endswith((b"\r\r", b"\n\n", b"\r\n\r\n")): + yield data + data = b"" + if data: + yield data + + async def aiter_bytes(self, iterator: AsyncIterator[bytes]) -> AsyncIterator[ServerSentEvent]: + """Given an iterator that yields raw binary data, iterate over it & yield every event encountered""" + async for chunk in self._aiter_chunks(iterator): + # Split before decoding so splitlines() only uses \r and \n + for raw_line in chunk.splitlines(): + line = raw_line.decode("utf-8") + sse = self.decode(line) + if sse: + yield sse + + async def _aiter_chunks(self, iterator: AsyncIterator[bytes]) -> AsyncIterator[bytes]: + """Given an iterator that yields raw binary data, iterate over it and yield individual SSE chunks""" + data = b"" + async for chunk in iterator: + for line in chunk.splitlines(keepends=True): + data += line + if data.endswith((b"\r\r", b"\n\n", b"\r\n\r\n")): + yield data + data = b"" + if data: + yield data def decode(self, line: str) -> ServerSentEvent | None: # See: https://html.spec.whatwg.org/multipage/server-sent-events.html#event-stream-interpretation # noqa: E501 diff --git a/tests/test_streaming.py b/tests/test_streaming.py index 1beb12e2..f9ad7182 100644 --- a/tests/test_streaming.py +++ b/tests/test_streaming.py @@ -1,104 +1,248 @@ +from __future__ import annotations + from typing import Iterator, AsyncIterator +import httpx import pytest -from anthropic._streaming import SSEDecoder +from anthropic import Anthropic, AsyncAnthropic +from anthropic._streaming import Stream, AsyncStream, ServerSentEvent @pytest.mark.asyncio -async def test_basic_async() -> None: - async def body() -> AsyncIterator[str]: - yield "event: completion" - yield 'data: {"foo":true}' - yield "" - - async for sse in SSEDecoder().aiter(body()): - assert sse.event == "completion" - assert sse.json() == {"foo": True} +@pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"]) +async def test_basic(sync: bool, client: Anthropic, async_client: AsyncAnthropic) -> None: + def body() -> Iterator[bytes]: + yield b"event: completion\n" + yield b'data: {"foo":true}\n' + yield b"\n" + iterator = make_event_iterator(content=body(), sync=sync, client=client, async_client=async_client) -def test_basic() -> None: - def body() -> Iterator[str]: - yield "event: completion" - yield 'data: {"foo":true}' - yield "" - - it = SSEDecoder().iter(body()) - sse = next(it) + sse = await iter_next(iterator) assert sse.event == "completion" assert sse.json() == {"foo": True} - with pytest.raises(StopIteration): - next(it) + await assert_empty_iter(iterator) -def test_data_missing_event() -> None: - def body() -> Iterator[str]: - yield 'data: {"foo":true}' - yield "" +@pytest.mark.asyncio +@pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"]) +async def test_data_missing_event(sync: bool, client: Anthropic, async_client: AsyncAnthropic) -> None: + def body() -> Iterator[bytes]: + yield b'data: {"foo":true}\n' + yield b"\n" - it = SSEDecoder().iter(body()) - sse = next(it) + iterator = make_event_iterator(content=body(), sync=sync, client=client, async_client=async_client) + + sse = await iter_next(iterator) assert sse.event is None assert sse.json() == {"foo": True} - with pytest.raises(StopIteration): - next(it) + await assert_empty_iter(iterator) + +@pytest.mark.asyncio +@pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"]) +async def test_event_missing_data(sync: bool, client: Anthropic, async_client: AsyncAnthropic) -> None: + def body() -> Iterator[bytes]: + yield b"event: ping\n" + yield b"\n" -def test_event_missing_data() -> None: - def body() -> Iterator[str]: - yield "event: ping" - yield "" + iterator = make_event_iterator(content=body(), sync=sync, client=client, async_client=async_client) - it = SSEDecoder().iter(body()) - sse = next(it) + sse = await iter_next(iterator) assert sse.event == "ping" assert sse.data == "" - with pytest.raises(StopIteration): - next(it) + await assert_empty_iter(iterator) -def test_multiple_events() -> None: - def body() -> Iterator[str]: - yield "event: ping" - yield "" - yield "event: completion" - yield "" +@pytest.mark.asyncio +@pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"]) +async def test_multiple_events(sync: bool, client: Anthropic, async_client: AsyncAnthropic) -> None: + def body() -> Iterator[bytes]: + yield b"event: ping\n" + yield b"\n" + yield b"event: completion\n" + yield b"\n" - it = SSEDecoder().iter(body()) + iterator = make_event_iterator(content=body(), sync=sync, client=client, async_client=async_client) - sse = next(it) + sse = await iter_next(iterator) assert sse.event == "ping" assert sse.data == "" - sse = next(it) + sse = await iter_next(iterator) assert sse.event == "completion" assert sse.data == "" - with pytest.raises(StopIteration): - next(it) - - -def test_multiple_events_with_data() -> None: - def body() -> Iterator[str]: - yield "event: ping" - yield 'data: {"foo":true}' - yield "" - yield "event: completion" - yield 'data: {"bar":false}' - yield "" + await assert_empty_iter(iterator) - it = SSEDecoder().iter(body()) - sse = next(it) +@pytest.mark.asyncio +@pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"]) +async def test_multiple_events_with_data(sync: bool, client: Anthropic, async_client: AsyncAnthropic) -> None: + def body() -> Iterator[bytes]: + yield b"event: ping\n" + yield b'data: {"foo":true}\n' + yield b"\n" + yield b"event: completion\n" + yield b'data: {"bar":false}\n' + yield b"\n" + + iterator = make_event_iterator(content=body(), sync=sync, client=client, async_client=async_client) + + sse = await iter_next(iterator) assert sse.event == "ping" assert sse.json() == {"foo": True} - sse = next(it) + sse = await iter_next(iterator) assert sse.event == "completion" assert sse.json() == {"bar": False} - with pytest.raises(StopIteration): - next(it) + await assert_empty_iter(iterator) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"]) +async def test_multiple_data_lines_with_empty_line(sync: bool, client: Anthropic, async_client: AsyncAnthropic) -> None: + def body() -> Iterator[bytes]: + yield b"event: ping\n" + yield b"data: {\n" + yield b'data: "foo":\n' + yield b"data: \n" + yield b"data:\n" + yield b"data: true}\n" + yield b"\n\n" + + iterator = make_event_iterator(content=body(), sync=sync, client=client, async_client=async_client) + + sse = await iter_next(iterator) + assert sse.event == "ping" + assert sse.json() == {"foo": True} + assert sse.data == '{\n"foo":\n\n\ntrue}' + + await assert_empty_iter(iterator) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"]) +async def test_data_json_escaped_double_new_line(sync: bool, client: Anthropic, async_client: AsyncAnthropic) -> None: + def body() -> Iterator[bytes]: + yield b"event: ping\n" + yield b'data: {"foo": "my long\\n\\ncontent"}' + yield b"\n\n" + + iterator = make_event_iterator(content=body(), sync=sync, client=client, async_client=async_client) + + sse = await iter_next(iterator) + assert sse.event == "ping" + assert sse.json() == {"foo": "my long\n\ncontent"} + + await assert_empty_iter(iterator) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"]) +async def test_multiple_data_lines(sync: bool, client: Anthropic, async_client: AsyncAnthropic) -> None: + def body() -> Iterator[bytes]: + yield b"event: ping\n" + yield b"data: {\n" + yield b'data: "foo":\n' + yield b"data: true}\n" + yield b"\n\n" + + iterator = make_event_iterator(content=body(), sync=sync, client=client, async_client=async_client) + + sse = await iter_next(iterator) + assert sse.event == "ping" + assert sse.json() == {"foo": True} + + await assert_empty_iter(iterator) + + +@pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"]) +async def test_special_new_line_character( + sync: bool, + client: Anthropic, + async_client: AsyncAnthropic, +) -> None: + def body() -> Iterator[bytes]: + yield b'data: {"content":" culpa"}\n' + yield b"\n" + yield b'data: {"content":" \xe2\x80\xa8"}\n' + yield b"\n" + yield b'data: {"content":"foo"}\n' + yield b"\n" + + iterator = make_event_iterator(content=body(), sync=sync, client=client, async_client=async_client) + + sse = await iter_next(iterator) + assert sse.event is None + assert sse.json() == {"content": " culpa"} + + sse = await iter_next(iterator) + assert sse.event is None + assert sse.json() == {"content": " 
"} + + sse = await iter_next(iterator) + assert sse.event is None + assert sse.json() == {"content": "foo"} + + await assert_empty_iter(iterator) + + +@pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"]) +async def test_multi_byte_character_multiple_chunks( + sync: bool, + client: Anthropic, + async_client: AsyncAnthropic, +) -> None: + def body() -> Iterator[bytes]: + yield b'data: {"content":"' + # bytes taken from the string 'известни' and arbitrarily split + # so that some multi-byte characters span multiple chunks + yield b"\xd0" + yield b"\xb8\xd0\xb7\xd0" + yield b"\xb2\xd0\xb5\xd1\x81\xd1\x82\xd0\xbd\xd0\xb8" + yield b'"}\n' + yield b"\n" + + iterator = make_event_iterator(content=body(), sync=sync, client=client, async_client=async_client) + + sse = await iter_next(iterator) + assert sse.event is None + assert sse.json() == {"content": "известни"} + + +async def to_aiter(iter: Iterator[bytes]) -> AsyncIterator[bytes]: + for chunk in iter: + yield chunk + + +async def iter_next(iter: Iterator[ServerSentEvent] | AsyncIterator[ServerSentEvent]) -> ServerSentEvent: + if isinstance(iter, AsyncIterator): + return await iter.__anext__() + + return next(iter) + + +async def assert_empty_iter(iter: Iterator[ServerSentEvent] | AsyncIterator[ServerSentEvent]) -> None: + with pytest.raises((StopAsyncIteration, RuntimeError)): + await iter_next(iter) + + +def make_event_iterator( + content: Iterator[bytes], + *, + sync: bool, + client: Anthropic, + async_client: AsyncAnthropic, +) -> Iterator[ServerSentEvent] | AsyncIterator[ServerSentEvent]: + if sync: + return Stream(cast_to=object, client=client, response=httpx.Response(200, content=content))._iter_events() + + return AsyncStream( + cast_to=object, client=async_client, response=httpx.Response(200, content=to_aiter(content)) + )._iter_events()