-
Notifications
You must be signed in to change notification settings - Fork 5
/
train.py
145 lines (115 loc) · 5.92 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import torch
import torch.utils.data
import torch.backends.cudnn as cudnn
from torch.utils.tensorboard import SummaryWriter
import os
import json
import valid
from utils import utils
from utils import sam
from utils import option
from data import dataset
from model import HTR_VT
from functools import partial
def compute_loss(args, model, image, batch_size, criterion, text, length):
preds = model(image, args.mask_ratio, args.max_span_length, use_masking=True)
preds = preds.float()
preds_size = torch.IntTensor([preds.size(1)] * batch_size).cuda()
preds = preds.permute(1, 0, 2).log_softmax(2)
torch.backends.cudnn.enabled = False
loss = criterion(preds, text.cuda(), preds_size, length.cuda()).mean()
torch.backends.cudnn.enabled = True
return loss
def main():
args = option.get_args_parser()
torch.manual_seed(args.seed)
args.save_dir = os.path.join(args.out_dir, args.exp_name)
os.makedirs(args.save_dir, exist_ok=True)
logger = utils.get_logger(args.save_dir)
logger.info(json.dumps(vars(args), indent=4, sort_keys=True))
writer = SummaryWriter(args.save_dir)
model = HTR_VT.create_model(nb_cls=args.nb_cls, img_size=args.img_size[::-1])
total_param = sum(p.numel() for p in model.parameters())
logger.info('total_param is {}'.format(total_param))
model.train()
model = model.cuda()
model_ema = utils.ModelEma(model, args.ema_decay)
model.zero_grad()
logger.info('Loading train loader...')
train_dataset = dataset.myLoadDS(args.train_data_list, args.data_path, args.img_size)
train_loader = torch.utils.data.DataLoader(train_dataset,
batch_size=args.train_bs,
shuffle=True,
pin_memory=True,
num_workers=args.num_workers,
collate_fn=partial(dataset.SameTrCollate, args=args))
train_iter = dataset.cycle_data(train_loader)
logger.info('Loading val loader...')
val_dataset = dataset.myLoadDS(args.val_data_list, args.data_path, args.img_size, ralph=train_dataset.ralph)
val_loader = torch.utils.data.DataLoader(val_dataset,
batch_size=args.val_bs,
shuffle=False,
pin_memory=True,
num_workers=args.num_workers)
optimizer = sam.SAM(model.parameters(), torch.optim.AdamW, lr=1e-7, betas=(0.9, 0.99), weight_decay=args.weight_decay)
criterion = torch.nn.CTCLoss(reduction='none', zero_infinity=True)
converter = utils.CTCLabelConverter(train_dataset.ralph.values())
best_cer, best_wer = 1e+6, 1e+6
train_loss = 0.0
#### ---- train & eval ---- ####
for nb_iter in range(1, args.total_iter):
optimizer, current_lr = utils.update_lr_cos(nb_iter, args.warm_up_iter, args.total_iter, args.max_lr, optimizer)
optimizer.zero_grad()
batch = next(train_iter)
image = batch[0].cuda()
text, length = converter.encode(batch[1])
batch_size = image.size(0)
loss = compute_loss(args, model, image, batch_size, criterion, text, length)
loss.backward()
optimizer.first_step(zero_grad=True)
compute_loss(args, model, image, batch_size, criterion, text, length).backward()
optimizer.second_step(zero_grad=True)
model.zero_grad()
model_ema.update(model, num_updates=nb_iter / 2)
train_loss += loss.item()
if nb_iter % args.print_iter == 0:
train_loss_avg = train_loss / args.print_iter
logger.info(f'Iter : {nb_iter} \t LR : {current_lr:0.5f} \t training loss : {train_loss_avg:0.5f} \t ' )
writer.add_scalar('./Train/lr', current_lr, nb_iter)
writer.add_scalar('./Train/train_loss', train_loss_avg, nb_iter)
train_loss = 0.0
if nb_iter % args.eval_iter == 0:
model.eval()
with torch.no_grad():
val_loss, val_cer, val_wer, preds, labels = valid.validation(model_ema.ema,
criterion,
val_loader,
converter)
if val_cer < best_cer:
logger.info(f'CER improved from {best_cer:.4f} to {val_cer:.4f}!!!')
best_cer = val_cer
checkpoint = {
'model': model.state_dict(),
'state_dict_ema': model_ema.ema.state_dict(),
'optimizer': optimizer.state_dict(),
}
torch.save(checkpoint, os.path.join(args.save_dir, 'best_CER.pth'))
if val_wer < best_wer:
logger.info(f'WER improved from {best_wer:.4f} to {val_wer:.4f}!!!')
best_wer = val_wer
checkpoint = {
'model': model.state_dict(),
'state_dict_ema': model_ema.ema.state_dict(),
'optimizer': optimizer.state_dict(),
}
torch.save(checkpoint, os.path.join(args.save_dir, 'best_WER.pth'))
logger.info(
f'Val. loss : {val_loss:0.3f} \t CER : {val_cer:0.4f} \t WER : {val_wer:0.4f} \t ')
writer.add_scalar('./VAL/CER', val_cer, nb_iter)
writer.add_scalar('./VAL/WER', val_wer, nb_iter)
writer.add_scalar('./VAL/bestCER', best_cer, nb_iter)
writer.add_scalar('./VAL/bestWER', best_wer, nb_iter)
writer.add_scalar('./VAL/val_loss', val_loss, nb_iter)
model.train()
if __name__ == '__main__':
main()