Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

增加api流式接口; 新增基于流式api和streamlit的web_demo #808

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 40 additions & 14 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -95,47 +95,56 @@ GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/THUDM/chatglm-6b

## Demo & API

我们提供了一个基于 [Gradio](https://gradio.app) 的网页版 Demo 和一个命令行 Demo。使用时首先需要下载本仓库:
### 网页版(基于gradio) Demo

![web-demo](resources/web-demo-gradio.gif)

首先安装 Gradio:`pip install gradio mdtex2html`,然后运行仓库中的 [web_demo_gradio.py](demo_and_api/web_demo_gradio.py):

```shell
git clone https://github.com/THUDM/ChatGLM-6B
cd ChatGLM-6B
cd demo_and_api
python web_demo_gradio.py
```

#### 网页版 Demo
程序会运行一个 Web Server,并输出地址。在浏览器中打开输出的地址即可使用。最新版 Demo 实现了打字机效果,速度体验大大提升。注意,由于国内 Gradio 的网络访问较为缓慢,启用 `demo.queue().launch(share=True, inbrowser=True)` 时所有网络会经过 Gradio 服务器转发,导致打字机体验大幅下降,现在默认启动方式已经改为 `share=False`,如有需要公网访问的需求,可以重新修改为 `share=True` 启动。

![web-demo](resources/web-demo.gif)
### 网页版(基于streamlit) Demo

首先安装 Gradio:`pip install gradio`,然后运行仓库中的 [web_demo.py](web_demo.py):
首先安装 Streamlit: `pip install streamlit streamlit-chat`,然后运行仓库中的 [web_demo_streamlit.py](demo_and_api/web_demo_streamlit.py):

```shell
python web_demo.py
cd demo_and_api
streamlit run web_demo_streamlit.py --server.port 6006
```

程序会运行一个 Web Server,并输出地址。在浏览器中打开输出的地址即可使用。最新版 Demo 实现了打字机效果,速度体验大大提升。注意,由于国内 Gradio 的网络访问较为缓慢,启用 `demo.queue().launch(share=True, inbrowser=True)` 时所有网络会经过 Gradio 服务器转发,导致打字机体验大幅下降,现在默认启动方式已经改为 `share=False`,如有需要公网访问的需求,可以重新修改为 `share=True` 启动。

感谢 [@AdamBear](https://github.com/AdamBear) 实现了基于 Streamlit 的网页版 Demo,运行方式见[#117](https://github.com/THUDM/ChatGLM-6B/pull/117).
*感谢 [@AdamBear](https://github.com/AdamBear) 贡献的此实现,详见[#117](https://github.com/THUDM/ChatGLM-6B/pull/117).*

#### 命令行 Demo
### 命令行 Demo

![cli-demo](resources/cli-demo.png)

运行仓库中 [cli_demo.py](cli_demo.py):

```shell
cd demo_and_api
python cli_demo.py
```

程序会在命令行中进行交互式的对话,在命令行中输入指示并回车即可生成回复,输入 `clear` 可以清空对话历史,输入 `stop` 终止程序。

### API部署
首先需要安装额外的依赖 `pip install fastapi uvicorn`,然后运行仓库中的 [api.py](api.py):
首先需要安装额外的依赖 `pip install fastapi uvicorn pydantic`,然后运行仓库中的 [api.py](demo_and_api/api.py):
```shell
cd demo_and_api
python api.py
```

API支持普通接口(/chat)和流式接口(/stream_chat);
流式接口可实现打字机效果,调用方式可参考 [web_demo_streamlit_with_api.py](demo_and_api/web_demo_streamlit_with_api.py):

默认部署在本地的 8000 端口,通过 POST 方法进行调用
```shell
curl -X POST "http://127.0.0.1:8000" \
curl -X POST "http://127.0.0.1:8000/chat" \
-H 'Content-Type: application/json' \
-d '{"prompt": "你好", "history": []}'
```
Expand All @@ -149,6 +158,23 @@ curl -X POST "http://127.0.0.1:8000" \
}
```

### 网页版(基于streamlit和API) Demo
streamlit作为前端,api作为后端,使用了api的流式接口

启动后端
api依赖安装参照前文
```shell
cd demo_and_api
python api.py
```

启动前端
streamlit依赖安装参照前文
```shell
cd demo_and_api
streamlit run web_demo_streamlit_with_api.py --server.port 6006
```

## 低成本部署
### 模型量化
默认情况下,模型以 FP16 精度加载,运行上述代码需要大概 13GB 显存。如果你的 GPU 显存有限,可以尝试以量化方式加载模型,使用方法如下:
Expand Down Expand Up @@ -216,7 +242,7 @@ model = load_model_on_gpus("THUDM/chatglm-6b", num_gpus=2)

## ChatGLM-6B 示例

以下是一些使用 `web_demo.py` 得到的示例截图。更多 ChatGLM-6B 的可能,等待你来探索发现!
以下是一些使用 `web_demo_gradio.py` 得到的示例截图。更多 ChatGLM-6B 的可能,等待你来探索发现!

<details><summary><b>自我认知</b></summary>

Expand Down
2 changes: 1 addition & 1 deletion README_en.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ cd ChatGLM-6B

#### Web Demo

![web-demo](resources/web-demo.png)
![web-demo](resources/web-demo-gradio.png)

Install Gradio `pip install gradio`,and run [web_demo.py](web_demo.py):

Expand Down
56 changes: 0 additions & 56 deletions api.py

This file was deleted.

73 changes: 73 additions & 0 deletions demo_and_api/api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse
from transformers import AutoTokenizer, AutoModel
from pydantic import BaseModel
import uvicorn, json, datetime
import torch

DEVICE = "cuda"
DEVICE_ID = "0"
CUDA_DEVICE = f"{DEVICE}:{DEVICE_ID}" if DEVICE_ID else DEVICE

app = FastAPI()

class Params(BaseModel):
prompt: str = 'hello'
history: list[list[str]] = []
max_length: int = 2048
top_p: float = 0.7
temperature: float = 0.95

class Answer(BaseModel):
status: int = 200
time: str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
response: str
history: list[list[str]] = []

def torch_gc():
if torch.cuda.is_available():
with torch.cuda.device(CUDA_DEVICE):
torch.cuda.empty_cache()
torch.cuda.ipc_collect()

async def create_chat(params: Params):
global model, tokenizer
response, history = model.chat(tokenizer,
params.prompt,
history=params.history,
max_length=params.max_length,
top_p=params.top_p,
temperature=params.temperature)
answer_ok = Answer(response=response, history=history)
print(answer_ok.json())
torch_gc()
return answer_ok

async def create_stream_chat(params: Params):
global model, tokenizer
for response, history in model.stream_chat(tokenizer,
params.prompt,
history=params.history,
max_length=params.max_length,
top_p=params.top_p,
temperature=params.temperature):
answer_ok = Answer(response=response, history=history)
# print(answer_ok.json())
yield "\ndata: " + json.dumps(answer_ok.json())

torch_gc()

@app.post("/chat")
async def post_chat(params: Params):
answer = await create_chat(params)
return answer

@app.post("/stream_chat")
async def post_stream_chat(params: Params):
return StreamingResponse(create_stream_chat(params))

if __name__ == '__main__':
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda()
model.eval()
uvicorn.run(app, host='0.0.0.0', port=8000, workers=1)
File renamed without changes.
7 changes: 7 additions & 0 deletions demo_and_api/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
pydantic
fastapi
uvicorn
gradio
mdtex2html
streamlit
streamlit-chat
File renamed without changes.
File renamed without changes.
File renamed without changes.
72 changes: 72 additions & 0 deletions demo_and_api/web_demo_streamlit_with_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import streamlit as st
from streamlit_chat import message
import requests
import json

st.set_page_config(
page_title="ChatGLM-6b 演示",
page_icon=":robot:"
)

MAX_TURNS = 20
MAX_BOXES = MAX_TURNS * 2
url = "http://localhost:8000/stream_chat"


def predict(input, max_length, top_p, temperature, history=None):
if history is None:
history = []

with container:
if len(history) > 0:
for i, (query, response) in enumerate(history):
message(query, avatar_style="big-smile", key=str(i) + "_user")
message(response, avatar_style="bottts", key=str(i))

message(input, avatar_style="big-smile", key=str(len(history)) + "_user")
st.write("AI正在回复:")
with st.empty():
req = {
"prompt": input,
"history": history,
"max_length": max_length,
"top_p": top_p,
"temperature": temperature
}
res = requests.post(url=url,json=req,stream=True)
for line in res.iter_lines(delimiter=b'\ndata: '):
line = line.decode(encoding='utf-8')
if line.strip() == '':
continue;
response_json = json.loads(json.loads(line))
response = response_json['response']
history = response_json['history']
st.write(response)

return history


container = st.container()

# create a prompt text for the text generation
prompt_text = st.text_area(label="用户命令输入",
height = 100,
placeholder="请在这儿输入您的命令")

max_length = st.sidebar.slider(
'max_length', 0, 4096, 2048, step=1
)
top_p = st.sidebar.slider(
'top_p', 0.0, 1.0, 0.6, step=0.01
)
temperature = st.sidebar.slider(
'temperature', 0.0, 1.0, 0.95, step=0.01
)

if 'state' not in st.session_state:
st.session_state['state'] = []

if st.button("发送", key="predict"):
with st.spinner("AI正在思考,请稍等........"):
# text generation
st.session_state["state"] = predict(prompt_text, max_length, top_p, temperature, st.session_state["state"])
2 changes: 0 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,5 @@ protobuf
transformers==4.27.1
cpm_kernels
torch>=1.10
gradio
mdtex2html
sentencepiece
accelerate
File renamed without changes
File renamed without changes