diff --git a/.gitignore b/.gitignore index 87862a2..8a9a383 100644 --- a/.gitignore +++ b/.gitignore @@ -158,4 +158,5 @@ cython_debug/ # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ -playground/ \ No newline at end of file +playground/ +.vscode/ \ No newline at end of file diff --git a/lmclient/chat_engine.py b/lmclient/chat_engine.py index db4c840..1e20b7f 100644 --- a/lmclient/chat_engine.py +++ b/lmclient/chat_engine.py @@ -2,7 +2,7 @@ import json import logging -from typing import Generic, List, Optional, TypeVar +from typing import Generic, List, Optional, TypeVar, cast from lmclient.models import BaseChatModel, load_from_model_id from lmclient.types import ChatModelOutput, FunctionCallDict, GeneralParameters, Message, Messages, ModelParameters @@ -23,7 +23,7 @@ def __init__( function_call_raise_error: bool = False, ): if isinstance(chat_model, str): - self._chat_model: BaseChatModel[T_P, T_O] = load_from_model_id(chat_model) # type: ignore + self._chat_model = cast(BaseChatModel[T_P, T_O], load_from_model_id(chat_model)) else: self._chat_model = chat_model @@ -42,7 +42,8 @@ def __init__( function_call=function_call, ) _parameters: T_P = self._chat_model.parameters_type.from_general_parameters(self.engine_parameters) - self._chat_model.parameters = self._chat_model.parameters.model_copy(update=_parameters.model_dump(exclude_unset=True)) + self._chat_model.update_parameters(**_parameters.model_dump(exclude_unset=True)) + self.function_call_raise_error = function_call_raise_error self.history: Messages = [] diff --git a/lmclient/models/azure.py b/lmclient/models/azure.py index bae7fab..4725be5 100644 --- a/lmclient/models/azure.py +++ b/lmclient/models/azure.py @@ -4,7 +4,7 @@ from pathlib import Path from typing import Any -from lmclient.models.http import HttpChatModel, RetryStrategy +from lmclient.models.http import HttpChatModel, ProxiesTypes, RetryStrategy from lmclient.models.openai import ( OpenAIChatParameters, convert_lmclient_to_openai, @@ -27,8 +27,9 @@ def __init__( retry: bool | RetryStrategy = False, parameters: OpenAIChatParameters = OpenAIChatParameters(), use_cache: Path | str | bool = False, + proxies: ProxiesTypes | None = None, ): - super().__init__(parameters=parameters, timeout=timeout, retry=retry, use_cache=use_cache) + super().__init__(parameters=parameters, timeout=timeout, retry=retry, use_cache=use_cache, proxies=proxies) self.model = model or os.environ['AZURE_CHAT_API_ENGINE'] or os.environ['AZURE_CHAT_MODEL_NAME'] self.system_prompt = system_prompt self.api_key = api_key or os.environ['AZURE_API_KEY'] diff --git a/lmclient/models/base.py b/lmclient/models/base.py index 5719a54..9eabbb1 100644 --- a/lmclient/models/base.py +++ b/lmclient/models/base.py @@ -98,3 +98,6 @@ async def async_chat_completion(self, messages: Messages, override_parameters: T else: model_output = await self._async_chat_completion(messages, parameters) return model_output + + def update_parameters(self, **kwargs: Any): + self.parameters = self.parameters.model_copy(update=kwargs) diff --git a/lmclient/models/http.py b/lmclient/models/http.py index f031098..6df81e1 100644 --- a/lmclient/models/http.py +++ b/lmclient/models/http.py @@ -6,6 +6,7 @@ from typing import Any import httpx +from httpx._types import ProxiesTypes from tenacity import retry, stop_after_attempt, wait_random_exponential from lmclient.models.base import T_P, BaseChatModel @@ -22,6 +23,7 @@ def __init__( parameters: T_P, timeout: int | None = None, retry: bool | RetryStrategy = False, + proxies: ProxiesTypes | None = None, use_cache: Path | str | bool = False, ): super().__init__(parameters=parameters, cache=use_cache) @@ -30,6 +32,7 @@ def __init__( self.retry_strategy = retry else: self.retry_strategy = RetryStrategy() if retry else None + self.proxies = proxies @abstractmethod def get_request_parameters(self, messages: Messages, parameters: T_P) -> dict[str, Any]: @@ -40,11 +43,12 @@ def parse_model_reponse(self, response: ModelResponse) -> Messages: ... def _chat_completion_without_retry(self, messages: Messages, parameters: T_P) -> HttpChatModelOutput: - http_parameters = self.get_request_parameters(messages, parameters) - http_parameters = {'timeout': self.timeout, **http_parameters} - logger.info(f'HTTP Request: {http_parameters}') - http_response = httpx.post(**http_parameters) # type: ignore - http_response.raise_for_status() + with httpx.Client(proxies=self.proxies) as client: + http_parameters = self.get_request_parameters(messages, parameters) + http_parameters = {'timeout': self.timeout, **http_parameters} + logger.info(f'HTTP Request: {http_parameters}') + http_response = client.post(**http_parameters) # type: ignore + http_response.raise_for_status() model_response = http_response.json() logger.info(f'HTTP Response: {model_response}') new_messages = self.parse_model_reponse(model_response) @@ -57,7 +61,7 @@ def _chat_completion_without_retry(self, messages: Messages, parameters: T_P) -> ) async def _async_chat_completion_without_retry(self, messages: Messages, parameters: T_P) -> HttpChatModelOutput: - async with httpx.AsyncClient() as client: + async with httpx.AsyncClient(proxies=self.proxies) as client: http_parameters = self.get_request_parameters(messages, parameters) http_parameters = {'timeout': self.timeout, **http_parameters} logger.info(f'ASYNC HTTP Request: {http_parameters}') diff --git a/lmclient/models/minimax_pro.py b/lmclient/models/minimax_pro.py index 034317e..45228cc 100644 --- a/lmclient/models/minimax_pro.py +++ b/lmclient/models/minimax_pro.py @@ -8,7 +8,7 @@ from typing_extensions import NotRequired, TypedDict from lmclient.exceptions import MessageError -from lmclient.models.http import HttpChatModel, RetryStrategy +from lmclient.models.http import HttpChatModel, ProxiesTypes, RetryStrategy from lmclient.types import ( FunctionCallDict, FunctionDict, @@ -91,8 +91,9 @@ def __init__( retry: bool | RetryStrategy = False, parameters: MinimaxProChatParameters = MinimaxProChatParameters(), use_cache: Path | str | bool = False, + proxies: ProxiesTypes | None = None, ): - super().__init__(parameters=parameters, timeout=timeout, retry=retry, use_cache=use_cache) + super().__init__(parameters=parameters, timeout=timeout, retry=retry, use_cache=use_cache, proxies=proxies) self.model = model self.base_url = base_url self.group_id = group_id or os.environ['MINIMAX_GROUP_ID'] diff --git a/lmclient/models/openai.py b/lmclient/models/openai.py index c515016..cc40d98 100644 --- a/lmclient/models/openai.py +++ b/lmclient/models/openai.py @@ -7,7 +7,7 @@ from typing_extensions import NotRequired, TypedDict from lmclient.exceptions import MessageError -from lmclient.models.http import HttpChatModel, RetryStrategy +from lmclient.models.http import HttpChatModel, ProxiesTypes, RetryStrategy from lmclient.parser import ParserError from lmclient.types import FunctionCallDict, FunctionDict, GeneralParameters, Message, Messages, ModelParameters, ModelResponse @@ -134,8 +134,9 @@ def __init__( retry: bool | RetryStrategy = False, parameters: OpenAIChatParameters = OpenAIChatParameters(), use_cache: Path | str | bool = False, + proxies: ProxiesTypes | None = None, ): - super().__init__(parameters=parameters, timeout=timeout, retry=retry, use_cache=use_cache) + super().__init__(parameters=parameters, timeout=timeout, retry=retry, use_cache=use_cache, proxies=proxies) self.model = model self.system_prompt = system_prompt self.api_base = api_base or os.getenv('OPENAI_API_BASE') or 'https://api.openai.com/v1' diff --git a/lmclient/models/wenxin.py b/lmclient/models/wenxin.py index 0a21d38..7d4b30c 100644 --- a/lmclient/models/wenxin.py +++ b/lmclient/models/wenxin.py @@ -10,8 +10,8 @@ from typing_extensions import Self, TypedDict from lmclient.exceptions import ResponseError -from lmclient.models.http import HttpChatModel -from lmclient.types import GeneralParameters, Message, Messages, ModelParameters, ModelResponse, RetryStrategy +from lmclient.models.http import HttpChatModel, ProxiesTypes, RetryStrategy +from lmclient.types import GeneralParameters, Message, Messages, ModelParameters, ModelResponse WENXIN_ACCESS_TOKEN_URL = 'https://aip.baidubce.com/oauth/2.0/token' WENXIN_BASE_URL = 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/' @@ -55,8 +55,9 @@ def __init__( timeout: int | None = None, retry: bool | RetryStrategy = False, use_cache: Path | str | bool = False, + proxies: ProxiesTypes | None = None, ): - super().__init__(parameters, timeout, retry, use_cache) + super().__init__(parameters=parameters, timeout=timeout, retry=retry, use_cache=use_cache, proxies=proxies) self.model = self.normalize_model(model) self._api_key = api_key or os.getenv('WENXIN_API_KEY') self._secret_key = secret_key or os.getenv('WENXIN_SECRET_KEY') diff --git a/lmclient/models/zhipu.py b/lmclient/models/zhipu.py index e41fe50..14c9be7 100644 --- a/lmclient/models/zhipu.py +++ b/lmclient/models/zhipu.py @@ -9,7 +9,7 @@ import jwt from lmclient.exceptions import MessageError -from lmclient.models.http import HttpChatModel, RetryStrategy +from lmclient.models.http import HttpChatModel, ProxiesTypes, RetryStrategy from lmclient.parser import ParserError from lmclient.types import GeneralParameters, Message, Messages, ModelParameters, ModelResponse @@ -68,8 +68,9 @@ def __init__( retry: bool | RetryStrategy = False, parameters: ZhiPuChatParameters = ZhiPuChatParameters(), use_cache: Path | str | bool = False, + proxies: ProxiesTypes | None = None, ): - super().__init__(parameters=parameters, timeout=timeout, retry=retry, use_cache=use_cache) + super().__init__(parameters=parameters, timeout=timeout, retry=retry, use_cache=use_cache, proxies=proxies) self.model = model self.api_key = api_key or os.environ['ZHIPU_API_KEY'] self.api_base = api_base or os.getenv('ZHIPU_API_BASE') or 'https://open.bigmodel.cn/api/paas/v3/model-api'