From 2b6f5ccc36e8a8cd2c2f4344c11999b20ef1106a Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Tue, 5 Nov 2024 20:22:13 -0800 Subject: [PATCH] Refactor Chat and AsyncChat to use _Shared base class Refs https://github.com/simonw/llm/issues/507#issuecomment-2458692338 --- llm/default_plugins/openai_models.py | 79 +++++++++++++--------------- 1 file changed, 37 insertions(+), 42 deletions(-) diff --git a/llm/default_plugins/openai_models.py b/llm/default_plugins/openai_models.py index 301481c7..d9f1b15f 100644 --- a/llm/default_plugins/openai_models.py +++ b/llm/default_plugins/openai_models.py @@ -270,15 +270,11 @@ def validate_logit_bias(cls, logit_bias): return validated_logit_bias -class Chat(Model): +class _Shared: needs_key = "openai" key_env_var = "OPENAI_API_KEY" - default_max_tokens = None - def get_async_model(self): - return AsyncChat(self.model_id, self.key) - class Options(SharedOptions): json_object: Optional[bool] = Field( description="Output a valid JSON object {...}. Prompt must mention JSON.", @@ -370,40 +366,6 @@ def build_messages(self, prompt, conversation): messages.append({"role": "user", "content": attachment_message}) return messages - def execute(self, prompt, stream, response, conversation=None): - if prompt.system and not self.allows_system_prompt: - raise NotImplementedError("Model does not support system prompts") - messages = self.build_messages(prompt, conversation) - kwargs = self.build_kwargs(prompt, stream) - client = self.get_client() - if stream: - completion = client.chat.completions.create( - model=self.model_name or self.model_id, - messages=messages, - stream=True, - **kwargs, - ) - chunks = [] - for chunk in completion: - chunks.append(chunk) - try: - content = chunk.choices[0].delta.content - except IndexError: - content = None - if content is not None: - yield content - response.response_json = remove_dict_none_values(combine_chunks(chunks)) - else: - completion = client.chat.completions.create( - model=self.model_name or self.model_id, - messages=messages, - stream=False, - **kwargs, - ) - response.response_json = remove_dict_none_values(completion.model_dump()) - yield completion.choices[0].message.content - response._prompt_json = redact_data_urls({"messages": messages}) - def get_client(self, async_=False): kwargs = {} if self.api_base: @@ -441,10 +403,43 @@ def build_kwargs(self, prompt, stream): return kwargs -class AsyncChat(AsyncModel, Chat): - needs_key = "openai" - key_env_var = "OPENAI_API_KEY" +class Chat(_Shared, Model): + def execute(self, prompt, stream, response, conversation=None): + if prompt.system and not self.allows_system_prompt: + raise NotImplementedError("Model does not support system prompts") + messages = self.build_messages(prompt, conversation) + kwargs = self.build_kwargs(prompt, stream) + client = self.get_client() + if stream: + completion = client.chat.completions.create( + model=self.model_name or self.model_id, + messages=messages, + stream=True, + **kwargs, + ) + chunks = [] + for chunk in completion: + chunks.append(chunk) + try: + content = chunk.choices[0].delta.content + except IndexError: + content = None + if content is not None: + yield content + response.response_json = remove_dict_none_values(combine_chunks(chunks)) + else: + completion = client.chat.completions.create( + model=self.model_name or self.model_id, + messages=messages, + stream=False, + **kwargs, + ) + response.response_json = remove_dict_none_values(completion.model_dump()) + yield completion.choices[0].message.content + response._prompt_json = redact_data_urls({"messages": messages}) + +class AsyncChat(_Shared, AsyncModel): async def execute(self, prompt, stream, response, conversation=None): if prompt.system and not self.allows_system_prompt: raise NotImplementedError("Model does not support system prompts")