diff --git a/README.md b/README.md index 31156164..ec737f22 100644 --- a/README.md +++ b/README.md @@ -95,47 +95,56 @@ GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/THUDM/chatglm-6b ## Demo & API -我们提供了一个基于 [Gradio](https://gradio.app) 的网页版 Demo 和一个命令行 Demo。使用时首先需要下载本仓库: +### 网页版(基于gradio) Demo + +![web-demo](resources/web-demo-gradio.gif) + +首先安装 Gradio:`pip install gradio mdtex2html`,然后运行仓库中的 [web_demo_gradio.py](demo_and_api/web_demo_gradio.py): ```shell -git clone https://github.com/THUDM/ChatGLM-6B -cd ChatGLM-6B +cd demo_and_api +python web_demo_gradio.py ``` -#### 网页版 Demo +程序会运行一个 Web Server,并输出地址。在浏览器中打开输出的地址即可使用。最新版 Demo 实现了打字机效果,速度体验大大提升。注意,由于国内 Gradio 的网络访问较为缓慢,启用 `demo.queue().launch(share=True, inbrowser=True)` 时所有网络会经过 Gradio 服务器转发,导致打字机体验大幅下降,现在默认启动方式已经改为 `share=False`,如有需要公网访问的需求,可以重新修改为 `share=True` 启动。 -![web-demo](resources/web-demo.gif) +### 网页版(基于streamlit) Demo -首先安装 Gradio:`pip install gradio`,然后运行仓库中的 [web_demo.py](web_demo.py): +首先安装 Streamlit: `pip install streamlit streamlit-chat`,然后运行仓库中的 [web_demo_streamlit.py](demo_and_api/web_demo_streamlit.py): ```shell -python web_demo.py +cd demo_and_api +streamlit run web_demo_streamlit.py --server.port 6006 ``` -程序会运行一个 Web Server,并输出地址。在浏览器中打开输出的地址即可使用。最新版 Demo 实现了打字机效果,速度体验大大提升。注意,由于国内 Gradio 的网络访问较为缓慢,启用 `demo.queue().launch(share=True, inbrowser=True)` 时所有网络会经过 Gradio 服务器转发,导致打字机体验大幅下降,现在默认启动方式已经改为 `share=False`,如有需要公网访问的需求,可以重新修改为 `share=True` 启动。 - -感谢 [@AdamBear](https://github.com/AdamBear) 实现了基于 Streamlit 的网页版 Demo,运行方式见[#117](https://github.com/THUDM/ChatGLM-6B/pull/117). +*感谢 [@AdamBear](https://github.com/AdamBear) 贡献的此实现,详见[#117](https://github.com/THUDM/ChatGLM-6B/pull/117).* -#### 命令行 Demo +### 命令行 Demo ![cli-demo](resources/cli-demo.png) 运行仓库中 [cli_demo.py](cli_demo.py): ```shell +cd demo_and_api python cli_demo.py ``` 程序会在命令行中进行交互式的对话,在命令行中输入指示并回车即可生成回复,输入 `clear` 可以清空对话历史,输入 `stop` 终止程序。 ### API部署 -首先需要安装额外的依赖 `pip install fastapi uvicorn`,然后运行仓库中的 [api.py](api.py): +首先需要安装额外的依赖 `pip install fastapi uvicorn pydantic`,然后运行仓库中的 [api.py](demo_and_api/api.py): ```shell +cd demo_and_api python api.py ``` + +API支持普通接口(/chat)和流式接口(/stream_chat); +流式接口可实现打字机效果,调用方式可参考 [web_demo_streamlit_with_api.py](demo_and_api/web_demo_streamlit_with_api.py): + 默认部署在本地的 8000 端口,通过 POST 方法进行调用 ```shell -curl -X POST "http://127.0.0.1:8000" \ +curl -X POST "http://127.0.0.1:8000/chat" \ -H 'Content-Type: application/json' \ -d '{"prompt": "你好", "history": []}' ``` @@ -149,6 +158,23 @@ curl -X POST "http://127.0.0.1:8000" \ } ``` +### 网页版(基于streamlit和API) Demo +streamlit作为前端,api作为后端,使用了api的流式接口 + +启动后端 +api依赖安装参照前文 +```shell +cd demo_and_api +python api.py +``` + +启动前端 +streamlit依赖安装参照前文 +```shell +cd demo_and_api +streamlit run web_demo_streamlit_with_api.py --server.port 6006 +``` + ## 低成本部署 ### 模型量化 默认情况下,模型以 FP16 精度加载,运行上述代码需要大概 13GB 显存。如果你的 GPU 显存有限,可以尝试以量化方式加载模型,使用方法如下: @@ -216,7 +242,7 @@ model = load_model_on_gpus("THUDM/chatglm-6b", num_gpus=2) ## ChatGLM-6B 示例 -以下是一些使用 `web_demo.py` 得到的示例截图。更多 ChatGLM-6B 的可能,等待你来探索发现! +以下是一些使用 `web_demo_gradio.py` 得到的示例截图。更多 ChatGLM-6B 的可能,等待你来探索发现!
自我认知 diff --git a/README_en.md b/README_en.md index 93b2ee29..bc40451f 100644 --- a/README_en.md +++ b/README_en.md @@ -85,7 +85,7 @@ cd ChatGLM-6B #### Web Demo -![web-demo](resources/web-demo.png) +![web-demo](resources/web-demo-gradio.png) Install Gradio `pip install gradio`,and run [web_demo.py](web_demo.py): diff --git a/api.py b/api.py deleted file mode 100644 index 693c70ac..00000000 --- a/api.py +++ /dev/null @@ -1,56 +0,0 @@ -from fastapi import FastAPI, Request -from transformers import AutoTokenizer, AutoModel -import uvicorn, json, datetime -import torch - -DEVICE = "cuda" -DEVICE_ID = "0" -CUDA_DEVICE = f"{DEVICE}:{DEVICE_ID}" if DEVICE_ID else DEVICE - - -def torch_gc(): - if torch.cuda.is_available(): - with torch.cuda.device(CUDA_DEVICE): - torch.cuda.empty_cache() - torch.cuda.ipc_collect() - - -app = FastAPI() - - -@app.post("/") -async def create_item(request: Request): - global model, tokenizer - json_post_raw = await request.json() - json_post = json.dumps(json_post_raw) - json_post_list = json.loads(json_post) - prompt = json_post_list.get('prompt') - history = json_post_list.get('history') - max_length = json_post_list.get('max_length') - top_p = json_post_list.get('top_p') - temperature = json_post_list.get('temperature') - response, history = model.chat(tokenizer, - prompt, - history=history, - max_length=max_length if max_length else 2048, - top_p=top_p if top_p else 0.7, - temperature=temperature if temperature else 0.95) - now = datetime.datetime.now() - time = now.strftime("%Y-%m-%d %H:%M:%S") - answer = { - "response": response, - "history": history, - "status": 200, - "time": time - } - log = "[" + time + "] " + '", prompt:"' + prompt + '", response:"' + repr(response) + '"' - print(log) - torch_gc() - return answer - - -if __name__ == '__main__': - tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True) - model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda() - model.eval() - uvicorn.run(app, host='0.0.0.0', port=8000, workers=1) diff --git a/demo_and_api/api.py b/demo_and_api/api.py new file mode 100644 index 00000000..b6de5325 --- /dev/null +++ b/demo_and_api/api.py @@ -0,0 +1,73 @@ +from fastapi import FastAPI, Request +from fastapi.responses import StreamingResponse +from transformers import AutoTokenizer, AutoModel +from pydantic import BaseModel +import uvicorn, json, datetime +import torch + +DEVICE = "cuda" +DEVICE_ID = "0" +CUDA_DEVICE = f"{DEVICE}:{DEVICE_ID}" if DEVICE_ID else DEVICE + +app = FastAPI() + +class Params(BaseModel): + prompt: str = 'hello' + history: list[list[str]] = [] + max_length: int = 2048 + top_p: float = 0.7 + temperature: float = 0.95 + +class Answer(BaseModel): + status: int = 200 + time: str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") + response: str + history: list[list[str]] = [] + +def torch_gc(): + if torch.cuda.is_available(): + with torch.cuda.device(CUDA_DEVICE): + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + +async def create_chat(params: Params): + global model, tokenizer + response, history = model.chat(tokenizer, + params.prompt, + history=params.history, + max_length=params.max_length, + top_p=params.top_p, + temperature=params.temperature) + answer_ok = Answer(response=response, history=history) + print(answer_ok.json()) + torch_gc() + return answer_ok + +async def create_stream_chat(params: Params): + global model, tokenizer + for response, history in model.stream_chat(tokenizer, + params.prompt, + history=params.history, + max_length=params.max_length, + top_p=params.top_p, + temperature=params.temperature): + answer_ok = Answer(response=response, history=history) + # print(answer_ok.json()) + yield "\ndata: " + json.dumps(answer_ok.json()) + + torch_gc() + +@app.post("/chat") +async def post_chat(params: Params): + answer = await create_chat(params) + return answer + +@app.post("/stream_chat") +async def post_stream_chat(params: Params): + return StreamingResponse(create_stream_chat(params)) + +if __name__ == '__main__': + tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True) + model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda() + model.eval() + uvicorn.run(app, host='0.0.0.0', port=8000, workers=1) diff --git a/cli_demo.py b/demo_and_api/cli_demo.py similarity index 100% rename from cli_demo.py rename to demo_and_api/cli_demo.py diff --git a/demo_and_api/requirements.txt b/demo_and_api/requirements.txt new file mode 100644 index 00000000..e0906c33 --- /dev/null +++ b/demo_and_api/requirements.txt @@ -0,0 +1,7 @@ +pydantic +fastapi +uvicorn +gradio +mdtex2html +streamlit +streamlit-chat \ No newline at end of file diff --git a/web_demo.py b/demo_and_api/web_demo_gradio.py similarity index 100% rename from web_demo.py rename to demo_and_api/web_demo_gradio.py diff --git a/web_demo_old.py b/demo_and_api/web_demo_old.py similarity index 100% rename from web_demo_old.py rename to demo_and_api/web_demo_old.py diff --git a/web_demo2.py b/demo_and_api/web_demo_streamlit.py similarity index 100% rename from web_demo2.py rename to demo_and_api/web_demo_streamlit.py diff --git a/demo_and_api/web_demo_streamlit_with_api.py b/demo_and_api/web_demo_streamlit_with_api.py new file mode 100644 index 00000000..d202c742 --- /dev/null +++ b/demo_and_api/web_demo_streamlit_with_api.py @@ -0,0 +1,72 @@ +import streamlit as st +from streamlit_chat import message +import requests +import json + +st.set_page_config( + page_title="ChatGLM-6b 演示", + page_icon=":robot:" +) + +MAX_TURNS = 20 +MAX_BOXES = MAX_TURNS * 2 +url = "http://localhost:8000/stream_chat" + + +def predict(input, max_length, top_p, temperature, history=None): + if history is None: + history = [] + + with container: + if len(history) > 0: + for i, (query, response) in enumerate(history): + message(query, avatar_style="big-smile", key=str(i) + "_user") + message(response, avatar_style="bottts", key=str(i)) + + message(input, avatar_style="big-smile", key=str(len(history)) + "_user") + st.write("AI正在回复:") + with st.empty(): + req = { + "prompt": input, + "history": history, + "max_length": max_length, + "top_p": top_p, + "temperature": temperature + } + res = requests.post(url=url,json=req,stream=True) + for line in res.iter_lines(delimiter=b'\ndata: '): + line = line.decode(encoding='utf-8') + if line.strip() == '': + continue; + response_json = json.loads(json.loads(line)) + response = response_json['response'] + history = response_json['history'] + st.write(response) + + return history + + +container = st.container() + +# create a prompt text for the text generation +prompt_text = st.text_area(label="用户命令输入", + height = 100, + placeholder="请在这儿输入您的命令") + +max_length = st.sidebar.slider( + 'max_length', 0, 4096, 2048, step=1 +) +top_p = st.sidebar.slider( + 'top_p', 0.0, 1.0, 0.6, step=0.01 +) +temperature = st.sidebar.slider( + 'temperature', 0.0, 1.0, 0.95, step=0.01 +) + +if 'state' not in st.session_state: + st.session_state['state'] = [] + +if st.button("发送", key="predict"): + with st.spinner("AI正在思考,请稍等........"): + # text generation + st.session_state["state"] = predict(prompt_text, max_length, top_p, temperature, st.session_state["state"]) \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index fb8d79f7..16a67039 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,7 +2,5 @@ protobuf transformers==4.27.1 cpm_kernels torch>=1.10 -gradio -mdtex2html sentencepiece accelerate \ No newline at end of file diff --git a/resources/web-demo.gif b/resources/web-demo-gradio.gif similarity index 100% rename from resources/web-demo.gif rename to resources/web-demo-gradio.gif diff --git a/resources/web-demo.png b/resources/web-demo-gradio.png similarity index 100% rename from resources/web-demo.png rename to resources/web-demo-gradio.png