Skip to content

Commit

Permalink
Hotfix (#14)
Browse files Browse the repository at this point in the history
* fix stream break
---------

Co-authored-by: wangyuxin <[email protected]>
  • Loading branch information
wangyuxinwhy and wangyuxin authored Jan 9, 2024
1 parent e4038fe commit b79b8e8
Show file tree
Hide file tree
Showing 6 changed files with 41 additions and 14 deletions.
5 changes: 3 additions & 2 deletions generate/chat_completion/model_output.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Generic, Literal, Optional, TypeVar
from typing import Generic, Literal, Optional, TypeVar, cast

from pydantic import BaseModel

Expand All @@ -17,7 +17,8 @@ class ChatCompletionOutput(ModelOutput, Generic[M]):
@property
def reply(self) -> str:
if self.message and isinstance(self.message, AssistantMessage):
return self.message.content
message = cast(AssistantMessage, self.message)
return message.content
return ''

@property
Expand Down
14 changes: 10 additions & 4 deletions generate/chat_completion/models/bailian.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,12 +182,15 @@ def stream_generate(
request_parameters = self._get_stream_request_parameters(messages, parameters)
message = AssistantMessage(content='')
is_start = True
is_finish = False
for line in self.http_client.stream_post(request_parameters=request_parameters):
if is_finish:
continue

output = self._parse_stream_line(line, message, is_start)
is_start = False
is_finish = output.is_finish
yield output
if output.is_finish:
break

@override
async def async_stream_generate(
Expand All @@ -198,12 +201,15 @@ async def async_stream_generate(
request_parameters = self._get_stream_request_parameters(messages, parameters)
message = AssistantMessage(content='')
is_start = True
is_finish = False
async for line in self.http_client.async_stream_post(request_parameters=request_parameters):
if is_finish:
continue

output = self._parse_stream_line(line, message, is_start)
is_start = False
is_finish = output.is_finish
yield output
if output.is_finish:
break

def _parse_stream_line(
self, line: str, message: AssistantMessage, is_start: bool
Expand Down
16 changes: 15 additions & 1 deletion generate/chat_completion/models/minimax_pro.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,9 +236,13 @@ def initial_message(self, response: ResponseValue) -> MinimaxProAssistantMessage

def update_existing_message(self, response: ResponseValue) -> str:
output_messages = [_convert_to_message(i) for i in response['choices'][0]['messages']]
if len(output_messages) == 1 and not isinstance(self.message, AssistantGroupMessage):
return self.update_single_message(output_messages[0]) # type: ignore

if len(output_messages) > 1 and not isinstance(self.message, AssistantGroupMessage):
self.message = AssistantGroupMessage(content=[self.message]) # type: ignore
messages = self.message.content if isinstance(self.message, AssistantGroupMessage) else [self.message]
self.message = cast(AssistantGroupMessage, self.message)
messages = self.message.content
delta = ''
for index, output_message in enumerate(output_messages, start=1):
if index > len(messages):
Expand All @@ -255,6 +259,16 @@ def update_existing_message(self, response: ResponseValue) -> str:
raise ValueError(f'unknown message type: {output_message}')
return delta

def update_single_message(self, message: FunctionCallMessage | AssistantMessage) -> str:
if isinstance(message, FunctionCallMessage):
delta = ''
self.message = message
return delta

delta = message.content
self.message.content += message.content # type: ignore
return delta


def calculate_cost(usage: dict[str, int], num_web_search: int = 0) -> float:
return 0.015 * (usage['total_tokens'] / 1000) + (0.03 * num_web_search)
Expand Down
16 changes: 11 additions & 5 deletions generate/chat_completion/models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ def process(self, response: ResponseValue) -> ChatCompletionStreamOutput[OpenAIA
finish_reason=finish_reason,
cost=cost,
extra=extra,
stream=Stream(delta=delta_dict.get('content', ''), control=stream_control),
stream=Stream(delta=delta_dict.get('content') or '', control=stream_control),
)

def process_initial_message(self, delta_dict: dict[str, Any]) -> OpenAIAssistantMessage | None:
Expand Down Expand Up @@ -405,13 +405,16 @@ def stream_generate(
parameters = self.parameters.update_with_validate(**kwargs)
request_parameters = self._get_stream_request_parameters(messages, parameters)
stream_processor = _StreamResponseProcessor()
is_finish = False
for line in self.http_client.stream_post(request_parameters=request_parameters):
if is_finish:
continue

output = stream_processor.process(json.loads(line))
if output is None:
continue
is_finish = output.is_finish
yield output
if output.is_finish:
break

@override
async def async_stream_generate(
Expand All @@ -421,13 +424,16 @@ async def async_stream_generate(
parameters = self.parameters.update_with_validate(**kwargs)
request_parameters = self._get_stream_request_parameters(messages, parameters)
stream_processor = _StreamResponseProcessor()
is_finish = False
async for line in self.http_client.async_stream_post(request_parameters=request_parameters):
if is_finish:
continue

output = stream_processor.process(json.loads(line))
if output is None:
continue
is_finish = output.is_finish
yield output
if output.is_finish:
break

@property
@override
Expand Down
2 changes: 1 addition & 1 deletion generate/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.2.2'
__version__ = '0.2.3'
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "generate-core"
version = "0.2.2"
version = "0.2.3"
description = "文本生成,图像生成,语音生成"
authors = ["wangyuxin <[email protected]>"]
license = "MIT"
Expand Down

0 comments on commit b79b8e8

Please sign in to comment.