Skip to content

Commit

Permalink
add moonshot (#30)
Browse files Browse the repository at this point in the history
* add moonshot
* Update version

---------

Co-authored-by: wangyuxin <[email protected]>
  • Loading branch information
wangyuxinwhy and wangyuxin authored Feb 6, 2024
1 parent 690fbf1 commit 3fbc9a2
Show file tree
Hide file tree
Showing 12 changed files with 110 additions and 14 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,13 @@ Generate Package 允许用户通过统一的 api 访问跨平台的生成式模
* [Minimax](https://api.minimax.chat/document/guides/chat)
* [百度智能云](https://cloud.baidu.com/doc/WENXINWORKSHOP/s/clntwmv7t)
* [智谱](https://open.bigmodel.cn/dev/api)
* [月之暗面](https://platform.moonshot.cn/docs)
* ...

## Features

* **多模态**,支持文本生成,多模态文本生成,结构体生成,图像生成,语音生成...
* **跨平台**,完整支持 OpenAI,Azure,Minimax,智谱,文心一言 在内的国内外多家平台
* **跨平台**,完整支持 OpenAI,Azure,Minimax,智谱,月之暗面,文心一言 在内的国内外多家平台
* **One API**,统一了不同平台的消息格式,推理参数,接口封装,返回解析,让用户无需关心不同平台的差异
* **异步和流式**,提供流式调用,非流式调用,同步调用,异步调用,异步批量调用,适配不同的应用场景
* **自带电池**,提供 UI,输入检查,参数检查,计费,速率控制,*ChatEngine*, *function call* 等功能
Expand Down
4 changes: 4 additions & 0 deletions generate/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
MinimaxChatParameters,
MinimaxProChat,
MinimaxProChatParameters,
MoonshotChat,
MoonshotParameters,
OpenAIChat,
OpenAIChatParameters,
RemoteChatCompletionModel,
Expand Down Expand Up @@ -95,6 +97,8 @@
'DashScopeChatParameters',
'DashScopeMultiModalChat',
'DashScopeMultiModalChatParameters',
'MoonshotChat',
'MoonshotParameters',
'OpenAISpeech',
'OpenAISpeechParameters',
'MinimaxSpeech',
Expand Down
7 changes: 7 additions & 0 deletions generate/chat_completion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
MinimaxChatParameters,
MinimaxProChat,
MinimaxProChatParameters,
MoonshotChat,
MoonshotParameters,
OpenAIChat,
OpenAIChatParameters,
WenxinChat,
Expand All @@ -46,6 +48,7 @@
(BailianChat, BailianChatParameters),
(DashScopeChat, DashScopeChatParameters),
(DashScopeMultiModalChat, DashScopeMultiModalChatParameters),
(MoonshotChat, MoonshotParameters),
]

ChatModelRegistry: dict[str, tuple[Type[ChatCompletionModel], Type[ModelParameters]]] = {
Expand Down Expand Up @@ -80,6 +83,10 @@
'BailianChatParameters',
'DashScopeChat',
'DashScopeChatParameters',
'DashScopeMultiModalChat',
'DashScopeMultiModalChatParameters',
'MoonshotChat',
'MoonshotParameters',
'MessagePrinter',
'SimpleMessagePrinter',
'get_json_schema',
Expand Down
3 changes: 3 additions & 0 deletions generate/chat_completion/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from generate.chat_completion.models.hunyuan import HunyuanChat, HunyuanChatParameters
from generate.chat_completion.models.minimax import MinimaxChat, MinimaxChatParameters
from generate.chat_completion.models.minimax_pro import MinimaxProChat, MinimaxProChatParameters
from generate.chat_completion.models.moonshot import MoonshotChat, MoonshotParameters
from generate.chat_completion.models.openai import OpenAIChat, OpenAIChatParameters
from generate.chat_completion.models.wenxin import WenxinChat, WenxinChatParameters
from generate.chat_completion.models.zhipu import (
Expand Down Expand Up @@ -46,4 +47,6 @@
'DashScopeChatParameters',
'DashScopeMultiModalChat',
'DashScopeMultiModalChatParameters',
'MoonshotChat',
'MoonshotParameters',
]
4 changes: 2 additions & 2 deletions generate/chat_completion/models/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def generate(self, prompt: Prompt, **kwargs: Unpack[OpenAIChatParametersDict]) -
parameters = self.parameters.clone_with_changes(**kwargs)
request_parameters = self._get_request_parameters(messages, parameters)
response = self.http_client.post(request_parameters)
output = parse_openai_model_reponse(response.json())
output = parse_openai_model_reponse(response.json(), model_type=self.model_type)
output.model_info.type = self.model_type
return output

Expand All @@ -70,7 +70,7 @@ async def async_generate(self, prompt: Prompt, **kwargs: Unpack[OpenAIChatParame
parameters = self.parameters.clone_with_changes(**kwargs)
request_parameters = self._get_request_parameters(messages, parameters)
response = await self.http_client.async_post(request_parameters=request_parameters)
output = parse_openai_model_reponse(response.json())
output = parse_openai_model_reponse(response.json(), model_type=self.model_type)
output.model_info.type = self.model_type
return output

Expand Down
64 changes: 64 additions & 0 deletions generate/chat_completion/models/moonshot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from __future__ import annotations

from typing import AsyncIterator, ClassVar, Iterator, Optional

from pydantic import PositiveInt
from typing_extensions import Unpack, override

from generate.chat_completion.message import Prompt
from generate.chat_completion.model_output import ChatCompletionOutput, ChatCompletionStreamOutput
from generate.chat_completion.models.openai import OpenAIChat
from generate.http import HttpClient
from generate.model import ModelParameters, ModelParametersDict
from generate.platforms import MoonshotSettings
from generate.types import Probability, Temperature


class MoonshotParameters(ModelParameters):
temperature: Optional[Temperature] = None
top_p: Optional[Probability] = None
max_tokens: Optional[PositiveInt] = None


class MoonshotParametersDict(ModelParametersDict, total=False):
temperature: Temperature
top_p: Probability
max_tokens: PositiveInt


class MoonshotChat(OpenAIChat):
model_type: ClassVar[str] = 'moonshot'

parameters: MoonshotParameters
settings: MoonshotSettings

def __init__(
self,
model: str = 'moonshot-v1-8k',
parameters: MoonshotParameters | None = None,
settings: MoonshotSettings | None = None,
http_client: HttpClient | None = None,
) -> None:
self.parameters = parameters or MoonshotParameters()
self.settings = settings or MoonshotSettings() # type: ignore
self.http_client = http_client or HttpClient()
self.model = model

@override
def generate(self, prompt: Prompt, **kwargs: Unpack[MoonshotParametersDict]) -> ChatCompletionOutput:
return super().generate(prompt, **kwargs)

@override
async def async_generate(self, prompt: Prompt, **kwargs: Unpack[MoonshotParametersDict]) -> ChatCompletionOutput:
return await super().async_generate(prompt, **kwargs)

@override
def stream_generate(self, prompt: Prompt, **kwargs: Unpack[MoonshotParametersDict]) -> Iterator[ChatCompletionStreamOutput]:
yield from super().stream_generate(prompt, **kwargs)

@override
async def async_stream_generate(
self, prompt: Prompt, **kwargs: Unpack[MoonshotParametersDict]
) -> AsyncIterator[ChatCompletionStreamOutput]:
async for i in super().async_stream_generate(prompt, **kwargs):
yield i
15 changes: 11 additions & 4 deletions generate/chat_completion/models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,13 @@ def calculate_cost(model_name: str, input_tokens: int, output_tokens: int) -> fl
return (0.03 * dollar_to_yuan) * (input_tokens / 1000) + (0.06 * dollar_to_yuan) * (output_tokens / 1000)
if 'gpt-3.5-turbo' in model_name:
return (0.001 * dollar_to_yuan) * (input_tokens / 1000) + (0.002 * dollar_to_yuan) * (output_tokens / 1000)
if 'moonshot' in model_name:
if '8k' in model_name:
return 0.012 * (input_tokens / 1000) + 0.012 * (output_tokens / 1000)
if '32k' in model_name:
return 0.024 * (input_tokens / 1000) + 0.024 * (output_tokens / 1000)
if '128k' in model_name:
return 0.06 * (input_tokens / 1000) + 0.06 * (output_tokens / 1000)
return None


Expand Down Expand Up @@ -242,7 +249,7 @@ def _convert_to_assistant_message(message: dict[str, Any]) -> AssistantMessage:
return AssistantMessage(content=message.get('content') or '', function_call=function_call, tool_calls=tool_calls)


def parse_openai_model_reponse(response: ResponseValue) -> ChatCompletionOutput:
def parse_openai_model_reponse(response: ResponseValue, model_type: str) -> ChatCompletionOutput:
message = _convert_to_assistant_message(response['choices'][0]['message'])
extra = {'usage': response['usage']}
if system_fingerprint := response.get('system_fingerprint'):
Expand All @@ -253,7 +260,7 @@ def parse_openai_model_reponse(response: ResponseValue) -> ChatCompletionOutput:
finish_reason = finish_details['type'] if (finish_details := choice.get('finish_details')) else None

return ChatCompletionOutput(
model_info=ModelInfo(task='chat_completion', type='openai', name=response['model']),
model_info=ModelInfo(task='chat_completion', type=model_type, name=response['model']),
message=message,
finish_reason=finish_reason or '',
cost=calculate_cost(response['model'], response['usage']['prompt_tokens'], response['usage']['completion_tokens']),
Expand Down Expand Up @@ -380,15 +387,15 @@ def generate(self, prompt: Prompt, **kwargs: Unpack[OpenAIChatParametersDict]) -
parameters = self.parameters.clone_with_changes(**kwargs)
request_parameters = self._get_request_parameters(messages, parameters)
response = self.http_client.post(request_parameters)
return parse_openai_model_reponse(response.json())
return parse_openai_model_reponse(response.json(), model_type=self.model_type)

@override
async def async_generate(self, prompt: Prompt, **kwargs: Unpack[OpenAIChatParametersDict]) -> ChatCompletionOutput:
messages = ensure_messages(prompt)
parameters = self.parameters.clone_with_changes(**kwargs)
request_parameters = self._get_request_parameters(messages, parameters)
response = await self.http_client.async_post(request_parameters=request_parameters)
return parse_openai_model_reponse(response.json())
return parse_openai_model_reponse(response.json(), model_type=self.model_type)

def _get_stream_request_parameters(self, messages: Messages, parameters: OpenAIChatParameters) -> HttpxPostKwargs:
http_parameters = self._get_request_parameters(messages, parameters)
Expand Down
2 changes: 2 additions & 0 deletions generate/platforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from generate.platforms.dashscope import DashScopeSettings
from generate.platforms.hunyuan import HunyuanSettings
from generate.platforms.minimax import MinimaxSettings
from generate.platforms.moonshot import MoonshotSettings
from generate.platforms.openai import OpenAISettings
from generate.platforms.zhipu import ZhipuSettings

Expand All @@ -19,4 +20,5 @@
'BailianSettings',
'HunyuanSettings',
'DashScopeSettings',
'MoonshotSettings',
]
9 changes: 9 additions & 0 deletions generate/platforms/moonshot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from pydantic import SecretStr
from pydantic_settings import BaseSettings, SettingsConfigDict


class MoonshotSettings(BaseSettings):
model_config = SettingsConfigDict(extra='ignore', env_prefix='moonshot_', env_file='.env')

api_key: SecretStr
api_base: str = 'https://api.moonshot.cn/v1'
6 changes: 4 additions & 2 deletions generate/ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,16 @@ def chat_model(self) -> ChatCompletionModel:
def get_avatars() -> List[Avatar]:
avatar_map = {
'openai': 'https://mrvian.com/wp-content/uploads/2023/02/logo-open-ai.png',
'wenxin': 'https://nlp-eb.cdn.bcebos.com/static/eb/asset/robin.e9dc83e5.png',
'wenxin': 'https://yuxin-wang.oss-cn-beijing.aliyuncs.com/uPic/wzkPgl.png',
'bailian': 'https://yuxin-wang.oss-cn-beijing.aliyuncs.com/uPic/kHZBrw.png',
'dashscope_multimodal': 'https://yuxin-wang.oss-cn-beijing.aliyuncs.com/uPic/kHZBrw.png',
'dashscope': 'https://yuxin-wang.oss-cn-beijing.aliyuncs.com/uPic/kHZBrw.png',
'hunyuan': 'https://cdn-portal.hunyuan.tencent.com/public/static/logo/logo.png',
'minimax': 'https://yuxin-wang.oss-cn-beijing.aliyuncs.com/uPic/lvMJ2T.png',
'minimax_pro': 'https://yuxin-wang.oss-cn-beijing.aliyuncs.com/uPic/lvMJ2T.png',
'zhipu': 'https://yuxin-wang.oss-cn-beijing.aliyuncs.com/uPic/HIntEu.png',
'baichuan': 'https://ai.tboxn.com/wp-content/uploads/2023/08/%E7%99%BE%E5%B7%9D%E5%A4%A7%E6%A8%A1%E5%9E%8B.png',
'baichuan': 'https://yuxin-wang.oss-cn-beijing.aliyuncs.com/uPic/fODcq1.png',
'moonshot': 'https://yuxin-wang.oss-cn-beijing.aliyuncs.com/uPic/hc2Ygt.png',
}
return [Avatar(name=k, url=v) for k, v in avatar_map.items()]

Expand All @@ -75,6 +76,7 @@ def get_generate_settings() -> List[Any]:
'wenxin',
'baichuan',
'minimax_pro',
'moonshot',
],
)
model_id = TextInput(
Expand Down
2 changes: 1 addition & 1 deletion generate/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.3.1'
__version__ = '0.3.2'
5 changes: 1 addition & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "generate-core"
version = "0.3.1"
version = "0.3.2"
description = "文本生成,图像生成,语音生成"
authors = ["wangyuxin <[email protected]>"]
license = "MIT"
Expand Down Expand Up @@ -40,7 +40,6 @@ select = [
"T",
"PT",
"RET",
"PL",
"TRY",
"PERF",
]
Expand All @@ -52,10 +51,8 @@ ignore = [
"A003", # shadow builtin
"ANN1", # self and cls
"ANN401", # Dynamically typed expressions (typing.Any) are disallowed in
"PLR0913", # Too many arguments in function definition
"TRY003", # Avoid specifying long messages outside the exception class
"PLC0414", # reimport
'PLR0912', # too many branches
]
exclude = ["playground", "api_docs"]
target-version = "py38"
Expand Down

0 comments on commit 3fbc9a2

Please sign in to comment.