diff --git a/examples/tutorial.ipynb b/examples/tutorial.ipynb index 3faf0c4..cedd277 100644 --- a/examples/tutorial.ipynb +++ b/examples/tutorial.ipynb @@ -9,7 +9,7 @@ "👏🏻 欢迎来到 Generate 的教程,在这里您将学习到:\n", "\n", "1. 使用统一简洁的 API 替代不同平台杂乱的 SDK\n", - "2. 使用 `Generate` 生成文本,图像以及音频" + "2. 使用 `Generate` 集成的模型生成文本,图像以及音频" ] }, { @@ -66,7 +66,7 @@ "source": [ "### 配置 OpenAI Key\n", "\n", - "在使用 `generate` 之前,需要先配置 OpenAI API Key,这样才能使用 OpenAI 的 API。 \n", + "在使用 `generate` 之前,需要先配置 OpenAI API Key。 \n", "\n", "`generate` 库使用 [Pydantic Settings](https://docs.pydantic.dev/latest/concepts/pydantic_settings/) 管理不同平台的配置,Pydantic Settings 会从 `.env` 文件,环境变量或者 **运行时** 获取相关配置。\n", "\n", @@ -150,12 +150,15 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "我们可以看到模型输出了一个结构化的对象 `ChatCompletionOutput`,其中包括了一些有用的信息,比如:\n", - "- 模型信息,在 model_info 字段中,包含了任务种类,平台以及模型名称\n", - "- 生成消息,在 messages 字段中,包含了 gpt 模型生成的消息\n", - "- 计费信息,在 cost 字段中,包含了此次任务的花销,单位是元\n", - "- 结束原因,在 finish_reason 字段中,展示了此次任务完成的原因\n", - "- 额外信息,在 extra 字段中,包含了一些可能会有用的额外信息,比如此次任务使用了多少 token 等" + "我们可以看到模型输出了一个结构化的对象 `ChatCompletionOutput`,其中包括了一些有用的信息\n", + "\n", + "| field | description |\n", + "| --- | --- |\n", + "| model_info | 包含任务种类,平台及模型名称 |\n", + "| message | 模型生成的消息 |\n", + "| cost | 此次任务的花销,单位是元 |\n", + "| finish_reason | 任务完成的原因 |\n", + "| extra | 包含可能会有用的额外信息 |\n" ] }, { @@ -164,7 +167,7 @@ "source": [ "`ChatCompletionOutput` 对象的基类是 [Pydantic BaseModel](https://docs.pydantic.dev/latest/concepts/models/),因此我们可以通过访问属性的方式访问这些字段。\n", "\n", - "除此之外,`ChatCompletionOutput` 还提供了一些常用的计算属性,比如 `reply` 和 `last_message`。就像下面这样" + "除此之外,`ChatCompletionOutput` 还提供了一些常用的计算属性,比如 `reply`。就像下面这样" ] }, { @@ -175,10 +178,10 @@ "source": [ "cost = model_output.cost\n", "reply = model_output.reply\n", - "last_message = model_output.last_message\n", + "message = model_output.message\n", "rich.print(f'{cost=}')\n", "rich.print(f'{reply=}')\n", - "rich.print(f'{last_message=}')" + "rich.print(f'{message=}')" ] }, { @@ -187,9 +190,9 @@ "source": [ "### 设置模型及其参数\n", "\n", - "当然,我们也可以不使用默认的模型和参数,而是自定义他们。\n", + "在上一个示例中,我们没有设置模型类型和参数,而是使用默认值。现在,让我们学习一下如何指定模型类型和模型参数。\n", "\n", - "模型的参数可以在模型初始化的时候设置,以作为模型的默认参数。也可以在调用 `generate` 方法的时候设置,以作为此次调用的参数。\n", + "模型的参数可以在模型初始化的时候设置,以作为模型的默认参数。也可以在调用 `generate` 方法的时候设置,作为此次调用的参数。\n", "\n", "- 初始化时的参数,必须显式声明,以 `OpenAIChat` 为例,它的参数为 `OpenAIChatParameters` 实例。\n", "- 调用时的参数,无须显式声明,直接传入关键字参数即可,比如 `model.generate('你好', temperature=0.5)`\n", diff --git a/generate/chat_completion/function_call.py b/generate/chat_completion/function_call.py index f208432..4eed2c1 100644 --- a/generate/chat_completion/function_call.py +++ b/generate/chat_completion/function_call.py @@ -39,26 +39,21 @@ def get_json_schema(function: Callable[..., Any]) -> FunctionJsonSchema: class function(Generic[P, T]): # noqa: N801 - """ - A decorator class that wraps a callable function and provides additional functionality. - - Args: - function (Callable[P, T]): The function to be wrapped. + def __init__(self, function: Callable[P, T]) -> None: + self.function: Callable[P, T] = validate_call(function) + self.json_schema: FunctionJsonSchema = get_json_schema(function) - Attributes: - function (Callable[P, T]): The wrapped function. - name (str): The name of the wrapped function. - docstring (ParsedDocstring): The parsed docstring of the wrapped function. - json_schema (Function): The JSON schema of the wrapped function. + @property + def name(self) -> str: + return self.json_schema['name'] - Methods: - __call__(self, *args: Any, **kwargs: Any) -> Any: Calls the wrapped function with the provided arguments. - call_with_message(self, message: Message) -> T: Calls the wrapped function with the arguments provided in the message. - """ + @property + def description(self) -> str: + return self.json_schema.get('description', '') - def __init__(self, function: Callable[P, T]) -> None: - self.function: Callable[P, T] = validate_call(function) - self.json_schema = get_json_schema(function) + @property + def parameters(self) -> JsonSchema: + return self.json_schema['parameters'] def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T: return self.function(*args, **kwargs) diff --git a/generate/chat_completion/message/__init__.py b/generate/chat_completion/message/__init__.py index cc4da15..c14562a 100644 --- a/generate/chat_completion/message/__init__.py +++ b/generate/chat_completion/message/__init__.py @@ -1,4 +1,5 @@ from generate.chat_completion.message.core import ( + AssistantGroupMessage, AssistantMessage, FunctionCall, FunctionCallMessage, @@ -34,6 +35,7 @@ 'UnionMessage', 'FunctionCallMessage', 'AssistantMessage', + 'AssistantGroupMessage', 'ToolCallsMessage', 'ensure_messages', 'FunctionCall', diff --git a/generate/chat_completion/message/core.py b/generate/chat_completion/message/core.py index 28176de..6b27250 100644 --- a/generate/chat_completion/message/core.py +++ b/generate/chat_completion/message/core.py @@ -79,7 +79,13 @@ class ToolCallsMessage(Message): content: List[ToolCall] -UnionAssistantMessage = Union[AssistantMessage, FunctionCallMessage, ToolCallsMessage] +class AssistantGroupMessage(Message): + role: Literal['assistant'] = 'assistant' + name: Optional[str] = None + content: List[Union[AssistantMessage, FunctionMessage, FunctionCallMessage]] + + +UnionAssistantMessage = Union[AssistantMessage, FunctionCallMessage, ToolCallsMessage, AssistantGroupMessage] UnionUserMessage = Union[UserMessage, UserMultiPartMessage] UnionUserPart = Union[TextPart, ImageUrlPart] UnionMessage = Union[SystemMessage, FunctionMessage, ToolMessage, UnionAssistantMessage, UnionUserMessage] diff --git a/generate/chat_completion/model_output.py b/generate/chat_completion/model_output.py index fa536d2..cbe4487 100644 --- a/generate/chat_completion/model_output.py +++ b/generate/chat_completion/model_output.py @@ -1,25 +1,19 @@ from __future__ import annotations -from typing import Generic, List, Literal, Optional, TypeVar +from typing import Generic, Literal, Optional, TypeVar from pydantic import BaseModel -from generate.chat_completion.message import AssistantMessage, UnionMessage +from generate.chat_completion.message import AssistantMessage, UnionAssistantMessage from generate.model import ModelOutput -M = TypeVar('M', bound=UnionMessage) +M = TypeVar('M', bound=UnionAssistantMessage) class ChatCompletionOutput(ModelOutput, Generic[M]): - messages: List[M] = [] + message: M finish_reason: Optional[str] = None - @property - def message(self) -> M: - if len(self.messages) != 1: - raise ValueError('Expected exactly one message') - return self.messages[0] - @property def reply(self) -> str: if self.message and isinstance(self.message, AssistantMessage): diff --git a/generate/chat_completion/models/baichuan.py b/generate/chat_completion/models/baichuan.py index e707b08..df013df 100644 --- a/generate/chat_completion/models/baichuan.py +++ b/generate/chat_completion/models/baichuan.py @@ -148,7 +148,7 @@ def _parse_reponse(self, response: ResponseValue) -> ChatCompletionOutput[Assist extra = {} return ChatCompletionOutput( model_info=self.model_info, - messages=[AssistantMessage(content=text)], + message=AssistantMessage(content=text), finish_reason=finish_reason, cost=cost, extra=extra, @@ -199,10 +199,9 @@ def _parse_stream_line( else: stream = Stream(delta=output_message.content, control='finish' if output.is_finish else 'continue') message.content += output_message.content - output.messages = [message] return ChatCompletionStreamOutput[AssistantMessage]( model_info=output.model_info, - messages=output.messages, + message=message, finish_reason=output.finish_reason, cost=output.cost, extra=output.extra, diff --git a/generate/chat_completion/models/bailian.py b/generate/chat_completion/models/bailian.py index a94fe17..682b4d0 100644 --- a/generate/chat_completion/models/bailian.py +++ b/generate/chat_completion/models/bailian.py @@ -47,7 +47,7 @@ def generate_default_request_id() -> str: def _convert_to_bailian_chat_qa_pair(messages: Messages) -> list[BailianChatQAPair]: pairs: list[BailianChatQAPair] = [] - if isinstance(messages[0], SystemMessage): + if messages and isinstance(messages[0], SystemMessage): pairs.append({'User': messages[0].content, 'Bot': '好的'}) messages = messages[1:] @@ -155,10 +155,9 @@ async def async_generate( def _parse_reponse(self, response: ResponseValue) -> ChatCompletionOutput[AssistantMessage]: if not response['Success']: raise UnexpectedResponseError(response) - messages = [AssistantMessage(content=response['Data']['Text'])] return ChatCompletionOutput( model_info=self.model_info, - messages=messages, + message=AssistantMessage(content=response['Data']['Text']), extra={ 'thoughts': response['Data']['Thoughts'], 'doc_references': response['Data']['DocReferences'], @@ -220,7 +219,7 @@ def _parse_stream_line( if len(reply) == len(message.content): return ChatCompletionStreamOutput( model_info=self.model_info, - messages=[message], + message=message, extra=extra, finish_reason='stop', stream=Stream(delta='', control='finish'), @@ -230,7 +229,7 @@ def _parse_stream_line( message.content = reply return ChatCompletionStreamOutput( model_info=self.model_info, - messages=[message], + message=message, extra=extra, stream=Stream(delta=delta, control='start' if is_start else 'continue'), ) diff --git a/generate/chat_completion/models/hunyuan.py b/generate/chat_completion/models/hunyuan.py index 530398a..ac8bc4f 100644 --- a/generate/chat_completion/models/hunyuan.py +++ b/generate/chat_completion/models/hunyuan.py @@ -111,10 +111,9 @@ async def async_generate( def _parse_reponse(self, response: ResponseValue) -> ChatCompletionOutput[AssistantMessage]: if response.get('error'): raise UnexpectedResponseError(response) - messages = [AssistantMessage(content=response['choices'][0]['messages']['content'])] return ChatCompletionOutput( model_info=self.model_info, - messages=messages, + message=AssistantMessage(content=response['choices'][0]['messages']['content']), finish_reason=response['choices'][0]['finish_reason'], cost=self.calculate_cost(response['usage']), extra={'usage': response['usage']}, @@ -172,7 +171,7 @@ def _parse_stream_line( if message_dict['finish_reason']: return ChatCompletionStreamOutput( model_info=self.model_info, - messages=[message], + message=message, finish_reason=message_dict['finish_reason'], cost=self.calculate_cost(parsed_line['usage']), stream=Stream(delta=delta, control='finish'), @@ -180,7 +179,7 @@ def _parse_stream_line( ) return ChatCompletionStreamOutput( model_info=self.model_info, - messages=[message], + message=message, finish_reason=None, stream=Stream(delta=delta, control='start' if is_start else 'continue'), ) diff --git a/generate/chat_completion/models/minimax.py b/generate/chat_completion/models/minimax.py index c6061b5..a02fdb3 100644 --- a/generate/chat_completion/models/minimax.py +++ b/generate/chat_completion/models/minimax.py @@ -145,12 +145,11 @@ async def async_generate( response = await self.http_client.async_post(request_parameters=request_parameters) return self._parse_reponse(response.json()) - def _parse_reponse(self, response: ResponseValue) -> ChatCompletionOutput: + def _parse_reponse(self, response: ResponseValue) -> ChatCompletionOutput[AssistantMessage]: try: - messages = [AssistantMessage(content=response['choices'][0]['text'])] - return ChatCompletionOutput( + return ChatCompletionOutput[AssistantMessage]( model_info=self.model_info, - messages=messages, + message=AssistantMessage(content=response['choices'][0]['text']), finish_reason=response['choices'][0]['finish_reason'], cost=self.calculate_cost(response['usage']), extra={ @@ -207,7 +206,7 @@ def _parse_stream_line( return ChatCompletionStreamOutput[AssistantMessage]( model_info=self.model_info, finish_reason=parsed_line['choices'][0]['finish_reason'], - messages=[message], + message=message, cost=self.calculate_cost(parsed_line['usage']), extra={ 'logprobes': parsed_line['choices'][0]['logprobes'], @@ -220,7 +219,7 @@ def _parse_stream_line( return ChatCompletionStreamOutput[AssistantMessage]( model_info=self.model_info, finish_reason=None, - messages=[message], + message=message, stream=Stream(delta=delta, control='start' if is_start else 'continue'), ) diff --git a/generate/chat_completion/models/minimax_pro.py b/generate/chat_completion/models/minimax_pro.py index 7a00d63..f6801c1 100644 --- a/generate/chat_completion/models/minimax_pro.py +++ b/generate/chat_completion/models/minimax_pro.py @@ -9,6 +9,7 @@ from generate.chat_completion.base import ChatCompletionModel from generate.chat_completion.function_call import FunctionJsonSchema from generate.chat_completion.message import ( + AssistantGroupMessage, AssistantMessage, FunctionCall, FunctionCallMessage, @@ -29,11 +30,11 @@ ResponseValue, UnexpectedResponseError, ) -from generate.model import ModelParameters, ModelParametersDict +from generate.model import ModelInfo, ModelParameters, ModelParametersDict from generate.platforms.minimax import MinimaxSettings from generate.types import Probability, Temperature -MinimaxProAssistantMessage = Union[FunctionCallMessage, AssistantMessage, UserMessage, FunctionMessage] +MinimaxProAssistantMessage = Union[AssistantMessage, FunctionCallMessage, AssistantGroupMessage] class BotSettingDict(TypedDict): @@ -174,6 +175,91 @@ def _convert_to_minimax_pro_message( raise MessageTypeError(message, allowed_message_type=(UserMessage, AssistantMessage, FunctionMessage, FunctionCallMessage)) +def _convert_to_message(message: MinimaxProMessage) -> FunctionCallMessage | AssistantMessage | FunctionMessage: + if 'function_call' in message: + return FunctionCallMessage( + name=message['sender_name'], + content=FunctionCall(name=message['function_call']['name'], arguments=message['function_call']['arguments']), + ) + if message['sender_type'] == 'BOT': + return AssistantMessage( + name=message['sender_name'], + content=message['text'], + ) + if message['sender_type'] == 'FUNCTION': + return FunctionMessage( + name=message['sender_name'], + content=message['text'], + ) + raise ValueError(f'unknown sender_type: {message["sender_type"]}') + + +class _StreamResponseProcessor: + def __init__(self, model_info: ModelInfo) -> None: + self.message: MinimaxProAssistantMessage | None = None + self.model_info = model_info + + def process(self, response: ResponseValue) -> ChatCompletionStreamOutput[MinimaxProAssistantMessage]: + if response.get('usage'): + return ChatCompletionStreamOutput[MinimaxProAssistantMessage]( + model_info=self.model_info, + message=self.message, + finish_reason=response['choices'][0]['finish_reason'], + cost=calculate_cost(response['usage']), + extra={ + 'input_sensitive': response['input_sensitive'], + 'output_sensitive': response['output_sensitive'], + 'usage': response['usage'], + }, + stream=Stream(delta='', control='finish'), + ) + + if self.message is None: + self.message = self.initial_message(response) + delta = self.message.content if isinstance(self.message, AssistantMessage) else '' + control = 'start' + else: + delta = self.update_existing_message(response) + control = 'continue' + + return ChatCompletionStreamOutput[MinimaxProAssistantMessage]( + model_info=self.model_info, + message=self.message, + finish_reason=None, + stream=Stream(delta=delta, control=control), + ) + + def initial_message(self, response: ResponseValue) -> MinimaxProAssistantMessage: + output_messages = [_convert_to_message(i) for i in response['choices'][0]['messages']] + message = output_messages[0] if len(output_messages) == 1 else AssistantGroupMessage(content=output_messages) + return cast(MinimaxProAssistantMessage, message) + + def update_existing_message(self, response: ResponseValue) -> str: + output_messages = [_convert_to_message(i) for i in response['choices'][0]['messages']] + if len(output_messages) > 1 and not isinstance(self.message, AssistantGroupMessage): + self.message = AssistantGroupMessage(content=[self.message]) # type: ignore + messages = self.message.content if isinstance(self.message, AssistantGroupMessage) else [self.message] + delta = '' + for index, output_message in enumerate(output_messages, start=1): + if index > len(messages): + messages.append(output_message) # type: ignore + if isinstance(output_message, AssistantMessage): + delta = output_message.content + elif isinstance(output_message, FunctionCallMessage): + messages[index - 1] = output_message + elif isinstance(output_message, AssistantMessage): + message = cast(AssistantMessage, messages[index - 1]) + delta = output_message.content + message.content += output_message.content + else: + raise ValueError(f'unknown message type: {output_message}') + return delta + + +def calculate_cost(usage: dict[str, int], num_web_search: int = 0) -> float: + return 0.015 * (usage['total_tokens'] / 1000) + (0.03 * num_web_search) + + class MinimaxProChat(ChatCompletionModel): model_type: ClassVar[str] = 'minimax_pro' @@ -236,17 +322,18 @@ async def async_generate( def _parse_reponse(self, response: ResponseValue) -> ChatCompletionOutput[MinimaxProAssistantMessage]: try: - messages: list[FunctionCallMessage | UserMessage | AssistantMessage | FunctionMessage] = [ - self._convert_to_message(i) for i in response['choices'][0]['messages'] + messages: list[FunctionCallMessage | AssistantMessage | FunctionMessage] = [ + _convert_to_message(i) for i in response['choices'][0]['messages'] ] + message = messages[0] if len(messages) == 1 else AssistantGroupMessage(content=messages) + message = cast(MinimaxProAssistantMessage, message) finish_reason = response['choices'][0]['finish_reason'] - num_web_search = sum([i for i in response['choices'][0]['messages'] if i['sender_name'] == 'plugin_web_search']) - + num_web_search = sum([1 for i in response['choices'][0]['messages'] if i['sender_name'] == 'plugin_web_search']) return ChatCompletionOutput[MinimaxProAssistantMessage]( model_info=self.model_info, - messages=messages, + message=message, finish_reason=finish_reason, - cost=self.calculate_cost(response['usage'], num_web_search), + cost=calculate_cost(response['usage'], num_web_search), extra={ 'input_sensitive': response['input_sensitive'], 'output_sensitive': response['output_sensitive'], @@ -268,12 +355,9 @@ def stream_generate( messages = ensure_messages(prompt) parameters = self.parameters.update_with_validate(**kwargs) request_parameters = self._get_stream_request_parameters(messages, parameters) - message = [] - is_start = True + stream_processor = _StreamResponseProcessor(model_info=self.model_info) for line in self.http_client.stream_post(request_parameters=request_parameters): - output = self._parse_stream_line(line, message, is_start) - is_start = False - yield output + yield stream_processor.process(json.loads(line)) @override async def async_stream_generate( @@ -282,82 +366,9 @@ async def async_stream_generate( messages = ensure_messages(prompt) parameters = self.parameters.update_with_validate(**kwargs) request_parameters = self._get_stream_request_parameters(messages, parameters) - message = [] - is_start = True + stream_processor = _StreamResponseProcessor(model_info=self.model_info) async for line in self.http_client.async_stream_post(request_parameters=request_parameters): - output = self._parse_stream_line(line, message, is_start) - is_start = False - yield output - - def _parse_stream_line( - self, line: str, messages: Messages, is_start: bool - ) -> ChatCompletionStreamOutput[MinimaxProAssistantMessage]: - messages = list(messages) - parsed_line = json.loads(line) - output_messages = [self._convert_to_message(i) for i in parsed_line['choices'][0]['messages']] - delta = '' - for index, output_message in enumerate(output_messages, start=1): - if index > len(messages): - messages.append(output_message) - if isinstance(output_message, AssistantMessage): - delta = output_message.content - elif isinstance(output_message, FunctionCallMessage): - messages[index - 1] = output_message - elif isinstance(output_message, AssistantMessage): - message = cast(AssistantMessage, messages[index - 1]) - delta = output_message.content - message.content += output_message.content - else: - raise ValueError(f'unknown message type: {output_message}') - - if parsed_line.get('usage'): - return ChatCompletionStreamOutput[MinimaxProAssistantMessage]( - model_info=self.model_info, - messages=messages, - finish_reason=parsed_line['choices'][0]['finish_reason'], - cost=self.calculate_cost(parsed_line['usage']), - extra={ - 'input_sensitive': parsed_line['input_sensitive'], - 'output_sensitive': parsed_line['output_sensitive'], - 'usage': parsed_line['usage'], - }, - stream=Stream(delta=delta, control='finish'), - ) - return ChatCompletionStreamOutput[MinimaxProAssistantMessage]( - model_info=self.model_info, - messages=messages, - finish_reason=None, - stream=Stream(delta=delta, control='start' if is_start else 'continue'), - ) - - @staticmethod - def _convert_to_message( - message: MinimaxProMessage - ) -> FunctionCallMessage | UserMessage | AssistantMessage | FunctionMessage: - if 'function_call' in message: - return FunctionCallMessage( - name=message['sender_name'], - content=FunctionCall(name=message['function_call']['name'], arguments=message['function_call']['arguments']), - ) - if message['sender_type'] == 'USER': - return UserMessage( - name=message['sender_name'], - content=message['text'], - ) - if message['sender_type'] == 'BOT': - return AssistantMessage( - name=message['sender_name'], - content=message['text'], - ) - if message['sender_type'] == 'FUNCTION': - return FunctionMessage( - name=message['sender_name'], - content=message['text'], - ) - raise ValueError(f'unknown sender_type: {message["sender_type"]}') - - def calculate_cost(self, usage: dict[str, int], num_web_search: int = 0) -> float: - return 0.015 * (usage['total_tokens'] / 1000) + (0.03 * num_web_search) + yield stream_processor.process(json.loads(line)) @property @override diff --git a/generate/chat_completion/models/openai.py b/generate/chat_completion/models/openai.py index 57ca31f..77cf0bb 100644 --- a/generate/chat_completion/models/openai.py +++ b/generate/chat_completion/models/openai.py @@ -259,7 +259,7 @@ def parse_openai_model_reponse(response: ResponseValue) -> ChatCompletionOutput[ return ChatCompletionOutput[OpenAIAssistantMessage]( model_info=ModelInfo(task='chat_completion', type='openai', name=response['model']), - messages=[message], + message=message, finish_reason=finish_reason or '', cost=calculate_cost(response['model'], response['usage']['prompt_tokens'], response['usage']['completion_tokens']), extra=extra, @@ -272,7 +272,6 @@ def __init__(self) -> None: self.is_start = True def process(self, response: ResponseValue) -> ChatCompletionStreamOutput[OpenAIAssistantMessage] | None: - delta, extra = '', {} delta_dict = response['choices'][0]['delta'] if self.message is None: @@ -288,11 +287,11 @@ def process(self, response: ResponseValue) -> ChatCompletionStreamOutput[OpenAIA self.is_start = False return ChatCompletionStreamOutput[OpenAIAssistantMessage]( model_info=ModelInfo(task='chat_completion', type='openai', name=response['model']), - messages=[self.message], + message=self.message, finish_reason=finish_reason, cost=cost, extra=extra, - stream=Stream(delta=delta, control=stream_control), + stream=Stream(delta=delta_dict.get('content', ''), control=stream_control), ) def process_initial_message(self, delta_dict: dict[str, Any]) -> OpenAIAssistantMessage | None: diff --git a/generate/chat_completion/models/wenxin.py b/generate/chat_completion/models/wenxin.py index 081ecd8..5363ca4 100644 --- a/generate/chat_completion/models/wenxin.py +++ b/generate/chat_completion/models/wenxin.py @@ -184,20 +184,18 @@ def _parse_reponse(self, response: ResponseValue) -> ChatCompletionOutput[Wenxin if response.get('error_msg'): raise UnexpectedResponseError(response) if response.get('function_call'): - messages: list[WenxinAssistantMessage] = [ - FunctionCallMessage( - content=FunctionCall( - name=response['function_call']['name'], - arguments=response['function_call']['arguments'], - thoughts=response['function_call']['thoughts'], - ), - ) - ] + message = FunctionCallMessage( + content=FunctionCall( + name=response['function_call']['name'], + arguments=response['function_call']['arguments'], + thoughts=response['function_call']['thoughts'], + ), + ) else: - messages = [AssistantMessage(content=response['result'])] + message = AssistantMessage(content=response['result']) return ChatCompletionOutput[WenxinAssistantMessage]( model_info=self.model_info, - messages=messages, + message=message, cost=self.calculate_cost(response['usage']), extra={ 'is_truncated': response['is_truncated'], @@ -258,13 +256,13 @@ def _parse_stream_line( 'need_clear_history': parsed_line['need_clear_history'], 'usage': parsed_line['usage'], }, - messages=[message], + message=message, finish_reason='stop', stream=Stream(delta=delta, control='finish'), ) return ChatCompletionStreamOutput[WenxinAssistantMessage]( model_info=self.model_info, - messages=[message], + message=message, finish_reason=None, stream=Stream(delta=delta, control='start' if is_start else 'continue'), ) diff --git a/generate/chat_completion/models/zhipu.py b/generate/chat_completion/models/zhipu.py index 4362788..ed33dbf 100644 --- a/generate/chat_completion/models/zhipu.py +++ b/generate/chat_completion/models/zhipu.py @@ -159,10 +159,9 @@ def _convert_messages(self, messages: Messages) -> list[ZhipuMessage]: def _parse_reponse(self, response: ResponseValue) -> ChatCompletionOutput[AssistantMessage]: if response['success']: text = response['data']['choices'][0]['content'] - messages = [AssistantMessage(content=text)] return ChatCompletionOutput[AssistantMessage]( model_info=self.model_info, - messages=messages, + message=AssistantMessage(content=text), cost=self.calculate_cost(response['data']['usage']), extra={'usage': response['data']['usage']}, ) @@ -232,13 +231,13 @@ def stream_generate( message.content += line yield ChatCompletionStreamOutput[AssistantMessage]( model_info=self.model_info, - messages=[message], + message=message, stream=Stream(delta=line, control='start' if is_start else 'continue'), ) is_start = False yield ChatCompletionStreamOutput[AssistantMessage]( model_info=self.model_info, - messages=[message], + message=message, finish_reason='stop', stream=Stream(delta='', control='finish'), ) @@ -256,13 +255,13 @@ async def async_stream_generate( message.content += line yield ChatCompletionStreamOutput[AssistantMessage]( model_info=self.model_info, - messages=[message], + message=message, stream=Stream(delta=line, control='start' if is_start else 'continue'), ) is_start = False yield ChatCompletionStreamOutput[AssistantMessage]( model_info=self.model_info, - messages=[message], + message=message, finish_reason='stop', stream=Stream(delta='', control='finish'), ) @@ -324,13 +323,13 @@ def stream_generate( message.content += line yield ChatCompletionStreamOutput[AssistantMessage]( model_info=self.model_info, - messages=[message], + message=message, stream=Stream(delta=line, control='start' if is_start else 'continue'), ) is_start = False yield ChatCompletionStreamOutput[AssistantMessage]( model_info=self.model_info, - messages=[message], + message=message, finish_reason='stop', stream=Stream(delta='', control='finish'), ) @@ -348,13 +347,13 @@ async def async_stream_generate( message.content += line yield ChatCompletionStreamOutput[AssistantMessage]( model_info=self.model_info, - messages=[message], + message=message, stream=Stream(delta=line, control='start' if is_start else 'continue'), ) is_start = False yield ChatCompletionStreamOutput[AssistantMessage]( model_info=self.model_info, - messages=[message], + message=message, finish_reason='stop', stream=Stream(delta='', control='finish'), ) diff --git a/generate/chat_completion/printer.py b/generate/chat_completion/printer.py index 637c648..8b2107e 100644 --- a/generate/chat_completion/printer.py +++ b/generate/chat_completion/printer.py @@ -22,6 +22,18 @@ def print_stream(self, stream: Stream) -> None: ... +class SilentMessagePrinter(MessagePrinter): + """ + A printer that does nothing. + """ + + def print_message(self, message: Message) -> None: + pass + + def print_stream(self, stream: Stream) -> None: + pass + + class SimpleMessagePrinter(MessagePrinter): """ A simple printer that prints messages and streams to the console. diff --git a/generate/chat_engine.py b/generate/chat_engine.py index 07de76d..c934a68 100644 --- a/generate/chat_engine.py +++ b/generate/chat_engine.py @@ -1,30 +1,29 @@ from __future__ import annotations import json -from typing import Any, Callable, List, Literal, Mapping, Sequence, TypedDict +from typing import Any, Callable, List, Mapping, TypedDict from typing_extensions import Self, Unpack -from generate.chat_completion import ChatCompletionModel, ChatCompletionOutput, function +from generate.chat_completion import ChatCompletionModel, ChatCompletionOutput from generate.chat_completion.message import ( AssistantMessage, FunctionCall, FunctionCallMessage, FunctionMessage, - MessageTypeError, ToolCall, ToolCallsMessage, ToolMessage, UnionMessage, UserMessage, ) -from generate.chat_completion.printer import MessagePrinter, SimpleMessagePrinter +from generate.chat_completion.printer import MessagePrinter, SilentMessagePrinter, SimpleMessagePrinter from generate.utils import load_chat_model class ChatEngineKwargs(TypedDict, total=False): - functions: Sequence[function] | Mapping[str, Callable] | None - call_raise_error: bool + functions: Mapping[str, Callable] | None + function_call_raise_error: bool max_calls_per_turn: int @@ -33,55 +32,39 @@ class ChatEngine: Args: chat_model (ChatCompletionModel): The chat completion model for generating responses. - functions (Sequence[function] | Mapping[str, Callable] | None, optional): Functions to be used in the chat. - It can be a list of function or a dictionary mapping function names to a Callable. - call_raise_error (bool, optional): Whether to raise an error when calling a function fails. Defaults to False. + functions (Mapping[str, Callable] | None, optional): Functions to be used in the chat. + function_call_raise_error (bool, optional): Whether to raise an error when calling a function fails. Defaults to False. max_calls_per_turn (int, optional): Maximum number of function calls allowed per turn. Defaults to 5. stream (bool | Literal['auto'], optional): Whether to use streaming. If 'auto', it is determined based on the presence of functions. printer (MessagePrinter | Literal['auto'] | None, optional): An instance for printing messages. If 'auto', a simple printer is used when streaming. """ - printer: MessagePrinter | None + printer: MessagePrinter def __init__( self, chat_model: ChatCompletionModel, - functions: Sequence[function] | Mapping[str, Callable] | None = None, - call_raise_error: bool = False, + functions: Mapping[str, Callable] | None = None, + function_call_raise_error: bool = False, max_calls_per_turn: int = 5, - stream: bool | Literal['auto'] = 'auto', - printer: MessagePrinter | Literal['auto'] | None = 'auto', + stream: bool = True, + printer: MessagePrinter | None = SimpleMessagePrinter(), ) -> None: self._chat_model = chat_model + self._function_map = functions or {} - if isinstance(functions, list): - self._function_map: dict[str, Callable] = {} - for _function in functions: - self._function_map[_function.json_schema['name']] = _function - elif isinstance(functions, dict): - self._function_map = functions - else: - self._function_map = {} - - self.call_raise_error = call_raise_error + self.function_call_raise_error = function_call_raise_error self.max_calls_per_turn = max_calls_per_turn - - if stream == 'auto': - self.stream = not bool(self._function_map) - else: - if self._function_map and stream: - raise ValueError('Cannot stream when functions are provided.') - self.stream = stream - - if printer == 'auto': - self.printer = SimpleMessagePrinter() if stream else None - else: - self.printer = printer - + self.stream: bool = stream + self.printer = printer or SilentMessagePrinter() self.history: list[UnionMessage] = [] self.model_ouptuts: list[ChatCompletionOutput] = [] self._call_count = 0 + @property + def cost(self) -> float: + return sum(output.cost or 0.0 for output in self.model_ouptuts) + @classmethod def from_model_id(cls, model_id: str, **kwargs: Unpack[ChatEngineKwargs]) -> Self: chat_model = load_chat_model(model_id) @@ -91,38 +74,26 @@ def from_model_id(cls, model_id: str, **kwargs: Unpack[ChatEngineKwargs]) -> Sel def chat_model(self) -> ChatCompletionModel: return self._chat_model - @property - def print_message(self) -> bool: - return self.printer is not None - def chat(self, user_input: str, **kwargs: Any) -> str: self._call_count = 0 user_input_message = UserMessage(content=user_input) self.history.append(user_input_message) - if self.printer: - self.printer.print_message(user_input_message) + self.printer.print_message(user_input_message) while True: if self.stream: model_output = self._stream_chat_helper(**kwargs) else: model_output = self._chat_model.generate(self.history, **kwargs) + self.printer.print_message(model_output.message) self._handle_model_output(model_output) if isinstance(model_output.message, AssistantMessage): return model_output.reply def _handle_model_output(self, model_output: ChatCompletionOutput, **kwargs: Any) -> None: - if not model_output.message: - raise RuntimeError('messages in model output is empty.', model_output.model_dump()) - self.model_ouptuts.append(model_output) - self.history.extend(model_output.messages) - if self.printer: - for message in model_output.messages: - if self.stream and isinstance(message, AssistantMessage): - continue - self.printer.print_message(message) + self.history.append(model_output.message) if isinstance(model_output.message, FunctionCallMessage): self._call_count += 1 @@ -160,33 +131,7 @@ def _handle_tool_calls(self, tool_calls: List[ToolCall], **kwargs: Any) -> None: def _stream_chat_helper(self, **kwargs: Any) -> ChatCompletionOutput: for stream_output in self._chat_model.stream_generate(self.history, **kwargs): - if self.printer: - self.printer.print_stream(stream_output.stream) - if stream_output.is_finish: - return stream_output - raise RuntimeError('Stream finished unexpectedly.') - - async def async_chat(self, user_input: str, **kwargs: Any) -> str | None: - self._call_count = 0 - - user_input_message = UserMessage(content=user_input) - self.history.append(user_input_message) - if self.printer: - self.printer.print_message(user_input_message) - - while True: - if self.stream: - model_output = await self._async_stream_chat_helper(**kwargs) - else: - model_output = await self._chat_model.async_generate(self.history, **kwargs) - self._handle_model_output(model_output) - if isinstance(model_output.message, AssistantMessage): - return model_output.reply - - async def _async_stream_chat_helper(self, **kwargs: Any) -> ChatCompletionOutput: - async for stream_output in self._chat_model.async_stream_generate(self.history, **kwargs): - if self.printer: - self.printer.print_stream(stream_output.stream) + self.printer.print_stream(stream_output.stream) if stream_output.is_finish: return stream_output raise RuntimeError('Stream finished unexpectedly.') @@ -194,7 +139,7 @@ async def _async_stream_chat_helper(self, **kwargs: Any) -> ChatCompletionOutput def _run_function_call(self, function_call: FunctionCall) -> Any | str: function = self._function_map.get(function_call.name) if function is None: - if self.call_raise_error: + if self.function_call_raise_error: raise ValueError(f'Function {function_call.name} not found') return 'Function not found, please try another function.' @@ -203,35 +148,10 @@ def _run_function_call(self, function_call: FunctionCall) -> Any | str: arguments = json.loads(function_call.arguments, strict=False) return function(**arguments) except Exception as e: - if self.call_raise_error: + if self.function_call_raise_error: raise return str(e) - async def _async_recursive_function_call(self, function_call: FunctionCall, **kwargs: Any) -> str: - function_output = self._run_function_call(function_call) - function_message = FunctionMessage(name=function_call.name, content=json.dumps(function_output, ensure_ascii=False)) - self.history.append(function_message) - if self.printer: - self.printer.print_message(function_message) - - model_output = await self._chat_model.async_generate(self.history, **kwargs) - self._handle_model_output(model_output) - - if not model_output.message: - raise RuntimeError('messages in model output is empty.', model_output.model_dump()) - - if isinstance(model_output.message, AssistantMessage): - return model_output.message.content - - if isinstance(model_output.message, FunctionCallMessage): - self._call_count += 1 - if self._call_count > self.max_calls_per_turn: - raise RuntimeError('Maximum number of function calls reached.') - function_call = model_output.message.content - return await self._async_recursive_function_call(function_call, **kwargs) - - raise MessageTypeError(model_output.message, allowed_message_type=(AssistantMessage, FunctionCallMessage)) - def reset(self) -> None: self.history.clear() diff --git a/generate/completion_engine.py b/generate/completion_engine.py index 353c759..eb90e9c 100644 --- a/generate/completion_engine.py +++ b/generate/completion_engine.py @@ -15,7 +15,7 @@ from typing_extensions import Self, TypedDict, Unpack from generate.chat_completion import ChatCompletionModel, ChatCompletionOutput -from generate.chat_completion.message import Prompt, Prompts, ensure_messages +from generate.chat_completion.message import AssistantMessage, Prompt, Prompts, ensure_messages from generate.utils import load_chat_model @@ -88,7 +88,9 @@ def _run_single_task( if self.error_mode == 'raise': raise if self.error_mode == 'ignore': - return ChatCompletionOutput(model_info=self.chat_model.model_info, extra={'error': str(e)}) + return ChatCompletionOutput( + model_info=self.chat_model.model_info, message=AssistantMessage(content=''), extra={'error': str(e)} + ) raise ValueError(f'Unknown error mode: {self.error_mode}') from e else: progress_bar.update(1) @@ -138,7 +140,9 @@ async def _async_run_single_task( if self.error_mode == 'raise': raise if self.error_mode == 'ignore': - return ChatCompletionOutput(model_info=self.chat_model.model_info, extra={'error': str(e)}) + return ChatCompletionOutput( + model_info=self.chat_model.model_info, message=AssistantMessage(content=''), extra={'error': str(e)} + ) raise ValueError(f'Unknown error mode: {self.error_mode}') from e else: diff --git a/generate/utils.py b/generate/utils.py index b498f51..a0b1582 100644 --- a/generate/utils.py +++ b/generate/utils.py @@ -13,11 +13,11 @@ def load_chat_model(model_id: str) -> ChatCompletionModel: return model_cls.from_name(name) -def load_speech_model(speech_model_id: str) -> TextToSpeechModel: - if '/' not in speech_model_id: - model_type = speech_model_id +def load_speech_model(model_id: str) -> TextToSpeechModel: + if '/' not in model_id: + model_type = model_id return SpeechModelRegistry[model_type][0]() - model_type, name = speech_model_id.split('/') + model_type, name = model_id.split('/') model_cls = SpeechModelRegistry[model_type][0] return model_cls.from_name(name) diff --git a/generate/version.py b/generate/version.py index fc79d63..020ed73 100644 --- a/generate/version.py +++ b/generate/version.py @@ -1 +1 @@ -__version__ = '0.2.1' +__version__ = '0.2.2' diff --git a/pyproject.toml b/pyproject.toml index e02aa0a..486eab5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "generate-core" -version = "0.2.1" +version = "0.2.2" description = "文本生成,图像生成,语音生成" authors = ["wangyuxin "] license = "MIT" diff --git a/tests/test_chat_completion_model.py b/tests/test_chat_completion_model.py index 42ec61c..24eec27 100644 --- a/tests/test_chat_completion_model.py +++ b/tests/test_chat_completion_model.py @@ -52,6 +52,8 @@ def test_http_stream_chat_model(chat_completion_model: ChatCompletionModel) -> N assert outputs[-1].stream.control == 'finish' for output in outputs[1:-1]: assert output.stream.control == 'continue' + text = ''.join(output.stream.delta for output in outputs) + assert text == outputs[-1].reply assert outputs[0].stream.control == 'start' assert outputs[-1].reply != '' assert async_output.reply != '' diff --git a/tests/test_completion.py b/tests/test_completion.py index 20cbe0d..7e46e60 100644 --- a/tests/test_completion.py +++ b/tests/test_completion.py @@ -35,13 +35,13 @@ def generate(self, prompt: Prompt, **kwargs: Unpack[ModelParametersDict]) -> Cha messages = ensure_messages(prompt) parameters = self.parameters.update_with_validate(**kwargs) content = f'{parameters.prefix}{messages[-1].content}' - return ChatCompletionOutput(model_info=self.model_info, messages=[AssistantMessage(content=content)]) + return ChatCompletionOutput(model_info=self.model_info, message=AssistantMessage(content=content)) async def async_generate(self, prompt: Prompt, **kwargs: Unpack[ModelParametersDict]) -> ChatCompletionOutput: messages = ensure_messages(prompt) parameters = self.parameters.update_with_validate(**kwargs) content = f'{parameters.prefix}{messages[-1].content}' - return ChatCompletionOutput(model_info=self.model_info, messages=[AssistantMessage(content=content)]) + return ChatCompletionOutput(model_info=self.model_info, message=AssistantMessage(content=content)) def stream_generate(self, prompt: Prompt, **kwargs: Unpack[ModelParametersDict]) -> Iterator[ChatCompletionStreamOutput]: messages = ensure_messages(prompt) @@ -49,11 +49,12 @@ def stream_generate(self, prompt: Prompt, **kwargs: Unpack[ModelParametersDict]) content = f'{parameters.prefix}{messages[-1].content}' yield ChatCompletionStreamOutput( model_info=self.model_info, + message=AssistantMessage(content=''), stream=Stream(delta='', control='start'), ) yield ChatCompletionStreamOutput( model_info=self.model_info, - messages=[AssistantMessage(content=content)], + message=AssistantMessage(content=content), stream=Stream(delta=content, control='finish'), ) @@ -65,11 +66,12 @@ async def async_stream_generate( content = f'{parameters.prefix}{messages[-1].content}' yield ChatCompletionStreamOutput( model_info=self.model_info, + message=AssistantMessage(content=''), stream=Stream(delta='', control='start'), ) yield ChatCompletionStreamOutput( model_info=self.model_info, - messages=[AssistantMessage(content=content)], + message=AssistantMessage(content=content), stream=Stream(delta=content, control='finish'), ) diff --git a/tests/test_function.py b/tests/test_function.py index 76b055a..1a6a012 100644 --- a/tests/test_function.py +++ b/tests/test_function.py @@ -87,7 +87,9 @@ def google(keyword: str) -> str: def test_openai_function() -> None: model = OpenAIChat(parameters=OpenAIChatParameters(functions=[get_weather.json_schema, google.json_schema], temperature=0)) - engine = ChatEngine(model, functions=[get_weather, google], stream=False, call_raise_error=True) + engine = ChatEngine( + model, functions={f.name: f for f in [get_weather, google]}, stream=False, function_call_raise_error=True + ) reply = engine.chat('今天北京天气怎么样?') assert '27' in reply @@ -101,6 +103,8 @@ def test_openai_tool() -> None: ] ) ) - engine = ChatEngine(model, functions=[get_weather, google], stream=False, call_raise_error=True) + engine = ChatEngine( + model, functions={f.name: f for f in [get_weather, google]}, stream=False, function_call_raise_error=True + ) reply = engine.chat('今天北京天气怎么样?') assert '27' in reply diff --git a/tests/test_parameters.py b/tests/test_parameters.py index ee18a97..ad7c8a3 100644 --- a/tests/test_parameters.py +++ b/tests/test_parameters.py @@ -9,12 +9,12 @@ class ToolChoice(TypedDict): name: str -class TestParameters(ModelParameters): +class FakeParameters(ModelParameters): name: str = 'TestModel' tool_choice: Union[str, ToolChoice, None] = None def test_parameters() -> None: - parameters = TestParameters(tool_choice=None) + parameters = FakeParameters(tool_choice=None) except_dump_data = {'name': 'TestModel', 'tool_choice': None} assert parameters.custom_model_dump() == except_dump_data