Skip to content

Commit

Permalink
Merge pull request #9 from wangyuxinwhy/chat-completion
Browse files Browse the repository at this point in the history
Chat completion
  • Loading branch information
wangyuxinwhy authored Sep 3, 2023
2 parents 283cc07 + 6d6a382 commit 909a25c
Show file tree
Hide file tree
Showing 31 changed files with 1,424 additions and 826 deletions.
18 changes: 15 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
5. 支持磁盘缓存
6. 100% type hints
7. 非常易用
8. 支持 OpenAI, Azure, Minimax, 智谱, 百度文心
9. 支持 FunctionCall

## 安装方式
支持 python3.8 及以上
Expand All @@ -22,10 +24,11 @@ pip install lmclient-core

## 使用方法

1. LMClient
```python
from lmclient import LMClient, OpenAIChat
from lmclient import LMClient, OpenAIChat, OpenAIChatParameters

model = OpenAIChat('gpt-3.5-turbo')
model = OpenAIChat('gpt-3.5-turbo', parameters=OpenAIChatParameters(temperature=0))
# 控制每分钟最大请求次数为 20, 异步容量为 5
client = LMClient(model, async_capacity=5, max_requests_per_minute=20)
prompts = [
Expand All @@ -34,9 +37,18 @@ prompts = [
[{'role': 'system', 'content': 'your are lmclient demo assistant'}, {'role': 'user', 'content': 'hello, who are you?'}],
'what is your name?',
]
values = client.async_run(prompts=prompts, temperature=0)
values = client.run(prompts=prompts)
print(values)
```
2. ChatEngine
```python
from lmclient import ChatEngine, OpenAIChat

model = OpenAIChat('gpt-3.5-turbo')
chat_engine = ChatEngine(model)
print(chat_engine.chat('你好,我是 chat_engine'))
print(chat_engine.chat('我上一句话是什么?')))
```

## 使用样例: 大规模翻译

Expand Down
46 changes: 40 additions & 6 deletions lmclient/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,40 @@
from lmclient.client import LMClient as LMClient
from lmclient.models import AzureChat as AzureChat
from lmclient.models import MinimaxChat as MinimaxChat
from lmclient.models import OpenAIChat as OpenAIChat
from lmclient.models import OpenAIExtract as OpenAIExtract
from lmclient.models import ZhiPuChat as ZhiPuChat
from lmclient.chat_engine import ChatEngine
from lmclient.client import LMClient
from lmclient.models import (
AzureChat,
MinimaxProChat,
MinimaxProChatParameters,
OpenAIChat,
OpenAIChatParameters,
WenxinChat,
WenxinChatParameters,
ZhiPuChat,
ZhiPuChatParameters,
)
from lmclient.utils import BaseSchema, PydanticVersion, function
from lmclient.version import __version__

if PydanticVersion == 1:
from pydantic import BaseModel

BaseModel.model_copy = BaseModel.copy # type: ignore
BaseModel.model_dump = BaseModel.dict # type: ignore
BaseModel.model_dump_json = BaseModel.json # type: ignore


__all__ = [
'LMClient',
'ChatEngine',
'AzureChat',
'OpenAIChat',
'OpenAIChatParameters',
'MinimaxProChat',
'MinimaxProChatParameters',
'ZhiPuChat',
'ZhiPuChatParameters',
'WenxinChat',
'WenxinChatParameters',
'BaseSchema',
'function',
'__version__',
]
53 changes: 53 additions & 0 deletions lmclient/cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from __future__ import annotations

import os
from pathlib import Path
from typing import cast

import diskcache

from lmclient.types import ChatModelOutput

DEFAULT_CACHE_DIR = Path(os.getenv('LMCLIENT_CACHE_DIR', '~/.cache/lmclient')).expanduser().resolve()


class ChatCacheMixin:
_cache: diskcache.Cache | None
_cache_dir: Path | None

def __init__(self, use_cache: Path | str | bool = False) -> None:
if isinstance(use_cache, (str, Path)):
self.cache_dir = Path(use_cache)
elif use_cache:
self.cache_dir = DEFAULT_CACHE_DIR
else:
self.cache_dir = None

def cache_model_output(self, key: str, model_output: ChatModelOutput) -> None:
if self._cache is not None:
self._cache[key] = model_output
else:
raise RuntimeError('Cache is not enabled')

def try_load_model_output(self, key: str):
if self._cache is not None and key in self._cache:
model_output = cast(ChatModelOutput, self._cache[key])
return model_output

@property
def use_cache(self) -> bool:
return self._cache is not None

@property
def cache_dir(self) -> Path | None:
return self._cache_dir

@cache_dir.setter
def cache_dir(self, value: Path | None) -> None:
if value is not None:
if value.exists() and not value.is_dir():
raise ValueError(f'Cache directory {value} is not a directory')
value.mkdir(parents=True, exist_ok=True)
self._cache = diskcache.Cache(value)
else:
self._cache = None
129 changes: 129 additions & 0 deletions lmclient/chat_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
from __future__ import annotations

import json
from typing import Any, Generic, List, Optional, TypeVar, cast

from lmclient.models import BaseChatModel, load_from_model_id
from lmclient.types import ChatModelOutput, FunctionCallDict, GeneralParameters, Message, Messages, ModelParameters
from lmclient.utils import function

T_P = TypeVar('T_P', bound=ModelParameters)
T_O = TypeVar('T_O', bound=ChatModelOutput)


class ChatEngine(Generic[T_P, T_O]):
def __init__(
self,
chat_model: BaseChatModel[T_P, T_O] | str,
temperature: float = 1,
top_p: float = 1,
functions: Optional[List[function]] = None,
function_call_raise_error: bool = False,
**extra_parameters: Any,
):
if isinstance(chat_model, str):
self._chat_model: BaseChatModel[T_P, T_O] = load_from_model_id(chat_model) # type: ignore
else:
self._chat_model = chat_model

self.functions = functions or []
if functions:
functions_schema = [function.schema for function in functions]
function_call = 'auto'
else:
functions_schema = None
function_call = None

self.engine_parameters = GeneralParameters(
temperature=temperature,
top_p=top_p,
functions=functions_schema,
function_call=function_call,
)
self._extra_parameters = extra_parameters
_parameters = self._chat_model.parameters_type.from_general_parameters(self.engine_parameters).model_copy(
update=self._extra_parameters
)
self._parameters = cast(T_P, _parameters)
self.function_call_raise_error = function_call_raise_error
self.history: Messages = []

@property
def extra_parameters(self) -> dict[str, Any]:
return self._extra_parameters

@extra_parameters.setter
def extra_parameters(self, extra_parameters: dict[str, Any]):
self._extra_parameters = extra_parameters
self._parameters = self._parameters.model_copy(update=self._extra_parameters)

@property
def chat_model(self):
return self._chat_model

def chat(self, user_input: str, **extra_parameters: Any) -> str:
parameters = self._parameters.model_copy(update=extra_parameters)
self.history.append(Message(role='user', content=user_input))
model_response = self._chat_model.chat_completion(self.history, parameters)
self.history.extend(model_response.messages)
if isinstance(reply := model_response.messages[-1].content, str):
return reply
else:
return self._recursive_function_call(reply, parameters)

async def async_chat(self, user_input: str, **extra_parameters: Any) -> str:
parameters = self._parameters.model_copy(update=extra_parameters)
self.history.append(Message(role='user', content=user_input))
model_response = await self._chat_model.async_chat_completion(self.history, parameters)
self.history.extend(model_response.messages)
if isinstance(reply := model_response.messages[-1].content, str):
return reply
else:
return await self._async_recursive_function_call(reply, parameters)

def run_function_call(self, function_call: FunctionCallDict):
function = None
for i in self.functions:
if i.name == function_call['name']:
function = i
if function is None:
if self.function_call_raise_error:
raise ValueError(f'Function {function_call["name"]} not found')
else:
return 'Function not found, please try another function.'

try:
arguments = json.loads(function_call['arguments'], strict=False)
return function(**arguments)
except Exception as e:
if self.function_call_raise_error:
raise e
else:
return f'Error: {e}'

def _recursive_function_call(self, function_call: FunctionCallDict, parameters: T_P) -> str:
function_output = self.run_function_call(function_call)
self.history.append(
Message(role='function', name=function_call['name'], content=json.dumps(function_output, ensure_ascii=False))
)
model_response = self._chat_model.chat_completion(self.history, parameters)
self.history.extend(model_response.messages)
if isinstance(reply := model_response.messages[-1].content, str):
return reply
else:
return self._recursive_function_call(reply, parameters)

async def _async_recursive_function_call(self, function_call: FunctionCallDict, parameters: T_P) -> str:
function_output = self.run_function_call(function_call)
self.history.append(
Message(role='function', name=function_call['name'], content=json.dumps(function_output, ensure_ascii=False))
)
model_response = await self._chat_model.async_chat_completion(self.history, parameters)
self.history.extend(model_response.messages)
if isinstance(reply := model_response.messages[-1].content, str):
return reply
else:
return self._recursive_function_call(reply, parameters)

def reset(self) -> None:
self.history.clear()
Loading

0 comments on commit 909a25c

Please sign in to comment.