-
Notifications
You must be signed in to change notification settings - Fork 16
/
Copy pathtrainer.py
120 lines (103 loc) · 4.83 KB
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
__author__ = 'ziao'
import os.path
import datetime
import cv2
import numpy as np
#from skimage.measure import compare_ssim
#from skimage.metrics import _structural_similarity
from skimage.metrics import structural_similarity
from utils import preprocess, metrics
#from skimage.metrics.structural_similarity import compare_ssim
def train(model, ims, configs, itr):
cost = model.train(ims, itr)
if configs.reverse_input:#旋转
ims_rev = np.flip(ims, axis=1).copy()
cost += model.train(ims_rev, itr)
cost = cost / 2
if itr % configs.display_interval == 0:#100
print(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'), 'itr: ' + str(itr))
print('training loss: ' + str(cost))
return cost
def test(model, test_input_handle, configs, itr):
print(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'), 'test...')
test_input_handle.begin(do_shuffle=False)
res_path = os.path.join(configs.gen_frm_dir, str(itr))
os.mkdir(res_path)
avg_mse = 0
batch_id = 0
img_mse, ssim, psnr, fmae, sharp = [], [], [], [], []
for i in range(configs.test_total_length - configs.test_input_length):
img_mse.append(0)
ssim.append(0)
psnr.append(0)
fmae.append(0)
sharp.append(0)
while (test_input_handle.no_batch_left() == False):
batch_id = batch_id + 1
test_ims = test_input_handle.get_batch()
test_dat = preprocess.reshape_patch(test_ims, configs.patch_size)
img_gen = model.test(test_dat)
# concat outputs of different gpus along batch
# img_gen = np.concatenate(img_gen)
# img_gen shape: (batch_size, seq_length, height, width, num_channels)
img_gen = preprocess.reshape_patch_back(img_gen, configs.patch_size)
output_length = configs.test_total_length - configs.test_input_length
# img_gen_length = img_gen.shape[0]
img_out = img_gen[:, -output_length:]
# MSE per frame
for i in range(output_length):
x = test_ims[:, i + configs.test_input_length, :, :, :]
gx = img_out[:, i, :, :, :]
fmae[i] += metrics.batch_mae_frame_float(gx, x)
gx = np.maximum(gx, 0)
gx = np.minimum(gx, 1)
mse = np.square(x - gx).sum()
img_mse[i] += mse
avg_mse += mse
real_frm = np.uint8(x * 255)
pred_frm = np.uint8(gx * 255)
psnr[i] += metrics.batch_psnr(pred_frm, real_frm)
for b in range(configs.batch_size):
score, _ = structural_similarity(pred_frm[b], real_frm[b], full=True, multichannel=True)
ssim[i] += score
sharp[i] += np.max(cv2.convertScaleAbs(cv2.Laplacian(pred_frm[b], 3)))
# save prediction examples
if 1 > 0 :#configs.max_iterations*0.6:
if batch_id <= configs.num_save_samples:
path = os.path.join(res_path, str(batch_id))
os.mkdir(path)
for i in range(configs.test_total_length):
name = 'gt' + str(i + 1) + '.png'
file_name = os.path.join(path, name)
img_gt = np.uint8(test_ims[0, i, :, :, :] * 255)
cv2.imwrite(file_name, img_gt)
for i in range(output_length):
name = 'pd' + str(i + 1 + configs.test_input_length) + '.png'
file_name = os.path.join(path, name)
img_pd = img_out[0, i, :, :, :]
img_pd = np.maximum(img_pd, 0)
img_pd = np.minimum(img_pd, 1)
img_pd = np.uint8(img_pd * 255)
cv2.imwrite(file_name, img_pd)
test_input_handle.next()
avg_mse = avg_mse / (batch_id * configs.batch_size)
print('mse per seq: ' + str(avg_mse))
for i in range(configs.test_total_length - configs.test_input_length):
print(img_mse[i] / (batch_id * configs.batch_size))
ssim = np.asarray(ssim, dtype=np.float32) / (configs.batch_size * batch_id)
psnr = np.asarray(psnr, dtype=np.float32) / batch_id
fmae = np.asarray(fmae, dtype=np.float32) / batch_id
sharp = np.asarray(sharp, dtype=np.float32) / (configs.batch_size * batch_id)
print('ssim per frame: ' + str(np.mean(ssim)))
for i in range(configs.test_total_length - configs.test_input_length):
print(ssim[i])
print('psnr per frame: ' + str(np.mean(psnr)))
for i in range(configs.test_total_length - configs.test_input_length):
print(psnr[i])
print('fmae per frame: ' + str(np.mean(fmae)))
for i in range(configs.test_total_length - configs.test_input_length):
print(fmae[i])
print('sharpness per frame: ' + str(np.mean(sharp)))
for i in range(configs.test_total_length - configs.test_input_length):
print(sharp[i])
return avg_mse, ssim, psnr, fmae, sharp