-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #9 from wangyuxinwhy/chat-completion
Chat completion
- Loading branch information
Showing
31 changed files
with
1,424 additions
and
826 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,40 @@ | ||
from lmclient.client import LMClient as LMClient | ||
from lmclient.models import AzureChat as AzureChat | ||
from lmclient.models import MinimaxChat as MinimaxChat | ||
from lmclient.models import OpenAIChat as OpenAIChat | ||
from lmclient.models import OpenAIExtract as OpenAIExtract | ||
from lmclient.models import ZhiPuChat as ZhiPuChat | ||
from lmclient.chat_engine import ChatEngine | ||
from lmclient.client import LMClient | ||
from lmclient.models import ( | ||
AzureChat, | ||
MinimaxProChat, | ||
MinimaxProChatParameters, | ||
OpenAIChat, | ||
OpenAIChatParameters, | ||
WenxinChat, | ||
WenxinChatParameters, | ||
ZhiPuChat, | ||
ZhiPuChatParameters, | ||
) | ||
from lmclient.utils import BaseSchema, PydanticVersion, function | ||
from lmclient.version import __version__ | ||
|
||
if PydanticVersion == 1: | ||
from pydantic import BaseModel | ||
|
||
BaseModel.model_copy = BaseModel.copy # type: ignore | ||
BaseModel.model_dump = BaseModel.dict # type: ignore | ||
BaseModel.model_dump_json = BaseModel.json # type: ignore | ||
|
||
|
||
__all__ = [ | ||
'LMClient', | ||
'ChatEngine', | ||
'AzureChat', | ||
'OpenAIChat', | ||
'OpenAIChatParameters', | ||
'MinimaxProChat', | ||
'MinimaxProChatParameters', | ||
'ZhiPuChat', | ||
'ZhiPuChatParameters', | ||
'WenxinChat', | ||
'WenxinChatParameters', | ||
'BaseSchema', | ||
'function', | ||
'__version__', | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
from __future__ import annotations | ||
|
||
import os | ||
from pathlib import Path | ||
from typing import cast | ||
|
||
import diskcache | ||
|
||
from lmclient.types import ChatModelOutput | ||
|
||
DEFAULT_CACHE_DIR = Path(os.getenv('LMCLIENT_CACHE_DIR', '~/.cache/lmclient')).expanduser().resolve() | ||
|
||
|
||
class ChatCacheMixin: | ||
_cache: diskcache.Cache | None | ||
_cache_dir: Path | None | ||
|
||
def __init__(self, use_cache: Path | str | bool = False) -> None: | ||
if isinstance(use_cache, (str, Path)): | ||
self.cache_dir = Path(use_cache) | ||
elif use_cache: | ||
self.cache_dir = DEFAULT_CACHE_DIR | ||
else: | ||
self.cache_dir = None | ||
|
||
def cache_model_output(self, key: str, model_output: ChatModelOutput) -> None: | ||
if self._cache is not None: | ||
self._cache[key] = model_output | ||
else: | ||
raise RuntimeError('Cache is not enabled') | ||
|
||
def try_load_model_output(self, key: str): | ||
if self._cache is not None and key in self._cache: | ||
model_output = cast(ChatModelOutput, self._cache[key]) | ||
return model_output | ||
|
||
@property | ||
def use_cache(self) -> bool: | ||
return self._cache is not None | ||
|
||
@property | ||
def cache_dir(self) -> Path | None: | ||
return self._cache_dir | ||
|
||
@cache_dir.setter | ||
def cache_dir(self, value: Path | None) -> None: | ||
if value is not None: | ||
if value.exists() and not value.is_dir(): | ||
raise ValueError(f'Cache directory {value} is not a directory') | ||
value.mkdir(parents=True, exist_ok=True) | ||
self._cache = diskcache.Cache(value) | ||
else: | ||
self._cache = None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
from __future__ import annotations | ||
|
||
import json | ||
from typing import Any, Generic, List, Optional, TypeVar, cast | ||
|
||
from lmclient.models import BaseChatModel, load_from_model_id | ||
from lmclient.types import ChatModelOutput, FunctionCallDict, GeneralParameters, Message, Messages, ModelParameters | ||
from lmclient.utils import function | ||
|
||
T_P = TypeVar('T_P', bound=ModelParameters) | ||
T_O = TypeVar('T_O', bound=ChatModelOutput) | ||
|
||
|
||
class ChatEngine(Generic[T_P, T_O]): | ||
def __init__( | ||
self, | ||
chat_model: BaseChatModel[T_P, T_O] | str, | ||
temperature: float = 1, | ||
top_p: float = 1, | ||
functions: Optional[List[function]] = None, | ||
function_call_raise_error: bool = False, | ||
**extra_parameters: Any, | ||
): | ||
if isinstance(chat_model, str): | ||
self._chat_model: BaseChatModel[T_P, T_O] = load_from_model_id(chat_model) # type: ignore | ||
else: | ||
self._chat_model = chat_model | ||
|
||
self.functions = functions or [] | ||
if functions: | ||
functions_schema = [function.schema for function in functions] | ||
function_call = 'auto' | ||
else: | ||
functions_schema = None | ||
function_call = None | ||
|
||
self.engine_parameters = GeneralParameters( | ||
temperature=temperature, | ||
top_p=top_p, | ||
functions=functions_schema, | ||
function_call=function_call, | ||
) | ||
self._extra_parameters = extra_parameters | ||
_parameters = self._chat_model.parameters_type.from_general_parameters(self.engine_parameters).model_copy( | ||
update=self._extra_parameters | ||
) | ||
self._parameters = cast(T_P, _parameters) | ||
self.function_call_raise_error = function_call_raise_error | ||
self.history: Messages = [] | ||
|
||
@property | ||
def extra_parameters(self) -> dict[str, Any]: | ||
return self._extra_parameters | ||
|
||
@extra_parameters.setter | ||
def extra_parameters(self, extra_parameters: dict[str, Any]): | ||
self._extra_parameters = extra_parameters | ||
self._parameters = self._parameters.model_copy(update=self._extra_parameters) | ||
|
||
@property | ||
def chat_model(self): | ||
return self._chat_model | ||
|
||
def chat(self, user_input: str, **extra_parameters: Any) -> str: | ||
parameters = self._parameters.model_copy(update=extra_parameters) | ||
self.history.append(Message(role='user', content=user_input)) | ||
model_response = self._chat_model.chat_completion(self.history, parameters) | ||
self.history.extend(model_response.messages) | ||
if isinstance(reply := model_response.messages[-1].content, str): | ||
return reply | ||
else: | ||
return self._recursive_function_call(reply, parameters) | ||
|
||
async def async_chat(self, user_input: str, **extra_parameters: Any) -> str: | ||
parameters = self._parameters.model_copy(update=extra_parameters) | ||
self.history.append(Message(role='user', content=user_input)) | ||
model_response = await self._chat_model.async_chat_completion(self.history, parameters) | ||
self.history.extend(model_response.messages) | ||
if isinstance(reply := model_response.messages[-1].content, str): | ||
return reply | ||
else: | ||
return await self._async_recursive_function_call(reply, parameters) | ||
|
||
def run_function_call(self, function_call: FunctionCallDict): | ||
function = None | ||
for i in self.functions: | ||
if i.name == function_call['name']: | ||
function = i | ||
if function is None: | ||
if self.function_call_raise_error: | ||
raise ValueError(f'Function {function_call["name"]} not found') | ||
else: | ||
return 'Function not found, please try another function.' | ||
|
||
try: | ||
arguments = json.loads(function_call['arguments'], strict=False) | ||
return function(**arguments) | ||
except Exception as e: | ||
if self.function_call_raise_error: | ||
raise e | ||
else: | ||
return f'Error: {e}' | ||
|
||
def _recursive_function_call(self, function_call: FunctionCallDict, parameters: T_P) -> str: | ||
function_output = self.run_function_call(function_call) | ||
self.history.append( | ||
Message(role='function', name=function_call['name'], content=json.dumps(function_output, ensure_ascii=False)) | ||
) | ||
model_response = self._chat_model.chat_completion(self.history, parameters) | ||
self.history.extend(model_response.messages) | ||
if isinstance(reply := model_response.messages[-1].content, str): | ||
return reply | ||
else: | ||
return self._recursive_function_call(reply, parameters) | ||
|
||
async def _async_recursive_function_call(self, function_call: FunctionCallDict, parameters: T_P) -> str: | ||
function_output = self.run_function_call(function_call) | ||
self.history.append( | ||
Message(role='function', name=function_call['name'], content=json.dumps(function_output, ensure_ascii=False)) | ||
) | ||
model_response = await self._chat_model.async_chat_completion(self.history, parameters) | ||
self.history.extend(model_response.messages) | ||
if isinstance(reply := model_response.messages[-1].content, str): | ||
return reply | ||
else: | ||
return self._recursive_function_call(reply, parameters) | ||
|
||
def reset(self) -> None: | ||
self.history.clear() |
Oops, something went wrong.