-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathuce.py
65 lines (51 loc) · 2.53 KB
/
uce.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import torch
from uce_utils import nentr
def uceloss(softmaxes, labels, n_bins=15):
d = softmaxes.device
bin_boundaries = torch.linspace(0, 1, n_bins + 1, device=d)
bin_lowers = bin_boundaries[:-1]
bin_uppers = bin_boundaries[1:]
_, predictions = torch.max(softmaxes, 1)
errors = predictions.ne(labels)
uncertainties = nentr(softmaxes, base=softmaxes.size(1))
errors_in_bin_list = []
avg_entropy_in_bin_list = []
uce = torch.zeros(1, device=d)
for bin_lower, bin_upper in zip(bin_lowers, bin_uppers):
# Calculate |uncert - err| in each bin
in_bin = uncertainties.gt(bin_lower.item()) * uncertainties.le(bin_upper.item())
prop_in_bin = in_bin.float().mean() # |Bm| / n
if prop_in_bin.item() > 0.0:
errors_in_bin = errors[in_bin].float().mean() # err()
avg_entropy_in_bin = uncertainties[in_bin].mean() # uncert()
uce += torch.abs(avg_entropy_in_bin - errors_in_bin) * prop_in_bin
errors_in_bin_list.append(errors_in_bin)
avg_entropy_in_bin_list.append(avg_entropy_in_bin)
err_in_bin = torch.tensor(errors_in_bin_list, device=d)
avg_entropy_in_bin = torch.tensor(avg_entropy_in_bin_list, device=d)
return uce, err_in_bin, avg_entropy_in_bin, uncertainties
def eceloss(softmaxes, labels, n_bins=15):
"""
Modified from https://github.com/gpleiss/temperature_scaling/blob/master/temperature_scaling.py
"""
d = softmaxes.device
bin_boundaries = torch.linspace(0, 1, n_bins + 1, device=d)
bin_lowers = bin_boundaries[:-1]
bin_uppers = bin_boundaries[1:]
confidences, predictions = torch.max(softmaxes, 1)
accuracies = predictions.eq(labels)
accuracy_in_bin_list = []
avg_confidence_in_bin_list = []
ece = torch.zeros(1, device=d)
for bin_lower, bin_upper in zip(bin_lowers, bin_uppers):
in_bin = confidences.gt(bin_lower.item()) * confidences.le(bin_upper.item())
prop_in_bin = in_bin.float().mean()
if prop_in_bin.item() > 0.0:
accuracy_in_bin = accuracies[in_bin].float().mean()
avg_confidence_in_bin = confidences[in_bin].mean()
ece += torch.abs(avg_confidence_in_bin - accuracy_in_bin) * prop_in_bin
accuracy_in_bin_list.append(accuracy_in_bin)
avg_confidence_in_bin_list.append(avg_confidence_in_bin)
acc_in_bin = torch.tensor(accuracy_in_bin_list, device=d)
avg_conf_in_bin = torch.tensor(avg_confidence_in_bin_list, device=d)
return ece, acc_in_bin, avg_conf_in_bin, confidences