-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathinference.py
84 lines (64 loc) · 2.66 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
# 표준 라이브러리
import argparse
import os
# 외부 라이브러리
import pandas as pd
from transformers import AutoTokenizer
from peft import AutoPeftModelForCausalLM
import torch
# 로컬 모듈
from data_loader.datasets import BaseDataset
from models.base_model import BaseModel
from utils.utils import load_config, set_seed
from unsloth import FastLanguageModel
from dotenv import load_dotenv
from huggingface_hub import login
import wandb
load_dotenv()
hf_api_key = os.getenv('HF_API_KEY')
wandb_api_key = os.getenv('WANDB_API_KEY')
login(hf_api_key)
wandb.login(key=wandb_api_key)
def main() :
parser = argparse.ArgumentParser()
parser.add_argument("--config_path", type=str,
help="")
args = parser.parse_args()
print("Config Path :",args.config_path) # Check Config path
configs = load_config(args.config_path)
set_seed(configs.seed)
# test_model_path_or_name = os.path.join("./saved/models", configs.test_model_path_or_name)
# print("Inference Model Name :", configs.test_model_path_or_name) # Check Configs model name
model, tokenizer = FastLanguageModel.from_pretrained(
model_name = "unsloth/Qwen2.5-32B-Instruct-bnb-4bit",
max_seq_length = 8192,
dtype = torch.float16,
load_in_4bit = True,
device_map="auto",
trust_remote_code=True,
)
model = FastLanguageModel.get_peft_model(
model,
r = 16,
target_modules = ["k_proj","v_proj"],
lora_alpha = 16,
lora_dropout = 0, # Supports any, but = 0 is optimized
bias = "none", # Supports any, but = "none" is optimized
# [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes!
use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
random_state = 42,
use_rslora = False, # We support rank stabilized LoRA
loftq_config = None, # And LoftQ
)
if configs.chat_template is not None :
tokenizer.chat_template = configs.chat_template
test_data = pd.read_csv(os.path.join(configs.data_dir, 'test.csv'))
test_dataset = BaseDataset(test_data, tokenizer, configs, False)
model = BaseModel(configs, tokenizer, model=model)
# outputs, decoder_output = model.inference(test_dataset)
generate_output = model.inference_generate(test_dataset)
os.makedirs("./saved/outputs", exist_ok=True)
# pd.DataFrame(outputs).to_csv(os.path.join("./saved/outputs", configs.output_file), index=False)
pd.DataFrame(generate_output).to_csv(os.path.join("./saved/outputs", "qwen32b_generation_output.csv"), index=False)
if __name__ == "__main__":
main()