diff --git a/lmclient/__init__.py b/lmclient/__init__.py index ff19653..0606b63 100644 --- a/lmclient/__init__.py +++ b/lmclient/__init__.py @@ -1,10 +1,6 @@ from lmclient.client import LMClient as LMClient -from lmclient.client import LMClientForStructuredData as LMClientForStructuredData from lmclient.models import AzureChat as AzureChat from lmclient.models import MinimaxChat as MinimaxChat from lmclient.models import OpenAIChat as OpenAIChat +from lmclient.models import OpenAIExtract as OpenAIExtract from lmclient.models import ZhiPuChat as ZhiPuChat -from lmclient.parsers import * # noqa: F403 - -AzureCompletion = AzureChat -OpenAICompletion = OpenAIChat diff --git a/lmclient/client.py b/lmclient/client.py index bae17c9..a516eee 100644 --- a/lmclient/client.py +++ b/lmclient/client.py @@ -1,30 +1,20 @@ from __future__ import annotations -import hashlib import os import time from enum import Enum from pathlib import Path -from typing import ClassVar, Generic, Sequence, Type, TypeVar, cast +from typing import ClassVar, Generic, Sequence, TypeVar import anyio import asyncer -import diskcache import tqdm -from lmclient.models import AzureChat, BaseChatModel, OpenAIChat -from lmclient.parsers import MinimaxTextParser, ModelResponseParser, OpenAIParser, OpenAISchema, ZhiPuParser -from lmclient.types import ModelResponse, Prompt, TaskResult -from lmclient.utils import ensure_messages -from lmclient.version import __cache_version__ +from lmclient.models import BaseChatModel +from lmclient.openai_schema import OpenAISchema +from lmclient.types import ChatModelOutput, Prompt DEFAULT_CACHE_DIR = Path(os.getenv('LMCLIENT_CACHE_DIR', '~/.cache/lmclient')).expanduser().resolve() -DEFAULT_MODEL_PARSER_MAP: dict[str, Type[ModelResponseParser]] = { - 'OpenAIChat': OpenAIParser, - 'AzureChat': OpenAIParser, - 'MinimaxChat': MinimaxTextParser, - 'ZhiPuChat': ZhiPuParser, -} T = TypeVar('T') T_O = TypeVar('T_O', bound=OpenAISchema) @@ -43,61 +33,39 @@ class ProgressBarMode(str, Enum): class LMClient(Generic[T]): error_mode: ErrorMode - _cache_dir: Path | None NUM_SECONDS_PER_MINUTE: ClassVar[int] = 60 PROGRESS_BAR_THRESHOLD: ClassVar[int] = 20 def __init__( self, - model: BaseChatModel, - max_requests_per_minute: int = 20, + chat_model: BaseChatModel[T], async_capacity: int = 3, + max_requests_per_minute: int = 20, error_mode: ErrorMode | str = ErrorMode.RAISE, - cache_dir: Path | str | None = DEFAULT_CACHE_DIR, progress_bar: ProgressBarMode | str = ProgressBarMode.AUTO, - max_retry_attempt: int | None = None, - output_parser: ModelResponseParser[T] | None = None, ): - self.model = model + self.chat_model = chat_model self.async_capacity = async_capacity self.max_requests_per_minute = max_requests_per_minute self.error_mode = ErrorMode(error_mode) self.progress_bar_mode = ProgressBarMode(progress_bar) - self.max_retry_attempt = max_retry_attempt self._task_created_time_list: list[int] = [] - self.cache_dir = Path(cache_dir) if cache_dir is not None else None - self.output_parser = output_parser or DEFAULT_MODEL_PARSER_MAP[self.model.__class__.__name__]() - - @property - def cache_dir(self) -> Path | None: - return self._cache_dir - - @cache_dir.setter - def cache_dir(self, value: Path | None) -> None: - if value is not None: - if value.exists() and not value.is_dir(): - raise ValueError(f'Cache directory {value} is not a directory') - value.mkdir(parents=True, exist_ok=True) - self._cache = diskcache.Cache(value) - else: - self._cache = None - - def run(self, prompts: Sequence[Prompt], **kwargs) -> list[TaskResult[T]]: + def run(self, prompts: Sequence[Prompt], **kwargs) -> list[ChatModelOutput[T]]: progress_bar = self._get_progress_bar(num_tasks=len(prompts)) - task_results: list[TaskResult] = [] + task_results: list[ChatModelOutput[T]] = [] for prompt in prompts: task_result = self._run_single_task(prompt=prompt, progress_bar=progress_bar, **kwargs) task_results.append(task_result) progress_bar.close() return task_results - async def _async_run(self, prompts: Sequence[Prompt], **kwargs) -> list[TaskResult[T]]: + async def _async_run(self, prompts: Sequence[Prompt], **kwargs) -> list[ChatModelOutput[T]]: limiter = anyio.CapacityLimiter(self.async_capacity) task_created_lock = anyio.Lock() progress_bar = self._get_progress_bar(num_tasks=len(prompts)) - soon_values: list[asyncer.SoonValue[TaskResult]] = [] + soon_values: list[asyncer.SoonValue[ChatModelOutput[T]]] = [] async with asyncer.create_task_group() as task_group: soon_func = task_group.soonify(self._async_run_single_task) for prompt in prompts: @@ -114,7 +82,7 @@ async def _async_run(self, prompts: Sequence[Prompt], **kwargs) -> list[TaskResu values = [soon_value.value for soon_value in soon_values] return values - def async_run(self, prompts: Sequence[Prompt], **kwargs) -> list[TaskResult[T]]: + def async_run(self, prompts: Sequence[Prompt], **kwargs) -> list[ChatModelOutput[T]]: return asyncer.runnify(self._async_run)(prompts, **kwargs) async def _async_run_single_task( @@ -124,10 +92,10 @@ async def _async_run_single_task( task_created_lock: anyio.Lock, progress_bar: tqdm.tqdm, **kwargs, - ) -> TaskResult: + ) -> ChatModelOutput: async with limiter: - task_key = self._gen_task_key(prompt=prompt, **kwargs) - response = self.read_from_cache(task_key) + task_key = self.chat_model.generate_hash_key(prompt=prompt, **kwargs) + response = self.chat_model.try_load_response(task_key) if response is None: async with task_created_lock: @@ -136,83 +104,40 @@ async def _async_run_single_task( await anyio.sleep(sleep_time) self._task_created_time_list.append(int(time.time())) - try: - if self.max_retry_attempt is None: - response = await self.model.async_chat(prompt=prompt, **kwargs) - else: - response = await self.model.async_chat_with_retry( - prompt=prompt, max_attempt=self.max_retry_attempt, **kwargs - ) - if self._cache is not None: - self._cache[task_key] = response - except BaseException as e: - if self.error_mode is ErrorMode.RAISE: - raise - elif self.error_mode is ErrorMode.IGNORE: - return TaskResult(error_message=str(e)) - else: - raise ValueError(f'Unknown error mode: {self.error_mode}') from e - try: - output = self.output_parser(response) - result = TaskResult(response=response, output=output) + output = await self.chat_model.async_chat(prompt=prompt, **kwargs) + progress_bar.update(1) + return output except BaseException as e: if self.error_mode is ErrorMode.RAISE: raise elif self.error_mode is ErrorMode.IGNORE: - result = TaskResult(error_message=str(e), response=response) + return ChatModelOutput(error_message=str(e)) else: raise ValueError(f'Unknown error mode: {self.error_mode}') from e - progress_bar.update(1) - return result + def _run_single_task(self, prompt: Prompt, progress_bar: tqdm.tqdm, **kwargs) -> ChatModelOutput: + task_key = self.chat_model.generate_hash_key(prompt=prompt, **kwargs) + response = self.chat_model.try_load_response(task_key) - def _run_single_task(self, prompt: Prompt, progress_bar: tqdm.tqdm, **kwargs) -> TaskResult: - task_key = self._gen_task_key(prompt=prompt, **kwargs) - - response = self.read_from_cache(task_key) if response is None: sleep_time = self._calculate_sleep_time() if sleep_time > 0: time.sleep(sleep_time) self._task_created_time_list.append(int(time.time())) - try: - if self.max_retry_attempt is None: - response = self.model.chat(prompt=prompt, **kwargs) - else: - response = self.model.chat_with_retry(prompt=prompt, max_retry_attempt=self.max_retry_attempt, **kwargs) - if self._cache is not None: - self._cache[task_key] = response - except BaseException as e: - if self.error_mode is ErrorMode.RAISE: - raise - elif self.error_mode is ErrorMode.IGNORE: - return TaskResult(error_message=str(e)) - else: - raise ValueError(f'Unknown error mode: {self.error_mode}') from e - try: - output = self.output_parser(response) - result = TaskResult(response=response, output=output) + output = self.chat_model.chat(prompt=prompt, **kwargs) + progress_bar.update(1) + return output except BaseException as e: if self.error_mode is ErrorMode.RAISE: raise elif self.error_mode is ErrorMode.IGNORE: - result = TaskResult(error_message=str(e), response=response) + return ChatModelOutput(output=f'Response Error: {e}', response={}) else: raise ValueError(f'Unknown error mode: {self.error_mode}') from e - progress_bar.update(1) - return result - - def read_from_cache(self, key: str) -> ModelResponse | None: - if self._cache is not None and key in self._cache: - response = self._cache[key] - response = cast(ModelResponse, response) - return response - return - def _calculate_sleep_time(self): idx = 0 current_time = time.time() @@ -227,82 +152,9 @@ def _calculate_sleep_time(self): else: return max(self.NUM_SECONDS_PER_MINUTE - int(current_time - self._task_created_time_list[0]) + 1, 0) - def _gen_task_key(self, prompt: Prompt, **kwargs) -> str: - messages = ensure_messages(prompt) - if not isinstance(prompt, str): - hash_text = '---'.join([f'{k}={v}' for message in messages for k, v in message.items()]) - else: - hash_text = prompt - items = sorted([f'{key}={value}' for key, value in kwargs.items()]) - items += [f'__cache_version__={__cache_version__}'] - items = [hash_text, self.model.identifier] + items - task_string = '---'.join(items) - return self.md5_hash(task_string) - - @staticmethod - def md5_hash(string: str): - return hashlib.md5(string.encode()).hexdigest() - def _get_progress_bar(self, num_tasks: int) -> tqdm.tqdm: use_progress_bar = (self.progress_bar_mode is ProgressBarMode.ALWAYS) or ( self.progress_bar_mode is ProgressBarMode.AUTO and num_tasks > self.PROGRESS_BAR_THRESHOLD ) - progress_bar = tqdm.tqdm(desc=f'{self.model.__class__.__name__}', total=num_tasks, disable=not use_progress_bar) + progress_bar = tqdm.tqdm(desc=f'{self.chat_model.__class__.__name__}', total=num_tasks, disable=not use_progress_bar) return progress_bar - - -class LMClientForStructuredData(LMClient[T_O]): - SupportedModels = [OpenAIChat, AzureChat] - - def __init__( - self, - model: BaseChatModel, - schema: Type[T_O], - system_prompt: str = 'Generate structured data from a given text', - max_requests_per_minute: int = 20, - async_capacity: int = 3, - error_mode: ErrorMode | str = ErrorMode.RAISE, - cache_dir: Path | str | None = DEFAULT_CACHE_DIR, - progress_bar: ProgressBarMode | str = ProgressBarMode.AUTO, - max_retry_attempt: int | None = None, - ): - if not any(isinstance(model, supported_model) for supported_model in self.SupportedModels): - raise ValueError(f'Unsupported model: {model.__class__.__name__}. Supported models: {self.SupportedModels}') - self.system_prompt = system_prompt - self.default_kwargs = { - 'functions': [schema.openai_schema()], - 'function_call': {'name': schema.openai_schema()['name']}, - } - - super().__init__( - model=model, - max_requests_per_minute=max_requests_per_minute, - async_capacity=async_capacity, - error_mode=error_mode, - cache_dir=cache_dir, - progress_bar=progress_bar, - output_parser=schema.from_response, - max_retry_attempt=max_retry_attempt, - ) - - def run(self, prompts: Sequence[str], **kwargs) -> list[TaskResult[T_O]]: - assembled_prompts = [] - for prompt in prompts: - messages = [ - {'role': 'system', 'text': self.system_prompt}, - {'role': 'user', 'text': prompt}, - ] - assembled_prompts.append(messages) - kwargs = {**self.default_kwargs, **kwargs} - return super().run(prompts, **kwargs) - - async def _async_run(self, prompts: Sequence[str], **kwargs) -> list[TaskResult[T_O]]: - assembled_prompts = [] - for prompt in prompts: - messages = [ - {'role': 'system', 'text': self.system_prompt}, - {'role': 'user', 'text': prompt}, - ] - assembled_prompts.append(messages) - kwargs = {**self.default_kwargs, **kwargs} - return await super()._async_run(prompts, **kwargs) diff --git a/lmclient/exceptions.py b/lmclient/exceptions.py deleted file mode 100644 index 01f32c1..0000000 --- a/lmclient/exceptions.py +++ /dev/null @@ -1,2 +0,0 @@ -class ParserError(Exception): - """Error raised when postprocess function fails.""" diff --git a/lmclient/models/__init__.py b/lmclient/models/__init__.py index ceb2e6c..c9a898d 100644 --- a/lmclient/models/__init__.py +++ b/lmclient/models/__init__.py @@ -2,4 +2,5 @@ from lmclient.models.base import BaseChatModel as BaseChatModel from lmclient.models.minimax import MinimaxChat as MinimaxChat from lmclient.models.openai import OpenAIChat as OpenAIChat +from lmclient.models.openai import OpenAIExtract as OpenAIExtract from lmclient.models.zhipu import ZhiPuChat as ZhiPuChat diff --git a/lmclient/models/azure.py b/lmclient/models/azure.py index 5c30fb0..b3ea9a0 100644 --- a/lmclient/models/azure.py +++ b/lmclient/models/azure.py @@ -1,68 +1,51 @@ from __future__ import annotations import os +from pathlib import Path +from typing import Any, TypeVar -import httpx +from lmclient.models.base import HttpChatModel, RetryStrategy +from lmclient.models.openai import OpenAIContentParser +from lmclient.parser import ModelResponseParser +from lmclient.types import Messages -from lmclient.models.base import BaseChatModel -from lmclient.types import ModelResponse, Prompt -from lmclient.utils import ensure_messages +T = TypeVar('T') -class AzureChat(BaseChatModel): +class AzureChat(HttpChatModel[T]): def __init__( self, - model_name: str | None = None, + model: str | None = None, api_key: str | None = None, api_base: str | None = None, api_version: str | None = None, timeout: int | None = 60, + response_parser: ModelResponseParser[T] | None = None, + retry: bool | RetryStrategy = False, + use_cache: Path | str | bool = False, ): - self.model_name = model_name or os.environ['AZURE_CHAT_API_ENGINE'] or os.environ['AZURE_CHAT_MODEL_NAME'] + response_parser = response_parser or OpenAIContentParser() + super().__init__(timeout=timeout, response_parser=response_parser, retry=retry, use_cache=use_cache) + self.model = model or os.environ['AZURE_CHAT_API_ENGINE'] or os.environ['AZURE_CHAT_MODEL_NAME'] self.api_key = api_key or os.environ['AZURE_API_KEY'] self.api_base = api_base or os.environ['AZURE_API_BASE'] self.api_version = api_version or os.getenv('AZURE_API_VERSION') - self.timeout = timeout - def chat(self, prompt: Prompt, **kwargs) -> ModelResponse: - messages = ensure_messages(prompt) + def get_post_parameters(self, messages: Messages, **kwargs) -> dict[str, Any]: headers = { 'api-key': self.api_key, } params = { - 'model': self.model_name, + 'model': self.model, 'messages': messages, **kwargs, } - response = httpx.post( - url=f'{self.api_base}/openai/deployments/{self.model_name}/chat/completions?api-version={self.api_version}', - headers=headers, - json=params, - timeout=self.timeout, - ) - response.raise_for_status() - return response.json() - - async def async_chat(self, prompt: Prompt, **kwargs) -> ModelResponse: - messages = ensure_messages(prompt) - headers = { - 'api-key': self.api_key, - } - params = { - 'model': self.model_name, - 'messages': messages, - **kwargs, + return { + 'url': f'{self.api_base}/openai/deployments/{self.model}/chat/completions?api-version={self.api_version}', + 'headers': headers, + 'json': params, } - async with httpx.AsyncClient() as client: - response = await client.post( - url=f'{self.api_base}/openai/deployments/{self.model_name}/chat/completions?api-version={self.api_version}', - headers=headers, - json=params, - timeout=self.timeout, - ) - response.raise_for_status() - return response.json() @property def identifier(self) -> str: - return f'{self.__class__.__name__}({self.model_name})' + return f'{self.__class__.__name__}({self.model})' diff --git a/lmclient/models/base.py b/lmclient/models/base.py index b115bbe..23fa834 100644 --- a/lmclient/models/base.py +++ b/lmclient/models/base.py @@ -1,28 +1,201 @@ from __future__ import annotations +import hashlib +import os +from pathlib import Path +from typing import Any, Generic, TypeVar, cast + +import diskcache +import httpx + +try: + from pydantic.v1 import BaseModel +except ImportError: + from pydantic import BaseModel from tenacity import retry, stop_after_attempt, wait_random_exponential -from lmclient.types import ModelResponse, Prompt +from lmclient.parser import ModelResponseParser +from lmclient.types import ChatModelOutput, Messages, ModelResponse, Prompt +from lmclient.utils import ensure_messages +from lmclient.version import __cache_version__ +T = TypeVar('T') +DEFAULT_CACHE_DIR = Path(os.getenv('LMCLIENT_CACHE_DIR', '~/.cache/lmclient')).expanduser().resolve() -class BaseChatModel: - def chat(self, prompt: Prompt, **kwargs) -> ModelResponse: - ... - async def async_chat(self, prompt: Prompt, **kwargs) -> ModelResponse: - ... +class BaseChatModel(Generic[T]): + _cache: diskcache.Cache | None + _cache_dir: Path | None - def chat_with_retry(self, prompt: Prompt, max_wait: int = 20, max_attempt: int = 3, **kwargs) -> ModelResponse: - return retry(wait=wait_random_exponential(min=1, max=max_wait), stop=stop_after_attempt(max_attempt))(self.chat)( - prompt=prompt, **kwargs - ) + def __init__( + self, + response_parser: ModelResponseParser[T] | None = None, + use_cache: Path | str | bool = False, + ) -> None: + self.response_parser = response_parser - async def async_chat_with_retry(self, prompt: Prompt, max_wait: int = 20, max_attempt: int = 3, **kwargs) -> ModelResponse: - return await retry(wait=wait_random_exponential(min=1, max=max_wait), stop=stop_after_attempt(max_attempt))( - self.async_chat - )(prompt=prompt, **kwargs) + if isinstance(use_cache, (str, Path)): + self.cache_dir = Path(use_cache) + elif use_cache: + self.cache_dir = DEFAULT_CACHE_DIR + else: + self.cache_dir = None @property def identifier(self) -> str: - self_dict_string = ', '.join(f'{k}={v}' for k, v in self.__dict__.items()) - return f'{self.__class__.__name__}({self_dict_string})' + raise NotImplementedError + + def call_model(self, messages: Messages, **kwargs) -> ModelResponse: + raise NotImplementedError + + async def async_call_model(self, messages: Messages, **kwargs) -> ModelResponse: + raise NotImplementedError + + def chat(self, prompt: Prompt, **kwargs) -> ChatModelOutput[T]: + messages = ensure_messages(prompt) + + if self.use_cache: + hash_key = self.generate_hash_key(prompt) + model_response = self.try_load_response(hash_key) + if model_response is None: + model_response = self.call_model(messages, **kwargs) + self.cache_response(hash_key, model_response) + else: + model_response = self.call_model(messages, **kwargs) + + if self.response_parser is None: + parsed_result = None + else: + parsed_result = self.response_parser(model_response) + + return ChatModelOutput( + parsed_result=parsed_result, + response=model_response, + ) + + async def async_chat(self, prompt: Prompt, **kwargs) -> ChatModelOutput[T]: + messages = ensure_messages(prompt) + + if self.use_cache: + hash_key = self.generate_hash_key(prompt) + model_response = self.try_load_response(hash_key) + if model_response is None: + model_response = await self.async_call_model(messages, **kwargs) + self.cache_response(hash_key, model_response) + else: + model_response = await self.async_call_model(messages, **kwargs) + + if self.response_parser is None: + parsed_result = None + else: + parsed_result = self.response_parser(model_response) + + return ChatModelOutput( + parsed_result=parsed_result, + response=model_response, + ) + + def cache_response(self, key: str, response: ModelResponse) -> None: + if self._cache is not None: + self._cache[key] = response + else: + raise RuntimeError('Cache is not enabled') + + def try_load_response(self, key: str): + if self._cache is not None and key in self._cache: + response = self._cache[key] + response = cast(ModelResponse, response) + return response + + def generate_hash_key(self, prompt: Prompt, **kwargs) -> str: + if isinstance(prompt, str): + hash_text = prompt + else: + hash_text = '---'.join([f'{k}={v}' for message in prompt for k, v in message.items()]) + items = sorted([f'{key}={value}' for key, value in kwargs.items()]) + items += [f'__cache_version__={__cache_version__}'] + items = [hash_text, self.identifier] + items + task_string = '---'.join(items) + return self.md5_hash(task_string) + + @staticmethod + def md5_hash(string: str): + return hashlib.md5(string.encode()).hexdigest() + + @property + def use_cache(self) -> bool: + return self._cache is not None + + @property + def cache_dir(self) -> Path | None: + return self._cache_dir + + @cache_dir.setter + def cache_dir(self, value: Path | None) -> None: + if value is not None: + if value.exists() and not value.is_dir(): + raise ValueError(f'Cache directory {value} is not a directory') + value.mkdir(parents=True, exist_ok=True) + self._cache = diskcache.Cache(value) + else: + self._cache = None + + +class RetryStrategy(BaseModel): # type: ignore + min_wait_seconds: int = 2 + max_wait_seconds: int = 20 + max_attempt: int = 3 + + +class HttpChatModel(BaseChatModel[T]): + def __init__( + self, + timeout: int | None = None, + retry: bool | RetryStrategy = False, + response_parser: ModelResponseParser[T] | None = None, + use_cache: Path | str | bool = False, + ): + super().__init__(response_parser=response_parser, use_cache=use_cache) + self.timeout = timeout + if isinstance(retry, RetryStrategy): + self.retry_strategy = retry + else: + self.retry_strategy = RetryStrategy() if retry else None + + def get_post_parameters(self, messages: Messages, **kwargs) -> dict[str, Any]: + raise NotImplementedError + + def call_model(self, messages: Messages, **kwargs) -> ModelResponse: + parameters = self.get_post_parameters(messages, **kwargs) + parameters = {'timeout': self.timeout, **parameters} + http_response = httpx.post(**parameters) + http_response.raise_for_status() + model_response = http_response.json() + return model_response + + async def async_call_model(self, messages: Messages, **kwargs) -> ModelResponse: + async with httpx.AsyncClient() as client: + parameters = self.get_post_parameters(messages, **kwargs) + parameters = {'timeout': self.timeout, **parameters} + http_response = await client.post(**parameters) + http_response.raise_for_status() + model_response = http_response.json() + return model_response + + def chat(self, prompt: Prompt, **kwargs) -> ChatModelOutput[T]: + if self.retry_strategy is None: + return super().chat(prompt, **kwargs) + + wait = wait_random_exponential(min=self.retry_strategy.min_wait_seconds, max=self.retry_strategy.max_wait_seconds) + stop = stop_after_attempt(self.retry_strategy.max_attempt) + output = retry(wait=wait, stop=stop)(super().chat)(prompt=prompt, **kwargs) + return output + + async def async_chat(self, prompt: Prompt, **kwargs) -> ChatModelOutput[T]: + if self.retry_strategy is None: + return await super().async_chat(prompt, **kwargs) + + wait = wait_random_exponential(min=self.retry_strategy.min_wait_seconds, max=self.retry_strategy.max_wait_seconds) + stop = stop_after_attempt(self.retry_strategy.max_attempt) + output = await retry(wait=wait, stop=stop)(super().async_chat)(prompt=prompt, **kwargs) + return output diff --git a/lmclient/models/minimax.py b/lmclient/models/minimax.py index 47acb98..4857059 100644 --- a/lmclient/models/minimax.py +++ b/lmclient/models/minimax.py @@ -1,31 +1,43 @@ from __future__ import annotations import os -from typing import Any +from pathlib import Path +from typing import Any, TypeVar -import httpx +from lmclient.models.base import HttpChatModel, RetryStrategy +from lmclient.parser import ModelResponseParser, ParserError +from lmclient.types import Messages, ModelResponse -from lmclient.models.base import BaseChatModel -from lmclient.types import Messages, ModelResponse, Prompt -from lmclient.utils import ensure_messages +T = TypeVar('T') -class MinimaxChat(BaseChatModel): +class MinimaxTextParser(ModelResponseParser): + def __call__(self, response: ModelResponse) -> str: + try: + output = response['choices'][0]['text'] + except (KeyError, IndexError) as e: + raise ParserError('Parse response failed') from e + return output + + +class MinimaxChat(HttpChatModel[T]): def __init__( self, - model_name: str = 'abab5.5-chat', + model: str = 'abab5.5-chat', group_id: str | None = None, api_key: str | None = None, timeout: int | None = 60, + response_parser: ModelResponseParser[T] | None = None, + retry: bool | RetryStrategy = False, + use_cache: Path | str | bool = False, ): - self.model_name = model_name + response_parser = response_parser or MinimaxTextParser() + super().__init__(timeout=timeout, response_parser=response_parser, retry=retry, use_cache=use_cache) + self.model = model self.group_id = group_id or os.environ['MINIMAX_GROUP_ID'] self.api_key = api_key or os.environ['MINIMAX_API_KEY'] - self.timeout = timeout - - def chat(self, prompt: Prompt, **kwargs) -> ModelResponse: - messages = ensure_messages(prompt) + def get_post_parameters(self, messages: Messages, **kwargs) -> dict[str, Any]: headers = { 'Authorization': f'Bearer {self.api_key}', 'Content-Type': 'application/json', @@ -34,38 +46,15 @@ def chat(self, prompt: Prompt, **kwargs) -> ModelResponse: if 'temperature' in kwargs: kwargs['temperature'] = max(0.01, kwargs['temperature']) json_data.update(kwargs) - response = httpx.post( - f'https://api.minimax.chat/v1/text/chatcompletion?GroupId={self.group_id}', - json=json_data, - headers=headers, - timeout=self.timeout, - ).json() - return response - - async def async_chat(self, prompt: Prompt, **kwargs) -> ModelResponse: - messages = ensure_messages(prompt) - - async with httpx.AsyncClient() as client: - headers = { - 'Authorization': f'Bearer {self.api_key}', - 'Content-Type': 'application/json', - } - json_data = self._messages_to_request_json_data(messages) - if 'temperature' in kwargs: - kwargs['temperature'] = max(0.01, kwargs['temperature']) - json_data.update(kwargs) - response = await client.post( - f'https://api.minimax.chat/v1/text/chatcompletion?GroupId={self.group_id}', - json=json_data, - headers=headers, - timeout=self.timeout, - ) - response = response.json() - return response + return { + 'url': f'https://api.minimax.chat/v1/text/chatcompletion?GroupId={self.group_id}', + 'json': json_data, + 'headers': headers, + } def _messages_to_request_json_data(self, messages: Messages): data: dict[str, Any] = { - 'model': self.model_name, + 'model': self.model, 'role_meta': {'user_name': '用户', 'bot_name': 'MM智能助理'}, } @@ -94,4 +83,4 @@ def _messages_to_request_json_data(self, messages: Messages): @property def identifier(self) -> str: - return f'{self.__class__.__name__}({self.model_name})' + return f'{self.__class__.__name__}({self.model})' diff --git a/lmclient/models/openai.py b/lmclient/models/openai.py index 73fc194..b65cd16 100644 --- a/lmclient/models/openai.py +++ b/lmclient/models/openai.py @@ -1,29 +1,73 @@ from __future__ import annotations import os +from pathlib import Path +from typing import Any, Type, TypeVar -import httpx +from lmclient.models.base import HttpChatModel, RetryStrategy +from lmclient.openai_schema import OpenAISchema +from lmclient.parser import ModelResponseParser, ParserError +from lmclient.types import Messages, ModelResponse -from lmclient.models.base import BaseChatModel -from lmclient.types import ModelResponse, Prompt -from lmclient.utils import ensure_messages +T = TypeVar('T') +T_O = TypeVar('T_O', bound=OpenAISchema) -class OpenAIChat(BaseChatModel): +class OpenAIParser(ModelResponseParser): + def __call__(self, response: ModelResponse) -> str | dict[str, str]: + try: + if self.is_function_call(response): + fucntion_call_output: dict[str, str] = response['choices'][0]['message']['function_call'] + return fucntion_call_output + else: + content_output: str = response['choices'][0]['message']['content'] + return content_output + except (KeyError, IndexError) as e: + raise ParserError('Parse response failed') from e + + @staticmethod + def is_function_call(reponse: ModelResponse) -> bool: + message = reponse['choices'][0]['message'] + return bool(message.get('function_call')) + + +class OpenAIFunctionCallParser(ModelResponseParser): + def __call__(self, response: ModelResponse) -> dict[str, str]: + try: + output: dict[str, str] = response['choices'][0]['message']['function_call'] + except (KeyError, IndexError) as e: + raise ParserError('Parse response failed') from e + return output + + +class OpenAIContentParser(ModelResponseParser): + def __call__(self, response: ModelResponse) -> str: + try: + output: str = response['choices'][0]['message']['content'] + except (KeyError, IndexError) as e: + raise ParserError('Parse response failed') from e + return output + + +class OpenAIChat(HttpChatModel[T]): def __init__( self, - model_name: str = 'gpt-3.5-turbo', + model: str = 'gpt-3.5-turbo', api_key: str | None = None, api_base: str | None = None, timeout: int | None = 60, + response_parser: ModelResponseParser[T] | None = None, + retry: bool | RetryStrategy = False, + use_cache: Path | str | bool = False, ): - self.model = model_name + response_parser = response_parser or OpenAIContentParser() + super().__init__(timeout=timeout, response_parser=response_parser, retry=retry, use_cache=use_cache) + self.model = model self.api_base = api_base or os.getenv('OPENAI_API_BASE') or 'https://api.openai.com/v1' self.api_key = api_key or os.environ['OPENAI_API_KEY'] self.timeout = timeout - def chat(self, prompt: Prompt, **kwargs) -> ModelResponse: - messages = ensure_messages(prompt) + def get_post_parameters(self, messages: Messages, **kwargs) -> dict[str, Any]: headers = { 'Authorization': f'Bearer {self.api_key}', } @@ -32,35 +76,55 @@ def chat(self, prompt: Prompt, **kwargs) -> ModelResponse: 'messages': messages, **kwargs, } - response = httpx.post( - url=f'{self.api_base}/chat/completions', - headers=headers, - json=params, - timeout=self.timeout, - ) - response.raise_for_status() - return response.json() - - async def async_chat(self, prompt: Prompt, **kwargs) -> ModelResponse: - messages = ensure_messages(prompt) + return { + 'url': f'{self.api_base}/chat/completions', + 'headers': headers, + 'json': params, + } + + @property + def identifier(self) -> str: + return f'{self.__class__.__name__}({self.model})' + + +class OpenAIExtract(HttpChatModel[T_O]): + def __init__( + self, + schema: Type[T_O], + model: str = 'gpt-3.5-turbo', + system_prompt: str = 'Extract structured data from a given text', + api_key: str | None = None, + api_base: str | None = None, + timeout: int | None = 60, + retry: bool | RetryStrategy = False, + use_cache: Path | str | bool = False, + ): + super().__init__(timeout=timeout, response_parser=schema.from_response, retry=retry, use_cache=use_cache) + self.schema = schema + self.model = model + self.system_prompt = system_prompt + self.api_base = api_base or os.getenv('OPENAI_API_BASE') or 'https://api.openai.com/v1' + self.api_key = api_key or os.environ['OPENAI_API_KEY'] + self.timeout = timeout + + def get_post_parameters(self, messages: Messages, **kwargs) -> dict[str, Any]: + messages = [{'role': 'system', 'content': self.system_prompt}] + list(messages) # type: ignore headers = { 'Authorization': f'Bearer {self.api_key}', } params = { 'model': self.model, 'messages': messages, + 'functions': [self.schema.openai_schema()], + 'function_call': {'name': self.schema.openai_schema()['name']}, **kwargs, } - async with httpx.AsyncClient() as client: - response = await client.post( - url=f'{self.api_base}/chat/completions', - headers=headers, - json=params, - timeout=self.timeout, - ) - response.raise_for_status() - return response.json() + return { + 'url': f'{self.api_base}/chat/completions', + 'headers': headers, + 'json': params, + } @property def identifier(self) -> str: - return f'{self.__class__.__name__}({self.model})' + return f'{self.__class__.__name__}(model={self.model}, system_prompt={self.system_prompt})' diff --git a/lmclient/models/zhipu.py b/lmclient/models/zhipu.py index ac02a69..08518e7 100644 --- a/lmclient/models/zhipu.py +++ b/lmclient/models/zhipu.py @@ -2,15 +2,17 @@ import os import time +from pathlib import Path +from typing import Any, TypeVar import cachetools.func # type: ignore -import httpx import jwt -from lmclient.models.base import BaseChatModel -from lmclient.types import ModelResponse, Prompt -from lmclient.utils import ensure_messages +from lmclient.models.base import HttpChatModel, RetryStrategy +from lmclient.parser import ModelResponseParser, ParserError +from lmclient.types import Messages, ModelResponse +T = TypeVar('T') API_TOKEN_TTL_SECONDS = 3 * 60 CACHE_TTL_SECONDS = API_TOKEN_TTL_SECONDS - 30 @@ -36,53 +38,48 @@ def generate_token(api_key: str): ) -class ZhiPuChat(BaseChatModel): +class ZhiPuParser(ModelResponseParser): + def __call__(self, response: ModelResponse) -> str: + try: + output = response['data']['choices'][0]['content'].strip('"').strip() + except (KeyError, IndexError) as e: + raise ParserError(f'Parse response failed, reponse: {response}') from e + return output + + +class ZhiPuChat(HttpChatModel[T]): def __init__( - self, model_name: str = 'chatglm_pro', api_base: str | None = None, api_key: str | None = None, timeout: int | None = 60 + self, + model: str = 'chatglm_pro', + api_base: str | None = None, + api_key: str | None = None, + timeout: int | None = 60, + response_parser: ModelResponseParser[T] | None = None, + retry: bool | RetryStrategy = False, + use_cache: Path | str | bool = False, ) -> None: - self.model_name = model_name + response_parser = response_parser or ZhiPuParser() + super().__init__(timeout=timeout, response_parser=response_parser, retry=retry, use_cache=use_cache) + self.model = model self.api_key = api_key or os.environ['ZHIPU_API_KEY'] - self.timeout = timeout self.api_base = api_base or os.getenv('ZHIPU_API_BASE') or 'https://open.bigmodel.cn/api/paas/v3/model-api' self.api_base.rstrip('/') - def chat( - self, prompt: Prompt, temperature: float | None = None, top_p: float | None = None, request_id: str | None = None - ) -> ModelResponse: - messages = ensure_messages(prompt) + def get_post_parameters(self, messages: Messages, **kwargs) -> dict[str, Any]: + for message in messages: + if message['role'] not in ('user', 'assistant'): + raise ValueError(f'Role of message must be user or assistant, but got {message["role"]}') headers = { 'Authorization': generate_token(self.api_key), } - params = {'prompt': messages, 'temperature': temperature, 'top_p': top_p, 'request_id': request_id} - reponse = httpx.post( - url=f'{self.api_base}/{self.model_name}/invoke', - headers=headers, - json=params, - timeout=self.timeout, - ) - reponse.raise_for_status() - return reponse.json() - - async def async_chat( - self, prompt: Prompt, temperature: float | None = None, top_p: float | None = None, request_id: str | None = None - ) -> ModelResponse: - messages = ensure_messages(prompt) - - headers = { - 'Authorization': generate_token(self.api_key), + params = {'prompt': messages, **kwargs} + return { + 'url': f'{self.api_base}/{self.model}/invoke', + 'headers': headers, + 'json': params, } - params = {'prompt': messages, 'temperature': temperature, 'top_p': top_p, 'request_id': request_id} - async with httpx.AsyncClient() as client: - reponse = await client.post( - url=f'{self.api_base}/{self.model_name}/invoke', - headers=headers, - json=params, - timeout=self.timeout, - ) - reponse.raise_for_status() - return reponse.json() @property def identifier(self) -> str: - return f'{self.__class__.__name__}({self.model_name})' + return f'{self.__class__.__name__}({self.model})' diff --git a/lmclient/parsers/openai.py b/lmclient/openai_schema.py similarity index 58% rename from lmclient/parsers/openai.py rename to lmclient/openai_schema.py index 2691273..ad44b21 100644 --- a/lmclient/parsers/openai.py +++ b/lmclient/openai_schema.py @@ -1,7 +1,4 @@ -from __future__ import annotations - import json -from typing import TypeVar try: from pydantic.v1 import BaseModel @@ -10,50 +7,10 @@ from pydantic import BaseModel from pydantic import Field as Field -from lmclient.exceptions import ParserError -from lmclient.parsers.base import ModelResponseParser +from lmclient.parser import ParserError from lmclient.types import ModelResponse -T = TypeVar('T', bound='OpenAISchema') - - -class OpenAIParser(ModelResponseParser): - def __call__(self, response: ModelResponse) -> str | dict[str, str]: - try: - if self.is_function_call(response): - fucntion_call_output: dict[str, str] = response['choices'][0]['message']['function_call'] - return fucntion_call_output - else: - content_output: str = response['choices'][0]['message']['content'] - return content_output - except (KeyError, IndexError) as e: - raise ParserError('Parse response failed') from e - - @staticmethod - def is_function_call(reponse: ModelResponse) -> bool: - message = reponse['choices'][0]['message'] - return bool(message.get('function_call')) - - -class OpenAIFunctionCallParser(ModelResponseParser): - def __call__(self, response: ModelResponse) -> dict[str, str]: - try: - output: dict[str, str] = response['choices'][0]['message']['function_call'] - except (KeyError, IndexError) as e: - raise ParserError('Parse response failed') from e - return output - - -class OpenAIContentParser(ModelResponseParser): - def __call__(self, response: ModelResponse) -> str: - try: - output: str = response['choices'][0]['message']['content'] - except (KeyError, IndexError) as e: - raise ParserError('Parse response failed') from e - return output - -# COPY FROM openai_function_call def _remove_a_key(d, remove_key) -> None: """Remove a key from a dictionary recursively""" if isinstance(d, dict): diff --git a/lmclient/parsers/base.py b/lmclient/parser.py similarity index 71% rename from lmclient/parsers/base.py rename to lmclient/parser.py index e6dcc13..b2990db 100644 --- a/lmclient/parsers/base.py +++ b/lmclient/parser.py @@ -8,3 +8,7 @@ class ModelResponseParser(Protocol[T]): def __call__(self, response: ModelResponse) -> T: ... + + +class ParserError(Exception): + """Error raised when postprocess function fails.""" diff --git a/lmclient/parsers/__init__.py b/lmclient/parsers/__init__.py deleted file mode 100644 index 95029a6..0000000 --- a/lmclient/parsers/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -from lmclient.parsers.base import ModelResponseParser -from lmclient.parsers.minimax import MinimaxTextParser -from lmclient.parsers.openai import Field, OpenAIContentParser, OpenAIFunctionCallParser, OpenAIParser, OpenAISchema -from lmclient.parsers.zhipu import ZhiPuParser - -__all__ = [ - 'ModelResponseParser', - 'MinimaxTextParser', - 'ZhiPuParser', - 'OpenAIContentParser', - 'OpenAIFunctionCallParser', - 'OpenAIParser', - 'OpenAISchema', - 'Field', -] diff --git a/lmclient/parsers/minimax.py b/lmclient/parsers/minimax.py deleted file mode 100644 index ad208fa..0000000 --- a/lmclient/parsers/minimax.py +++ /dev/null @@ -1,12 +0,0 @@ -from lmclient.exceptions import ParserError -from lmclient.parsers.base import ModelResponseParser -from lmclient.types import ModelResponse - - -class MinimaxTextParser(ModelResponseParser): - def __call__(self, response: ModelResponse) -> str: - try: - output = response['choices'][0]['message']['text'] - except (KeyError, IndexError) as e: - raise ParserError('Parse response failed') from e - return output diff --git a/lmclient/parsers/zhipu.py b/lmclient/parsers/zhipu.py deleted file mode 100644 index eacd232..0000000 --- a/lmclient/parsers/zhipu.py +++ /dev/null @@ -1,12 +0,0 @@ -from lmclient.exceptions import ParserError -from lmclient.parsers.base import ModelResponseParser -from lmclient.types import ModelResponse - - -class ZhiPuParser(ModelResponseParser): - def __call__(self, response: ModelResponse) -> str: - try: - output = response['data']['choices'][0]['content'].strip('"').strip() - except (KeyError, IndexError) as e: - raise ParserError('Parse response failed') from e - return output diff --git a/lmclient/types.py b/lmclient/types.py index da82259..6c95405 100644 --- a/lmclient/types.py +++ b/lmclient/types.py @@ -26,7 +26,6 @@ class Message(TypedDict): Prompt = Union[str, Sequence[dict]] -class TaskResult(BaseModel, Generic[T]): # type: ignore - output: Optional[T] = None +class ChatModelOutput(BaseModel, Generic[T]): # type: ignore + parsed_result: Optional[T] = None response: ModelResponse = Field(default_factory=dict) - error_message: Optional[str] = None diff --git a/lmclient/version.py b/lmclient/version.py index a2618f8..c11ef48 100644 --- a/lmclient/version.py +++ b/lmclient/version.py @@ -1,2 +1,2 @@ -__version__ = '0.5.1' +__version__ = '0.6.0' __cache_version__ = '3' diff --git a/pyproject.toml b/pyproject.toml index 068248c..f30df87 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "lmclient-core" -version = "0.5.1" +version = "0.6.0" description = "LM Async Client, openai client, azure openai client ..." authors = ["wangyuxin "] readme = "README.md" diff --git a/scripts/ner.py b/scripts/ner.py index 42c66cb..a32f992 100644 --- a/scripts/ner.py +++ b/scripts/ner.py @@ -7,8 +7,9 @@ import typer -from lmclient import AzureChat, Field, LMClientForStructuredData, OpenAIChat, OpenAISchema +from lmclient import LMClient, OpenAIExtract from lmclient.client import ErrorMode +from lmclient.openai_schema import Field, OpenAISchema class ModelType(str, Enum): @@ -35,43 +36,28 @@ def read_from_jsonl(file: str | Path): def main( input_josnl_file: Path, output_file: Path, - model_type: ModelType = ModelType.openai, max_requests_per_minute: int = 20, async_capacity: int = 3, - error_mode: ErrorMode = ErrorMode.IGNORE, - cache: bool = False, + error_mode: ErrorMode = ErrorMode.RAISE, + use_cache: bool = False, ): - if model_type is ModelType.azure: - model = AzureChat( - 'gpt-35-turbo-16k', - api_version='2023-07-01-preview', - ) - else: - model = OpenAIChat('gpt-3.5-turbo') + model = OpenAIExtract( + schema=NerInfo, + use_cache=use_cache, + ) - client = LMClientForStructuredData( + client = LMClient( model, - schema=NerInfo, - system_prompt='You are a NER model, extract entity information from the text.', max_requests_per_minute=max_requests_per_minute, async_capacity=async_capacity, error_mode=error_mode, ) - if not cache: - client.cache_dir = None - texts = read_from_jsonl(input_josnl_file) - results = client.async_run(texts) + model_outputs = client.async_run(texts) with open(output_file, 'w') as f: - for text, result in zip(texts, results): - if result.output is None: - output = None - else: - try: - output = result.output.model_dump() - except AttributeError: - output = result.output.dict() + for text, output in zip(texts, model_outputs): + output = output.parsed_result.dict() if output.parsed_result else None output_dict = {'text': text, 'output': output} f.write(json.dumps(output_dict, ensure_ascii=False) + '\n') diff --git a/scripts/translate.py b/scripts/translate.py index 5fc7af5..1e817f6 100644 --- a/scripts/translate.py +++ b/scripts/translate.py @@ -24,22 +24,21 @@ def main( max_requests_per_minute: int = 20, async_capacity: int = 3, error_mode: ErrorMode = ErrorMode.IGNORE, - cache_dir: str = 'lmclient-translate-cache', + use_cache: bool = True, ): if model_name == 'azure': - model = AzureChat() + model = AzureChat(use_cache=use_cache) elif model_name == 'minimax': - model = MinimaxChat('abab5.5-chat') + model = MinimaxChat(use_cache=use_cache) else: - model = OpenAIChat(model_name) + model = OpenAIChat(model=model_name, use_cache=use_cache) client = LMClient[str]( model, max_requests_per_minute=max_requests_per_minute, async_capacity=async_capacity, error_mode=error_mode, - cache_dir=cache_dir, ) texts = read_from_jsonl(input_josnl_file) @@ -51,7 +50,7 @@ def main( with open(output_file, 'w') as f: for text, result in zip(texts, results): - f.write(json.dumps({'text': text, 'translation': result.output}, ensure_ascii=False) + '\n') + f.write(json.dumps({'text': text, 'translation': result.parsed_result}, ensure_ascii=False) + '\n') if __name__ == '__main__': diff --git a/tests/test_client.py b/tests/test_client.py index ea880b5..6e8219f 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -8,43 +8,49 @@ class TestModel(BaseChatModel): - def chat(self, prompt: str | Messages, **kwargs) -> ModelResponse: + def call_model(self, messages: Messages, **kwargs) -> ModelResponse: return { - 'content': f'Completed: {prompt}', + 'content': f'Completed: {messages[-1]["content"]}', } - async def async_chat(self, prompt: str | Messages, **kwargs) -> ModelResponse: + async def async_call_model(self, messages: Messages, **kwargs) -> ModelResponse: return { - 'content': f'Completed: {prompt}', + 'content': f'Completed: {messages[-1]["content"]}', } def default_postprocess_function(self, response: ModelResponse) -> str: return response['content'] + @property + def identifier(self) -> str: + return 'TestModel' + def model_parser(response): return response['content'] def test_sync_completion(): - completion_model = TestModel() - client = LMClient(completion_model, output_parser=model_parser, cache_dir=None) - - messages = [ - {'role': 'system', 'content': 'your are lmclient demo assistant'}, - {'role': 'user', 'content': 'hello, who are you?'}, + completion_model = TestModel(response_parser=model_parser, use_cache=False) + client = LMClient(completion_model) + prompts = [ + 'Hello, my name is', + [ + {'role': 'system', 'content': 'your are lmclient demo assistant'}, + {'role': 'user', 'content': 'hello, who are you?'}, + ], ] - prompts = ['Hello, my name is', 'I am a student', 'I like to play basketball', messages] results = client.run(prompts) - assert isinstance(results[0].output, str) - assert results[0].output == 'Completed: Hello, my name is' + assert isinstance(results[0].parsed_result, str) + assert results[0].parsed_result == 'Completed: Hello, my name is' + assert results[1].parsed_result == 'Completed: hello, who are you?' assert len(results) == len(prompts) def test_async_completion(): - completion_model = TestModel() - client = LMClient(completion_model, async_capacity=2, max_requests_per_minute=5, cache_dir=None, output_parser=model_parser) + completion_model = TestModel(response_parser=model_parser, use_cache=False) + client = LMClient(completion_model, async_capacity=2, max_requests_per_minute=5) LMClient.NUM_SECONDS_PER_MINUTE = 2 start_time = time.perf_counter() @@ -57,16 +63,14 @@ def test_async_completion(): elapsed_time = time.perf_counter() - start_time assert results[0].response['content'] == 'Completed: Hello, my name is' - assert results[0].output == 'Completed: Hello, my name is' + assert results[0].parsed_result == 'Completed: Hello, my name is' assert len(results) == len(prompts) assert elapsed_time > 4 def test_async_completion_with_cache(tmp_path): - completion_model = TestModel() - client = LMClient( - completion_model, async_capacity=2, max_requests_per_minute=5, cache_dir=tmp_path, output_parser=model_parser - ) + completion_model = TestModel(use_cache=tmp_path) + client = LMClient(completion_model, async_capacity=2, max_requests_per_minute=5) LMClient.NUM_SECONDS_PER_MINUTE = 2 start_time = time.perf_counter() @@ -78,4 +82,4 @@ def test_async_completion_with_cache(tmp_path): assert results[3].response['content'] == 'Completed: Hello, my name is' assert len(results) == len(prompts) assert elapsed_time < 2 - assert len(list(client._cache)) == 3 # type: ignore + assert len(list(completion_model._cache)) == 3 # type: ignore diff --git a/tests/test_model.py b/tests/test_model.py index e07bf80..9dee9b5 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -2,6 +2,7 @@ import pytest from lmclient.models import AzureChat, MinimaxChat, OpenAIChat, ZhiPuChat +from lmclient.models.openai import OpenAIContentParser @pytest.mark.parametrize( @@ -12,13 +13,15 @@ ], ) def test_azure_model(prompt): - model = AzureChat() + model = AzureChat(response_parser=OpenAIContentParser()) sync_output = model.chat(prompt) async_output = anyio.run(model.async_chat, prompt) - assert isinstance(sync_output, dict) - assert isinstance(async_output, dict) + assert isinstance(sync_output.response, dict) + assert isinstance(sync_output.parsed_result, str) + assert isinstance(async_output.response, dict) + assert isinstance(async_output.parsed_result, str) @pytest.mark.parametrize( @@ -29,13 +32,15 @@ def test_azure_model(prompt): ], ) def test_openai_model(prompt): - completion_model = OpenAIChat('gpt-3.5-turbo') + chat_model = OpenAIChat('gpt-3.5-turbo', response_parser=OpenAIContentParser()) - sync_output = completion_model.chat(prompt) - async_output = anyio.run(completion_model.async_chat, prompt) + sync_output = chat_model.chat(prompt) + async_output = anyio.run(chat_model.async_chat, prompt) - assert isinstance(sync_output, dict) - assert isinstance(async_output, dict) + assert isinstance(sync_output.response, dict) + assert isinstance(sync_output.parsed_result, str) + assert isinstance(async_output.response, dict) + assert isinstance(async_output.parsed_result, str) @pytest.mark.parametrize( @@ -51,15 +56,17 @@ def test_minimax_model(prompt): sync_output = completion_model.chat(prompt) async_output = anyio.run(completion_model.async_chat, prompt) - assert isinstance(sync_output, dict) - assert isinstance(async_output, dict) + assert isinstance(sync_output.response, dict) + assert isinstance(sync_output.parsed_result, str) + assert isinstance(async_output.response, dict) + assert isinstance(async_output.parsed_result, str) @pytest.mark.parametrize( 'prompt', [ 'Hello, my name is', - [{'role': 'system', 'content': 'your are lmclient demo assistant'}, {'role': 'user', 'content': 'hello, who are you?'}], + [{'role': 'user', 'content': 'hello, who are you?'}], ], ) def test_zhipu_model(prompt): @@ -68,5 +75,7 @@ def test_zhipu_model(prompt): sync_output = completion_model.chat(prompt) async_output = anyio.run(completion_model.async_chat, prompt) - assert isinstance(sync_output, dict) - assert isinstance(async_output, dict) + assert isinstance(sync_output.response, dict) + assert isinstance(sync_output.parsed_result, str) + assert isinstance(async_output.response, dict) + assert isinstance(async_output.parsed_result, str)