From b79b8e8b6017193e6959e37be580c7ec0eb95e61 Mon Sep 17 00:00:00 2001 From: "yuxin.wang" Date: Tue, 9 Jan 2024 12:22:31 +0800 Subject: [PATCH] Hotfix (#14) * fix stream break --------- Co-authored-by: wangyuxin --- generate/chat_completion/model_output.py | 5 +++-- generate/chat_completion/models/bailian.py | 14 ++++++++++---- generate/chat_completion/models/minimax_pro.py | 16 +++++++++++++++- generate/chat_completion/models/openai.py | 16 +++++++++++----- generate/version.py | 2 +- pyproject.toml | 2 +- 6 files changed, 41 insertions(+), 14 deletions(-) diff --git a/generate/chat_completion/model_output.py b/generate/chat_completion/model_output.py index cbe4487..da9040e 100644 --- a/generate/chat_completion/model_output.py +++ b/generate/chat_completion/model_output.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Generic, Literal, Optional, TypeVar +from typing import Generic, Literal, Optional, TypeVar, cast from pydantic import BaseModel @@ -17,7 +17,8 @@ class ChatCompletionOutput(ModelOutput, Generic[M]): @property def reply(self) -> str: if self.message and isinstance(self.message, AssistantMessage): - return self.message.content + message = cast(AssistantMessage, self.message) + return message.content return '' @property diff --git a/generate/chat_completion/models/bailian.py b/generate/chat_completion/models/bailian.py index 682b4d0..037b881 100644 --- a/generate/chat_completion/models/bailian.py +++ b/generate/chat_completion/models/bailian.py @@ -182,12 +182,15 @@ def stream_generate( request_parameters = self._get_stream_request_parameters(messages, parameters) message = AssistantMessage(content='') is_start = True + is_finish = False for line in self.http_client.stream_post(request_parameters=request_parameters): + if is_finish: + continue + output = self._parse_stream_line(line, message, is_start) is_start = False + is_finish = output.is_finish yield output - if output.is_finish: - break @override async def async_stream_generate( @@ -198,12 +201,15 @@ async def async_stream_generate( request_parameters = self._get_stream_request_parameters(messages, parameters) message = AssistantMessage(content='') is_start = True + is_finish = False async for line in self.http_client.async_stream_post(request_parameters=request_parameters): + if is_finish: + continue + output = self._parse_stream_line(line, message, is_start) is_start = False + is_finish = output.is_finish yield output - if output.is_finish: - break def _parse_stream_line( self, line: str, message: AssistantMessage, is_start: bool diff --git a/generate/chat_completion/models/minimax_pro.py b/generate/chat_completion/models/minimax_pro.py index f6801c1..ca9c593 100644 --- a/generate/chat_completion/models/minimax_pro.py +++ b/generate/chat_completion/models/minimax_pro.py @@ -236,9 +236,13 @@ def initial_message(self, response: ResponseValue) -> MinimaxProAssistantMessage def update_existing_message(self, response: ResponseValue) -> str: output_messages = [_convert_to_message(i) for i in response['choices'][0]['messages']] + if len(output_messages) == 1 and not isinstance(self.message, AssistantGroupMessage): + return self.update_single_message(output_messages[0]) # type: ignore + if len(output_messages) > 1 and not isinstance(self.message, AssistantGroupMessage): self.message = AssistantGroupMessage(content=[self.message]) # type: ignore - messages = self.message.content if isinstance(self.message, AssistantGroupMessage) else [self.message] + self.message = cast(AssistantGroupMessage, self.message) + messages = self.message.content delta = '' for index, output_message in enumerate(output_messages, start=1): if index > len(messages): @@ -255,6 +259,16 @@ def update_existing_message(self, response: ResponseValue) -> str: raise ValueError(f'unknown message type: {output_message}') return delta + def update_single_message(self, message: FunctionCallMessage | AssistantMessage) -> str: + if isinstance(message, FunctionCallMessage): + delta = '' + self.message = message + return delta + + delta = message.content + self.message.content += message.content # type: ignore + return delta + def calculate_cost(usage: dict[str, int], num_web_search: int = 0) -> float: return 0.015 * (usage['total_tokens'] / 1000) + (0.03 * num_web_search) diff --git a/generate/chat_completion/models/openai.py b/generate/chat_completion/models/openai.py index 77cf0bb..163c2da 100644 --- a/generate/chat_completion/models/openai.py +++ b/generate/chat_completion/models/openai.py @@ -291,7 +291,7 @@ def process(self, response: ResponseValue) -> ChatCompletionStreamOutput[OpenAIA finish_reason=finish_reason, cost=cost, extra=extra, - stream=Stream(delta=delta_dict.get('content', ''), control=stream_control), + stream=Stream(delta=delta_dict.get('content') or '', control=stream_control), ) def process_initial_message(self, delta_dict: dict[str, Any]) -> OpenAIAssistantMessage | None: @@ -405,13 +405,16 @@ def stream_generate( parameters = self.parameters.update_with_validate(**kwargs) request_parameters = self._get_stream_request_parameters(messages, parameters) stream_processor = _StreamResponseProcessor() + is_finish = False for line in self.http_client.stream_post(request_parameters=request_parameters): + if is_finish: + continue + output = stream_processor.process(json.loads(line)) if output is None: continue + is_finish = output.is_finish yield output - if output.is_finish: - break @override async def async_stream_generate( @@ -421,13 +424,16 @@ async def async_stream_generate( parameters = self.parameters.update_with_validate(**kwargs) request_parameters = self._get_stream_request_parameters(messages, parameters) stream_processor = _StreamResponseProcessor() + is_finish = False async for line in self.http_client.async_stream_post(request_parameters=request_parameters): + if is_finish: + continue + output = stream_processor.process(json.loads(line)) if output is None: continue + is_finish = output.is_finish yield output - if output.is_finish: - break @property @override diff --git a/generate/version.py b/generate/version.py index 020ed73..d93b5b2 100644 --- a/generate/version.py +++ b/generate/version.py @@ -1 +1 @@ -__version__ = '0.2.2' +__version__ = '0.2.3' diff --git a/pyproject.toml b/pyproject.toml index 486eab5..30afacb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "generate-core" -version = "0.2.2" +version = "0.2.3" description = "文本生成,图像生成,语音生成" authors = ["wangyuxin "] license = "MIT"