diff --git a/api.py b/api.py index 3193d2d8..81c919cd 100644 --- a/api.py +++ b/api.py @@ -1,14 +1,30 @@ +import asyncio + from fastapi import FastAPI, Request +from fastapi.responses import StreamingResponse from transformers import AutoTokenizer, AutoModel -import uvicorn, json, datetime +from pydantic import BaseModel +import uvicorn, datetime import torch -import threading -import asyncio + DEVICE = "cuda" DEVICE_ID = "0" +EXECUTOR_POOL_SIZE = 10 CUDA_DEVICE = f"{DEVICE}:{DEVICE_ID}" if DEVICE_ID else DEVICE +class Params(BaseModel): + prompt: str = 'hello' + history: list[list[str]] = [] + max_length: int = 2048 + top_p: float = 0.7 + temperature: float = 0.95 + +class Answer(BaseModel): + status: int = 200 + time: str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") + response: str + history: list[list[str]] = [] def torch_gc(): if torch.cuda.is_available(): @@ -21,45 +37,31 @@ def torch_gc(): import concurrent from functools import partial -pool = concurrent.futures.ThreadPoolExecutor(10) +pool = concurrent.futures.ThreadPoolExecutor(EXECUTOR_POOL_SIZE) -@app.post("/") -async def _create_item(request: Request): +@app.post("/chat") +async def create_chat(params: Params) -> Answer: global model, tokenizer - json_post_raw = await request.json() - json_post = json.dumps(json_post_raw) - json_post_list = json.loads(json_post) - prompt = json_post_list.get('prompt') - history = json_post_list.get('history') - max_length = json_post_list.get('max_length') - top_p = json_post_list.get('top_p') - temperature = json_post_list.get('temperature') - loop = asyncio.get_event_loop() - response, history = await loop.run_in_executor(pool,partial(model.chat,tokenizer, - prompt, - history=history, - max_length=max_length if max_length else 2048, - top_p=top_p if top_p else 0.7, - temperature=temperature if temperature else 0.95)) - now = datetime.datetime.now() - time = now.strftime("%Y-%m-%d %H:%M:%S") - answer = { - "response": response, - "history": history, - "status": 200, - "time": time - } - log = "[" + time + "] " + '", prompt:"' + prompt + '", response:"' + repr(response) + '"' - print(log) + if EXECUTOR_POOL_SIZE != 0: + loop = asyncio.get_event_loop() + response, history = await loop.run_in_executor(pool, partial(model.chat, + tokenizer, + params.prompt, + history=params.history, + max_length=params.max_length, + top_p=params.top_p, + temperature=params.temperature)) + else: + response, history = model.chat(tokenizer, + params.prompt, + history=params.history, + max_length=params.max_length, + top_p=params.top_p, + temperature=params.temperature) + answer_ok = Answer(response=response, history=history) + # print(answer_ok.json()) torch_gc() - return answer - -async def create_item(request: Request): - loop = asyncio.get_event_loop() - with concurrent.futures.ThreadPoolExecutor() as pool: - result = await loop.run_in_executor(pool,_create_item, request) - print(result) - return result + return answer_ok if __name__ == '__main__': tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)