-
Notifications
You must be signed in to change notification settings - Fork 6
/
losses.py
170 lines (118 loc) · 5.74 KB
/
losses.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
#!/usr/env/bin python3.7
from functools import reduce
from operator import mul, add
from typing import List, Tuple, cast
import torch
import numpy as np
from torch import Tensor, einsum
from utils import simplex, one_hot
class CrossEntropy():
def __init__(self, **kwargs):
# Self.idc is used to filter out some classes of the target mask. Use fancy indexing
self.idc: List[int] = kwargs["idc"]
self.nd: str = kwargs["nd"]
print(f"Initialized {self.__class__.__name__} with {kwargs}")
def __call__(self, probs: Tensor, target: Tensor, _: Tensor, __) -> Tensor:
assert simplex(probs) and simplex(target)
log_p: Tensor = (probs[:, self.idc, ...] + 1e-10).log()
mask: Tensor = cast(Tensor, target[:, self.idc, ...].type(torch.float32))
loss = - einsum(f"bk{self.nd},bk{self.nd}->", mask, log_p)
loss /= mask.sum() + 1e-10
return loss
class AbstractConstraints():
def __init__(self, **kwargs):
self.idc: List[int] = kwargs["idc"]
self.nd: str = kwargs["nd"]
self.C = len(self.idc)
self.__fn__ = getattr(__import__('utils'), kwargs['fn'])
print(f"Initialized {self.__class__.__name__} with {kwargs}")
def penalty(self, z: Tensor) -> Tensor:
raise NotImplementedError
def __call__(self, probs: Tensor, target: Tensor, bounds: Tensor, _) -> Tensor:
assert simplex(probs) # and simplex(target) # Actually, does not care about second part
assert probs.shape == target.shape
# b, _, w, h = probs.shape # type: Tuple[int, int, int, int]
b: int
b, _, *im_shape = probs.shape
_, _, k, two = bounds.shape # scalar or vector
assert two == 2
value: Tensor = cast(Tensor, self.__fn__(probs[:, self.idc, ...]))
lower_b = bounds[:, self.idc, :, 0]
upper_b = bounds[:, self.idc, :, 1]
assert value.shape == (b, self.C, k), value.shape
assert lower_b.shape == upper_b.shape == (b, self.C, k), lower_b.shape
upper_z: Tensor = cast(Tensor, (value - upper_b).type(torch.float32)).flatten()
lower_z: Tensor = cast(Tensor, (lower_b - value).type(torch.float32)).flatten()
upper_penalty: Tensor = reduce(add, (self.penalty(e) for e in upper_z))
lower_penalty: Tensor = reduce(add, (self.penalty(e) for e in lower_z))
res: Tensor = upper_penalty + lower_penalty
loss: Tensor = res.sum() / reduce(mul, im_shape)
assert loss.requires_grad == probs.requires_grad # Handle the case for validation
return loss
class LogBarrierLoss(AbstractConstraints):
def __init__(self, **kwargs):
self.t: float = kwargs["t"]
super().__init__(**kwargs)
def penalty(self, z: Tensor) -> Tensor:
assert z.shape == ()
if z <= - 1 / self.t**2:
return - torch.log(-z) / self.t
else:
return self.t * z + -np.log(1 / (self.t**2)) / self.t + 1 / self.t
class BoxPrior():
def __init__(self, **kwargs):
self.idc: List[int] = kwargs["idc"]
self.t: float = kwargs["t"]
print(f"Initialized {self.__class__.__name__} with {kwargs}")
def barrier(self, z: Tensor) -> Tensor:
assert z.shape == ()
if z <= - 1 / self.t**2:
return - torch.log(-z) / self.t
else:
return self.t * z + -np.log(1 / (self.t**2)) / self.t + 1 / self.t
def __call__(self, probs: Tensor, _: Tensor, __: Tensor,
box_prior: List[List[Tuple[Tensor, Tensor]]]) -> Tensor:
assert simplex(probs)
B: int = probs.shape[0]
assert len(box_prior) == B
sublosses = []
for b in range(B):
for k in self.idc:
masks, bounds = box_prior[b][k]
sizes: Tensor = einsum('wh,nwh->n', probs[b, k], masks)
assert sizes.shape == bounds.shape == (masks.shape[0],), (sizes.shape, bounds.shape, masks.shape)
shifted: Tensor = bounds - sizes
init = torch.zeros((), dtype=torch.float32, requires_grad=probs.requires_grad, device=probs.device)
sublosses.append(reduce(add, (self.barrier(v) for v in shifted), init))
loss: Tensor = reduce(add, sublosses)
assert loss.dtype == torch.float32
assert loss.shape == (), loss.shape
return loss
class NegSizeLoss():
def __init__(self, **kwargs):
# Self.idc is used to filter out some classes of the target mask. Use fancy indexing
self.idc: List[int] = kwargs["idc"]
self.t: float = kwargs["t"]
self.nd: str = kwargs["nd"]
print(f"Initialized {self.__class__.__name__} with {kwargs}")
def penalty(self, z: Tensor) -> Tensor:
assert z.shape == ()
if z <= - 1 / self.t**2:
return - torch.log(-z) / self.t
else:
return self.t * z + -np.log(1 / (self.t**2)) / self.t + 1 / self.t
def __call__(self, probs: Tensor, target: Tensor, _: Tensor, __) -> Tensor:
assert simplex(probs) and simplex(target)
b: int
b, _, *im_shape = probs.shape
probs_m: Tensor = probs[:, self.idc, ...]
target_m: Tensor = cast(Tensor, target[:, self.idc, ...].type(torch.float32))
nd: str = self.nd
# Compute the size for each class, masked by the target pixels (where target ==1)
masked_sizes: Tensor = einsum(f"bk{nd},bk{nd}->bk", probs_m, target_m).flatten()
# We want that size to be <= so no shift is needed
res: Tensor = reduce(add, (self.penalty(e) for e in masked_sizes)) # type: ignore
loss: Tensor = res / reduce(mul, im_shape)
assert loss.shape == ()
assert loss.requires_grad == probs.requires_grad # Handle the case for validation
return loss