-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy patheval.py
60 lines (43 loc) · 1.8 KB
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
from torch.utils.data import DataLoader
from evaluate import load
from transformers import HfArgumentParser, AutoTokenizer, AutoConfig
from peft import PeftModel, PeftConfig
from model.modeling import T5ForClassification
from data import T5Dataset, T5Collator
from args import ModelArgs, DataArgs
from trainer import Trainer
#region Parser
parser = HfArgumentParser([ModelArgs, DataArgs])
model_args, data_args = parser.parse_args_into_dataclasses()
print(model_args, data_args, sep='\n')
#endregion
#region Model
peft_config = PeftConfig.from_pretrained(model_args.pretrained_model_name_or_path)
base_config = AutoConfig.from_pretrained(model_args.pretrained_model_name_or_path)
model = T5ForClassification.from_pretrained(pretrained_model_name_or_path=peft_config.base_model_name_or_path, **base_config.to_diff_dict(), load_in_8bit=True, device_map={'': 0})
model = PeftModel.from_pretrained(model, model_args.pretrained_model_name_or_path, device_map={'': 0})
model.eval()
#endregion
#region Tokenizer + Data
tokenizer = AutoTokenizer.from_pretrained(model_args.pretrained_model_name_or_path, model_max_length=data_args.seq_length)
data_collator = T5Collator(tokenizer)
eval_data = T5Dataset(data_args.path + "/dev")
eval_dataloader = DataLoader(eval_data, batch_size=data_args.bs, collate_fn=data_collator)
#endregion
#region Metrics
accuracy_metric = load("accuracy")
f1_metric = load("f1")
def compute_metrics(preds, labels):
acc = accuracy_metric.compute(predictions=preds, references=labels)
f1 = f1_metric.compute(predictions=preds, references=labels, average='macro')
metrics = {**acc, **f1}
return metrics
#endregion
trainer = Trainer(
model=model,
eval_dataloader=eval_dataloader,
tokenizer=tokenizer,
compute_metrics=compute_metrics
)
metrics = trainer.eval()
print(metrics)