Skip to content

Commit

Permalink
Refactor model update methods (#18)
Browse files Browse the repository at this point in the history
Co-authored-by: wangyuxin <[email protected]>
  • Loading branch information
wangyuxinwhy and wangyuxin authored Jan 29, 2024
1 parent 8e4c6c2 commit 48434e6
Show file tree
Hide file tree
Showing 16 changed files with 71 additions and 57 deletions.
4 changes: 2 additions & 2 deletions generate/chat_completion/models/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def _get_request_parameters(self, messages: Messages, parameters: OpenAIChatPara
@override
def generate(self, prompt: Prompt, **kwargs: Unpack[OpenAIChatParametersDict]) -> ChatCompletionOutput:
messages = ensure_messages(prompt)
parameters = self.parameters.update_with_validate(**kwargs)
parameters = self.parameters.clone_with_changes(**kwargs)
request_parameters = self._get_request_parameters(messages, parameters)
response = self.http_client.post(request_parameters)
output = parse_openai_model_reponse(response.json())
Expand All @@ -63,7 +63,7 @@ def generate(self, prompt: Prompt, **kwargs: Unpack[OpenAIChatParametersDict]) -
@override
async def async_generate(self, prompt: Prompt, **kwargs: Unpack[OpenAIChatParametersDict]) -> ChatCompletionOutput:
messages = ensure_messages(prompt)
parameters = self.parameters.update_with_validate(**kwargs)
parameters = self.parameters.clone_with_changes(**kwargs)
request_parameters = self._get_request_parameters(messages, parameters)
response = await self.http_client.async_post(request_parameters=request_parameters)
output = parse_openai_model_reponse(response.json())
Expand Down
8 changes: 4 additions & 4 deletions generate/chat_completion/models/baichuan.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,15 +119,15 @@ def _get_request_parameters(self, messages: Messages, parameters: BaichuanChatPa
@override
def generate(self, prompt: Prompt, **kwargs: Unpack[BaichuanChatParametersDict]) -> ChatCompletionOutput:
messages = ensure_messages(prompt)
parameters = self.parameters.update_with_validate(**kwargs)
parameters = self.parameters.clone_with_changes(**kwargs)
request_parameters = self._get_request_parameters(messages, parameters)
response = self.http_client.post(request_parameters=request_parameters)
return self._parse_reponse(response.json())

@override
async def async_generate(self, prompt: Prompt, **kwargs: Unpack[BaichuanChatParametersDict]) -> ChatCompletionOutput:
messages = ensure_messages(prompt)
parameters = self.parameters.update_with_validate(**kwargs)
parameters = self.parameters.clone_with_changes(**kwargs)
request_parameters = self._get_request_parameters(messages, parameters)
response = await self.http_client.async_post(request_parameters=request_parameters)
return self._parse_reponse(response.json())
Expand Down Expand Up @@ -164,7 +164,7 @@ def stream_generate(
self, prompt: Prompt, **kwargs: Unpack[BaichuanChatParametersDict]
) -> Iterator[ChatCompletionStreamOutput]:
messages = ensure_messages(prompt)
parameters = self.parameters.update_with_validate(**kwargs)
parameters = self.parameters.clone_with_changes(**kwargs)
request_parameters = self._get_stream_request_parameters(messages, parameters)
message = AssistantMessage(content='')
is_start = True
Expand All @@ -178,7 +178,7 @@ async def async_stream_generate(
self, prompt: Prompt, **kwargs: Unpack[BaichuanChatParametersDict]
) -> AsyncIterator[ChatCompletionStreamOutput]:
messages = ensure_messages(prompt)
parameters = self.parameters.update_with_validate(**kwargs)
parameters = self.parameters.clone_with_changes(**kwargs)
request_parameters = self._get_stream_request_parameters(messages, parameters)
message = AssistantMessage(content='')
is_start = True
Expand Down
8 changes: 4 additions & 4 deletions generate/chat_completion/models/bailian.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,15 +137,15 @@ def _get_request_parameters(self, messages: Messages, parameters: BailianChatPar
@override
def generate(self, prompt: Prompt, **kwargs: Unpack[BailianChatParametersDict]) -> ChatCompletionOutput:
messages = ensure_messages(prompt)
parameters = self.parameters.update_with_validate(**kwargs)
parameters = self.parameters.clone_with_changes(**kwargs)
request_parameters = self._get_request_parameters(messages, parameters)
response = self.http_client.post(request_parameters=request_parameters)
return self._parse_reponse(response.json())

@override
async def async_generate(self, prompt: Prompt, **kwargs: Unpack[BailianChatParametersDict]) -> ChatCompletionOutput:
messages = ensure_messages(prompt)
parameters = self.parameters.update_with_validate(**kwargs)
parameters = self.parameters.clone_with_changes(**kwargs)
request_parameters = self._get_request_parameters(messages, parameters)
response = await self.http_client.async_post(request_parameters=request_parameters)
return self._parse_reponse(response.json())
Expand Down Expand Up @@ -176,7 +176,7 @@ def stream_generate(
self, prompt: Prompt, **kwargs: Unpack[BailianChatParametersDict]
) -> Iterator[ChatCompletionStreamOutput]:
messages = ensure_messages(prompt)
parameters = self.parameters.update_with_validate(**kwargs)
parameters = self.parameters.clone_with_changes(**kwargs)
request_parameters = self._get_stream_request_parameters(messages, parameters)
message = AssistantMessage(content='')
is_start = True
Expand All @@ -195,7 +195,7 @@ async def async_stream_generate(
self, prompt: Prompt, **kwargs: Unpack[BailianChatParametersDict]
) -> AsyncIterator[ChatCompletionStreamOutput]:
messages = ensure_messages(prompt)
parameters = self.parameters.update_with_validate(**kwargs)
parameters = self.parameters.clone_with_changes(**kwargs)
request_parameters = self._get_stream_request_parameters(messages, parameters)
message = AssistantMessage(content='')
is_start = True
Expand Down
8 changes: 4 additions & 4 deletions generate/chat_completion/models/hunyuan.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,15 +93,15 @@ def _get_request_parameters(self, messages: Messages, parameters: HunyuanChatPar
@override
def generate(self, prompt: Prompt, **kwargs: Unpack[HunyuanChatParametersDict]) -> ChatCompletionOutput:
messages = ensure_messages(prompt)
parameters = self.parameters.update_with_validate(**kwargs)
parameters = self.parameters.clone_with_changes(**kwargs)
request_parameters = self._get_request_parameters(messages, parameters)
response = self.http_client.post(request_parameters=request_parameters)
return self._parse_reponse(response.json())

@override
async def async_generate(self, prompt: Prompt, **kwargs: Unpack[HunyuanChatParametersDict]) -> ChatCompletionOutput:
messages = ensure_messages(prompt)
parameters = self.parameters.update_with_validate(**kwargs)
parameters = self.parameters.clone_with_changes(**kwargs)
request_parameters = self._get_request_parameters(messages, parameters)
response = await self.http_client.async_post(request_parameters=request_parameters)
return self._parse_reponse(response.json())
Expand Down Expand Up @@ -136,7 +136,7 @@ def stream_generate(
self, prompt: Prompt, **kwargs: Unpack[HunyuanChatParametersDict]
) -> Iterator[ChatCompletionStreamOutput]:
messages = ensure_messages(prompt)
parameters = self.parameters.update_with_validate(**kwargs)
parameters = self.parameters.clone_with_changes(**kwargs)
request_parameters = self._get_stream_request_parameters(messages, parameters)
message = AssistantMessage(content='')
is_start = True
Expand All @@ -150,7 +150,7 @@ async def async_stream_generate(
self, prompt: Prompt, **kwargs: Unpack[HunyuanChatParametersDict]
) -> AsyncIterator[ChatCompletionStreamOutput]:
messages = ensure_messages(prompt)
parameters = self.parameters.update_with_validate(**kwargs)
parameters = self.parameters.clone_with_changes(**kwargs)
request_parameters = self._get_stream_request_parameters(messages, parameters)
message = AssistantMessage(content='')
is_start = True
Expand Down
8 changes: 4 additions & 4 deletions generate/chat_completion/models/minimax.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,15 +130,15 @@ def _get_request_parameters(self, messages: Messages, parameters: MinimaxChatPar
@override
def generate(self, prompt: Prompt, **kwargs: Unpack[MinimaxChatParametersDict]) -> ChatCompletionOutput:
messages = ensure_messages(prompt)
parameters = self.parameters.update_with_validate(**kwargs)
parameters = self.parameters.clone_with_changes(**kwargs)
request_parameters = self._get_request_parameters(messages, parameters)
response = self.http_client.post(request_parameters=request_parameters)
return self._parse_reponse(response.json())

@override
async def async_generate(self, prompt: Prompt, **kwargs: Unpack[MinimaxChatParametersDict]) -> ChatCompletionOutput:
messages = ensure_messages(prompt)
parameters = self.parameters.update_with_validate(**kwargs)
parameters = self.parameters.clone_with_changes(**kwargs)
request_parameters = self._get_request_parameters(messages, parameters)
response = await self.http_client.async_post(request_parameters=request_parameters)
return self._parse_reponse(response.json())
Expand Down Expand Up @@ -171,7 +171,7 @@ def stream_generate(
self, prompt: Prompt, **kwargs: Unpack[MinimaxChatParametersDict]
) -> Iterator[ChatCompletionStreamOutput]:
messages = ensure_messages(prompt)
parameters = self.parameters.update_with_validate(**kwargs)
parameters = self.parameters.clone_with_changes(**kwargs)
request_parameters = self._get_stream_request_parameters(messages, parameters)
message = AssistantMessage(content='')
is_start = True
Expand All @@ -185,7 +185,7 @@ async def async_stream_generate(
self, prompt: Prompt, **kwargs: Unpack[MinimaxChatParametersDict]
) -> AsyncIterator[ChatCompletionStreamOutput]:
messages = ensure_messages(prompt)
parameters = self.parameters.update_with_validate(**kwargs)
parameters = self.parameters.clone_with_changes(**kwargs)
request_parameters = self._get_stream_request_parameters(messages, parameters)
message = AssistantMessage(content='')
is_start = True
Expand Down
10 changes: 5 additions & 5 deletions generate/chat_completion/models/minimax_pro.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def process(self, response: ResponseValue) -> ChatCompletionStreamOutput:
model_info=self.model_info,
message=self.message,
finish_reason=response['choices'][0]['finish_reason'],
cost=calculate_cost(model_name=self.model_info.name , usage=response['usage']),
cost=calculate_cost(model_name=self.model_info.name, usage=response['usage']),
extra={
'input_sensitive': response['input_sensitive'],
'output_sensitive': response['output_sensitive'],
Expand Down Expand Up @@ -299,15 +299,15 @@ def _get_request_parameters(self, messages: Messages, parameters: MinimaxProChat
@override
def generate(self, prompt: Prompt, **kwargs: Unpack[MinimaxProChatParametersDict]) -> ChatCompletionOutput:
messages = ensure_messages(prompt)
parameters = self.parameters.update_with_validate(**kwargs)
parameters = self.parameters.clone_with_changes(**kwargs)
request_parameters = self._get_request_parameters(messages, parameters)
response = self.http_client.post(request_parameters=request_parameters)
return self._parse_reponse(response.json())

@override
async def async_generate(self, prompt: Prompt, **kwargs: Unpack[MinimaxProChatParametersDict]) -> ChatCompletionOutput:
messages = ensure_messages(prompt)
parameters = self.parameters.update_with_validate(**kwargs)
parameters = self.parameters.clone_with_changes(**kwargs)
request_parameters = self._get_request_parameters(messages, parameters)
response = await self.http_client.async_post(request_parameters=request_parameters)
return self._parse_reponse(response.json())
Expand Down Expand Up @@ -344,7 +344,7 @@ def stream_generate(
self, prompt: Prompt, **kwargs: Unpack[MinimaxProChatParametersDict]
) -> Iterator[ChatCompletionStreamOutput]:
messages = ensure_messages(prompt)
parameters = self.parameters.update_with_validate(**kwargs)
parameters = self.parameters.clone_with_changes(**kwargs)
request_parameters = self._get_stream_request_parameters(messages, parameters)
stream_processor = _StreamResponseProcessor(model_info=self.model_info)
for line in self.http_client.stream_post(request_parameters=request_parameters):
Expand All @@ -355,7 +355,7 @@ async def async_stream_generate(
self, prompt: Prompt, **kwargs: Unpack[MinimaxProChatParametersDict]
) -> AsyncIterator[ChatCompletionStreamOutput]:
messages = ensure_messages(prompt)
parameters = self.parameters.update_with_validate(**kwargs)
parameters = self.parameters.clone_with_changes(**kwargs)
request_parameters = self._get_stream_request_parameters(messages, parameters)
stream_processor = _StreamResponseProcessor(model_info=self.model_info)
async for line in self.http_client.async_stream_post(request_parameters=request_parameters):
Expand Down
8 changes: 4 additions & 4 deletions generate/chat_completion/models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,15 +375,15 @@ def _get_request_parameters(self, messages: Messages, parameters: OpenAIChatPara
@override
def generate(self, prompt: Prompt, **kwargs: Unpack[OpenAIChatParametersDict]) -> ChatCompletionOutput:
messages = ensure_messages(prompt)
parameters = self.parameters.update_with_validate(**kwargs)
parameters = self.parameters.clone_with_changes(**kwargs)
request_parameters = self._get_request_parameters(messages, parameters)
response = self.http_client.post(request_parameters)
return parse_openai_model_reponse(response.json())

@override
async def async_generate(self, prompt: Prompt, **kwargs: Unpack[OpenAIChatParametersDict]) -> ChatCompletionOutput:
messages = ensure_messages(prompt)
parameters = self.parameters.update_with_validate(**kwargs)
parameters = self.parameters.clone_with_changes(**kwargs)
request_parameters = self._get_request_parameters(messages, parameters)
response = await self.http_client.async_post(request_parameters=request_parameters)
return parse_openai_model_reponse(response.json())
Expand All @@ -398,7 +398,7 @@ def stream_generate(
self, prompt: Prompt, **kwargs: Unpack[OpenAIChatParametersDict]
) -> Iterator[ChatCompletionStreamOutput]:
messages = ensure_messages(prompt)
parameters = self.parameters.update_with_validate(**kwargs)
parameters = self.parameters.clone_with_changes(**kwargs)
request_parameters = self._get_stream_request_parameters(messages, parameters)
stream_processor = _StreamResponseProcessor()
is_finish = False
Expand All @@ -417,7 +417,7 @@ async def async_stream_generate(
self, prompt: Prompt, **kwargs: Unpack[OpenAIChatParametersDict]
) -> AsyncIterator[ChatCompletionStreamOutput]:
messages = ensure_messages(prompt)
parameters = self.parameters.update_with_validate(**kwargs)
parameters = self.parameters.clone_with_changes(**kwargs)
request_parameters = self._get_stream_request_parameters(messages, parameters)
stream_processor = _StreamResponseProcessor()
is_finish = False
Expand Down
8 changes: 4 additions & 4 deletions generate/chat_completion/models/wenxin.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,15 +160,15 @@ def _get_request_parameters(self, messages: Messages, parameters: WenxinChatPara
@override
def generate(self, prompt: Prompt, **kwargs: Unpack[WenxinChatParametersDict]) -> ChatCompletionOutput:
messages = ensure_messages(prompt)
parameters = self.parameters.update_with_validate(**kwargs)
parameters = self.parameters.clone_with_changes(**kwargs)
request_parameters = self._get_request_parameters(messages, parameters)
response = self.http_client.post(request_parameters)
return self._parse_reponse(response.json())

@override
async def async_generate(self, prompt: Prompt, **kwargs: Unpack[WenxinChatParametersDict]) -> ChatCompletionOutput:
messages = ensure_messages(prompt)
parameters = self.parameters.update_with_validate(**kwargs)
parameters = self.parameters.clone_with_changes(**kwargs)
request_parameters = self._get_request_parameters(messages, parameters)
response = await self.http_client.async_post(request_parameters=request_parameters)
return self._parse_reponse(response.json())
Expand Down Expand Up @@ -206,7 +206,7 @@ def stream_generate(
self, prompt: Prompt, **kwargs: Unpack[WenxinChatParametersDict]
) -> Iterator[ChatCompletionStreamOutput]:
messages = ensure_messages(prompt)
parameters = self.parameters.update_with_validate(**kwargs)
parameters = self.parameters.clone_with_changes(**kwargs)
if parameters.functions:
raise ValueError('stream_generate does not support functions')
request_parameters = self._get_stream_request_parameters(messages, parameters)
Expand All @@ -222,7 +222,7 @@ async def async_stream_generate(
self, prompt: Prompt, **kwargs: Unpack[WenxinChatParametersDict]
) -> AsyncIterator[ChatCompletionStreamOutput]:
messages = ensure_messages(prompt)
parameters = self.parameters.update_with_validate(**kwargs)
parameters = self.parameters.clone_with_changes(**kwargs)
if parameters.functions:
raise ValueError('stream_generate does not support functions')
request_parameters = self._get_stream_request_parameters(messages, parameters)
Expand Down
Loading

0 comments on commit 48434e6

Please sign in to comment.