-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlosses.py
95 lines (72 loc) · 2.7 KB
/
losses.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import torch
import torch.nn.functional as F
import numpy as np
import math
class Grad3d(torch.nn.Module):
"""
N-D gradient loss.
"""
def __init__(self, penalty='l1', loss_mult=None):
super(Grad3d, self).__init__()
self.penalty = penalty
self.loss_mult = loss_mult
def forward(self, y_pred, y_true):
dy = torch.abs(y_pred[:, :, 1:, :, :] - y_pred[:, :, :-1, :, :])
dx = torch.abs(y_pred[:, :, :, 1:, :] - y_pred[:, :, :, :-1, :])
dz = torch.abs(y_pred[:, :, :, :, 1:] - y_pred[:, :, :, :, :-1])
if self.penalty == 'l2':
dy = dy * dy
dx = dx * dx
dz = dz * dz
d = torch.mean(dx) + torch.mean(dy) + torch.mean(dz)
grad = d / 3.0
if self.loss_mult is not None:
grad *= self.loss_mult
return grad
class NCC_vxm(torch.nn.Module):
"""
Local (over window) normalized cross correlation loss.
"""
def __init__(self, win=None):
super(NCC_vxm, self).__init__()
self.win = win
def forward(self, y_true, y_pred):
Ii = y_true
Ji = y_pred
# get dimension of volume
# assumes Ii, Ji are sized [batch_size, *vol_shape, nb_feats]
ndims = len(list(Ii.size())) - 2
assert ndims in [1, 2, 3], "volumes should be 1 to 3 dimensions. found: %d" % ndims
# set window size
win = [9] * ndims if self.win is None else self.win
# compute filters
sum_filt = torch.ones([1, 1, *win]).to("cuda")
pad_no = math.floor(win[0] / 2)
if ndims == 1:
stride = (1)
padding = (pad_no)
elif ndims == 2:
stride = (1, 1)
padding = (pad_no, pad_no)
else:
stride = (1, 1, 1)
padding = (pad_no, pad_no, pad_no)
# get convolution function
conv_fn = getattr(F, 'conv%dd' % ndims)
# compute CC squares
I2 = Ii * Ii
J2 = Ji * Ji
IJ = Ii * Ji
I_sum = conv_fn(Ii, sum_filt, stride=stride, padding=padding)
J_sum = conv_fn(Ji, sum_filt, stride=stride, padding=padding)
I2_sum = conv_fn(I2, sum_filt, stride=stride, padding=padding)
J2_sum = conv_fn(J2, sum_filt, stride=stride, padding=padding)
IJ_sum = conv_fn(IJ, sum_filt, stride=stride, padding=padding)
win_size = np.prod(win)
u_I = I_sum / win_size
u_J = J_sum / win_size
cross = IJ_sum - u_J * I_sum - u_I * J_sum + u_I * u_J * win_size
I_var = I2_sum - 2 * u_I * I_sum + u_I * u_I * win_size
J_var = J2_sum - 2 * u_J * J_sum + u_J * u_J * win_size
cc = cross * cross / (I_var * J_var + 1e-5)
return -torch.mean(cc)