From 4baf6ef6911f814731f3494b4cc317ab68ecfc36 Mon Sep 17 00:00:00 2001 From: Graden Rea Date: Wed, 14 Feb 2024 15:40:49 -0800 Subject: [PATCH] Add streaming support --- src/groq/_streaming.py | 4 + src/groq/resources/chat/completions.py | 198 ++++++++++++++++++++++++- 2 files changed, 199 insertions(+), 3 deletions(-) diff --git a/src/groq/_streaming.py b/src/groq/_streaming.py index ac0ea8a..2769874 100644 --- a/src/groq/_streaming.py +++ b/src/groq/_streaming.py @@ -53,6 +53,8 @@ def __stream__(self) -> Iterator[_T]: iterator = self._iter_events() for sse in iterator: + if sse.data.startswith("[DONE]"): + break yield process_data(data=sse.json(), cast_to=cast_to, response=response) # Ensure the entire stream is consumed @@ -106,6 +108,8 @@ async def __aiter__(self) -> AsyncIterator[_T]: async def _iter_events(self) -> AsyncIterator[ServerSentEvent]: async for sse in self._decoder.aiter(self.response.aiter_lines()): + if sse.data.startswith("[DONE]"): + break yield sse async def __stream__(self) -> AsyncIterator[_T]: diff --git a/src/groq/resources/chat/completions.py b/src/groq/resources/chat/completions.py index 9f332f5..2661803 100644 --- a/src/groq/resources/chat/completions.py +++ b/src/groq/resources/chat/completions.py @@ -2,10 +2,11 @@ from __future__ import annotations -from typing import Dict, List, Union, Iterable, Optional +from typing import Dict, List, Literal, Union, Iterable, Optional, overload import httpx +from ...lib.chat_completion_chunk import ChatCompletionChunk from ..._types import NOT_GIVEN, Body, Query, Headers, NotGiven from ..._utils import maybe_transform from ..._compat import cached_property @@ -16,6 +17,7 @@ async_to_raw_response_wrapper, async_to_streamed_response_wrapper, ) +from ..._streaming import AsyncStream, Stream from ...types.chat import ChatCompletion, completion_create_params from ..._base_client import ( make_request_options, @@ -33,6 +35,7 @@ def with_raw_response(self) -> CompletionsWithRawResponse: def with_streaming_response(self) -> CompletionsWithStreamingResponse: return CompletionsWithStreamingResponse(self) + @overload def create( self, *, @@ -47,7 +50,7 @@ def create( response_format: completion_create_params.ResponseFormat | NotGiven = NOT_GIVEN, seed: int | NotGiven = NOT_GIVEN, stop: Union[Optional[str], List[str], None] | NotGiven = NOT_GIVEN, - stream: bool | NotGiven = NOT_GIVEN, + stream: Optional[Literal[False]] | NotGiven = NOT_GIVEN, temperature: float | NotGiven = NOT_GIVEN, tool_choice: completion_create_params.ToolChoice | NotGiven = NOT_GIVEN, tools: Iterable[completion_create_params.Tool] | NotGiven = NOT_GIVEN, @@ -61,6 +64,98 @@ def create( extra_body: Body | None = None, timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, ) -> ChatCompletion: + ... + + @overload + def create( + self, + *, + frequency_penalty: float | NotGiven = NOT_GIVEN, + logit_bias: Dict[str, int] | NotGiven = NOT_GIVEN, + logprobs: bool | NotGiven = NOT_GIVEN, + max_tokens: int | NotGiven = NOT_GIVEN, + messages: Iterable[completion_create_params.Message] | NotGiven = NOT_GIVEN, + model: str | NotGiven = NOT_GIVEN, + n: int | NotGiven = NOT_GIVEN, + presence_penalty: float | NotGiven = NOT_GIVEN, + response_format: completion_create_params.ResponseFormat | NotGiven = NOT_GIVEN, + seed: int | NotGiven = NOT_GIVEN, + stop: Union[Optional[str], List[str], None] | NotGiven = NOT_GIVEN, + stream: Literal[True], + temperature: float | NotGiven = NOT_GIVEN, + tool_choice: completion_create_params.ToolChoice | NotGiven = NOT_GIVEN, + tools: Iterable[completion_create_params.Tool] | NotGiven = NOT_GIVEN, + top_logprobs: int | NotGiven = NOT_GIVEN, + top_p: float | NotGiven = NOT_GIVEN, + user: str | NotGiven = NOT_GIVEN, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> Stream[ChatCompletionChunk]: + ... + + @overload + def create( + self, + *, + frequency_penalty: float | NotGiven = NOT_GIVEN, + logit_bias: Dict[str, int] | NotGiven = NOT_GIVEN, + logprobs: bool | NotGiven = NOT_GIVEN, + max_tokens: int | NotGiven = NOT_GIVEN, + messages: Iterable[completion_create_params.Message] | NotGiven = NOT_GIVEN, + model: str | NotGiven = NOT_GIVEN, + n: int | NotGiven = NOT_GIVEN, + presence_penalty: float | NotGiven = NOT_GIVEN, + response_format: completion_create_params.ResponseFormat | NotGiven = NOT_GIVEN, + seed: int | NotGiven = NOT_GIVEN, + stop: Union[Optional[str], List[str], None] | NotGiven = NOT_GIVEN, + stream: bool, + temperature: float | NotGiven = NOT_GIVEN, + tool_choice: completion_create_params.ToolChoice | NotGiven = NOT_GIVEN, + tools: Iterable[completion_create_params.Tool] | NotGiven = NOT_GIVEN, + top_logprobs: int | NotGiven = NOT_GIVEN, + top_p: float | NotGiven = NOT_GIVEN, + user: str | NotGiven = NOT_GIVEN, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> ChatCompletion | Stream[ChatCompletionChunk]: + ... + + def create( + self, + *, + frequency_penalty: float | NotGiven = NOT_GIVEN, + logit_bias: Dict[str, int] | NotGiven = NOT_GIVEN, + logprobs: bool | NotGiven = NOT_GIVEN, + max_tokens: int | NotGiven = NOT_GIVEN, + messages: Iterable[completion_create_params.Message] | NotGiven = NOT_GIVEN, + model: str | NotGiven = NOT_GIVEN, + n: int | NotGiven = NOT_GIVEN, + presence_penalty: float | NotGiven = NOT_GIVEN, + response_format: completion_create_params.ResponseFormat | NotGiven = NOT_GIVEN, + seed: int | NotGiven = NOT_GIVEN, + stop: Union[Optional[str], List[str], None] | NotGiven = NOT_GIVEN, + stream: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN, + temperature: float | NotGiven = NOT_GIVEN, + tool_choice: completion_create_params.ToolChoice | NotGiven = NOT_GIVEN, + tools: Iterable[completion_create_params.Tool] | NotGiven = NOT_GIVEN, + top_logprobs: int | NotGiven = NOT_GIVEN, + top_p: float | NotGiven = NOT_GIVEN, + user: str | NotGiven = NOT_GIVEN, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> ChatCompletion | Stream[ChatCompletionChunk]: """ Creates a completion for a chat prompt @@ -105,6 +200,8 @@ def create( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), cast_to=ChatCompletion, + stream=stream or False, + stream_cls=Stream[ChatCompletionChunk], ) @@ -117,6 +214,7 @@ def with_raw_response(self) -> AsyncCompletionsWithRawResponse: def with_streaming_response(self) -> AsyncCompletionsWithStreamingResponse: return AsyncCompletionsWithStreamingResponse(self) + @overload async def create( self, *, @@ -131,7 +229,7 @@ async def create( response_format: completion_create_params.ResponseFormat | NotGiven = NOT_GIVEN, seed: int | NotGiven = NOT_GIVEN, stop: Union[Optional[str], List[str], None] | NotGiven = NOT_GIVEN, - stream: bool | NotGiven = NOT_GIVEN, + stream: Optional[Literal[False]] | NotGiven = NOT_GIVEN, temperature: float | NotGiven = NOT_GIVEN, tool_choice: completion_create_params.ToolChoice | NotGiven = NOT_GIVEN, tools: Iterable[completion_create_params.Tool] | NotGiven = NOT_GIVEN, @@ -145,6 +243,98 @@ async def create( extra_body: Body | None = None, timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, ) -> ChatCompletion: + ... + + @overload + async def create( + self, + *, + frequency_penalty: float | NotGiven = NOT_GIVEN, + logit_bias: Dict[str, int] | NotGiven = NOT_GIVEN, + logprobs: bool | NotGiven = NOT_GIVEN, + max_tokens: int | NotGiven = NOT_GIVEN, + messages: Iterable[completion_create_params.Message] | NotGiven = NOT_GIVEN, + model: str | NotGiven = NOT_GIVEN, + n: int | NotGiven = NOT_GIVEN, + presence_penalty: float | NotGiven = NOT_GIVEN, + response_format: completion_create_params.ResponseFormat | NotGiven = NOT_GIVEN, + seed: int | NotGiven = NOT_GIVEN, + stop: Union[Optional[str], List[str], None] | NotGiven = NOT_GIVEN, + stream: Literal[True], + temperature: float | NotGiven = NOT_GIVEN, + tool_choice: completion_create_params.ToolChoice | NotGiven = NOT_GIVEN, + tools: Iterable[completion_create_params.Tool] | NotGiven = NOT_GIVEN, + top_logprobs: int | NotGiven = NOT_GIVEN, + top_p: float | NotGiven = NOT_GIVEN, + user: str | NotGiven = NOT_GIVEN, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> AsyncStream[ChatCompletionChunk]: + ... + + @overload + async def create( + self, + *, + frequency_penalty: float | NotGiven = NOT_GIVEN, + logit_bias: Dict[str, int] | NotGiven = NOT_GIVEN, + logprobs: bool | NotGiven = NOT_GIVEN, + max_tokens: int | NotGiven = NOT_GIVEN, + messages: Iterable[completion_create_params.Message] | NotGiven = NOT_GIVEN, + model: str | NotGiven = NOT_GIVEN, + n: int | NotGiven = NOT_GIVEN, + presence_penalty: float | NotGiven = NOT_GIVEN, + response_format: completion_create_params.ResponseFormat | NotGiven = NOT_GIVEN, + seed: int | NotGiven = NOT_GIVEN, + stop: Union[Optional[str], List[str], None] | NotGiven = NOT_GIVEN, + stream: bool, + temperature: float | NotGiven = NOT_GIVEN, + tool_choice: completion_create_params.ToolChoice | NotGiven = NOT_GIVEN, + tools: Iterable[completion_create_params.Tool] | NotGiven = NOT_GIVEN, + top_logprobs: int | NotGiven = NOT_GIVEN, + top_p: float | NotGiven = NOT_GIVEN, + user: str | NotGiven = NOT_GIVEN, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> ChatCompletion | AsyncStream[ChatCompletionChunk]: + ... + + async def create( + self, + *, + frequency_penalty: float | NotGiven = NOT_GIVEN, + logit_bias: Dict[str, int] | NotGiven = NOT_GIVEN, + logprobs: bool | NotGiven = NOT_GIVEN, + max_tokens: int | NotGiven = NOT_GIVEN, + messages: Iterable[completion_create_params.Message] | NotGiven = NOT_GIVEN, + model: str | NotGiven = NOT_GIVEN, + n: int | NotGiven = NOT_GIVEN, + presence_penalty: float | NotGiven = NOT_GIVEN, + response_format: completion_create_params.ResponseFormat | NotGiven = NOT_GIVEN, + seed: int | NotGiven = NOT_GIVEN, + stop: Union[Optional[str], List[str], None] | NotGiven = NOT_GIVEN, + stream: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN, + temperature: float | NotGiven = NOT_GIVEN, + tool_choice: completion_create_params.ToolChoice | NotGiven = NOT_GIVEN, + tools: Iterable[completion_create_params.Tool] | NotGiven = NOT_GIVEN, + top_logprobs: int | NotGiven = NOT_GIVEN, + top_p: float | NotGiven = NOT_GIVEN, + user: str | NotGiven = NOT_GIVEN, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> ChatCompletion | AsyncStream[ChatCompletionChunk]: """ Creates a completion for a chat prompt @@ -189,6 +379,8 @@ async def create( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), cast_to=ChatCompletion, + stream=stream or False, + stream_cls=AsyncStream[ChatCompletionChunk], )