Skip to content

Commit

Permalink
使用pydantic定义输入和输出结构.
Browse files Browse the repository at this point in the history
  • Loading branch information
is committed May 15, 2023
1 parent 186d656 commit b709499
Showing 1 changed file with 41 additions and 39 deletions.
80 changes: 41 additions & 39 deletions api.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,30 @@
import asyncio

from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse
from transformers import AutoTokenizer, AutoModel
import uvicorn, json, datetime
from pydantic import BaseModel
import uvicorn, datetime
import torch
import threading
import asyncio


DEVICE = "cuda"
DEVICE_ID = "0"
EXECUTOR_POOL_SIZE = 10
CUDA_DEVICE = f"{DEVICE}:{DEVICE_ID}" if DEVICE_ID else DEVICE

class Params(BaseModel):
prompt: str = 'hello'
history: list[list[str]] = []
max_length: int = 2048
top_p: float = 0.7
temperature: float = 0.95

class Answer(BaseModel):
status: int = 200
time: str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
response: str
history: list[list[str]] = []

def torch_gc():
if torch.cuda.is_available():
Expand All @@ -21,45 +37,31 @@ def torch_gc():

import concurrent
from functools import partial
pool = concurrent.futures.ThreadPoolExecutor(10)
pool = concurrent.futures.ThreadPoolExecutor(EXECUTOR_POOL_SIZE)

@app.post("/")
async def _create_item(request: Request):
@app.post("/chat")
async def create_chat(params: Params) -> Answer:
global model, tokenizer
json_post_raw = await request.json()
json_post = json.dumps(json_post_raw)
json_post_list = json.loads(json_post)
prompt = json_post_list.get('prompt')
history = json_post_list.get('history')
max_length = json_post_list.get('max_length')
top_p = json_post_list.get('top_p')
temperature = json_post_list.get('temperature')
loop = asyncio.get_event_loop()
response, history = await loop.run_in_executor(pool,partial(model.chat,tokenizer,
prompt,
history=history,
max_length=max_length if max_length else 2048,
top_p=top_p if top_p else 0.7,
temperature=temperature if temperature else 0.95))
now = datetime.datetime.now()
time = now.strftime("%Y-%m-%d %H:%M:%S")
answer = {
"response": response,
"history": history,
"status": 200,
"time": time
}
log = "[" + time + "] " + '", prompt:"' + prompt + '", response:"' + repr(response) + '"'
print(log)
if EXECUTOR_POOL_SIZE != 0:
loop = asyncio.get_event_loop()
response, history = await loop.run_in_executor(pool, partial(model.chat,
tokenizer,
params.prompt,
history=params.history,
max_length=params.max_length,
top_p=params.top_p,
temperature=params.temperature))
else:
response, history = model.chat(tokenizer,
params.prompt,
history=params.history,
max_length=params.max_length,
top_p=params.top_p,
temperature=params.temperature)
answer_ok = Answer(response=response, history=history)
# print(answer_ok.json())
torch_gc()
return answer

async def create_item(request: Request):
loop = asyncio.get_event_loop()
with concurrent.futures.ThreadPoolExecutor() as pool:
result = await loop.run_in_executor(pool,_create_item, request)
print(result)
return result
return answer_ok

if __name__ == '__main__':
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
Expand Down

0 comments on commit b709499

Please sign in to comment.