-
Notifications
You must be signed in to change notification settings - Fork 2
/
utils.py
107 lines (84 loc) · 3.22 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import os
import numpy as np
import random
import torch
from easydict import EasyDict
import yaml
from transformers.generation_logits_process import LogitsProcessor
def load_config(config_path="./config.yaml"):
# Read config.yaml file
with open(config_path) as infile:
SAVED_CFG = yaml.load(infile, Loader=yaml.FullLoader)
CFG = EasyDict(SAVED_CFG["CFG"])
return CFG
def load_devices():
CFG = load_config()
device_ids = CFG["device_ids"]
list_devices = []
# cpu
if device_ids == -1 and not torch.cuda.is_available():
list_devices.append(torch.device("cpu"))
# single-gpu
elif device_ids != -1 and type(device_ids) == int and torch.cuda.is_available():
return torch.device("cuda:" + str(device_ids))
# multiple-gpu
elif device_ids != -1 and type(device_ids) == list and torch.cuda.is_available():
for device_index in device_ids:
list_devices.append(torch.device(f"cuda:{device_index}"))
print("working on", list_devices)
return list_devices
def seed_everything(seed):
os.environ["PYTHONHASHSEED"] = str(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(seed)
np.random.default_rng(seed)
random.seed(seed)
def make_sequence_length(examples):
examples["sequence_length"] = len(examples["input_ids"])
return examples
def remove_lengthy_texts(examples):
MAX_LENGTH = 4000
if len(examples["text"]) >= MAX_LENGTH:
return False
else:
return True
def restrict_token_length_fn(examples):
CFG = load_config()
if (
CFG.min_prefix_length <= len(examples["input_ids"])
and len(examples["input_ids"]) <= CFG.max_prefix_length
):
return True
else:
return False
def get_token_sequence_length(examples):
examples["sequence_length"] = len(examples["input_ids"])
return examples
def collate_fn(batch):
return (
torch.tensor([item["input_ids"] for item in batch], dtype=torch.long),
torch.tensor([item["attention_mask"] for item in batch], dtype=torch.long),
)
class DecayingTemperatureWarper(LogitsProcessor):
"""
Written by @shreyansh26 at https://github.com/shreyansh26/Extracting-Training-Data-from-Large-Langauge-Models/blob/main/extraction_temperature_decay.py
- Custom LogitProcessor to decay Temperature from 10.0 to 1.0 over the first 20 tokens
- Assign 1.0 for subsequent tokens after the 20th token
"""
def __init__(self, temperature: float):
if not isinstance(temperature, float) or not (temperature > 0):
raise ValueError(
f"`temperature` has to be a strictly positive float, but is {temperature}"
)
self.temperature = temperature
# make dictionary from 0 to 20 as keys
# from 10 to 1 as values
self.temperature_dict = {i: 10.0 - 9.0 * (i / 20) for i in range(21)}
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.FloatTensor:
cur_len = input_ids.shape[-1]
self.temperature = self.temperature_dict.get(cur_len, 1.0)
return scores