Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add baichuan #22

Merged
merged 1 commit into from
Oct 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions lmclient/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from lmclient.completion_engine import CompletionEngine
from lmclient.models import (
AzureChat,
BaichuanChat,
BaichuanChatParameters,
HunyuanChat,
HunyuanChatParameters,
MinimaxChat,
Expand Down Expand Up @@ -45,6 +47,8 @@
'WenxinChatParameters',
'HunyuanChat',
'HunyuanChatParameters',
'BaichuanChat',
'BaichuanChatParameters',
'BaseSchema',
'function',
'GeneralParameters',
Expand Down
4 changes: 4 additions & 0 deletions lmclient/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Any

from lmclient.models.azure import AzureChat
from lmclient.models.baichuan import BaichuanChat, BaichuanChatParameters
from lmclient.models.base import BaseChatModel
from lmclient.models.hunyuan import HunyuanChat, HunyuanChatParameters
from lmclient.models.minimax import MinimaxChat, MinimaxChatParameters
Expand All @@ -17,6 +18,7 @@
ZhiPuChat.model_type: ZhiPuChat,
WenxinChat.model_type: WenxinChat,
HunyuanChat.model_type: HunyuanChat,
BaichuanChat.model_type: BaichuanChat,
}


Expand Down Expand Up @@ -48,4 +50,6 @@ def list_chat_model_types():
'WenxinChatParameters',
'HunyuanChat',
'HunyuanChatParameters',
'BaichuanChat',
'BaichuanChatParameters',
]
117 changes: 117 additions & 0 deletions lmclient/models/baichuan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
from __future__ import annotations

import hashlib
import json
import os
import time
from pathlib import Path
from typing import Any, ClassVar, Literal, Optional, TypedDict

from lmclient.exceptions import MessageError
from lmclient.models.http import HttpChatModel, ProxiesTypes, RetryStrategy
from lmclient.parser import ParserError
from lmclient.types import GeneralParameters, Message, Messages, ModelParameters, ModelResponse


class BaichuanChatParameters(ModelParameters):
temperature: Optional[float] = None
top_k: Optional[int] = None
top_p: Optional[float] = None
with_search_enhance: Optional[bool] = None

@classmethod
def from_general_parameters(cls, general_parameters: GeneralParameters):
return cls(
temperature=general_parameters.temperature,
top_p=general_parameters.top_p,
)


class BaichuanMessageDict(TypedDict):
role: Literal['user', 'assistant']
content: str


class BaichuanChat(HttpChatModel[BaichuanChatParameters]):
model_type = 'zhipu'
default_api_base: ClassVar[str] = 'https://api.baichuan-ai.com/v1/chat'

def __init__(
self,
model: str = 'Baichuan2-53B',
api_key: str | None = None,
secret_key: str | None = None,
api_base: str | None = None,
timeout: int | None = 60,
retry: bool | RetryStrategy = False,
parameters: BaichuanChatParameters = BaichuanChatParameters(),
use_cache: Path | str | bool = False,
proxies: ProxiesTypes | None = None,
):
super().__init__(parameters=parameters, timeout=timeout, retry=retry, use_cache=use_cache, proxies=proxies)
self.model = model
self.api_key = api_key or os.environ['BAICHUAN_API_KEY']
self.secret_key = secret_key or os.environ['BAICHUAN_SECRET_KEY']
self.api_base = api_base or self.default_api_base
self.api_base.rstrip('/')

def get_request_parameters(self, messages: Messages, parameters: BaichuanChatParameters) -> dict[str, Any]:
baichuan_messages: list[BaichuanMessageDict] = []
for message in messages:
role = message.role
if role not in ('user', 'assistant'):
raise ValueError(f'Role of message must be user or assistant, but got {message.role}')
if not isinstance(message.content, str):
raise MessageError(f'Message content must be str, but got {type(message.content)}')
baichuan_messages.append(
{
'role': role,
'content': message.content,
}
)

data = {
'model': self.model,
'messages': baichuan_messages,
}
parameters_dict = parameters.model_dump(exclude_none=True)
if parameters_dict:
data['parameters'] = parameters_dict
time_stamp = int(time.time())
signature = self.calculate_md5(self.secret_key + json.dumps(data) + str(time_stamp))

headers = {
'Content-Type': 'application/json',
'Authorization': 'Bearer ' + self.api_key,
'X-BC-Timestamp': str(time_stamp),
'X-BC-Signature': signature,
'X-BC-Sign-Algo': 'MD5',
'X-BC-Request-Id': 'your requestId',
}
return {
'url': self.api_base,
'headers': headers,
'json': data,
}

@staticmethod
def calculate_md5(input_string: str):
md5 = hashlib.md5()
md5.update(input_string.encode('utf-8'))
encrypted = md5.hexdigest()
return encrypted

def parse_model_reponse(self, response: ModelResponse) -> Messages:
try:
text = response['data']['messages'][-1]['content']
return [Message(role='assistant', content=text)]
except (KeyError, IndexError) as e:
raise ParserError(f'Parse response failed, reponse: {response}') from e

@property
def name(self) -> str:
return self.model

@classmethod
def from_name(cls, name: str, **kwargs: Any):
return cls(model=name, **kwargs)
2 changes: 1 addition & 1 deletion lmclient/version.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
__version__ = '0.8.3'
__version__ = '0.8.4'
__cache_version__ = '4'
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "lmclient-core"
version = "0.8.3"
version = "0.8.4"
description = "LM Async Client, openai client, azure openai client ..."
authors = ["wangyuxin <[email protected]>"]
readme = "README.md"
Expand Down
4 changes: 3 additions & 1 deletion tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from lmclient.models import (
AzureChat,
BaichuanChat,
BaseChatModel,
HunyuanChat,
MinimaxChat,
Expand All @@ -15,7 +16,8 @@


@pytest.mark.parametrize(
'chat_model', (AzureChat(), MinimaxProChat(), MinimaxChat(), OpenAIChat(), ZhiPuChat(), WenxinChat(), HunyuanChat())
'chat_model',
(AzureChat(), MinimaxProChat(), MinimaxChat(), OpenAIChat(), ZhiPuChat(), WenxinChat(), HunyuanChat(), BaichuanChat()),
)
def test_http_chat_model(chat_model: BaseChatModel[ModelParameters, HttpChatModelOutput]):
test_messages = [Message(role='user', content='hello')]
Expand Down