Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Glm4 #16

Merged
merged 4 commits into from
Jan 18, 2024
Merged

Glm4 #16

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions generate/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
OpenAIImageGenerationParameters,
QianfanImageGeneration,
QianfanImageGenerationParameters,
ZhipuImageGeneration,
)
from generate.text_to_speech import (
MinimaxProSpeech,
Expand Down Expand Up @@ -82,6 +83,7 @@
'BaiduImageGenerationParameters',
'QianfanImageGeneration',
'QianfanImageGenerationParameters',
'ZhipuImageGeneration',
'function',
'load_chat_model',
'load_speech_model',
Expand Down
8 changes: 4 additions & 4 deletions generate/access_token_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,15 @@ class AccessTokenManager(ABC):
_token: Optional[str] = None
_token_expires_at: datetime

def __init__(self, token_refresh_days: int = 1) -> None:
def __init__(self, token_refresh_seconds: int = 24 * 60 * 60) -> None:
self._token = None
self.token_refresh_days = token_refresh_days
self.token_refresh_seconds = token_refresh_seconds

@property
def token(self) -> str:
if self._token is None:
self._token = self._get_token()
self._token_expires_at = datetime.now() + timedelta(days=self.token_refresh_days)
self._token_expires_at = datetime.now() + timedelta(seconds=self.token_refresh_seconds)
else:
self._maybe_refresh_token()
return self._token
Expand All @@ -27,4 +27,4 @@ def _get_token(self) -> str:
def _maybe_refresh_token(self) -> None:
if self._token_expires_at < datetime.now():
self._token = self._get_token()
self._token_expires_at = datetime.now() + timedelta(days=self.token_refresh_days)
self._token_expires_at = datetime.now() + timedelta(seconds=self.token_refresh_seconds)
2 changes: 1 addition & 1 deletion generate/chat_completion/message/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ class FunctionCall(BaseModel):

class ToolCall(BaseModel):
id: str # noqa: A003
type: Literal['function'] = 'function' # noqa: A003
type: str = 'function'
function: FunctionCall


Expand Down
2 changes: 1 addition & 1 deletion generate/chat_completion/model_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from generate.model import ModelOutput


class ChatCompletionOutput(ModelOutput, AssistantMessage):
class ChatCompletionOutput(ModelOutput):
message: AssistantMessage
finish_reason: Optional[str] = None

Expand Down
27 changes: 14 additions & 13 deletions generate/chat_completion/models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def calculate_cost(model_name: str, input_tokens: int, output_tokens: int) -> fl
return None


def convert_openai_message_to_generate_message(message: dict[str, Any]) -> AssistantMessage:
def _convert_to_assistant_message(message: dict[str, Any]) -> AssistantMessage:
if function_call_dict := message.get('function_call'):
function_call = FunctionCall(
name=function_call_dict.get('name') or '',
Expand All @@ -236,7 +236,7 @@ def convert_openai_message_to_generate_message(message: dict[str, Any]) -> Assis


def parse_openai_model_reponse(response: ResponseValue) -> ChatCompletionOutput:
message = convert_openai_message_to_generate_message(response['choices'][0]['message'])
message = _convert_to_assistant_message(response['choices'][0]['message'])
extra = {'usage': response['usage']}
if system_fingerprint := response.get('system_fingerprint'):
extra['system_fingerprint'] = system_fingerprint
Expand All @@ -263,11 +263,11 @@ def process(self, response: ResponseValue) -> ChatCompletionStreamOutput | None:
delta_dict = response['choices'][0]['delta']

if self.message is None:
self.message = self.process_initial_message(delta_dict)
if self.message is None:
return None
else:
self.update_existing_message(delta_dict)
if self._is_contains_content(delta_dict):
self.message = self.process_initial_message(delta_dict)
return None

self.update_existing_message(delta_dict)
extra = self.extract_extra_info(response)
cost = cost = self.calculate_response_cost(response)
finish_reason = self.determine_finish_reason(response)
Expand All @@ -282,14 +282,15 @@ def process(self, response: ResponseValue) -> ChatCompletionStreamOutput | None:
stream=Stream(delta=delta_dict.get('content') or '', control=stream_control),
)

def process_initial_message(self, delta_dict: dict[str, Any]) -> AssistantMessage | None:
if (
def _is_contains_content(self, delta_dict: dict[str, Any]) -> bool:
return not (
delta_dict.get('content') is None
and delta_dict.get('tool_calls') is None
and delta_dict.get('function_call') is None
):
return None
return convert_openai_message_to_generate_message(delta_dict)
)

def process_initial_message(self, delta_dict: dict[str, Any]) -> AssistantMessage:
return _convert_to_assistant_message(delta_dict)

def update_existing_message(self, delta_dict: dict[str, Any]) -> None:
if not delta_dict:
Expand All @@ -302,7 +303,7 @@ def update_existing_message(self, delta_dict: dict[str, Any]) -> None:
if delta_dict.get('tool_calls'):
index = delta_dict['tool_calls'][0]['index']
if index >= len(self.message.tool_calls or []):
new_tool_calls_message = convert_openai_message_to_generate_message(delta_dict).tool_calls
new_tool_calls_message = _convert_to_assistant_message(delta_dict).tool_calls
assert new_tool_calls_message is not None
if self.message.tool_calls is None:
self.message.tool_calls = []
Expand Down
Loading