-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathloss.py
115 lines (91 loc) · 3.43 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
"""Define various loss functions and bundle them with appropriate metrics."""
import torch
import numpy as np
class Loss:
"""Abstract class, containing necessary methods.
Abstract class to collect information about the 'higher-level' loss function, used to train an energy-based model
containing the evaluation of the loss function, its gradients w.r.t. to first and second argument and evaluations
of the actual metric that is targeted.
"""
def __init__(self):
"""Init."""
pass
def __call__(self, reference, argmin):
"""Return l(x, y)."""
raise NotImplementedError()
return value, name, format
def metric(self, reference, argmin):
"""The actually sought metric."""
raise NotImplementedError()
return value, name, format
class PSNR(Loss):
"""A classical MSE target.
The minimized criterion is MSE Loss, the actual metric is average PSNR.
"""
def __init__(self):
"""Init with torch MSE."""
self.loss_fn = torch.nn.MSELoss(size_average=None, reduce=None, reduction='mean')
def __call__(self, x=None, y=None):
"""Return l(x, y)."""
name = 'MSE'
pf_fmt = '.6f'
if x is None:
return name, pf_fmt
else:
value = 0.5 * self.loss_fn(x, y)
return value, name, pf_fmt
def metric(self, x=None, y=None):
"""The actually sought metric."""
name = 'avg PSNR'
pf_fmt = '.3f'
if x is None:
return name, pf_fmt
else:
value = self.psnr_compute(x, y)
return value, name, pf_fmt
@staticmethod
def psnr_compute(img_batch, ref_batch, batched=False, factor=1.0):
"""Standard PSNR."""
def get_psnr(img_in, img_ref):
mse = ((img_in - img_ref) ** 2).mean()
if mse > 0 and torch.isfinite(mse):
return (10 * torch.log10(factor ** 2 / mse)).item()
elif not torch.isfinite(mse):
return float('nan')
else:
return float('inf')
if batched:
psnr = get_psnr(img_batch.detach(), ref_batch)
else:
[B, C, m, n] = img_batch.shape
psnrs = []
for sample in range(B):
psnrs.append(get_psnr(img_batch.detach()[sample, :, :, :], ref_batch[sample, :, :, :]))
psnr = np.mean(psnrs)
return psnr
class Classification(Loss):
"""A classical NLL loss for classification. Evaluation has the softmax baked in.
The minimized criterion is cross entropy, the actual metric is total accuracy.
"""
def __init__(self):
"""Init with torch MSE."""
self.loss_fn = torch.nn.CrossEntropyLoss(weight=None, size_average=None, ignore_index=-100,
reduce=None, reduction='mean')
def __call__(self, x=None, y=None):
"""Return l(x, y)."""
name = 'CrossEntropy'
pf_fmt = '1.5f'
if x is None:
return name, pf_fmt
else:
value = self.loss_fn(x, y)
return value, name, pf_fmt
def metric(self, x=None, y=None):
"""The actually sought metric."""
name = 'Accuracy'
pf_fmt = '6.2%'
if x is None:
return name, pf_fmt
else:
value = (x.data.argmax(dim=1) == y).sum().float() / y.shape[0]
return value.detach(), name, pf_fmt