From 408cf3e0c1f118199277a02c5850428cae896f85 Mon Sep 17 00:00:00 2001 From: wangyuxin Date: Thu, 18 Jan 2024 19:00:16 +0800 Subject: [PATCH 1/4] Update Zhipu GLM4 --- generate/__init__.py | 2 + generate/access_token_manager.py | 8 +- generate/chat_completion/message/core.py | 2 +- generate/chat_completion/model_output.py | 2 +- generate/chat_completion/models/openai.py | 27 +- generate/chat_completion/models/zhipu.py | 459 +++++++++++++------ generate/image_generation/__init__.py | 2 + generate/image_generation/models/__init__.py | 2 + generate/image_generation/models/zhipu.py | 79 ++++ generate/platforms/zhipu.py | 31 +- 10 files changed, 453 insertions(+), 161 deletions(-) create mode 100644 generate/image_generation/models/zhipu.py diff --git a/generate/__init__.py b/generate/__init__.py index f7569ea..d8eab2e 100644 --- a/generate/__init__.py +++ b/generate/__init__.py @@ -29,6 +29,7 @@ OpenAIImageGenerationParameters, QianfanImageGeneration, QianfanImageGenerationParameters, + ZhipuImageGeneration, ) from generate.text_to_speech import ( MinimaxProSpeech, @@ -82,6 +83,7 @@ 'BaiduImageGenerationParameters', 'QianfanImageGeneration', 'QianfanImageGenerationParameters', + 'ZhipuImageGeneration', 'function', 'load_chat_model', 'load_speech_model', diff --git a/generate/access_token_manager.py b/generate/access_token_manager.py index 4f145c3..c1bd715 100644 --- a/generate/access_token_manager.py +++ b/generate/access_token_manager.py @@ -7,15 +7,15 @@ class AccessTokenManager(ABC): _token: Optional[str] = None _token_expires_at: datetime - def __init__(self, token_refresh_days: int = 1) -> None: + def __init__(self, token_refresh_seconds: int = 24 * 60 * 60) -> None: self._token = None - self.token_refresh_days = token_refresh_days + self.token_refresh_seconds = token_refresh_seconds @property def token(self) -> str: if self._token is None: self._token = self._get_token() - self._token_expires_at = datetime.now() + timedelta(days=self.token_refresh_days) + self._token_expires_at = datetime.now() + timedelta(seconds=self.token_refresh_seconds) else: self._maybe_refresh_token() return self._token @@ -27,4 +27,4 @@ def _get_token(self) -> str: def _maybe_refresh_token(self) -> None: if self._token_expires_at < datetime.now(): self._token = self._get_token() - self._token_expires_at = datetime.now() + timedelta(days=self.token_refresh_days) + self._token_expires_at = datetime.now() + timedelta(seconds=self.token_refresh_seconds) diff --git a/generate/chat_completion/message/core.py b/generate/chat_completion/message/core.py index 8dc4ebc..ea93dae 100644 --- a/generate/chat_completion/message/core.py +++ b/generate/chat_completion/message/core.py @@ -59,7 +59,7 @@ class FunctionCall(BaseModel): class ToolCall(BaseModel): id: str # noqa: A003 - type: Literal['function'] = 'function' # noqa: A003 + type: str = 'function' function: FunctionCall diff --git a/generate/chat_completion/model_output.py b/generate/chat_completion/model_output.py index 055356a..ec1ac6b 100644 --- a/generate/chat_completion/model_output.py +++ b/generate/chat_completion/model_output.py @@ -8,7 +8,7 @@ from generate.model import ModelOutput -class ChatCompletionOutput(ModelOutput, AssistantMessage): +class ChatCompletionOutput(ModelOutput): message: AssistantMessage finish_reason: Optional[str] = None diff --git a/generate/chat_completion/models/openai.py b/generate/chat_completion/models/openai.py index 03ff15b..f9518a7 100644 --- a/generate/chat_completion/models/openai.py +++ b/generate/chat_completion/models/openai.py @@ -210,7 +210,7 @@ def calculate_cost(model_name: str, input_tokens: int, output_tokens: int) -> fl return None -def convert_openai_message_to_generate_message(message: dict[str, Any]) -> AssistantMessage: +def _convert_to_assistant_message(message: dict[str, Any]) -> AssistantMessage: if function_call_dict := message.get('function_call'): function_call = FunctionCall( name=function_call_dict.get('name') or '', @@ -236,7 +236,7 @@ def convert_openai_message_to_generate_message(message: dict[str, Any]) -> Assis def parse_openai_model_reponse(response: ResponseValue) -> ChatCompletionOutput: - message = convert_openai_message_to_generate_message(response['choices'][0]['message']) + message = _convert_to_assistant_message(response['choices'][0]['message']) extra = {'usage': response['usage']} if system_fingerprint := response.get('system_fingerprint'): extra['system_fingerprint'] = system_fingerprint @@ -263,11 +263,11 @@ def process(self, response: ResponseValue) -> ChatCompletionStreamOutput | None: delta_dict = response['choices'][0]['delta'] if self.message is None: - self.message = self.process_initial_message(delta_dict) - if self.message is None: - return None - else: - self.update_existing_message(delta_dict) + if self._is_contains_content(delta_dict): + self.message = self.process_initial_message(delta_dict) + return None + + self.update_existing_message(delta_dict) extra = self.extract_extra_info(response) cost = cost = self.calculate_response_cost(response) finish_reason = self.determine_finish_reason(response) @@ -282,14 +282,15 @@ def process(self, response: ResponseValue) -> ChatCompletionStreamOutput | None: stream=Stream(delta=delta_dict.get('content') or '', control=stream_control), ) - def process_initial_message(self, delta_dict: dict[str, Any]) -> AssistantMessage | None: - if ( + def _is_contains_content(self, delta_dict: dict[str, Any]) -> bool: + return not ( delta_dict.get('content') is None and delta_dict.get('tool_calls') is None and delta_dict.get('function_call') is None - ): - return None - return convert_openai_message_to_generate_message(delta_dict) + ) + + def process_initial_message(self, delta_dict: dict[str, Any]) -> AssistantMessage: + return _convert_to_assistant_message(delta_dict) def update_existing_message(self, delta_dict: dict[str, Any]) -> None: if not delta_dict: @@ -302,7 +303,7 @@ def update_existing_message(self, delta_dict: dict[str, Any]) -> None: if delta_dict.get('tool_calls'): index = delta_dict['tool_calls'][0]['index'] if index >= len(self.message.tool_calls or []): - new_tool_calls_message = convert_openai_message_to_generate_message(delta_dict).tool_calls + new_tool_calls_message = _convert_to_assistant_message(delta_dict).tool_calls assert new_tool_calls_message is not None if self.message.tool_calls is None: self.message.tool_calls = [] diff --git a/generate/chat_completion/models/zhipu.py b/generate/chat_completion/models/zhipu.py index 11e83fc..b4d512e 100644 --- a/generate/chat_completion/models/zhipu.py +++ b/generate/chat_completion/models/zhipu.py @@ -1,23 +1,24 @@ from __future__ import annotations -import time -from typing import Any, AsyncIterator, ClassVar, Iterator, Literal, Optional +import json +from typing import Any, AsyncIterator, ClassVar, Iterator, List, Literal, Optional, Union -import cachetools.func # type: ignore -import jwt from typing_extensions import NotRequired, Self, TypedDict, Unpack, override from generate.chat_completion.base import ChatCompletionModel from generate.chat_completion.message import ( AssistantMessage, + FunctionCall, Message, Messages, MessageTypeError, Prompt, SystemMessage, + ToolCall, UserMessage, ensure_messages, ) +from generate.chat_completion.message.core import ToolMessage from generate.chat_completion.model_output import ChatCompletionOutput, ChatCompletionStreamOutput, Stream from generate.http import ( HttpClient, @@ -25,70 +26,83 @@ ResponseValue, UnexpectedResponseError, ) -from generate.model import ModelParameters, ModelParametersDict -from generate.platforms.zhipu import ZhipuSettings -from generate.types import Probability, Temperature +from generate.model import ModelInfo, ModelParameters, ModelParametersDict +from generate.platforms.zhipu import ZhipuSettings, generate_zhipu_token +from generate.types import JsonSchema, Probability, Temperature -API_TOKEN_TTL_SECONDS = 3 * 60 -CACHE_TTL_SECONDS = API_TOKEN_TTL_SECONDS - 30 +class Function(TypedDict): + name: str + description: str + parameters: NotRequired[JsonSchema] -class ZhipuRef(TypedDict): + +class Retrieval(TypedDict): + knowledge_id: str + prompt_template: NotRequired[str] + + +class WebSearch(TypedDict): enable: NotRequired[bool] search_query: NotRequired[str] +class ZhipuFunctionTool(TypedDict): + type: Literal['function'] + function: Function + + +class ZhipuRetrievalTool(TypedDict): + type: Literal['retrieval'] + retrieval: Retrieval + + +class ZhipuWebSearchTool(TypedDict): + type: Literal['web_search'] + web_search: WebSearch + + +ZhipuTool = Union[ZhipuFunctionTool, ZhipuRetrievalTool, ZhipuWebSearchTool] + + class ZhipuChatParameters(ModelParameters): temperature: Optional[Temperature] = None top_p: Optional[Probability] = None + do_sample: Optional[bool] = None request_id: Optional[str] = None - search_query: Optional[str] = None - - def custom_model_dump(self) -> dict[str, Any]: - output = super().custom_model_dump() - if self.search_query: - output['ref'] = {'enable': True, 'search_query': self.search_query} - output['return_type'] = 'text' - return output + max_tokens: Optional[int] = None + stop: Optional[List[str]] = None + tools: Optional[List[ZhipuTool]] = None + tool_choice: Optional[str] = None class ZhipuChatParametersDict(ModelParametersDict, total=False): temperature: Optional[Temperature] top_p: Optional[Probability] request_id: Optional[str] - search_query: Optional[str] + max_tokens: Optional[int] + stop: Optional[list[str]] + tools: Optional[list[ZhipuTool]] + tool_choice: Optional[str] -class ZhipuMeta(TypedDict): - user_info: str - bot_info: str - bot_name: str - user_name: str +class ZhipuToolCall(TypedDict): + id: str + type: str + index: int + function: NotRequired[ZhipuFunctionCall] -class ZhipuCharacterChatParameters(ModelParameters): - meta: ZhipuMeta = { - 'user_info': '我是陆星辰,是一个男性,是一位知名导演,也是苏梦远的合作导演。', - 'bot_info': '苏梦远,本名苏远心,是一位当红的国内女歌手及演员。', - 'bot_name': '苏梦远', - 'user_name': '陆星辰', - } - request_id: Optional[str] = None - - def custom_model_dump(self) -> dict[str, Any]: - output = super().custom_model_dump() - output['return_type'] = 'text' - return output - - -class ZhipuCharacterChatParametersDict(ModelParametersDict, total=False): - meta: ZhipuMeta - request_id: Optional[str] +class ZhipuFunctionCall(TypedDict): + name: str + arguments: str class ZhipuMessage(TypedDict): - role: Literal['user', 'assistant'] - content: str + role: Literal['user', 'assistant', 'system', 'tool'] + content: NotRequired[str] + tool_calls: NotRequired[list[ZhipuToolCall]] + tool_call_id: NotRequired[str] def convert_to_zhipu_message(message: Message) -> ZhipuMessage: @@ -99,106 +113,220 @@ def convert_to_zhipu_message(message: Message) -> ZhipuMessage: } if isinstance(message, AssistantMessage): + if message.tool_calls is not None: + dict_format_toll_calls: list[ZhipuToolCall] = [] + for index, tool_call in enumerate(message.tool_calls): + tool_type = tool_call.type + if tool_type not in {'function', 'retrieval', 'web_search'}: + raise ValueError(f'invalid tool type: {tool_type}, should be one of function, retrieval, web_search') + dict_format_toll_call: ZhipuToolCall = { + 'id': tool_call.id, + 'type': tool_type, + 'index': index, + } + if tool_type == 'function': + function_dict: ZhipuFunctionCall = { + 'name': tool_call.function.name, + 'arguments': tool_call.function.arguments, + } + dict_format_toll_call['function'] = function_dict + dict_format_toll_calls.append(dict_format_toll_call) + return { + 'role': 'assistant', + 'tool_calls': dict_format_toll_calls, + } return { 'role': 'assistant', 'content': message.content, } - raise MessageTypeError(message, (UserMessage, AssistantMessage)) + if isinstance(message, SystemMessage): + return { + 'role': 'system', + 'content': message.content, + } + if isinstance(message, ToolMessage): + return { + 'role': 'tool', + 'content': message.content or '', + 'tool_call_id': message.tool_call_id, + } -@cachetools.func.ttl_cache(maxsize=10, ttl=CACHE_TTL_SECONDS) -def generate_token(api_key: str) -> str: - try: - api_key, secret = api_key.split('.') - except Exception as e: - raise ValueError('invalid api_key') from e + raise MessageTypeError(message, (UserMessage, AssistantMessage)) - payload = { - 'api_key': api_key, - 'exp': int(round(time.time() * 1000)) + API_TOKEN_TTL_SECONDS * 1000, - 'timestamp': int(round(time.time() * 1000)), - } - return jwt.encode( # type: ignore - payload, - secret, - algorithm='HS256', - headers={'alg': 'HS256', 'sign_type': 'SIGN'}, +def _convert_to_assistant_message(zhiput_message_dict: dict[str, Any]) -> AssistantMessage: + if 'tool_calls' in zhiput_message_dict: + dict_format_tool_calls = zhiput_message_dict['tool_calls'] + dict_format_tool_calls.sort(key=lambda x: x['index']) + tool_calls = [] + for tool_call_dict in zhiput_message_dict['tool_calls']: + if tool_call_dict['type'] != 'function': + raise ValueError(f'invalid tool type: {tool_call_dict["type"]}, should be function') + tool_calls.append( + ToolCall( + id=tool_call_dict['id'], + type='function', + function=FunctionCall( + name=tool_call_dict['function']['name'], + arguments=tool_call_dict['function']['arguments'], + ), + ) + ) + return AssistantMessage( + role='assistant', + content='', + tool_calls=tool_calls, + ) + return AssistantMessage( + role='assistant', + content=zhiput_message_dict['content'], ) -class BaseZhipuChat(ChatCompletionModel): +def _calculate_cost(model_name: str, usage: dict[str, Any]) -> float | None: + if model_name == 'glm-4': + return 0.1 * (usage['total_tokens'] / 1000) + if model_name == 'glm-3-turbo': + return 0.005 * (usage['total_tokens'] / 1000) + if model_name == 'characterglm': + return 0.015 * (usage['total_tokens'] / 1000) + return None + + +class _StreamResponseProcessor: + def __init__(self) -> None: + self.message: AssistantMessage | None = None + self.is_start = True + + def process(self, stream_line: str) -> ChatCompletionStreamOutput | None: + if not stream_line.strip(): + return None + + line = self._preprocess_stream_line(stream_line) + if not line: + return None + response = json.loads(line) + delta_dict = response['choices'][0]['delta'] + + if self.message is None: + if self._is_contains_content(delta_dict): + self.message = self.process_initial_message(delta_dict) + else: + return None + else: + self.update_existing_message(delta_dict) + extra = self.extract_extra_info(response) + cost = cost = self.calculate_response_cost(response) + finish_reason = self.determine_finish_reason(response) + stream_control = 'finish' if finish_reason else 'start' if self.is_start else 'continue' + self.is_start = False + return ChatCompletionStreamOutput( + model_info=ModelInfo(task='chat_completion', type='zhipu', name=response['model']), + message=self.message, + finish_reason=finish_reason, + cost=cost, + extra=extra, + stream=Stream(delta=delta_dict.get('content') or '', control=stream_control), + ) + + @staticmethod + def _preprocess_stream_line(line: str) -> str: + line = line.replace('data:', '') + return line.strip() + + def _is_contains_content(self, delta_dict: dict[str, Any]) -> bool: + return not ( + delta_dict.get('content') is None + and delta_dict.get('tool_calls') is None + and delta_dict.get('function_call') is None + ) + + def process_initial_message(self, delta_dict: dict[str, Any]) -> AssistantMessage: + return _convert_to_assistant_message(delta_dict) + + def update_existing_message(self, delta_dict: dict[str, Any]) -> None: + if not delta_dict: + return + assert self.message is not None + + delta_content = delta_dict.get('content', '') + self.message.content += delta_content + + if delta_dict.get('tool_calls'): + index = delta_dict['tool_calls'][0]['index'] + if index >= len(self.message.tool_calls or []): + new_tool_calls_message = _convert_to_assistant_message(delta_dict).tool_calls + assert new_tool_calls_message is not None + if self.message.tool_calls is None: + self.message.tool_calls = [] + self.message.tool_calls.append(new_tool_calls_message[0]) + else: + assert self.message.tool_calls is not None + self.message.tool_calls[index].function.arguments += delta_dict['tool_calls'][0]['function']['arguments'] + + def extract_extra_info(self, response: ResponseValue) -> dict[str, Any]: + extra = {} + if usage := response.get('usage'): + extra['usage'] = usage + if system_fingerprint := response.get('system_fingerprint'): + extra['system_fingerprint'] = system_fingerprint + return extra + + @staticmethod + def calculate_response_cost(response: ResponseValue) -> float | None: + if usage := response.get('usage'): + return _calculate_cost(response['model'], usage) + return None + + def determine_finish_reason(self, response: ResponseValue) -> str | None: + return response['choices'][0].get('finish_reason') + + +class ZhipuChat(ChatCompletionModel): + model_type: ClassVar[str] = 'zhipu' + def __init__( self, - model: str, - parameters: ModelParameters, + model: str = 'glm-4', + parameters: ZhipuChatParameters | None = None, settings: ZhipuSettings | None = None, http_client: HttpClient | None = None, ) -> None: self.model = model - self.parameters = parameters + self.parameters = parameters or ZhipuChatParameters() self.settings = settings or ZhipuSettings() # type: ignore - self.http_client = http_client or HttpClient() + self.http_client = http_client or HttpClient(stream_strategy='basic') def _get_request_parameters(self, messages: Messages, parameters: ModelParameters) -> HttpxPostKwargs: zhipu_messages = self._convert_messages(messages) headers = { - 'Authorization': generate_token(self.settings.api_key.get_secret_value()), + 'Authorization': generate_zhipu_token(self.settings.api_key.get_secret_value()), } - params = {'prompt': zhipu_messages, **parameters.custom_model_dump()} + params = {'messages': zhipu_messages, 'model': self.model, **parameters.custom_model_dump()} return { - 'url': f'{self.settings.api_base}/{self.model}/invoke', + 'url': f'{self.settings.v4_api_base}/chat/completions', 'headers': headers, 'json': params, } - def _convert_messages(self, messages: Messages) -> list[ZhipuMessage]: - return [convert_to_zhipu_message(message) for message in messages] - def _parse_reponse(self, response: ResponseValue) -> ChatCompletionOutput: - if response['success']: - text = response['data']['choices'][0]['content'] - return ChatCompletionOutput( - model_info=self.model_info, - message=AssistantMessage(content=text), - cost=self.calculate_cost(response['data']['usage']), - extra={'usage': response['data']['usage']}, - ) - - raise UnexpectedResponseError(response) + message_dict = response['choices'][0]['message'] + return ChatCompletionOutput( + model_info=self.model_info, + message=_convert_to_assistant_message(message_dict), + cost=_calculate_cost(self.model, response['usage']), + extra={'usage': response['usage']}, + ) def _get_stream_request_parameters(self, messages: Messages, parameters: ModelParameters) -> HttpxPostKwargs: http_parameters = self._get_request_parameters(messages, parameters) - http_parameters['url'] = f'{self.settings.api_base}/{self.model}/sse-invoke' + http_parameters['json']['stream'] = True return http_parameters - def calculate_cost(self, usage: dict[str, Any]) -> float | None: - if self.name == 'chatglm_turbo': - return 0.005 * (usage['total_tokens'] / 1000) - if self.name == 'characterglm': - return 0.015 * (usage['total_tokens'] / 1000) - return None - - -class ZhipuChat(BaseZhipuChat): - model_type: ClassVar[str] = 'zhipu' - - def __init__( - self, - model: str = 'chatglm_turbo', - parameters: ZhipuChatParameters | None = None, - settings: ZhipuSettings | None = None, - http_client: HttpClient | None = None, - ) -> None: - parameters = parameters or ZhipuChatParameters() - super().__init__(model=model, parameters=parameters, settings=settings, http_client=http_client) - - @override def _convert_messages(self, messages: Messages) -> list[ZhipuMessage]: - if isinstance(system_message := messages[0], SystemMessage): - messages = [UserMessage(content=system_message.content), AssistantMessage(content='好的')] + messages[1:] - return super()._convert_messages(messages) + return [convert_to_zhipu_message(message) for message in messages] @override def generate(self, prompt: Prompt, **kwargs: Unpack[ZhipuChatParametersDict]) -> ChatCompletionOutput: @@ -223,22 +351,16 @@ def stream_generate( messages = ensure_messages(prompt) parameters = self.parameters.update_with_validate(**kwargs) request_parameters = self._get_stream_request_parameters(messages, parameters) - message = AssistantMessage(content='') - is_start = True + stream_processor = _StreamResponseProcessor() + is_finish = False for line in self.http_client.stream_post(request_parameters=request_parameters): - message.content += line - yield ChatCompletionStreamOutput( - model_info=self.model_info, - message=message, - stream=Stream(delta=line, control='start' if is_start else 'continue'), - ) - is_start = False - yield ChatCompletionStreamOutput( - model_info=self.model_info, - message=message, - finish_reason='stop', - stream=Stream(delta='', control='finish'), - ) + if is_finish: + continue + output = stream_processor.process(line) + if output is None: + continue + is_finish = output.is_finish + yield output @override async def async_stream_generate( @@ -247,22 +369,16 @@ async def async_stream_generate( messages = ensure_messages(prompt) parameters = self.parameters.update_with_validate(**kwargs) request_parameters = self._get_stream_request_parameters(messages, parameters) - message = AssistantMessage(content='') - is_start = True + stream_processor = _StreamResponseProcessor() + is_finish = False async for line in self.http_client.async_stream_post(request_parameters=request_parameters): - message.content += line - yield ChatCompletionStreamOutput( - model_info=self.model_info, - message=message, - stream=Stream(delta=line, control='start' if is_start else 'continue'), - ) - is_start = False - yield ChatCompletionStreamOutput( - model_info=self.model_info, - message=message, - finish_reason='stop', - stream=Stream(delta='', control='finish'), - ) + if is_finish: + continue + output = stream_processor.process(line) + if output is None: + continue + is_finish = output.is_finish + yield output @property @override @@ -275,7 +391,34 @@ def from_name(cls, name: str) -> Self: return cls(model=name) -class ZhipuCharacterChat(BaseZhipuChat): +class ZhipuMeta(TypedDict): + user_info: str + bot_info: str + bot_name: str + user_name: str + + +class ZhipuCharacterChatParameters(ModelParameters): + meta: ZhipuMeta = { + 'user_info': '我是陆星辰,是一个男性,是一位知名导演,也是苏梦远的合作导演。', + 'bot_info': '苏梦远,本名苏远心,是一位当红的国内女歌手及演员。', + 'bot_name': '苏梦远', + 'user_name': '陆星辰', + } + request_id: Optional[str] = None + + def custom_model_dump(self) -> dict[str, Any]: + output = super().custom_model_dump() + output['return_type'] = 'text' + return output + + +class ZhipuCharacterChatParametersDict(ModelParametersDict, total=False): + meta: ZhipuMeta + request_id: Optional[str] + + +class ZhipuCharacterChat(ChatCompletionModel): model_type: ClassVar[str] = 'zhipu-character' def __init__( @@ -285,8 +428,42 @@ def __init__( settings: ZhipuSettings | None = None, http_client: HttpClient | None = None, ) -> None: - parameters = parameters or ZhipuCharacterChatParameters() - super().__init__(model=model, parameters=parameters, settings=settings, http_client=http_client) + self.model = model + self.parameters = parameters or ZhipuCharacterChatParameters() + self.settings = settings or ZhipuSettings() # type: ignore + self.http_client = http_client or HttpClient() + + def _get_request_parameters(self, messages: Messages, parameters: ModelParameters) -> HttpxPostKwargs: + zhipu_messages = self._convert_messages(messages) + headers = { + 'Authorization': generate_zhipu_token(self.settings.api_key.get_secret_value()), + } + params = {'prompt': zhipu_messages, **parameters.custom_model_dump()} + return { + 'url': f'{self.settings.v3_api_base}/{self.model}/invoke', + 'headers': headers, + 'json': params, + } + + def _convert_messages(self, messages: Messages) -> list[ZhipuMessage]: + return [convert_to_zhipu_message(message) for message in messages] + + def _parse_reponse(self, response: ResponseValue) -> ChatCompletionOutput: + if response['success']: + text = response['data']['choices'][0]['content'] + return ChatCompletionOutput( + model_info=self.model_info, + message=AssistantMessage(content=text), + cost=_calculate_cost(self.name, response['data']['usage']), + extra={'usage': response['data']['usage']}, + ) + + raise UnexpectedResponseError(response) + + def _get_stream_request_parameters(self, messages: Messages, parameters: ModelParameters) -> HttpxPostKwargs: + http_parameters = self._get_request_parameters(messages, parameters) + http_parameters['url'] = f'{self.settings.v3_api_base}/{self.model}/sse-invoke' + return http_parameters @override def generate(self, prompt: Prompt, **kwargs: Unpack[ZhipuCharacterChatParametersDict]) -> ChatCompletionOutput: diff --git a/generate/image_generation/__init__.py b/generate/image_generation/__init__.py index a6a0cdd..3903526 100644 --- a/generate/image_generation/__init__.py +++ b/generate/image_generation/__init__.py @@ -10,6 +10,7 @@ OpenAIImageGenerationParameters, QianfanImageGeneration, QianfanImageGenerationParameters, + ZhipuImageGeneration, ) from generate.model import ModelParameters @@ -33,4 +34,5 @@ 'BaiduImageGenerationParameters', 'QianfanImageGeneration', 'QianfanImageGenerationParameters', + 'ZhipuImageGeneration', ] diff --git a/generate/image_generation/models/__init__.py b/generate/image_generation/models/__init__.py index 37c6e2b..d7fabfb 100644 --- a/generate/image_generation/models/__init__.py +++ b/generate/image_generation/models/__init__.py @@ -1,6 +1,7 @@ from generate.image_generation.models.baidu import BaiduImageGeneration, BaiduImageGenerationParameters from generate.image_generation.models.openai import OpenAIImageGeneration, OpenAIImageGenerationParameters from generate.image_generation.models.qianfan import QianfanImageGeneration, QianfanImageGenerationParameters +from generate.image_generation.models.zhipu import ZhipuImageGeneration __all__ = [ 'OpenAIImageGeneration', @@ -9,4 +10,5 @@ 'BaiduImageGenerationParameters', 'QianfanImageGeneration', 'QianfanImageGenerationParameters', + 'ZhipuImageGeneration', ] diff --git a/generate/image_generation/models/zhipu.py b/generate/image_generation/models/zhipu.py new file mode 100644 index 0000000..31ccdcd --- /dev/null +++ b/generate/image_generation/models/zhipu.py @@ -0,0 +1,79 @@ +from __future__ import annotations + +from httpx import Response +from typing_extensions import Self, override + +from generate.http import HttpClient, HttpxPostKwargs +from generate.image_generation.base import GeneratedImage, ImageGenerationModel, ImageGenerationOutput +from generate.platforms.zhipu import ZhipuSettings, generate_zhipu_token + + +class ZhipuImageGeneration(ImageGenerationModel): + model_type = 'zhipu' + + def __init__( + self, + model: str = 'cogview-3', + settings: ZhipuSettings | None = None, + http_client: HttpClient | None = None, + ) -> None: + self.model = model + self.settings = settings or ZhipuSettings() # type: ignore + self.http_client = http_client or HttpClient() + + def _get_request_parameters(self, prompt: str) -> HttpxPostKwargs: + headers = { + 'Content-Type': 'application/json', + 'Authorization': generate_zhipu_token(api_key=self.settings.api_key.get_secret_value()), + } + json_data = { + 'model': self.model, + 'prompt': prompt, + } + return { + 'url': self.settings.v4_api_base + 'images/generations', + 'json': json_data, + 'headers': headers, + } + + @override + def generate(self, prompt: str) -> ImageGenerationOutput: + request_parameters = self._get_request_parameters(prompt) + response = self.http_client.post(request_parameters=request_parameters) + return self._construct_model_output(prompt, response) + + @override + async def async_generate(self, prompt: str) -> ImageGenerationOutput: + request_parameters = self._get_request_parameters(prompt) + response = await self.http_client.async_post(request_parameters=request_parameters) + return self._construct_model_output(prompt, response) + + def _construct_model_output(self, prompt: str, response: Response) -> ImageGenerationOutput: + response_data = response.json() + generated_images: list[GeneratedImage] = [] + for image_data in response_data['data']: + url = image_data['url'] + content = self.http_client.get({'url': url}).content + generated_images.append( + GeneratedImage( + url=url, + prompt=prompt, + image_format='png', + content=content, + ) + ) + return ImageGenerationOutput( + model_info=self.model_info, + images=generated_images, + cost=0.25, + ) + + @property + @override + def name(self) -> str: + return self.model + + @classmethod + @override + def from_name(cls, name: str) -> Self: + return cls(model=name) diff --git a/generate/platforms/zhipu.py b/generate/platforms/zhipu.py index 0d20d39..32dd9a6 100644 --- a/generate/platforms/zhipu.py +++ b/generate/platforms/zhipu.py @@ -1,9 +1,38 @@ +import time + +import cachetools.func # type: ignore +import jwt from pydantic import SecretStr from pydantic_settings import BaseSettings, SettingsConfigDict +API_TOKEN_TTL_SECONDS = 3 * 60 +CACHE_TTL_SECONDS = API_TOKEN_TTL_SECONDS - 30 + class ZhipuSettings(BaseSettings): model_config = SettingsConfigDict(extra='ignore', env_prefix='zhipu_', env_file='.env') api_key: SecretStr - api_base: str = 'https://open.bigmodel.cn/api/paas/v3/model-api' + v3_api_base: str = 'https://open.bigmodel.cn/api/paas/v3/model-api' + v4_api_base: str = 'https://open.bigmodel.cn/api/paas/v4/' + + +@cachetools.func.ttl_cache(ttl=CACHE_TTL_SECONDS) +def generate_zhipu_token(api_key: str) -> str: + try: + api_key, secret = api_key.split('.') + except Exception as e: + raise ValueError('invalid api_key') from e + + payload = { + 'api_key': api_key, + 'exp': int(round(time.time() * 1000)) + API_TOKEN_TTL_SECONDS * 1000, + 'timestamp': int(round(time.time() * 1000)), + } + + return jwt.encode( # type: ignore + payload, + secret, + algorithm='HS256', + headers={'alg': 'HS256', 'sign_type': 'SIGN'}, + ) From 8d84f60e2e84399a4237e78533bd6cc99052b754 Mon Sep 17 00:00:00 2001 From: wangyuxin Date: Thu, 18 Jan 2024 19:02:37 +0800 Subject: [PATCH 2/4] Update version to 0.3.1 and add project description --- generate/version.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/generate/version.py b/generate/version.py index 0404d81..e1424ed 100644 --- a/generate/version.py +++ b/generate/version.py @@ -1 +1 @@ -__version__ = '0.3.0' +__version__ = '0.3.1' diff --git a/pyproject.toml b/pyproject.toml index 5130452..c4bec02 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "generate-core" -version = "0.3.0" +version = "0.3.1" description = "文本生成,图像生成,语音生成" authors = ["wangyuxin "] license = "MIT" From 56800f4ef4f74d35aa4def90620b6e0d04a3dfe1 Mon Sep 17 00:00:00 2001 From: wangyuxin Date: Thu, 18 Jan 2024 19:05:26 +0800 Subject: [PATCH 3/4] Update ZhipuCharacterChat model_type in ZhipuCharacterChat class --- generate/chat_completion/models/zhipu.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/generate/chat_completion/models/zhipu.py b/generate/chat_completion/models/zhipu.py index b4d512e..14efe9d 100644 --- a/generate/chat_completion/models/zhipu.py +++ b/generate/chat_completion/models/zhipu.py @@ -419,11 +419,11 @@ class ZhipuCharacterChatParametersDict(ModelParametersDict, total=False): class ZhipuCharacterChat(ChatCompletionModel): - model_type: ClassVar[str] = 'zhipu-character' + model_type: ClassVar[str] = 'zhipu_character' def __init__( self, - model: str = 'characterglm', + model: str = 'charglm-3', parameters: ZhipuCharacterChatParameters | None = None, settings: ZhipuSettings | None = None, http_client: HttpClient | None = None, From 7a73671c55aa40345b82bd8af5a8c09c814bc2a3 Mon Sep 17 00:00:00 2001 From: wangyuxin Date: Thu, 18 Jan 2024 19:10:23 +0800 Subject: [PATCH 4/4] 0.3.1 to 0.3.0 --- generate/version.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/generate/version.py b/generate/version.py index e1424ed..0404d81 100644 --- a/generate/version.py +++ b/generate/version.py @@ -1 +1 @@ -__version__ = '0.3.1' +__version__ = '0.3.0' diff --git a/pyproject.toml b/pyproject.toml index c4bec02..5130452 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "generate-core" -version = "0.3.1" +version = "0.3.0" description = "文本生成,图像生成,语音生成" authors = ["wangyuxin "] license = "MIT"