From 48434e641c995521bcbacd907a0ea693d6b53dcd Mon Sep 17 00:00:00 2001 From: "yuxin.wang" Date: Mon, 29 Jan 2024 16:51:10 +0800 Subject: [PATCH] Refactor model update methods (#18) Co-authored-by: wangyuxin --- generate/chat_completion/models/azure.py | 4 ++-- generate/chat_completion/models/baichuan.py | 8 +++---- generate/chat_completion/models/bailian.py | 8 +++---- generate/chat_completion/models/hunyuan.py | 8 +++---- generate/chat_completion/models/minimax.py | 8 +++---- .../chat_completion/models/minimax_pro.py | 10 ++++---- generate/chat_completion/models/openai.py | 8 +++---- generate/chat_completion/models/wenxin.py | 8 +++---- generate/chat_completion/models/zhipu.py | 24 ++++++++++++------- generate/image_generation/models/baidu.py | 4 ++-- generate/image_generation/models/openai.py | 4 ++-- generate/image_generation/models/qianfan.py | 4 ++-- generate/model.py | 10 ++++++-- generate/text_to_speech/models/minimax.py | 8 +++---- generate/text_to_speech/models/openai.py | 4 ++-- tests/test_completion.py | 8 +++---- 16 files changed, 71 insertions(+), 57 deletions(-) diff --git a/generate/chat_completion/models/azure.py b/generate/chat_completion/models/azure.py index 266cb13..a79a135 100644 --- a/generate/chat_completion/models/azure.py +++ b/generate/chat_completion/models/azure.py @@ -53,7 +53,7 @@ def _get_request_parameters(self, messages: Messages, parameters: OpenAIChatPara @override def generate(self, prompt: Prompt, **kwargs: Unpack[OpenAIChatParametersDict]) -> ChatCompletionOutput: messages = ensure_messages(prompt) - parameters = self.parameters.update_with_validate(**kwargs) + parameters = self.parameters.clone_with_changes(**kwargs) request_parameters = self._get_request_parameters(messages, parameters) response = self.http_client.post(request_parameters) output = parse_openai_model_reponse(response.json()) @@ -63,7 +63,7 @@ def generate(self, prompt: Prompt, **kwargs: Unpack[OpenAIChatParametersDict]) - @override async def async_generate(self, prompt: Prompt, **kwargs: Unpack[OpenAIChatParametersDict]) -> ChatCompletionOutput: messages = ensure_messages(prompt) - parameters = self.parameters.update_with_validate(**kwargs) + parameters = self.parameters.clone_with_changes(**kwargs) request_parameters = self._get_request_parameters(messages, parameters) response = await self.http_client.async_post(request_parameters=request_parameters) output = parse_openai_model_reponse(response.json()) diff --git a/generate/chat_completion/models/baichuan.py b/generate/chat_completion/models/baichuan.py index b474ce6..7da141e 100644 --- a/generate/chat_completion/models/baichuan.py +++ b/generate/chat_completion/models/baichuan.py @@ -119,7 +119,7 @@ def _get_request_parameters(self, messages: Messages, parameters: BaichuanChatPa @override def generate(self, prompt: Prompt, **kwargs: Unpack[BaichuanChatParametersDict]) -> ChatCompletionOutput: messages = ensure_messages(prompt) - parameters = self.parameters.update_with_validate(**kwargs) + parameters = self.parameters.clone_with_changes(**kwargs) request_parameters = self._get_request_parameters(messages, parameters) response = self.http_client.post(request_parameters=request_parameters) return self._parse_reponse(response.json()) @@ -127,7 +127,7 @@ def generate(self, prompt: Prompt, **kwargs: Unpack[BaichuanChatParametersDict]) @override async def async_generate(self, prompt: Prompt, **kwargs: Unpack[BaichuanChatParametersDict]) -> ChatCompletionOutput: messages = ensure_messages(prompt) - parameters = self.parameters.update_with_validate(**kwargs) + parameters = self.parameters.clone_with_changes(**kwargs) request_parameters = self._get_request_parameters(messages, parameters) response = await self.http_client.async_post(request_parameters=request_parameters) return self._parse_reponse(response.json()) @@ -164,7 +164,7 @@ def stream_generate( self, prompt: Prompt, **kwargs: Unpack[BaichuanChatParametersDict] ) -> Iterator[ChatCompletionStreamOutput]: messages = ensure_messages(prompt) - parameters = self.parameters.update_with_validate(**kwargs) + parameters = self.parameters.clone_with_changes(**kwargs) request_parameters = self._get_stream_request_parameters(messages, parameters) message = AssistantMessage(content='') is_start = True @@ -178,7 +178,7 @@ async def async_stream_generate( self, prompt: Prompt, **kwargs: Unpack[BaichuanChatParametersDict] ) -> AsyncIterator[ChatCompletionStreamOutput]: messages = ensure_messages(prompt) - parameters = self.parameters.update_with_validate(**kwargs) + parameters = self.parameters.clone_with_changes(**kwargs) request_parameters = self._get_stream_request_parameters(messages, parameters) message = AssistantMessage(content='') is_start = True diff --git a/generate/chat_completion/models/bailian.py b/generate/chat_completion/models/bailian.py index 69050f1..56783e3 100644 --- a/generate/chat_completion/models/bailian.py +++ b/generate/chat_completion/models/bailian.py @@ -137,7 +137,7 @@ def _get_request_parameters(self, messages: Messages, parameters: BailianChatPar @override def generate(self, prompt: Prompt, **kwargs: Unpack[BailianChatParametersDict]) -> ChatCompletionOutput: messages = ensure_messages(prompt) - parameters = self.parameters.update_with_validate(**kwargs) + parameters = self.parameters.clone_with_changes(**kwargs) request_parameters = self._get_request_parameters(messages, parameters) response = self.http_client.post(request_parameters=request_parameters) return self._parse_reponse(response.json()) @@ -145,7 +145,7 @@ def generate(self, prompt: Prompt, **kwargs: Unpack[BailianChatParametersDict]) @override async def async_generate(self, prompt: Prompt, **kwargs: Unpack[BailianChatParametersDict]) -> ChatCompletionOutput: messages = ensure_messages(prompt) - parameters = self.parameters.update_with_validate(**kwargs) + parameters = self.parameters.clone_with_changes(**kwargs) request_parameters = self._get_request_parameters(messages, parameters) response = await self.http_client.async_post(request_parameters=request_parameters) return self._parse_reponse(response.json()) @@ -176,7 +176,7 @@ def stream_generate( self, prompt: Prompt, **kwargs: Unpack[BailianChatParametersDict] ) -> Iterator[ChatCompletionStreamOutput]: messages = ensure_messages(prompt) - parameters = self.parameters.update_with_validate(**kwargs) + parameters = self.parameters.clone_with_changes(**kwargs) request_parameters = self._get_stream_request_parameters(messages, parameters) message = AssistantMessage(content='') is_start = True @@ -195,7 +195,7 @@ async def async_stream_generate( self, prompt: Prompt, **kwargs: Unpack[BailianChatParametersDict] ) -> AsyncIterator[ChatCompletionStreamOutput]: messages = ensure_messages(prompt) - parameters = self.parameters.update_with_validate(**kwargs) + parameters = self.parameters.clone_with_changes(**kwargs) request_parameters = self._get_stream_request_parameters(messages, parameters) message = AssistantMessage(content='') is_start = True diff --git a/generate/chat_completion/models/hunyuan.py b/generate/chat_completion/models/hunyuan.py index 7e12226..a450cd2 100644 --- a/generate/chat_completion/models/hunyuan.py +++ b/generate/chat_completion/models/hunyuan.py @@ -93,7 +93,7 @@ def _get_request_parameters(self, messages: Messages, parameters: HunyuanChatPar @override def generate(self, prompt: Prompt, **kwargs: Unpack[HunyuanChatParametersDict]) -> ChatCompletionOutput: messages = ensure_messages(prompt) - parameters = self.parameters.update_with_validate(**kwargs) + parameters = self.parameters.clone_with_changes(**kwargs) request_parameters = self._get_request_parameters(messages, parameters) response = self.http_client.post(request_parameters=request_parameters) return self._parse_reponse(response.json()) @@ -101,7 +101,7 @@ def generate(self, prompt: Prompt, **kwargs: Unpack[HunyuanChatParametersDict]) @override async def async_generate(self, prompt: Prompt, **kwargs: Unpack[HunyuanChatParametersDict]) -> ChatCompletionOutput: messages = ensure_messages(prompt) - parameters = self.parameters.update_with_validate(**kwargs) + parameters = self.parameters.clone_with_changes(**kwargs) request_parameters = self._get_request_parameters(messages, parameters) response = await self.http_client.async_post(request_parameters=request_parameters) return self._parse_reponse(response.json()) @@ -136,7 +136,7 @@ def stream_generate( self, prompt: Prompt, **kwargs: Unpack[HunyuanChatParametersDict] ) -> Iterator[ChatCompletionStreamOutput]: messages = ensure_messages(prompt) - parameters = self.parameters.update_with_validate(**kwargs) + parameters = self.parameters.clone_with_changes(**kwargs) request_parameters = self._get_stream_request_parameters(messages, parameters) message = AssistantMessage(content='') is_start = True @@ -150,7 +150,7 @@ async def async_stream_generate( self, prompt: Prompt, **kwargs: Unpack[HunyuanChatParametersDict] ) -> AsyncIterator[ChatCompletionStreamOutput]: messages = ensure_messages(prompt) - parameters = self.parameters.update_with_validate(**kwargs) + parameters = self.parameters.clone_with_changes(**kwargs) request_parameters = self._get_stream_request_parameters(messages, parameters) message = AssistantMessage(content='') is_start = True diff --git a/generate/chat_completion/models/minimax.py b/generate/chat_completion/models/minimax.py index acde1e6..e982342 100644 --- a/generate/chat_completion/models/minimax.py +++ b/generate/chat_completion/models/minimax.py @@ -130,7 +130,7 @@ def _get_request_parameters(self, messages: Messages, parameters: MinimaxChatPar @override def generate(self, prompt: Prompt, **kwargs: Unpack[MinimaxChatParametersDict]) -> ChatCompletionOutput: messages = ensure_messages(prompt) - parameters = self.parameters.update_with_validate(**kwargs) + parameters = self.parameters.clone_with_changes(**kwargs) request_parameters = self._get_request_parameters(messages, parameters) response = self.http_client.post(request_parameters=request_parameters) return self._parse_reponse(response.json()) @@ -138,7 +138,7 @@ def generate(self, prompt: Prompt, **kwargs: Unpack[MinimaxChatParametersDict]) @override async def async_generate(self, prompt: Prompt, **kwargs: Unpack[MinimaxChatParametersDict]) -> ChatCompletionOutput: messages = ensure_messages(prompt) - parameters = self.parameters.update_with_validate(**kwargs) + parameters = self.parameters.clone_with_changes(**kwargs) request_parameters = self._get_request_parameters(messages, parameters) response = await self.http_client.async_post(request_parameters=request_parameters) return self._parse_reponse(response.json()) @@ -171,7 +171,7 @@ def stream_generate( self, prompt: Prompt, **kwargs: Unpack[MinimaxChatParametersDict] ) -> Iterator[ChatCompletionStreamOutput]: messages = ensure_messages(prompt) - parameters = self.parameters.update_with_validate(**kwargs) + parameters = self.parameters.clone_with_changes(**kwargs) request_parameters = self._get_stream_request_parameters(messages, parameters) message = AssistantMessage(content='') is_start = True @@ -185,7 +185,7 @@ async def async_stream_generate( self, prompt: Prompt, **kwargs: Unpack[MinimaxChatParametersDict] ) -> AsyncIterator[ChatCompletionStreamOutput]: messages = ensure_messages(prompt) - parameters = self.parameters.update_with_validate(**kwargs) + parameters = self.parameters.clone_with_changes(**kwargs) request_parameters = self._get_stream_request_parameters(messages, parameters) message = AssistantMessage(content='') is_start = True diff --git a/generate/chat_completion/models/minimax_pro.py b/generate/chat_completion/models/minimax_pro.py index bb02dd9..7767022 100644 --- a/generate/chat_completion/models/minimax_pro.py +++ b/generate/chat_completion/models/minimax_pro.py @@ -199,7 +199,7 @@ def process(self, response: ResponseValue) -> ChatCompletionStreamOutput: model_info=self.model_info, message=self.message, finish_reason=response['choices'][0]['finish_reason'], - cost=calculate_cost(model_name=self.model_info.name , usage=response['usage']), + cost=calculate_cost(model_name=self.model_info.name, usage=response['usage']), extra={ 'input_sensitive': response['input_sensitive'], 'output_sensitive': response['output_sensitive'], @@ -299,7 +299,7 @@ def _get_request_parameters(self, messages: Messages, parameters: MinimaxProChat @override def generate(self, prompt: Prompt, **kwargs: Unpack[MinimaxProChatParametersDict]) -> ChatCompletionOutput: messages = ensure_messages(prompt) - parameters = self.parameters.update_with_validate(**kwargs) + parameters = self.parameters.clone_with_changes(**kwargs) request_parameters = self._get_request_parameters(messages, parameters) response = self.http_client.post(request_parameters=request_parameters) return self._parse_reponse(response.json()) @@ -307,7 +307,7 @@ def generate(self, prompt: Prompt, **kwargs: Unpack[MinimaxProChatParametersDict @override async def async_generate(self, prompt: Prompt, **kwargs: Unpack[MinimaxProChatParametersDict]) -> ChatCompletionOutput: messages = ensure_messages(prompt) - parameters = self.parameters.update_with_validate(**kwargs) + parameters = self.parameters.clone_with_changes(**kwargs) request_parameters = self._get_request_parameters(messages, parameters) response = await self.http_client.async_post(request_parameters=request_parameters) return self._parse_reponse(response.json()) @@ -344,7 +344,7 @@ def stream_generate( self, prompt: Prompt, **kwargs: Unpack[MinimaxProChatParametersDict] ) -> Iterator[ChatCompletionStreamOutput]: messages = ensure_messages(prompt) - parameters = self.parameters.update_with_validate(**kwargs) + parameters = self.parameters.clone_with_changes(**kwargs) request_parameters = self._get_stream_request_parameters(messages, parameters) stream_processor = _StreamResponseProcessor(model_info=self.model_info) for line in self.http_client.stream_post(request_parameters=request_parameters): @@ -355,7 +355,7 @@ async def async_stream_generate( self, prompt: Prompt, **kwargs: Unpack[MinimaxProChatParametersDict] ) -> AsyncIterator[ChatCompletionStreamOutput]: messages = ensure_messages(prompt) - parameters = self.parameters.update_with_validate(**kwargs) + parameters = self.parameters.clone_with_changes(**kwargs) request_parameters = self._get_stream_request_parameters(messages, parameters) stream_processor = _StreamResponseProcessor(model_info=self.model_info) async for line in self.http_client.async_stream_post(request_parameters=request_parameters): diff --git a/generate/chat_completion/models/openai.py b/generate/chat_completion/models/openai.py index f9518a7..4b1f9fd 100644 --- a/generate/chat_completion/models/openai.py +++ b/generate/chat_completion/models/openai.py @@ -375,7 +375,7 @@ def _get_request_parameters(self, messages: Messages, parameters: OpenAIChatPara @override def generate(self, prompt: Prompt, **kwargs: Unpack[OpenAIChatParametersDict]) -> ChatCompletionOutput: messages = ensure_messages(prompt) - parameters = self.parameters.update_with_validate(**kwargs) + parameters = self.parameters.clone_with_changes(**kwargs) request_parameters = self._get_request_parameters(messages, parameters) response = self.http_client.post(request_parameters) return parse_openai_model_reponse(response.json()) @@ -383,7 +383,7 @@ def generate(self, prompt: Prompt, **kwargs: Unpack[OpenAIChatParametersDict]) - @override async def async_generate(self, prompt: Prompt, **kwargs: Unpack[OpenAIChatParametersDict]) -> ChatCompletionOutput: messages = ensure_messages(prompt) - parameters = self.parameters.update_with_validate(**kwargs) + parameters = self.parameters.clone_with_changes(**kwargs) request_parameters = self._get_request_parameters(messages, parameters) response = await self.http_client.async_post(request_parameters=request_parameters) return parse_openai_model_reponse(response.json()) @@ -398,7 +398,7 @@ def stream_generate( self, prompt: Prompt, **kwargs: Unpack[OpenAIChatParametersDict] ) -> Iterator[ChatCompletionStreamOutput]: messages = ensure_messages(prompt) - parameters = self.parameters.update_with_validate(**kwargs) + parameters = self.parameters.clone_with_changes(**kwargs) request_parameters = self._get_stream_request_parameters(messages, parameters) stream_processor = _StreamResponseProcessor() is_finish = False @@ -417,7 +417,7 @@ async def async_stream_generate( self, prompt: Prompt, **kwargs: Unpack[OpenAIChatParametersDict] ) -> AsyncIterator[ChatCompletionStreamOutput]: messages = ensure_messages(prompt) - parameters = self.parameters.update_with_validate(**kwargs) + parameters = self.parameters.clone_with_changes(**kwargs) request_parameters = self._get_stream_request_parameters(messages, parameters) stream_processor = _StreamResponseProcessor() is_finish = False diff --git a/generate/chat_completion/models/wenxin.py b/generate/chat_completion/models/wenxin.py index 3152ba2..f1e7a80 100644 --- a/generate/chat_completion/models/wenxin.py +++ b/generate/chat_completion/models/wenxin.py @@ -160,7 +160,7 @@ def _get_request_parameters(self, messages: Messages, parameters: WenxinChatPara @override def generate(self, prompt: Prompt, **kwargs: Unpack[WenxinChatParametersDict]) -> ChatCompletionOutput: messages = ensure_messages(prompt) - parameters = self.parameters.update_with_validate(**kwargs) + parameters = self.parameters.clone_with_changes(**kwargs) request_parameters = self._get_request_parameters(messages, parameters) response = self.http_client.post(request_parameters) return self._parse_reponse(response.json()) @@ -168,7 +168,7 @@ def generate(self, prompt: Prompt, **kwargs: Unpack[WenxinChatParametersDict]) - @override async def async_generate(self, prompt: Prompt, **kwargs: Unpack[WenxinChatParametersDict]) -> ChatCompletionOutput: messages = ensure_messages(prompt) - parameters = self.parameters.update_with_validate(**kwargs) + parameters = self.parameters.clone_with_changes(**kwargs) request_parameters = self._get_request_parameters(messages, parameters) response = await self.http_client.async_post(request_parameters=request_parameters) return self._parse_reponse(response.json()) @@ -206,7 +206,7 @@ def stream_generate( self, prompt: Prompt, **kwargs: Unpack[WenxinChatParametersDict] ) -> Iterator[ChatCompletionStreamOutput]: messages = ensure_messages(prompt) - parameters = self.parameters.update_with_validate(**kwargs) + parameters = self.parameters.clone_with_changes(**kwargs) if parameters.functions: raise ValueError('stream_generate does not support functions') request_parameters = self._get_stream_request_parameters(messages, parameters) @@ -222,7 +222,7 @@ async def async_stream_generate( self, prompt: Prompt, **kwargs: Unpack[WenxinChatParametersDict] ) -> AsyncIterator[ChatCompletionStreamOutput]: messages = ensure_messages(prompt) - parameters = self.parameters.update_with_validate(**kwargs) + parameters = self.parameters.clone_with_changes(**kwargs) if parameters.functions: raise ValueError('stream_generate does not support functions') request_parameters = self._get_stream_request_parameters(messages, parameters) diff --git a/generate/chat_completion/models/zhipu.py b/generate/chat_completion/models/zhipu.py index 14efe9d..ab3fefb 100644 --- a/generate/chat_completion/models/zhipu.py +++ b/generate/chat_completion/models/zhipu.py @@ -3,6 +3,7 @@ import json from typing import Any, AsyncIterator, ClassVar, Iterator, List, Literal, Optional, Union +from pydantic import field_validator from typing_extensions import NotRequired, Self, TypedDict, Unpack, override from generate.chat_completion.base import ChatCompletionModel @@ -75,6 +76,13 @@ class ZhipuChatParameters(ModelParameters): tools: Optional[List[ZhipuTool]] = None tool_choice: Optional[str] = None + @field_validator('temperature') + @classmethod + def can_not_equal_zero(cls, v: Optional[Temperature]) -> Optional[Temperature]: + if v == 0: + return 0.01 + return v + class ZhipuChatParametersDict(ModelParametersDict, total=False): temperature: Optional[Temperature] @@ -331,7 +339,7 @@ def _convert_messages(self, messages: Messages) -> list[ZhipuMessage]: @override def generate(self, prompt: Prompt, **kwargs: Unpack[ZhipuChatParametersDict]) -> ChatCompletionOutput: messages = ensure_messages(prompt) - parameters = self.parameters.update_with_validate(**kwargs) + parameters = self.parameters.clone_with_changes(**kwargs) request_parameters = self._get_request_parameters(messages, parameters) response = self.http_client.post(request_parameters=request_parameters) return self._parse_reponse(response.json()) @@ -339,7 +347,7 @@ def generate(self, prompt: Prompt, **kwargs: Unpack[ZhipuChatParametersDict]) -> @override async def async_generate(self, prompt: Prompt, **kwargs: Unpack[ZhipuChatParametersDict]) -> ChatCompletionOutput: messages = ensure_messages(prompt) - parameters = self.parameters.update_with_validate(**kwargs) + parameters = self.parameters.clone_with_changes(**kwargs) request_parameters = self._get_request_parameters(messages, parameters) response = await self.http_client.async_post(request_parameters=request_parameters) return self._parse_reponse(response.json()) @@ -349,7 +357,7 @@ def stream_generate( self, prompt: Prompt, **kwargs: Unpack[ZhipuChatParametersDict] ) -> Iterator[ChatCompletionStreamOutput]: messages = ensure_messages(prompt) - parameters = self.parameters.update_with_validate(**kwargs) + parameters = self.parameters.clone_with_changes(**kwargs) request_parameters = self._get_stream_request_parameters(messages, parameters) stream_processor = _StreamResponseProcessor() is_finish = False @@ -367,7 +375,7 @@ async def async_stream_generate( self, prompt: Prompt, **kwargs: Unpack[ZhipuChatParametersDict] ) -> AsyncIterator[ChatCompletionStreamOutput]: messages = ensure_messages(prompt) - parameters = self.parameters.update_with_validate(**kwargs) + parameters = self.parameters.clone_with_changes(**kwargs) request_parameters = self._get_stream_request_parameters(messages, parameters) stream_processor = _StreamResponseProcessor() is_finish = False @@ -468,7 +476,7 @@ def _get_stream_request_parameters(self, messages: Messages, parameters: ModelPa @override def generate(self, prompt: Prompt, **kwargs: Unpack[ZhipuCharacterChatParametersDict]) -> ChatCompletionOutput: messages = ensure_messages(prompt) - parameters = self.parameters.update_with_validate(**kwargs) + parameters = self.parameters.clone_with_changes(**kwargs) request_parameters = self._get_request_parameters(messages, parameters) response = self.http_client.post(request_parameters=request_parameters) return self._parse_reponse(response.json()) @@ -476,7 +484,7 @@ def generate(self, prompt: Prompt, **kwargs: Unpack[ZhipuCharacterChatParameters @override async def async_generate(self, prompt: Prompt, **kwargs: Unpack[ZhipuCharacterChatParametersDict]) -> ChatCompletionOutput: messages = ensure_messages(prompt) - parameters = self.parameters.update_with_validate(**kwargs) + parameters = self.parameters.clone_with_changes(**kwargs) request_parameters = self._get_request_parameters(messages, parameters) response = await self.http_client.async_post(request_parameters=request_parameters) return self._parse_reponse(response.json()) @@ -486,7 +494,7 @@ def stream_generate( self, prompt: Prompt, **kwargs: Unpack[ZhipuCharacterChatParametersDict] ) -> Iterator[ChatCompletionStreamOutput]: messages = ensure_messages(prompt) - parameters = self.parameters.update_with_validate(**kwargs) + parameters = self.parameters.clone_with_changes(**kwargs) request_parameters = self._get_stream_request_parameters(messages, parameters) message = AssistantMessage(content='') is_start = True @@ -510,7 +518,7 @@ async def async_stream_generate( self, prompt: Prompt, **kwargs: Unpack[ZhipuCharacterChatParametersDict] ) -> AsyncIterator[ChatCompletionStreamOutput]: messages = ensure_messages(prompt) - parameters = self.parameters.update_with_validate(**kwargs) + parameters = self.parameters.clone_with_changes(**kwargs) request_parameters = self._get_stream_request_parameters(messages, parameters) message = AssistantMessage(content='') is_start = True diff --git a/generate/image_generation/models/baidu.py b/generate/image_generation/models/baidu.py index 5dd908e..c5b7f2f 100644 --- a/generate/image_generation/models/baidu.py +++ b/generate/image_generation/models/baidu.py @@ -88,7 +88,7 @@ def _get_request_parameters(self, prompt: str, parameters: BaiduImageGenerationP @override def generate(self, prompt: str, **kwargs: Unpack[BaiduImageGenerationParametersDict]) -> ImageGenerationOutput: - parameters = self.parameters.update_with_validate(**kwargs) + parameters = self.parameters.clone_with_changes(**kwargs) request_parameters = self._get_request_parameters(prompt, parameters) response = self.http_client.post(request_parameters=request_parameters) task_id = response.json()['data']['task_id'] @@ -106,7 +106,7 @@ def generate(self, prompt: str, **kwargs: Unpack[BaiduImageGenerationParametersD @override async def async_generate(self, prompt: str, **kwargs: Unpack[BaiduImageGenerationParametersDict]) -> ImageGenerationOutput: - parameters = self.parameters.update_with_validate(**kwargs) + parameters = self.parameters.clone_with_changes(**kwargs) request_parameters = self._get_request_parameters(prompt, parameters) response = await self.http_client.async_post(request_parameters=request_parameters) image_urls = await self._async_get_image_urls(response.json()['data']['task_id']) diff --git a/generate/image_generation/models/openai.py b/generate/image_generation/models/openai.py index d670d13..728a42d 100644 --- a/generate/image_generation/models/openai.py +++ b/generate/image_generation/models/openai.py @@ -113,7 +113,7 @@ def _get_request_parameters(self, prompt: str, parameters: OpenAIImageGeneration @override def generate(self, prompt: str, **kwargs: Unpack[OpenAIImageGenerationParametersDict]) -> ImageGenerationOutput: self._check_prompt(prompt) - parameters = self.parameters.update_with_validate(**kwargs) + parameters = self.parameters.clone_with_changes(**kwargs) request_parameters = self._get_request_parameters(prompt, parameters) response = self.http_client.post(request_parameters=request_parameters) return self._construct_model_output(prompt, parameters, response) @@ -121,7 +121,7 @@ def generate(self, prompt: str, **kwargs: Unpack[OpenAIImageGenerationParameters @override async def async_generate(self, prompt: str, **kwargs: Unpack[OpenAIImageGenerationParametersDict]) -> ImageGenerationOutput: self._check_prompt(prompt) - parameters = self.parameters.update_with_validate(**kwargs) + parameters = self.parameters.clone_with_changes(**kwargs) request_parameters = self._get_request_parameters(prompt, parameters) response = await self.http_client.async_post(request_parameters=request_parameters) return self._construct_model_output(prompt, parameters, response) diff --git a/generate/image_generation/models/qianfan.py b/generate/image_generation/models/qianfan.py index c054f69..c71a002 100644 --- a/generate/image_generation/models/qianfan.py +++ b/generate/image_generation/models/qianfan.py @@ -57,7 +57,7 @@ def __init__( @override def generate(self, prompt: str, **kwargs: Unpack[QianfanImageGenerationParametersDict]) -> ImageGenerationOutput: - parameters = self.parameters.update_with_validate(**kwargs) + parameters = self.parameters.clone_with_changes(**kwargs) request_parameters = self._get_request_parameters(prompt, parameters) response = self.http_client.post(request_parameters=request_parameters) return self._construct_model_output(prompt, response.json()) @@ -66,7 +66,7 @@ def generate(self, prompt: str, **kwargs: Unpack[QianfanImageGenerationParameter async def async_generate( self, prompt: str, **kwargs: Unpack[QianfanImageGenerationParametersDict] ) -> ImageGenerationOutput: - parameters = self.parameters.update_with_validate(**kwargs) + parameters = self.parameters.clone_with_changes(**kwargs) request_parameters = self._get_request_parameters(prompt, parameters) response = await self.http_client.async_post(request_parameters=request_parameters) return self._construct_model_output(prompt, response.json()) diff --git a/generate/model.py b/generate/model.py index 06281d5..891c64f 100644 --- a/generate/model.py +++ b/generate/model.py @@ -8,11 +8,17 @@ class ModelParameters(BaseModel): + model_config = ConfigDict(validate_assignment=True) + def custom_model_dump(self) -> dict[str, Any]: return {**self.model_dump(exclude_none=True, by_alias=True), **self.model_dump(exclude_unset=True, by_alias=True)} - def update_with_validate(self, **kwargs: Any) -> Self: - return self.__class__.model_validate({**self.model_dump(exclude_unset=True), **kwargs}) # type: ignore + def clone_with_changes(self, **changes: Any) -> Self: + return self.__class__.model_validate({**self.model_dump(exclude_unset=True), **changes}) # type: ignore + + def model_update(self, **kwargs: Any) -> None: + for k, v in kwargs.items(): + setattr(self, k, v) class ModelParametersDict(TypedDict, total=False): diff --git a/generate/text_to_speech/models/minimax.py b/generate/text_to_speech/models/minimax.py index fbcd490..8211fd5 100644 --- a/generate/text_to_speech/models/minimax.py +++ b/generate/text_to_speech/models/minimax.py @@ -81,7 +81,7 @@ def _get_request_parameters(self, text: str, parameters: MinimaxSpeechParameters } def generate(self, prompt: str, **kwargs: Unpack[MinimaxSpeechParametersDict]) -> TextToSpeechOutput: - parameters = self.parameters.update_with_validate(**kwargs) + parameters = self.parameters.clone_with_changes(**kwargs) request_parameters = self._get_request_parameters(prompt, parameters) response = self.http_client.post(request_parameters=request_parameters) return TextToSpeechOutput( @@ -92,7 +92,7 @@ def generate(self, prompt: str, **kwargs: Unpack[MinimaxSpeechParametersDict]) - ) async def async_generate(self, prompt: str, **kwargs: Unpack[MinimaxSpeechParametersDict]) -> TextToSpeechOutput: - parameters = self.parameters.update_with_validate(**kwargs) + parameters = self.parameters.clone_with_changes(**kwargs) request_parameters = self._get_request_parameters(prompt, parameters) response = await self.http_client.async_post(request_parameters=request_parameters) return TextToSpeechOutput( @@ -152,7 +152,7 @@ def _get_request_parameters(self, text: str, parameters: MinimaxProSpeechParamet @override def generate(self, prompt: str, **kwargs: Unpack[MinimaxProSpeechParametersDict]) -> TextToSpeechOutput: - parameters = self.parameters.update_with_validate(**kwargs) + parameters = self.parameters.clone_with_changes(**kwargs) request_parameters = self._get_request_parameters(prompt, parameters) response = self.http_client.post(request_parameters=request_parameters) response_data = response.json() @@ -171,7 +171,7 @@ def generate(self, prompt: str, **kwargs: Unpack[MinimaxProSpeechParametersDict] @override async def async_generate(self, prompt: str, **kwargs: Unpack[MinimaxProSpeechParametersDict]) -> TextToSpeechOutput: - parameters = self.parameters.update_with_validate(**kwargs) + parameters = self.parameters.clone_with_changes(**kwargs) request_parameters = self._get_request_parameters(prompt, parameters) response = await self.http_client.async_post(request_parameters=request_parameters) response_data = response.json() diff --git a/generate/text_to_speech/models/openai.py b/generate/text_to_speech/models/openai.py index 4d3bf9c..ba3db05 100644 --- a/generate/text_to_speech/models/openai.py +++ b/generate/text_to_speech/models/openai.py @@ -56,7 +56,7 @@ def _get_request_parameters(self, text: str, parameters: OpenAISpeechParameters) @override def generate(self, prompt: str, **kwargs: Unpack[OpenAISpeechParametersDict]) -> TextToSpeechOutput: - parameters = self.parameters.update_with_validate(**kwargs) + parameters = self.parameters.clone_with_changes(**kwargs) request_parameters = self._get_request_parameters(prompt, parameters) response = self.http_client.post(request_parameters=request_parameters) return TextToSpeechOutput( @@ -68,7 +68,7 @@ def generate(self, prompt: str, **kwargs: Unpack[OpenAISpeechParametersDict]) -> @override async def async_generate(self, prompt: str, **kwargs: Unpack[OpenAISpeechParametersDict]) -> TextToSpeechOutput: - parameters = self.parameters.update_with_validate(**kwargs) + parameters = self.parameters.clone_with_changes(**kwargs) request_parameters = self._get_request_parameters(prompt, parameters) response = await self.http_client.async_post(request_parameters=request_parameters) return TextToSpeechOutput( diff --git a/tests/test_completion.py b/tests/test_completion.py index 7e46e60..2f8d0db 100644 --- a/tests/test_completion.py +++ b/tests/test_completion.py @@ -33,19 +33,19 @@ def __init__(self, parameters: FakeChatParameters | None = None) -> None: def generate(self, prompt: Prompt, **kwargs: Unpack[ModelParametersDict]) -> ChatCompletionOutput: messages = ensure_messages(prompt) - parameters = self.parameters.update_with_validate(**kwargs) + parameters = self.parameters.clone_with_changes(**kwargs) content = f'{parameters.prefix}{messages[-1].content}' return ChatCompletionOutput(model_info=self.model_info, message=AssistantMessage(content=content)) async def async_generate(self, prompt: Prompt, **kwargs: Unpack[ModelParametersDict]) -> ChatCompletionOutput: messages = ensure_messages(prompt) - parameters = self.parameters.update_with_validate(**kwargs) + parameters = self.parameters.clone_with_changes(**kwargs) content = f'{parameters.prefix}{messages[-1].content}' return ChatCompletionOutput(model_info=self.model_info, message=AssistantMessage(content=content)) def stream_generate(self, prompt: Prompt, **kwargs: Unpack[ModelParametersDict]) -> Iterator[ChatCompletionStreamOutput]: messages = ensure_messages(prompt) - parameters = self.parameters.update_with_validate(**kwargs) + parameters = self.parameters.clone_with_changes(**kwargs) content = f'{parameters.prefix}{messages[-1].content}' yield ChatCompletionStreamOutput( model_info=self.model_info, @@ -62,7 +62,7 @@ async def async_stream_generate( self, prompt: Prompt, **kwargs: Unpack[ModelParametersDict] ) -> AsyncIterator[ChatCompletionStreamOutput]: messages = ensure_messages(prompt) - parameters = self.parameters.update_with_validate(**kwargs) + parameters = self.parameters.clone_with_changes(**kwargs) content = f'{parameters.prefix}{messages[-1].content}' yield ChatCompletionStreamOutput( model_info=self.model_info,