Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add timeout #48

Merged
merged 1 commit into from
Mar 6, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
add timeout
wangyuxin committed Mar 6, 2024
commit 278a6740824d7975c9b2fa9b4bb6b1b7876c3c0c
8 changes: 8 additions & 0 deletions generate/chat_completion/base.py
Original file line number Diff line number Diff line change
@@ -97,27 +97,35 @@ def _get_request_parameters(self, prompt: Prompt, stream: bool = False, **kwargs

@override
def generate(self, prompt: Prompt, **kwargs: Any) -> ChatCompletionOutput:
timeout = kwargs.pop('timeout') if 'timeout' in kwargs else None
request_parameters = self._get_request_parameters(prompt, **kwargs)
request_parameters['timeout'] = timeout
response = self.http_client.post(request_parameters=request_parameters)
return self._process_reponse(response.json())

@override
async def async_generate(self, prompt: Prompt, **kwargs: Any) -> ChatCompletionOutput:
timeout = kwargs.pop('timeout') if 'timeout' in kwargs else None
request_parameters = self._get_request_parameters(prompt, **kwargs)
request_parameters['timeout'] = timeout
response = await self.http_client.async_post(request_parameters=request_parameters)
return self._process_reponse(response.json())

@override
def stream_generate(self, prompt: Prompt, **kwargs: Any) -> Iterator[ChatCompletionStreamOutput]:
timeout = kwargs.pop('timeout') if 'timeout' in kwargs else None
request_parameters = self._get_request_parameters(prompt, stream=True, **kwargs)
request_parameters['timeout'] = timeout
stream_manager = StreamManager(info=self.model_info)
for line in self.http_client.stream_post(request_parameters=request_parameters):
if output := self._process_stream_line(line, stream_manager):
yield output

@override
async def async_stream_generate(self, prompt: Prompt, **kwargs: Any) -> AsyncIterator[ChatCompletionStreamOutput]:
timeout = kwargs.pop('timeout') if 'timeout' in kwargs else None
request_parameters = self._get_request_parameters(prompt, stream=True, **kwargs)
request_parameters['timeout'] = timeout
stream_manager = StreamManager(info=self.model_info)
async for line in self.http_client.async_stream_post(request_parameters=request_parameters):
if output := self._process_stream_line(line, stream_manager):
4 changes: 2 additions & 2 deletions generate/chat_completion/models/anthropic.py
Original file line number Diff line number Diff line change
@@ -24,7 +24,7 @@
from generate.chat_completion.model_output import ChatCompletionOutput, ChatCompletionStreamOutput
from generate.chat_completion.stream_manager import StreamManager
from generate.http import HttpClient, HttpxPostKwargs
from generate.model import ModelParameters, ModelParametersDict
from generate.model import ModelParameters, RemoteModelParametersDict
from generate.platforms import AnthropicSettings
from generate.types import Probability, Temperature

@@ -44,7 +44,7 @@ class AnthropicChatParameters(ModelParameters):
top_k: Optional[PositiveInt] = None


class AnthropicParametersDict(ModelParametersDict, total=False):
class AnthropicParametersDict(RemoteModelParametersDict, total=False):
system: Optional[str]
max_tokens: PositiveInt
metadata: Optional[Dict[str, Any]]
4 changes: 2 additions & 2 deletions generate/chat_completion/models/baichuan.py
Original file line number Diff line number Diff line change
@@ -26,7 +26,7 @@
ResponseValue,
UnexpectedResponseError,
)
from generate.model import ModelParameters, ModelParametersDict
from generate.model import ModelParameters, RemoteModelParametersDict
from generate.platforms.baichuan import BaichuanSettings
from generate.types import Probability, Temperature

@@ -43,7 +43,7 @@ class BaichuanChatParameters(ModelParameters):
search: Optional[bool] = Field(default=None, alias='with_search_enhance')


class BaichuanChatParametersDict(ModelParametersDict, total=False):
class BaichuanChatParametersDict(RemoteModelParametersDict, total=False):
temperature: Optional[Temperature]
top_k: Optional[int]
top_p: Optional[Probability]
4 changes: 2 additions & 2 deletions generate/chat_completion/models/bailian.py
Original file line number Diff line number Diff line change
@@ -25,7 +25,7 @@
ResponseValue,
UnexpectedResponseError,
)
from generate.model import ModelParameters, ModelParametersDict
from generate.model import ModelParameters, RemoteModelParametersDict
from generate.platforms.bailian import BailianSettings, BailianTokenManager
from generate.types import Probability

@@ -67,7 +67,7 @@ def custom_model_dump(self) -> dict[str, Any]:
return output


class BailianChatParametersDict(ModelParametersDict, total=False):
class BailianChatParametersDict(RemoteModelParametersDict, total=False):
request_id: str
top_p: Optional[Probability]
top_k: Optional[int]
4 changes: 2 additions & 2 deletions generate/chat_completion/models/dashscope.py
Original file line number Diff line number Diff line change
@@ -23,7 +23,7 @@
HttpxPostKwargs,
ResponseValue,
)
from generate.model import ModelParameters, ModelParametersDict
from generate.model import ModelParameters, RemoteModelParametersDict
from generate.platforms.dashscope import DashScopeSettings
from generate.types import Probability

@@ -44,7 +44,7 @@ class DashScopeChatParameters(ModelParameters):
search: Annotated[Optional[bool], Field(alias='enable_search')] = None


class DashScopeChatParametersDict(ModelParametersDict, total=False):
class DashScopeChatParametersDict(RemoteModelParametersDict, total=False):
seed: Optional[PositiveInt]
max_tokens: Optional[PositiveInt]
top_p: Optional[Probability]
4 changes: 2 additions & 2 deletions generate/chat_completion/models/dashscope_multimodal.py
Original file line number Diff line number Diff line change
@@ -30,7 +30,7 @@
HttpxPostKwargs,
ResponseValue,
)
from generate.model import ModelParameters, ModelParametersDict
from generate.model import ModelParameters, RemoteModelParametersDict
from generate.platforms.dashscope import DashScopeSettings
from generate.types import Probability

@@ -41,7 +41,7 @@ class DashScopeMultiModalChatParameters(ModelParameters):
top_k: Optional[Annotated[int, Field(ge=0, le=100)]] = None


class DashScopeMultiModalChatParametersDict(ModelParametersDict, total=False):
class DashScopeMultiModalChatParametersDict(RemoteModelParametersDict, total=False):
seed: Optional[PositiveInt]
top_p: Optional[Probability]
top_k: Optional[int]
4 changes: 2 additions & 2 deletions generate/chat_completion/models/deepseek.py
Original file line number Diff line number Diff line change
@@ -9,7 +9,7 @@
from generate.chat_completion.model_output import ChatCompletionOutput, ChatCompletionStreamOutput
from generate.chat_completion.models.openai_like import OpenAILikeChat
from generate.http import HttpClient
from generate.model import ModelParameters, ModelParametersDict
from generate.model import ModelParameters, RemoteModelParametersDict
from generate.platforms import DeepSeekSettings
from generate.types import Probability

@@ -23,7 +23,7 @@ class DeepSeekChatParameters(ModelParameters):
stop: Optional[Union[str, List[str]]] = None


class DeepSeekParametersDict(ModelParametersDict, total=False):
class DeepSeekParametersDict(RemoteModelParametersDict, total=False):
temperature: Optional[float]
top_p: Optional[Probability]
max_tokens: Optional[PositiveInt]
4 changes: 2 additions & 2 deletions generate/chat_completion/models/hunyuan.py
Original file line number Diff line number Diff line change
@@ -29,7 +29,7 @@
ResponseValue,
UnexpectedResponseError,
)
from generate.model import ModelParameters, ModelParametersDict
from generate.model import ModelParameters, RemoteModelParametersDict
from generate.platforms.hunyuan import HunyuanSettings
from generate.types import Probability, Temperature

@@ -44,7 +44,7 @@ class HunyuanChatParameters(ModelParameters):
top_p: Optional[Probability] = None


class HunyuanChatParametersDict(ModelParametersDict, total=False):
class HunyuanChatParametersDict(RemoteModelParametersDict, total=False):
temperature: Optional[Temperature]
top_p: Optional[Probability]

4 changes: 2 additions & 2 deletions generate/chat_completion/models/minimax.py
Original file line number Diff line number Diff line change
@@ -25,7 +25,7 @@
ResponseValue,
UnexpectedResponseError,
)
from generate.model import ModelParameters, ModelParametersDict
from generate.model import ModelParameters, RemoteModelParametersDict
from generate.platforms.minimax import MinimaxSettings
from generate.types import Probability, Temperature

@@ -62,7 +62,7 @@ def custom_model_dump(self) -> dict[str, Any]:
return output


class MinimaxChatParametersDict(ModelParametersDict, total=False):
class MinimaxChatParametersDict(RemoteModelParametersDict, total=False):
system_prompt: str
role_meta: RoleMeta
beam_width: Optional[int]
4 changes: 2 additions & 2 deletions generate/chat_completion/models/minimax_pro.py
Original file line number Diff line number Diff line change
@@ -28,7 +28,7 @@
ResponseValue,
UnexpectedResponseError,
)
from generate.model import ModelInfo, ModelParameters, ModelParametersDict
from generate.model import ModelInfo, ModelParameters, RemoteModelParametersDict
from generate.platforms.minimax import MinimaxSettings
from generate.types import OrIterable, Probability, Temperature
from generate.utils import ensure_iterable
@@ -118,7 +118,7 @@ def custom_model_dump(self) -> dict[str, Any]:
return output


class MinimaxProChatParametersDict(ModelParametersDict, total=False):
class MinimaxProChatParametersDict(RemoteModelParametersDict, total=False):
reply_constraints: ReplyConstrainsDict
bot_setting: List[BotSettingDict]
temperature: Optional[Temperature]
4 changes: 2 additions & 2 deletions generate/chat_completion/models/moonshot.py
Original file line number Diff line number Diff line change
@@ -9,7 +9,7 @@
from generate.chat_completion.model_output import ChatCompletionOutput, ChatCompletionStreamOutput
from generate.chat_completion.models.openai_like import OpenAILikeChat
from generate.http import HttpClient
from generate.model import ModelParameters, ModelParametersDict
from generate.model import ModelParameters, RemoteModelParametersDict
from generate.platforms import MoonshotSettings
from generate.types import Probability, Temperature

@@ -20,7 +20,7 @@ class MoonshotChatParameters(ModelParameters):
max_tokens: Optional[PositiveInt] = None


class MoonshotParametersDict(ModelParametersDict, total=False):
class MoonshotParametersDict(RemoteModelParametersDict, total=False):
temperature: Temperature
top_p: Probability
max_tokens: PositiveInt
4 changes: 2 additions & 2 deletions generate/chat_completion/models/openai.py
Original file line number Diff line number Diff line change
@@ -19,7 +19,7 @@
from generate.http import (
HttpClient,
)
from generate.model import ModelParameters, ModelParametersDict
from generate.model import ModelParameters, RemoteModelParametersDict
from generate.platforms.openai import OpenAISettings
from generate.types import OrIterable, Probability, Temperature
from generate.utils import ensure_iterable
@@ -42,7 +42,7 @@ class OpenAIChatParameters(ModelParameters):
tool_choice: Union[Literal['auto'], OpenAIToolChoice, None] = None


class OpenAIChatParametersDict(ModelParametersDict, total=False):
class OpenAIChatParametersDict(RemoteModelParametersDict, total=False):
temperature: Optional[Temperature]
top_p: Optional[Probability]
max_tokens: Optional[PositiveInt]
14 changes: 8 additions & 6 deletions generate/chat_completion/models/test.py
Original file line number Diff line number Diff line change
@@ -11,14 +11,14 @@
)
from generate.chat_completion.message import AssistantMessage, Prompt, ensure_messages
from generate.chat_completion.model_output import Stream
from generate.model import ModelParameters, ModelParametersDict
from generate.model import ModelParameters, RemoteModelParametersDict


class FakeChatParameters(ModelParameters):
prefix: str = 'Completed:'


class FakeChatParametersDict(ModelParametersDict, total=False):
class FakeChatParametersDict(RemoteModelParametersDict, total=False):
prefix: str


@@ -28,19 +28,21 @@ class FakeChat(ChatCompletionModel):
def __init__(self, parameters: FakeChatParameters | None = None) -> None:
self.parameters = parameters or FakeChatParameters()

def generate(self, prompt: Prompt, **kwargs: Unpack[ModelParametersDict]) -> ChatCompletionOutput:
def generate(self, prompt: Prompt, **kwargs: Unpack[RemoteModelParametersDict]) -> ChatCompletionOutput:
messages = ensure_messages(prompt)
parameters = self.parameters.clone_with_changes(**kwargs)
content = f'{parameters.prefix}{messages[-1].content}'
return ChatCompletionOutput(model_info=self.model_info, message=AssistantMessage(content=content))

async def async_generate(self, prompt: Prompt, **kwargs: Unpack[ModelParametersDict]) -> ChatCompletionOutput:
async def async_generate(self, prompt: Prompt, **kwargs: Unpack[RemoteModelParametersDict]) -> ChatCompletionOutput:
messages = ensure_messages(prompt)
parameters = self.parameters.clone_with_changes(**kwargs)
content = f'{parameters.prefix}{messages[-1].content}'
return ChatCompletionOutput(model_info=self.model_info, message=AssistantMessage(content=content))

def stream_generate(self, prompt: Prompt, **kwargs: Unpack[ModelParametersDict]) -> Iterator[ChatCompletionStreamOutput]:
def stream_generate(
self, prompt: Prompt, **kwargs: Unpack[RemoteModelParametersDict]
) -> Iterator[ChatCompletionStreamOutput]:
messages = ensure_messages(prompt)
parameters = self.parameters.clone_with_changes(**kwargs)
content = f'{parameters.prefix}{messages[-1].content}'
@@ -56,7 +58,7 @@ def stream_generate(self, prompt: Prompt, **kwargs: Unpack[ModelParametersDict])
)

async def async_stream_generate(
self, prompt: Prompt, **kwargs: Unpack[ModelParametersDict]
self, prompt: Prompt, **kwargs: Unpack[RemoteModelParametersDict]
) -> AsyncIterator[ChatCompletionStreamOutput]:
messages = ensure_messages(prompt)
parameters = self.parameters.clone_with_changes(**kwargs)
4 changes: 2 additions & 2 deletions generate/chat_completion/models/wenxin.py
Original file line number Diff line number Diff line change
@@ -28,7 +28,7 @@
ResponseValue,
UnexpectedResponseError,
)
from generate.model import ModelParameters, ModelParametersDict
from generate.model import ModelParameters, RemoteModelParametersDict
from generate.platforms.baidu import QianfanSettings, QianfanTokenManager
from generate.types import JsonSchema, OrIterable, Probability, Temperature
from generate.utils import ensure_iterable
@@ -124,7 +124,7 @@ def custom_model_dump(self) -> dict[str, Any]:
return output


class WenxinChatParametersDict(ModelParametersDict, total=False):
class WenxinChatParametersDict(RemoteModelParametersDict, total=False):
temperature: Optional[Temperature]
top_p: Optional[Probability]
functions: Optional[List[WenxinFunction]]
4 changes: 2 additions & 2 deletions generate/chat_completion/models/yi.py
Original file line number Diff line number Diff line change
@@ -9,7 +9,7 @@
from generate.chat_completion.model_output import ChatCompletionOutput, ChatCompletionStreamOutput
from generate.chat_completion.models.openai_like import OpenAILikeChat
from generate.http import HttpClient
from generate.model import ModelParameters, ModelParametersDict
from generate.model import ModelParameters, RemoteModelParametersDict
from generate.platforms import YiSettings


@@ -18,7 +18,7 @@ class YiChatParameters(ModelParameters):
max_tokens: Optional[PositiveInt] = None


class YiParametersDict(ModelParametersDict, total=False):
class YiParametersDict(RemoteModelParametersDict, total=False):
temperature: Optional[Annotated[float, Field(ge=0, lt=2)]]
max_tokens: Optional[PositiveInt]

6 changes: 3 additions & 3 deletions generate/chat_completion/models/zhipu.py
Original file line number Diff line number Diff line change
@@ -33,7 +33,7 @@
ResponseValue,
UnexpectedResponseError,
)
from generate.model import ModelInfo, ModelParameters, ModelParametersDict
from generate.model import ModelInfo, ModelParameters, RemoteModelParametersDict
from generate.platforms.zhipu import ZhipuSettings, generate_zhipu_token
from generate.types import JsonSchema, Probability, Temperature

@@ -92,7 +92,7 @@ def can_not_equal_zero(cls, v: Optional[Temperature]) -> Optional[Temperature]:
return v


class ZhipuChatParametersDict(ModelParametersDict, total=False):
class ZhipuChatParametersDict(RemoteModelParametersDict, total=False):
temperature: Optional[Temperature]
top_p: Optional[Probability]
request_id: Optional[str]
@@ -448,7 +448,7 @@ def custom_model_dump(self) -> dict[str, Any]:
return output


class ZhipuCharacterChatParametersDict(ModelParametersDict, total=False):
class ZhipuCharacterChatParametersDict(RemoteModelParametersDict, total=False):
meta: ZhipuMeta
request_id: Optional[str]

4 changes: 2 additions & 2 deletions generate/image_generation/models/baidu.py
Original file line number Diff line number Diff line change
@@ -13,7 +13,7 @@
ImageGenerationOutput,
RemoteImageGenerationModel,
)
from generate.model import ModelParameters, ModelParametersDict
from generate.model import ModelParameters, RemoteModelParametersDict
from generate.platforms.baidu import BaiduCreationSettings, BaiduCreationTokenManager

ValidSize = Literal[
@@ -52,7 +52,7 @@ def custom_model_dump(self) -> dict[str, Any]:
return output_data


class BaiduImageGenerationParametersDict(ModelParametersDict, total=False):
class BaiduImageGenerationParametersDict(RemoteModelParametersDict, total=False):
size: ValidSize
n: Optional[int]
reference_image: Union[HttpUrl, Base64Str, None]
4 changes: 2 additions & 2 deletions generate/image_generation/models/openai.py
Original file line number Diff line number Diff line change
@@ -9,7 +9,7 @@

from generate.http import HttpClient, HttpxPostKwargs
from generate.image_generation.base import GeneratedImage, ImageGenerationOutput, RemoteImageGenerationModel
from generate.model import ModelParameters, ModelParametersDict
from generate.model import ModelParameters, RemoteModelParametersDict
from generate.platforms.openai import OpenAISettings

MAX_PROMPT_LENGTH_DALLE_3 = 4000
@@ -46,7 +46,7 @@ class OpenAIImageGenerationParameters(ModelParameters):
user: Optional[str] = None


class OpenAIImageGenerationParametersDict(ModelParametersDict, total=False):
class OpenAIImageGenerationParametersDict(RemoteModelParametersDict, total=False):
quality: Optional[Literal['hd', 'standard']]
response_format: Optional[Literal['url', 'b64_json']]
size: Optional[Literal['256x256', '512x512', '1024x1024', '1792x1024', '1024x1792']]
4 changes: 2 additions & 2 deletions generate/image_generation/models/qianfan.py
Original file line number Diff line number Diff line change
@@ -8,7 +8,7 @@

from generate.http import HttpClient, HttpxPostKwargs, ResponseValue
from generate.image_generation.base import GeneratedImage, ImageGenerationOutput, RemoteImageGenerationModel
from generate.model import ModelParameters, ModelParametersDict
from generate.model import ModelParameters, RemoteModelParametersDict
from generate.platforms.baidu import QianfanSettings, QianfanTokenManager

ValidSize = Literal[
@@ -30,7 +30,7 @@ class QianfanImageGenerationParameters(ModelParameters):
user: Optional[str] = Field(default=None, serialization_alias='user_id')


class QianfanImageGenerationParametersDict(ModelParametersDict, total=False):
class QianfanImageGenerationParametersDict(RemoteModelParametersDict, total=False):
size: Optional[ValidSize]
n: Optional[int]
negative_prompt: Optional[str]
4 changes: 2 additions & 2 deletions generate/model.py
Original file line number Diff line number Diff line change
@@ -28,8 +28,8 @@ def model_update(self, **kwargs: Any) -> None:
setattr(self, k, v)


class ModelParametersDict(TypedDict, total=False):
...
class RemoteModelParametersDict(TypedDict, total=False):
timeout: Optional[int]


class ModelInfo(BaseModel):
6 changes: 3 additions & 3 deletions generate/modifiers/structure.py
Original file line number Diff line number Diff line change
@@ -17,7 +17,7 @@
UserMessage,
ensure_messages,
)
from generate.model import GenerateModel, ModelOutput, ModelParametersDict
from generate.model import GenerateModel, ModelOutput, RemoteModelParametersDict

field_info_title = 'Output JSON strictly based the format and pydantic field information below:\n'
json_schema_title = 'Output JSON strictly based the OpenAI JSON Schema:\n'
@@ -177,7 +177,7 @@ def system_message(self) -> SystemMessage:
)
return SystemMessage(content=system_content)

def generate(self, prompt: Prompt, **kwargs: Unpack[ModelParametersDict]) -> StructureModelOutput[O]:
def generate(self, prompt: Prompt, **kwargs: Unpack[RemoteModelParametersDict]) -> StructureModelOutput[O]:
messages = deepcopy(self.messages)
messages.extend(ensure_messages(prompt))
num_reask = 0
@@ -205,7 +205,7 @@ def generate(self, prompt: Prompt, **kwargs: Unpack[ModelParametersDict]) -> Str

raise ValueError(f'Failed to generate valid JSON after {self.max_num_reask} reasks.')

async def async_generate(self, prompt: Prompt, **kwargs: Unpack[ModelParametersDict]) -> StructureModelOutput[O]:
async def async_generate(self, prompt: Prompt, **kwargs: Unpack[RemoteModelParametersDict]) -> StructureModelOutput[O]:
messages = deepcopy(self.messages)
messages.extend(ensure_messages(prompt))
num_reask = 0
12 changes: 11 additions & 1 deletion generate/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import asyncio
from typing import AsyncIterator, Awaitable, Generator, Generic, Iterable, TypeVar
from typing import Any, AsyncIterator, Awaitable, Generator, Generic, Iterable, TypeVar

from generate.types import OrIterable

@@ -50,3 +50,13 @@ def sync_aiter(aiterator: AsyncIterator[T]) -> Generator[T, None, None]:

nest_asyncio.apply()
yield from sync_aiter(aiterator)


def unwrap_model(model: Any) -> Any:
from generate.model import GenerateModel

if hasattr(model, 'model'):
if isinstance(model.model, GenerateModel):
return unwrap_model(model.model)
return model
return model
2 changes: 1 addition & 1 deletion tests/test_chat_completion_model.py
Original file line number Diff line number Diff line change
@@ -29,7 +29,7 @@ def test_model_type_is_unique() -> None:
'parameters',
[
{},
{'temperature': 0.5, 'top_p': 0.85, 'max_tokens': 20},
{'temperature': 0.5, 'top_p': 0.85, 'max_tokens': 20, 'timeout': 20},
],
)
def test_http_chat_model(model_cls: Type[ChatCompletionModel], parameters: dict[str, Any]) -> None: