Skip to content

Commit

Permalink
Hotfix (#2)
Browse files Browse the repository at this point in the history
* fix system message bug

* Update version number to 0.2.1

---------

Co-authored-by: wangyuxin <[email protected]>
  • Loading branch information
wangyuxinwhy and wangyuxin authored Nov 16, 2023
1 parent 1bd33f2 commit 539a896
Show file tree
Hide file tree
Showing 26 changed files with 201 additions and 252 deletions.
2 changes: 1 addition & 1 deletion generate/chat_completion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
ZhipuChatParameters,
)
from generate.chat_completion.printer import MessagePrinter, SimpleMessagePrinter
from generate.parameters import ModelParameters
from generate.model import ModelParameters

ChatModels: list[tuple[Type[ChatCompletionModel], Type[ModelParameters]]] = [
(AzureChat, OpenAIChatParameters),
Expand Down
36 changes: 4 additions & 32 deletions generate/chat_completion/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,35 +2,20 @@

import logging
from abc import ABC, abstractmethod
from typing import Any, AsyncIterator, ClassVar, Generic, Iterator, TypeVar

from typing_extensions import Self, TypeGuard
from typing import Any, AsyncIterator, ClassVar, Iterator, TypeVar

from generate.chat_completion.message import Messages, Prompt, ensure_messages
from generate.chat_completion.model_output import ChatCompletionOutput, ChatCompletionStreamOutput
from generate.model import ModelInfo
from generate.parameters import ModelParameters
from generate.model import GenerateModel, ModelParameters

P = TypeVar('P', bound=ModelParameters)
logger = logging.getLogger(__name__)


class ChatCompletionModel(Generic[P], ABC):
class ChatCompletionModel(GenerateModel[P, Prompt, ChatCompletionOutput], ABC):
model_task: ClassVar[str] = 'chat_completion'
model_type: ClassVar[str]

def __init__(self, parameters: P) -> None:
self.parameters = parameters

@property
@abstractmethod
def name(self) -> str:
...

@classmethod
@abstractmethod
def from_name(cls, name: str, **kwargs: Any) -> Self:
...

@abstractmethod
def _completion(self, messages: Messages, parameters: P) -> ChatCompletionOutput:
...
Expand All @@ -47,10 +32,6 @@ def _stream_completion(self, messages: Messages, parameters: P) -> Iterator[Chat
def _async_stream_completion(self, messages: Messages, parameters: P) -> AsyncIterator[ChatCompletionStreamOutput]:
...

@property
def model_info(self) -> ModelInfo:
return ModelInfo(task='chat_completion', type=self.model_type, name=self.name)

def generate(self, prompt: Prompt, **override_parameters: Any) -> ChatCompletionOutput:
parameters = self._merge_parameters(**override_parameters)
messages = ensure_messages(prompt)
Expand All @@ -74,12 +55,3 @@ def async_stream_generate(self, prompt: Prompt, **override_parameters: Any) -> A
messages = ensure_messages(prompt)
logger.debug(f'{messages=}, {parameters=}')
return self._async_stream_completion(messages, parameters)

def _merge_parameters(self, **override_parameters: Any) -> P:
return self.parameters.__class__.model_validate(
{**self.parameters.model_dump(exclude_unset=True), **override_parameters}
)


def is_stream_model_output(model_output: ChatCompletionOutput) -> TypeGuard[ChatCompletionStreamOutput]:
return getattr(model_output, 'stream', None) is not None
2 changes: 1 addition & 1 deletion generate/chat_completion/models/baichuan.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
HttpxPostKwargs,
UnexpectedResponseError,
)
from generate.parameters import ModelParameters
from generate.model import ModelParameters
from generate.types import Probability, Temperature


Expand Down
2 changes: 1 addition & 1 deletion generate/chat_completion/models/bailian.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
HttpxPostKwargs,
UnexpectedResponseError,
)
from generate.parameters import ModelParameters
from generate.model import ModelParameters
from generate.types import Probability


Expand Down
2 changes: 1 addition & 1 deletion generate/chat_completion/models/hunyuan.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
HttpxPostKwargs,
UnexpectedResponseError,
)
from generate.parameters import ModelParameters
from generate.model import ModelParameters
from generate.types import Probability, Temperature


Expand Down
2 changes: 1 addition & 1 deletion generate/chat_completion/models/minimax.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
HttpxPostKwargs,
UnexpectedResponseError,
)
from generate.parameters import ModelParameters
from generate.model import ModelParameters
from generate.types import Probability, Temperature


Expand Down
2 changes: 1 addition & 1 deletion generate/chat_completion/models/minimax_pro.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
HttpxPostKwargs,
UnexpectedResponseError,
)
from generate.parameters import ModelParameters
from generate.model import ModelParameters
from generate.types import Probability, Temperature


Expand Down
182 changes: 93 additions & 89 deletions generate/chat_completion/models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

import json
import os
from typing import Any, AsyncIterator, ClassVar, Dict, Iterator, List, Literal, Optional, Union, cast
from functools import partial
from typing import Any, AsyncIterator, Callable, ClassVar, Dict, Iterator, List, Literal, Optional, Type, Union, cast

from pydantic import Field, PositiveInt
from typing_extensions import Annotated, NotRequired, Self, TypedDict, Unpack, override
Expand All @@ -19,6 +20,7 @@
Messages,
MessageTypeError,
MessageValueError,
SystemMessage,
TextPart,
ToolCall,
ToolCallsMessage,
Expand All @@ -35,8 +37,7 @@
HttpxPostKwargs,
UnexpectedResponseError,
)
from generate.model import ModelInfo
from generate.parameters import ModelParameters
from generate.model import ModelInfo, ModelParameters
from generate.types import Probability, Temperature


Expand Down Expand Up @@ -95,94 +96,97 @@ class OpenAIChatParameters(ModelParameters):
tool_choice: Union[Literal['auto'], OpenAIToolChoice, None] = None


def convert_to_openai_message(message: Message) -> OpenAIMessage:
if isinstance(message, UserMessage):
return {
'role': 'user',
'content': message.content,
}

if isinstance(message, UserMultiPartMessage):
content = []
for part in message.content:
if isinstance(part, TextPart):
content.append({'type': 'text', 'text': part.text})
elif isinstance(part, ImageUrlPart):
image_url_part_dict = {
'type': 'image_url',
'image_url': {
'url': part.image_url.url,
},
}
if part.image_url.detail:
image_url_part_dict['image_url']['detail'] = part.image_url.detail
content.append(image_url_part_dict)
else:
raise MessageValueError(message, f'OpenAI does not support {type(part)} ')

return {
'role': 'user',
'content': content,
}

if isinstance(message, AssistantMessage):
return {
'role': 'assistant',
'content': message.content,
}

if isinstance(message, ToolCallsMessage):
return {
'role': 'assistant',
'content': None,
'tool_calls': [
{
'id': tool_call.id,
'type': 'function',
'function': {
'name': tool_call.function.name,
'arguments': tool_call.function.arguments,
},
}
for tool_call in message.content
],
}

if isinstance(message, ToolMessage):
return {
'role': 'tool',
'tool_call_id': message.tool_call_id,
'content': message.content,
}

if isinstance(message, FunctionCallMessage):
return {
'role': 'assistant',
'function_call': {
'name': message.content.name,
'arguments': message.content.arguments,
},
'content': None,
}
def _to_text_message_dict(role: str, message: Message) -> OpenAIMessage:
return {
'role': role,
'content': message.content,
}


def _to_user_multipart_message_dict(message: UserMultiPartMessage) -> OpenAIMessage:
content = []
for part in message.content:
if isinstance(part, TextPart):
content.append({'type': 'text', 'text': part.text})
elif isinstance(part, ImageUrlPart):
image_url_part_dict = {
'type': 'image_url',
'image_url': {
'url': part.image_url.url,
},
}
if part.image_url.detail:
image_url_part_dict['image_url']['detail'] = part.image_url.detail
content.append(image_url_part_dict)
else:
raise MessageValueError(message, f'OpenAI does not support {type(part)} ')

return {
'role': 'user',
'content': content,
}


def _to_tool_message_dict(message: ToolMessage) -> OpenAIMessage:
return {
'role': 'tool',
'tool_call_id': message.tool_call_id,
'content': message.content,
}


def _to_tool_calls_message_dict(message: ToolCallsMessage) -> OpenAIMessage:
return {
'role': 'assistant',
'content': None,
'tool_calls': [
{
'id': tool_call.id,
'type': 'function',
'function': {
'name': tool_call.function.name,
'arguments': tool_call.function.arguments,
},
}
for tool_call in message.content
],
}


def _to_function_message_dict(message: FunctionMessage) -> OpenAIMessage:
return {
'role': 'function',
'name': message.name,
'content': message.content,
}


def _to_function_call_message_dict(message: FunctionCallMessage) -> OpenAIMessage:
return {
'role': 'assistant',
'function_call': {
'name': message.content.name,
'arguments': message.content.arguments,
},
'content': None,
}

if isinstance(message, FunctionMessage):
return {
'role': 'function',
'name': message.name,
'content': message.content,
}

raise MessageTypeError(
message,
allowed_message_type=(
UserMessage,
AssistantMessage,
FunctionMessage,
FunctionCallMessage,
ToolCallsMessage,
ToolMessage,
),
)
def convert_to_openai_message(message: Message) -> OpenAIMessage:
to_function_map: dict[Type[Message], Callable[[Any], OpenAIMessage]] = {
SystemMessage: partial(_to_text_message_dict, 'system'),
UserMessage: partial(_to_text_message_dict, 'user'),
AssistantMessage: partial(_to_text_message_dict, 'assistant'),
UserMultiPartMessage: _to_user_multipart_message_dict,
ToolMessage: _to_tool_message_dict,
ToolCallsMessage: _to_tool_calls_message_dict,
FunctionMessage: _to_function_message_dict,
FunctionCallMessage: _to_function_call_message_dict,
}
if to_function := to_function_map.get(type(message)):
return to_function(message)

raise MessageTypeError(message, allowed_message_type=tuple(to_function_map.keys()))


def calculate_cost(model_name: str, input_tokens: int, output_tokens: int) -> float | None:
Expand Down
2 changes: 1 addition & 1 deletion generate/chat_completion/models/wenxin.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
HttpxPostKwargs,
UnexpectedResponseError,
)
from generate.parameters import ModelParameters
from generate.model import ModelParameters
from generate.types import JsonSchema, Probability, Temperature


Expand Down
2 changes: 1 addition & 1 deletion generate/chat_completion/models/zhipu.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
HttpxPostKwargs,
UnexpectedResponseError,
)
from generate.parameters import ModelParameters
from generate.model import ModelParameters
from generate.types import Probability, Temperature

P = TypeVar('P', bound=ModelParameters)
Expand Down
5 changes: 2 additions & 3 deletions generate/image_generation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@

from typing import Type

from generate.image_generation.base import ImageGenerationModel
from generate.image_generation.model_output import GeneratedImage, ImageGenerationOutput
from generate.image_generation.base import GeneratedImage, ImageGenerationModel, ImageGenerationOutput
from generate.image_generation.models import OpenAIImageGeneration, OpenAIImageGenerationParameters
from generate.parameters import ModelParameters
from generate.model import ModelParameters

ImageGenerationModels: list[tuple[Type[ImageGenerationModel], Type[ModelParameters]]] = [
(OpenAIImageGeneration, OpenAIImageGenerationParameters),
Expand Down
Loading

0 comments on commit 539a896

Please sign in to comment.