From 4a11c8b7130750b3d017dc9d0419805dfd61f66b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Erik=20Bj=C3=A4reholt?= Date: Thu, 24 Oct 2024 19:23:20 +0200 Subject: [PATCH] fix: fixed typing in llm_openai.py --- gptme/llm_openai.py | 41 +++++++++++++++++++++++++++++------------ 1 file changed, 29 insertions(+), 12 deletions(-) diff --git a/gptme/llm_openai.py b/gptme/llm_openai.py index 03890033..af307409 100644 --- a/gptme/llm_openai.py +++ b/gptme/llm_openai.py @@ -1,6 +1,11 @@ import logging from collections.abc import Generator -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, cast +from openai.types.chat import ChatCompletionChunk +from openai.types.chat.chat_completion_chunk import ( + ChoiceDeltaToolCall, + ChoiceDeltaToolCallFunction, +) from .config import Config from .constants import TEMPERATURE, TOP_P @@ -99,7 +104,7 @@ def chat(messages: list[Message], model: str, tools) -> str: def stream(messages: list[Message], model: str, tools) -> Generator[str, None, None]: assert openai, "LLM not initialized" stop_reason = None - for chunk in openai.chat.completions.create( + for chunk_raw in openai.chat.completions.create( model=model, messages=msgs2dicts(_prep_o1(messages), openai=True), # type: ignore temperature=TEMPERATURE, @@ -115,16 +120,28 @@ def stream(messages: list[Message], model: str, tools) -> Generator[str, None, N openrouter_headers if "openrouter.ai" in str(openai.base_url) else {} ), ): - if not chunk.choices: # type: ignore + # Cast the chunk to the correct type + chunk = cast(ChatCompletionChunk, chunk_raw) + + if not chunk.choices: # Got a chunk with no choices, Azure always sends one of these at the start continue - stop_reason = chunk.choices[0].finish_reason # type: ignore - if content := chunk.choices[0].delta.content: - yield content - # TODO: propagate tool calls back better, this is a bit hacky - for tool_call in chunk.choices[0].delta.tool_calls or (): - if name := tool_call.function.name: - yield "@" + name + ": " - if args := tool_call.function.arguments: - yield args + + choice = chunk.choices[0] + stop_reason = choice.finish_reason + delta = choice.delta + + if delta.content is not None: + yield delta.content + + # Handle tool calls + if delta.tool_calls: + for tool_call in delta.tool_calls: + if isinstance(tool_call, ChoiceDeltaToolCall) and tool_call.function: + func = tool_call.function + if isinstance(func, ChoiceDeltaToolCallFunction): + if func.name: + yield "@" + func.name + ": " + if func.arguments: + yield func.arguments logger.debug(f"Stop reason: {stop_reason}")