Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modify message #15

Merged
merged 3 commits into from
Jan 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 0 additions & 9 deletions generate/chat_completion/function_call.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
from __future__ import annotations

import json
from typing import Any, Callable, Generic, TypeVar

from docstring_parser import parse
from pydantic import TypeAdapter, validate_call
from typing_extensions import NotRequired, ParamSpec, TypedDict

from generate.chat_completion.message import FunctionCallMessage, Message
from generate.types import JsonSchema

P = ParamSpec('P')
Expand Down Expand Up @@ -58,13 +56,6 @@ def parameters(self) -> JsonSchema:
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T:
return self.function(*args, **kwargs)

def call_with_message(self, message: Message) -> T:
if isinstance(message, FunctionCallMessage):
function_call = message.content
arguments = json.loads(function_call.arguments, strict=False)
return self.function(**arguments) # type: ignore
raise ValueError(f'message is not a function call: {message}')


def recusive_remove(obj: Any, remove_key: str) -> None:
"""
Expand Down
8 changes: 0 additions & 8 deletions generate/chat_completion/message/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
from generate.chat_completion.message.core import (
AssistantGroupMessage,
AssistantMessage,
FunctionCall,
FunctionCallMessage,
FunctionMessage,
ImageUrl,
ImageUrlPart,
Expand All @@ -13,9 +11,7 @@
SystemMessage,
TextPart,
ToolCall,
ToolCallsMessage,
ToolMessage,
UnionAssistantMessage,
UnionMessage,
UnionUserMessage,
UserMessage,
Expand All @@ -30,13 +26,9 @@
from generate.chat_completion.message.utils import ensure_messages

__all__ = [
'UnionAssistantMessage',
'UnionUserMessage',
'UnionMessage',
'FunctionCallMessage',
'AssistantMessage',
'AssistantGroupMessage',
'ToolCallsMessage',
'ensure_messages',
'FunctionCall',
'FunctionMessage',
Expand Down
28 changes: 8 additions & 20 deletions generate/chat_completion/message/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,44 +51,32 @@ class ToolMessage(Message):
content: Optional[str] = None


class AssistantMessage(Message):
role: Literal['assistant'] = 'assistant'
content: str


class FunctionCall(BaseModel):
name: str
arguments: str
thoughts: Optional[str] = None


class FunctionCallMessage(Message):
role: Literal['assistant'] = 'assistant'
content: FunctionCall


class ToolCall(BaseModel):
id: str # noqa: A003
type: Literal['function'] = 'function' # noqa: A003
function: FunctionCall


class ToolCallsMessage(Message):
class AssistantMessage(Message):
role: Literal['assistant'] = 'assistant'
name: Optional[str] = None
content: List[ToolCall]
content: str = ''
function_call: Optional[FunctionCall] = None
tool_calls: Optional[List[ToolCall]] = None


class AssistantGroupMessage(Message):
role: Literal['assistant'] = 'assistant'
name: Optional[str] = None
content: List[Union[AssistantMessage, FunctionMessage, FunctionCallMessage]]
@property
def is_over(self) -> bool:
return self.function_call is None and self.tool_calls is None


UnionAssistantMessage = Union[AssistantMessage, FunctionCallMessage, ToolCallsMessage, AssistantGroupMessage]
UnionUserMessage = Union[UserMessage, UserMultiPartMessage]
UnionUserPart = Union[TextPart, ImageUrlPart]
UnionMessage = Union[SystemMessage, FunctionMessage, ToolMessage, UnionAssistantMessage, UnionUserMessage]
UnionMessage = Union[SystemMessage, FunctionMessage, ToolMessage, AssistantMessage, UnionUserMessage]
Messages = List[UnionMessage]
MessageDict = Dict[str, Any]
MessageDicts = Sequence[MessageDict]
Expand Down
17 changes: 6 additions & 11 deletions generate/chat_completion/model_output.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,20 @@
from __future__ import annotations

from typing import Generic, Literal, Optional, TypeVar, cast
from typing import Literal, Optional

from pydantic import BaseModel

from generate.chat_completion.message import AssistantMessage, UnionAssistantMessage
from generate.chat_completion.message import AssistantMessage
from generate.model import ModelOutput

M = TypeVar('M', bound=UnionAssistantMessage)


class ChatCompletionOutput(ModelOutput, Generic[M]):
message: M
class ChatCompletionOutput(ModelOutput, AssistantMessage):
message: AssistantMessage
finish_reason: Optional[str] = None

@property
def reply(self) -> str:
if self.message and isinstance(self.message, AssistantMessage):
message = cast(AssistantMessage, self.message)
return message.content
return ''
return self.message.content

@property
def is_finish(self) -> bool:
Expand All @@ -31,5 +26,5 @@ class Stream(BaseModel):
control: Literal['start', 'continue', 'finish']


class ChatCompletionStreamOutput(ChatCompletionOutput, Generic[M]):
class ChatCompletionStreamOutput(ChatCompletionOutput):
stream: Stream
18 changes: 7 additions & 11 deletions generate/chat_completion/models/baichuan.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,24 +117,22 @@ def _get_request_parameters(self, messages: Messages, parameters: BaichuanChatPa
}

@override
def generate(self, prompt: Prompt, **kwargs: Unpack[BaichuanChatParametersDict]) -> ChatCompletionOutput[AssistantMessage]:
def generate(self, prompt: Prompt, **kwargs: Unpack[BaichuanChatParametersDict]) -> ChatCompletionOutput:
messages = ensure_messages(prompt)
parameters = self.parameters.update_with_validate(**kwargs)
request_parameters = self._get_request_parameters(messages, parameters)
response = self.http_client.post(request_parameters=request_parameters)
return self._parse_reponse(response.json())

@override
async def async_generate(
self, prompt: Prompt, **kwargs: Unpack[BaichuanChatParametersDict]
) -> ChatCompletionOutput[AssistantMessage]:
async def async_generate(self, prompt: Prompt, **kwargs: Unpack[BaichuanChatParametersDict]) -> ChatCompletionOutput:
messages = ensure_messages(prompt)
parameters = self.parameters.update_with_validate(**kwargs)
request_parameters = self._get_request_parameters(messages, parameters)
response = await self.http_client.async_post(request_parameters=request_parameters)
return self._parse_reponse(response.json())

def _parse_reponse(self, response: ResponseValue) -> ChatCompletionOutput[AssistantMessage]:
def _parse_reponse(self, response: ResponseValue) -> ChatCompletionOutput:
try:
text = response['data']['messages'][-1]['content']
finish_reason = response['data']['messages'][-1]['finish_reason']
Expand Down Expand Up @@ -164,7 +162,7 @@ def _get_stream_request_parameters(self, messages: Messages, parameters: Baichua
@override
def stream_generate(
self, prompt: Prompt, **kwargs: Unpack[BaichuanChatParametersDict]
) -> Iterator[ChatCompletionStreamOutput[AssistantMessage]]:
) -> Iterator[ChatCompletionStreamOutput]:
messages = ensure_messages(prompt)
parameters = self.parameters.update_with_validate(**kwargs)
request_parameters = self._get_stream_request_parameters(messages, parameters)
Expand All @@ -178,7 +176,7 @@ def stream_generate(
@override
async def async_stream_generate(
self, prompt: Prompt, **kwargs: Unpack[BaichuanChatParametersDict]
) -> AsyncIterator[ChatCompletionStreamOutput[AssistantMessage]]:
) -> AsyncIterator[ChatCompletionStreamOutput]:
messages = ensure_messages(prompt)
parameters = self.parameters.update_with_validate(**kwargs)
request_parameters = self._get_stream_request_parameters(messages, parameters)
Expand All @@ -189,17 +187,15 @@ async def async_stream_generate(
is_start = False
yield output

def _parse_stream_line(
self, line: str, message: AssistantMessage, is_start: bool
) -> ChatCompletionStreamOutput[AssistantMessage]:
def _parse_stream_line(self, line: str, message: AssistantMessage, is_start: bool) -> ChatCompletionStreamOutput:
output = self._parse_reponse(json.loads(line))
output_message = output.message
if is_start:
stream = Stream(delta=output_message.content, control='start')
else:
stream = Stream(delta=output_message.content, control='finish' if output.is_finish else 'continue')
message.content += output_message.content
return ChatCompletionStreamOutput[AssistantMessage](
return ChatCompletionStreamOutput(
model_info=output.model_info,
message=message,
finish_reason=output.finish_reason,
Expand Down
16 changes: 6 additions & 10 deletions generate/chat_completion/models/bailian.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,24 +135,22 @@ def _get_request_parameters(self, messages: Messages, parameters: BailianChatPar
}

@override
def generate(self, prompt: Prompt, **kwargs: Unpack[BailianChatParametersDict]) -> ChatCompletionOutput[AssistantMessage]:
def generate(self, prompt: Prompt, **kwargs: Unpack[BailianChatParametersDict]) -> ChatCompletionOutput:
messages = ensure_messages(prompt)
parameters = self.parameters.update_with_validate(**kwargs)
request_parameters = self._get_request_parameters(messages, parameters)
response = self.http_client.post(request_parameters=request_parameters)
return self._parse_reponse(response.json())

@override
async def async_generate(
self, prompt: Prompt, **kwargs: Unpack[BailianChatParametersDict]
) -> ChatCompletionOutput[AssistantMessage]:
async def async_generate(self, prompt: Prompt, **kwargs: Unpack[BailianChatParametersDict]) -> ChatCompletionOutput:
messages = ensure_messages(prompt)
parameters = self.parameters.update_with_validate(**kwargs)
request_parameters = self._get_request_parameters(messages, parameters)
response = await self.http_client.async_post(request_parameters=request_parameters)
return self._parse_reponse(response.json())

def _parse_reponse(self, response: ResponseValue) -> ChatCompletionOutput[AssistantMessage]:
def _parse_reponse(self, response: ResponseValue) -> ChatCompletionOutput:
if not response['Success']:
raise UnexpectedResponseError(response)
return ChatCompletionOutput(
Expand All @@ -176,7 +174,7 @@ def _get_stream_request_parameters(self, messages: Messages, parameters: Bailian
@override
def stream_generate(
self, prompt: Prompt, **kwargs: Unpack[BailianChatParametersDict]
) -> Iterator[ChatCompletionStreamOutput[AssistantMessage]]:
) -> Iterator[ChatCompletionStreamOutput]:
messages = ensure_messages(prompt)
parameters = self.parameters.update_with_validate(**kwargs)
request_parameters = self._get_stream_request_parameters(messages, parameters)
Expand All @@ -195,7 +193,7 @@ def stream_generate(
@override
async def async_stream_generate(
self, prompt: Prompt, **kwargs: Unpack[BailianChatParametersDict]
) -> AsyncIterator[ChatCompletionStreamOutput[AssistantMessage]]:
) -> AsyncIterator[ChatCompletionStreamOutput]:
messages = ensure_messages(prompt)
parameters = self.parameters.update_with_validate(**kwargs)
request_parameters = self._get_stream_request_parameters(messages, parameters)
Expand All @@ -211,9 +209,7 @@ async def async_stream_generate(
is_finish = output.is_finish
yield output

def _parse_stream_line(
self, line: str, message: AssistantMessage, is_start: bool
) -> ChatCompletionStreamOutput[AssistantMessage]:
def _parse_stream_line(self, line: str, message: AssistantMessage, is_start: bool) -> ChatCompletionStreamOutput:
parsed_line = json.loads(line)
reply: str = parsed_line['Data']['Text']
extra = {
Expand Down
16 changes: 6 additions & 10 deletions generate/chat_completion/models/hunyuan.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,24 +91,22 @@ def _get_request_parameters(self, messages: Messages, parameters: HunyuanChatPar
}

@override
def generate(self, prompt: Prompt, **kwargs: Unpack[HunyuanChatParametersDict]) -> ChatCompletionOutput[AssistantMessage]:
def generate(self, prompt: Prompt, **kwargs: Unpack[HunyuanChatParametersDict]) -> ChatCompletionOutput:
messages = ensure_messages(prompt)
parameters = self.parameters.update_with_validate(**kwargs)
request_parameters = self._get_request_parameters(messages, parameters)
response = self.http_client.post(request_parameters=request_parameters)
return self._parse_reponse(response.json())

@override
async def async_generate(
self, prompt: Prompt, **kwargs: Unpack[HunyuanChatParametersDict]
) -> ChatCompletionOutput[AssistantMessage]:
async def async_generate(self, prompt: Prompt, **kwargs: Unpack[HunyuanChatParametersDict]) -> ChatCompletionOutput:
messages = ensure_messages(prompt)
parameters = self.parameters.update_with_validate(**kwargs)
request_parameters = self._get_request_parameters(messages, parameters)
response = await self.http_client.async_post(request_parameters=request_parameters)
return self._parse_reponse(response.json())

def _parse_reponse(self, response: ResponseValue) -> ChatCompletionOutput[AssistantMessage]:
def _parse_reponse(self, response: ResponseValue) -> ChatCompletionOutput:
if response.get('error'):
raise UnexpectedResponseError(response)
return ChatCompletionOutput(
Expand Down Expand Up @@ -136,7 +134,7 @@ def _get_stream_request_parameters(self, messages: Messages, parameters: Hunyuan
@override
def stream_generate(
self, prompt: Prompt, **kwargs: Unpack[HunyuanChatParametersDict]
) -> Iterator[ChatCompletionStreamOutput[AssistantMessage]]:
) -> Iterator[ChatCompletionStreamOutput]:
messages = ensure_messages(prompt)
parameters = self.parameters.update_with_validate(**kwargs)
request_parameters = self._get_stream_request_parameters(messages, parameters)
Expand All @@ -150,7 +148,7 @@ def stream_generate(
@override
async def async_stream_generate(
self, prompt: Prompt, **kwargs: Unpack[HunyuanChatParametersDict]
) -> AsyncIterator[ChatCompletionStreamOutput[AssistantMessage]]:
) -> AsyncIterator[ChatCompletionStreamOutput]:
messages = ensure_messages(prompt)
parameters = self.parameters.update_with_validate(**kwargs)
request_parameters = self._get_stream_request_parameters(messages, parameters)
Expand All @@ -161,9 +159,7 @@ async def async_stream_generate(
is_start = False
yield output

def _parse_stream_line(
self, line: str, message: AssistantMessage, is_start: bool
) -> ChatCompletionStreamOutput[AssistantMessage]:
def _parse_stream_line(self, line: str, message: AssistantMessage, is_start: bool) -> ChatCompletionStreamOutput:
parsed_line = json.loads(line)
message_dict = parsed_line['choices'][0]
delta = message_dict['delta']['content']
Expand Down
Loading