From 690fbf16f2154a76084412d0f052c9827f27808a Mon Sep 17 00:00:00 2001 From: "yuxin.wang" Date: Mon, 5 Feb 2024 18:18:08 +0800 Subject: [PATCH] fixbug (#29) Co-authored-by: wangyuxin --- generate/chat_completion/models/dashscope.py | 2 +- generate/chat_completion/models/zhipu.py | 2 ++ generate/test.py | 4 ++-- generate/ui.py | 9 +++++---- 4 files changed, 10 insertions(+), 7 deletions(-) diff --git a/generate/chat_completion/models/dashscope.py b/generate/chat_completion/models/dashscope.py index ac807c6..a0a446a 100644 --- a/generate/chat_completion/models/dashscope.py +++ b/generate/chat_completion/models/dashscope.py @@ -46,7 +46,7 @@ class DashScopeChatParameters(ModelParameters): enable_search: Optional[bool] = None -class DashScopeChatParametersDict(ModelParametersDict): +class DashScopeChatParametersDict(ModelParametersDict, total=False): seed: int max_tokens: int top_p: float diff --git a/generate/chat_completion/models/zhipu.py b/generate/chat_completion/models/zhipu.py index ac6881b..f9be849 100644 --- a/generate/chat_completion/models/zhipu.py +++ b/generate/chat_completion/models/zhipu.py @@ -86,6 +86,8 @@ class ZhipuChatParameters(ModelParameters): def can_not_equal_zero(cls, v: Optional[Temperature]) -> Optional[Temperature]: if v == 0: return 0.01 + if v == 1: + return 0.99 return v diff --git a/generate/test.py b/generate/test.py index 05e285a..70ba357 100644 --- a/generate/test.py +++ b/generate/test.py @@ -17,7 +17,7 @@ def get_pytest_params( include: Sequence[str] | None = None, ) -> list[Any]: exclude = exclude or [] - include = include or [] + include = include if isinstance(types, str): types = [types] @@ -25,7 +25,7 @@ def get_pytest_params( for model_name, (model_cls, paramter_cls) in model_registry.items(): if model_name in exclude: continue - if model_name not in include: + if include and model_name not in include: continue values: list[Any] = [] for t in types: diff --git a/generate/ui.py b/generate/ui.py index ee2f083..31655ef 100644 --- a/generate/ui.py +++ b/generate/ui.py @@ -28,10 +28,10 @@ class UserState(BaseModel): - chat_model_id: str = 'openai/gpt-3.5-turbo' + chat_model_id: str = 'openai' temperature: float = 1.0 system_message: str = '' - max_tokens: int = 4000 + max_tokens: Optional[int] = None _chat_history: Messages = [] @property @@ -89,7 +89,7 @@ def get_generate_settings() -> List[Any]: initial='', ) temperature_slider = Slider(id='Temperature', label='Temperature', min=0, max=1.0, step=0.1, initial=1) - max_tokens = Slider(id='MaxTokens', label='Max Tokens', min=1, max=5000, step=100, initial=4000) + max_tokens = Slider(id='MaxTokens', label='Max Tokens', min=1, max=5000, step=100, initial=0) return [model_select, model_id, system_message_input, temperature_slider, max_tokens] @@ -111,7 +111,8 @@ async def settings_update(settings: dict) -> None: state.chat_model_id = settings['Model'] state.temperature = settings['Temperature'] state.system_message = settings['SystemMessage'] - state.max_tokens = settings['MaxTokens'] + if settings['MaxTokens']: + state.max_tokens = settings['MaxTokens'] @cl.on_message