diff --git a/generate/chat_completion/function_call.py b/generate/chat_completion/function_call.py index 4eed2c1..05efff1 100644 --- a/generate/chat_completion/function_call.py +++ b/generate/chat_completion/function_call.py @@ -1,13 +1,11 @@ from __future__ import annotations -import json from typing import Any, Callable, Generic, TypeVar from docstring_parser import parse from pydantic import TypeAdapter, validate_call from typing_extensions import NotRequired, ParamSpec, TypedDict -from generate.chat_completion.message import FunctionCallMessage, Message from generate.types import JsonSchema P = ParamSpec('P') @@ -58,13 +56,6 @@ def parameters(self) -> JsonSchema: def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T: return self.function(*args, **kwargs) - def call_with_message(self, message: Message) -> T: - if isinstance(message, FunctionCallMessage): - function_call = message.content - arguments = json.loads(function_call.arguments, strict=False) - return self.function(**arguments) # type: ignore - raise ValueError(f'message is not a function call: {message}') - def recusive_remove(obj: Any, remove_key: str) -> None: """ diff --git a/generate/chat_completion/message/__init__.py b/generate/chat_completion/message/__init__.py index c14562a..cfc3a50 100644 --- a/generate/chat_completion/message/__init__.py +++ b/generate/chat_completion/message/__init__.py @@ -1,8 +1,6 @@ from generate.chat_completion.message.core import ( - AssistantGroupMessage, AssistantMessage, FunctionCall, - FunctionCallMessage, FunctionMessage, ImageUrl, ImageUrlPart, @@ -13,9 +11,7 @@ SystemMessage, TextPart, ToolCall, - ToolCallsMessage, ToolMessage, - UnionAssistantMessage, UnionMessage, UnionUserMessage, UserMessage, @@ -30,13 +26,9 @@ from generate.chat_completion.message.utils import ensure_messages __all__ = [ - 'UnionAssistantMessage', 'UnionUserMessage', 'UnionMessage', - 'FunctionCallMessage', 'AssistantMessage', - 'AssistantGroupMessage', - 'ToolCallsMessage', 'ensure_messages', 'FunctionCall', 'FunctionMessage', diff --git a/generate/chat_completion/message/core.py b/generate/chat_completion/message/core.py index 6b27250..8dc4ebc 100644 --- a/generate/chat_completion/message/core.py +++ b/generate/chat_completion/message/core.py @@ -51,44 +51,32 @@ class ToolMessage(Message): content: Optional[str] = None -class AssistantMessage(Message): - role: Literal['assistant'] = 'assistant' - content: str - - class FunctionCall(BaseModel): name: str arguments: str thoughts: Optional[str] = None -class FunctionCallMessage(Message): - role: Literal['assistant'] = 'assistant' - content: FunctionCall - - class ToolCall(BaseModel): id: str # noqa: A003 type: Literal['function'] = 'function' # noqa: A003 function: FunctionCall -class ToolCallsMessage(Message): +class AssistantMessage(Message): role: Literal['assistant'] = 'assistant' - name: Optional[str] = None - content: List[ToolCall] + content: str = '' + function_call: Optional[FunctionCall] = None + tool_calls: Optional[List[ToolCall]] = None - -class AssistantGroupMessage(Message): - role: Literal['assistant'] = 'assistant' - name: Optional[str] = None - content: List[Union[AssistantMessage, FunctionMessage, FunctionCallMessage]] + @property + def is_over(self) -> bool: + return self.function_call is None and self.tool_calls is None -UnionAssistantMessage = Union[AssistantMessage, FunctionCallMessage, ToolCallsMessage, AssistantGroupMessage] UnionUserMessage = Union[UserMessage, UserMultiPartMessage] UnionUserPart = Union[TextPart, ImageUrlPart] -UnionMessage = Union[SystemMessage, FunctionMessage, ToolMessage, UnionAssistantMessage, UnionUserMessage] +UnionMessage = Union[SystemMessage, FunctionMessage, ToolMessage, AssistantMessage, UnionUserMessage] Messages = List[UnionMessage] MessageDict = Dict[str, Any] MessageDicts = Sequence[MessageDict] diff --git a/generate/chat_completion/model_output.py b/generate/chat_completion/model_output.py index da9040e..055356a 100644 --- a/generate/chat_completion/model_output.py +++ b/generate/chat_completion/model_output.py @@ -1,25 +1,20 @@ from __future__ import annotations -from typing import Generic, Literal, Optional, TypeVar, cast +from typing import Literal, Optional from pydantic import BaseModel -from generate.chat_completion.message import AssistantMessage, UnionAssistantMessage +from generate.chat_completion.message import AssistantMessage from generate.model import ModelOutput -M = TypeVar('M', bound=UnionAssistantMessage) - -class ChatCompletionOutput(ModelOutput, Generic[M]): - message: M +class ChatCompletionOutput(ModelOutput, AssistantMessage): + message: AssistantMessage finish_reason: Optional[str] = None @property def reply(self) -> str: - if self.message and isinstance(self.message, AssistantMessage): - message = cast(AssistantMessage, self.message) - return message.content - return '' + return self.message.content @property def is_finish(self) -> bool: @@ -31,5 +26,5 @@ class Stream(BaseModel): control: Literal['start', 'continue', 'finish'] -class ChatCompletionStreamOutput(ChatCompletionOutput, Generic[M]): +class ChatCompletionStreamOutput(ChatCompletionOutput): stream: Stream diff --git a/generate/chat_completion/models/baichuan.py b/generate/chat_completion/models/baichuan.py index df013df..b474ce6 100644 --- a/generate/chat_completion/models/baichuan.py +++ b/generate/chat_completion/models/baichuan.py @@ -117,7 +117,7 @@ def _get_request_parameters(self, messages: Messages, parameters: BaichuanChatPa } @override - def generate(self, prompt: Prompt, **kwargs: Unpack[BaichuanChatParametersDict]) -> ChatCompletionOutput[AssistantMessage]: + def generate(self, prompt: Prompt, **kwargs: Unpack[BaichuanChatParametersDict]) -> ChatCompletionOutput: messages = ensure_messages(prompt) parameters = self.parameters.update_with_validate(**kwargs) request_parameters = self._get_request_parameters(messages, parameters) @@ -125,16 +125,14 @@ def generate(self, prompt: Prompt, **kwargs: Unpack[BaichuanChatParametersDict]) return self._parse_reponse(response.json()) @override - async def async_generate( - self, prompt: Prompt, **kwargs: Unpack[BaichuanChatParametersDict] - ) -> ChatCompletionOutput[AssistantMessage]: + async def async_generate(self, prompt: Prompt, **kwargs: Unpack[BaichuanChatParametersDict]) -> ChatCompletionOutput: messages = ensure_messages(prompt) parameters = self.parameters.update_with_validate(**kwargs) request_parameters = self._get_request_parameters(messages, parameters) response = await self.http_client.async_post(request_parameters=request_parameters) return self._parse_reponse(response.json()) - def _parse_reponse(self, response: ResponseValue) -> ChatCompletionOutput[AssistantMessage]: + def _parse_reponse(self, response: ResponseValue) -> ChatCompletionOutput: try: text = response['data']['messages'][-1]['content'] finish_reason = response['data']['messages'][-1]['finish_reason'] @@ -164,7 +162,7 @@ def _get_stream_request_parameters(self, messages: Messages, parameters: Baichua @override def stream_generate( self, prompt: Prompt, **kwargs: Unpack[BaichuanChatParametersDict] - ) -> Iterator[ChatCompletionStreamOutput[AssistantMessage]]: + ) -> Iterator[ChatCompletionStreamOutput]: messages = ensure_messages(prompt) parameters = self.parameters.update_with_validate(**kwargs) request_parameters = self._get_stream_request_parameters(messages, parameters) @@ -178,7 +176,7 @@ def stream_generate( @override async def async_stream_generate( self, prompt: Prompt, **kwargs: Unpack[BaichuanChatParametersDict] - ) -> AsyncIterator[ChatCompletionStreamOutput[AssistantMessage]]: + ) -> AsyncIterator[ChatCompletionStreamOutput]: messages = ensure_messages(prompt) parameters = self.parameters.update_with_validate(**kwargs) request_parameters = self._get_stream_request_parameters(messages, parameters) @@ -189,9 +187,7 @@ async def async_stream_generate( is_start = False yield output - def _parse_stream_line( - self, line: str, message: AssistantMessage, is_start: bool - ) -> ChatCompletionStreamOutput[AssistantMessage]: + def _parse_stream_line(self, line: str, message: AssistantMessage, is_start: bool) -> ChatCompletionStreamOutput: output = self._parse_reponse(json.loads(line)) output_message = output.message if is_start: @@ -199,7 +195,7 @@ def _parse_stream_line( else: stream = Stream(delta=output_message.content, control='finish' if output.is_finish else 'continue') message.content += output_message.content - return ChatCompletionStreamOutput[AssistantMessage]( + return ChatCompletionStreamOutput( model_info=output.model_info, message=message, finish_reason=output.finish_reason, diff --git a/generate/chat_completion/models/bailian.py b/generate/chat_completion/models/bailian.py index 037b881..69050f1 100644 --- a/generate/chat_completion/models/bailian.py +++ b/generate/chat_completion/models/bailian.py @@ -135,7 +135,7 @@ def _get_request_parameters(self, messages: Messages, parameters: BailianChatPar } @override - def generate(self, prompt: Prompt, **kwargs: Unpack[BailianChatParametersDict]) -> ChatCompletionOutput[AssistantMessage]: + def generate(self, prompt: Prompt, **kwargs: Unpack[BailianChatParametersDict]) -> ChatCompletionOutput: messages = ensure_messages(prompt) parameters = self.parameters.update_with_validate(**kwargs) request_parameters = self._get_request_parameters(messages, parameters) @@ -143,16 +143,14 @@ def generate(self, prompt: Prompt, **kwargs: Unpack[BailianChatParametersDict]) return self._parse_reponse(response.json()) @override - async def async_generate( - self, prompt: Prompt, **kwargs: Unpack[BailianChatParametersDict] - ) -> ChatCompletionOutput[AssistantMessage]: + async def async_generate(self, prompt: Prompt, **kwargs: Unpack[BailianChatParametersDict]) -> ChatCompletionOutput: messages = ensure_messages(prompt) parameters = self.parameters.update_with_validate(**kwargs) request_parameters = self._get_request_parameters(messages, parameters) response = await self.http_client.async_post(request_parameters=request_parameters) return self._parse_reponse(response.json()) - def _parse_reponse(self, response: ResponseValue) -> ChatCompletionOutput[AssistantMessage]: + def _parse_reponse(self, response: ResponseValue) -> ChatCompletionOutput: if not response['Success']: raise UnexpectedResponseError(response) return ChatCompletionOutput( @@ -176,7 +174,7 @@ def _get_stream_request_parameters(self, messages: Messages, parameters: Bailian @override def stream_generate( self, prompt: Prompt, **kwargs: Unpack[BailianChatParametersDict] - ) -> Iterator[ChatCompletionStreamOutput[AssistantMessage]]: + ) -> Iterator[ChatCompletionStreamOutput]: messages = ensure_messages(prompt) parameters = self.parameters.update_with_validate(**kwargs) request_parameters = self._get_stream_request_parameters(messages, parameters) @@ -195,7 +193,7 @@ def stream_generate( @override async def async_stream_generate( self, prompt: Prompt, **kwargs: Unpack[BailianChatParametersDict] - ) -> AsyncIterator[ChatCompletionStreamOutput[AssistantMessage]]: + ) -> AsyncIterator[ChatCompletionStreamOutput]: messages = ensure_messages(prompt) parameters = self.parameters.update_with_validate(**kwargs) request_parameters = self._get_stream_request_parameters(messages, parameters) @@ -211,9 +209,7 @@ async def async_stream_generate( is_finish = output.is_finish yield output - def _parse_stream_line( - self, line: str, message: AssistantMessage, is_start: bool - ) -> ChatCompletionStreamOutput[AssistantMessage]: + def _parse_stream_line(self, line: str, message: AssistantMessage, is_start: bool) -> ChatCompletionStreamOutput: parsed_line = json.loads(line) reply: str = parsed_line['Data']['Text'] extra = { diff --git a/generate/chat_completion/models/hunyuan.py b/generate/chat_completion/models/hunyuan.py index ac8bc4f..7e12226 100644 --- a/generate/chat_completion/models/hunyuan.py +++ b/generate/chat_completion/models/hunyuan.py @@ -91,7 +91,7 @@ def _get_request_parameters(self, messages: Messages, parameters: HunyuanChatPar } @override - def generate(self, prompt: Prompt, **kwargs: Unpack[HunyuanChatParametersDict]) -> ChatCompletionOutput[AssistantMessage]: + def generate(self, prompt: Prompt, **kwargs: Unpack[HunyuanChatParametersDict]) -> ChatCompletionOutput: messages = ensure_messages(prompt) parameters = self.parameters.update_with_validate(**kwargs) request_parameters = self._get_request_parameters(messages, parameters) @@ -99,16 +99,14 @@ def generate(self, prompt: Prompt, **kwargs: Unpack[HunyuanChatParametersDict]) return self._parse_reponse(response.json()) @override - async def async_generate( - self, prompt: Prompt, **kwargs: Unpack[HunyuanChatParametersDict] - ) -> ChatCompletionOutput[AssistantMessage]: + async def async_generate(self, prompt: Prompt, **kwargs: Unpack[HunyuanChatParametersDict]) -> ChatCompletionOutput: messages = ensure_messages(prompt) parameters = self.parameters.update_with_validate(**kwargs) request_parameters = self._get_request_parameters(messages, parameters) response = await self.http_client.async_post(request_parameters=request_parameters) return self._parse_reponse(response.json()) - def _parse_reponse(self, response: ResponseValue) -> ChatCompletionOutput[AssistantMessage]: + def _parse_reponse(self, response: ResponseValue) -> ChatCompletionOutput: if response.get('error'): raise UnexpectedResponseError(response) return ChatCompletionOutput( @@ -136,7 +134,7 @@ def _get_stream_request_parameters(self, messages: Messages, parameters: Hunyuan @override def stream_generate( self, prompt: Prompt, **kwargs: Unpack[HunyuanChatParametersDict] - ) -> Iterator[ChatCompletionStreamOutput[AssistantMessage]]: + ) -> Iterator[ChatCompletionStreamOutput]: messages = ensure_messages(prompt) parameters = self.parameters.update_with_validate(**kwargs) request_parameters = self._get_stream_request_parameters(messages, parameters) @@ -150,7 +148,7 @@ def stream_generate( @override async def async_stream_generate( self, prompt: Prompt, **kwargs: Unpack[HunyuanChatParametersDict] - ) -> AsyncIterator[ChatCompletionStreamOutput[AssistantMessage]]: + ) -> AsyncIterator[ChatCompletionStreamOutput]: messages = ensure_messages(prompt) parameters = self.parameters.update_with_validate(**kwargs) request_parameters = self._get_stream_request_parameters(messages, parameters) @@ -161,9 +159,7 @@ async def async_stream_generate( is_start = False yield output - def _parse_stream_line( - self, line: str, message: AssistantMessage, is_start: bool - ) -> ChatCompletionStreamOutput[AssistantMessage]: + def _parse_stream_line(self, line: str, message: AssistantMessage, is_start: bool) -> ChatCompletionStreamOutput: parsed_line = json.loads(line) message_dict = parsed_line['choices'][0] delta = message_dict['delta']['content'] diff --git a/generate/chat_completion/models/minimax.py b/generate/chat_completion/models/minimax.py index a02fdb3..acde1e6 100644 --- a/generate/chat_completion/models/minimax.py +++ b/generate/chat_completion/models/minimax.py @@ -128,7 +128,7 @@ def _get_request_parameters(self, messages: Messages, parameters: MinimaxChatPar } @override - def generate(self, prompt: Prompt, **kwargs: Unpack[MinimaxChatParametersDict]) -> ChatCompletionOutput[AssistantMessage]: + def generate(self, prompt: Prompt, **kwargs: Unpack[MinimaxChatParametersDict]) -> ChatCompletionOutput: messages = ensure_messages(prompt) parameters = self.parameters.update_with_validate(**kwargs) request_parameters = self._get_request_parameters(messages, parameters) @@ -136,18 +136,16 @@ def generate(self, prompt: Prompt, **kwargs: Unpack[MinimaxChatParametersDict]) return self._parse_reponse(response.json()) @override - async def async_generate( - self, prompt: Prompt, **kwargs: Unpack[MinimaxChatParametersDict] - ) -> ChatCompletionOutput[AssistantMessage]: + async def async_generate(self, prompt: Prompt, **kwargs: Unpack[MinimaxChatParametersDict]) -> ChatCompletionOutput: messages = ensure_messages(prompt) parameters = self.parameters.update_with_validate(**kwargs) request_parameters = self._get_request_parameters(messages, parameters) response = await self.http_client.async_post(request_parameters=request_parameters) return self._parse_reponse(response.json()) - def _parse_reponse(self, response: ResponseValue) -> ChatCompletionOutput[AssistantMessage]: + def _parse_reponse(self, response: ResponseValue) -> ChatCompletionOutput: try: - return ChatCompletionOutput[AssistantMessage]( + return ChatCompletionOutput( model_info=self.model_info, message=AssistantMessage(content=response['choices'][0]['text']), finish_reason=response['choices'][0]['finish_reason'], @@ -171,7 +169,7 @@ def _get_stream_request_parameters(self, messages: Messages, parameters: Minimax @override def stream_generate( self, prompt: Prompt, **kwargs: Unpack[MinimaxChatParametersDict] - ) -> Iterator[ChatCompletionStreamOutput[AssistantMessage]]: + ) -> Iterator[ChatCompletionStreamOutput]: messages = ensure_messages(prompt) parameters = self.parameters.update_with_validate(**kwargs) request_parameters = self._get_stream_request_parameters(messages, parameters) @@ -185,7 +183,7 @@ def stream_generate( @override async def async_stream_generate( self, prompt: Prompt, **kwargs: Unpack[MinimaxChatParametersDict] - ) -> AsyncIterator[ChatCompletionStreamOutput[AssistantMessage]]: + ) -> AsyncIterator[ChatCompletionStreamOutput]: messages = ensure_messages(prompt) parameters = self.parameters.update_with_validate(**kwargs) request_parameters = self._get_stream_request_parameters(messages, parameters) @@ -196,14 +194,12 @@ async def async_stream_generate( is_start = False yield output - def _parse_stream_line( - self, line: str, message: AssistantMessage, is_start: bool - ) -> ChatCompletionStreamOutput[AssistantMessage]: + def _parse_stream_line(self, line: str, message: AssistantMessage, is_start: bool) -> ChatCompletionStreamOutput: parsed_line = json.loads(line) delta = parsed_line['choices'][0]['delta'] message.content += delta if parsed_line['reply']: - return ChatCompletionStreamOutput[AssistantMessage]( + return ChatCompletionStreamOutput( model_info=self.model_info, finish_reason=parsed_line['choices'][0]['finish_reason'], message=message, @@ -216,7 +212,7 @@ def _parse_stream_line( }, stream=Stream(delta=delta, control='finish'), ) - return ChatCompletionStreamOutput[AssistantMessage]( + return ChatCompletionStreamOutput( model_info=self.model_info, finish_reason=None, message=message, diff --git a/generate/chat_completion/models/minimax_pro.py b/generate/chat_completion/models/minimax_pro.py index ca9c593..5b5b5c2 100644 --- a/generate/chat_completion/models/minimax_pro.py +++ b/generate/chat_completion/models/minimax_pro.py @@ -1,7 +1,7 @@ from __future__ import annotations import json -from typing import Any, AsyncIterator, ClassVar, Dict, Iterator, List, Literal, Optional, Union, cast +from typing import Any, AsyncIterator, ClassVar, Dict, Iterator, List, Literal, Optional, cast from pydantic import Field, PositiveInt, model_validator from typing_extensions import Annotated, NotRequired, Self, TypedDict, Unpack, override @@ -9,10 +9,8 @@ from generate.chat_completion.base import ChatCompletionModel from generate.chat_completion.function_call import FunctionJsonSchema from generate.chat_completion.message import ( - AssistantGroupMessage, AssistantMessage, FunctionCall, - FunctionCallMessage, FunctionMessage, Message, Messages, @@ -34,8 +32,6 @@ from generate.platforms.minimax import MinimaxSettings from generate.types import Probability, Temperature -MinimaxProAssistantMessage = Union[AssistantMessage, FunctionCallMessage, AssistantGroupMessage] - class BotSettingDict(TypedDict): bot_name: str @@ -145,23 +141,19 @@ def _convert_to_minimax_pro_message( sender_name = message.name or default_bot_name if sender_name is None: raise MessageValueError(message, 'bot name is required') + if message.function_call is None: + return { + 'sender_type': 'BOT', + 'sender_name': sender_name, + 'text': message.content, + } return { 'sender_type': 'BOT', 'sender_name': sender_name, 'text': message.content, - } - - if isinstance(message, FunctionCallMessage): - sender_name = message.name or default_bot_name - if sender_name is None: - raise MessageValueError(message, 'bot name is required') - return { - 'sender_type': 'BOT', - 'sender_name': sender_name, - 'text': '', 'function_call': { - 'name': message.content.name, - 'arguments': message.content.arguments, + 'name': message.function_call.name, + 'arguments': message.function_call.arguments, }, } @@ -172,14 +164,15 @@ def _convert_to_minimax_pro_message( 'text': message.content, } - raise MessageTypeError(message, allowed_message_type=(UserMessage, AssistantMessage, FunctionMessage, FunctionCallMessage)) + raise MessageTypeError(message, allowed_message_type=(UserMessage, AssistantMessage, FunctionMessage)) -def _convert_to_message(message: MinimaxProMessage) -> FunctionCallMessage | AssistantMessage | FunctionMessage: +def _convert_to_message(message: MinimaxProMessage) -> AssistantMessage | FunctionMessage: if 'function_call' in message: - return FunctionCallMessage( + return AssistantMessage( name=message['sender_name'], - content=FunctionCall(name=message['function_call']['name'], arguments=message['function_call']['arguments']), + content=message['text'], + function_call=FunctionCall(name=message['function_call']['name'], arguments=message['function_call']['arguments']), ) if message['sender_type'] == 'BOT': return AssistantMessage( @@ -196,12 +189,13 @@ def _convert_to_message(message: MinimaxProMessage) -> FunctionCallMessage | Ass class _StreamResponseProcessor: def __init__(self, model_info: ModelInfo) -> None: - self.message: MinimaxProAssistantMessage | None = None + self.message: AssistantMessage | None = None self.model_info = model_info - def process(self, response: ResponseValue) -> ChatCompletionStreamOutput[MinimaxProAssistantMessage]: + def process(self, response: ResponseValue) -> ChatCompletionStreamOutput: if response.get('usage'): - return ChatCompletionStreamOutput[MinimaxProAssistantMessage]( + assert self.message is not None + return ChatCompletionStreamOutput( model_info=self.model_info, message=self.message, finish_reason=response['choices'][0]['finish_reason'], @@ -222,45 +216,25 @@ def process(self, response: ResponseValue) -> ChatCompletionStreamOutput[Minimax delta = self.update_existing_message(response) control = 'continue' - return ChatCompletionStreamOutput[MinimaxProAssistantMessage]( + return ChatCompletionStreamOutput( model_info=self.model_info, message=self.message, finish_reason=None, stream=Stream(delta=delta, control=control), ) - def initial_message(self, response: ResponseValue) -> MinimaxProAssistantMessage: + def initial_message(self, response: ResponseValue) -> AssistantMessage: output_messages = [_convert_to_message(i) for i in response['choices'][0]['messages']] - message = output_messages[0] if len(output_messages) == 1 else AssistantGroupMessage(content=output_messages) - return cast(MinimaxProAssistantMessage, message) + message = output_messages[-1] + return cast(AssistantMessage, message) def update_existing_message(self, response: ResponseValue) -> str: output_messages = [_convert_to_message(i) for i in response['choices'][0]['messages']] - if len(output_messages) == 1 and not isinstance(self.message, AssistantGroupMessage): - return self.update_single_message(output_messages[0]) # type: ignore - - if len(output_messages) > 1 and not isinstance(self.message, AssistantGroupMessage): - self.message = AssistantGroupMessage(content=[self.message]) # type: ignore - self.message = cast(AssistantGroupMessage, self.message) - messages = self.message.content - delta = '' - for index, output_message in enumerate(output_messages, start=1): - if index > len(messages): - messages.append(output_message) # type: ignore - if isinstance(output_message, AssistantMessage): - delta = output_message.content - elif isinstance(output_message, FunctionCallMessage): - messages[index - 1] = output_message - elif isinstance(output_message, AssistantMessage): - message = cast(AssistantMessage, messages[index - 1]) - delta = output_message.content - message.content += output_message.content - else: - raise ValueError(f'unknown message type: {output_message}') - return delta + message = output_messages[-1] + if not isinstance(message, AssistantMessage): + return '' - def update_single_message(self, message: FunctionCallMessage | AssistantMessage) -> str: - if isinstance(message, FunctionCallMessage): + if message.function_call is not None: delta = '' self.message = message return delta @@ -315,9 +289,7 @@ def _get_request_parameters(self, messages: Messages, parameters: MinimaxProChat } @override - def generate( - self, prompt: Prompt, **kwargs: Unpack[MinimaxProChatParametersDict] - ) -> ChatCompletionOutput[MinimaxProAssistantMessage]: + def generate(self, prompt: Prompt, **kwargs: Unpack[MinimaxProChatParametersDict]) -> ChatCompletionOutput: messages = ensure_messages(prompt) parameters = self.parameters.update_with_validate(**kwargs) request_parameters = self._get_request_parameters(messages, parameters) @@ -325,25 +297,22 @@ def generate( return self._parse_reponse(response.json()) @override - async def async_generate( - self, prompt: Prompt, **kwargs: Unpack[MinimaxProChatParametersDict] - ) -> ChatCompletionOutput[MinimaxProAssistantMessage]: + async def async_generate(self, prompt: Prompt, **kwargs: Unpack[MinimaxProChatParametersDict]) -> ChatCompletionOutput: messages = ensure_messages(prompt) parameters = self.parameters.update_with_validate(**kwargs) request_parameters = self._get_request_parameters(messages, parameters) response = await self.http_client.async_post(request_parameters=request_parameters) return self._parse_reponse(response.json()) - def _parse_reponse(self, response: ResponseValue) -> ChatCompletionOutput[MinimaxProAssistantMessage]: + def _parse_reponse(self, response: ResponseValue) -> ChatCompletionOutput: try: - messages: list[FunctionCallMessage | AssistantMessage | FunctionMessage] = [ + messages: list[AssistantMessage | FunctionMessage] = [ _convert_to_message(i) for i in response['choices'][0]['messages'] ] - message = messages[0] if len(messages) == 1 else AssistantGroupMessage(content=messages) - message = cast(MinimaxProAssistantMessage, message) + message = cast(AssistantMessage, messages[-1]) finish_reason = response['choices'][0]['finish_reason'] num_web_search = sum([1 for i in response['choices'][0]['messages'] if i['sender_name'] == 'plugin_web_search']) - return ChatCompletionOutput[MinimaxProAssistantMessage]( + return ChatCompletionOutput( model_info=self.model_info, message=message, finish_reason=finish_reason, @@ -365,7 +334,7 @@ def _get_stream_request_parameters(self, messages: Messages, parameters: Minimax @override def stream_generate( self, prompt: Prompt, **kwargs: Unpack[MinimaxProChatParametersDict] - ) -> Iterator[ChatCompletionStreamOutput[MinimaxProAssistantMessage]]: + ) -> Iterator[ChatCompletionStreamOutput]: messages = ensure_messages(prompt) parameters = self.parameters.update_with_validate(**kwargs) request_parameters = self._get_stream_request_parameters(messages, parameters) @@ -376,7 +345,7 @@ def stream_generate( @override async def async_stream_generate( self, prompt: Prompt, **kwargs: Unpack[MinimaxProChatParametersDict] - ) -> AsyncIterator[ChatCompletionStreamOutput[MinimaxProAssistantMessage]]: + ) -> AsyncIterator[ChatCompletionStreamOutput]: messages = ensure_messages(prompt) parameters = self.parameters.update_with_validate(**kwargs) request_parameters = self._get_stream_request_parameters(messages, parameters) diff --git a/generate/chat_completion/models/openai.py b/generate/chat_completion/models/openai.py index 163c2da..03ff15b 100644 --- a/generate/chat_completion/models/openai.py +++ b/generate/chat_completion/models/openai.py @@ -12,7 +12,6 @@ from generate.chat_completion.message import ( AssistantMessage, FunctionCall, - FunctionCallMessage, FunctionMessage, Message, Messages, @@ -21,7 +20,6 @@ SystemMessage, TextPart, ToolCall, - ToolCallsMessage, ToolMessage, UserMessage, UserMultiPartMessage, @@ -37,8 +35,6 @@ from generate.platforms.openai import OpenAISettings from generate.types import Probability, Temperature -OpenAIAssistantMessage = Union[AssistantMessage, FunctionCallMessage, ToolCallsMessage] - class FunctionCallName(TypedDict): name: str @@ -150,11 +146,13 @@ def _to_tool_message_dict(message: ToolMessage) -> OpenAIMessage: } -def _to_tool_calls_message_dict(message: ToolCallsMessage) -> OpenAIMessage: - return { +def _to_asssistant_message_dict(message: AssistantMessage) -> OpenAIMessage: + base_dict = { 'role': 'assistant', - 'content': None, - 'tool_calls': [ + 'content': message.content or None, + } + if message.tool_calls: + tool_calls = [ { 'id': tool_call.id, 'type': 'function', @@ -163,9 +161,15 @@ def _to_tool_calls_message_dict(message: ToolCallsMessage) -> OpenAIMessage: 'arguments': tool_call.function.arguments, }, } - for tool_call in message.content - ], - } + for tool_call in message.tool_calls + ] + base_dict['tool_calls'] = tool_calls + if message.function_call: + base_dict['function_call'] = { + 'name': message.function_call.name, + 'arguments': message.function_call.arguments, + } + return cast(OpenAIMessage, base_dict) def _to_function_message_dict(message: FunctionMessage) -> OpenAIMessage: @@ -176,27 +180,14 @@ def _to_function_message_dict(message: FunctionMessage) -> OpenAIMessage: } -def _to_function_call_message_dict(message: FunctionCallMessage) -> OpenAIMessage: - return { - 'role': 'assistant', - 'function_call': { - 'name': message.content.name, - 'arguments': message.content.arguments, - }, - 'content': None, - } - - def convert_to_openai_message(message: Message) -> OpenAIMessage: to_function_map: dict[Type[Message], Callable[[Any], OpenAIMessage]] = { SystemMessage: partial(_to_text_message_dict, 'system'), UserMessage: partial(_to_text_message_dict, 'user'), - AssistantMessage: partial(_to_text_message_dict, 'assistant'), + AssistantMessage: partial(_to_asssistant_message_dict), UserMultiPartMessage: _to_user_multipart_message_dict, ToolMessage: _to_tool_message_dict, - ToolCallsMessage: _to_tool_calls_message_dict, FunctionMessage: _to_function_message_dict, - FunctionCallMessage: _to_function_call_message_dict, } if to_function := to_function_map.get(type(message)): return to_function(message) @@ -219,35 +210,32 @@ def calculate_cost(model_name: str, input_tokens: int, output_tokens: int) -> fl return None -def convert_openai_message_to_generate_message( - message: dict[str, Any] -) -> FunctionCallMessage | ToolCallsMessage | AssistantMessage: - if function_call := message.get('function_call'): - function_call = cast(OpenAIFunctionCall, function_call) - return FunctionCallMessage( - content=FunctionCall( - name=function_call.get('name') or '', - arguments=function_call['arguments'], - ), - ) - if tool_calls := message.get('tool_calls'): - tool_calls = cast(List[OpenAIToolCall], tool_calls) - return ToolCallsMessage( - content=[ - ToolCall( - id=tool_call['id'], - function=FunctionCall( - name=tool_call['function'].get('name') or '', - arguments=tool_call['function']['arguments'], - ), - ) - for tool_call in tool_calls - ], +def convert_openai_message_to_generate_message(message: dict[str, Any]) -> AssistantMessage: + if function_call_dict := message.get('function_call'): + function_call = FunctionCall( + name=function_call_dict.get('name') or '', + arguments=function_call_dict['arguments'], ) - return AssistantMessage(content=message['content'] or '') - - -def parse_openai_model_reponse(response: ResponseValue) -> ChatCompletionOutput[OpenAIAssistantMessage]: + else: + function_call = None + + if tool_calls_dict := message.get('tool_calls'): + tool_calls = [ + ToolCall( + id=tool_call['id'], + function=FunctionCall( + name=tool_call['function'].get('name') or '', + arguments=tool_call['function']['arguments'], + ), + ) + for tool_call in tool_calls_dict + ] + else: + tool_calls = None + return AssistantMessage(content=message.get('content') or '', function_call=function_call, tool_calls=tool_calls) + + +def parse_openai_model_reponse(response: ResponseValue) -> ChatCompletionOutput: message = convert_openai_message_to_generate_message(response['choices'][0]['message']) extra = {'usage': response['usage']} if system_fingerprint := response.get('system_fingerprint'): @@ -257,7 +245,7 @@ def parse_openai_model_reponse(response: ResponseValue) -> ChatCompletionOutput[ if (finish_reason := choice.get('finish_reason')) is None: finish_reason = finish_details['type'] if (finish_details := choice.get('finish_details')) else None - return ChatCompletionOutput[OpenAIAssistantMessage]( + return ChatCompletionOutput( model_info=ModelInfo(task='chat_completion', type='openai', name=response['model']), message=message, finish_reason=finish_reason or '', @@ -268,10 +256,10 @@ def parse_openai_model_reponse(response: ResponseValue) -> ChatCompletionOutput[ class _StreamResponseProcessor: def __init__(self) -> None: - self.message: OpenAIAssistantMessage | None = None + self.message: AssistantMessage | None = None self.is_start = True - def process(self, response: ResponseValue) -> ChatCompletionStreamOutput[OpenAIAssistantMessage] | None: + def process(self, response: ResponseValue) -> ChatCompletionStreamOutput | None: delta_dict = response['choices'][0]['delta'] if self.message is None: @@ -285,7 +273,7 @@ def process(self, response: ResponseValue) -> ChatCompletionStreamOutput[OpenAIA finish_reason = self.determine_finish_reason(response) stream_control = 'finish' if finish_reason else 'start' if self.is_start else 'continue' self.is_start = False - return ChatCompletionStreamOutput[OpenAIAssistantMessage]( + return ChatCompletionStreamOutput( model_info=ModelInfo(task='chat_completion', type='openai', name=response['model']), message=self.message, finish_reason=finish_reason, @@ -294,7 +282,7 @@ def process(self, response: ResponseValue) -> ChatCompletionStreamOutput[OpenAIA stream=Stream(delta=delta_dict.get('content') or '', control=stream_control), ) - def process_initial_message(self, delta_dict: dict[str, Any]) -> OpenAIAssistantMessage | None: + def process_initial_message(self, delta_dict: dict[str, Any]) -> AssistantMessage | None: if ( delta_dict.get('content') is None and delta_dict.get('tool_calls') is None @@ -306,19 +294,30 @@ def process_initial_message(self, delta_dict: dict[str, Any]) -> OpenAIAssistant def update_existing_message(self, delta_dict: dict[str, Any]) -> None: if not delta_dict: return + assert self.message is not None - if isinstance(self.message, AssistantMessage): - delta = delta_dict['content'] - self.message.content += delta - elif isinstance(self.message, FunctionCallMessage): - self.message.content.arguments += delta_dict['function_call']['arguments'] - elif isinstance(self.message, ToolCallsMessage): + delta_content = delta_dict.get('content', '') + self.message.content += delta_content + + if delta_dict.get('tool_calls'): index = delta_dict['tool_calls'][0]['index'] - if index >= len(self.message.content): - new_tool_calls_message = cast(ToolCallsMessage, convert_openai_message_to_generate_message(delta_dict)) - self.message.content.append(new_tool_calls_message.content[0]) + if index >= len(self.message.tool_calls or []): + new_tool_calls_message = convert_openai_message_to_generate_message(delta_dict).tool_calls + assert new_tool_calls_message is not None + if self.message.tool_calls is None: + self.message.tool_calls = [] + self.message.tool_calls.append(new_tool_calls_message[0]) else: - self.message.content[index].function.arguments += delta_dict['tool_calls'][0]['function']['arguments'] + assert self.message.tool_calls is not None + self.message.tool_calls[index].function.arguments += delta_dict['tool_calls'][0]['function']['arguments'] + + if delta_dict.get('function_call'): + if self.message.function_call is None: + self.message.function_call = FunctionCall(name='', arguments='') + function_name = delta_dict['function_call'].get('name', '') + self.message.function_call.name = function_name + arguments = delta_dict['function_call'].get('arguments', '') + self.message.function_call.arguments += arguments def extract_extra_info(self, response: ResponseValue) -> dict[str, Any]: extra = {} @@ -373,9 +372,7 @@ def _get_request_parameters(self, messages: Messages, parameters: OpenAIChatPara } @override - def generate( - self, prompt: Prompt, **kwargs: Unpack[OpenAIChatParametersDict] - ) -> ChatCompletionOutput[OpenAIAssistantMessage]: + def generate(self, prompt: Prompt, **kwargs: Unpack[OpenAIChatParametersDict]) -> ChatCompletionOutput: messages = ensure_messages(prompt) parameters = self.parameters.update_with_validate(**kwargs) request_parameters = self._get_request_parameters(messages, parameters) @@ -383,9 +380,7 @@ def generate( return parse_openai_model_reponse(response.json()) @override - async def async_generate( - self, prompt: Prompt, **kwargs: Unpack[OpenAIChatParametersDict] - ) -> ChatCompletionOutput[OpenAIAssistantMessage]: + async def async_generate(self, prompt: Prompt, **kwargs: Unpack[OpenAIChatParametersDict]) -> ChatCompletionOutput: messages = ensure_messages(prompt) parameters = self.parameters.update_with_validate(**kwargs) request_parameters = self._get_request_parameters(messages, parameters) @@ -400,7 +395,7 @@ def _get_stream_request_parameters(self, messages: Messages, parameters: OpenAIC @override def stream_generate( self, prompt: Prompt, **kwargs: Unpack[OpenAIChatParametersDict] - ) -> Iterator[ChatCompletionStreamOutput[OpenAIAssistantMessage]]: + ) -> Iterator[ChatCompletionStreamOutput]: messages = ensure_messages(prompt) parameters = self.parameters.update_with_validate(**kwargs) request_parameters = self._get_stream_request_parameters(messages, parameters) @@ -419,7 +414,7 @@ def stream_generate( @override async def async_stream_generate( self, prompt: Prompt, **kwargs: Unpack[OpenAIChatParametersDict] - ) -> AsyncIterator[ChatCompletionStreamOutput[OpenAIAssistantMessage]]: + ) -> AsyncIterator[ChatCompletionStreamOutput]: messages = ensure_messages(prompt) parameters = self.parameters.update_with_validate(**kwargs) request_parameters = self._get_stream_request_parameters(messages, parameters) diff --git a/generate/chat_completion/models/wenxin.py b/generate/chat_completion/models/wenxin.py index 5363ca4..3152ba2 100644 --- a/generate/chat_completion/models/wenxin.py +++ b/generate/chat_completion/models/wenxin.py @@ -1,7 +1,7 @@ from __future__ import annotations import json -from typing import Any, AsyncIterator, ClassVar, Iterator, List, Literal, Optional, Union +from typing import Any, AsyncIterator, ClassVar, Iterator, List, Literal, Optional from pydantic import Field, model_validator from typing_extensions import Annotated, NotRequired, Self, TypedDict, Unpack, override @@ -10,7 +10,6 @@ from generate.chat_completion.message import ( AssistantMessage, FunctionCall, - FunctionCallMessage, FunctionMessage, Message, Messages, @@ -31,8 +30,6 @@ from generate.platforms.baidu import QianfanSettings, QianfanTokenManager from generate.types import JsonSchema, Probability, Temperature -WenxinAssistantMessage = Union[FunctionCallMessage, AssistantMessage] - class WenxinMessage(TypedDict): role: Literal['user', 'assistant', 'function'] @@ -63,21 +60,21 @@ def _convert_to_wenxin_message(message: Message) -> WenxinMessage: } if isinstance(message, AssistantMessage): + if message.function_call: + return { + 'role': 'assistant', + 'function_call': { + 'name': message.function_call.name, + 'arguments': message.function_call.arguments, + 'thoughts': message.function_call.thoughts or '', + }, + 'content': message.content, + } return { 'role': 'assistant', 'content': message.content, } - if isinstance(message, FunctionCallMessage): - return { - 'role': 'assistant', - 'function_call': { - 'name': message.content.name, - 'arguments': message.content.arguments, - 'thoughts': message.content.thoughts or '', - }, - 'content': '', - } if isinstance(message, FunctionMessage): return { 'role': 'function', @@ -85,7 +82,7 @@ def _convert_to_wenxin_message(message: Message) -> WenxinMessage: 'content': message.content, } - raise MessageTypeError(message, allowed_message_type=(UserMessage, AssistantMessage, FunctionMessage, FunctionCallMessage)) + raise MessageTypeError(message, allowed_message_type=(UserMessage, AssistantMessage, FunctionMessage)) def _convert_messages(messages: Messages) -> list[WenxinMessage]: @@ -161,9 +158,7 @@ def _get_request_parameters(self, messages: Messages, parameters: WenxinChatPara } @override - def generate( - self, prompt: Prompt, **kwargs: Unpack[WenxinChatParametersDict] - ) -> ChatCompletionOutput[WenxinAssistantMessage]: + def generate(self, prompt: Prompt, **kwargs: Unpack[WenxinChatParametersDict]) -> ChatCompletionOutput: messages = ensure_messages(prompt) parameters = self.parameters.update_with_validate(**kwargs) request_parameters = self._get_request_parameters(messages, parameters) @@ -171,29 +166,26 @@ def generate( return self._parse_reponse(response.json()) @override - async def async_generate( - self, prompt: Prompt, **kwargs: Unpack[WenxinChatParametersDict] - ) -> ChatCompletionOutput[WenxinAssistantMessage]: + async def async_generate(self, prompt: Prompt, **kwargs: Unpack[WenxinChatParametersDict]) -> ChatCompletionOutput: messages = ensure_messages(prompt) parameters = self.parameters.update_with_validate(**kwargs) request_parameters = self._get_request_parameters(messages, parameters) response = await self.http_client.async_post(request_parameters=request_parameters) return self._parse_reponse(response.json()) - def _parse_reponse(self, response: ResponseValue) -> ChatCompletionOutput[WenxinAssistantMessage]: + def _parse_reponse(self, response: ResponseValue) -> ChatCompletionOutput: if response.get('error_msg'): raise UnexpectedResponseError(response) if response.get('function_call'): - message = FunctionCallMessage( - content=FunctionCall( - name=response['function_call']['name'], - arguments=response['function_call']['arguments'], - thoughts=response['function_call']['thoughts'], - ), + function_call = FunctionCall( + name=response['function_call']['name'], + arguments=response['function_call']['arguments'], + thoughts=response['function_call']['thoughts'], ) else: - message = AssistantMessage(content=response['result']) - return ChatCompletionOutput[WenxinAssistantMessage]( + function_call = None + message = AssistantMessage(content=response['result'], function_call=function_call) + return ChatCompletionOutput( model_info=self.model_info, message=message, cost=self.calculate_cost(response['usage']), @@ -212,7 +204,7 @@ def _get_stream_request_parameters(self, messages: Messages, parameters: WenxinC @override def stream_generate( self, prompt: Prompt, **kwargs: Unpack[WenxinChatParametersDict] - ) -> Iterator[ChatCompletionStreamOutput[WenxinAssistantMessage]]: + ) -> Iterator[ChatCompletionStreamOutput]: messages = ensure_messages(prompt) parameters = self.parameters.update_with_validate(**kwargs) if parameters.functions: @@ -228,7 +220,7 @@ def stream_generate( @override async def async_stream_generate( self, prompt: Prompt, **kwargs: Unpack[WenxinChatParametersDict] - ) -> AsyncIterator[ChatCompletionStreamOutput[WenxinAssistantMessage]]: + ) -> AsyncIterator[ChatCompletionStreamOutput]: messages = ensure_messages(prompt) parameters = self.parameters.update_with_validate(**kwargs) if parameters.functions: @@ -241,14 +233,12 @@ async def async_stream_generate( is_start = False yield output - def _parse_stream_line( - self, line: str, message: AssistantMessage, is_start: bool - ) -> ChatCompletionStreamOutput[WenxinAssistantMessage]: + def _parse_stream_line(self, line: str, message: AssistantMessage, is_start: bool) -> ChatCompletionStreamOutput: parsed_line = json.loads(line) delta = parsed_line['result'] message.content += delta if parsed_line['is_end']: - return ChatCompletionStreamOutput[WenxinAssistantMessage]( + return ChatCompletionStreamOutput( model_info=self.model_info, cost=self.calculate_cost(parsed_line['usage']), extra={ @@ -260,7 +250,7 @@ def _parse_stream_line( finish_reason='stop', stream=Stream(delta=delta, control='finish'), ) - return ChatCompletionStreamOutput[WenxinAssistantMessage]( + return ChatCompletionStreamOutput( model_info=self.model_info, message=message, finish_reason=None, diff --git a/generate/chat_completion/models/zhipu.py b/generate/chat_completion/models/zhipu.py index ed33dbf..11e83fc 100644 --- a/generate/chat_completion/models/zhipu.py +++ b/generate/chat_completion/models/zhipu.py @@ -156,10 +156,10 @@ def _get_request_parameters(self, messages: Messages, parameters: ModelParameter def _convert_messages(self, messages: Messages) -> list[ZhipuMessage]: return [convert_to_zhipu_message(message) for message in messages] - def _parse_reponse(self, response: ResponseValue) -> ChatCompletionOutput[AssistantMessage]: + def _parse_reponse(self, response: ResponseValue) -> ChatCompletionOutput: if response['success']: text = response['data']['choices'][0]['content'] - return ChatCompletionOutput[AssistantMessage]( + return ChatCompletionOutput( model_info=self.model_info, message=AssistantMessage(content=text), cost=self.calculate_cost(response['data']['usage']), @@ -201,7 +201,7 @@ def _convert_messages(self, messages: Messages) -> list[ZhipuMessage]: return super()._convert_messages(messages) @override - def generate(self, prompt: Prompt, **kwargs: Unpack[ZhipuChatParametersDict]) -> ChatCompletionOutput[AssistantMessage]: + def generate(self, prompt: Prompt, **kwargs: Unpack[ZhipuChatParametersDict]) -> ChatCompletionOutput: messages = ensure_messages(prompt) parameters = self.parameters.update_with_validate(**kwargs) request_parameters = self._get_request_parameters(messages, parameters) @@ -209,9 +209,7 @@ def generate(self, prompt: Prompt, **kwargs: Unpack[ZhipuChatParametersDict]) -> return self._parse_reponse(response.json()) @override - async def async_generate( - self, prompt: Prompt, **kwargs: Unpack[ZhipuChatParametersDict] - ) -> ChatCompletionOutput[AssistantMessage]: + async def async_generate(self, prompt: Prompt, **kwargs: Unpack[ZhipuChatParametersDict]) -> ChatCompletionOutput: messages = ensure_messages(prompt) parameters = self.parameters.update_with_validate(**kwargs) request_parameters = self._get_request_parameters(messages, parameters) @@ -221,7 +219,7 @@ async def async_generate( @override def stream_generate( self, prompt: Prompt, **kwargs: Unpack[ZhipuChatParametersDict] - ) -> Iterator[ChatCompletionStreamOutput[AssistantMessage]]: + ) -> Iterator[ChatCompletionStreamOutput]: messages = ensure_messages(prompt) parameters = self.parameters.update_with_validate(**kwargs) request_parameters = self._get_stream_request_parameters(messages, parameters) @@ -229,13 +227,13 @@ def stream_generate( is_start = True for line in self.http_client.stream_post(request_parameters=request_parameters): message.content += line - yield ChatCompletionStreamOutput[AssistantMessage]( + yield ChatCompletionStreamOutput( model_info=self.model_info, message=message, stream=Stream(delta=line, control='start' if is_start else 'continue'), ) is_start = False - yield ChatCompletionStreamOutput[AssistantMessage]( + yield ChatCompletionStreamOutput( model_info=self.model_info, message=message, finish_reason='stop', @@ -245,7 +243,7 @@ def stream_generate( @override async def async_stream_generate( self, prompt: Prompt, **kwargs: Unpack[ZhipuChatParametersDict] - ) -> AsyncIterator[ChatCompletionStreamOutput[AssistantMessage]]: + ) -> AsyncIterator[ChatCompletionStreamOutput]: messages = ensure_messages(prompt) parameters = self.parameters.update_with_validate(**kwargs) request_parameters = self._get_stream_request_parameters(messages, parameters) @@ -253,13 +251,13 @@ async def async_stream_generate( is_start = True async for line in self.http_client.async_stream_post(request_parameters=request_parameters): message.content += line - yield ChatCompletionStreamOutput[AssistantMessage]( + yield ChatCompletionStreamOutput( model_info=self.model_info, message=message, stream=Stream(delta=line, control='start' if is_start else 'continue'), ) is_start = False - yield ChatCompletionStreamOutput[AssistantMessage]( + yield ChatCompletionStreamOutput( model_info=self.model_info, message=message, finish_reason='stop', @@ -291,9 +289,7 @@ def __init__( super().__init__(model=model, parameters=parameters, settings=settings, http_client=http_client) @override - def generate( - self, prompt: Prompt, **kwargs: Unpack[ZhipuCharacterChatParametersDict] - ) -> ChatCompletionOutput[AssistantMessage]: + def generate(self, prompt: Prompt, **kwargs: Unpack[ZhipuCharacterChatParametersDict]) -> ChatCompletionOutput: messages = ensure_messages(prompt) parameters = self.parameters.update_with_validate(**kwargs) request_parameters = self._get_request_parameters(messages, parameters) @@ -301,9 +297,7 @@ def generate( return self._parse_reponse(response.json()) @override - async def async_generate( - self, prompt: Prompt, **kwargs: Unpack[ZhipuCharacterChatParametersDict] - ) -> ChatCompletionOutput[AssistantMessage]: + async def async_generate(self, prompt: Prompt, **kwargs: Unpack[ZhipuCharacterChatParametersDict]) -> ChatCompletionOutput: messages = ensure_messages(prompt) parameters = self.parameters.update_with_validate(**kwargs) request_parameters = self._get_request_parameters(messages, parameters) @@ -313,7 +307,7 @@ async def async_generate( @override def stream_generate( self, prompt: Prompt, **kwargs: Unpack[ZhipuCharacterChatParametersDict] - ) -> Iterator[ChatCompletionStreamOutput[AssistantMessage]]: + ) -> Iterator[ChatCompletionStreamOutput]: messages = ensure_messages(prompt) parameters = self.parameters.update_with_validate(**kwargs) request_parameters = self._get_stream_request_parameters(messages, parameters) @@ -321,13 +315,13 @@ def stream_generate( is_start = True for line in self.http_client.stream_post(request_parameters=request_parameters): message.content += line - yield ChatCompletionStreamOutput[AssistantMessage]( + yield ChatCompletionStreamOutput( model_info=self.model_info, message=message, stream=Stream(delta=line, control='start' if is_start else 'continue'), ) is_start = False - yield ChatCompletionStreamOutput[AssistantMessage]( + yield ChatCompletionStreamOutput( model_info=self.model_info, message=message, finish_reason='stop', @@ -337,7 +331,7 @@ def stream_generate( @override async def async_stream_generate( self, prompt: Prompt, **kwargs: Unpack[ZhipuCharacterChatParametersDict] - ) -> AsyncIterator[ChatCompletionStreamOutput[AssistantMessage]]: + ) -> AsyncIterator[ChatCompletionStreamOutput]: messages = ensure_messages(prompt) parameters = self.parameters.update_with_validate(**kwargs) request_parameters = self._get_stream_request_parameters(messages, parameters) @@ -345,13 +339,13 @@ async def async_stream_generate( is_start = True async for line in self.http_client.async_stream_post(request_parameters=request_parameters): message.content += line - yield ChatCompletionStreamOutput[AssistantMessage]( + yield ChatCompletionStreamOutput( model_info=self.model_info, message=message, stream=Stream(delta=line, control='start' if is_start else 'continue'), ) is_start = False - yield ChatCompletionStreamOutput[AssistantMessage]( + yield ChatCompletionStreamOutput( model_info=self.model_info, message=message, finish_reason='stop', diff --git a/generate/chat_completion/printer.py b/generate/chat_completion/printer.py index 8b2107e..0852573 100644 --- a/generate/chat_completion/printer.py +++ b/generate/chat_completion/printer.py @@ -4,10 +4,8 @@ from generate.chat_completion.message import ( AssistantMessage, - FunctionCallMessage, FunctionMessage, Message, - ToolCallsMessage, ToolMessage, UserMessage, ) @@ -48,15 +46,18 @@ def __init__(self, smooth: bool = True, interval: float = 0.03) -> None: self.interval = interval def print_message(self, message: Message) -> None: - if isinstance(message, (UserMessage, AssistantMessage, FunctionMessage, ToolMessage)): + if isinstance(message, (UserMessage, FunctionMessage, ToolMessage)): print(f'{message.role}: {message.content}') - elif isinstance(message, FunctionCallMessage): - print(f'Function call: {message.content.name}\nArguments: {message.content.arguments}') - elif isinstance(message, ToolCallsMessage): - for tool_call in message.content: - print( - f'Tool call: {tool_call.id}\nFunction: {tool_call.function.name}\nArguments: {tool_call.function.arguments}' - ) + elif isinstance(message, AssistantMessage): + if message.content: + print(f'assistant: {message.content}') + if message.function_call: + print(f'Function call: {message.function_call.name}\nArguments: {message.function_call.arguments}') + if message.tool_calls: + for tool_call in message.tool_calls: + print( + f'Tool call: {tool_call.id}\nFunction: {tool_call.function.name}\nArguments: {tool_call.function.arguments}' + ) else: raise TypeError(f'Invalid message type: {type(message)}') diff --git a/generate/chat_engine.py b/generate/chat_engine.py index c934a68..5fc87d8 100644 --- a/generate/chat_engine.py +++ b/generate/chat_engine.py @@ -7,12 +7,9 @@ from generate.chat_completion import ChatCompletionModel, ChatCompletionOutput from generate.chat_completion.message import ( - AssistantMessage, FunctionCall, - FunctionCallMessage, FunctionMessage, ToolCall, - ToolCallsMessage, ToolMessage, UnionMessage, UserMessage, @@ -88,26 +85,24 @@ def chat(self, user_input: str, **kwargs: Any) -> str: model_output = self._chat_model.generate(self.history, **kwargs) self.printer.print_message(model_output.message) self._handle_model_output(model_output) - if isinstance(model_output.message, AssistantMessage): + if model_output.message.is_over: return model_output.reply def _handle_model_output(self, model_output: ChatCompletionOutput, **kwargs: Any) -> None: self.model_ouptuts.append(model_output) self.history.append(model_output.message) - if isinstance(model_output.message, FunctionCallMessage): + if model_output.message.function_call: self._call_count += 1 if self._call_count > self.max_calls_per_turn: raise RuntimeError('Maximum number of function calls reached.') - function_call = model_output.message.content - self._handle_function_call(function_call) + self._handle_function_call(model_output.message.function_call) - if isinstance(model_output.message, ToolCallsMessage): + if model_output.message.tool_calls: self._call_count += 1 if self._call_count > self.max_calls_per_turn: raise RuntimeError('Maximum number of tool calls reached.') - tool_calls = model_output.message.content - self._handle_tool_calls(tool_calls, **kwargs) + self._handle_tool_calls(model_output.message.tool_calls, **kwargs) def _handle_function_call(self, function_call: FunctionCall) -> None: function_output = self._run_function_call(function_call) diff --git a/generate/version.py b/generate/version.py index d93b5b2..0404d81 100644 --- a/generate/version.py +++ b/generate/version.py @@ -1 +1 @@ -__version__ = '0.2.3' +__version__ = '0.3.0' diff --git a/pyproject.toml b/pyproject.toml index 30afacb..5130452 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "generate-core" -version = "0.2.3" +version = "0.3.0" description = "文本生成,图像生成,语音生成" authors = ["wangyuxin "] license = "MIT" diff --git a/tests/test_message.py b/tests/test_message.py index 11c8fc4..3b98a5c 100644 --- a/tests/test_message.py +++ b/tests/test_message.py @@ -1,7 +1,6 @@ from generate.chat_completion.message import ( AssistantMessage, FunctionCall, - FunctionCallMessage, UserMessage, ensure_messages, ) @@ -18,8 +17,8 @@ def test_ensure_messages_with_dict() -> None: expected_messages = [UserMessage(content='Hello, how can I help you?')] assert ensure_messages(prompt) == expected_messages - prompt = {'role': 'assistant', 'name': 'bot', 'content': {'name': 'test', 'arguments': 'test'}} - expected_messages = [FunctionCallMessage(name='bot', content=FunctionCall(name='test', arguments='test'))] + prompt = {'role': 'assistant', 'name': 'bot', 'function_call': {'name': 'test', 'arguments': 'test'}} + expected_messages = [AssistantMessage(name='bot', function_call=FunctionCall(name='test', arguments='test'))] assert ensure_messages(prompt) == expected_messages prompt = {'role': 'assistant', 'name': 'bot', 'content': 'test'}