Skip to content

Commit

Permalink
fix: fixed typing in llm_openai.py
Browse files Browse the repository at this point in the history
  • Loading branch information
ErikBjare committed Oct 24, 2024
1 parent 28fa60a commit 4a11c8b
Showing 1 changed file with 29 additions and 12 deletions.
41 changes: 29 additions & 12 deletions gptme/llm_openai.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
import logging
from collections.abc import Generator
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, cast
from openai.types.chat import ChatCompletionChunk
from openai.types.chat.chat_completion_chunk import (
ChoiceDeltaToolCall,
ChoiceDeltaToolCallFunction,
)

from .config import Config
from .constants import TEMPERATURE, TOP_P
Expand Down Expand Up @@ -99,7 +104,7 @@ def chat(messages: list[Message], model: str, tools) -> str:
def stream(messages: list[Message], model: str, tools) -> Generator[str, None, None]:
assert openai, "LLM not initialized"
stop_reason = None
for chunk in openai.chat.completions.create(
for chunk_raw in openai.chat.completions.create(
model=model,
messages=msgs2dicts(_prep_o1(messages), openai=True), # type: ignore
temperature=TEMPERATURE,
Expand All @@ -115,16 +120,28 @@ def stream(messages: list[Message], model: str, tools) -> Generator[str, None, N
openrouter_headers if "openrouter.ai" in str(openai.base_url) else {}
),
):
if not chunk.choices: # type: ignore
# Cast the chunk to the correct type
chunk = cast(ChatCompletionChunk, chunk_raw)

if not chunk.choices:
# Got a chunk with no choices, Azure always sends one of these at the start
continue
stop_reason = chunk.choices[0].finish_reason # type: ignore
if content := chunk.choices[0].delta.content:
yield content
# TODO: propagate tool calls back better, this is a bit hacky
for tool_call in chunk.choices[0].delta.tool_calls or ():
if name := tool_call.function.name:
yield "@" + name + ": "
if args := tool_call.function.arguments:
yield args

choice = chunk.choices[0]
stop_reason = choice.finish_reason
delta = choice.delta

if delta.content is not None:
yield delta.content

# Handle tool calls
if delta.tool_calls:
for tool_call in delta.tool_calls:
if isinstance(tool_call, ChoiceDeltaToolCall) and tool_call.function:
func = tool_call.function
if isinstance(func, ChoiceDeltaToolCallFunction):
if func.name:
yield "@" + func.name + ": "
if func.arguments:
yield func.arguments
logger.debug(f"Stop reason: {stop_reason}")

0 comments on commit 4a11c8b

Please sign in to comment.