-
Notifications
You must be signed in to change notification settings - Fork 1
/
finetune_kilt_fever.py
157 lines (129 loc) · 6.55 KB
/
finetune_kilt_fever.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
from logger import LoggingCallback
from custom_checkpoint import CustomCheckpointCallback
import random
import numpy as np
import torch
import argparse
import os
import re
import glob
import pytorch_lightning as pl
from trainer import *
def set_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
def extractValLoss(checkpoint_path):
"""Eg checkpoint path format: path_to_dir/checkpoint_epoch=4-val_loss=0.450662.ckpt"""
val_loss = float(re.search('val_loss=(.+?).ckpt', checkpoint_path).group(1))
return val_loss
def extractStepOREpochNum(checkpoint_path):
"""Eg checkpoint path format: path_to_dir/checkpoint_epoch=4.ckpt (or)
path_to_dir/checkpoint_epoch=4-step=50.ckpt (or)
"""
if "step" in checkpoint_path:
num = int(re.search('step=(.+?).ckpt', checkpoint_path).group(1))
else:
num = int(re.search('epoch=(.+?).ckpt', checkpoint_path).group(1))
return num
def getBestModelCheckpointPath(checkpoint_dir):
checkpoint_list = glob.glob(os.path.join(checkpoint_dir, "checkpoint_*.ckpt"))
try:
# Get the checkpoint with lowest validation loss
sorted_list = sorted(checkpoint_list, key=lambda x: extractValLoss(x.split("/")[-1]))
except:
# If validation loss is not present, get the checkpoint with highest step number or epoch number.
sorted_list = sorted(checkpoint_list, key=lambda x: extractStepOREpochNum(x.split("/")[-1]), reverse=True)
return sorted_list[0]
def run():
#torch.multiprocessing.freeze_support()
parser = argparse.ArgumentParser()
parser.add_argument('--data_dir', type=str, default="datasets/kilt_fever",
help='Path for Data files')
parser.add_argument('--output_dir', type=str, default="outputs/kilt_fever_outputs",
help='Path to save the checkpoints')
parser.add_argument('--checkpoint_dir', type=str, default="",
help='Checkpoint directory')
parser.add_argument('--save_every_n_steps', type=int, default=-1,
help='Interval of training steps to save the model checkpoints. Use -1 to disable this callback')
parser.add_argument('--model_name_or_path', type=str, default="t5-base",
help='Model name or Path')
parser.add_argument('--tokenizer_name_or_path', type=str, default="t5-base",
help='Tokenizer name or Path')
# you can find out more on optimisation levels here https://nvidia.github.io/apex/amp.html#opt-levels-and-properties
parser.add_argument('--opt_level', type=str, default="01",
help='Optimization level')
parser.add_argument('--early_stop_callback', type=lambda x: (str(x).lower() == 'true'), default="False",
help='Whether to do early stopping?')
# if you want to enable 16-bit training then install apex and set this to true
parser.add_argument('--fp_16', type=lambda x: (str(x).lower() == 'true'), default="False",
help='Whether to use 16 bit precision floating point operations?')
parser.add_argument('--learning_rate', type=float, default=2e-5,
help='Learning Rate')
parser.add_argument('--weight_decay', type=float, default=0.0,
help='Weight decay')
parser.add_argument('--adam_epsilon', type=float, default=1e-8,
help='Epsilon value for Adam Optimizer')
# if you enable 16-bit training then set this to a sensible value, 0.5 is a good default
parser.add_argument('--max_grad_norm', type=float, default=1.0,
help='Maximum Gradient Norm value for Clipping')
parser.add_argument('--max_seq_length', type=int, default=256,
help='Maximum Sequence Length')
parser.add_argument('--warmup_steps', type=int, default=400,
help='Number of warmup steps')
parser.add_argument('--train_batch_size', type=int, default=8,
help='Batch size for Training')
parser.add_argument('--eval_batch_size', type=int, default=8,
help='Batch size for Evaluation')
parser.add_argument('--num_train_epochs', type=int, default=10,
help='Number of Training epochs')
parser.add_argument('--gradient_accumulation_steps', type=int, default=32,
help='Gradient Accumulation Steps')
parser.add_argument('--n_gpu', type=int, default=1,
help='Number of GPUs to use for computation')
parser.add_argument('--gpu_nums', type=str, default="0",
help='GPU ids separated by "," to use for computation')
parser.add_argument('--seed', type=int, default=42,
help='Manual Seed Value')
args = parser.parse_known_args()[0]
print(args)
set_seed(args.seed)
# Create a folder if output_dir doesn't exists:
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
print("Creating output directory")
checkpoint_callback = pl.callbacks.ModelCheckpoint(
filepath=args.output_dir + "/{epoch}-{val_loss:.6f}", prefix="checkpoint_", monitor="val_loss", mode="min", save_top_k=1
)
trainer_custom_callbacks = [LoggingCallback()]
if args.save_every_n_steps != -1:
custom_checkpoint_callback = CustomCheckpointCallback(
filepath=args.output_dir, prefix="checkpoint_", save_every_n_steps=args.save_every_n_steps
)
trainer_custom_callbacks.append(custom_checkpoint_callback)
train_params = dict(
accumulate_grad_batches=args.gradient_accumulation_steps,
gpus=args.gpu_nums,
max_epochs=args.num_train_epochs,
early_stop_callback=args.early_stop_callback,
precision=16 if args.fp_16 else 32,
amp_level=args.opt_level,
gradient_clip_val=args.max_grad_norm,
checkpoint_callback=checkpoint_callback,
callbacks=trainer_custom_callbacks,
distributed_backend='ddp'
)
if len(args.checkpoint_dir) != 0:
best_checkpoint_path = getBestModelCheckpointPath(args.checkpoint_dir)
print("Using checkpoint = ", str(best_checkpoint_path))
checkpoint_state = torch.load(best_checkpoint_path, map_location="cpu")
model = T5FineTuner(args)
model.load_state_dict(checkpoint_state['state_dict'])
else:
model = T5FineTuner(args)
trainer = pl.Trainer(**train_params)
trainer.fit(model)
if __name__ == '__main__':
run()