Skip to content

Commit

Permalink
chore(internal): restructure streaming implementation to use composition
Browse files Browse the repository at this point in the history
  • Loading branch information
RobertCraigie committed May 24, 2024
1 parent 472b831 commit b1a1c03
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 41 deletions.
152 changes: 115 additions & 37 deletions src/anthropic/lib/streaming/_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import asyncio
from types import TracebackType
from typing import TYPE_CHECKING, Generic, TypeVar, Callable
from typing_extensions import Iterator, Awaitable, AsyncIterator, override, assert_never
from typing_extensions import Self, Iterator, Awaitable, AsyncIterator, assert_never

import httpx

Expand All @@ -15,7 +15,7 @@
from ..._client import Anthropic, AsyncAnthropic


class MessageStream(Stream[MessageStreamEvent]):
class MessageStream:
text_stream: Iterator[str]
"""Iterator over just the text deltas in the stream.
Expand All @@ -26,18 +26,52 @@ class MessageStream(Stream[MessageStreamEvent]):
```
"""

response: httpx.Response

def __init__(
self,
*,
cast_to: type[MessageStreamEvent],
response: httpx.Response,
client: Anthropic,
) -> None:
super().__init__(cast_to=cast_to, response=response, client=client)
self.response = response
self._cast_to = cast_to
self._client = client

self.text_stream = self.__stream_text__()
self.__final_message_snapshot: Message | None = None

self._iterator = self.__stream__()
self._raw_stream: Stream[MessageStreamEvent] = Stream(cast_to=cast_to, response=response, client=client)

def __next__(self) -> MessageStreamEvent:
return self._iterator.__next__()

def __iter__(self) -> Iterator[MessageStreamEvent]:
for item in self._iterator:
yield item

def __enter__(self) -> Self:
return self

def __exit__(
self,
exc_type: type[BaseException] | None,
exc: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
self.close()

def close(self) -> None:
"""
Close the response and release the connection.
Automatically called if the response body is read to completion.
"""
self.response.close()
self.on_end()

def get_final_message(self) -> Message:
"""Waits until the stream has been read to completion and returns
the accumulated `Message` object.
Expand Down Expand Up @@ -69,11 +103,6 @@ def until_done(self) -> None:
"""Blocks until the stream has been consumed"""
consume_sync_iterator(self)

@override
def close(self) -> None:
super().close()
self.on_end()

# properties
@property
def current_message_snapshot(self) -> Message:
Expand Down Expand Up @@ -118,17 +147,16 @@ def on_end(self) -> None:
def on_timeout(self) -> None:
"""Fires if the request times out"""

@override
def __stream__(self) -> Iterator[MessageStreamEvent]:
try:
for event in super().__stream__():
for sse_event in self._raw_stream:
self.__final_message_snapshot = accumulate_event(
event=event,
event=sse_event,
current_snapshot=self.__final_message_snapshot,
)
self._emit_sse_event(event)
self._emit_sse_event(sse_event)

yield event
yield sse_event
except (httpx.TimeoutException, asyncio.TimeoutError) as exc:
self.on_timeout()
self.on_exception(exc)
Expand Down Expand Up @@ -184,25 +212,35 @@ class MessageStreamManager(Generic[MessageStreamT]):
```
"""

def __init__(self, api_request: Callable[[], MessageStreamT]) -> None:
self.__stream: MessageStreamT | None = None
def __init__(
self, api_request: Callable[[], Stream[MessageStreamEvent]], event_handler_cls: type[MessageStreamT]
) -> None:
self.__event_handler: MessageStreamT | None = None
self.__event_handler_cls: type[MessageStreamT] = event_handler_cls
self.__api_request = api_request

def __enter__(self) -> MessageStreamT:
self.__stream = self.__api_request()
return self.__stream
raw_stream = self.__api_request()

self.__event_handler = self.__event_handler_cls(
cast_to=raw_stream._cast_to,
response=raw_stream.response,
client=raw_stream._client,
)

return self.__event_handler

def __exit__(
self,
exc_type: type[BaseException] | None,
exc: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
if self.__stream is not None:
self.__stream.close()
if self.__event_handler is not None:
self.__event_handler.close()


class AsyncMessageStream(AsyncStream[MessageStreamEvent]):
class AsyncMessageStream:
text_stream: AsyncIterator[str]
"""Async iterator over just the text deltas in the stream.
Expand All @@ -213,18 +251,54 @@ class AsyncMessageStream(AsyncStream[MessageStreamEvent]):
```
"""

response: httpx.Response

def __init__(
self,
*,
cast_to: type[MessageStreamEvent],
response: httpx.Response,
client: AsyncAnthropic,
) -> None:
super().__init__(cast_to=cast_to, response=response, client=client)
self.response = response
self._cast_to = cast_to
self._client = client

self.text_stream = self.__stream_text__()
self.__final_message_snapshot: Message | None = None

self._iterator = self.__stream__()
self._raw_stream: AsyncStream[MessageStreamEvent] = AsyncStream(
cast_to=cast_to, response=response, client=client
)

async def __anext__(self) -> MessageStreamEvent:
return await self._iterator.__anext__()

async def __aiter__(self) -> AsyncIterator[MessageStreamEvent]:
async for item in self._iterator:
yield item

async def __aenter__(self) -> Self:
return self

async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
await self.close()

async def close(self) -> None:
"""
Close the response and release the connection.
Automatically called if the response body is read to completion.
"""
await self.response.aclose()
await self.on_end()

async def get_final_message(self) -> Message:
"""Waits until the stream has been read to completion and returns
the accumulated `Message` object.
Expand Down Expand Up @@ -256,11 +330,6 @@ async def until_done(self) -> None:
"""Waits until the stream has been consumed"""
await consume_async_iterator(self)

@override
async def close(self) -> None:
await super().close()
await self.on_end()

# properties
@property
def current_message_snapshot(self) -> Message:
Expand Down Expand Up @@ -311,17 +380,16 @@ async def on_end(self) -> None:
async def on_timeout(self) -> None:
"""Fires if the request times out"""

@override
async def __stream__(self) -> AsyncIterator[MessageStreamEvent]:
try:
async for event in super().__stream__():
async for sse_event in self._raw_stream:
self.__final_message_snapshot = accumulate_event(
event=event,
event=sse_event,
current_snapshot=self.__final_message_snapshot,
)
await self._emit_sse_event(event)
await self._emit_sse_event(sse_event)

yield event
yield sse_event
except (httpx.TimeoutException, asyncio.TimeoutError) as exc:
await self.on_timeout()
await self.on_exception(exc)
Expand Down Expand Up @@ -382,22 +450,32 @@ class AsyncMessageStreamManager(Generic[AsyncMessageStreamT]):
```
"""

def __init__(self, api_request: Awaitable[AsyncMessageStreamT]) -> None:
self.__stream: AsyncMessageStreamT | None = None
def __init__(
self, api_request: Awaitable[AsyncStream[MessageStreamEvent]], event_handler_cls: type[AsyncMessageStreamT]
) -> None:
self.__event_handler: AsyncMessageStreamT | None = None
self.__event_handler_cls: type[AsyncMessageStreamT] = event_handler_cls
self.__api_request = api_request

async def __aenter__(self) -> AsyncMessageStreamT:
self.__stream = await self.__api_request
return self.__stream
raw_stream = await self.__api_request

self.__event_handler = self.__event_handler_cls(
cast_to=raw_stream._cast_to,
response=raw_stream.response,
client=raw_stream._client,
)

return self.__event_handler

async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
if self.__stream is not None:
await self.__stream.close()
if self.__event_handler is not None:
await self.__event_handler.close()


def accumulate_event(*, event: MessageStreamEvent, current_snapshot: Message | None) -> Message:
Expand Down
8 changes: 4 additions & 4 deletions src/anthropic/resources/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -829,9 +829,9 @@ def stream( # pyright: ignore[reportInconsistentOverload]
),
cast_to=Message,
stream=True,
stream_cls=event_handler,
stream_cls=Stream[MessageStreamEvent],
)
return MessageStreamManager(make_request)
return MessageStreamManager(make_request, event_handler)


class AsyncMessages(AsyncAPIResource):
Expand Down Expand Up @@ -1624,9 +1624,9 @@ def stream( # pyright: ignore[reportInconsistentOverload]
),
cast_to=Message,
stream=True,
stream_cls=event_handler,
stream_cls=AsyncStream[MessageStreamEvent],
)
return AsyncMessageStreamManager(request)
return AsyncMessageStreamManager(request, event_handler)


class MessagesWithRawResponse:
Expand Down

0 comments on commit b1a1c03

Please sign in to comment.