forked from shibing624/MedicalGPT
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgradio_demo.py
130 lines (118 loc) · 4.94 KB
/
gradio_demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
# -*- coding: utf-8 -*-
"""
@author:XuMing([email protected])
@description:
pip install gradio>=3.50.2
"""
import argparse
import os
from threading import Thread
import gradio as gr
import torch
from peft import PeftModel
from transformers import (
AutoModel,
AutoTokenizer,
AutoModelForCausalLM,
BloomForCausalLM,
BloomTokenizerFast,
LlamaTokenizer,
LlamaForCausalLM,
GenerationConfig,
TextIteratorStreamer,
)
from supervised_finetuning import get_conv_template
MODEL_CLASSES = {
"bloom": (BloomForCausalLM, BloomTokenizerFast),
"chatglm": (AutoModel, AutoTokenizer),
"llama": (LlamaForCausalLM, LlamaTokenizer),
"baichuan": (AutoModelForCausalLM, AutoTokenizer),
"auto": (AutoModelForCausalLM, AutoTokenizer),
}
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--model_type', default=None, type=str, required=True)
parser.add_argument('--base_model', default=None, type=str, required=True)
parser.add_argument('--lora_model', default="", type=str, help="If None, perform inference on the base model")
parser.add_argument('--tokenizer_path', default=None, type=str)
parser.add_argument('--template_name', default="vicuna", type=str,
help="Prompt template name, eg: alpaca, vicuna, baichuan2, chatglm2 etc.")
parser.add_argument('--only_cpu', action='store_true', help='only use CPU for inference')
parser.add_argument('--resize_emb', action='store_true', help='Whether to resize model token embeddings')
parser.add_argument('--share', action='store_true', help='Share gradio')
parser.add_argument('--port', default=8081, type=int, help='Port of gradio demo')
args = parser.parse_args()
print(args)
load_type = torch.float16
if torch.cuda.is_available() and not args.only_cpu:
device = torch.device(0)
else:
device = torch.device('cpu')
if args.tokenizer_path is None:
args.tokenizer_path = args.base_model
model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
tokenizer = tokenizer_class.from_pretrained(args.tokenizer_path, trust_remote_code=True)
base_model = model_class.from_pretrained(
args.base_model,
torch_dtype=load_type,
low_cpu_mem_usage=True,
device_map='auto',
trust_remote_code=True,
)
try:
base_model.generation_config = GenerationConfig.from_pretrained(args.base_model, trust_remote_code=True)
except OSError:
print("Failed to load generation config, use default.")
if args.resize_emb:
model_vocab_size = base_model.get_input_embeddings().weight.size(0)
tokenzier_vocab_size = len(tokenizer)
print(f"Vocab of the base model: {model_vocab_size}")
print(f"Vocab of the tokenizer: {tokenzier_vocab_size}")
if model_vocab_size != tokenzier_vocab_size:
print("Resize model embeddings to fit tokenizer")
base_model.resize_token_embeddings(tokenzier_vocab_size)
if args.lora_model:
model = PeftModel.from_pretrained(base_model, args.lora_model, torch_dtype=load_type, device_map='auto')
print("loaded lora model")
else:
model = base_model
if device == torch.device('cpu'):
model.float()
model.eval()
prompt_template = get_conv_template(args.template_name)
stop_str = tokenizer.eos_token if tokenizer.eos_token else prompt_template.stop_str
def predict(message, history):
"""Generate answer from prompt with GPT and stream the output"""
history_messages = history + [[message, ""]]
prompt = prompt_template.get_prompt(messages=history_messages)
streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
input_ids = tokenizer(prompt).input_ids
context_len = 2048
max_new_tokens = 512
max_src_len = context_len - max_new_tokens - 8
input_ids = input_ids[-max_src_len:]
generation_kwargs = dict(
input_ids=torch.as_tensor([input_ids]).to(device),
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=False,
num_beams=1,
repetition_penalty=1.0,
)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
partial_message = ""
for new_token in streamer:
if new_token != stop_str:
partial_message += new_token
yield partial_message
gr.ChatInterface(
predict,
chatbot=gr.Chatbot(),
textbox=gr.Textbox(placeholder="Ask me question", lines=4, scale=9),
title="MedicalGPT",
description="为了促进医疗行业大模型的开放研究,本项目开源了[MedicalGPT](https://github.com/shibing624/MedicalGPT)医疗大模型",
theme="soft",
).queue().launch(share=args.share, inbrowser=True, server_name='0.0.0.0', server_port=args.port)
if __name__ == '__main__':
main()