Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【功能新增】在线 LLM 模型支持阿里云通义千问 #1534

Merged
merged 3 commits into from
Sep 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ docker run -d --gpus all -p 80:8501 registry.cn-beijing.aliyuncs.com/chatchat/ch
- [MiniMax](https://api.minimax.chat)
- [讯飞星火](https://xinghuo.xfyun.cn)
- [百度千帆](https://cloud.baidu.com/product/wenxinworkshop?track=dingbutonglan)
- [阿里云通义千问](https://dashscope.aliyun.com/)

项目中默认使用的 LLM 类型为 `THUDM/chatglm2-6b`,如需使用其他 LLM 类型,请在 [configs/model_config.py] 中对 `llm_model_dict` 和 `LLM_MODEL` 进行修改。

Expand Down
8 changes: 7 additions & 1 deletion configs/model_config.py.example
Original file line number Diff line number Diff line change
Expand Up @@ -115,14 +115,20 @@ ONLINE_LLM_MODEL = {
"secret_key": "",
"provider": "QianFanWorker",
},
# 火山方舟 API
# 火山方舟 API,文档参考 https://www.volcengine.com/docs/82379
"fangzhou-api": {
"version": "chatglm-6b-model", # 当前支持 "chatglm-6b-model", 更多的见文档模型支持列表中方舟部分。
"version_url": "", # 可以不填写version,直接填写在方舟申请模型发布的API地址
"api_key": "",
"secret_key": "",
"provider": "FangZhouWorker",
},
# 阿里云通义千问 API,文档参考 https://help.aliyun.com/zh/dashscope/developer-reference/api-details
"qwen-api": {
"version": "qwen-turbo", # 可选包括 "qwen-turbo", "qwen-plus"
"api_key": "", # 请在阿里云控制台模型服务灵积API-KEY管理页面创建
"provider": "QwenWorker",
},
}


Expand Down
3 changes: 3 additions & 0 deletions configs/server_config.py.example
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,9 @@ FSCHAT_MODEL_WORKERS = {
"fangzhou-api": {
"port": 21005,
},
"qwen-api": {
"port": 21006,
},
}

# fastchat multi model worker server
Expand Down
6 changes: 6 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,12 @@ pytest
scikit-learn
numexpr

# online api libs
# zhipuai
# dashscope>=1.10.0 # qwen
# qianfan
# volcengine>=1.0.106 # fangzhou

# uncomment libs if you want to use corresponding vector store
# pymilvus==2.1.3 # requires milvus==2.1.3
# psycopg2
Expand Down
6 changes: 6 additions & 0 deletions requirements_api.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,12 @@ pytest
scikit-learn
numexpr

# online api libs
# zhipuai
# dashscope>=1.10.0 # qwen
# qianfan
# volcengine>=1.0.106 # fangzhou

# uncomment libs if you want to use corresponding vector store
# pymilvus==2.1.3 # requires milvus==2.1.3
# psycopg2
Expand Down
1 change: 1 addition & 0 deletions server/model_workers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
from .xinghuo import XingHuoWorker
from .qianfan import QianFanWorker
from .fangzhou import FangZhouWorker
from .qwen import QwenWorker
2 changes: 1 addition & 1 deletion server/model_workers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,5 +92,5 @@ def prompt_to_messages(self, prompt: str) -> List[Dict]:
if content := msg[len(ai_start):].strip():
result.append({"role": ai_role, "content": content})
else:
raise RuntimeError(f"unknow role in msg: {msg}")
raise RuntimeError(f"unknown role in msg: {msg}")
return result
123 changes: 123 additions & 0 deletions server/model_workers/qwen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
import json
import sys
from configs import TEMPERATURE
from http import HTTPStatus
from typing import List, Literal, Dict

from fastchat import conversation as conv

from server.model_workers.base import ApiModelWorker
from server.utils import get_model_worker_config


def request_qwen_api(
messages: List[Dict[str, str]],
api_key: str = None,
version: str = "qwen-turbo",
temperature: float = TEMPERATURE,
model_name: str = "qwen-api",
):
import dashscope

config = get_model_worker_config(model_name)
api_key = api_key or config.get("api_key")
version = version or config.get("version")

gen = dashscope.Generation()
responses = gen.call(
model=version,
temperature=temperature,
api_key=api_key,
messages=messages,
result_format='message', # set the result is message format.
stream=True,
)

text = ""
for resp in responses:
if resp.status_code != HTTPStatus.OK:
yield {
"code": resp.status_code,
"text": "api not response correctly",
}

if resp["status_code"] == 200:
if choices := resp["output"]["choices"]:
yield {
"code": 200,
"text": choices[0]["message"]["content"],
}
else:
yield {
"code": resp["status_code"],
"text": resp["message"],
}


class QwenWorker(ApiModelWorker):
def __init__(
self,
*,
version: Literal["qwen-turbo", "qwen-plus"] = "qwen-turbo",
model_names: List[str] = ["qwen-api"],
controller_addr: str,
worker_addr: str,
**kwargs,
):
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
kwargs.setdefault("context_len", 16384)
super().__init__(**kwargs)

# TODO: 确认模板是否需要修改
self.conv = conv.Conversation(
name=self.model_names[0],
system_message="你是一个聪明、对人类有帮助的人工智能,你可以对人类提出的问题给出有用、详细、礼貌的回答。",
messages=[],
roles=["user", "assistant", "system"],
sep="\n### ",
stop_str="###",
)
config = self.get_config()
self.api_key = config.get("api_key")
self.version = version

def generate_stream_gate(self, params):
messages = self.prompt_to_messages(params["prompt"])

for resp in request_qwen_api(messages=messages,
api_key=self.api_key,
version=self.version,
temperature=params.get("temperature")):
if resp["code"] == 200:
yield json.dumps({
"error_code": 0,
"text": resp["text"]
},
ensure_ascii=False
).encode() + b"\0"
else:
yield json.dumps({
"error_code": resp["code"],
"text": resp["text"]
},
ensure_ascii=False
).encode() + b"\0"

def get_embeddings(self, params):
# TODO: 支持embeddings
print("embedding")
print(params)


if __name__ == "__main__":
import uvicorn
from server.utils import MakeFastAPIOffline
from fastchat.serve.model_worker import app

worker = QwenWorker(
controller_addr="http://127.0.0.1:20001",
worker_addr="http://127.0.0.1:20007",
)
sys.modules["fastchat.serve.model_worker"].worker = worker
MakeFastAPIOffline(app)
uvicorn.run(app, port=20007)
19 changes: 19 additions & 0 deletions tests/online_api/test_qwen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import sys
from pathlib import Path
root_path = Path(__file__).parent.parent.parent
sys.path.append(str(root_path))

from server.model_workers.qwen import request_qwen_api
from pprint import pprint
import pytest


@pytest.mark.parametrize("version", ["qwen-turbo"])
def test_qwen(version):
messages = [{"role": "user", "content": "hello"}]
print("\n" + version + "\n")

for x in request_qwen_api(messages, version=version):
print(type(x))
pprint(x)
assert x["code"] == 200