-
Notifications
You must be signed in to change notification settings - Fork 0
/
chat.py
155 lines (132 loc) · 5.5 KB
/
chat.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import argparse
import glob
import os
import torch
from peft import PeftModel
from transformers import (AutoModelForCausalLM, AutoTokenizer,
BitsAndBytesConfig)
def chat():
model_name = "meta-llama/Llama-2-13b-chat-hf"
bnb_config = BitsAndBytesConfig(
load_in_4bit=True, # 4bitベースモデルの有効化
bnb_4bit_quant_type="nf4", # 量子化種別 (fp4 or nf4)
bnb_4bit_compute_dtype=torch.float16, # 4bitベースモデルのdtype (float16 or bfloat16)
bnb_4bit_use_double_quant=False, # 4bitベースモデルのネストされた量子化の有効化 (二重量子化)
)
# モデルの準備
base_model = AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=bnb_config,
device_map="auto",
torch_dtype=torch.bfloat16,
)
model_dict = {
n + 1: model
for n, model in enumerate(
glob.glob(
os.path.join(
os.path.dirname(os.path.abspath(__file__)), "model", "llama-2-*"
)
)
)
}
# model_dict = {1: "./llama-2-13b-skt-eng", 2: "./llama-2-13b-skt-eng-context", 3: "./llama-2-13b-skt-ger-context", 4: "./llama-2-13b-skt-all-context", 5: "./llama-2-13b-skt-grk-lat-context"}
print("モデル選択")
for key, value in model_dict.items():
print(f"[{key}] {os.path.basename(value)}")
model_num = input("Input number: ")
if model_num:
print(f"Loading {model_dict[int(model_num)]} model.")
peftmodel_name = model_dict[int(model_num)]
else:
print(f"Loading recently created model.")
peftmodel_name = "./results"
model = PeftModel.from_pretrained(base_model, peftmodel_name)
# トークナイザーの準備
tokenizer = AutoTokenizer.from_pretrained(
model_name, use_fast=False, add_eos_token=True, trust_remote_code=True
)
tokenizer.pad_token = tokenizer.unk_token
tokenizer.padding_side = "right"
# 推論の実行
try:
while True:
# word = input("Input a Sanskrit word: ")
# prompt = f"[INST]What is the meaning of {word}?[/INST]"
base_prompt = input("Prompt: ")
prompt = f"### Instruction:\n{base_prompt}\n\n### Response:\n"
for i in range(args.response):
inputs = tokenizer(
prompt, add_special_tokens=False, return_tensors="pt"
)
outputs = model.generate(
**inputs.to(model.device),
max_new_tokens=256,
# do_sample=True,
temperature=0.7,
# return_dict_in_generate=True,
)
output = tokenizer.decode(
# outputs.sequences[0, inputs.input_ids.shape[1] :]
outputs[0],
skip_special_tokens=True,
)
print(f"Answer ({i}): ", output)
except KeyboardInterrupt:
print("Quit.")
def chat_base_model():
model_name = "meta-llama/Llama-2-13b-chat-hf"
bnb_config = BitsAndBytesConfig(
load_in_4bit=True, # 4bitベースモデルの有効化
bnb_4bit_quant_type="nf4", # 量子化種別 (fp4 or nf4)
bnb_4bit_compute_dtype=torch.float16, # 4bitベースモデルのdtype (float16 or bfloat16)
bnb_4bit_use_double_quant=False, # 4bitベースモデルのネストされた量子化の有効化 (二重量子化)
)
model = AutoModelForCausalLM.from_pretrained(
model_name, # モデル名
quantization_config=bnb_config, # 量子化パラメータ
device_map="auto",
use_auth_token=True,
)
model.config.use_cache = True # キャッシュ (学習時はFalse)
model.config.pretraining_tp = 2 # 事前学習で使用したテンソル並列ランク(7B:1、13B:2)
# トークナイザーの準備
tokenizer = AutoTokenizer.from_pretrained(
model_name, # モデル名
use_fast=False, # Fastトークナイザーの有効化
add_eos_token=True, # データへのEOSの追加を指示
trust_remote_code=True,
use_auth_token=True,
)
tokenizer.pad_token = tokenizer.unk_token
tokenizer.padding_side = "right" # fp16でのオーバーフロー問題対策
# 推論の実行
try:
while True:
# word = input("Input a Sanskrit word: ")
# prompt = f"[INST]What is the meaning of {word}?[/INST]"
base_prompt = input("Prompt: ")
prompt = f"#Instruction:\n{base_prompt}\n\n# Response:\n"
inputs = tokenizer(prompt, add_special_tokens=False, return_tensors="pt")
outputs = model.generate(
**inputs.to(model.device),
max_new_tokens=100,
do_sample=True,
temperature=0.7,
return_dict_in_generate=True,
)
output = tokenizer.decode(outputs.sequences[0, inputs.input_ids.shape[1] :])
print("Answer: ", output)
except KeyboardInterrupt:
print("\nQuit.")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Chat with llama.")
parser.add_argument("-l", "--llama", action="store_true", help="Use base model")
parser.add_argument(
"-r", "--response", type=int, default=1, help="Number of responses (default 1)"
)
args = parser.parse_args()
if args.llama:
chat_base_model()
else:
chat()