Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

From name bug & wenxin4 #20

Merged
merged 2 commits into from
Oct 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions lmclient/models/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from lmclient.models.http import HttpChatModel, ProxiesTypes, RetryStrategy
from lmclient.models.openai import (
OpenAIChatParameters,
convert_lmclient_to_openai,
format_message_to_openai,
parse_openai_model_reponse,
)
from lmclient.types import Messages, ModelResponse
Expand Down Expand Up @@ -41,7 +41,7 @@ def get_request_parameters(self, messages: Messages, parameters: OpenAIChatParam
'api-key': self.api_key,
}
parameters_dict = parameters.model_dump(exclude_defaults=True)
openai_messages = [convert_lmclient_to_openai(message) for message in messages]
openai_messages = [format_message_to_openai(message) for message in messages]
if self.system_prompt:
openai_messages.insert(0, {'role': 'system', 'content': self.system_prompt})
params = {
Expand Down
65 changes: 39 additions & 26 deletions lmclient/models/openai.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import os
from copy import copy
from pathlib import Path
from typing import Any, Dict, List, Literal, Optional, Union

Expand All @@ -9,7 +10,17 @@
from lmclient.exceptions import MessageError
from lmclient.models.http import HttpChatModel, ProxiesTypes, RetryStrategy
from lmclient.parser import ParserError
from lmclient.types import FunctionCallDict, FunctionDict, GeneralParameters, Message, Messages, ModelParameters, ModelResponse
from lmclient.types import (
FunctionCallDict,
FunctionDict,
GeneralParameters,
Message,
Messages,
ModelParameters,
ModelResponse,
is_function_call_message,
is_text_message,
)


class FunctionCallNameDict(TypedDict):
Expand Down Expand Up @@ -68,35 +79,37 @@ def from_general_parameters(cls, general_parameters: GeneralParameters):
)


def convert_lmclient_to_openai(message: Message, valid_roles: set[str] | None = None) -> OpenAIMessageDict:
valid_roles = valid_roles or {'user', 'assistant', 'function', 'system'}
if message.role not in valid_roles:
raise MessageError(f'Invalid role "{message.role}", supported roles are {valid_roles}')
def format_message_to_openai(message: Message) -> OpenAIMessageDict:
role = message.role
if role == 'error':
raise MessageError(f'Invalid message role: {role}, only "user", "assistant", "system" and "function" are allowed')

content = message.content

if isinstance(content, dict):
if message.role != 'assistant':
raise MessageError(f'Invalid role "{message.role}" for function call, can only be made by "assistant"')
if is_function_call_message(message):
function_call = copy(message.content)
if 'thoughts' in function_call:
function_call.pop('thoughts')
return {
'role': message.role,
'function_call': content,
'role': 'assistant',
'function_call': function_call,
'content': None,
}
elif message.role == 'function':
name = message.name
if name is None:
raise MessageError(f'Function name is required, message: {message}')
return {
'role': message.role,
'name': name,
'content': content,
}
elif is_text_message(message):
if role == 'function':
name = message.name
if name is None:
raise MessageError(f'Function name is required, message: {message}')
return {
'role': role,
'name': name,
'content': message.content,
}
else:
return {
'role': role,
'content': message.content,
}
else:
return {
'role': message.role,
'content': content,
}
raise MessageError(f'Invalid message type: {message}')


def parse_openai_model_reponse(response: ModelResponse) -> Messages:
Expand Down Expand Up @@ -147,7 +160,7 @@ def get_request_parameters(self, messages: Messages, parameters: OpenAIChatParam
'Authorization': f'Bearer {self.api_key}',
}
parameters_dict = parameters.model_dump(exclude_defaults=True)
openai_messages = [convert_lmclient_to_openai(message) for message in messages]
openai_messages = [format_message_to_openai(message) for message in messages]
if self.system_prompt:
openai_messages.insert(0, {'role': 'system', 'content': self.system_prompt})
params = {
Expand Down
105 changes: 91 additions & 14 deletions lmclient/models/wenxin.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,31 +4,110 @@
from datetime import datetime, timedelta
from email.errors import MessageError
from pathlib import Path
from typing import Any, Literal, Optional
from typing import Any, Dict, List, Literal, Optional

import httpx
from typing_extensions import Self, TypedDict
from typing_extensions import NotRequired, Self, TypedDict

from lmclient.exceptions import ResponseError
from lmclient.models.http import HttpChatModel, ProxiesTypes, RetryStrategy
from lmclient.types import GeneralParameters, Message, Messages, ModelParameters, ModelResponse
from lmclient.types import (
GeneralParameters,
Message,
Messages,
ModelParameters,
ModelResponse,
is_function_call_message,
is_text_message,
)


class WenxinMessageDict(TypedDict):
role: Literal['user', 'assistant']
role: Literal['user', 'assistant', 'function']
content: str
name: NotRequired[str]
function_call: NotRequired[WenxinFunctionCallDict]


class WenxinFunctionCallDict(TypedDict):
name: str
arguments: str
thoughts: NotRequired[str]


class WenxinFunctionDict(TypedDict):
name: str
description: str
parameters: Dict[str, Any]
responses: NotRequired[Dict[str, str]]
examples: NotRequired[List[WenxinMessageDict]]


def format_message_to_wenxin(message: Message) -> WenxinMessageDict:
role = message.role
if role == 'error' or role == 'system':
raise MessageError(f'Invalid message role: {role}, only "user", "assistant" and "function" are allowed')

if is_function_call_message(message):
return {
'role': 'assistant',
'function_call': {
'name': message.content['name'],
'arguments': message.content['arguments'],
'thoughts': message.content.get('thoughts', ''),
},
'content': '',
}
elif is_text_message(message):
if role == 'function':
name = message.name
if name is None:
raise MessageError(f'Function name is required, message: {message}')
return {
'role': role,
'name': name,
'content': message.content,
}
else:
return {
'role': role,
'content': message.content,
}
else:
raise MessageError(f'Invalid message type: {message}')


class WenxinChatParameters(ModelParameters):
temperature: Optional[float] = None
top_p: Optional[float] = None
functions: Optional[List[WenxinFunctionDict]] = None
penalty_score: Optional[float] = None
system: Optional[str] = None
user_id: Optional[str] = None

@classmethod
def from_general_parameters(cls, general_parameters: GeneralParameters) -> Self:
if general_parameters.functions is not None:
wenxin_functions: list[WenxinFunctionDict] | None = []
for general_function in general_parameters.functions:
wenxin_function = WenxinFunctionDict(
name=general_function['name'],
description=general_function.get('description', ''),
parameters=general_function['parameters'],
)
if 'responses' in general_function:
wenxin_function['responses'] = general_function['responses']
if 'examples' in general_function:
messages = [Message(**message_dict) for message_dict in general_function['examples']]
wenxin_messages = [format_message_to_wenxin(message) for message in messages]
wenxin_function['examples'] = wenxin_messages
wenxin_functions.append(wenxin_function)
else:
wenxin_functions = None
return cls(
temperature=general_parameters.temperature,
top_p=general_parameters.top_p,
functions=wenxin_functions,
)


Expand Down Expand Up @@ -96,14 +175,7 @@ def get_access_token(self) -> str:
def get_request_parameters(self, messages: Messages, parameters: WenxinChatParameters) -> dict[str, Any]:
self.maybe_refresh_access_token()

message_dicts: list[WenxinMessageDict] = []
for message in messages:
role = message.role
if role != 'assistant' and role != 'user':
raise MessageError(f'Invalid message role: {role}, only "user" and "assistant" are allowed')
if not isinstance(content := message.content, str):
raise MessageError(f'Invalid message content: {content}, only string is allowed')
message_dicts.append(WenxinMessageDict(content=content, role=role))
message_dicts: list[WenxinMessageDict] = [format_message_to_wenxin(message) for message in messages]
parameters_dict = parameters.model_dump(exclude_none=True)
if 'temperature' in parameters_dict:
parameters_dict['temperature'] = max(0.01, parameters_dict['temperature'])
Expand All @@ -119,7 +191,12 @@ def get_request_parameters(self, messages: Messages, parameters: WenxinChatParam
def parse_model_reponse(self, response: ModelResponse) -> Messages:
if response.get('error_msg'):
raise ResponseError(response['error_msg'])
return [Message(role='assistant', content=response['result'])]
if response.get('function_call'):
arguments = response['function_call']['arguments']
name = response['function_call']['name']
return [Message(role='assistant', content={'name': name, 'arguments': arguments})]
else:
return [Message(role='assistant', content=response['result'])]

def maybe_refresh_access_token(self):
if self._access_token_expires_at < datetime.now():
Expand All @@ -128,4 +205,4 @@ def maybe_refresh_access_token(self):

@classmethod
def from_name(cls, name: str, **kwargs: Any) -> Self:
return cls(**kwargs)
return cls(model=name, **kwargs)
37 changes: 31 additions & 6 deletions lmclient/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,33 +3,58 @@
from typing import Any, Dict, List, Literal, Optional, Sequence, Union

from pydantic import BaseModel, ConfigDict, Field
from typing_extensions import NotRequired, Self, TypedDict
from typing_extensions import NotRequired, Self, TypedDict, TypeGuard

from lmclient.exceptions import MessageError

Messages = List['Message']
ModelResponse = Dict[str, Any]
Prompt = Union[str, 'Message', 'MessageDict', Sequence[Union['MessageDict', 'Message']]]
Role = Literal['user', 'assistant', 'function', 'error']
Role = Literal['user', 'assistant', 'function', 'error', 'system']


class FunctionDict(TypedDict):
name: str
description: NotRequired[str]
parameters: Dict[str, Any]
responses: NotRequired[Dict[str, str]]
examples: NotRequired[List[MessageDict]]


class FunctionCallDict(TypedDict):
name: str
arguments: str
thoughts: NotRequired[str]


class FunctionCallMessage(BaseModel):
role: Literal['assitant']
name: Optional[str] = None
content: FunctionCallDict


class TextMessage(BaseModel):
role: Role
name: Optional[str] = None
content: str


class Message(BaseModel):
role: Role
content: Union[str, FunctionCallDict]
name: Optional[str] = None
content: Union[str, FunctionCallDict]


def is_function_call_message(message: Message) -> TypeGuard[FunctionCallMessage]:
if isinstance(message.content, dict):
if message.role != 'assistant':
raise MessageError(f'Invalid role "{message.role}" for function call, can only be made by "assistant"')
return True
return False


@property
def is_function_call(self) -> bool:
return isinstance(self.content, dict)
def is_text_message(message: Message) -> TypeGuard[TextMessage]:
return isinstance(message.content, str)


class MessageDict(TypedDict):
Expand Down