-
Notifications
You must be signed in to change notification settings - Fork 61
/
losses.py
122 lines (97 loc) · 3.45 KB
/
losses.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import torch
import torch.nn.functional as F
from helpers import get_device
def relu_evidence(y):
return F.relu(y)
def exp_evidence(y):
return torch.exp(torch.clamp(y, -10, 10))
def softplus_evidence(y):
return F.softplus(y)
def kl_divergence(alpha, num_classes, device=None):
if not device:
device = get_device()
ones = torch.ones([1, num_classes], dtype=torch.float32, device=device)
sum_alpha = torch.sum(alpha, dim=1, keepdim=True)
first_term = (
torch.lgamma(sum_alpha)
- torch.lgamma(alpha).sum(dim=1, keepdim=True)
+ torch.lgamma(ones).sum(dim=1, keepdim=True)
- torch.lgamma(ones.sum(dim=1, keepdim=True))
)
second_term = (
(alpha - ones)
.mul(torch.digamma(alpha) - torch.digamma(sum_alpha))
.sum(dim=1, keepdim=True)
)
kl = first_term + second_term
return kl
def loglikelihood_loss(y, alpha, device=None):
if not device:
device = get_device()
y = y.to(device)
alpha = alpha.to(device)
S = torch.sum(alpha, dim=1, keepdim=True)
loglikelihood_err = torch.sum((y - (alpha / S)) ** 2, dim=1, keepdim=True)
loglikelihood_var = torch.sum(
alpha * (S - alpha) / (S * S * (S + 1)), dim=1, keepdim=True
)
loglikelihood = loglikelihood_err + loglikelihood_var
return loglikelihood
def mse_loss(y, alpha, epoch_num, num_classes, annealing_step, device=None):
if not device:
device = get_device()
y = y.to(device)
alpha = alpha.to(device)
loglikelihood = loglikelihood_loss(y, alpha, device=device)
annealing_coef = torch.min(
torch.tensor(1.0, dtype=torch.float32),
torch.tensor(epoch_num / annealing_step, dtype=torch.float32),
)
kl_alpha = (alpha - 1) * (1 - y) + 1
kl_div = annealing_coef * kl_divergence(kl_alpha, num_classes, device=device)
return loglikelihood + kl_div
def edl_loss(func, y, alpha, epoch_num, num_classes, annealing_step, device=None):
y = y.to(device)
alpha = alpha.to(device)
S = torch.sum(alpha, dim=1, keepdim=True)
A = torch.sum(y * (func(S) - func(alpha)), dim=1, keepdim=True)
annealing_coef = torch.min(
torch.tensor(1.0, dtype=torch.float32),
torch.tensor(epoch_num / annealing_step, dtype=torch.float32),
)
kl_alpha = (alpha - 1) * (1 - y) + 1
kl_div = annealing_coef * kl_divergence(kl_alpha, num_classes, device=device)
return A + kl_div
def edl_mse_loss(output, target, epoch_num, num_classes, annealing_step, device=None):
if not device:
device = get_device()
evidence = relu_evidence(output)
alpha = evidence + 1
loss = torch.mean(
mse_loss(target, alpha, epoch_num, num_classes, annealing_step, device=device)
)
return loss
def edl_log_loss(output, target, epoch_num, num_classes, annealing_step, device=None):
if not device:
device = get_device()
evidence = relu_evidence(output)
alpha = evidence + 1
loss = torch.mean(
edl_loss(
torch.log, target, alpha, epoch_num, num_classes, annealing_step, device
)
)
return loss
def edl_digamma_loss(
output, target, epoch_num, num_classes, annealing_step, device=None
):
if not device:
device = get_device()
evidence = relu_evidence(output)
alpha = evidence + 1
loss = torch.mean(
edl_loss(
torch.digamma, target, alpha, epoch_num, num_classes, annealing_step, device
)
)
return loss