Skip to content

Commit

Permalink
Add streaming support
Browse files Browse the repository at this point in the history
  • Loading branch information
Graden Rea authored and stainless-app[bot] committed Feb 15, 2024
1 parent 36d47a3 commit 4baf6ef
Show file tree
Hide file tree
Showing 2 changed files with 199 additions and 3 deletions.
4 changes: 4 additions & 0 deletions src/groq/_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ def __stream__(self) -> Iterator[_T]:
iterator = self._iter_events()

for sse in iterator:
if sse.data.startswith("[DONE]"):
break
yield process_data(data=sse.json(), cast_to=cast_to, response=response)

# Ensure the entire stream is consumed
Expand Down Expand Up @@ -106,6 +108,8 @@ async def __aiter__(self) -> AsyncIterator[_T]:

async def _iter_events(self) -> AsyncIterator[ServerSentEvent]:
async for sse in self._decoder.aiter(self.response.aiter_lines()):
if sse.data.startswith("[DONE]"):
break
yield sse

async def __stream__(self) -> AsyncIterator[_T]:
Expand Down
198 changes: 195 additions & 3 deletions src/groq/resources/chat/completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@

from __future__ import annotations

from typing import Dict, List, Union, Iterable, Optional
from typing import Dict, List, Literal, Union, Iterable, Optional, overload

import httpx

from ...lib.chat_completion_chunk import ChatCompletionChunk
from ..._types import NOT_GIVEN, Body, Query, Headers, NotGiven
from ..._utils import maybe_transform
from ..._compat import cached_property
Expand All @@ -16,6 +17,7 @@
async_to_raw_response_wrapper,
async_to_streamed_response_wrapper,
)
from ..._streaming import AsyncStream, Stream
from ...types.chat import ChatCompletion, completion_create_params
from ..._base_client import (
make_request_options,
Expand All @@ -33,6 +35,7 @@ def with_raw_response(self) -> CompletionsWithRawResponse:
def with_streaming_response(self) -> CompletionsWithStreamingResponse:
return CompletionsWithStreamingResponse(self)

@overload
def create(
self,
*,
Expand All @@ -47,7 +50,7 @@ def create(
response_format: completion_create_params.ResponseFormat | NotGiven = NOT_GIVEN,
seed: int | NotGiven = NOT_GIVEN,
stop: Union[Optional[str], List[str], None] | NotGiven = NOT_GIVEN,
stream: bool | NotGiven = NOT_GIVEN,
stream: Optional[Literal[False]] | NotGiven = NOT_GIVEN,
temperature: float | NotGiven = NOT_GIVEN,
tool_choice: completion_create_params.ToolChoice | NotGiven = NOT_GIVEN,
tools: Iterable[completion_create_params.Tool] | NotGiven = NOT_GIVEN,
Expand All @@ -61,6 +64,98 @@ def create(
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> ChatCompletion:
...

@overload
def create(
self,
*,
frequency_penalty: float | NotGiven = NOT_GIVEN,
logit_bias: Dict[str, int] | NotGiven = NOT_GIVEN,
logprobs: bool | NotGiven = NOT_GIVEN,
max_tokens: int | NotGiven = NOT_GIVEN,
messages: Iterable[completion_create_params.Message] | NotGiven = NOT_GIVEN,
model: str | NotGiven = NOT_GIVEN,
n: int | NotGiven = NOT_GIVEN,
presence_penalty: float | NotGiven = NOT_GIVEN,
response_format: completion_create_params.ResponseFormat | NotGiven = NOT_GIVEN,
seed: int | NotGiven = NOT_GIVEN,
stop: Union[Optional[str], List[str], None] | NotGiven = NOT_GIVEN,
stream: Literal[True],
temperature: float | NotGiven = NOT_GIVEN,
tool_choice: completion_create_params.ToolChoice | NotGiven = NOT_GIVEN,
tools: Iterable[completion_create_params.Tool] | NotGiven = NOT_GIVEN,
top_logprobs: int | NotGiven = NOT_GIVEN,
top_p: float | NotGiven = NOT_GIVEN,
user: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> Stream[ChatCompletionChunk]:
...

@overload
def create(
self,
*,
frequency_penalty: float | NotGiven = NOT_GIVEN,
logit_bias: Dict[str, int] | NotGiven = NOT_GIVEN,
logprobs: bool | NotGiven = NOT_GIVEN,
max_tokens: int | NotGiven = NOT_GIVEN,
messages: Iterable[completion_create_params.Message] | NotGiven = NOT_GIVEN,
model: str | NotGiven = NOT_GIVEN,
n: int | NotGiven = NOT_GIVEN,
presence_penalty: float | NotGiven = NOT_GIVEN,
response_format: completion_create_params.ResponseFormat | NotGiven = NOT_GIVEN,
seed: int | NotGiven = NOT_GIVEN,
stop: Union[Optional[str], List[str], None] | NotGiven = NOT_GIVEN,
stream: bool,
temperature: float | NotGiven = NOT_GIVEN,
tool_choice: completion_create_params.ToolChoice | NotGiven = NOT_GIVEN,
tools: Iterable[completion_create_params.Tool] | NotGiven = NOT_GIVEN,
top_logprobs: int | NotGiven = NOT_GIVEN,
top_p: float | NotGiven = NOT_GIVEN,
user: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> ChatCompletion | Stream[ChatCompletionChunk]:
...

def create(
self,
*,
frequency_penalty: float | NotGiven = NOT_GIVEN,
logit_bias: Dict[str, int] | NotGiven = NOT_GIVEN,
logprobs: bool | NotGiven = NOT_GIVEN,
max_tokens: int | NotGiven = NOT_GIVEN,
messages: Iterable[completion_create_params.Message] | NotGiven = NOT_GIVEN,
model: str | NotGiven = NOT_GIVEN,
n: int | NotGiven = NOT_GIVEN,
presence_penalty: float | NotGiven = NOT_GIVEN,
response_format: completion_create_params.ResponseFormat | NotGiven = NOT_GIVEN,
seed: int | NotGiven = NOT_GIVEN,
stop: Union[Optional[str], List[str], None] | NotGiven = NOT_GIVEN,
stream: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN,
temperature: float | NotGiven = NOT_GIVEN,
tool_choice: completion_create_params.ToolChoice | NotGiven = NOT_GIVEN,
tools: Iterable[completion_create_params.Tool] | NotGiven = NOT_GIVEN,
top_logprobs: int | NotGiven = NOT_GIVEN,
top_p: float | NotGiven = NOT_GIVEN,
user: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> ChatCompletion | Stream[ChatCompletionChunk]:
"""
Creates a completion for a chat prompt
Expand Down Expand Up @@ -105,6 +200,8 @@ def create(
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
),
cast_to=ChatCompletion,
stream=stream or False,
stream_cls=Stream[ChatCompletionChunk],
)


Expand All @@ -117,6 +214,7 @@ def with_raw_response(self) -> AsyncCompletionsWithRawResponse:
def with_streaming_response(self) -> AsyncCompletionsWithStreamingResponse:
return AsyncCompletionsWithStreamingResponse(self)

@overload
async def create(
self,
*,
Expand All @@ -131,7 +229,7 @@ async def create(
response_format: completion_create_params.ResponseFormat | NotGiven = NOT_GIVEN,
seed: int | NotGiven = NOT_GIVEN,
stop: Union[Optional[str], List[str], None] | NotGiven = NOT_GIVEN,
stream: bool | NotGiven = NOT_GIVEN,
stream: Optional[Literal[False]] | NotGiven = NOT_GIVEN,
temperature: float | NotGiven = NOT_GIVEN,
tool_choice: completion_create_params.ToolChoice | NotGiven = NOT_GIVEN,
tools: Iterable[completion_create_params.Tool] | NotGiven = NOT_GIVEN,
Expand All @@ -145,6 +243,98 @@ async def create(
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> ChatCompletion:
...

@overload
async def create(
self,
*,
frequency_penalty: float | NotGiven = NOT_GIVEN,
logit_bias: Dict[str, int] | NotGiven = NOT_GIVEN,
logprobs: bool | NotGiven = NOT_GIVEN,
max_tokens: int | NotGiven = NOT_GIVEN,
messages: Iterable[completion_create_params.Message] | NotGiven = NOT_GIVEN,
model: str | NotGiven = NOT_GIVEN,
n: int | NotGiven = NOT_GIVEN,
presence_penalty: float | NotGiven = NOT_GIVEN,
response_format: completion_create_params.ResponseFormat | NotGiven = NOT_GIVEN,
seed: int | NotGiven = NOT_GIVEN,
stop: Union[Optional[str], List[str], None] | NotGiven = NOT_GIVEN,
stream: Literal[True],
temperature: float | NotGiven = NOT_GIVEN,
tool_choice: completion_create_params.ToolChoice | NotGiven = NOT_GIVEN,
tools: Iterable[completion_create_params.Tool] | NotGiven = NOT_GIVEN,
top_logprobs: int | NotGiven = NOT_GIVEN,
top_p: float | NotGiven = NOT_GIVEN,
user: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> AsyncStream[ChatCompletionChunk]:
...

@overload
async def create(
self,
*,
frequency_penalty: float | NotGiven = NOT_GIVEN,
logit_bias: Dict[str, int] | NotGiven = NOT_GIVEN,
logprobs: bool | NotGiven = NOT_GIVEN,
max_tokens: int | NotGiven = NOT_GIVEN,
messages: Iterable[completion_create_params.Message] | NotGiven = NOT_GIVEN,
model: str | NotGiven = NOT_GIVEN,
n: int | NotGiven = NOT_GIVEN,
presence_penalty: float | NotGiven = NOT_GIVEN,
response_format: completion_create_params.ResponseFormat | NotGiven = NOT_GIVEN,
seed: int | NotGiven = NOT_GIVEN,
stop: Union[Optional[str], List[str], None] | NotGiven = NOT_GIVEN,
stream: bool,
temperature: float | NotGiven = NOT_GIVEN,
tool_choice: completion_create_params.ToolChoice | NotGiven = NOT_GIVEN,
tools: Iterable[completion_create_params.Tool] | NotGiven = NOT_GIVEN,
top_logprobs: int | NotGiven = NOT_GIVEN,
top_p: float | NotGiven = NOT_GIVEN,
user: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> ChatCompletion | AsyncStream[ChatCompletionChunk]:
...

async def create(
self,
*,
frequency_penalty: float | NotGiven = NOT_GIVEN,
logit_bias: Dict[str, int] | NotGiven = NOT_GIVEN,
logprobs: bool | NotGiven = NOT_GIVEN,
max_tokens: int | NotGiven = NOT_GIVEN,
messages: Iterable[completion_create_params.Message] | NotGiven = NOT_GIVEN,
model: str | NotGiven = NOT_GIVEN,
n: int | NotGiven = NOT_GIVEN,
presence_penalty: float | NotGiven = NOT_GIVEN,
response_format: completion_create_params.ResponseFormat | NotGiven = NOT_GIVEN,
seed: int | NotGiven = NOT_GIVEN,
stop: Union[Optional[str], List[str], None] | NotGiven = NOT_GIVEN,
stream: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN,
temperature: float | NotGiven = NOT_GIVEN,
tool_choice: completion_create_params.ToolChoice | NotGiven = NOT_GIVEN,
tools: Iterable[completion_create_params.Tool] | NotGiven = NOT_GIVEN,
top_logprobs: int | NotGiven = NOT_GIVEN,
top_p: float | NotGiven = NOT_GIVEN,
user: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> ChatCompletion | AsyncStream[ChatCompletionChunk]:
"""
Creates a completion for a chat prompt
Expand Down Expand Up @@ -189,6 +379,8 @@ async def create(
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
),
cast_to=ChatCompletion,
stream=stream or False,
stream_cls=AsyncStream[ChatCompletionChunk],
)


Expand Down

0 comments on commit 4baf6ef

Please sign in to comment.