-
Notifications
You must be signed in to change notification settings - Fork 7
/
lscloss.py
151 lines (134 loc) · 7.19 KB
/
lscloss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import torch
import torch.nn.functional as F
class LocalSaliencyCoherence(torch.nn.Module):
"""
This loss function based on the following paper.
Please consider using the following bibtex for citation:
@article{obukhov2019gated,
author={Anton Obukhov and Stamatios Georgoulis and Dengxin Dai and Luc {Van Gool}},
title={Gated {CRF} Loss for Weakly Supervised Semantic Image Segmentation},
journal={CoRR},
volume={abs/1906.04651},
year={2019},
url={http://arxiv.org/abs/1906.04651},
}
"""
def forward(
self, y_hat_softmax, kernels_desc, kernels_radius, sample, height_input, width_input,
mask_src=None, mask_dst=None, compatibility=None, custom_modality_downsamplers=None, out_kernels_vis=False
):
"""
Performs the forward pass of the loss.
:param y_hat_softmax: A tensor of predicted per-pixel class probabilities of size NxCxHxW
:param kernels_desc: A list of dictionaries, each describing one Gaussian kernel composition from modalities.
The final kernel is a weighted sum of individual kernels. Following example is a composition of
RGBXY and XY kernels:
kernels_desc: [{
'weight': 0.9, # Weight of RGBXY kernel
'xy': 6, # Sigma for XY
'rgb': 0.1, # Sigma for RGB
},{
'weight': 0.1, # Weight of XY kernel
'xy': 6, # Sigma for XY
}]
:param kernels_radius: Defines size of bounding box region around each pixel in which the kernel is constructed.
:param sample: A dictionary with modalities (except 'xy') used in kernels_desc parameter. Each of the provided
modalities is allowed to be larger than the shape of y_hat_softmax, in such case downsampling will be
invoked. Default downsampling method is area resize; this can be overriden by setting.
custom_modality_downsamplers parameter.
:param width_input, height_input: Dimensions of the full scale resolution of modalities
:param mask_src: (optional) Source mask.
:param mask_dst: (optional) Destination mask.
:param compatibility: (optional) Classes compatibility matrix, defaults to Potts model.
:param custom_modality_downsamplers: A dictionary of modality downsampling functions.
:param out_kernels_vis: Whether to return a tensor with kernels visualized with some step.
:return: Loss function value.
"""
assert y_hat_softmax.dim() == 4, 'Prediction must be a NCHW batch'
N, C, height_pred, width_pred = y_hat_softmax.shape
device = y_hat_softmax.device
assert width_input % width_pred == 0 and height_input % height_pred == 0 and \
width_input * height_pred == height_input * width_pred, \
f'[{width_input}x{height_input}] !~= [{width_pred}x{height_pred}]'
kernels = self._create_kernels(
kernels_desc, kernels_radius, sample, N, height_pred, width_pred, device, custom_modality_downsamplers
)
y_hat_unfolded = self._unfold(y_hat_softmax, kernels_radius)
y_hat_unfolded = torch.abs(y_hat_unfolded[:, :, kernels_radius, kernels_radius, :, :].view(N, C, 1, 1, height_pred, width_pred) - y_hat_unfolded)
loss = torch.mean((kernels * y_hat_unfolded).view(N, C, (kernels_radius * 2 + 1) ** 2, height_pred, width_pred).sum(dim=2, keepdim=True))
out = {
'loss': loss.mean(),
}
if out_kernels_vis:
out['kernels_vis'] = self._visualize_kernels(
kernels, kernels_radius, height_input, width_input, height_pred, width_pred
)
return out
@staticmethod
def _downsample(img, modality, height_dst, width_dst, custom_modality_downsamplers):
if custom_modality_downsamplers is not None and modality in custom_modality_downsamplers:
f_down = custom_modality_downsamplers[modality]
else:
f_down = F.adaptive_avg_pool2d
return f_down(img, (height_dst, width_dst))
@staticmethod
def _create_kernels(
kernels_desc, kernels_radius, sample, N, height_pred, width_pred, device, custom_modality_downsamplers
):
kernels = None
for i, desc in enumerate(kernels_desc):
weight = desc['weight']
features = []
for modality, sigma in desc.items():
if modality == 'weight':
continue
if modality == 'xy':
feature = LocalSaliencyCoherence._get_mesh(N, height_pred, width_pred, device)
else:
assert modality in sample, \
f'Modality {modality} is listed in {i}-th kernel descriptor, but not present in the sample'
feature = sample[modality]
# feature = LocalSaliencyCoherence._downsample(
# feature, modality, height_pred, width_pred, custom_modality_downsamplers
# )
feature /= sigma
features.append(feature)
features = torch.cat(features, dim=1)
kernel = weight * LocalSaliencyCoherence._create_kernels_from_features(features, kernels_radius)
kernels = kernel if kernels is None else kernel + kernels
return kernels
@staticmethod
def _create_kernels_from_features(features, radius):
assert features.dim() == 4, 'Features must be a NCHW batch'
N, C, H, W = features.shape
kernels = LocalSaliencyCoherence._unfold(features, radius)
kernels = kernels - kernels[:, :, radius, radius, :, :].view(N, C, 1, 1, H, W)
kernels = (-0.5 * kernels ** 2).sum(dim=1, keepdim=True).exp()
# kernels[:, :, radius, radius, :, :] = 0
return kernels
@staticmethod
def _get_mesh(N, H, W, device):
return torch.cat((
torch.arange(0, W, 1, dtype=torch.float32, device=device).view(1, 1, 1, W).repeat(N, 1, H, 1),
torch.arange(0, H, 1, dtype=torch.float32, device=device).view(1, 1, H, 1).repeat(N, 1, 1, W)
), 1)
@staticmethod
def _unfold(img, radius):
assert img.dim() == 4, 'Unfolding requires NCHW batch'
N, C, H, W = img.shape
diameter = 2 * radius + 1
return F.unfold(img, diameter, 1, radius).view(N, C, diameter, diameter, H, W)
@staticmethod
def _visualize_kernels(kernels, radius, height_input, width_input, height_pred, width_pred):
diameter = 2 * radius + 1
vis = kernels[:, :, :, :, radius::diameter, radius::diameter]
vis_nh, vis_nw = vis.shape[-2:]
vis = vis.permute(0, 1, 4, 2, 5, 3).contiguous().view(kernels.shape[0], 1, diameter * vis_nh, diameter * vis_nw)
if vis.shape[2] > height_pred:
vis = vis[:, :, :height_pred, :]
if vis.shape[3] > width_pred:
vis = vis[:, :, :, :width_pred]
if vis.shape[2:] != (height_pred, width_pred):
vis = F.pad(vis, [0, width_pred-vis.shape[3], 0, height_pred-vis.shape[2]])
vis = F.interpolate(vis, (height_input, width_input), mode='nearest')
return vis