-
Notifications
You must be signed in to change notification settings - Fork 115
/
trainer.py
executable file
·207 lines (177 loc) · 9.3 KB
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
from utils import calc_psnr_and_ssim
from model import Vgg19
import os
import numpy as np
from imageio import imread, imsave
from PIL import Image
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.utils as utils
class Trainer():
def __init__(self, args, logger, dataloader, model, loss_all):
self.args = args
self.logger = logger
self.dataloader = dataloader
self.model = model
self.loss_all = loss_all
self.device = torch.device('cpu') if args.cpu else torch.device('cuda')
self.vgg19 = Vgg19.Vgg19(requires_grad=False).to(self.device)
if ((not self.args.cpu) and (self.args.num_gpu > 1)):
self.vgg19 = nn.DataParallel(self.vgg19, list(range(self.args.num_gpu)))
self.params = [
{"params": filter(lambda p: p.requires_grad, self.model.MainNet.parameters() if
args.num_gpu==1 else self.model.module.MainNet.parameters()),
"lr": args.lr_rate
},
{"params": filter(lambda p: p.requires_grad, self.model.LTE.parameters() if
args.num_gpu==1 else self.model.module.LTE.parameters()),
"lr": args.lr_rate_lte
}
]
self.optimizer = optim.Adam(self.params, betas=(args.beta1, args.beta2), eps=args.eps)
self.scheduler = optim.lr_scheduler.StepLR(
self.optimizer, step_size=self.args.decay, gamma=self.args.gamma)
self.max_psnr = 0.
self.max_psnr_epoch = 0
self.max_ssim = 0.
self.max_ssim_epoch = 0
def load(self, model_path=None):
if (model_path):
self.logger.info('load_model_path: ' + model_path)
#model_state_dict_save = {k.replace('module.',''):v for k,v in torch.load(model_path).items()}
model_state_dict_save = {k:v for k,v in torch.load(model_path, map_location=self.device).items()}
model_state_dict = self.model.state_dict()
model_state_dict.update(model_state_dict_save)
self.model.load_state_dict(model_state_dict)
def prepare(self, sample_batched):
for key in sample_batched.keys():
sample_batched[key] = sample_batched[key].to(self.device)
return sample_batched
def train(self, current_epoch=0, is_init=False):
self.model.train()
if (not is_init):
self.scheduler.step()
self.logger.info('Current epoch learning rate: %e' %(self.optimizer.param_groups[0]['lr']))
for i_batch, sample_batched in enumerate(self.dataloader['train']):
self.optimizer.zero_grad()
sample_batched = self.prepare(sample_batched)
lr = sample_batched['LR']
lr_sr = sample_batched['LR_sr']
hr = sample_batched['HR']
ref = sample_batched['Ref']
ref_sr = sample_batched['Ref_sr']
sr, S, T_lv3, T_lv2, T_lv1 = self.model(lr=lr, lrsr=lr_sr, ref=ref, refsr=ref_sr)
### calc loss
is_print = ((i_batch + 1) % self.args.print_every == 0) ### flag of print
rec_loss = self.args.rec_w * self.loss_all['rec_loss'](sr, hr)
loss = rec_loss
if (is_print):
self.logger.info( ('init ' if is_init else '') + 'epoch: ' + str(current_epoch) +
'\t batch: ' + str(i_batch+1) )
self.logger.info( 'rec_loss: %.10f' %(rec_loss.item()) )
if (not is_init):
if ('per_loss' in self.loss_all):
sr_relu5_1 = self.vgg19((sr + 1.) / 2.)
with torch.no_grad():
hr_relu5_1 = self.vgg19((hr.detach() + 1.) / 2.)
per_loss = self.args.per_w * self.loss_all['per_loss'](sr_relu5_1, hr_relu5_1)
loss += per_loss
if (is_print):
self.logger.info( 'per_loss: %.10f' %(per_loss.item()) )
if ('tpl_loss' in self.loss_all):
sr_lv1, sr_lv2, sr_lv3 = self.model(sr=sr)
tpl_loss = self.args.tpl_w * self.loss_all['tpl_loss'](sr_lv3, sr_lv2, sr_lv1,
S, T_lv3, T_lv2, T_lv1)
loss += tpl_loss
if (is_print):
self.logger.info( 'tpl_loss: %.10f' %(tpl_loss.item()) )
if ('adv_loss' in self.loss_all):
adv_loss = self.args.adv_w * self.loss_all['adv_loss'](sr, hr)
loss += adv_loss
if (is_print):
self.logger.info( 'adv_loss: %.10f' %(adv_loss.item()) )
loss.backward()
self.optimizer.step()
if ((not is_init) and current_epoch % self.args.save_every == 0):
self.logger.info('saving the model...')
tmp = self.model.state_dict()
model_state_dict = {key.replace('module.',''): tmp[key] for key in tmp if
(('SearchNet' not in key) and ('_copy' not in key))}
model_name = self.args.save_dir.strip('/')+'/model/model_'+str(current_epoch).zfill(5)+'.pt'
torch.save(model_state_dict, model_name)
def evaluate(self, current_epoch=0):
self.logger.info('Epoch ' + str(current_epoch) + ' evaluation process...')
if (self.args.dataset == 'CUFED'):
self.model.eval()
with torch.no_grad():
psnr, ssim, cnt = 0., 0., 0
for i_batch, sample_batched in enumerate(self.dataloader['test']['1']):
cnt += 1
sample_batched = self.prepare(sample_batched)
lr = sample_batched['LR']
lr_sr = sample_batched['LR_sr']
hr = sample_batched['HR']
ref = sample_batched['Ref']
ref_sr = sample_batched['Ref_sr']
sr, _, _, _, _ = self.model(lr=lr, lrsr=lr_sr, ref=ref, refsr=ref_sr)
if (self.args.eval_save_results):
sr_save = (sr+1.) * 127.5
sr_save = np.transpose(sr_save.squeeze().round().cpu().numpy(), (1, 2, 0)).astype(np.uint8)
imsave(os.path.join(self.args.save_dir, 'save_results', str(i_batch).zfill(5)+'.png'), sr_save)
### calculate psnr and ssim
_psnr, _ssim = calc_psnr_and_ssim(sr.detach(), hr.detach())
psnr += _psnr
ssim += _ssim
psnr_ave = psnr / cnt
ssim_ave = ssim / cnt
self.logger.info('Ref PSNR (now): %.3f \t SSIM (now): %.4f' %(psnr_ave, ssim_ave))
if (psnr_ave > self.max_psnr):
self.max_psnr = psnr_ave
self.max_psnr_epoch = current_epoch
if (ssim_ave > self.max_ssim):
self.max_ssim = ssim_ave
self.max_ssim_epoch = current_epoch
self.logger.info('Ref PSNR (max): %.3f (%d) \t SSIM (max): %.4f (%d)'
%(self.max_psnr, self.max_psnr_epoch, self.max_ssim, self.max_ssim_epoch))
self.logger.info('Evaluation over.')
def test(self):
self.logger.info('Test process...')
self.logger.info('lr path: %s' %(self.args.lr_path))
self.logger.info('ref path: %s' %(self.args.ref_path))
### LR and LR_sr
LR = imread(self.args.lr_path)
h1, w1 = LR.shape[:2]
LR_sr = np.array(Image.fromarray(LR).resize((w1*4, h1*4), Image.BICUBIC))
### Ref and Ref_sr
Ref = imread(self.args.ref_path)
h2, w2 = Ref.shape[:2]
h2, w2 = h2//4*4, w2//4*4
Ref = Ref[:h2, :w2, :]
Ref_sr = np.array(Image.fromarray(Ref).resize((w2//4, h2//4), Image.BICUBIC))
Ref_sr = np.array(Image.fromarray(Ref_sr).resize((w2, h2), Image.BICUBIC))
### change type
LR = LR.astype(np.float32)
LR_sr = LR_sr.astype(np.float32)
Ref = Ref.astype(np.float32)
Ref_sr = Ref_sr.astype(np.float32)
### rgb range to [-1, 1]
LR = LR / 127.5 - 1.
LR_sr = LR_sr / 127.5 - 1.
Ref = Ref / 127.5 - 1.
Ref_sr = Ref_sr / 127.5 - 1.
### to tensor
LR_t = torch.from_numpy(LR.transpose((2,0,1))).unsqueeze(0).float().to(self.device)
LR_sr_t = torch.from_numpy(LR_sr.transpose((2,0,1))).unsqueeze(0).float().to(self.device)
Ref_t = torch.from_numpy(Ref.transpose((2,0,1))).unsqueeze(0).float().to(self.device)
Ref_sr_t = torch.from_numpy(Ref_sr.transpose((2,0,1))).unsqueeze(0).float().to(self.device)
self.model.eval()
with torch.no_grad():
sr, _, _, _, _ = self.model(lr=LR_t, lrsr=LR_sr_t, ref=Ref_t, refsr=Ref_sr_t)
sr_save = (sr+1.) * 127.5
sr_save = np.transpose(sr_save.squeeze().round().cpu().numpy(), (1, 2, 0)).astype(np.uint8)
save_path = os.path.join(self.args.save_dir, 'save_results', os.path.basename(self.args.lr_path))
imsave(save_path, sr_save)
self.logger.info('output path: %s' %(save_path))
self.logger.info('Test over.')