Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

添加代理支持 #11

Merged
merged 1 commit into from
Sep 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -158,4 +158,5 @@ cython_debug/
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
playground/
playground/
.vscode/
7 changes: 4 additions & 3 deletions lmclient/chat_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import json
import logging
from typing import Generic, List, Optional, TypeVar
from typing import Generic, List, Optional, TypeVar, cast

from lmclient.models import BaseChatModel, load_from_model_id
from lmclient.types import ChatModelOutput, FunctionCallDict, GeneralParameters, Message, Messages, ModelParameters
Expand All @@ -23,7 +23,7 @@ def __init__(
function_call_raise_error: bool = False,
):
if isinstance(chat_model, str):
self._chat_model: BaseChatModel[T_P, T_O] = load_from_model_id(chat_model) # type: ignore
self._chat_model = cast(BaseChatModel[T_P, T_O], load_from_model_id(chat_model))
else:
self._chat_model = chat_model

Expand All @@ -42,7 +42,8 @@ def __init__(
function_call=function_call,
)
_parameters: T_P = self._chat_model.parameters_type.from_general_parameters(self.engine_parameters)
self._chat_model.parameters = self._chat_model.parameters.model_copy(update=_parameters.model_dump(exclude_unset=True))
self._chat_model.update_parameters(**_parameters.model_dump(exclude_unset=True))

self.function_call_raise_error = function_call_raise_error
self.history: Messages = []

Expand Down
5 changes: 3 additions & 2 deletions lmclient/models/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from pathlib import Path
from typing import Any

from lmclient.models.http import HttpChatModel, RetryStrategy
from lmclient.models.http import HttpChatModel, ProxiesTypes, RetryStrategy
from lmclient.models.openai import (
OpenAIChatParameters,
convert_lmclient_to_openai,
Expand All @@ -27,8 +27,9 @@ def __init__(
retry: bool | RetryStrategy = False,
parameters: OpenAIChatParameters = OpenAIChatParameters(),
use_cache: Path | str | bool = False,
proxies: ProxiesTypes | None = None,
):
super().__init__(parameters=parameters, timeout=timeout, retry=retry, use_cache=use_cache)
super().__init__(parameters=parameters, timeout=timeout, retry=retry, use_cache=use_cache, proxies=proxies)
self.model = model or os.environ['AZURE_CHAT_API_ENGINE'] or os.environ['AZURE_CHAT_MODEL_NAME']
self.system_prompt = system_prompt
self.api_key = api_key or os.environ['AZURE_API_KEY']
Expand Down
3 changes: 3 additions & 0 deletions lmclient/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,3 +98,6 @@ async def async_chat_completion(self, messages: Messages, override_parameters: T
else:
model_output = await self._async_chat_completion(messages, parameters)
return model_output

def update_parameters(self, **kwargs: Any):
self.parameters = self.parameters.model_copy(update=kwargs)
16 changes: 10 additions & 6 deletions lmclient/models/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Any

import httpx
from httpx._types import ProxiesTypes
from tenacity import retry, stop_after_attempt, wait_random_exponential

from lmclient.models.base import T_P, BaseChatModel
Expand All @@ -22,6 +23,7 @@ def __init__(
parameters: T_P,
timeout: int | None = None,
retry: bool | RetryStrategy = False,
proxies: ProxiesTypes | None = None,
use_cache: Path | str | bool = False,
):
super().__init__(parameters=parameters, cache=use_cache)
Expand All @@ -30,6 +32,7 @@ def __init__(
self.retry_strategy = retry
else:
self.retry_strategy = RetryStrategy() if retry else None
self.proxies = proxies

@abstractmethod
def get_request_parameters(self, messages: Messages, parameters: T_P) -> dict[str, Any]:
Expand All @@ -40,11 +43,12 @@ def parse_model_reponse(self, response: ModelResponse) -> Messages:
...

def _chat_completion_without_retry(self, messages: Messages, parameters: T_P) -> HttpChatModelOutput:
http_parameters = self.get_request_parameters(messages, parameters)
http_parameters = {'timeout': self.timeout, **http_parameters}
logger.info(f'HTTP Request: {http_parameters}')
http_response = httpx.post(**http_parameters) # type: ignore
http_response.raise_for_status()
with httpx.Client(proxies=self.proxies) as client:
http_parameters = self.get_request_parameters(messages, parameters)
http_parameters = {'timeout': self.timeout, **http_parameters}
logger.info(f'HTTP Request: {http_parameters}')
http_response = client.post(**http_parameters) # type: ignore
http_response.raise_for_status()
model_response = http_response.json()
logger.info(f'HTTP Response: {model_response}')
new_messages = self.parse_model_reponse(model_response)
Expand All @@ -57,7 +61,7 @@ def _chat_completion_without_retry(self, messages: Messages, parameters: T_P) ->
)

async def _async_chat_completion_without_retry(self, messages: Messages, parameters: T_P) -> HttpChatModelOutput:
async with httpx.AsyncClient() as client:
async with httpx.AsyncClient(proxies=self.proxies) as client:
http_parameters = self.get_request_parameters(messages, parameters)
http_parameters = {'timeout': self.timeout, **http_parameters}
logger.info(f'ASYNC HTTP Request: {http_parameters}')
Expand Down
5 changes: 3 additions & 2 deletions lmclient/models/minimax_pro.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from typing_extensions import NotRequired, TypedDict

from lmclient.exceptions import MessageError
from lmclient.models.http import HttpChatModel, RetryStrategy
from lmclient.models.http import HttpChatModel, ProxiesTypes, RetryStrategy
from lmclient.types import (
FunctionCallDict,
FunctionDict,
Expand Down Expand Up @@ -91,8 +91,9 @@ def __init__(
retry: bool | RetryStrategy = False,
parameters: MinimaxProChatParameters = MinimaxProChatParameters(),
use_cache: Path | str | bool = False,
proxies: ProxiesTypes | None = None,
):
super().__init__(parameters=parameters, timeout=timeout, retry=retry, use_cache=use_cache)
super().__init__(parameters=parameters, timeout=timeout, retry=retry, use_cache=use_cache, proxies=proxies)
self.model = model
self.base_url = base_url
self.group_id = group_id or os.environ['MINIMAX_GROUP_ID']
Expand Down
5 changes: 3 additions & 2 deletions lmclient/models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing_extensions import NotRequired, TypedDict

from lmclient.exceptions import MessageError
from lmclient.models.http import HttpChatModel, RetryStrategy
from lmclient.models.http import HttpChatModel, ProxiesTypes, RetryStrategy
from lmclient.parser import ParserError
from lmclient.types import FunctionCallDict, FunctionDict, GeneralParameters, Message, Messages, ModelParameters, ModelResponse

Expand Down Expand Up @@ -134,8 +134,9 @@ def __init__(
retry: bool | RetryStrategy = False,
parameters: OpenAIChatParameters = OpenAIChatParameters(),
use_cache: Path | str | bool = False,
proxies: ProxiesTypes | None = None,
):
super().__init__(parameters=parameters, timeout=timeout, retry=retry, use_cache=use_cache)
super().__init__(parameters=parameters, timeout=timeout, retry=retry, use_cache=use_cache, proxies=proxies)
self.model = model
self.system_prompt = system_prompt
self.api_base = api_base or os.getenv('OPENAI_API_BASE') or 'https://api.openai.com/v1'
Expand Down
7 changes: 4 additions & 3 deletions lmclient/models/wenxin.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
from typing_extensions import Self, TypedDict

from lmclient.exceptions import ResponseError
from lmclient.models.http import HttpChatModel
from lmclient.types import GeneralParameters, Message, Messages, ModelParameters, ModelResponse, RetryStrategy
from lmclient.models.http import HttpChatModel, ProxiesTypes, RetryStrategy
from lmclient.types import GeneralParameters, Message, Messages, ModelParameters, ModelResponse

WENXIN_ACCESS_TOKEN_URL = 'https://aip.baidubce.com/oauth/2.0/token'
WENXIN_BASE_URL = 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/'
Expand Down Expand Up @@ -55,8 +55,9 @@ def __init__(
timeout: int | None = None,
retry: bool | RetryStrategy = False,
use_cache: Path | str | bool = False,
proxies: ProxiesTypes | None = None,
):
super().__init__(parameters, timeout, retry, use_cache)
super().__init__(parameters=parameters, timeout=timeout, retry=retry, use_cache=use_cache, proxies=proxies)
self.model = self.normalize_model(model)
self._api_key = api_key or os.getenv('WENXIN_API_KEY')
self._secret_key = secret_key or os.getenv('WENXIN_SECRET_KEY')
Expand Down
5 changes: 3 additions & 2 deletions lmclient/models/zhipu.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import jwt

from lmclient.exceptions import MessageError
from lmclient.models.http import HttpChatModel, RetryStrategy
from lmclient.models.http import HttpChatModel, ProxiesTypes, RetryStrategy
from lmclient.parser import ParserError
from lmclient.types import GeneralParameters, Message, Messages, ModelParameters, ModelResponse

Expand Down Expand Up @@ -68,8 +68,9 @@ def __init__(
retry: bool | RetryStrategy = False,
parameters: ZhiPuChatParameters = ZhiPuChatParameters(),
use_cache: Path | str | bool = False,
proxies: ProxiesTypes | None = None,
):
super().__init__(parameters=parameters, timeout=timeout, retry=retry, use_cache=use_cache)
super().__init__(parameters=parameters, timeout=timeout, retry=retry, use_cache=use_cache, proxies=proxies)
self.model = model
self.api_key = api_key or os.environ['ZHIPU_API_KEY']
self.api_base = api_base or os.getenv('ZHIPU_API_BASE') or 'https://open.bigmodel.cn/api/paas/v3/model-api'
Expand Down