-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathengine.py
111 lines (79 loc) · 3.23 KB
/
engine.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import os
import time
import hashlib
from typing import Iterable
import imageio
import util.misc as utils
import datetime
import numpy as np
import matplotlib.pyplot as plt
from util.metric import nmse, psnr, ssim, AverageMeter
from collections import defaultdict
import torch
import torch.nn.functional as F
def train_one_epoch_null_nohead(model, criterion,data_loader, optimizer, device):
model.train()
loss_all = 0
count=0
for _, data in enumerate(data_loader):
count+=1
image, target, mean, std, fname, slice_num = data # NOTE
image = image.unsqueeze(1) # (8,1,320,320)
target = target.unsqueeze(1)
image = image.to(device)
target = target.to(device)
outputs = model(image)
loss = criterion(outputs, target)
optimizer.zero_grad()
loss['loss'].backward()
optimizer.step()
loss_all += loss['loss'].item()
loss_avg = loss_all / len(data_loader)
global_step = count
return {"loss": loss_avg, "global_step": global_step}
@torch.no_grad()
def server_evaluate(model, criterion, data_loaders, device):
model.eval()
criterion.eval()
criterion.to(device)
nmse_meter = AverageMeter()
psnr_meter = AverageMeter()
ssim_meter = AverageMeter()
output_dic = defaultdict(dict)
target_dic = defaultdict(dict)
start_time = time.time()
loss_all = 0
count = 0
for idx, data_loader in enumerate(data_loaders):
for i, data in enumerate(data_loader):
count += 1
image, target, mean, std, fname, slice_num = data
image = image.unsqueeze(1) # (8,1,320,320)
image = image.to(device)
target = target.to(device)
mean = mean.unsqueeze(1).unsqueeze(2)
std = std.unsqueeze(1).unsqueeze(2)
mean = mean.to(device)
std = std.to(device)
outputs = model(image)
outputs = outputs.squeeze(1)
outputs = outputs * std + mean
target = target * std + mean
loss = criterion(outputs, target)
loss_all += loss['loss'].item()
for k, f in enumerate(fname):
output_dic[f][slice_num[k].item()] = outputs[k]
target_dic[f][slice_num[k].item()] = target[k]
for name in output_dic.keys():
f_output = torch.stack([v for _, v in output_dic[name].items()]) # (34,320,320)
f_target = torch.stack([v for _, v in target_dic[name].items()]) # (34,320,320)
our_nmse = nmse(f_target.cpu().numpy(), f_output.cpu().numpy())
our_psnr = psnr(f_target.cpu().numpy(), f_output.cpu().numpy())
our_ssim = ssim(f_target.cpu().numpy(), f_output.cpu().numpy())
nmse_meter.update(our_nmse, 1)
psnr_meter.update(our_psnr, 1)
ssim_meter.update(our_ssim, 1)
total_time = time.time() - start_time
total_time = str(datetime.timedelta(seconds=int(total_time)))
loss_avg = loss_all / count
return {'total_time': total_time, 'loss': loss_avg, 'PSNR': psnr_meter.avg, 'SSIM': ssim_meter.avg, 'NMSE': nmse_meter.avg}