From 12ee15c5e664736dd4f4828a2040c8ebd5695215 Mon Sep 17 00:00:00 2001 From: wangyuxin Date: Fri, 1 Sep 2023 19:02:04 +0800 Subject: [PATCH 1/6] chat completion --- demo.py | 0 lmclient/__init__.py | 2 +- lmclient/cache.py | 73 +++++++++++ lmclient/chat_engine.py | 104 +++++++++++++++ lmclient/client.py | 34 +++-- lmclient/exceptions.py | 5 + lmclient/function.py | 86 +++++++++++++ lmclient/models/__init__.py | 3 +- lmclient/models/azure.py | 38 ++++-- lmclient/models/base.py | 206 +++--------------------------- lmclient/models/http.py | 88 +++++++++++++ lmclient/models/minimax.py | 86 ------------- lmclient/models/minimax_pro.py | 220 ++++++++++++++++++++++++++++++++ lmclient/models/openai.py | 223 +++++++++++++++++++-------------- lmclient/models/spark.py | 78 ------------ lmclient/models/zhipu.py | 66 ++++++++-- lmclient/openai_schema.py | 68 ---------- lmclient/types.py | 61 ++++++--- lmclient/utils.py | 138 +++++++++++++++++--- poetry.lock | 13 +- pyproject.toml | 4 + scripts/ner.py | 2 +- scripts/translate.py | 2 +- tests/test_client.py | 18 +-- tests/test_model.py | 36 +++--- 25 files changed, 1030 insertions(+), 624 deletions(-) create mode 100644 demo.py create mode 100644 lmclient/cache.py create mode 100644 lmclient/chat_engine.py create mode 100644 lmclient/exceptions.py create mode 100644 lmclient/function.py create mode 100644 lmclient/models/http.py delete mode 100644 lmclient/models/minimax.py create mode 100644 lmclient/models/minimax_pro.py delete mode 100644 lmclient/models/spark.py delete mode 100644 lmclient/openai_schema.py diff --git a/demo.py b/demo.py new file mode 100644 index 0000000..e69de29 diff --git a/lmclient/__init__.py b/lmclient/__init__.py index 0606b63..43634b5 100644 --- a/lmclient/__init__.py +++ b/lmclient/__init__.py @@ -1,6 +1,6 @@ from lmclient.client import LMClient as LMClient from lmclient.models import AzureChat as AzureChat -from lmclient.models import MinimaxChat as MinimaxChat +from lmclient.models import MinimaxProChat as MinimaxProChat from lmclient.models import OpenAIChat as OpenAIChat from lmclient.models import OpenAIExtract as OpenAIExtract from lmclient.models import ZhiPuChat as ZhiPuChat diff --git a/lmclient/cache.py b/lmclient/cache.py new file mode 100644 index 0000000..7bda4b1 --- /dev/null +++ b/lmclient/cache.py @@ -0,0 +1,73 @@ +from __future__ import annotations + +import hashlib +import os +from pathlib import Path +from typing import cast + +import diskcache + +from lmclient.types import Messages, ModelResponse, Prompt, ModelParameters +from lmclient.utils import to_dict +from lmclient.version import __cache_version__ + +DEFAULT_CACHE_DIR = Path(os.getenv('LMCLIENT_CACHE_DIR', '~/.cache/lmclient')).expanduser().resolve() + + +class ChatCacheMixin: + identifier: str + _cache: diskcache.Cache | None + _cache_dir: Path | None + + def __init__(self, use_cache: Path | str | bool = False) -> None: + if isinstance(use_cache, (str, Path)): + self.cache_dir = Path(use_cache) + elif use_cache: + self.cache_dir = DEFAULT_CACHE_DIR + else: + self.cache_dir = None + + def cache_response(self, key: str, response: ModelResponse) -> None: + if self._cache is not None: + self._cache[key] = response + else: + raise RuntimeError('Cache is not enabled') + + def try_load_response(self, key: str): + if self._cache is not None and key in self._cache: + response = self._cache[key] + response = cast(ModelResponse, response) + return response + + def generate_hash_key(self, messages: Messages, parameters: ModelParameters) -> str: + if isinstance(prompt, str): + hash_text = prompt + else: + hash_text = '---'.join([f'{k}={v}' for message in prompt for k, v in to_dict(message).items()]) + items = sorted([f'{key}={value}' for key, value in parameters.model_dump()]) + items += [f'__cache_version__={__cache_version__}'] + items = [hash_text, self.identifier] + items + task_string = '---'.join(items) + return self.md5_hash(task_string) + + @staticmethod + def md5_hash(string: str): + return hashlib.md5(string.encode()).hexdigest() + + @property + def use_cache(self) -> bool: + return self._cache is not None + + @property + def cache_dir(self) -> Path | None: + return self._cache_dir + + @cache_dir.setter + def cache_dir(self, value: Path | None) -> None: + if value is not None: + if value.exists() and not value.is_dir(): + raise ValueError(f'Cache directory {value} is not a directory') + value.mkdir(parents=True, exist_ok=True) + self._cache = diskcache.Cache(value) + else: + self._cache = None diff --git a/lmclient/chat_engine.py b/lmclient/chat_engine.py new file mode 100644 index 0000000..1f3b1e5 --- /dev/null +++ b/lmclient/chat_engine.py @@ -0,0 +1,104 @@ +from __future__ import annotations + +import json +from typing import List, Optional + +from lmclient.models.base import BaseChatModel +from lmclient.types import FunctionCallDict, GeneralParameters, Message, Messages, ModelParameters +from lmclient.utils import lm_function + + +class ChatEngine: + def __init__( + self, + chat_model: BaseChatModel, + temperature: float = 1, + top_p: float = 1, + functions: Optional[List[lm_function]] = None, + function_call_raise_error: bool = False, + **extra_model_parameters + ): + self._chat_model = chat_model + self.functions = functions or [] + if functions: + functions_schema = [function.schema for function in functions] + function_call = 'auto' + else: + functions_schema = None + function_call = None + + self.engine_parameters = GeneralParameters( + temperature=temperature, + top_p=top_p, + functions=functions_schema, + function_call=function_call, + ) + self._extra_model_parameters = extra_model_parameters + self._model_parameters: ModelParameters = self._chat_model.parameters_type.from_general_parameters(self.engine_parameters).model_copy(update=self._extra_model_parameters) + self.function_call_raise_error = function_call_raise_error + self.history: Messages = [] + + @property + def chat_model(self): + return self._chat_model + + @chat_model.setter + def chat_model(self, model: BaseChatModel): + self._chat_model = model + self._model_parameters = self._chat_model.parameters_type.from_general_parameters(self.engine_parameters).model_copy(update=self._extra_model_parameters) + + @property + def model_parameters(self): + return self._model_parameters + + @property + def extra_model_parameters(self): + return self._extra_model_parameters + + @extra_model_parameters.setter + def extra_model_parameters(self, extra_model_parameters: dict): + self._extra_model_parameters = extra_model_parameters + self._model_parameters = self.model_parameters.model_copy(update=self._extra_model_parameters) + + def chat(self, user_input: str, **extra_model_parameters) -> str: + model_parameters = self.model_parameters.model_copy(update=extra_model_parameters) + self.history.append(Message(role='user', content=user_input)) + model_response = self.chat_model.chat_completion(self.history, model_parameters) + self.history.extend(model_response.messages) + if isinstance(reply := model_response.messages[-1].content, str): + return reply + else: + return self._recursive_function_call(reply, model_parameters) + + def run_function_call(self, function_call: FunctionCallDict): + function = None + for i in self.functions: + if i.name == function_call['name']: + function = i + if function is None: + if self.function_call_raise_error: + raise ValueError(f'Function {function_call["name"]} not found') + else: + return 'Function not found, please try another function.' + + try: + arguments = json.loads(function_call["arguments"], strict=False) + return function(**arguments) + except Exception as e: + if self.function_call_raise_error: + raise e + else: + return f'Error: {e}' + + def _recursive_function_call(self, function_call: FunctionCallDict, model_parameters: ModelParameters) -> str: + function_output = self.run_function_call(function_call) + self.history.append(Message(role='function', name=function_call['name'], content=json.dumps(function_output, ensure_ascii=False))) + model_response = self.chat_model.chat_completion(self.history, model_parameters) + self.history.extend(model_response.messages) + if isinstance(reply := model_response.messages[-1].content, str): + return reply + else: + return self._recursive_function_call(reply, model_parameters) + + def reset(self) -> None: + self.history.clear() diff --git a/lmclient/client.py b/lmclient/client.py index a516eee..9874731 100644 --- a/lmclient/client.py +++ b/lmclient/client.py @@ -4,21 +4,17 @@ import time from enum import Enum from pathlib import Path -from typing import ClassVar, Generic, Sequence, TypeVar +from typing import ClassVar, Generic, Sequence import anyio import asyncer import tqdm -from lmclient.models import BaseChatModel -from lmclient.openai_schema import OpenAISchema -from lmclient.types import ChatModelOutput, Prompt +from lmclient.models.base import BaseChatModel, T +from lmclient.types import ChatModelOutput, Message, Prompt DEFAULT_CACHE_DIR = Path(os.getenv('LMCLIENT_CACHE_DIR', '~/.cache/lmclient')).expanduser().resolve() -T = TypeVar('T') -T_O = TypeVar('T_O', bound=OpenAISchema) - class ErrorMode(str, Enum): RAISE = 'raise' @@ -51,21 +47,21 @@ def __init__( self.progress_bar_mode = ProgressBarMode(progress_bar) self._task_created_time_list: list[int] = [] - def run(self, prompts: Sequence[Prompt], **kwargs) -> list[ChatModelOutput[T]]: + def run(self, prompts: Sequence[Prompt], **kwargs) -> list[ChatModelOutput]: progress_bar = self._get_progress_bar(num_tasks=len(prompts)) - task_results: list[ChatModelOutput[T]] = [] + task_results: list[ChatModelOutput] = [] for prompt in prompts: task_result = self._run_single_task(prompt=prompt, progress_bar=progress_bar, **kwargs) task_results.append(task_result) progress_bar.close() return task_results - async def _async_run(self, prompts: Sequence[Prompt], **kwargs) -> list[ChatModelOutput[T]]: + async def _async_run(self, prompts: Sequence[Prompt], **kwargs) -> list[ChatModelOutput]: limiter = anyio.CapacityLimiter(self.async_capacity) task_created_lock = anyio.Lock() progress_bar = self._get_progress_bar(num_tasks=len(prompts)) - soon_values: list[asyncer.SoonValue[ChatModelOutput[T]]] = [] + soon_values: list[asyncer.SoonValue[ChatModelOutput]] = [] async with asyncer.create_task_group() as task_group: soon_func = task_group.soonify(self._async_run_single_task) for prompt in prompts: @@ -82,7 +78,7 @@ async def _async_run(self, prompts: Sequence[Prompt], **kwargs) -> list[ChatMode values = [soon_value.value for soon_value in soon_values] return values - def async_run(self, prompts: Sequence[Prompt], **kwargs) -> list[ChatModelOutput[T]]: + def async_run(self, prompts: Sequence[Prompt], **kwargs) -> list[ChatModelOutput]: return asyncer.runnify(self._async_run)(prompts, **kwargs) async def _async_run_single_task( @@ -91,10 +87,12 @@ async def _async_run_single_task( limiter: anyio.CapacityLimiter, task_created_lock: anyio.Lock, progress_bar: tqdm.tqdm, - **kwargs, + override_parameters: T | None = None, ) -> ChatModelOutput: + if isinstance(prompt, str): + prompt = [Message(role='user', content=prompt)] async with limiter: - task_key = self.chat_model.generate_hash_key(prompt=prompt, **kwargs) + task_key = self.chat_model.generate_hash_key(prompt=prompt, override_parameters) response = self.chat_model.try_load_response(task_key) if response is None: @@ -105,14 +103,14 @@ async def _async_run_single_task( self._task_created_time_list.append(int(time.time())) try: - output = await self.chat_model.async_chat(prompt=prompt, **kwargs) + output = await self.chat_model.async_chat_completion(messages=prompt, override_parameters=**kwargs) progress_bar.update(1) return output except BaseException as e: if self.error_mode is ErrorMode.RAISE: raise elif self.error_mode is ErrorMode.IGNORE: - return ChatModelOutput(error_message=str(e)) + return ChatModelOutput(messages=[Message(role='Error', content=f'Error: {e}')]) else: raise ValueError(f'Unknown error mode: {self.error_mode}') from e @@ -127,14 +125,14 @@ def _run_single_task(self, prompt: Prompt, progress_bar: tqdm.tqdm, **kwargs) -> self._task_created_time_list.append(int(time.time())) try: - output = self.chat_model.chat(prompt=prompt, **kwargs) + output = self.chat_model.chat_completion(messages=prompt, **kwargs) progress_bar.update(1) return output except BaseException as e: if self.error_mode is ErrorMode.RAISE: raise elif self.error_mode is ErrorMode.IGNORE: - return ChatModelOutput(output=f'Response Error: {e}', response={}) + return ChatModelOutput(message=f'Response Error: {e}', response={}) else: raise ValueError(f'Unknown error mode: {self.error_mode}') from e diff --git a/lmclient/exceptions.py b/lmclient/exceptions.py new file mode 100644 index 0000000..3af3f6f --- /dev/null +++ b/lmclient/exceptions.py @@ -0,0 +1,5 @@ +class MessageError(Exception): + """ + Base class for all message errors. + """ + pass diff --git a/lmclient/function.py b/lmclient/function.py new file mode 100644 index 0000000..370b36e --- /dev/null +++ b/lmclient/function.py @@ -0,0 +1,86 @@ +# MIT License +# +# Copyright (c) 2023 Jason Liu +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import json +from functools import wraps +from typing import Any, Callable + +from docstring_parser import parse +from pydantic import validate_arguments + +from lmclient.exceptions import MessageError +from lmclient.types import Message + + +def _remove_a_key(d, remove_key) -> None: + """Remove a key from a dictionary recursively""" + if isinstance(d, dict): + for key in list(d.keys()): + if key == remove_key: + del d[key] + else: + _remove_a_key(d[key], remove_key) + + +class lm_function: + def __init__(self, func: Callable) -> None: + self.func = func + self.name = self.func.__name__ + self.validate_func = validate_arguments(func) + self.docstring = parse(self.func.__doc__ or '') + + parameters = self.validate_func.model.model_json_schema() + parameters["properties"] = { + k: v + for k, v in parameters["properties"].items() + if k not in ("v__duplicate_kwargs", "args", "kwargs") + } + for param in self.docstring.params: + if (name := param.arg_name) in parameters["properties"] and ( + description := param.description + ): + parameters["properties"][name]["description"] = description + parameters["required"] = sorted( + k for k, v in parameters["properties"].items() if "default" not in v + ) + _remove_a_key(parameters, "additionalProperties") + _remove_a_key(parameters, "title") + self.openai_schema = { + "name": self.name, + "description": self.docstring.short_description, + "parameters": parameters, + } + self.model = self.validate_func.model + + def __call__(self, *args: Any, **kwargs: Any) -> Any: + @wraps(self.func) + def wrapper(*args, **kwargs): + return self.validate_func(*args, **kwargs) + + return wrapper(*args, **kwargs) + + def from_message(self, message: Message): + function_call = message.content + if isinstance(function_call, str): + raise MessageError(f'{message} is not a valid function call message') + arguments = json.loads(function_call["arguments"], strict=False) + return self.validate_func(**arguments) diff --git a/lmclient/models/__init__.py b/lmclient/models/__init__.py index c9a898d..a0a3a13 100644 --- a/lmclient/models/__init__.py +++ b/lmclient/models/__init__.py @@ -1,6 +1,5 @@ from lmclient.models.azure import AzureChat as AzureChat from lmclient.models.base import BaseChatModel as BaseChatModel -from lmclient.models.minimax import MinimaxChat as MinimaxChat +from lmclient.models.minimax_pro import MinimaxProChat as MinimaxProChat from lmclient.models.openai import OpenAIChat as OpenAIChat -from lmclient.models.openai import OpenAIExtract as OpenAIExtract from lmclient.models.zhipu import ZhiPuChat as ZhiPuChat diff --git a/lmclient/models/azure.py b/lmclient/models/azure.py index b3ea9a0..f5a2441 100644 --- a/lmclient/models/azure.py +++ b/lmclient/models/azure.py @@ -2,43 +2,52 @@ import os from pathlib import Path -from typing import Any, TypeVar +from typing import Any -from lmclient.models.base import HttpChatModel, RetryStrategy -from lmclient.models.openai import OpenAIContentParser -from lmclient.parser import ModelResponseParser -from lmclient.types import Messages +from lmclient.models.http import HttpChatModel, RetryStrategy +from lmclient.models.openai import ( + OpenAIChatParameters, + OpenAIMessageDict, + convert_lmclient_to_openai, + parse_openai_model_reponse, +) +from lmclient.types import Messages, ModelResponse +from lmclient.utils import to_dict -T = TypeVar('T') +class AzureChat(HttpChatModel[OpenAIChatParameters]): + parameters_type = OpenAIChatParameters -class AzureChat(HttpChatModel[T]): def __init__( self, model: str | None = None, + system_prompt: str | None = None, api_key: str | None = None, api_base: str | None = None, api_version: str | None = None, timeout: int | None = 60, - response_parser: ModelResponseParser[T] | None = None, retry: bool | RetryStrategy = False, + default_parameters: OpenAIChatParameters | None = None, use_cache: Path | str | bool = False, ): - response_parser = response_parser or OpenAIContentParser() - super().__init__(timeout=timeout, response_parser=response_parser, retry=retry, use_cache=use_cache) + super().__init__(default_parameters=default_parameters, timeout=timeout, retry=retry, use_cache=use_cache) self.model = model or os.environ['AZURE_CHAT_API_ENGINE'] or os.environ['AZURE_CHAT_MODEL_NAME'] + self.system_prompt = system_prompt self.api_key = api_key or os.environ['AZURE_API_KEY'] self.api_base = api_base or os.environ['AZURE_API_BASE'] self.api_version = api_version or os.getenv('AZURE_API_VERSION') - def get_post_parameters(self, messages: Messages, **kwargs) -> dict[str, Any]: + def get_post_parameters(self, messages: Messages, parameters: OpenAIChatParameters | None = None) -> dict[str, Any]: headers = { 'api-key': self.api_key, } + parameters_dict = {} if parameters is None else to_dict(parameters, exclude_defaults=True) + openai_messages: list[OpenAIMessageDict] = [] if self.system_prompt is None else [{'role': 'system', 'content': self.system_prompt}] + openai_messages = openai_messages + [convert_lmclient_to_openai(message) for message in messages] params = { 'model': self.model, - 'messages': messages, - **kwargs, + 'messages': [convert_lmclient_to_openai(message) for message in messages], + **parameters_dict, } return { 'url': f'{self.api_base}/openai/deployments/{self.model}/chat/completions?api-version={self.api_version}', @@ -46,6 +55,9 @@ def get_post_parameters(self, messages: Messages, **kwargs) -> dict[str, Any]: 'json': params, } + def parse_model_reponse(self, response: ModelResponse) -> Messages: + return parse_openai_model_reponse(response) + @property def identifier(self) -> str: return f'{self.__class__.__name__}({self.model})' diff --git a/lmclient/models/base.py b/lmclient/models/base.py index 23fa834..ddb5898 100644 --- a/lmclient/models/base.py +++ b/lmclient/models/base.py @@ -1,201 +1,31 @@ from __future__ import annotations -import hashlib -import os +from abc import ABC, abstractmethod from pathlib import Path -from typing import Any, Generic, TypeVar, cast +from typing import ClassVar, Generic, Type, TypeVar -import diskcache -import httpx +from lmclient.cache import ChatCacheMixin +from lmclient.types import ChatModelOutput, Messages, ModelParameters -try: - from pydantic.v1 import BaseModel -except ImportError: - from pydantic import BaseModel -from tenacity import retry, stop_after_attempt, wait_random_exponential +T = TypeVar("T", bound=ModelParameters) -from lmclient.parser import ModelResponseParser -from lmclient.types import ChatModelOutput, Messages, ModelResponse, Prompt -from lmclient.utils import ensure_messages -from lmclient.version import __cache_version__ -T = TypeVar('T') -DEFAULT_CACHE_DIR = Path(os.getenv('LMCLIENT_CACHE_DIR', '~/.cache/lmclient')).expanduser().resolve() +class BaseChatModel(ABC, Generic[T], ChatCacheMixin): + parameters_type: ClassVar[Type[ModelParameters]] - -class BaseChatModel(Generic[T]): - _cache: diskcache.Cache | None - _cache_dir: Path | None - - def __init__( - self, - response_parser: ModelResponseParser[T] | None = None, - use_cache: Path | str | bool = False, - ) -> None: - self.response_parser = response_parser - - if isinstance(use_cache, (str, Path)): - self.cache_dir = Path(use_cache) - elif use_cache: - self.cache_dir = DEFAULT_CACHE_DIR - else: - self.cache_dir = None + def __init__(self, default_parameters: T | None = None, use_cache: Path | str | bool = False) -> None: + super().__init__(use_cache=use_cache) + self.default_parameters = default_parameters @property + @abstractmethod def identifier(self) -> str: - raise NotImplementedError - - def call_model(self, messages: Messages, **kwargs) -> ModelResponse: - raise NotImplementedError - - async def async_call_model(self, messages: Messages, **kwargs) -> ModelResponse: - raise NotImplementedError - - def chat(self, prompt: Prompt, **kwargs) -> ChatModelOutput[T]: - messages = ensure_messages(prompt) - - if self.use_cache: - hash_key = self.generate_hash_key(prompt) - model_response = self.try_load_response(hash_key) - if model_response is None: - model_response = self.call_model(messages, **kwargs) - self.cache_response(hash_key, model_response) - else: - model_response = self.call_model(messages, **kwargs) - - if self.response_parser is None: - parsed_result = None - else: - parsed_result = self.response_parser(model_response) - - return ChatModelOutput( - parsed_result=parsed_result, - response=model_response, - ) - - async def async_chat(self, prompt: Prompt, **kwargs) -> ChatModelOutput[T]: - messages = ensure_messages(prompt) - - if self.use_cache: - hash_key = self.generate_hash_key(prompt) - model_response = self.try_load_response(hash_key) - if model_response is None: - model_response = await self.async_call_model(messages, **kwargs) - self.cache_response(hash_key, model_response) - else: - model_response = await self.async_call_model(messages, **kwargs) - - if self.response_parser is None: - parsed_result = None - else: - parsed_result = self.response_parser(model_response) - - return ChatModelOutput( - parsed_result=parsed_result, - response=model_response, - ) - - def cache_response(self, key: str, response: ModelResponse) -> None: - if self._cache is not None: - self._cache[key] = response - else: - raise RuntimeError('Cache is not enabled') - - def try_load_response(self, key: str): - if self._cache is not None and key in self._cache: - response = self._cache[key] - response = cast(ModelResponse, response) - return response - - def generate_hash_key(self, prompt: Prompt, **kwargs) -> str: - if isinstance(prompt, str): - hash_text = prompt - else: - hash_text = '---'.join([f'{k}={v}' for message in prompt for k, v in message.items()]) - items = sorted([f'{key}={value}' for key, value in kwargs.items()]) - items += [f'__cache_version__={__cache_version__}'] - items = [hash_text, self.identifier] + items - task_string = '---'.join(items) - return self.md5_hash(task_string) - - @staticmethod - def md5_hash(string: str): - return hashlib.md5(string.encode()).hexdigest() - - @property - def use_cache(self) -> bool: - return self._cache is not None - - @property - def cache_dir(self) -> Path | None: - return self._cache_dir - - @cache_dir.setter - def cache_dir(self, value: Path | None) -> None: - if value is not None: - if value.exists() and not value.is_dir(): - raise ValueError(f'Cache directory {value} is not a directory') - value.mkdir(parents=True, exist_ok=True) - self._cache = diskcache.Cache(value) - else: - self._cache = None - - -class RetryStrategy(BaseModel): # type: ignore - min_wait_seconds: int = 2 - max_wait_seconds: int = 20 - max_attempt: int = 3 - - -class HttpChatModel(BaseChatModel[T]): - def __init__( - self, - timeout: int | None = None, - retry: bool | RetryStrategy = False, - response_parser: ModelResponseParser[T] | None = None, - use_cache: Path | str | bool = False, - ): - super().__init__(response_parser=response_parser, use_cache=use_cache) - self.timeout = timeout - if isinstance(retry, RetryStrategy): - self.retry_strategy = retry - else: - self.retry_strategy = RetryStrategy() if retry else None - - def get_post_parameters(self, messages: Messages, **kwargs) -> dict[str, Any]: - raise NotImplementedError - - def call_model(self, messages: Messages, **kwargs) -> ModelResponse: - parameters = self.get_post_parameters(messages, **kwargs) - parameters = {'timeout': self.timeout, **parameters} - http_response = httpx.post(**parameters) - http_response.raise_for_status() - model_response = http_response.json() - return model_response - - async def async_call_model(self, messages: Messages, **kwargs) -> ModelResponse: - async with httpx.AsyncClient() as client: - parameters = self.get_post_parameters(messages, **kwargs) - parameters = {'timeout': self.timeout, **parameters} - http_response = await client.post(**parameters) - http_response.raise_for_status() - model_response = http_response.json() - return model_response - - def chat(self, prompt: Prompt, **kwargs) -> ChatModelOutput[T]: - if self.retry_strategy is None: - return super().chat(prompt, **kwargs) - - wait = wait_random_exponential(min=self.retry_strategy.min_wait_seconds, max=self.retry_strategy.max_wait_seconds) - stop = stop_after_attempt(self.retry_strategy.max_attempt) - output = retry(wait=wait, stop=stop)(super().chat)(prompt=prompt, **kwargs) - return output + ... - async def async_chat(self, prompt: Prompt, **kwargs) -> ChatModelOutput[T]: - if self.retry_strategy is None: - return await super().async_chat(prompt, **kwargs) + @abstractmethod + def chat_completion(self, messages: Messages, override_parameters: T | None = None) -> ChatModelOutput: + ... - wait = wait_random_exponential(min=self.retry_strategy.min_wait_seconds, max=self.retry_strategy.max_wait_seconds) - stop = stop_after_attempt(self.retry_strategy.max_attempt) - output = await retry(wait=wait, stop=stop)(super().async_chat)(prompt=prompt, **kwargs) - return output + @abstractmethod + async def async_chat_completion(self, messages: Messages, override_parameters: T | None = None) -> ChatModelOutput: + ... diff --git a/lmclient/models/http.py b/lmclient/models/http.py new file mode 100644 index 0000000..7e94c56 --- /dev/null +++ b/lmclient/models/http.py @@ -0,0 +1,88 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from pathlib import Path +from typing import Any + +import httpx +from tenacity import retry, stop_after_attempt, wait_random_exponential + +from lmclient.models.base import BaseChatModel, T +from lmclient.types import BaseModel, ChatModelOutput, Messages, ModelResponse + + +class RetryStrategy(BaseModel): # type: ignore + min_wait_seconds: int = 2 + max_wait_seconds: int = 20 + max_attempt: int = 3 + + +class HttpChatModel(BaseChatModel[T], ABC): + def __init__( + self, + timeout: int | None = None, + retry: bool | RetryStrategy = False, + default_parameters: T | None = None, + use_cache: Path | str | bool = False, + ): + super().__init__(default_parameters=default_parameters, use_cache=use_cache) + self.timeout = timeout + if isinstance(retry, RetryStrategy): + self.retry_strategy = retry + else: + self.retry_strategy = RetryStrategy() if retry else None + + @abstractmethod + def get_post_parameters(self, messages: Messages, parameters: T | None = None) -> dict[str, Any]: + ... + + @abstractmethod + def parse_model_reponse(self, response: ModelResponse) -> Messages: + ... + + def _chat_completion(self, messages: Messages, override_parameters: T | None = None) -> ChatModelOutput: + if self.default_parameters is not None and override_parameters is not None: + override_parameters = self.default_parameters.model_copy(update=override_parameters.model_dump()) + + http_parameters = self.get_post_parameters(messages, override_parameters) + http_parameters = {'timeout': self.timeout, **http_parameters} + http_response = httpx.post(**http_parameters) + http_response.raise_for_status() + model_response = http_response.json() + return ChatModelOutput( + messages=self.parse_model_reponse(model_response), + response=model_response, + ) + + async def _async_chat_completion(self, messages: Messages, override_parameters: T | None = None) -> ChatModelOutput: + if self.default_parameters is not None and override_parameters is not None: + override_parameters = self.default_parameters.model_copy(update=override_parameters.model_dump()) + + async with httpx.AsyncClient() as client: + http_parameters = self.get_post_parameters(messages, override_parameters) + http_parameters = {'timeout': self.timeout, **http_parameters} + http_response = await client.post(**http_parameters) + http_response.raise_for_status() + model_response = http_response.json() + return ChatModelOutput( + messages=self.parse_model_reponse(model_response), + response=model_response, + ) + + def chat_completion(self, messages: Messages, override_parameters: T | None = None) -> ChatModelOutput: + if self.retry_strategy is None: + return self._chat_completion(messages, override_parameters) + + wait = wait_random_exponential(min=self.retry_strategy.min_wait_seconds, max=self.retry_strategy.max_wait_seconds) + stop = stop_after_attempt(self.retry_strategy.max_attempt) + output = retry(wait=wait, stop=stop)(self._chat_completion)(messages, override_parameters) + return output + + async def async_chat_completion(self, messages: Messages, override_parameters: T | None = None) -> ChatModelOutput: + if self.retry_strategy is None: + return await self._async_chat_completion(messages, override_parameters) + + wait = wait_random_exponential(min=self.retry_strategy.min_wait_seconds, max=self.retry_strategy.max_wait_seconds) + stop = stop_after_attempt(self.retry_strategy.max_attempt) + output = await retry(wait=wait, stop=stop)(self._async_chat_completion)(messages, override_parameters) + return output diff --git a/lmclient/models/minimax.py b/lmclient/models/minimax.py deleted file mode 100644 index 4857059..0000000 --- a/lmclient/models/minimax.py +++ /dev/null @@ -1,86 +0,0 @@ -from __future__ import annotations - -import os -from pathlib import Path -from typing import Any, TypeVar - -from lmclient.models.base import HttpChatModel, RetryStrategy -from lmclient.parser import ModelResponseParser, ParserError -from lmclient.types import Messages, ModelResponse - -T = TypeVar('T') - - -class MinimaxTextParser(ModelResponseParser): - def __call__(self, response: ModelResponse) -> str: - try: - output = response['choices'][0]['text'] - except (KeyError, IndexError) as e: - raise ParserError('Parse response failed') from e - return output - - -class MinimaxChat(HttpChatModel[T]): - def __init__( - self, - model: str = 'abab5.5-chat', - group_id: str | None = None, - api_key: str | None = None, - timeout: int | None = 60, - response_parser: ModelResponseParser[T] | None = None, - retry: bool | RetryStrategy = False, - use_cache: Path | str | bool = False, - ): - response_parser = response_parser or MinimaxTextParser() - super().__init__(timeout=timeout, response_parser=response_parser, retry=retry, use_cache=use_cache) - self.model = model - self.group_id = group_id or os.environ['MINIMAX_GROUP_ID'] - self.api_key = api_key or os.environ['MINIMAX_API_KEY'] - - def get_post_parameters(self, messages: Messages, **kwargs) -> dict[str, Any]: - headers = { - 'Authorization': f'Bearer {self.api_key}', - 'Content-Type': 'application/json', - } - json_data = self._messages_to_request_json_data(messages) - if 'temperature' in kwargs: - kwargs['temperature'] = max(0.01, kwargs['temperature']) - json_data.update(kwargs) - return { - 'url': f'https://api.minimax.chat/v1/text/chatcompletion?GroupId={self.group_id}', - 'json': json_data, - 'headers': headers, - } - - def _messages_to_request_json_data(self, messages: Messages): - data: dict[str, Any] = { - 'model': self.model, - 'role_meta': {'user_name': '用户', 'bot_name': 'MM智能助理'}, - } - - if messages[0]['role'] == 'system': - data['prompt'] = messages[0]['content'] - messages = messages[1:] - else: - data['prompt'] = '你是MM智能助理' - minimax_messages = [] - for message in messages: - if message['role'] == 'user': - role = 'USER' - elif message['role'] == 'assistant': - role = 'BOT' - else: - raise ValueError(f'Invalid role: {message["role"]}') - - minimax_messages.append( - { - 'sender_type': role, - 'text': message['content'], - } - ) - data['messages'] = minimax_messages - return data - - @property - def identifier(self) -> str: - return f'{self.__class__.__name__}({self.model})' diff --git a/lmclient/models/minimax_pro.py b/lmclient/models/minimax_pro.py new file mode 100644 index 0000000..ae7de84 --- /dev/null +++ b/lmclient/models/minimax_pro.py @@ -0,0 +1,220 @@ +from __future__ import annotations + +import json +import os +from pathlib import Path +from typing import Any, ClassVar, List, Literal, Optional, Type + +from typing_extensions import NotRequired, TypedDict + +from lmclient.exceptions import MessageError +from lmclient.models.http import HttpChatModel, RetryStrategy +from lmclient.parser import ModelResponseParser, ParserError +from lmclient.types import ( + Field, + FunctionCallDict, + FunctionDict, + GeneralParameters, + Message, + Messages, + ModelParameters, + ModelResponse, +) +from lmclient.utils import to_dict + +DEFAULT_BOT_PROMPT = "MM智能助理是一款由MiniMax自研的,没有调用其他产品的接口的大型语言模型。MiniMax是一家中国科技公司,一直致力于进行大模型相关的研究。" + + +class BotSettingDict(TypedDict): + bot_name: str + content: str + + +class GlyphDict(TypedDict): + type: str + raw_glpyh: str + json_properties: dict + + +class ReplyConstrainsDict(TypedDict): + sender_type: str + sender_name: str + glyph: NotRequired[GlyphDict] + + +class MinimaxMessageDict(TypedDict): + sender_type: Literal['USER', 'BOT', 'FUNCTION'] + sender_name: str + text: str + function_call: NotRequired[FunctionCallDict] + + +class MinimaxProChatParameters(ModelParameters): + temperature: float = 1 + top_p: float = 1 + tokens_to_generate: int = 1024 + mask_sensitive_info: bool = True + bot_setting: List[BotSettingDict] = Field(default_factory=list) + reply_constrains: ReplyConstrainsDict = Field(default_factory=list) + sample_messages: Optional[List[MinimaxMessageDict]] = None + functions: Optional[List[FunctionDict]] = None + plugins : Optional[List[str]] = None + + @classmethod + def from_general_parameters(cls, general_parameters: GeneralParameters): + return cls( + temperature=general_parameters.temperature, + top_p=general_parameters.top_p, + tokens_to_generate=general_parameters.max_tokens or 1024, + functions=general_parameters.functions, + ) + +class MinimaxProFunctionCallParser(ModelResponseParser): + def __call__(self, response: ModelResponse) -> FunctionCallDict: + try: + function_call_dict = response['choices'][0]['messages'][-1]['function_call'] + return FunctionCallDict( + name=function_call_dict['name'], + arguments=json.loads(function_call_dict['arguments']) + ) + except (KeyError, IndexError) as e: + raise ParserError('Parse response failed') from e + + +class MinimaxProTextParser(ModelResponseParser): + def __call__(self, response: ModelResponse) -> str: + try: + output = response['reply'] + except (KeyError, IndexError) as e: + raise ParserError('Parse response failed') from e + return output + + +class MinimaxProParser(ModelResponseParser): + def __call__(self, response: ModelResponse) -> Messages: + return [self._minimax_to_lmclient(i) for i in response['choices'][0]['messages']] + + @staticmethod + def _minimax_to_lmclient(message: MinimaxMessageDict) -> Message: + if 'function_call' in message: + return Message( + role=message['sender_type'], + name=message['sender_name'], + content=message['function_call'] + ) + else: + return Message( + role=message['sender_type'], + name=message['sender_name'], + content=message['text'], + ) + + +class MinimaxProChat(HttpChatModel[MinimaxProChatParameters]): + parameters_type = MinimaxProChatParameters + + def __init__( + self, + model: str = 'abab5.5-chat', + base_url: str = 'https://api.minimax.chat/v1/text/chatcompletion_pro', + group_id: str | None = None, + api_key: str | None = None, + bot_name: str = 'MM智能助理', + user_name: str = '用户', + system_prompt: str | None = None, + timeout: int | None = 60, + retry: bool | RetryStrategy = False, + default_parameters: MinimaxProChatParameters | None = None, + use_cache: Path | str | bool = False, + ): + super().__init__(default_parameters=default_parameters, timeout=timeout, retry=retry, use_cache=use_cache) + self.model = model + self.base_url = base_url + self.group_id = group_id or os.environ['MINIMAX_GROUP_ID'] + self.api_key = api_key or os.environ['MINIMAX_API_KEY'] + self.bot_name = bot_name + self.system_prompt = system_prompt or DEFAULT_BOT_PROMPT + self.user_name = user_name + + def get_post_parameters(self, messages: Messages, parameters: MinimaxProChatParameters | None = None) -> dict[str, Any]: + headers = { + 'Authorization': f'Bearer {self.api_key}', + 'Content-Type': 'application/json', + } + + json_data = { + 'model': self.model, + 'messages': [self._lmclient_to_minimax(message, self.bot_name, self.user_name) for message in messages] + } + + parameters = parameters or MinimaxProChatParameters() + if not parameters.bot_setting: + parameters.bot_setting = [{'bot_name': self.bot_name, 'content': self.system_prompt}] + if not parameters.reply_constrains: + parameters.reply_constrains = {'sender_type': 'USER', 'sender_name': self.bot_name} + parameters_dict = to_dict(parameters, exclude_defaults=True) if parameters else {} + if 'temperature' in parameters_dict: + parameters_dict['temperature'] = max(0.01, parameters_dict['temperature']) + json_data.update(parameters_dict) + + return { + 'url': self.base_url, + 'json': json_data, + 'headers': headers, + 'params': {'GroupId': self.group_id}, + } + + def parse_model_reponse(self, response: ModelResponse) -> Messages: + return [self._minimax_to_lmclient(i) for i in response['choices'][0]['messages']] + + @staticmethod + def _minimax_to_lmclient(message: MinimaxMessageDict) -> Message: + if 'function_call' in message: + return Message( + role=message['sender_type'], + name=message['sender_name'], + content=message['function_call'] + ) + else: + return Message( + role=message['sender_type'], + name=message['sender_name'], + content=message['text'], + ) + + def _lmclient_to_minimax(self, message: Message, default_bot_name: str = 'MM智能助理', default_user_name: str = '用户') -> MinimaxMessageDict: + if isinstance(message.content, dict): + if message.role != 'BOT': + raise MessageError(f'Invalid role {message.role} for function call, must be BOT') + return { + 'sender_type': message.role, + 'sender_name': message.name or default_bot_name, + 'text': '', + 'function_call': message.content, + } + elif message.role == 'BOT': + return { + 'sender_type': message.role, + 'sender_name': message.name or default_bot_name, + 'text': message.content, + } + elif message.role == 'FUNCTION': + if message.name is None: + raise MessageError(f'Function name is required, message: {message}') + return { + 'sender_type': message.role, + 'sender_name': message.name, + 'text': message.content, + } + elif message.role == 'USER': + return { + 'sender_type': message.role, + 'sender_name': message.name or default_user_name, + 'text': message.content, + } + else: + raise MessageError(f'Invalid role {message.role}, must be BOT, FUNCTION, or USER') + + @property + def identifier(self) -> str: + return f'{self.__class__.__name__}({self.model})' diff --git a/lmclient/models/openai.py b/lmclient/models/openai.py index b65cd16..0a682fa 100644 --- a/lmclient/models/openai.py +++ b/lmclient/models/openai.py @@ -2,122 +2,154 @@ import os from pathlib import Path -from typing import Any, Type, TypeVar - -from lmclient.models.base import HttpChatModel, RetryStrategy -from lmclient.openai_schema import OpenAISchema -from lmclient.parser import ModelResponseParser, ParserError -from lmclient.types import Messages, ModelResponse - -T = TypeVar('T') -T_O = TypeVar('T_O', bound=OpenAISchema) - - -class OpenAIParser(ModelResponseParser): - def __call__(self, response: ModelResponse) -> str | dict[str, str]: - try: - if self.is_function_call(response): - fucntion_call_output: dict[str, str] = response['choices'][0]['message']['function_call'] - return fucntion_call_output - else: - content_output: str = response['choices'][0]['message']['content'] - return content_output - except (KeyError, IndexError) as e: - raise ParserError('Parse response failed') from e - - @staticmethod - def is_function_call(reponse: ModelResponse) -> bool: - message = reponse['choices'][0]['message'] - return bool(message.get('function_call')) - - -class OpenAIFunctionCallParser(ModelResponseParser): - def __call__(self, response: ModelResponse) -> dict[str, str]: - try: - output: dict[str, str] = response['choices'][0]['message']['function_call'] - except (KeyError, IndexError) as e: - raise ParserError('Parse response failed') from e - return output - - -class OpenAIContentParser(ModelResponseParser): - def __call__(self, response: ModelResponse) -> str: - try: - output: str = response['choices'][0]['message']['content'] - except (KeyError, IndexError) as e: - raise ParserError('Parse response failed') from e - return output - - -class OpenAIChat(HttpChatModel[T]): - def __init__( - self, - model: str = 'gpt-3.5-turbo', - api_key: str | None = None, - api_base: str | None = None, - timeout: int | None = 60, - response_parser: ModelResponseParser[T] | None = None, - retry: bool | RetryStrategy = False, - use_cache: Path | str | bool = False, - ): - response_parser = response_parser or OpenAIContentParser() - super().__init__(timeout=timeout, response_parser=response_parser, retry=retry, use_cache=use_cache) - self.model = model - self.api_base = api_base or os.getenv('OPENAI_API_BASE') or 'https://api.openai.com/v1' - self.api_key = api_key or os.environ['OPENAI_API_KEY'] - self.timeout = timeout - - def get_post_parameters(self, messages: Messages, **kwargs) -> dict[str, Any]: - headers = { - 'Authorization': f'Bearer {self.api_key}', - } - params = { - 'model': self.model, - 'messages': messages, - **kwargs, - } +from typing import Any, Dict, List, Literal, Optional, Union + +from typing_extensions import NotRequired, TypedDict + +from lmclient.exceptions import MessageError +from lmclient.models.http import HttpChatModel, RetryStrategy +from lmclient.parser import ParserError +from lmclient.types import FunctionCallDict, FunctionDict, GeneralParameters, Message, Messages, ModelParameters, ModelResponse +from lmclient.utils import to_dict + + +class FunctionCallNameDict(TypedDict): + name: str + + +class OpenAIMessageDict(TypedDict): + role: str + content: Optional[str] + name: NotRequired[str] + function_call: NotRequired[FunctionCallDict] + + +class OpenAIChatParameters(ModelParameters): + temperature: float = 1 + top_p: float = 1 + max_tokens: Optional[int] = None + functions: Optional[List[FunctionDict]] = None + function_call: Union[Literal['auto'], FunctionCallNameDict, None] = None + stop: Union[str, List[str], None] = None + presence_penalty: Optional[float] = 0 + frequency_penalty: Optional[float] = 0 + logit_bias: Optional[Dict[int, int]] = None + user: Optional[str] = None + + @classmethod + def from_general_parameters(cls, general_parameters: GeneralParameters): + if general_parameters.function_call != 'auto' and general_parameters.function_call is not None: + function_call = FunctionCallNameDict(name=general_parameters.function_call) + else: + function_call = general_parameters.function_call + + return cls( + temperature=general_parameters.temperature, + top_p=general_parameters.top_p, + max_tokens=general_parameters.max_tokens, + functions=general_parameters.functions, + function_call=function_call, + ) + + +class OpenAIExtractParameters(ModelParameters): + temperature: float = 1 + top_p: float = 1 + stop: Union[str, List[str], None] = None + presence_penalty: Optional[float] = 0 + frequency_penalty: Optional[float] = 0 + logit_bias: Optional[Dict[int, int]] = None + user: Optional[str] = None + + @classmethod + def from_general_parameters(cls, general_parameters: GeneralParameters): + return cls( + temperature=general_parameters.temperature, + top_p=general_parameters.top_p, + ) + + +def convert_lmclient_to_openai(message: Message, valid_roles: set[str] | None = None) -> OpenAIMessageDict: + valid_roles = valid_roles or {'user', 'assistant', 'function', 'system'} + if message.role not in valid_roles: + raise MessageError(f'Invalid role "{message.role}", supported roles are {valid_roles}') + + content = message.content + + if isinstance(content, dict): + if message.role != 'assistant': + raise MessageError(f'Invalid role "{message.role}" for function call, can only be made by "assistant"') return { - 'url': f'{self.api_base}/chat/completions', - 'headers': headers, - 'json': params, - } + 'role': message.role, + 'function_call': content, + 'content': None, + } + elif message.role == 'function': + name = message.name + if name is None: + raise MessageError(f'Function name is required, message: {message}') + return { + 'role': message.role, + 'name': name, + 'content': content, + } + else: + return { + 'role': message.role, + 'content': content, + } + - @property - def identifier(self) -> str: - return f'{self.__class__.__name__}({self.model})' +def parse_openai_model_reponse(response: ModelResponse) -> Messages: + funcation_call = response['choices'][0]['message'].get('function_call') + try: + if bool(funcation_call): + return [Message( + role='assistant', + content=funcation_call, + )] + else: + text: str = response['choices'][0]['message']['content'] + return [Message( + role='assistant', + content=text, + )] + except (KeyError, IndexError) as e: + raise ParserError('Parse response failed') from e + + +class OpenAIChat(HttpChatModel[OpenAIChatParameters]): + parameters_type = OpenAIChatParameters -class OpenAIExtract(HttpChatModel[T_O]): def __init__( self, - schema: Type[T_O], model: str = 'gpt-3.5-turbo', - system_prompt: str = 'Extract structured data from a given text', + system_prompt: str | None = None, api_key: str | None = None, api_base: str | None = None, timeout: int | None = 60, retry: bool | RetryStrategy = False, + default_parameters: OpenAIChatParameters | None = None, use_cache: Path | str | bool = False, ): - super().__init__(timeout=timeout, response_parser=schema.from_response, retry=retry, use_cache=use_cache) - self.schema = schema + super().__init__(default_parameters=default_parameters, timeout=timeout, retry=retry, use_cache=use_cache) self.model = model self.system_prompt = system_prompt self.api_base = api_base or os.getenv('OPENAI_API_BASE') or 'https://api.openai.com/v1' self.api_key = api_key or os.environ['OPENAI_API_KEY'] - self.timeout = timeout - def get_post_parameters(self, messages: Messages, **kwargs) -> dict[str, Any]: - messages = [{'role': 'system', 'content': self.system_prompt}] + list(messages) # type: ignore + def get_post_parameters(self, messages: Messages, parameters: OpenAIChatParameters | None = None) -> dict[str, Any]: headers = { 'Authorization': f'Bearer {self.api_key}', } + parameters_dict = {} if parameters is None else to_dict(parameters, exclude_defaults=True) + openai_messages: list[OpenAIMessageDict] = [] if self.system_prompt is None else [{'role': 'system', 'content': self.system_prompt}] + openai_messages = openai_messages + [convert_lmclient_to_openai(message) for message in messages] params = { 'model': self.model, - 'messages': messages, - 'functions': [self.schema.openai_schema()], - 'function_call': {'name': self.schema.openai_schema()['name']}, - **kwargs, + 'messages': openai_messages, + **parameters_dict, } return { 'url': f'{self.api_base}/chat/completions', @@ -125,6 +157,9 @@ def get_post_parameters(self, messages: Messages, **kwargs) -> dict[str, Any]: 'json': params, } + def parse_model_reponse(self, response: ModelResponse) -> Messages: + return parse_openai_model_reponse(response) + @property def identifier(self) -> str: - return f'{self.__class__.__name__}(model={self.model}, system_prompt={self.system_prompt})' + return f'{self.__class__.__name__}({self.model})' diff --git a/lmclient/models/spark.py b/lmclient/models/spark.py deleted file mode 100644 index df2e70d..0000000 --- a/lmclient/models/spark.py +++ /dev/null @@ -1,78 +0,0 @@ -from __future__ import annotations - -import base64 -import hashlib -import hmac -import json -import os -from datetime import datetime -from time import mktime -from typing import Any -from urllib.parse import urlencode, urlparse -from wsgiref.handlers import format_date_time - -import websocket - -from lmclient.models.base import BaseChatModel - - -class SparkChat(BaseChatModel): - def __init__( - self, - app_id: str | None = None, - api_key: str | None = None, - api_secret: str | None = None, - spark_url: str | None = None, - ) -> None: - self.app_id = app_id or os.environ['SPARK_APP_ID'] - self.api_key = api_key or os.environ['SPARK_API_KEY'] - self.api_secret = api_secret or os.environ['SPARK_API_SECRET'] - self.spark_url = spark_url or os.environ['SPARK_URL'] - - self.response: dict[str, Any] = {} - self.receive_round = 0 - - @property - def host(self) -> str: - return urlparse(self.spark_url).netloc - - @property - def path(self) -> str: - return urlparse(self.spark_url).path - - def generate_new_request_url(self): - now = datetime.now() - date = format_date_time(mktime(now.timetuple())) - - signature_origin = 'host: ' + self.host + '\n' - signature_origin += 'date: ' + date + '\n' - signature_origin += 'GET ' + self.path + ' HTTP/1.1' - - signature_sha = hmac.new( - self.api_secret.encode('utf-8'), signature_origin.encode('utf-8'), digestmod=hashlib.sha256 - ).digest() - signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding='utf-8') - - authorization_origin = f'api_key="{self.api_key}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"' - authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8') - - values = {'authorization': authorization, 'date': date, 'host': self.host} - # 拼接鉴权参数,生成url - url = self.spark_url + '?' + urlencode(values) - return url - - def on_message(self, ws_app: websocket.WebSocketApp, message: str) -> None: - receive_reponse = json.loads(message) - - code = receive_reponse['header']['code'] - if code != 0: - ws_app.close() - raise Exception(f'Error code: {code}') - - status = receive_reponse['payload']['choices']['status'] - if status == 2: - ws_app.close() - return - - self.response[f'round_{self.receive_round}'] = receive_reponse - self.receive_round += 1 diff --git a/lmclient/models/zhipu.py b/lmclient/models/zhipu.py index 08518e7..3628298 100644 --- a/lmclient/models/zhipu.py +++ b/lmclient/models/zhipu.py @@ -1,20 +1,50 @@ from __future__ import annotations +import logging import os -import time from pathlib import Path -from typing import Any, TypeVar +import time +from typing import Any, TypedDict, TypeVar import cachetools.func # type: ignore import jwt -from lmclient.models.base import HttpChatModel, RetryStrategy +from lmclient.exceptions import MessageError +from lmclient.models.http import HttpChatModel, RetryStrategy from lmclient.parser import ModelResponseParser, ParserError -from lmclient.types import Messages, ModelResponse +from lmclient.types import GeneralParameters, Messages, ModelParameters, ModelResponse +from lmclient.utils import to_dict T = TypeVar('T') API_TOKEN_TTL_SECONDS = 3 * 60 CACHE_TTL_SECONDS = API_TOKEN_TTL_SECONDS - 30 +logger = logging.getLogger(__name__) + + +class ZhiPuChatParameters(ModelParameters): + temperature: float = 1 + top_p: float = 1 + + @classmethod + def from_general_parameters(cls, general_parameters: GeneralParameters): + return cls( + temperature=general_parameters.temperature, + top_p=general_parameters.top_p, + ) + + +class ZhiPuResponse(ModelResponse): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.content = self.content.replace('\n', ' ') + + +class ZhiPuModel(HttpChatModel): + name = 'zhipu' + +class ZhiPuMessageDict(TypedDict): + role: str + content: str @cachetools.func.ttl_cache(maxsize=10, ttl=CACHE_TTL_SECONDS) @@ -47,33 +77,43 @@ def __call__(self, response: ModelResponse) -> str: return output -class ZhiPuChat(HttpChatModel[T]): +class ZhiPuChat(HttpChatModel[ZhiPuChatParameters]): def __init__( self, model: str = 'chatglm_pro', api_base: str | None = None, api_key: str | None = None, timeout: int | None = 60, - response_parser: ModelResponseParser[T] | None = None, retry: bool | RetryStrategy = False, + default_parameters: ZhiPuChatParameters | None = None, use_cache: Path | str | bool = False, - ) -> None: - response_parser = response_parser or ZhiPuParser() - super().__init__(timeout=timeout, response_parser=response_parser, retry=retry, use_cache=use_cache) + ): + super().__init__(default_parameters=default_parameters, timeout=timeout, retry=retry, use_cache=use_cache) self.model = model self.api_key = api_key or os.environ['ZHIPU_API_KEY'] self.api_base = api_base or os.getenv('ZHIPU_API_BASE') or 'https://open.bigmodel.cn/api/paas/v3/model-api' self.api_base.rstrip('/') - def get_post_parameters(self, messages: Messages, **kwargs) -> dict[str, Any]: + def get_post_parameters(self, messages: Messages, parameters: ZhiPuChatParameters | None = None) -> dict[str, Any]: for message in messages: - if message['role'] not in ('user', 'assistant'): - raise ValueError(f'Role of message must be user or assistant, but got {message["role"]}') + if message.role not in ('user', 'assistant'): + raise ValueError(f'Role of message must be user or assistant, but got {message.role}') + zhipu_messages: list[ZhiPuMessageDict] = [] + for message in messages: + if message.role not in ('user', 'assistant'): + raise MessageError(f'Role of message must be user or assistant, but got {message.role}') + if not isinstance(message.content, str): + raise MessageError(f'Message content must be str, but got {type(message.content)}') + zhipu_messages.append({ + 'role': message.role, + 'content': message.content, + }) headers = { 'Authorization': generate_token(self.api_key), } - params = {'prompt': messages, **kwargs} + parameters_dict = {} if parameters is None else to_dict(parameters, exclude_defaults=True) + params = {'prompt': messages, **parameters_dict} return { 'url': f'{self.api_base}/{self.model}/invoke', 'headers': headers, diff --git a/lmclient/openai_schema.py b/lmclient/openai_schema.py deleted file mode 100644 index ad44b21..0000000 --- a/lmclient/openai_schema.py +++ /dev/null @@ -1,68 +0,0 @@ -import json - -try: - from pydantic.v1 import BaseModel - from pydantic.v1 import Field as Field -except ImportError: - from pydantic import BaseModel - from pydantic import Field as Field - -from lmclient.parser import ParserError -from lmclient.types import ModelResponse - - -def _remove_a_key(d, remove_key) -> None: - """Remove a key from a dictionary recursively""" - if isinstance(d, dict): - for key in list(d.keys()): - if key == remove_key: - del d[key] - else: - _remove_a_key(d[key], remove_key) - - -class OpenAISchema(BaseModel): # type: ignore - @classmethod - def openai_schema(cls): - """ - Return the schema in the format of OpenAI's schema as jsonschema - - Note: - Its important to add a docstring to describe how to best use this class, it will be included in the description attribute and be part of the prompt. - - Returns: - model_json_schema (dict): A dictionary in the format of OpenAI's schema as jsonschema - """ - schema = cls.schema() - parameters = {k: v for k, v in schema.items() if k not in ('title', 'description')} - parameters['required'] = sorted(parameters['properties']) - _remove_a_key(parameters, 'title') - - if 'description' not in schema: - schema['description'] = f'Correctly extracted `{cls.__name__}` with all the required parameters with correct types' - - return { - 'name': schema['title'], - 'description': schema['description'], - 'parameters': parameters, - } - - @classmethod - def from_response(cls, response: ModelResponse): - """Execute the function from the response of an openai chat completion - - Parameters: - completion (openai.ChatCompletion): The response from an openai chat completion - throw_error (bool): Whether to throw an error if the function call is not detected - - Returns: - cls (OpenAISchema): An instance of the class - """ - message = response['choices'][0]['message'] - - if 'function_call' not in message: - raise ParserError('No function call detected') - - function_call = message['function_call'] - arguments = json.loads(function_call['arguments']) - return cls(**arguments) diff --git a/lmclient/types.py b/lmclient/types.py index 8161c70..663b3be 100644 --- a/lmclient/types.py +++ b/lmclient/types.py @@ -1,31 +1,52 @@ from __future__ import annotations -from typing import Any, Dict, Generic, Sequence, TypedDict, TypeVar, Union +from typing import Any, Dict, List, Optional, Union -try: - from pydantic.v1 import BaseModel, Field -except ImportError: - from pydantic import BaseModel, Field +from pydantic import BaseModel, Field +from typing_extensions import NotRequired, TypedDict -from typing_extensions import NotRequired -T = TypeVar('T') +class Message(BaseModel): + role: str + content: Union[str, FunctionCallDict] + name: Optional[str] = None + @property + def is_function_call(self) -> bool: + return isinstance(self.content, dict) -class Message(TypedDict): - role: str - content: str - name: NotRequired[str] - function_call: NotRequired[str] +class ChatModelOutput(BaseModel): + messages: Messages + response: ModelResponse = Field(default_factory=dict) -MessageRequiredKeys = ('role', 'content') -MessageNotRequiredKeys = ('name', 'function') -Messages = Sequence[Message] -ModelResponse = Dict[str, Any] -Prompt = Union[str, Sequence[dict]] +class FunctionDict(TypedDict): + name: str + description: NotRequired[str] + parameters: dict -class ChatModelOutput(BaseModel, Generic[T]): # type: ignore - parsed_result: T - response: ModelResponse = Field(default_factory=dict) + +class GeneralParameters(BaseModel): + temperature: float = 1 + top_p: float = 1 + max_tokens: Optional[int] = None + functions: Optional[List[FunctionDict]] = None + function_call: Optional[str] = None + + +class ModelParameters(BaseModel): + + @classmethod + def from_general_parameters(cls, general_parameters: GeneralParameters): + raise NotImplementedError + + +class FunctionCallDict(TypedDict): + name: str + arguments: str + + +Messages = List[Message] +ModelResponse = Dict[str, Any] +Prompt = Union[str, Messages] diff --git a/lmclient/utils.py b/lmclient/utils.py index 4f0da8b..832527d 100644 --- a/lmclient/utils.py +++ b/lmclient/utils.py @@ -1,19 +1,129 @@ from __future__ import annotations -from lmclient.types import Message, MessageNotRequiredKeys, MessageRequiredKeys, Messages, Prompt +import json +from functools import wraps +from typing import Any, Callable +from docstring_parser import parse +from pydantic import BaseModel, validate_arguments -def ensure_messages(value: Prompt) -> Messages: - if isinstance(value, str): - return [Message(role='user', content=value)] +from lmclient.exceptions import MessageError +from lmclient.types import FunctionDict, Message + + +def get_pydantic_version(): + import pydantic + from packaging import version + return version.parse(pydantic.__version__).major + + +PydanticVersion = get_pydantic_version() + + + +def _remove_a_key(d, remove_key) -> None: + """Remove a key from a dictionary recursively""" + if isinstance(d, dict): + for key in list(d.keys()): + if key == remove_key: + del d[key] + else: + _remove_a_key(d[key], remove_key) + + +class lm_function: + def __init__(self, func: Callable) -> None: + self.func = func + self.name = self.func.__name__ + self.validate_func = validate_arguments(func) + self.docstring = parse(self.func.__doc__ or '') + + parameters = self.validate_func.model.model_json_schema() + parameters["properties"] = { + k: v + for k, v in parameters["properties"].items() + if k not in ("v__duplicate_kwargs", "args", "kwargs") + } + for param in self.docstring.params: + if (name := param.arg_name) in parameters["properties"] and ( + description := param.description + ): + parameters["properties"][name]["description"] = description + parameters["required"] = sorted( + k for k, v in parameters["properties"].items() if "default" not in v + ) + _remove_a_key(parameters, "additionalProperties") + _remove_a_key(parameters, "title") + self.schema: FunctionDict = { + "name": self.name, + "description": self.docstring.short_description or '', + "parameters": parameters, + } + self.model = self.validate_func.model + + def __call__(self, *args: Any, **kwargs: Any) -> Any: + @wraps(self.func) + def wrapper(*args, **kwargs): + return self.validate_func(*args, **kwargs) + + return wrapper(*args, **kwargs) + + def from_message(self, message: Message): + function_call = message.content + if isinstance(function_call, str): + raise MessageError(f'{message} is not a valid function call message') + arguments = json.loads(function_call["arguments"], strict=False) + return self.validate_func(**arguments) + + +class LMSchema(BaseModel): + @classmethod + @property + def openai_schema(cls): + schema = cls.model_json_schema() + docstring = parse(cls.__doc__ or '') + parameters = { + k: v for k, v in schema.items() if k not in ("title", "description") + } + for param in docstring.params: + if (name := param.arg_name) in parameters["properties"] and ( + description := param.description + ): + if "description" not in parameters["properties"][name]: + parameters["properties"][name]["description"] = description + + parameters["required"] = sorted( + k for k, v in parameters["properties"].items() if "default" not in v + ) + + if "description" not in schema: + if docstring.short_description: + schema["description"] = docstring.short_description + else: + schema["description"] = ( + f"Correctly extracted `{cls.__name__}` with all " + f"the required parameters with correct types" + ) + + _remove_a_key(parameters, "additionalProperties") + _remove_a_key(parameters, "title") + return { + "name": schema["title"], + "description": schema["description"], + "parameters": parameters, + } + + @classmethod + def from_message(cls, message: Message): + function_call = message.content + if isinstance(function_call, str): + raise MessageError(f'{message} is not a valid function call message') + arguments = json.loads(function_call["arguments"], strict=False) + return cls(**arguments) + + +def to_dict(value: BaseModel, exclude_defaults: bool = False): + if PydanticVersion == 2: + return value.model_dump(exclude_defaults=exclude_defaults) else: - messages: list[Message] = [] - for message_dict in value: - temp_dict = {} - for key in MessageRequiredKeys: - temp_dict[key] = message_dict[key] - for key in MessageNotRequiredKeys: - if key in message_dict: - temp_dict[key] = message_dict[key] - messages.append(Message(**temp_dict)) - return messages + return value.dict(exclude_defaults=exclude_defaults) diff --git a/poetry.lock b/poetry.lock index c48b314..09a9803 100644 --- a/poetry.lock +++ b/poetry.lock @@ -168,6 +168,17 @@ files = [ {file = "diskcache-5.6.1.tar.gz", hash = "sha256:e4c978532feff5814c4cc00fe1e11e40501985946643d73220d41ee7737c72c3"}, ] +[[package]] +name = "docstring-parser" +version = "0.15" +description = "Parse Python docstrings in reST, Google and Numpydoc format" +optional = false +python-versions = ">=3.6,<4.0" +files = [ + {file = "docstring_parser-0.15-py3-none-any.whl", hash = "sha256:d1679b86250d269d06a99670924d6bce45adc00b08069dae8c47d98e89b667a9"}, + {file = "docstring_parser-0.15.tar.gz", hash = "sha256:48ddc093e8b1865899956fcc03b03e66bb7240c310fac5af81814580c55bf682"}, +] + [[package]] name = "exceptiongroup" version = "1.1.1" @@ -724,4 +735,4 @@ test = ["websockets"] [metadata] lock-version = "2.0" python-versions = "^3.8" -content-hash = "42a9a8ba6ad4111cab7314d4a678703947262926ca73cdf65f6aae64a1da7afa" +content-hash = "f3975d2b5cc36b16efc98df46aef32c64d86dac5fbb9c822c30119eac8094ebd" diff --git a/pyproject.toml b/pyproject.toml index f30df87..f34bef2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,6 +21,7 @@ pyjwt = "^2.8.0" cachetools = "^5.3.1" tqdm = "^4.66.1" pydantic = ">1.0" +docstring-parser = "^0.15" [tool.ruff] line-length = 128 @@ -38,6 +39,9 @@ ignore = [ "C901", # too complex ] +[tool.pyright] +reportIncompatibleMethodOverride=true + [tool.poetry.group.dev.dependencies] pytest = "^7.3.1" blue = "^0.9.1" diff --git a/scripts/ner.py b/scripts/ner.py index a32f992..50d4142 100644 --- a/scripts/ner.py +++ b/scripts/ner.py @@ -57,7 +57,7 @@ def main( model_outputs = client.async_run(texts) with open(output_file, 'w') as f: for text, output in zip(texts, model_outputs): - output = output.parsed_result.dict() if output.parsed_result else None + output = output.message.dict() if output.message else None output_dict = {'text': text, 'output': output} f.write(json.dumps(output_dict, ensure_ascii=False) + '\n') diff --git a/scripts/translate.py b/scripts/translate.py index 1e817f6..dbc207f 100644 --- a/scripts/translate.py +++ b/scripts/translate.py @@ -50,7 +50,7 @@ def main( with open(output_file, 'w') as f: for text, result in zip(texts, results): - f.write(json.dumps({'text': text, 'translation': result.parsed_result}, ensure_ascii=False) + '\n') + f.write(json.dumps({'text': text, 'translation': result.message}, ensure_ascii=False) + '\n') if __name__ == '__main__': diff --git a/tests/test_client.py b/tests/test_client.py index 6e8219f..cf31265 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -8,14 +8,14 @@ class TestModel(BaseChatModel): - def call_model(self, messages: Messages, **kwargs) -> ModelResponse: + def chat_completion(self, messages: Messages, **kwargs) -> ModelResponse: return { - 'content': f'Completed: {messages[-1]["content"]}', + 'content': f'Completed: {messages[-1].content}', } - async def async_call_model(self, messages: Messages, **kwargs) -> ModelResponse: + async def async_chat_completion(self, messages: Messages, **kwargs) -> ModelResponse: return { - 'content': f'Completed: {messages[-1]["content"]}', + 'content': f'Completed: {messages[-1].content}', } def default_postprocess_function(self, response: ModelResponse) -> str: @@ -42,9 +42,9 @@ def test_sync_completion(): ] results = client.run(prompts) - assert isinstance(results[0].parsed_result, str) - assert results[0].parsed_result == 'Completed: Hello, my name is' - assert results[1].parsed_result == 'Completed: hello, who are you?' + assert isinstance(results[0].message, str) + assert results[0].message == 'Completed: Hello, my name is' + assert results[1].message == 'Completed: hello, who are you?' assert len(results) == len(prompts) @@ -63,13 +63,13 @@ def test_async_completion(): elapsed_time = time.perf_counter() - start_time assert results[0].response['content'] == 'Completed: Hello, my name is' - assert results[0].parsed_result == 'Completed: Hello, my name is' + assert results[0].message == 'Completed: Hello, my name is' assert len(results) == len(prompts) assert elapsed_time > 4 def test_async_completion_with_cache(tmp_path): - completion_model = TestModel(use_cache=tmp_path) + completion_model = TestModel(use_cache=tmp_path, response_parser=model_parser) client = LMClient(completion_model, async_capacity=2, max_requests_per_minute=5) LMClient.NUM_SECONDS_PER_MINUTE = 2 diff --git a/tests/test_model.py b/tests/test_model.py index 9dee9b5..59807c0 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -1,72 +1,74 @@ import anyio import pytest -from lmclient.models import AzureChat, MinimaxChat, OpenAIChat, ZhiPuChat -from lmclient.models.openai import OpenAIContentParser +from lmclient.models import AzureChat, MinimaxProChat, OpenAIChat, ZhiPuChat +from lmclient.models.openai import OpenAITextParser +from lmclient.types import Message +test_messages = [Message(role='user', content='hello')] @pytest.mark.parametrize( 'prompt', [ 'Hello, my name is', - [{'role': 'system', 'content': 'your are lmclient demo assistant'}, {'role': 'user', 'content': 'hello, who are you?'}], + test_messages ], ) def test_azure_model(prompt): - model = AzureChat(response_parser=OpenAIContentParser()) + model = AzureChat(response_parser=OpenAITextParser()) sync_output = model.chat(prompt) async_output = anyio.run(model.async_chat, prompt) assert isinstance(sync_output.response, dict) - assert isinstance(sync_output.parsed_result, str) + assert isinstance(sync_output.message, str) assert isinstance(async_output.response, dict) - assert isinstance(async_output.parsed_result, str) + assert isinstance(async_output.message, str) @pytest.mark.parametrize( 'prompt', [ 'Hello, my name is', - [{'role': 'system', 'content': 'your are lmclient demo assistant'}, {'role': 'user', 'content': 'hello, who are you?'}], + test_messages ], ) def test_openai_model(prompt): - chat_model = OpenAIChat('gpt-3.5-turbo', response_parser=OpenAIContentParser()) + chat_model = OpenAIChat('gpt-3.5-turbo', response_parser=OpenAITextParser()) sync_output = chat_model.chat(prompt) async_output = anyio.run(chat_model.async_chat, prompt) assert isinstance(sync_output.response, dict) - assert isinstance(sync_output.parsed_result, str) + assert isinstance(sync_output.message, str) assert isinstance(async_output.response, dict) - assert isinstance(async_output.parsed_result, str) + assert isinstance(async_output.message, str) @pytest.mark.parametrize( 'prompt', [ 'Hello, my name is', - [{'role': 'system', 'content': 'your are lmclient demo assistant'}, {'role': 'user', 'content': 'hello, who are you?'}], + test_messages ], ) def test_minimax_model(prompt): - completion_model = MinimaxChat('abab5.5-chat') + completion_model = MinimaxProChat('abab5.5-chat') sync_output = completion_model.chat(prompt) async_output = anyio.run(completion_model.async_chat, prompt) assert isinstance(sync_output.response, dict) - assert isinstance(sync_output.parsed_result, str) + assert isinstance(sync_output.message, str) assert isinstance(async_output.response, dict) - assert isinstance(async_output.parsed_result, str) + assert isinstance(async_output.message, str) @pytest.mark.parametrize( 'prompt', [ 'Hello, my name is', - [{'role': 'user', 'content': 'hello, who are you?'}], + test_messages ], ) def test_zhipu_model(prompt): @@ -76,6 +78,6 @@ def test_zhipu_model(prompt): async_output = anyio.run(completion_model.async_chat, prompt) assert isinstance(sync_output.response, dict) - assert isinstance(sync_output.parsed_result, str) + assert isinstance(sync_output.message, str) assert isinstance(async_output.response, dict) - assert isinstance(async_output.parsed_result, str) + assert isinstance(async_output.message, str) From 1ed208e7cc27c5d25e4bf5ed227d2e8a6e327a91 Mon Sep 17 00:00:00 2001 From: wangyuxin Date: Sat, 2 Sep 2023 16:46:41 +0800 Subject: [PATCH 2/6] add function --- README.md | 18 +++- demo.py | 0 lmclient/__init__.py | 31 ++++-- lmclient/cache.py | 32 ++---- lmclient/chat_engine.py | 77 +++++++------- lmclient/client.py | 98 +++++++++++------- lmclient/exceptions.py | 1 + lmclient/function.py | 86 ---------------- lmclient/models/__init__.py | 44 +++++++- lmclient/models/azure.py | 26 ++--- lmclient/models/base.py | 77 ++++++++++++-- lmclient/models/http.py | 68 ++++++------ lmclient/models/minimax_pro.py | 159 +++++++++++------------------ lmclient/models/openai.py | 68 ++++++------ lmclient/models/zhipu.py | 65 ++++++------ lmclient/types.py | 37 +++++-- lmclient/utils.py | 91 +++++++++-------- lmclient/version.py | 4 +- pyproject.toml | 4 +- scripts/data/ner_input.jsonl | 3 - scripts/data/translate_input.jsonl | 33 +++++- scripts/ner.py | 66 ------------ scripts/translate.py | 38 +++---- tests/test_client.py | 67 ++++++------ tests/test_model.py | 83 ++++++--------- 25 files changed, 614 insertions(+), 662 deletions(-) delete mode 100644 demo.py delete mode 100644 lmclient/function.py delete mode 100644 scripts/data/ner_input.jsonl delete mode 100644 scripts/ner.py diff --git a/README.md b/README.md index 17d23ad..4c917f6 100644 --- a/README.md +++ b/README.md @@ -13,6 +13,8 @@ 5. 支持磁盘缓存 6. 100% type hints 7. 非常易用 +8. 支持 OpenAI, Azure, Minimax, ZhiPu 等模型 +9. 支持 FunctionCall ## 安装方式 支持 python3.8 及以上 @@ -22,10 +24,11 @@ pip install lmclient-core ## 使用方法 +1. LMClient ```python -from lmclient import LMClient, OpenAIChat +from lmclient import LMClient, OpenAIChat, OpenAIChatParameters -model = OpenAIChat('gpt-3.5-turbo') +model = OpenAIChat('gpt-3.5-turbo', parameters=OpenAIChatParameters(temperature=0)) # 控制每分钟最大请求次数为 20, 异步容量为 5 client = LMClient(model, async_capacity=5, max_requests_per_minute=20) prompts = [ @@ -34,9 +37,18 @@ prompts = [ [{'role': 'system', 'content': 'your are lmclient demo assistant'}, {'role': 'user', 'content': 'hello, who are you?'}], 'what is your name?', ] -values = client.async_run(prompts=prompts, temperature=0) +values = client.run(prompts=prompts) print(values) ``` +2. ChatEngine +```python +from lmclient import ChatEngine, OpenAIChat + +model = OpenAIChat('gpt-3.5-turbo') +chat_engine = ChatEngine(model) +print(chat_engine.chat('你好,我是 chat_engine')) +print(chat_engine.chat('我上一句话是什么?'))) +``` ## 使用样例: 大规模翻译 diff --git a/demo.py b/demo.py deleted file mode 100644 index e69de29..0000000 diff --git a/lmclient/__init__.py b/lmclient/__init__.py index 43634b5..b36069e 100644 --- a/lmclient/__init__.py +++ b/lmclient/__init__.py @@ -1,6 +1,25 @@ -from lmclient.client import LMClient as LMClient -from lmclient.models import AzureChat as AzureChat -from lmclient.models import MinimaxProChat as MinimaxProChat -from lmclient.models import OpenAIChat as OpenAIChat -from lmclient.models import OpenAIExtract as OpenAIExtract -from lmclient.models import ZhiPuChat as ZhiPuChat +from lmclient.chat_engine import ChatEngine +from lmclient.client import LMClient +from lmclient.models import ( + AzureChat, + MinimaxProChat, + MinimaxProChatParameters, + OpenAIChat, + OpenAIChatParameters, + ZhiPuChat, + ZhiPuChatParameters, +) +from lmclient.version import __version__ + +__all__ = [ + 'LMClient', + 'ChatEngine', + 'AzureChat', + 'OpenAIChat', + 'OpenAIChatParameters', + 'MinimaxProChat', + 'MinimaxProChatParameters', + 'ZhiPuChat', + 'ZhiPuChatParameters', + '__version__', +] diff --git a/lmclient/cache.py b/lmclient/cache.py index 7bda4b1..4490bb9 100644 --- a/lmclient/cache.py +++ b/lmclient/cache.py @@ -1,21 +1,17 @@ from __future__ import annotations -import hashlib import os from pathlib import Path from typing import cast import diskcache -from lmclient.types import Messages, ModelResponse, Prompt, ModelParameters -from lmclient.utils import to_dict -from lmclient.version import __cache_version__ +from lmclient.types import ChatModelOutput DEFAULT_CACHE_DIR = Path(os.getenv('LMCLIENT_CACHE_DIR', '~/.cache/lmclient')).expanduser().resolve() class ChatCacheMixin: - identifier: str _cache: diskcache.Cache | None _cache_dir: Path | None @@ -27,32 +23,16 @@ def __init__(self, use_cache: Path | str | bool = False) -> None: else: self.cache_dir = None - def cache_response(self, key: str, response: ModelResponse) -> None: + def cache_model_output(self, key: str, model_output: ChatModelOutput) -> None: if self._cache is not None: - self._cache[key] = response + self._cache[key] = model_output else: raise RuntimeError('Cache is not enabled') - def try_load_response(self, key: str): + def try_load_model_output(self, key: str): if self._cache is not None and key in self._cache: - response = self._cache[key] - response = cast(ModelResponse, response) - return response - - def generate_hash_key(self, messages: Messages, parameters: ModelParameters) -> str: - if isinstance(prompt, str): - hash_text = prompt - else: - hash_text = '---'.join([f'{k}={v}' for message in prompt for k, v in to_dict(message).items()]) - items = sorted([f'{key}={value}' for key, value in parameters.model_dump()]) - items += [f'__cache_version__={__cache_version__}'] - items = [hash_text, self.identifier] + items - task_string = '---'.join(items) - return self.md5_hash(task_string) - - @staticmethod - def md5_hash(string: str): - return hashlib.md5(string.encode()).hexdigest() + model_output = cast(ChatModelOutput, self._cache[key]) + return model_output @property def use_cache(self) -> bool: diff --git a/lmclient/chat_engine.py b/lmclient/chat_engine.py index 1f3b1e5..82082b2 100644 --- a/lmclient/chat_engine.py +++ b/lmclient/chat_engine.py @@ -1,24 +1,31 @@ from __future__ import annotations import json -from typing import List, Optional +from typing import Any, Generic, List, Optional, TypeVar, cast -from lmclient.models.base import BaseChatModel -from lmclient.types import FunctionCallDict, GeneralParameters, Message, Messages, ModelParameters -from lmclient.utils import lm_function +from lmclient.models import BaseChatModel, load_from_model_id +from lmclient.types import ChatModelOutput, FunctionCallDict, GeneralParameters, Message, Messages, ModelParameters +from lmclient.utils import function +T_P = TypeVar('T_P', bound=ModelParameters) +T_O = TypeVar('T_O', bound=ChatModelOutput) -class ChatEngine: + +class ChatEngine(Generic[T_P, T_O]): def __init__( self, - chat_model: BaseChatModel, + chat_model: BaseChatModel[T_P, T_O] | str, temperature: float = 1, top_p: float = 1, - functions: Optional[List[lm_function]] = None, + functions: Optional[List[function]] = None, function_call_raise_error: bool = False, - **extra_model_parameters + **extra_parameters: Any, ): - self._chat_model = chat_model + if isinstance(chat_model, str): + self._chat_model: BaseChatModel[T_P, T_O] = load_from_model_id(chat_model) # type: ignore + else: + self._chat_model = chat_model + self.functions = functions or [] if functions: functions_schema = [function.schema for function in functions] @@ -33,42 +40,32 @@ def __init__( functions=functions_schema, function_call=function_call, ) - self._extra_model_parameters = extra_model_parameters - self._model_parameters: ModelParameters = self._chat_model.parameters_type.from_general_parameters(self.engine_parameters).model_copy(update=self._extra_model_parameters) + self._extra_parameters = extra_parameters + _parameters = self._chat_model.parameters_type.from_general_parameters(self.engine_parameters).model_copy( + update=self._extra_parameters + ) + self._parameters = cast(T_P, _parameters) self.function_call_raise_error = function_call_raise_error self.history: Messages = [] @property - def chat_model(self): - return self._chat_model - - @chat_model.setter - def chat_model(self, model: BaseChatModel): - self._chat_model = model - self._model_parameters = self._chat_model.parameters_type.from_general_parameters(self.engine_parameters).model_copy(update=self._extra_model_parameters) - - @property - def model_parameters(self): - return self._model_parameters + def extra_parameters(self) -> dict[str, Any]: + return self._extra_parameters - @property - def extra_model_parameters(self): - return self._extra_model_parameters + @extra_parameters.setter + def extra_parameters(self, extra_parameters: dict[str, Any]): + self._extra_parameters = extra_parameters + self._parameters = self._parameters.model_copy(update=self._extra_parameters) - @extra_model_parameters.setter - def extra_model_parameters(self, extra_model_parameters: dict): - self._extra_model_parameters = extra_model_parameters - self._model_parameters = self.model_parameters.model_copy(update=self._extra_model_parameters) - - def chat(self, user_input: str, **extra_model_parameters) -> str: - model_parameters = self.model_parameters.model_copy(update=extra_model_parameters) + def chat(self, user_input: str, **extra_parameters: Any) -> str: + parameters = self._parameters.model_copy(update=extra_parameters) self.history.append(Message(role='user', content=user_input)) - model_response = self.chat_model.chat_completion(self.history, model_parameters) + model_response = self._chat_model.chat_completion(self.history, parameters) self.history.extend(model_response.messages) if isinstance(reply := model_response.messages[-1].content, str): return reply else: - return self._recursive_function_call(reply, model_parameters) + return self._recursive_function_call(reply, parameters) def run_function_call(self, function_call: FunctionCallDict): function = None @@ -82,7 +79,7 @@ def run_function_call(self, function_call: FunctionCallDict): return 'Function not found, please try another function.' try: - arguments = json.loads(function_call["arguments"], strict=False) + arguments = json.loads(function_call['arguments'], strict=False) return function(**arguments) except Exception as e: if self.function_call_raise_error: @@ -90,15 +87,17 @@ def run_function_call(self, function_call: FunctionCallDict): else: return f'Error: {e}' - def _recursive_function_call(self, function_call: FunctionCallDict, model_parameters: ModelParameters) -> str: + def _recursive_function_call(self, function_call: FunctionCallDict, parameters: T_P) -> str: function_output = self.run_function_call(function_call) - self.history.append(Message(role='function', name=function_call['name'], content=json.dumps(function_output, ensure_ascii=False))) - model_response = self.chat_model.chat_completion(self.history, model_parameters) + self.history.append( + Message(role='function', name=function_call['name'], content=json.dumps(function_output, ensure_ascii=False)) + ) + model_response = self._chat_model.chat_completion(self.history, parameters) self.history.extend(model_response.messages) if isinstance(reply := model_response.messages[-1].content, str): return reply else: - return self._recursive_function_call(reply, model_parameters) + return self._recursive_function_call(reply, parameters) def reset(self) -> None: self.history.clear() diff --git a/lmclient/client.py b/lmclient/client.py index 9874731..7b05e15 100644 --- a/lmclient/client.py +++ b/lmclient/client.py @@ -4,14 +4,15 @@ import time from enum import Enum from pathlib import Path -from typing import ClassVar, Generic, Sequence +from typing import ClassVar, Generic, NoReturn, Sequence, cast import anyio import asyncer import tqdm -from lmclient.models.base import BaseChatModel, T -from lmclient.types import ChatModelOutput, Message, Prompt +from lmclient.models import load_from_model_id +from lmclient.models.base import T_O, T_P, BaseChatModel +from lmclient.types import ChatModelOutput, Message, Messages, Prompt DEFAULT_CACHE_DIR = Path(os.getenv('LMCLIENT_CACHE_DIR', '~/.cache/lmclient')).expanduser().resolve() @@ -27,36 +28,59 @@ class ProgressBarMode(str, Enum): NEVER = 'never' -class LMClient(Generic[T]): +def ensure_messages(prompt: Prompt) -> Messages: + if isinstance(prompt, str): + return [Message(role='user', content=prompt)] + elif isinstance(prompt, Message): + return [prompt] + elif isinstance(prompt, dict): + return [Message(**prompt)] + else: + messages: list[Message] = [] + for message in prompt: + if isinstance(message, dict): + messages.append(Message(**message)) + else: + messages.append(message) + return messages + + +class LMClient(Generic[T_P, T_O]): error_mode: ErrorMode NUM_SECONDS_PER_MINUTE: ClassVar[int] = 60 PROGRESS_BAR_THRESHOLD: ClassVar[int] = 20 def __init__( self, - chat_model: BaseChatModel[T], + chat_model: BaseChatModel[T_P, T_O] | str, async_capacity: int = 3, max_requests_per_minute: int = 20, error_mode: ErrorMode | str = ErrorMode.RAISE, progress_bar: ProgressBarMode | str = ProgressBarMode.AUTO, ): - self.chat_model = chat_model + if isinstance(chat_model, str): + chat_model = load_from_model_id(chat_model) # type: ignore + self.chat_model = cast(BaseChatModel[T_P, T_O], chat_model) + else: + self.chat_model = chat_model self.async_capacity = async_capacity self.max_requests_per_minute = max_requests_per_minute self.error_mode = ErrorMode(error_mode) self.progress_bar_mode = ProgressBarMode(progress_bar) self._task_created_time_list: list[int] = [] - def run(self, prompts: Sequence[Prompt], **kwargs) -> list[ChatModelOutput]: + def run(self, prompts: Sequence[Prompt], override_parameters: T_P | None = None) -> list[ChatModelOutput]: progress_bar = self._get_progress_bar(num_tasks=len(prompts)) task_results: list[ChatModelOutput] = [] for prompt in prompts: - task_result = self._run_single_task(prompt=prompt, progress_bar=progress_bar, **kwargs) + task_result = self._run_single_task( + prompt=prompt, progress_bar=progress_bar, override_parameters=override_parameters + ) task_results.append(task_result) progress_bar.close() return task_results - async def _async_run(self, prompts: Sequence[Prompt], **kwargs) -> list[ChatModelOutput]: + async def _async_run(self, prompts: Sequence[Prompt], override_parameters: T_P | None = None) -> list[ChatModelOutput]: limiter = anyio.CapacityLimiter(self.async_capacity) task_created_lock = anyio.Lock() progress_bar = self._get_progress_bar(num_tasks=len(prompts)) @@ -70,7 +94,7 @@ async def _async_run(self, prompts: Sequence[Prompt], **kwargs) -> list[ChatMode limiter=limiter, task_created_lock=task_created_lock, progress_bar=progress_bar, - **kwargs, + override_parameters=override_parameters, ) soon_values.append(soon_value) @@ -78,32 +102,28 @@ async def _async_run(self, prompts: Sequence[Prompt], **kwargs) -> list[ChatMode values = [soon_value.value for soon_value in soon_values] return values - def async_run(self, prompts: Sequence[Prompt], **kwargs) -> list[ChatModelOutput]: - return asyncer.runnify(self._async_run)(prompts, **kwargs) + def async_run(self, prompts: Sequence[Prompt], override_parameters: T_P | None = None) -> list[ChatModelOutput]: + return asyncer.runnify(self._async_run)(prompts, override_parameters) async def _async_run_single_task( self, prompt: Prompt, limiter: anyio.CapacityLimiter, task_created_lock: anyio.Lock, - progress_bar: tqdm.tqdm, - override_parameters: T | None = None, + progress_bar: tqdm.tqdm[NoReturn], + override_parameters: T_P | None = None, ) -> ChatModelOutput: - if isinstance(prompt, str): - prompt = [Message(role='user', content=prompt)] - async with limiter: - task_key = self.chat_model.generate_hash_key(prompt=prompt, override_parameters) - response = self.chat_model.try_load_response(task_key) - - if response is None: - async with task_created_lock: - sleep_time = self._calculate_sleep_time() - if sleep_time > 0: - await anyio.sleep(sleep_time) - self._task_created_time_list.append(int(time.time())) + messages = ensure_messages(prompt) + async with limiter: try: - output = await self.chat_model.async_chat_completion(messages=prompt, override_parameters=**kwargs) + output = await self.chat_model.async_chat_completion(messages=messages, override_parameters=override_parameters) + if not output.is_cache: + async with task_created_lock: + sleep_time = self._calculate_sleep_time() + if sleep_time > 0: + await anyio.sleep(sleep_time) + self._task_created_time_list.append(int(time.time())) progress_bar.update(1) return output except BaseException as e: @@ -114,25 +134,25 @@ async def _async_run_single_task( else: raise ValueError(f'Unknown error mode: {self.error_mode}') from e - def _run_single_task(self, prompt: Prompt, progress_bar: tqdm.tqdm, **kwargs) -> ChatModelOutput: - task_key = self.chat_model.generate_hash_key(prompt=prompt, **kwargs) - response = self.chat_model.try_load_response(task_key) - - if response is None: - sleep_time = self._calculate_sleep_time() - if sleep_time > 0: - time.sleep(sleep_time) - self._task_created_time_list.append(int(time.time())) + def _run_single_task( + self, prompt: Prompt, progress_bar: tqdm.tqdm[NoReturn], override_parameters: T_P | None = None + ) -> ChatModelOutput: + messages = ensure_messages(prompt) try: - output = self.chat_model.chat_completion(messages=prompt, **kwargs) + output = self.chat_model.chat_completion(messages=messages, override_parameters=override_parameters) + if not output.is_cache: + sleep_time = self._calculate_sleep_time() + if sleep_time > 0: + time.sleep(sleep_time) + self._task_created_time_list.append(int(time.time())) progress_bar.update(1) return output except BaseException as e: if self.error_mode is ErrorMode.RAISE: raise elif self.error_mode is ErrorMode.IGNORE: - return ChatModelOutput(message=f'Response Error: {e}', response={}) + return ChatModelOutput(messages=[Message(role='Error', content=str(e))]) else: raise ValueError(f'Unknown error mode: {self.error_mode}') from e @@ -150,7 +170,7 @@ def _calculate_sleep_time(self): else: return max(self.NUM_SECONDS_PER_MINUTE - int(current_time - self._task_created_time_list[0]) + 1, 0) - def _get_progress_bar(self, num_tasks: int) -> tqdm.tqdm: + def _get_progress_bar(self, num_tasks: int) -> tqdm.tqdm[NoReturn]: use_progress_bar = (self.progress_bar_mode is ProgressBarMode.ALWAYS) or ( self.progress_bar_mode is ProgressBarMode.AUTO and num_tasks > self.PROGRESS_BAR_THRESHOLD ) diff --git a/lmclient/exceptions.py b/lmclient/exceptions.py index 3af3f6f..0d0c18c 100644 --- a/lmclient/exceptions.py +++ b/lmclient/exceptions.py @@ -2,4 +2,5 @@ class MessageError(Exception): """ Base class for all message errors. """ + pass diff --git a/lmclient/function.py b/lmclient/function.py deleted file mode 100644 index 370b36e..0000000 --- a/lmclient/function.py +++ /dev/null @@ -1,86 +0,0 @@ -# MIT License -# -# Copyright (c) 2023 Jason Liu -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. - -import json -from functools import wraps -from typing import Any, Callable - -from docstring_parser import parse -from pydantic import validate_arguments - -from lmclient.exceptions import MessageError -from lmclient.types import Message - - -def _remove_a_key(d, remove_key) -> None: - """Remove a key from a dictionary recursively""" - if isinstance(d, dict): - for key in list(d.keys()): - if key == remove_key: - del d[key] - else: - _remove_a_key(d[key], remove_key) - - -class lm_function: - def __init__(self, func: Callable) -> None: - self.func = func - self.name = self.func.__name__ - self.validate_func = validate_arguments(func) - self.docstring = parse(self.func.__doc__ or '') - - parameters = self.validate_func.model.model_json_schema() - parameters["properties"] = { - k: v - for k, v in parameters["properties"].items() - if k not in ("v__duplicate_kwargs", "args", "kwargs") - } - for param in self.docstring.params: - if (name := param.arg_name) in parameters["properties"] and ( - description := param.description - ): - parameters["properties"][name]["description"] = description - parameters["required"] = sorted( - k for k, v in parameters["properties"].items() if "default" not in v - ) - _remove_a_key(parameters, "additionalProperties") - _remove_a_key(parameters, "title") - self.openai_schema = { - "name": self.name, - "description": self.docstring.short_description, - "parameters": parameters, - } - self.model = self.validate_func.model - - def __call__(self, *args: Any, **kwargs: Any) -> Any: - @wraps(self.func) - def wrapper(*args, **kwargs): - return self.validate_func(*args, **kwargs) - - return wrapper(*args, **kwargs) - - def from_message(self, message: Message): - function_call = message.content - if isinstance(function_call, str): - raise MessageError(f'{message} is not a valid function call message') - arguments = json.loads(function_call["arguments"], strict=False) - return self.validate_func(**arguments) diff --git a/lmclient/models/__init__.py b/lmclient/models/__init__.py index a0a3a13..86a17e1 100644 --- a/lmclient/models/__init__.py +++ b/lmclient/models/__init__.py @@ -1,5 +1,39 @@ -from lmclient.models.azure import AzureChat as AzureChat -from lmclient.models.base import BaseChatModel as BaseChatModel -from lmclient.models.minimax_pro import MinimaxProChat as MinimaxProChat -from lmclient.models.openai import OpenAIChat as OpenAIChat -from lmclient.models.zhipu import ZhiPuChat as ZhiPuChat +from typing import Any + +from lmclient.models.azure import AzureChat +from lmclient.models.base import BaseChatModel +from lmclient.models.minimax_pro import MinimaxProChat, MinimaxProChatParameters +from lmclient.models.openai import OpenAIChat, OpenAIChatParameters +from lmclient.models.zhipu import ZhiPuChat, ZhiPuChatParameters + +ModelRegistry = { + AzureChat.model_type: AzureChat, + OpenAIChat.model_type: OpenAIChat, + MinimaxProChat.model_type: MinimaxProChat, + ZhiPuChat.model_type: ZhiPuChat, +} + + +def load_from_model_id(model_id: str, **kwargs: Any): + if '/' not in model_id: + model_type = model_id + return ModelRegistry[model_type](**kwargs) # type: ignore + model_type, name = model_id.split('/') + model_cls = ModelRegistry[model_type] + return model_cls.from_name(name, **kwargs) + + +def list_chat_model_types(): + return list(ModelRegistry.keys()) + + +__all__ = [ + 'AzureChat', + 'BaseChatModel', + 'MinimaxProChat', + 'MinimaxProChatParameters', + 'OpenAIChat', + 'OpenAIChatParameters', + 'ZhiPuChat', + 'ZhiPuChatParameters', +] diff --git a/lmclient/models/azure.py b/lmclient/models/azure.py index f5a2441..2a1b5f8 100644 --- a/lmclient/models/azure.py +++ b/lmclient/models/azure.py @@ -7,7 +7,6 @@ from lmclient.models.http import HttpChatModel, RetryStrategy from lmclient.models.openai import ( OpenAIChatParameters, - OpenAIMessageDict, convert_lmclient_to_openai, parse_openai_model_reponse, ) @@ -16,37 +15,34 @@ class AzureChat(HttpChatModel[OpenAIChatParameters]): - parameters_type = OpenAIChatParameters + model_type = 'azure' def __init__( self, model: str | None = None, - system_prompt: str | None = None, api_key: str | None = None, api_base: str | None = None, api_version: str | None = None, timeout: int | None = 60, retry: bool | RetryStrategy = False, - default_parameters: OpenAIChatParameters | None = None, + parameters: OpenAIChatParameters = OpenAIChatParameters(), use_cache: Path | str | bool = False, ): - super().__init__(default_parameters=default_parameters, timeout=timeout, retry=retry, use_cache=use_cache) + super().__init__(parameters=parameters, timeout=timeout, retry=retry, use_cache=use_cache) self.model = model or os.environ['AZURE_CHAT_API_ENGINE'] or os.environ['AZURE_CHAT_MODEL_NAME'] - self.system_prompt = system_prompt self.api_key = api_key or os.environ['AZURE_API_KEY'] self.api_base = api_base or os.environ['AZURE_API_BASE'] self.api_version = api_version or os.getenv('AZURE_API_VERSION') - def get_post_parameters(self, messages: Messages, parameters: OpenAIChatParameters | None = None) -> dict[str, Any]: + def get_request_parameters(self, messages: Messages, parameters: OpenAIChatParameters) -> dict[str, Any]: headers = { 'api-key': self.api_key, } - parameters_dict = {} if parameters is None else to_dict(parameters, exclude_defaults=True) - openai_messages: list[OpenAIMessageDict] = [] if self.system_prompt is None else [{'role': 'system', 'content': self.system_prompt}] - openai_messages = openai_messages + [convert_lmclient_to_openai(message) for message in messages] + parameters_dict = to_dict(parameters, exclude_defaults=True) + openai_messages = [convert_lmclient_to_openai(message) for message in messages] params = { 'model': self.model, - 'messages': [convert_lmclient_to_openai(message) for message in messages], + 'messages': openai_messages, **parameters_dict, } return { @@ -59,5 +55,9 @@ def parse_model_reponse(self, response: ModelResponse) -> Messages: return parse_openai_model_reponse(response) @property - def identifier(self) -> str: - return f'{self.__class__.__name__}({self.model})' + def name(self) -> str: + return self.model + + @classmethod + def from_name(cls, name: str, **kwargs: Any): + return cls(model=name, **kwargs) diff --git a/lmclient/models/base.py b/lmclient/models/base.py index ddb5898..0012620 100644 --- a/lmclient/models/base.py +++ b/lmclient/models/base.py @@ -2,30 +2,89 @@ from abc import ABC, abstractmethod from pathlib import Path -from typing import ClassVar, Generic, Type, TypeVar +from typing import Any, ClassVar, Generic, Type, TypeVar, cast + +from typing_extensions import Self from lmclient.cache import ChatCacheMixin from lmclient.types import ChatModelOutput, Messages, ModelParameters +from lmclient.utils import generate_chat_completion_hash_key -T = TypeVar("T", bound=ModelParameters) +T_P = TypeVar('T_P', bound=ModelParameters) +T_O = TypeVar('T_O', bound=ChatModelOutput) -class BaseChatModel(ABC, Generic[T], ChatCacheMixin): - parameters_type: ClassVar[Type[ModelParameters]] +class BaseChatModel(Generic[T_P, T_O], ChatCacheMixin, ABC): + model_type: ClassVar[str] - def __init__(self, default_parameters: T | None = None, use_cache: Path | str | bool = False) -> None: + def __init__(self, parameters: T_P, use_cache: Path | str | bool = False) -> None: super().__init__(use_cache=use_cache) - self.default_parameters = default_parameters + self.parameters = parameters + self.parameters_type: Type[T_P] = parameters.__class__ @property @abstractmethod - def identifier(self) -> str: + def name(self) -> str: ... @abstractmethod - def chat_completion(self, messages: Messages, override_parameters: T | None = None) -> ChatModelOutput: + def _chat_completion(self, messages: Messages, parameters: T_P) -> T_O: ... @abstractmethod - async def async_chat_completion(self, messages: Messages, override_parameters: T | None = None) -> ChatModelOutput: + async def _async_chat_completion(self, messages: Messages, parameters: T_P) -> T_O: ... + + @classmethod + @abstractmethod + def from_name(cls, name: str, **kwargs: Any) -> Self: + ... + + @property + def model_id(self) -> str: + return f'{self.model_type}/{self.name}' + + def chat_completion(self, messages: Messages, override_parameters: T_P | None = None) -> T_O: + if override_parameters is not None: + parameters = self.parameters.model_copy(update=override_parameters.model_dump()) + else: + parameters = self.parameters + + if self.use_cache: + hash_key = generate_chat_completion_hash_key(self.model_id, messages, parameters) + cached_output = self.try_load_model_output(hash_key) + if cached_output is not None: + cached_output.is_cache = True + cached_output.hash_key = hash_key + cached_output = cast(T_O, cached_output) + return cached_output + else: + model_output = self._chat_completion(messages, parameters) + model_output.hash_key = hash_key + self.cache_model_output(hash_key, model_output) + return model_output + else: + model_output = self._chat_completion(messages, parameters) + return model_output + + async def async_chat_completion(self, messages: Messages, override_parameters: T_P | None = None) -> T_O: + if override_parameters is not None: + parameters = self.parameters.model_copy(update=override_parameters.model_dump()) + else: + parameters = self.parameters + + if self.use_cache: + hash_key = generate_chat_completion_hash_key(self.model_id, messages, parameters) + cached_output = self.try_load_model_output(hash_key) + if cached_output is not None: + cached_output.is_cache = True + cached_output = cast(T_O, cached_output) + return cached_output + else: + model_output = await self._async_chat_completion(messages, parameters) + model_output.hash_key = hash_key + self.cache_model_output(hash_key, model_output) + return model_output + else: + model_output = await self._async_chat_completion(messages, parameters) + return model_output diff --git a/lmclient/models/http.py b/lmclient/models/http.py index 7e94c56..57a5f4c 100644 --- a/lmclient/models/http.py +++ b/lmclient/models/http.py @@ -1,5 +1,6 @@ from __future__ import annotations +import logging from abc import ABC, abstractmethod from pathlib import Path from typing import Any @@ -7,25 +8,23 @@ import httpx from tenacity import retry, stop_after_attempt, wait_random_exponential -from lmclient.models.base import BaseChatModel, T -from lmclient.types import BaseModel, ChatModelOutput, Messages, ModelResponse +from lmclient.models.base import T_P, BaseChatModel +from lmclient.types import HttpChatModelOutput, Messages, ModelResponse, RetryStrategy +logger = logging.getLogger(__name__) -class RetryStrategy(BaseModel): # type: ignore - min_wait_seconds: int = 2 - max_wait_seconds: int = 20 - max_attempt: int = 3 +class HttpChatModel(BaseChatModel[T_P, HttpChatModelOutput], ABC): + model_type = 'http' -class HttpChatModel(BaseChatModel[T], ABC): def __init__( self, + parameters: T_P, timeout: int | None = None, retry: bool | RetryStrategy = False, - default_parameters: T | None = None, use_cache: Path | str | bool = False, ): - super().__init__(default_parameters=default_parameters, use_cache=use_cache) + super().__init__(parameters=parameters, use_cache=use_cache) self.timeout = timeout if isinstance(retry, RetryStrategy): self.retry_strategy = retry @@ -33,56 +32,61 @@ def __init__( self.retry_strategy = RetryStrategy() if retry else None @abstractmethod - def get_post_parameters(self, messages: Messages, parameters: T | None = None) -> dict[str, Any]: + def get_request_parameters(self, messages: Messages, parameters: T_P) -> dict[str, Any]: ... @abstractmethod def parse_model_reponse(self, response: ModelResponse) -> Messages: ... - def _chat_completion(self, messages: Messages, override_parameters: T | None = None) -> ChatModelOutput: - if self.default_parameters is not None and override_parameters is not None: - override_parameters = self.default_parameters.model_copy(update=override_parameters.model_dump()) - - http_parameters = self.get_post_parameters(messages, override_parameters) + def _chat_completion_without_retry(self, messages: Messages, parameters: T_P) -> HttpChatModelOutput: + http_parameters = self.get_request_parameters(messages, parameters) http_parameters = {'timeout': self.timeout, **http_parameters} - http_response = httpx.post(**http_parameters) + logger.info(f'HTTP Request: {http_parameters}') + http_response = httpx.post(**http_parameters) # type: ignore http_response.raise_for_status() model_response = http_response.json() - return ChatModelOutput( - messages=self.parse_model_reponse(model_response), + logger.info(f'HTTP Response: {model_response}') + new_messages = self.parse_model_reponse(model_response) + reply = new_messages[-1].content + reply = reply if isinstance(reply, str) else '' + return HttpChatModelOutput( + messages=new_messages, response=model_response, + reply=reply, ) - async def _async_chat_completion(self, messages: Messages, override_parameters: T | None = None) -> ChatModelOutput: - if self.default_parameters is not None and override_parameters is not None: - override_parameters = self.default_parameters.model_copy(update=override_parameters.model_dump()) - + async def _async_chat_completion_without_retry(self, messages: Messages, parameters: T_P) -> HttpChatModelOutput: async with httpx.AsyncClient() as client: - http_parameters = self.get_post_parameters(messages, override_parameters) + http_parameters = self.get_request_parameters(messages, parameters) http_parameters = {'timeout': self.timeout, **http_parameters} - http_response = await client.post(**http_parameters) + logger.info(f'ASYNC HTTP Request: {http_parameters}') + http_response = await client.post(**http_parameters) # type: ignore http_response.raise_for_status() model_response = http_response.json() - return ChatModelOutput( - messages=self.parse_model_reponse(model_response), + new_messages = self.parse_model_reponse(model_response) + reply = new_messages[-1].content + reply = reply if isinstance(reply, str) else '' + return HttpChatModelOutput( + messages=new_messages, response=model_response, + reply=reply, ) - def chat_completion(self, messages: Messages, override_parameters: T | None = None) -> ChatModelOutput: + def _chat_completion(self, messages: Messages, parameters: T_P) -> HttpChatModelOutput: if self.retry_strategy is None: - return self._chat_completion(messages, override_parameters) + return self._chat_completion_without_retry(messages, parameters) wait = wait_random_exponential(min=self.retry_strategy.min_wait_seconds, max=self.retry_strategy.max_wait_seconds) stop = stop_after_attempt(self.retry_strategy.max_attempt) - output = retry(wait=wait, stop=stop)(self._chat_completion)(messages, override_parameters) + output = retry(wait=wait, stop=stop)(self._chat_completion_without_retry)(messages, parameters) return output - async def async_chat_completion(self, messages: Messages, override_parameters: T | None = None) -> ChatModelOutput: + async def _async_chat_completion(self, messages: Messages, parameters: T_P) -> HttpChatModelOutput: if self.retry_strategy is None: - return await self._async_chat_completion(messages, override_parameters) + return await self._async_chat_completion_without_retry(messages, parameters) wait = wait_random_exponential(min=self.retry_strategy.min_wait_seconds, max=self.retry_strategy.max_wait_seconds) stop = stop_after_attempt(self.retry_strategy.max_attempt) - output = await retry(wait=wait, stop=stop)(self._async_chat_completion)(messages, override_parameters) + output = await retry(wait=wait, stop=stop)(self._async_chat_completion_without_retry)(messages, parameters) return output diff --git a/lmclient/models/minimax_pro.py b/lmclient/models/minimax_pro.py index ae7de84..41e700c 100644 --- a/lmclient/models/minimax_pro.py +++ b/lmclient/models/minimax_pro.py @@ -1,17 +1,15 @@ from __future__ import annotations -import json import os from pathlib import Path -from typing import Any, ClassVar, List, Literal, Optional, Type +from typing import Any, Dict, List, Literal, Optional +from pydantic import Field from typing_extensions import NotRequired, TypedDict from lmclient.exceptions import MessageError from lmclient.models.http import HttpChatModel, RetryStrategy -from lmclient.parser import ModelResponseParser, ParserError from lmclient.types import ( - Field, FunctionCallDict, FunctionDict, GeneralParameters, @@ -22,7 +20,9 @@ ) from lmclient.utils import to_dict -DEFAULT_BOT_PROMPT = "MM智能助理是一款由MiniMax自研的,没有调用其他产品的接口的大型语言模型。MiniMax是一家中国科技公司,一直致力于进行大模型相关的研究。" +DEFAULT_MINIMAX_BOT_NAME = 'MM智能助理' +DEFAULT_MINIMAX_USER_NAME = '用户' +DEFAULT_MINIMAX_BOT_PROMPT = 'MM智能助理是一款由MiniMax自研的,没有调用其他产品的接口的大型语言模型。MiniMax是一家中国科技公司,一直致力于进行大模型相关的研究。' class BotSettingDict(TypedDict): @@ -33,7 +33,7 @@ class BotSettingDict(TypedDict): class GlyphDict(TypedDict): type: str raw_glpyh: str - json_properties: dict + json_properties: Dict[str, Any] class ReplyConstrainsDict(TypedDict): @@ -49,16 +49,24 @@ class MinimaxMessageDict(TypedDict): function_call: NotRequired[FunctionCallDict] +def default_bot_setting(): + return [{'bot_name': 'MM智能助理', 'content': DEFAULT_MINIMAX_BOT_PROMPT}] + + +def default_reply_constrains(): + return {'sender_type': 'BOT', 'sender_name': DEFAULT_MINIMAX_BOT_NAME} + + class MinimaxProChatParameters(ModelParameters): - temperature: float = 1 - top_p: float = 1 - tokens_to_generate: int = 1024 - mask_sensitive_info: bool = True - bot_setting: List[BotSettingDict] = Field(default_factory=list) - reply_constrains: ReplyConstrainsDict = Field(default_factory=list) + bot_setting: List[BotSettingDict] = Field(default_factory=default_bot_setting) + reply_constraints: ReplyConstrainsDict = Field(default_factory=default_reply_constrains) + temperature: Optional[float] = None + top_p: Optional[float] = None + tokens_to_generate: Optional[int] = None + mask_sensitive_info: Optional[bool] = None sample_messages: Optional[List[MinimaxMessageDict]] = None functions: Optional[List[FunctionDict]] = None - plugins : Optional[List[str]] = None + plugins: Optional[List[str]] = None @classmethod def from_general_parameters(cls, general_parameters: GeneralParameters): @@ -69,94 +77,38 @@ def from_general_parameters(cls, general_parameters: GeneralParameters): functions=general_parameters.functions, ) -class MinimaxProFunctionCallParser(ModelResponseParser): - def __call__(self, response: ModelResponse) -> FunctionCallDict: - try: - function_call_dict = response['choices'][0]['messages'][-1]['function_call'] - return FunctionCallDict( - name=function_call_dict['name'], - arguments=json.loads(function_call_dict['arguments']) - ) - except (KeyError, IndexError) as e: - raise ParserError('Parse response failed') from e - - -class MinimaxProTextParser(ModelResponseParser): - def __call__(self, response: ModelResponse) -> str: - try: - output = response['reply'] - except (KeyError, IndexError) as e: - raise ParserError('Parse response failed') from e - return output - - -class MinimaxProParser(ModelResponseParser): - def __call__(self, response: ModelResponse) -> Messages: - return [self._minimax_to_lmclient(i) for i in response['choices'][0]['messages']] - - @staticmethod - def _minimax_to_lmclient(message: MinimaxMessageDict) -> Message: - if 'function_call' in message: - return Message( - role=message['sender_type'], - name=message['sender_name'], - content=message['function_call'] - ) - else: - return Message( - role=message['sender_type'], - name=message['sender_name'], - content=message['text'], - ) - class MinimaxProChat(HttpChatModel[MinimaxProChatParameters]): - parameters_type = MinimaxProChatParameters + model_type = 'minimax_pro' def __init__( self, model: str = 'abab5.5-chat', - base_url: str = 'https://api.minimax.chat/v1/text/chatcompletion_pro', group_id: str | None = None, api_key: str | None = None, - bot_name: str = 'MM智能助理', - user_name: str = '用户', - system_prompt: str | None = None, + base_url: str = 'https://api.minimax.chat/v1/text/chatcompletion_pro', timeout: int | None = 60, retry: bool | RetryStrategy = False, - default_parameters: MinimaxProChatParameters | None = None, + parameters: MinimaxProChatParameters = MinimaxProChatParameters(), use_cache: Path | str | bool = False, ): - super().__init__(default_parameters=default_parameters, timeout=timeout, retry=retry, use_cache=use_cache) + super().__init__(parameters=parameters, timeout=timeout, retry=retry, use_cache=use_cache) self.model = model self.base_url = base_url self.group_id = group_id or os.environ['MINIMAX_GROUP_ID'] self.api_key = api_key or os.environ['MINIMAX_API_KEY'] - self.bot_name = bot_name - self.system_prompt = system_prompt or DEFAULT_BOT_PROMPT - self.user_name = user_name - def get_post_parameters(self, messages: Messages, parameters: MinimaxProChatParameters | None = None) -> dict[str, Any]: + def get_request_parameters(self, messages: Messages, parameters: MinimaxProChatParameters) -> dict[str, Any]: headers = { 'Authorization': f'Bearer {self.api_key}', 'Content-Type': 'application/json', } - json_data = { - 'model': self.model, - 'messages': [self._lmclient_to_minimax(message, self.bot_name, self.user_name) for message in messages] - } - - parameters = parameters or MinimaxProChatParameters() - if not parameters.bot_setting: - parameters.bot_setting = [{'bot_name': self.bot_name, 'content': self.system_prompt}] - if not parameters.reply_constrains: - parameters.reply_constrains = {'sender_type': 'USER', 'sender_name': self.bot_name} - parameters_dict = to_dict(parameters, exclude_defaults=True) if parameters else {} + json_data = {'model': self.model, 'messages': [self._lmclient_to_minimax(message) for message in messages]} + parameters_dict = to_dict(parameters, exclude_none=True) if 'temperature' in parameters_dict: parameters_dict['temperature'] = max(0.01, parameters_dict['temperature']) json_data.update(parameters_dict) - return { 'url': self.base_url, 'json': json_data, @@ -169,46 +121,53 @@ def parse_model_reponse(self, response: ModelResponse) -> Messages: @staticmethod def _minimax_to_lmclient(message: MinimaxMessageDict) -> Message: + role_map = { + 'USER': 'user', + 'BOT': 'assistant', + 'FUNCTION': 'funtion', + } + if 'function_call' in message: - return Message( - role=message['sender_type'], - name=message['sender_name'], - content=message['function_call'] - ) + return Message(role=role_map[message['sender_type']], name=message['sender_name'], content=message['function_call']) else: return Message( - role=message['sender_type'], + role=role_map[message['sender_type']], name=message['sender_name'], content=message['text'], ) - def _lmclient_to_minimax(self, message: Message, default_bot_name: str = 'MM智能助理', default_user_name: str = '用户') -> MinimaxMessageDict: + def _lmclient_to_minimax( + self, + message: Message, + default_bot_name: str = DEFAULT_MINIMAX_BOT_NAME, + default_user_name: str = DEFAULT_MINIMAX_USER_NAME, + ) -> MinimaxMessageDict: if isinstance(message.content, dict): - if message.role != 'BOT': - raise MessageError(f'Invalid role {message.role} for function call, must be BOT') + if message.role != 'assistant': + raise MessageError(f'Invalid role {message.role} for function call, must be assistant') return { - 'sender_type': message.role, + 'sender_type': 'BOT', 'sender_name': message.name or default_bot_name, 'text': '', 'function_call': message.content, } - elif message.role == 'BOT': - return { - 'sender_type': message.role, - 'sender_name': message.name or default_bot_name, - 'text': message.content, - } - elif message.role == 'FUNCTION': + elif message.role == 'assistant': + return { + 'sender_type': 'BOT', + 'sender_name': message.name or default_bot_name, + 'text': message.content, + } + elif message.role == 'function': if message.name is None: raise MessageError(f'Function name is required, message: {message}') return { - 'sender_type': message.role, + 'sender_type': 'FUNCTION', 'sender_name': message.name, 'text': message.content, } - elif message.role == 'USER': + elif message.role == 'user': return { - 'sender_type': message.role, + 'sender_type': 'USER', 'sender_name': message.name or default_user_name, 'text': message.content, } @@ -216,5 +175,9 @@ def _lmclient_to_minimax(self, message: Message, default_bot_name: str = 'MM智 raise MessageError(f'Invalid role {message.role}, must be BOT, FUNCTION, or USER') @property - def identifier(self) -> str: - return f'{self.__class__.__name__}({self.model})' + def name(self) -> str: + return self.model + + @classmethod + def from_name(cls, name: str, **kwargs: Any): + return cls(model=name, **kwargs) diff --git a/lmclient/models/openai.py b/lmclient/models/openai.py index 0a682fa..3a73f4b 100644 --- a/lmclient/models/openai.py +++ b/lmclient/models/openai.py @@ -80,72 +80,72 @@ def convert_lmclient_to_openai(message: Message, valid_roles: set[str] | None = if message.role != 'assistant': raise MessageError(f'Invalid role "{message.role}" for function call, can only be made by "assistant"') return { - 'role': message.role, - 'function_call': content, - 'content': None, - } + 'role': message.role, + 'function_call': content, + 'content': None, + } elif message.role == 'function': name = message.name if name is None: raise MessageError(f'Function name is required, message: {message}') - return { - 'role': message.role, - 'name': name, - 'content': content, - } + return { + 'role': message.role, + 'name': name, + 'content': content, + } else: return { - 'role': message.role, - 'content': content, - } - + 'role': message.role, + 'content': content, + } def parse_openai_model_reponse(response: ModelResponse) -> Messages: funcation_call = response['choices'][0]['message'].get('function_call') try: if bool(funcation_call): - return [Message( - role='assistant', - content=funcation_call, - )] + return [ + Message( + role='assistant', + content=funcation_call, + ) + ] else: text: str = response['choices'][0]['message']['content'] - return [Message( - role='assistant', - content=text, - )] + return [ + Message( + role='assistant', + content=text, + ) + ] except (KeyError, IndexError) as e: raise ParserError('Parse response failed') from e class OpenAIChat(HttpChatModel[OpenAIChatParameters]): - parameters_type = OpenAIChatParameters + model_type = 'openai' def __init__( self, model: str = 'gpt-3.5-turbo', - system_prompt: str | None = None, api_key: str | None = None, api_base: str | None = None, timeout: int | None = 60, retry: bool | RetryStrategy = False, - default_parameters: OpenAIChatParameters | None = None, + parameters: OpenAIChatParameters = OpenAIChatParameters(), use_cache: Path | str | bool = False, ): - super().__init__(default_parameters=default_parameters, timeout=timeout, retry=retry, use_cache=use_cache) + super().__init__(parameters=parameters, timeout=timeout, retry=retry, use_cache=use_cache) self.model = model - self.system_prompt = system_prompt self.api_base = api_base or os.getenv('OPENAI_API_BASE') or 'https://api.openai.com/v1' self.api_key = api_key or os.environ['OPENAI_API_KEY'] - def get_post_parameters(self, messages: Messages, parameters: OpenAIChatParameters | None = None) -> dict[str, Any]: + def get_request_parameters(self, messages: Messages, parameters: OpenAIChatParameters) -> dict[str, Any]: headers = { 'Authorization': f'Bearer {self.api_key}', } - parameters_dict = {} if parameters is None else to_dict(parameters, exclude_defaults=True) - openai_messages: list[OpenAIMessageDict] = [] if self.system_prompt is None else [{'role': 'system', 'content': self.system_prompt}] - openai_messages = openai_messages + [convert_lmclient_to_openai(message) for message in messages] + parameters_dict = to_dict(parameters, exclude_defaults=True) + openai_messages = [convert_lmclient_to_openai(message) for message in messages] params = { 'model': self.model, 'messages': openai_messages, @@ -161,5 +161,9 @@ def parse_model_reponse(self, response: ModelResponse) -> Messages: return parse_openai_model_reponse(response) @property - def identifier(self) -> str: - return f'{self.__class__.__name__}({self.model})' + def name(self) -> str: + return self.model + + @classmethod + def from_name(cls, name: str, **kwargs: Any): + return cls(model=name, **kwargs) diff --git a/lmclient/models/zhipu.py b/lmclient/models/zhipu.py index 3628298..9fefd8f 100644 --- a/lmclient/models/zhipu.py +++ b/lmclient/models/zhipu.py @@ -1,9 +1,8 @@ from __future__ import annotations -import logging import os -from pathlib import Path import time +from pathlib import Path from typing import Any, TypedDict, TypeVar import cachetools.func # type: ignore @@ -11,14 +10,13 @@ from lmclient.exceptions import MessageError from lmclient.models.http import HttpChatModel, RetryStrategy -from lmclient.parser import ModelResponseParser, ParserError -from lmclient.types import GeneralParameters, Messages, ModelParameters, ModelResponse +from lmclient.parser import ParserError +from lmclient.types import GeneralParameters, Message, Messages, ModelParameters, ModelResponse from lmclient.utils import to_dict T = TypeVar('T') API_TOKEN_TTL_SECONDS = 3 * 60 CACHE_TTL_SECONDS = API_TOKEN_TTL_SECONDS - 30 -logger = logging.getLogger(__name__) class ZhiPuChatParameters(ModelParameters): @@ -33,15 +31,6 @@ def from_general_parameters(cls, general_parameters: GeneralParameters): ) -class ZhiPuResponse(ModelResponse): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.content = self.content.replace('\n', ' ') - - -class ZhiPuModel(HttpChatModel): - name = 'zhipu' - class ZhiPuMessageDict(TypedDict): role: str content: str @@ -60,7 +49,7 @@ def generate_token(api_key: str): 'timestamp': int(round(time.time() * 1000)), } - return jwt.encode( + return jwt.encode( # type: ignore payload, secret, algorithm='HS256', @@ -68,16 +57,9 @@ def generate_token(api_key: str): ) -class ZhiPuParser(ModelResponseParser): - def __call__(self, response: ModelResponse) -> str: - try: - output = response['data']['choices'][0]['content'].strip('"').strip() - except (KeyError, IndexError) as e: - raise ParserError(f'Parse response failed, reponse: {response}') from e - return output - - class ZhiPuChat(HttpChatModel[ZhiPuChatParameters]): + model_type = 'zhipu' + def __init__( self, model: str = 'chatglm_pro', @@ -85,16 +67,16 @@ def __init__( api_key: str | None = None, timeout: int | None = 60, retry: bool | RetryStrategy = False, - default_parameters: ZhiPuChatParameters | None = None, + parameters: ZhiPuChatParameters = ZhiPuChatParameters(), use_cache: Path | str | bool = False, ): - super().__init__(default_parameters=default_parameters, timeout=timeout, retry=retry, use_cache=use_cache) + super().__init__(parameters=parameters, timeout=timeout, retry=retry, use_cache=use_cache) self.model = model self.api_key = api_key or os.environ['ZHIPU_API_KEY'] self.api_base = api_base or os.getenv('ZHIPU_API_BASE') or 'https://open.bigmodel.cn/api/paas/v3/model-api' self.api_base.rstrip('/') - def get_post_parameters(self, messages: Messages, parameters: ZhiPuChatParameters | None = None) -> dict[str, Any]: + def get_request_parameters(self, messages: Messages, parameters: ZhiPuChatParameters) -> dict[str, Any]: for message in messages: if message.role not in ('user', 'assistant'): raise ValueError(f'Role of message must be user or assistant, but got {message.role}') @@ -105,21 +87,34 @@ def get_post_parameters(self, messages: Messages, parameters: ZhiPuChatParameter raise MessageError(f'Role of message must be user or assistant, but got {message.role}') if not isinstance(message.content, str): raise MessageError(f'Message content must be str, but got {type(message.content)}') - zhipu_messages.append({ - 'role': message.role, - 'content': message.content, - }) + zhipu_messages.append( + { + 'role': message.role, + 'content': message.content, + } + ) headers = { 'Authorization': generate_token(self.api_key), } - parameters_dict = {} if parameters is None else to_dict(parameters, exclude_defaults=True) - params = {'prompt': messages, **parameters_dict} + parameters_dict = to_dict(parameters, exclude_defaults=True) + params = {'prompt': zhipu_messages, **parameters_dict} return { 'url': f'{self.api_base}/{self.model}/invoke', 'headers': headers, 'json': params, } + def parse_model_reponse(self, response: ModelResponse) -> Messages: + try: + text = response['data']['choices'][0]['content'].strip('"').strip() + return [Message(role='assistant', content=text)] + except (KeyError, IndexError) as e: + raise ParserError(f'Parse response failed, reponse: {response}') from e + @property - def identifier(self) -> str: - return f'{self.__class__.__name__}({self.model})' + def name(self) -> str: + return self.model + + @classmethod + def from_name(cls, name: str, **kwargs: Any): + return cls(model=name, **kwargs) diff --git a/lmclient/types.py b/lmclient/types.py index 663b3be..9abba55 100644 --- a/lmclient/types.py +++ b/lmclient/types.py @@ -1,9 +1,9 @@ from __future__ import annotations -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Sequence, Union -from pydantic import BaseModel, Field -from typing_extensions import NotRequired, TypedDict +from pydantic import BaseModel, ConfigDict, Field +from typing_extensions import NotRequired, Self, TypedDict class Message(BaseModel): @@ -16,15 +16,16 @@ def is_function_call(self) -> bool: return isinstance(self.content, dict) -class ChatModelOutput(BaseModel): - messages: Messages - response: ModelResponse = Field(default_factory=dict) +class MessageDict(TypedDict): + role: str + content: Union[str, FunctionCallDict] + name: NotRequired[str] class FunctionDict(TypedDict): name: str description: NotRequired[str] - parameters: dict + parameters: Dict[str, Any] class GeneralParameters(BaseModel): @@ -35,13 +36,31 @@ class GeneralParameters(BaseModel): function_call: Optional[str] = None +class ChatModelOutput(BaseModel): + messages: Messages + hash_key: str = '' + is_cache: bool = False + reply: str = '' + + class ModelParameters(BaseModel): + model_config = ConfigDict(frozen=True) @classmethod - def from_general_parameters(cls, general_parameters: GeneralParameters): + def from_general_parameters(cls, general_parameters: GeneralParameters) -> Self: raise NotImplementedError +class RetryStrategy(BaseModel): + min_wait_seconds: int = 2 + max_wait_seconds: int = 20 + max_attempt: int = 3 + + +class HttpChatModelOutput(ChatModelOutput): + response: ModelResponse = Field(default_factory=dict) + + class FunctionCallDict(TypedDict): name: str arguments: str @@ -49,4 +68,4 @@ class FunctionCallDict(TypedDict): Messages = List[Message] ModelResponse = Dict[str, Any] -Prompt = Union[str, Messages] +Prompt = Union[str, Message, MessageDict, Sequence[Union[MessageDict, Message]]] diff --git a/lmclient/utils.py b/lmclient/utils.py index 832527d..2887cd2 100644 --- a/lmclient/utils.py +++ b/lmclient/utils.py @@ -1,5 +1,7 @@ +# type: ignore from __future__ import annotations +import hashlib import json from functools import wraps from typing import Any, Callable @@ -8,18 +10,30 @@ from pydantic import BaseModel, validate_arguments from lmclient.exceptions import MessageError -from lmclient.types import FunctionDict, Message +from lmclient.types import FunctionDict, Message, Messages, ModelParameters +from lmclient.version import __cache_version__ def get_pydantic_version(): import pydantic from packaging import version + return version.parse(pydantic.__version__).major PydanticVersion = get_pydantic_version() +def generate_chat_completion_hash_key(model_id: str, messages: Messages, parameters: ModelParameters) -> str: + messages_text = '---'.join([f'{k}={v}' for message in messages for k, v in to_dict(message).items()]) + messages_hash = md5_hash(messages_text) + parameters_hash = md5_hash(parameters.model_dump_json(exclude_none=True)) + return f'{model_id}|{messages_hash}|{parameters_hash}|v{__cache_version__}' + + +def md5_hash(string: str) -> str: + return hashlib.md5(string.encode()).hexdigest() + def _remove_a_key(d, remove_key) -> None: """Remove a key from a dictionary recursively""" @@ -31,7 +45,7 @@ def _remove_a_key(d, remove_key) -> None: _remove_a_key(d[key], remove_key) -class lm_function: +class function: def __init__(self, func: Callable) -> None: self.func = func self.name = self.func.__name__ @@ -39,25 +53,19 @@ def __init__(self, func: Callable) -> None: self.docstring = parse(self.func.__doc__ or '') parameters = self.validate_func.model.model_json_schema() - parameters["properties"] = { - k: v - for k, v in parameters["properties"].items() - if k not in ("v__duplicate_kwargs", "args", "kwargs") + parameters['properties'] = { + k: v for k, v in parameters['properties'].items() if k not in ('v__duplicate_kwargs', 'args', 'kwargs') } for param in self.docstring.params: - if (name := param.arg_name) in parameters["properties"] and ( - description := param.description - ): - parameters["properties"][name]["description"] = description - parameters["required"] = sorted( - k for k, v in parameters["properties"].items() if "default" not in v - ) - _remove_a_key(parameters, "additionalProperties") - _remove_a_key(parameters, "title") + if (name := param.arg_name) in parameters['properties'] and (description := param.description): + parameters['properties'][name]['description'] = description + parameters['required'] = sorted(k for k, v in parameters['properties'].items() if 'default' not in v) + _remove_a_key(parameters, 'additionalProperties') + _remove_a_key(parameters, 'title') self.schema: FunctionDict = { - "name": self.name, - "description": self.docstring.short_description or '', - "parameters": parameters, + 'name': self.name, + 'description': self.docstring.short_description or '', + 'parameters': parameters, } self.model = self.validate_func.model @@ -72,45 +80,38 @@ def from_message(self, message: Message): function_call = message.content if isinstance(function_call, str): raise MessageError(f'{message} is not a valid function call message') - arguments = json.loads(function_call["arguments"], strict=False) + arguments = json.loads(function_call['arguments'], strict=False) return self.validate_func(**arguments) -class LMSchema(BaseModel): +class BaseSchema(BaseModel): @classmethod @property def openai_schema(cls): schema = cls.model_json_schema() docstring = parse(cls.__doc__ or '') - parameters = { - k: v for k, v in schema.items() if k not in ("title", "description") - } + parameters = {k: v for k, v in schema.items() if k not in ('title', 'description')} for param in docstring.params: - if (name := param.arg_name) in parameters["properties"] and ( - description := param.description - ): - if "description" not in parameters["properties"][name]: - parameters["properties"][name]["description"] = description + if (name := param.arg_name) in parameters['properties'] and (description := param.description): + if 'description' not in parameters['properties'][name]: + parameters['properties'][name]['description'] = description - parameters["required"] = sorted( - k for k, v in parameters["properties"].items() if "default" not in v - ) + parameters['required'] = sorted(k for k, v in parameters['properties'].items() if 'default' not in v) - if "description" not in schema: + if 'description' not in schema: if docstring.short_description: - schema["description"] = docstring.short_description + schema['description'] = docstring.short_description else: - schema["description"] = ( - f"Correctly extracted `{cls.__name__}` with all " - f"the required parameters with correct types" + schema['description'] = ( + f'Correctly extracted `{cls.__name__}` with all ' f'the required parameters with correct types' ) - _remove_a_key(parameters, "additionalProperties") - _remove_a_key(parameters, "title") + _remove_a_key(parameters, 'additionalProperties') + _remove_a_key(parameters, 'title') return { - "name": schema["title"], - "description": schema["description"], - "parameters": parameters, + 'name': schema['title'], + 'description': schema['description'], + 'parameters': parameters, } @classmethod @@ -118,12 +119,12 @@ def from_message(cls, message: Message): function_call = message.content if isinstance(function_call, str): raise MessageError(f'{message} is not a valid function call message') - arguments = json.loads(function_call["arguments"], strict=False) + arguments = json.loads(function_call['arguments'], strict=False) return cls(**arguments) -def to_dict(value: BaseModel, exclude_defaults: bool = False): +def to_dict(value: BaseModel, exclude_defaults: bool = False, exclude_none: bool = False): if PydanticVersion == 2: - return value.model_dump(exclude_defaults=exclude_defaults) + return value.model_dump(exclude_defaults=exclude_defaults, exclude_none=exclude_none) else: - return value.dict(exclude_defaults=exclude_defaults) + return value.dict(exclude_defaults=exclude_defaults, exclude_none=exclude_none) diff --git a/lmclient/version.py b/lmclient/version.py index c11ef48..e7a05af 100644 --- a/lmclient/version.py +++ b/lmclient/version.py @@ -1,2 +1,2 @@ -__version__ = '0.6.0' -__cache_version__ = '3' +__version__ = '0.7.0' +__cache_version__ = '4' diff --git a/pyproject.toml b/pyproject.toml index f34bef2..2332c96 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "lmclient-core" -version = "0.6.0" +version = "0.7.0" description = "LM Async Client, openai client, azure openai client ..." authors = ["wangyuxin "] readme = "README.md" @@ -40,7 +40,7 @@ ignore = [ ] [tool.pyright] -reportIncompatibleMethodOverride=true +reportMissingTypeStubs=false [tool.poetry.group.dev.dependencies] pytest = "^7.3.1" diff --git a/scripts/data/ner_input.jsonl b/scripts/data/ner_input.jsonl deleted file mode 100644 index 365ed96..0000000 --- a/scripts/data/ner_input.jsonl +++ /dev/null @@ -1,3 +0,0 @@ -{"text": "海钓比赛地点在厦门与金门之间的海域。"} -{"text": "克马尔的女儿让娜今年读五年级,她所在的班上有30多名同学,该班的“家委会”由10名家长组成。"} -{"text": "沙特队教练佩雷拉:两支队都想胜,因此都作出了最大的努力。"} \ No newline at end of file diff --git a/scripts/data/translate_input.jsonl b/scripts/data/translate_input.jsonl index 52378cd..f7ddc87 100644 --- a/scripts/data/translate_input.jsonl +++ b/scripts/data/translate_input.jsonl @@ -1,3 +1,30 @@ -{"text": "players who have scored 5 goals in world cup finals"} -{"text": "where was christianity most strongly established by ad 325"} -{"text": "when was the last time turkey was in the world cup"} +To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France? +What is in front of the Notre Dame Main Building? +The Basilica of the Sacred heart at Notre Dame is beside to which structure? +What is the Grotto at Notre Dame? +What sits on top of the Main Building at Notre Dame? +When did the Scholastic Magazine of Notre dame begin publishing? +How often is Notre Dame's the Juggler published? +What is the daily student paper at Notre Dame called? +How many student news papers are found at Notre Dame? +In what year did the student paper Common Sense begin publication at Notre Dame? +Where is the headquarters of the Congregation of the Holy Cross? +What is the primary seminary of the Congregation of the Holy Cross? +What is the oldest structure at Notre Dame? +What individuals live at Fatima House at Notre Dame? +Which prize did Frederick Buechner create? +How many BS level degrees are offered in the College of Engineering at Notre Dame? +In what year was the College of Engineering at Notre Dame formed? +Before the creation of the College of Engineering similar studies were carried out at which Notre Dame college? +How many departments are within the Stinson-Remick Hall of Engineering? +The College of Science began to offer civil engineering courses beginning at what time at Notre Dame? +What entity provides help with the management of time for new students at Notre Dame? +How many colleges for undergraduates are at Notre Dame? +What was created at Notre Dame in 1962 to assist first year students? +Which organization declared the First Year of Studies program at Notre Dame "outstanding?" +The granting of Doctorate degrees first occurred in what year at Notre Dame? +What type of degree is an M.Div.? +Which program at Notre Dame offers a Master of Education degree? +In what year was a Master of Arts course first offered at Notre Dame? +Which department at Notre Dame is the only one to not offer a PhD program? +What institute at Notre Dame studies the reasons for violent conflict? \ No newline at end of file diff --git a/scripts/ner.py b/scripts/ner.py deleted file mode 100644 index 50d4142..0000000 --- a/scripts/ner.py +++ /dev/null @@ -1,66 +0,0 @@ -from __future__ import annotations - -import json -from enum import Enum -from pathlib import Path -from typing import List - -import typer - -from lmclient import LMClient, OpenAIExtract -from lmclient.client import ErrorMode -from lmclient.openai_schema import Field, OpenAISchema - - -class ModelType(str, Enum): - openai = 'openai' - azure = 'azure' - - -class NerInfo(OpenAISchema): - """命名实体信息,包括人名,地名和组织名""" - - person: List[str] = Field(default_factory=list) - location: List[str] = Field(default_factory=list) - organization: List[str] = Field(default_factory=list) - - -def read_from_jsonl(file: str | Path): - texts: list[str] = [] - with open(file, 'r') as f: - for line in f: - texts.append(json.loads(line.strip())['text']) - return texts - - -def main( - input_josnl_file: Path, - output_file: Path, - max_requests_per_minute: int = 20, - async_capacity: int = 3, - error_mode: ErrorMode = ErrorMode.RAISE, - use_cache: bool = False, -): - - model = OpenAIExtract( - schema=NerInfo, - use_cache=use_cache, - ) - - client = LMClient( - model, - max_requests_per_minute=max_requests_per_minute, - async_capacity=async_capacity, - error_mode=error_mode, - ) - texts = read_from_jsonl(input_josnl_file) - model_outputs = client.async_run(texts) - with open(output_file, 'w') as f: - for text, output in zip(texts, model_outputs): - output = output.message.dict() if output.message else None - output_dict = {'text': text, 'output': output} - f.write(json.dumps(output_dict, ensure_ascii=False) + '\n') - - -if __name__ == '__main__': - typer.run(main) diff --git a/scripts/translate.py b/scripts/translate.py index dbc207f..8b59df5 100644 --- a/scripts/translate.py +++ b/scripts/translate.py @@ -5,44 +5,36 @@ import typer -from lmclient import AzureChat, LMClient, MinimaxChat, OpenAIChat +from lmclient import LMClient from lmclient.client import ErrorMode +from lmclient.models import load_from_model_id -def read_from_jsonl(file: str | Path): - texts: list[str] = [] - with open(file, 'r') as f: - for line in f: - texts.append(json.loads(line.strip())['text']) +def read_from_text_file(file: str | Path): + file = Path(file) + texts: list[str] = file.read_text().split('\n') return texts def main( input_josnl_file: Path, output_file: Path, - model_name: str = 'gpt-3.5-turbo', - max_requests_per_minute: int = 20, + model_id: str = 'openai', + max_requests_per_minute: int = 5, async_capacity: int = 3, - error_mode: ErrorMode = ErrorMode.IGNORE, + error_mode: ErrorMode = ErrorMode.RAISE, use_cache: bool = True, -): - - if model_name == 'azure': - model = AzureChat(use_cache=use_cache) - elif model_name == 'minimax': - model = MinimaxChat(use_cache=use_cache) - else: - model = OpenAIChat(model=model_name, use_cache=use_cache) - - client = LMClient[str]( - model, +) -> None: + model = load_from_model_id(model_id=model_id, use_cache=use_cache) + client = LMClient( + model, # type: ignore max_requests_per_minute=max_requests_per_minute, async_capacity=async_capacity, error_mode=error_mode, ) - texts = read_from_jsonl(input_josnl_file) - prompts = [] + texts = read_from_text_file(input_josnl_file) + prompts: list[str] = [] for text in texts: prompt = f'translate following sentece to chinese\nsentence: {text}\ntranslation: ' prompts.append(prompt) @@ -50,7 +42,7 @@ def main( with open(output_file, 'w') as f: for text, result in zip(texts, results): - f.write(json.dumps({'text': text, 'translation': result.message}, ensure_ascii=False) + '\n') + f.write(json.dumps({'text': text, 'translation': result.reply}, ensure_ascii=False) + '\n') if __name__ == '__main__': diff --git a/tests/test_client.py b/tests/test_client.py index cf31265..ad30047 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1,60 +1,64 @@ from __future__ import annotations import time +from pathlib import Path +from typing import Any from lmclient.client import LMClient from lmclient.models.base import BaseChatModel -from lmclient.types import Messages, ModelResponse +from lmclient.types import ChatModelOutput, Message, MessageDict, Messages, ModelParameters -class TestModel(BaseChatModel): - def chat_completion(self, messages: Messages, **kwargs) -> ModelResponse: - return { - 'content': f'Completed: {messages[-1].content}', - } +class TestModelParameters(ModelParameters): + prefix: str = 'Completed:' - async def async_chat_completion(self, messages: Messages, **kwargs) -> ModelResponse: - return { - 'content': f'Completed: {messages[-1].content}', - } - def default_postprocess_function(self, response: ModelResponse) -> str: - return response['content'] +class TestModel(BaseChatModel[TestModelParameters, ChatModelOutput]): + model_type = 'test' + + def __init__(self, default_parameters: TestModelParameters | None = None, use_cache: Path | str | bool = False) -> None: + default_parameters = default_parameters or TestModelParameters() + super().__init__(default_parameters, use_cache) + + def _chat_completion(self, messages: Messages, parameters: TestModelParameters) -> ChatModelOutput: + content = f'Completed: {messages[-1].content}' + return ChatModelOutput(messages=[Message(role='assistant', content=content)], reply=content) + + async def _async_chat_completion(self, messages: Messages, parameters: TestModelParameters) -> ChatModelOutput: + content = f'Completed: {messages[-1].content}' + return ChatModelOutput(messages=[Message(role='assistant', content=content)], reply=content) @property - def identifier(self) -> str: + def name(self) -> str: return 'TestModel' - -def model_parser(response): - return response['content'] + @classmethod + def from_name(cls, name: str, **kwargs: Any): + return cls() def test_sync_completion(): - completion_model = TestModel(response_parser=model_parser, use_cache=False) + completion_model = TestModel() client = LMClient(completion_model) prompts = [ 'Hello, my name is', - [ - {'role': 'system', 'content': 'your are lmclient demo assistant'}, - {'role': 'user', 'content': 'hello, who are you?'}, - ], + Message(role='user', content='hello, who are you?'), ] results = client.run(prompts) - assert isinstance(results[0].message, str) - assert results[0].message == 'Completed: Hello, my name is' - assert results[1].message == 'Completed: hello, who are you?' + assert isinstance(results[0].reply, str) + assert results[0].reply == 'Completed: Hello, my name is' + assert results[1].reply == 'Completed: hello, who are you?' assert len(results) == len(prompts) def test_async_completion(): - completion_model = TestModel(response_parser=model_parser, use_cache=False) + completion_model = TestModel() client = LMClient(completion_model, async_capacity=2, max_requests_per_minute=5) LMClient.NUM_SECONDS_PER_MINUTE = 2 start_time = time.perf_counter() - messages = [ + messages: list[MessageDict] = [ {'role': 'system', 'content': 'your are lmclient demo assistant'}, {'role': 'user', 'content': 'hello, who are you?'}, ] @@ -62,14 +66,13 @@ def test_async_completion(): results = client.async_run(prompts) elapsed_time = time.perf_counter() - start_time - assert results[0].response['content'] == 'Completed: Hello, my name is' - assert results[0].message == 'Completed: Hello, my name is' + assert results[0].reply == 'Completed: Hello, my name is' assert len(results) == len(prompts) assert elapsed_time > 4 -def test_async_completion_with_cache(tmp_path): - completion_model = TestModel(use_cache=tmp_path, response_parser=model_parser) +def test_async_completion_with_cache(tmp_path: Path): + completion_model = TestModel(use_cache=tmp_path) client = LMClient(completion_model, async_capacity=2, max_requests_per_minute=5) LMClient.NUM_SECONDS_PER_MINUTE = 2 @@ -78,8 +81,8 @@ def test_async_completion_with_cache(tmp_path): results = client.async_run(prompts) elapsed_time = time.perf_counter() - start_time - assert isinstance(results[0].response['content'], str) - assert results[3].response['content'] == 'Completed: Hello, my name is' + assert isinstance(results[0].reply, str) + assert results[3].reply == 'Completed: Hello, my name is' assert len(results) == len(prompts) assert elapsed_time < 2 assert len(list(completion_model._cache)) == 3 # type: ignore diff --git a/tests/test_model.py b/tests/test_model.py index 59807c0..00a180b 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -1,83 +1,58 @@ import anyio -import pytest from lmclient.models import AzureChat, MinimaxProChat, OpenAIChat, ZhiPuChat -from lmclient.models.openai import OpenAITextParser from lmclient.types import Message test_messages = [Message(role='user', content='hello')] -@pytest.mark.parametrize( - 'prompt', - [ - 'Hello, my name is', - test_messages - ], -) -def test_azure_model(prompt): - model = AzureChat(response_parser=OpenAITextParser()) - sync_output = model.chat(prompt) - async_output = anyio.run(model.async_chat, prompt) +def test_azure_model(): + chat_model = AzureChat() + + test_messages = [Message(role='user', content='hello')] + sync_output = chat_model.chat_completion(test_messages) + async_output = anyio.run(chat_model.async_chat_completion, test_messages) assert isinstance(sync_output.response, dict) - assert isinstance(sync_output.message, str) + assert isinstance(sync_output.reply, str) assert isinstance(async_output.response, dict) - assert isinstance(async_output.message, str) + assert isinstance(async_output.reply, str) -@pytest.mark.parametrize( - 'prompt', - [ - 'Hello, my name is', - test_messages - ], -) -def test_openai_model(prompt): - chat_model = OpenAIChat('gpt-3.5-turbo', response_parser=OpenAITextParser()) +def test_openai_model(): + chat_model = OpenAIChat('gpt-3.5-turbo') - sync_output = chat_model.chat(prompt) - async_output = anyio.run(chat_model.async_chat, prompt) + test_messages = [Message(role='user', content='hello')] + sync_output = chat_model.chat_completion(test_messages) + async_output = anyio.run(chat_model.async_chat_completion, test_messages) assert isinstance(sync_output.response, dict) - assert isinstance(sync_output.message, str) + assert isinstance(sync_output.reply, str) assert isinstance(async_output.response, dict) - assert isinstance(async_output.message, str) + assert isinstance(async_output.reply, str) -@pytest.mark.parametrize( - 'prompt', - [ - 'Hello, my name is', - test_messages - ], -) -def test_minimax_model(prompt): - completion_model = MinimaxProChat('abab5.5-chat') +def test_minimax_model(): + chat_model = MinimaxProChat() - sync_output = completion_model.chat(prompt) - async_output = anyio.run(completion_model.async_chat, prompt) + test_messages = [Message(role='user', content='hello')] + sync_output = chat_model.chat_completion(test_messages) + async_output = anyio.run(chat_model.async_chat_completion, test_messages) assert isinstance(sync_output.response, dict) - assert isinstance(sync_output.message, str) + assert isinstance(sync_output.reply, str) assert isinstance(async_output.response, dict) - assert isinstance(async_output.message, str) + assert isinstance(async_output.reply, str) -@pytest.mark.parametrize( - 'prompt', - [ - 'Hello, my name is', - test_messages - ], -) -def test_zhipu_model(prompt): - completion_model = ZhiPuChat() +def test_zhipu_model(): + chat_model = ZhiPuChat() - sync_output = completion_model.chat(prompt) - async_output = anyio.run(completion_model.async_chat, prompt) + test_messages = [Message(role='user', content='hello')] + sync_output = chat_model.chat_completion(test_messages) + async_output = anyio.run(chat_model.async_chat_completion, test_messages) assert isinstance(sync_output.response, dict) - assert isinstance(sync_output.message, str) + assert isinstance(sync_output.reply, str) assert isinstance(async_output.response, dict) - assert isinstance(async_output.message, str) + assert isinstance(async_output.reply, str) From dffb6e4acdbd5582db6cfa827f8db174be9ba5c6 Mon Sep 17 00:00:00 2001 From: wangyuxin Date: Sat, 2 Sep 2023 17:30:27 +0800 Subject: [PATCH 3/6] =?UTF-8?q?=E6=94=AF=E6=8C=81=20pydantic=20v1?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- lmclient/__init__.py | 11 +++++++++++ lmclient/models/azure.py | 3 +-- lmclient/models/minimax_pro.py | 3 +-- lmclient/models/openai.py | 3 +-- lmclient/models/zhipu.py | 3 +-- lmclient/types.py | 31 +++++++++++++++---------------- lmclient/utils.py | 9 +-------- tests/__init__.py | 0 tests/test_chat.py | 31 +++++++++++++++++++++++++++++++ 9 files changed, 62 insertions(+), 32 deletions(-) create mode 100644 tests/__init__.py create mode 100644 tests/test_chat.py diff --git a/lmclient/__init__.py b/lmclient/__init__.py index b36069e..a4f4842 100644 --- a/lmclient/__init__.py +++ b/lmclient/__init__.py @@ -9,8 +9,17 @@ ZhiPuChat, ZhiPuChatParameters, ) +from lmclient.utils import BaseSchema, PydanticVersion, function from lmclient.version import __version__ +if PydanticVersion == 1: + from pydantic import BaseModel + + BaseModel.model_copy = BaseModel.copy # type: ignore + BaseModel.model_dump = BaseModel.dict # type: ignore + BaseModel.model_dump_json = BaseModel.json # type: ignore + + __all__ = [ 'LMClient', 'ChatEngine', @@ -21,5 +30,7 @@ 'MinimaxProChatParameters', 'ZhiPuChat', 'ZhiPuChatParameters', + 'BaseSchema', + 'function', '__version__', ] diff --git a/lmclient/models/azure.py b/lmclient/models/azure.py index 2a1b5f8..dbf6292 100644 --- a/lmclient/models/azure.py +++ b/lmclient/models/azure.py @@ -11,7 +11,6 @@ parse_openai_model_reponse, ) from lmclient.types import Messages, ModelResponse -from lmclient.utils import to_dict class AzureChat(HttpChatModel[OpenAIChatParameters]): @@ -38,7 +37,7 @@ def get_request_parameters(self, messages: Messages, parameters: OpenAIChatParam headers = { 'api-key': self.api_key, } - parameters_dict = to_dict(parameters, exclude_defaults=True) + parameters_dict = parameters.model_dump(exclude_defaults=True) openai_messages = [convert_lmclient_to_openai(message) for message in messages] params = { 'model': self.model, diff --git a/lmclient/models/minimax_pro.py b/lmclient/models/minimax_pro.py index 41e700c..6c00e02 100644 --- a/lmclient/models/minimax_pro.py +++ b/lmclient/models/minimax_pro.py @@ -18,7 +18,6 @@ ModelParameters, ModelResponse, ) -from lmclient.utils import to_dict DEFAULT_MINIMAX_BOT_NAME = 'MM智能助理' DEFAULT_MINIMAX_USER_NAME = '用户' @@ -105,7 +104,7 @@ def get_request_parameters(self, messages: Messages, parameters: MinimaxProChatP } json_data = {'model': self.model, 'messages': [self._lmclient_to_minimax(message) for message in messages]} - parameters_dict = to_dict(parameters, exclude_none=True) + parameters_dict = parameters.model_dump(exclude_none=True) if 'temperature' in parameters_dict: parameters_dict['temperature'] = max(0.01, parameters_dict['temperature']) json_data.update(parameters_dict) diff --git a/lmclient/models/openai.py b/lmclient/models/openai.py index 3a73f4b..c38fd38 100644 --- a/lmclient/models/openai.py +++ b/lmclient/models/openai.py @@ -10,7 +10,6 @@ from lmclient.models.http import HttpChatModel, RetryStrategy from lmclient.parser import ParserError from lmclient.types import FunctionCallDict, FunctionDict, GeneralParameters, Message, Messages, ModelParameters, ModelResponse -from lmclient.utils import to_dict class FunctionCallNameDict(TypedDict): @@ -144,7 +143,7 @@ def get_request_parameters(self, messages: Messages, parameters: OpenAIChatParam headers = { 'Authorization': f'Bearer {self.api_key}', } - parameters_dict = to_dict(parameters, exclude_defaults=True) + parameters_dict = parameters.model_dump(exclude_defaults=True) openai_messages = [convert_lmclient_to_openai(message) for message in messages] params = { 'model': self.model, diff --git a/lmclient/models/zhipu.py b/lmclient/models/zhipu.py index 9fefd8f..e41fe50 100644 --- a/lmclient/models/zhipu.py +++ b/lmclient/models/zhipu.py @@ -12,7 +12,6 @@ from lmclient.models.http import HttpChatModel, RetryStrategy from lmclient.parser import ParserError from lmclient.types import GeneralParameters, Message, Messages, ModelParameters, ModelResponse -from lmclient.utils import to_dict T = TypeVar('T') API_TOKEN_TTL_SECONDS = 3 * 60 @@ -96,7 +95,7 @@ def get_request_parameters(self, messages: Messages, parameters: ZhiPuChatParame headers = { 'Authorization': generate_token(self.api_key), } - parameters_dict = to_dict(parameters, exclude_defaults=True) + parameters_dict = parameters.model_dump(exclude_defaults=True) params = {'prompt': zhipu_messages, **parameters_dict} return { 'url': f'{self.api_base}/{self.model}/invoke', diff --git a/lmclient/types.py b/lmclient/types.py index 9abba55..8d09c48 100644 --- a/lmclient/types.py +++ b/lmclient/types.py @@ -5,6 +5,21 @@ from pydantic import BaseModel, ConfigDict, Field from typing_extensions import NotRequired, Self, TypedDict +Messages = List['Message'] +ModelResponse = Dict[str, Any] +Prompt = Union[str, 'Message', 'MessageDict', Sequence[Union['MessageDict', 'Message']]] + + +class FunctionDict(TypedDict): + name: str + description: NotRequired[str] + parameters: Dict[str, Any] + + +class FunctionCallDict(TypedDict): + name: str + arguments: str + class Message(BaseModel): role: str @@ -22,12 +37,6 @@ class MessageDict(TypedDict): name: NotRequired[str] -class FunctionDict(TypedDict): - name: str - description: NotRequired[str] - parameters: Dict[str, Any] - - class GeneralParameters(BaseModel): temperature: float = 1 top_p: float = 1 @@ -59,13 +68,3 @@ class RetryStrategy(BaseModel): class HttpChatModelOutput(ChatModelOutput): response: ModelResponse = Field(default_factory=dict) - - -class FunctionCallDict(TypedDict): - name: str - arguments: str - - -Messages = List[Message] -ModelResponse = Dict[str, Any] -Prompt = Union[str, Message, MessageDict, Sequence[Union[MessageDict, Message]]] diff --git a/lmclient/utils.py b/lmclient/utils.py index 2887cd2..1591563 100644 --- a/lmclient/utils.py +++ b/lmclient/utils.py @@ -25,7 +25,7 @@ def get_pydantic_version(): def generate_chat_completion_hash_key(model_id: str, messages: Messages, parameters: ModelParameters) -> str: - messages_text = '---'.join([f'{k}={v}' for message in messages for k, v in to_dict(message).items()]) + messages_text = '---'.join([f'{k}={v}' for message in messages for k, v in message.model_dump().items()]) messages_hash = md5_hash(messages_text) parameters_hash = md5_hash(parameters.model_dump_json(exclude_none=True)) return f'{model_id}|{messages_hash}|{parameters_hash}|v{__cache_version__}' @@ -121,10 +121,3 @@ def from_message(cls, message: Message): raise MessageError(f'{message} is not a valid function call message') arguments = json.loads(function_call['arguments'], strict=False) return cls(**arguments) - - -def to_dict(value: BaseModel, exclude_defaults: bool = False, exclude_none: bool = False): - if PydanticVersion == 2: - return value.model_dump(exclude_defaults=exclude_defaults, exclude_none=exclude_none) - else: - return value.dict(exclude_defaults=exclude_defaults, exclude_none=exclude_none) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_chat.py b/tests/test_chat.py new file mode 100644 index 0000000..283a585 --- /dev/null +++ b/tests/test_chat.py @@ -0,0 +1,31 @@ +from lmclient import ChatEngine, MinimaxProChat, function + + +@function +def get_weather(loc: str) -> str: + """ + 获取指定地区的天气信息 + + Parameters: + loc: 地区,比如北京,上海等 + """ + return f"{loc},晴朗,27度" + + +@function +def google(keyword: str) -> str: + """ + 搜索谷歌 + + Parameters: + keyword: 搜索关键词 + """ + return '没有内容' + + +def test_function(): + model = MinimaxProChat() + engine = ChatEngine(model, temperature=0, functions=[get_weather, google]) + reply = engine.chat('今天北京天气怎么样?') + assert '27' in reply + From afd06a9aa78df6fc8cc97036fad6203f8a22200a Mon Sep 17 00:00:00 2001 From: wangyuxin Date: Sun, 3 Sep 2023 12:43:46 +0800 Subject: [PATCH 4/6] add wenxin --- lmclient/__init__.py | 4 ++ lmclient/chat_engine.py | 4 ++ lmclient/client.py | 4 +- lmclient/exceptions.py | 4 ++ lmclient/models/__init__.py | 4 ++ lmclient/models/azure.py | 4 ++ lmclient/models/minimax_pro.py | 5 +- lmclient/models/openai.py | 4 ++ lmclient/models/wenxin.py | 126 +++++++++++++++++++++++++++++++++ lmclient/types.py | 7 +- tests/test_chat.py | 3 +- tests/test_client.py | 1 - tests/test_model.py | 83 +++++++++++++--------- 13 files changed, 209 insertions(+), 44 deletions(-) create mode 100644 lmclient/models/wenxin.py diff --git a/lmclient/__init__.py b/lmclient/__init__.py index a4f4842..c313ad0 100644 --- a/lmclient/__init__.py +++ b/lmclient/__init__.py @@ -6,6 +6,8 @@ MinimaxProChatParameters, OpenAIChat, OpenAIChatParameters, + WenxinChat, + WenxinChatParameters, ZhiPuChat, ZhiPuChatParameters, ) @@ -30,6 +32,8 @@ 'MinimaxProChatParameters', 'ZhiPuChat', 'ZhiPuChatParameters', + 'WenxinChat', + 'WenxinChatParameters', 'BaseSchema', 'function', '__version__', diff --git a/lmclient/chat_engine.py b/lmclient/chat_engine.py index 82082b2..cddd0ed 100644 --- a/lmclient/chat_engine.py +++ b/lmclient/chat_engine.py @@ -57,6 +57,10 @@ def extra_parameters(self, extra_parameters: dict[str, Any]): self._extra_parameters = extra_parameters self._parameters = self._parameters.model_copy(update=self._extra_parameters) + @property + def chat_model(self): + return self._chat_model + def chat(self, user_input: str, **extra_parameters: Any) -> str: parameters = self._parameters.model_copy(update=extra_parameters) self.history.append(Message(role='user', content=user_input)) diff --git a/lmclient/client.py b/lmclient/client.py index 7b05e15..c3cd7da 100644 --- a/lmclient/client.py +++ b/lmclient/client.py @@ -130,7 +130,7 @@ async def _async_run_single_task( if self.error_mode is ErrorMode.RAISE: raise elif self.error_mode is ErrorMode.IGNORE: - return ChatModelOutput(messages=[Message(role='Error', content=f'Error: {e}')]) + return ChatModelOutput(messages=[Message(role='error', content=f'Error: {e}')]) else: raise ValueError(f'Unknown error mode: {self.error_mode}') from e @@ -152,7 +152,7 @@ def _run_single_task( if self.error_mode is ErrorMode.RAISE: raise elif self.error_mode is ErrorMode.IGNORE: - return ChatModelOutput(messages=[Message(role='Error', content=str(e))]) + return ChatModelOutput(messages=[Message(role='error', content=str(e))]) else: raise ValueError(f'Unknown error mode: {self.error_mode}') from e diff --git a/lmclient/exceptions.py b/lmclient/exceptions.py index 0d0c18c..9ac893d 100644 --- a/lmclient/exceptions.py +++ b/lmclient/exceptions.py @@ -4,3 +4,7 @@ class MessageError(Exception): """ pass + + +class ResponseError(Exception): + pass diff --git a/lmclient/models/__init__.py b/lmclient/models/__init__.py index 86a17e1..cc3d5fa 100644 --- a/lmclient/models/__init__.py +++ b/lmclient/models/__init__.py @@ -4,6 +4,7 @@ from lmclient.models.base import BaseChatModel from lmclient.models.minimax_pro import MinimaxProChat, MinimaxProChatParameters from lmclient.models.openai import OpenAIChat, OpenAIChatParameters +from lmclient.models.wenxin import WenxinChat, WenxinChatParameters from lmclient.models.zhipu import ZhiPuChat, ZhiPuChatParameters ModelRegistry = { @@ -11,6 +12,7 @@ OpenAIChat.model_type: OpenAIChat, MinimaxProChat.model_type: MinimaxProChat, ZhiPuChat.model_type: ZhiPuChat, + WenxinChat.model_type: WenxinChat, } @@ -36,4 +38,6 @@ def list_chat_model_types(): 'OpenAIChatParameters', 'ZhiPuChat', 'ZhiPuChatParameters', + 'WenxinChat', + 'WenxinChatParameters', ] diff --git a/lmclient/models/azure.py b/lmclient/models/azure.py index dbf6292..bae7fab 100644 --- a/lmclient/models/azure.py +++ b/lmclient/models/azure.py @@ -19,6 +19,7 @@ class AzureChat(HttpChatModel[OpenAIChatParameters]): def __init__( self, model: str | None = None, + system_prompt: str | None = None, api_key: str | None = None, api_base: str | None = None, api_version: str | None = None, @@ -29,6 +30,7 @@ def __init__( ): super().__init__(parameters=parameters, timeout=timeout, retry=retry, use_cache=use_cache) self.model = model or os.environ['AZURE_CHAT_API_ENGINE'] or os.environ['AZURE_CHAT_MODEL_NAME'] + self.system_prompt = system_prompt self.api_key = api_key or os.environ['AZURE_API_KEY'] self.api_base = api_base or os.environ['AZURE_API_BASE'] self.api_version = api_version or os.getenv('AZURE_API_VERSION') @@ -39,6 +41,8 @@ def get_request_parameters(self, messages: Messages, parameters: OpenAIChatParam } parameters_dict = parameters.model_dump(exclude_defaults=True) openai_messages = [convert_lmclient_to_openai(message) for message in messages] + if self.system_prompt: + openai_messages.insert(0, {'role': 'system', 'content': self.system_prompt}) params = { 'model': self.model, 'messages': openai_messages, diff --git a/lmclient/models/minimax_pro.py b/lmclient/models/minimax_pro.py index 6c00e02..034317e 100644 --- a/lmclient/models/minimax_pro.py +++ b/lmclient/models/minimax_pro.py @@ -17,6 +17,7 @@ Messages, ModelParameters, ModelResponse, + Role, ) DEFAULT_MINIMAX_BOT_NAME = 'MM智能助理' @@ -120,10 +121,10 @@ def parse_model_reponse(self, response: ModelResponse) -> Messages: @staticmethod def _minimax_to_lmclient(message: MinimaxMessageDict) -> Message: - role_map = { + role_map: dict[str, Role] = { 'USER': 'user', 'BOT': 'assistant', - 'FUNCTION': 'funtion', + 'FUNCTION': 'function', } if 'function_call' in message: diff --git a/lmclient/models/openai.py b/lmclient/models/openai.py index c38fd38..c515016 100644 --- a/lmclient/models/openai.py +++ b/lmclient/models/openai.py @@ -127,6 +127,7 @@ class OpenAIChat(HttpChatModel[OpenAIChatParameters]): def __init__( self, model: str = 'gpt-3.5-turbo', + system_prompt: str | None = None, api_key: str | None = None, api_base: str | None = None, timeout: int | None = 60, @@ -136,6 +137,7 @@ def __init__( ): super().__init__(parameters=parameters, timeout=timeout, retry=retry, use_cache=use_cache) self.model = model + self.system_prompt = system_prompt self.api_base = api_base or os.getenv('OPENAI_API_BASE') or 'https://api.openai.com/v1' self.api_key = api_key or os.environ['OPENAI_API_KEY'] @@ -145,6 +147,8 @@ def get_request_parameters(self, messages: Messages, parameters: OpenAIChatParam } parameters_dict = parameters.model_dump(exclude_defaults=True) openai_messages = [convert_lmclient_to_openai(message) for message in messages] + if self.system_prompt: + openai_messages.insert(0, {'role': 'system', 'content': self.system_prompt}) params = { 'model': self.model, 'messages': openai_messages, diff --git a/lmclient/models/wenxin.py b/lmclient/models/wenxin.py new file mode 100644 index 0000000..0a21d38 --- /dev/null +++ b/lmclient/models/wenxin.py @@ -0,0 +1,126 @@ +from __future__ import annotations + +import os +from datetime import datetime, timedelta +from email.errors import MessageError +from pathlib import Path +from typing import Any, Literal, Optional + +import httpx +from typing_extensions import Self, TypedDict + +from lmclient.exceptions import ResponseError +from lmclient.models.http import HttpChatModel +from lmclient.types import GeneralParameters, Message, Messages, ModelParameters, ModelResponse, RetryStrategy + +WENXIN_ACCESS_TOKEN_URL = 'https://aip.baidubce.com/oauth/2.0/token' +WENXIN_BASE_URL = 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/' + + +class WenxinMessageDict(TypedDict): + role: Literal['user', 'assistant'] + content: str + + +class WenxinChatParameters(ModelParameters): + temperature: Optional[float] = None + top_p: Optional[float] = None + penalty_score: Optional[float] = None + + @classmethod + def from_general_parameters(cls, general_parameters: GeneralParameters) -> Self: + return cls( + temperature=general_parameters.temperature, + top_p=general_parameters.top_p, + ) + + +class WenxinChat(HttpChatModel[WenxinChatParameters]): + model_type = 'wenxin' + model_name_entrypoint_map: dict[str, str] = { + 'llama_2_7b': 'llama_2_7b', + 'llama_2_13b': 'llama_2_13b', + 'llama_2_70b': 'llama_2_70b', + 'ERNIE-Bot': 'completions', + 'ERNIE-Bot-turbo': 'eb-instant', + } + access_token_refresh_days: int = 20 + + def __init__( + self, + model: str = 'ERNIE-Bot', + api_key: str | None = None, + secret_key: str | None = None, + parameters: WenxinChatParameters = WenxinChatParameters(), + timeout: int | None = None, + retry: bool | RetryStrategy = False, + use_cache: Path | str | bool = False, + ): + super().__init__(parameters, timeout, retry, use_cache) + self.model = self.normalize_model(model) + self._api_key = api_key or os.getenv('WENXIN_API_KEY') + self._secret_key = secret_key or os.getenv('WENXIN_SECRET_KEY') + self._access_token = self.get_access_token() + self._access_token_expires_at = datetime.now() + timedelta(days=self.access_token_refresh_days) + + @property + def name(self) -> str: + return self.model + + @property + def api_url(self) -> str: + return WENXIN_BASE_URL + self.model_name_entrypoint_map[self.model] + + @staticmethod + def normalize_model(model: str): + _map = { + 'llama-2-7b-chat': 'llama_2_7b', + 'llama-2-13b-chat': 'llama_2_13b', + 'llama-2-70b-chat': 'llama_2_70b', + } + return _map.get(model, model) + + def get_access_token(self, base_url: str = WENXIN_ACCESS_TOKEN_URL) -> str: + headers = {'Content-Type': 'application/json', 'Accept': 'application/json'} + params = {'grant_type': 'client_credentials', 'client_id': self._api_key, 'client_secret': self._secret_key} + response = httpx.post(base_url, headers=headers, params=params) + response.raise_for_status() + response_dict = response.json() + if 'error' in response_dict: + raise ResponseError(response_dict['error_description']) + return response_dict['access_token'] + + def get_request_parameters(self, messages: Messages, parameters: WenxinChatParameters) -> dict[str, Any]: + self.maybe_refresh_access_token() + + message_dicts: list[WenxinMessageDict] = [] + for message in messages: + role = message.role + if role != 'assistant' and role != 'user': + raise MessageError(f'Invalid message role: {role}, only "user" and "assistant" are allowed') + if not isinstance(content := message.content, str): + raise MessageError(f'Invalid message content: {content}, only string is allowed') + message_dicts.append(WenxinMessageDict(content=content, role=role)) + parameters_dict = parameters.model_dump(exclude_none=True) + json_data = {'messages': message_dicts, **parameters_dict} + + return { + 'url': self.api_url, + 'json': json_data, + 'params': {'access_token': self._access_token}, + 'headers': {'Content-Type': 'application/json'}, + } + + def parse_model_reponse(self, response: ModelResponse) -> Messages: + if response.get('error_msg'): + raise ResponseError(response['error_msg']) + return [Message(role='assistant', content=response['result'])] + + def maybe_refresh_access_token(self): + if self._access_token_expires_at < datetime.now(): + self._access_token = self.get_access_token() + self._access_token_expires_at = datetime.now() + timedelta(days=self.access_token_refresh_days) + + @classmethod + def from_name(cls, name: str, **kwargs: Any) -> Self: + return cls(model=name, **kwargs) diff --git a/lmclient/types.py b/lmclient/types.py index 8d09c48..84aebe6 100644 --- a/lmclient/types.py +++ b/lmclient/types.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any, Dict, List, Optional, Sequence, Union +from typing import Any, Dict, List, Literal, Optional, Sequence, Union from pydantic import BaseModel, ConfigDict, Field from typing_extensions import NotRequired, Self, TypedDict @@ -8,6 +8,7 @@ Messages = List['Message'] ModelResponse = Dict[str, Any] Prompt = Union[str, 'Message', 'MessageDict', Sequence[Union['MessageDict', 'Message']]] +Role = Literal['user', 'assistant', 'function', 'error'] class FunctionDict(TypedDict): @@ -22,7 +23,7 @@ class FunctionCallDict(TypedDict): class Message(BaseModel): - role: str + role: Role content: Union[str, FunctionCallDict] name: Optional[str] = None @@ -32,7 +33,7 @@ def is_function_call(self) -> bool: class MessageDict(TypedDict): - role: str + role: Role content: Union[str, FunctionCallDict] name: NotRequired[str] diff --git a/tests/test_chat.py b/tests/test_chat.py index 283a585..f929182 100644 --- a/tests/test_chat.py +++ b/tests/test_chat.py @@ -9,7 +9,7 @@ def get_weather(loc: str) -> str: Parameters: loc: 地区,比如北京,上海等 """ - return f"{loc},晴朗,27度" + return f'{loc},晴朗,27度' @function @@ -28,4 +28,3 @@ def test_function(): engine = ChatEngine(model, temperature=0, functions=[get_weather, google]) reply = engine.chat('今天北京天气怎么样?') assert '27' in reply - diff --git a/tests/test_client.py b/tests/test_client.py index ad30047..1bcd0e7 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -59,7 +59,6 @@ def test_async_completion(): start_time = time.perf_counter() messages: list[MessageDict] = [ - {'role': 'system', 'content': 'your are lmclient demo assistant'}, {'role': 'user', 'content': 'hello, who are you?'}, ] prompts = ['Hello, my name is', 'I am a student', messages] * 4 diff --git a/tests/test_model.py b/tests/test_model.py index 00a180b..d843451 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -1,14 +1,12 @@ import anyio +import pytest -from lmclient.models import AzureChat, MinimaxProChat, OpenAIChat, ZhiPuChat -from lmclient.types import Message +from lmclient.models import AzureChat, BaseChatModel, MinimaxProChat, OpenAIChat, WenxinChat, ZhiPuChat +from lmclient.types import HttpChatModelOutput, Message, ModelParameters -test_messages = [Message(role='user', content='hello')] - - -def test_azure_model(): - chat_model = AzureChat() +@pytest.mark.parametrize('chat_model', (AzureChat(), MinimaxProChat(), OpenAIChat(), ZhiPuChat(), WenxinChat())) +def test_http_chat_model(chat_model: BaseChatModel[ModelParameters, HttpChatModelOutput]): test_messages = [Message(role='user', content='hello')] sync_output = chat_model.chat_completion(test_messages) async_output = anyio.run(chat_model.async_chat_completion, test_messages) @@ -19,40 +17,57 @@ def test_azure_model(): assert isinstance(async_output.reply, str) -def test_openai_model(): - chat_model = OpenAIChat('gpt-3.5-turbo') +# def test_azure_model(): +# chat_model = AzureChat() - test_messages = [Message(role='user', content='hello')] - sync_output = chat_model.chat_completion(test_messages) - async_output = anyio.run(chat_model.async_chat_completion, test_messages) +# test_messages = [Message(role='user', content='hello')] +# sync_output = chat_model.chat_completion(test_messages) +# async_output = anyio.run(chat_model.async_chat_completion, test_messages) - assert isinstance(sync_output.response, dict) - assert isinstance(sync_output.reply, str) - assert isinstance(async_output.response, dict) - assert isinstance(async_output.reply, str) +# assert isinstance(sync_output.response, dict) +# assert isinstance(sync_output.reply, str) +# assert isinstance(async_output.response, dict) +# assert isinstance(async_output.reply, str) -def test_minimax_model(): - chat_model = MinimaxProChat() +# def test_openai_model(): +# chat_model = OpenAIChat('gpt-3.5-turbo') - test_messages = [Message(role='user', content='hello')] - sync_output = chat_model.chat_completion(test_messages) - async_output = anyio.run(chat_model.async_chat_completion, test_messages) +# test_messages = [Message(role='user', content='hello')] +# sync_output = chat_model.chat_completion(test_messages) +# async_output = anyio.run(chat_model.async_chat_completion, test_messages) - assert isinstance(sync_output.response, dict) - assert isinstance(sync_output.reply, str) - assert isinstance(async_output.response, dict) - assert isinstance(async_output.reply, str) +# assert isinstance(sync_output.response, dict) +# assert isinstance(sync_output.reply, str) +# assert isinstance(async_output.response, dict) +# assert isinstance(async_output.reply, str) -def test_zhipu_model(): - chat_model = ZhiPuChat() +# def test_minimax_model(): +# chat_model = MinimaxProChat() - test_messages = [Message(role='user', content='hello')] - sync_output = chat_model.chat_completion(test_messages) - async_output = anyio.run(chat_model.async_chat_completion, test_messages) +# test_messages = [Message(role='user', content='hello')] +# sync_output = chat_model.chat_completion(test_messages) +# async_output = anyio.run(chat_model.async_chat_completion, test_messages) - assert isinstance(sync_output.response, dict) - assert isinstance(sync_output.reply, str) - assert isinstance(async_output.response, dict) - assert isinstance(async_output.reply, str) +# assert isinstance(sync_output.response, dict) +# assert isinstance(sync_output.reply, str) +# assert isinstance(async_output.response, dict) +# assert isinstance(async_output.reply, str) + + +# def test_zhipu_model(): +# chat_model = ZhiPuChat() + +# test_messages = [Message(role='user', content='hello')] +# sync_output = chat_model.chat_completion(test_messages) +# async_output = anyio.run(chat_model.async_chat_completion, test_messages) + +# assert isinstance(sync_output.response, dict) +# assert isinstance(sync_output.reply, str) +# assert isinstance(async_output.response, dict) +# assert isinstance(async_output.reply, str) + + +# def test_wenxin_model(): +# WeninChat() From 729239bab994d92e8dc7b3ea123b05b54b5169c7 Mon Sep 17 00:00:00 2001 From: wangyuxin Date: Sun, 3 Sep 2023 12:44:58 +0800 Subject: [PATCH 5/6] =?UTF-8?q?=E6=9B=B4=E6=96=B0=20readme?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 4c917f6..e95a993 100644 --- a/README.md +++ b/README.md @@ -13,7 +13,7 @@ 5. 支持磁盘缓存 6. 100% type hints 7. 非常易用 -8. 支持 OpenAI, Azure, Minimax, ZhiPu 等模型 +8. 支持 OpenAI, Azure, Minimax, 智谱, 百度文心 9. 支持 FunctionCall ## 安装方式 From 6d6a3825b16fd8d5f5559111734697eaedb6994d Mon Sep 17 00:00:00 2001 From: wangyuxin Date: Sun, 3 Sep 2023 16:34:09 +0800 Subject: [PATCH 6/6] =?UTF-8?q?=E6=94=AF=E6=8C=81=E5=BC=82=E6=AD=A5chat?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- lmclient/chat_engine.py | 22 ++++++++++++++++++++++ scripts/multimodel_chat.py | 23 +++++++++++++++++++++++ 2 files changed, 45 insertions(+) create mode 100644 scripts/multimodel_chat.py diff --git a/lmclient/chat_engine.py b/lmclient/chat_engine.py index cddd0ed..bd3d56a 100644 --- a/lmclient/chat_engine.py +++ b/lmclient/chat_engine.py @@ -71,6 +71,16 @@ def chat(self, user_input: str, **extra_parameters: Any) -> str: else: return self._recursive_function_call(reply, parameters) + async def async_chat(self, user_input: str, **extra_parameters: Any) -> str: + parameters = self._parameters.model_copy(update=extra_parameters) + self.history.append(Message(role='user', content=user_input)) + model_response = await self._chat_model.async_chat_completion(self.history, parameters) + self.history.extend(model_response.messages) + if isinstance(reply := model_response.messages[-1].content, str): + return reply + else: + return await self._async_recursive_function_call(reply, parameters) + def run_function_call(self, function_call: FunctionCallDict): function = None for i in self.functions: @@ -103,5 +113,17 @@ def _recursive_function_call(self, function_call: FunctionCallDict, parameters: else: return self._recursive_function_call(reply, parameters) + async def _async_recursive_function_call(self, function_call: FunctionCallDict, parameters: T_P) -> str: + function_output = self.run_function_call(function_call) + self.history.append( + Message(role='function', name=function_call['name'], content=json.dumps(function_output, ensure_ascii=False)) + ) + model_response = await self._chat_model.async_chat_completion(self.history, parameters) + self.history.extend(model_response.messages) + if isinstance(reply := model_response.messages[-1].content, str): + return reply + else: + return self._recursive_function_call(reply, parameters) + def reset(self) -> None: self.history.clear() diff --git a/scripts/multimodel_chat.py b/scripts/multimodel_chat.py new file mode 100644 index 0000000..55cd291 --- /dev/null +++ b/scripts/multimodel_chat.py @@ -0,0 +1,23 @@ +import asyncio + +from lmclient import ChatEngine, MinimaxProChat, OpenAIChat, WenxinChat, ZhiPuChat + +chat_models = { + 'wenxin': WenxinChat(timeout=20), + 'llama2-70b': WenxinChat(model='llama_2_70b', timeout=20), + 'gpt4': OpenAIChat(model='gpt-4'), + 'gpt3.5': OpenAIChat(model='gpt-3.5-turbo'), + 'minimax': MinimaxProChat(), + 'zhipu': ZhiPuChat(), +} +engines = {model_name: ChatEngine(chat_model=chat_model) for model_name, chat_model in chat_models.items()} # type: ignore + + +async def multimodel_chat(user_input: str): + reply_list = await asyncio.gather(*[engine.async_chat(user_input) for engine in engines.values()]) + for model_name, reply in zip(engines.keys(), reply_list): + print(f'{model_name}: {reply}') + + +if __name__ == '__main__': + asyncio.run(multimodel_chat('你好'))