-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathutils.py
42 lines (36 loc) · 1.39 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
from transformers.integrations import TensorBoardCallback
from torch.utils.tensorboard import SummaryWriter
from transformers import TrainingArguments
from transformers import Trainer, HfArgumentParser
from transformers import AutoTokenizer, AutoModel
import torch
import torch.nn as nn
from dataclasses import dataclass, field
import datasets
import os
class CastOutputToFloat(nn.Sequential):
def forward(self, x):
return super().forward(x).to(torch.float32)
def get_data_collator(tokenizer: AutoTokenizer):
def data_collator(features: list) -> dict:
len_ids = [len(feature["input_ids"]) for feature in features]
longest = max(len_ids)
input_ids = []
labels_list = []
for ids_l, feature in sorted(zip(len_ids, features), key=lambda x: -x[0]):
ids = feature["input_ids"]
seq_len = feature["seq_len"]
labels = (
[-100] * (seq_len - 1) + ids[(seq_len - 1) :] + [-100] * (longest - ids_l)
)
ids = ids + [tokenizer.pad_token_id] * (longest - ids_l)
_ids = torch.LongTensor(ids)
labels_list.append(torch.LongTensor(labels))
input_ids.append(_ids)
input_ids = torch.stack(input_ids)
labels = torch.stack(labels_list)
return {
"input_ids": input_ids,
"labels": labels,
}
return data_collator