Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Stream #13

Merged
merged 2 commits into from
Dec 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 16 additions & 13 deletions examples/tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
"👏🏻 欢迎来到 Generate 的教程,在这里您将学习到:\n",
"\n",
"1. 使用统一简洁的 API 替代不同平台杂乱的 SDK\n",
"2. 使用 `Generate` 生成文本,图像以及音频"
"2. 使用 `Generate` 集成的模型生成文本,图像以及音频"
]
},
{
Expand Down Expand Up @@ -66,7 +66,7 @@
"source": [
"### 配置 OpenAI Key\n",
"\n",
"在使用 `generate` 之前,需要先配置 OpenAI API Key,这样才能使用 OpenAI 的 API。 \n",
"在使用 `generate` 之前,需要先配置 OpenAI API Key。 \n",
"\n",
"`generate` 库使用 [Pydantic Settings](https://docs.pydantic.dev/latest/concepts/pydantic_settings/) 管理不同平台的配置,Pydantic Settings 会从 `.env` 文件,环境变量或者 **运行时** 获取相关配置。\n",
"\n",
Expand Down Expand Up @@ -150,12 +150,15 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"我们可以看到模型输出了一个结构化的对象 `ChatCompletionOutput`,其中包括了一些有用的信息,比如:\n",
"- 模型信息,在 model_info 字段中,包含了任务种类,平台以及模型名称\n",
"- 生成消息,在 messages 字段中,包含了 gpt 模型生成的消息\n",
"- 计费信息,在 cost 字段中,包含了此次任务的花销,单位是元\n",
"- 结束原因,在 finish_reason 字段中,展示了此次任务完成的原因\n",
"- 额外信息,在 extra 字段中,包含了一些可能会有用的额外信息,比如此次任务使用了多少 token 等"
"我们可以看到模型输出了一个结构化的对象 `ChatCompletionOutput`,其中包括了一些有用的信息\n",
"\n",
"| field | description |\n",
"| --- | --- |\n",
"| model_info | 包含任务种类,平台及模型名称 |\n",
"| message | 模型生成的消息 |\n",
"| cost | 此次任务的花销,单位是元 |\n",
"| finish_reason | 任务完成的原因 |\n",
"| extra | 包含可能会有用的额外信息 |\n"
]
},
{
Expand All @@ -164,7 +167,7 @@
"source": [
"`ChatCompletionOutput` 对象的基类是 [Pydantic BaseModel](https://docs.pydantic.dev/latest/concepts/models/),因此我们可以通过访问属性的方式访问这些字段。\n",
"\n",
"除此之外,`ChatCompletionOutput` 还提供了一些常用的计算属性,比如 `reply` 和 `last_message`。就像下面这样"
"除此之外,`ChatCompletionOutput` 还提供了一些常用的计算属性,比如 `reply`。就像下面这样"
]
},
{
Expand All @@ -175,10 +178,10 @@
"source": [
"cost = model_output.cost\n",
"reply = model_output.reply\n",
"last_message = model_output.last_message\n",
"message = model_output.message\n",
"rich.print(f'{cost=}')\n",
"rich.print(f'{reply=}')\n",
"rich.print(f'{last_message=}')"
"rich.print(f'{message=}')"
]
},
{
Expand All @@ -187,9 +190,9 @@
"source": [
"### 设置模型及其参数\n",
"\n",
"当然,我们也可以不使用默认的模型和参数,而是自定义他们。\n",
"在上一个示例中,我们没有设置模型类型和参数,而是使用默认值。现在,让我们学习一下如何指定模型类型和模型参数。\n",
"\n",
"模型的参数可以在模型初始化的时候设置,以作为模型的默认参数。也可以在调用 `generate` 方法的时候设置,以作为此次调用的参数。\n",
"模型的参数可以在模型初始化的时候设置,以作为模型的默认参数。也可以在调用 `generate` 方法的时候设置,作为此次调用的参数。\n",
"\n",
"- 初始化时的参数,必须显式声明,以 `OpenAIChat` 为例,它的参数为 `OpenAIChatParameters` 实例。\n",
"- 调用时的参数,无须显式声明,直接传入关键字参数即可,比如 `model.generate('你好', temperature=0.5)`\n",
Expand Down
29 changes: 12 additions & 17 deletions generate/chat_completion/function_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,26 +39,21 @@ def get_json_schema(function: Callable[..., Any]) -> FunctionJsonSchema:


class function(Generic[P, T]): # noqa: N801
"""
A decorator class that wraps a callable function and provides additional functionality.

Args:
function (Callable[P, T]): The function to be wrapped.
def __init__(self, function: Callable[P, T]) -> None:
self.function: Callable[P, T] = validate_call(function)
self.json_schema: FunctionJsonSchema = get_json_schema(function)

Attributes:
function (Callable[P, T]): The wrapped function.
name (str): The name of the wrapped function.
docstring (ParsedDocstring): The parsed docstring of the wrapped function.
json_schema (Function): The JSON schema of the wrapped function.
@property
def name(self) -> str:
return self.json_schema['name']

Methods:
__call__(self, *args: Any, **kwargs: Any) -> Any: Calls the wrapped function with the provided arguments.
call_with_message(self, message: Message) -> T: Calls the wrapped function with the arguments provided in the message.
"""
@property
def description(self) -> str:
return self.json_schema.get('description', '')

def __init__(self, function: Callable[P, T]) -> None:
self.function: Callable[P, T] = validate_call(function)
self.json_schema = get_json_schema(function)
@property
def parameters(self) -> JsonSchema:
return self.json_schema['parameters']

def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T:
return self.function(*args, **kwargs)
Expand Down
2 changes: 2 additions & 0 deletions generate/chat_completion/message/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from generate.chat_completion.message.core import (
AssistantGroupMessage,
AssistantMessage,
FunctionCall,
FunctionCallMessage,
Expand Down Expand Up @@ -34,6 +35,7 @@
'UnionMessage',
'FunctionCallMessage',
'AssistantMessage',
'AssistantGroupMessage',
'ToolCallsMessage',
'ensure_messages',
'FunctionCall',
Expand Down
8 changes: 7 additions & 1 deletion generate/chat_completion/message/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,13 @@ class ToolCallsMessage(Message):
content: List[ToolCall]


UnionAssistantMessage = Union[AssistantMessage, FunctionCallMessage, ToolCallsMessage]
class AssistantGroupMessage(Message):
role: Literal['assistant'] = 'assistant'
name: Optional[str] = None
content: List[Union[AssistantMessage, FunctionMessage, FunctionCallMessage]]


UnionAssistantMessage = Union[AssistantMessage, FunctionCallMessage, ToolCallsMessage, AssistantGroupMessage]
UnionUserMessage = Union[UserMessage, UserMultiPartMessage]
UnionUserPart = Union[TextPart, ImageUrlPart]
UnionMessage = Union[SystemMessage, FunctionMessage, ToolMessage, UnionAssistantMessage, UnionUserMessage]
Expand Down
14 changes: 4 additions & 10 deletions generate/chat_completion/model_output.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,19 @@
from __future__ import annotations

from typing import Generic, List, Literal, Optional, TypeVar
from typing import Generic, Literal, Optional, TypeVar

from pydantic import BaseModel

from generate.chat_completion.message import AssistantMessage, UnionMessage
from generate.chat_completion.message import AssistantMessage, UnionAssistantMessage
from generate.model import ModelOutput

M = TypeVar('M', bound=UnionMessage)
M = TypeVar('M', bound=UnionAssistantMessage)


class ChatCompletionOutput(ModelOutput, Generic[M]):
messages: List[M] = []
message: M
finish_reason: Optional[str] = None

@property
def message(self) -> M:
if len(self.messages) != 1:
raise ValueError('Expected exactly one message')
return self.messages[0]

@property
def reply(self) -> str:
if self.message and isinstance(self.message, AssistantMessage):
Expand Down
5 changes: 2 additions & 3 deletions generate/chat_completion/models/baichuan.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def _parse_reponse(self, response: ResponseValue) -> ChatCompletionOutput[Assist
extra = {}
return ChatCompletionOutput(
model_info=self.model_info,
messages=[AssistantMessage(content=text)],
message=AssistantMessage(content=text),
finish_reason=finish_reason,
cost=cost,
extra=extra,
Expand Down Expand Up @@ -199,10 +199,9 @@ def _parse_stream_line(
else:
stream = Stream(delta=output_message.content, control='finish' if output.is_finish else 'continue')
message.content += output_message.content
output.messages = [message]
return ChatCompletionStreamOutput[AssistantMessage](
model_info=output.model_info,
messages=output.messages,
message=message,
finish_reason=output.finish_reason,
cost=output.cost,
extra=output.extra,
Expand Down
9 changes: 4 additions & 5 deletions generate/chat_completion/models/bailian.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def generate_default_request_id() -> str:

def _convert_to_bailian_chat_qa_pair(messages: Messages) -> list[BailianChatQAPair]:
pairs: list[BailianChatQAPair] = []
if isinstance(messages[0], SystemMessage):
if messages and isinstance(messages[0], SystemMessage):
pairs.append({'User': messages[0].content, 'Bot': '好的'})
messages = messages[1:]

Expand Down Expand Up @@ -155,10 +155,9 @@ async def async_generate(
def _parse_reponse(self, response: ResponseValue) -> ChatCompletionOutput[AssistantMessage]:
if not response['Success']:
raise UnexpectedResponseError(response)
messages = [AssistantMessage(content=response['Data']['Text'])]
return ChatCompletionOutput(
model_info=self.model_info,
messages=messages,
message=AssistantMessage(content=response['Data']['Text']),
extra={
'thoughts': response['Data']['Thoughts'],
'doc_references': response['Data']['DocReferences'],
Expand Down Expand Up @@ -220,7 +219,7 @@ def _parse_stream_line(
if len(reply) == len(message.content):
return ChatCompletionStreamOutput(
model_info=self.model_info,
messages=[message],
message=message,
extra=extra,
finish_reason='stop',
stream=Stream(delta='', control='finish'),
Expand All @@ -230,7 +229,7 @@ def _parse_stream_line(
message.content = reply
return ChatCompletionStreamOutput(
model_info=self.model_info,
messages=[message],
message=message,
extra=extra,
stream=Stream(delta=delta, control='start' if is_start else 'continue'),
)
Expand Down
7 changes: 3 additions & 4 deletions generate/chat_completion/models/hunyuan.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,9 @@ async def async_generate(
def _parse_reponse(self, response: ResponseValue) -> ChatCompletionOutput[AssistantMessage]:
if response.get('error'):
raise UnexpectedResponseError(response)
messages = [AssistantMessage(content=response['choices'][0]['messages']['content'])]
return ChatCompletionOutput(
model_info=self.model_info,
messages=messages,
message=AssistantMessage(content=response['choices'][0]['messages']['content']),
finish_reason=response['choices'][0]['finish_reason'],
cost=self.calculate_cost(response['usage']),
extra={'usage': response['usage']},
Expand Down Expand Up @@ -172,15 +171,15 @@ def _parse_stream_line(
if message_dict['finish_reason']:
return ChatCompletionStreamOutput(
model_info=self.model_info,
messages=[message],
message=message,
finish_reason=message_dict['finish_reason'],
cost=self.calculate_cost(parsed_line['usage']),
stream=Stream(delta=delta, control='finish'),
extra={'usage': parsed_line['usage']},
)
return ChatCompletionStreamOutput(
model_info=self.model_info,
messages=[message],
message=message,
finish_reason=None,
stream=Stream(delta=delta, control='start' if is_start else 'continue'),
)
Expand Down
11 changes: 5 additions & 6 deletions generate/chat_completion/models/minimax.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,12 +145,11 @@ async def async_generate(
response = await self.http_client.async_post(request_parameters=request_parameters)
return self._parse_reponse(response.json())

def _parse_reponse(self, response: ResponseValue) -> ChatCompletionOutput:
def _parse_reponse(self, response: ResponseValue) -> ChatCompletionOutput[AssistantMessage]:
try:
messages = [AssistantMessage(content=response['choices'][0]['text'])]
return ChatCompletionOutput(
return ChatCompletionOutput[AssistantMessage](
model_info=self.model_info,
messages=messages,
message=AssistantMessage(content=response['choices'][0]['text']),
finish_reason=response['choices'][0]['finish_reason'],
cost=self.calculate_cost(response['usage']),
extra={
Expand Down Expand Up @@ -207,7 +206,7 @@ def _parse_stream_line(
return ChatCompletionStreamOutput[AssistantMessage](
model_info=self.model_info,
finish_reason=parsed_line['choices'][0]['finish_reason'],
messages=[message],
message=message,
cost=self.calculate_cost(parsed_line['usage']),
extra={
'logprobes': parsed_line['choices'][0]['logprobes'],
Expand All @@ -220,7 +219,7 @@ def _parse_stream_line(
return ChatCompletionStreamOutput[AssistantMessage](
model_info=self.model_info,
finish_reason=None,
messages=[message],
message=message,
stream=Stream(delta=delta, control='start' if is_start else 'continue'),
)

Expand Down
Loading