-
Notifications
You must be signed in to change notification settings - Fork 10
/
Copy pathcross_entropy_loss.py
126 lines (112 loc) · 5.24 KB
/
cross_entropy_loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import paddle
from paddle import nn
import paddle.nn.functional as F
from ..builder import LOSSES
@LOSSES.register()
class CrossEntropyLoss(nn.Layer):
"""
Implements the cross entropy loss function.
Args:
weight (tuple|list|ndarray|Tensor, optional): A manual rescaling weight
given to each class. Its length must be equal to the number of classes.
Default ``None``.
ignore_index (int64, optional): Specifies a target value that is ignored
and does not contribute to the input gradient. Default ``255``.
top_k_percent_pixels (float, optional): the value lies in [0.0, 1.0].
When its value < 1.0, only compute the loss for the top k percent pixels
(e.g., the top 20% pixels). This is useful for hard pixel mining. Default ``1.0``.
data_format (str, optional): The tensor format to use, 'NCHW' or 'NHWC'. Default ``'NCHW'``.
"""
def __init__(self,
weight=None,
ignore_index=255,
top_k_percent_pixels=1.0,
data_format='NCHW',
loss_weight = 1.,
cfg = None):
super(CrossEntropyLoss, self).__init__()
self.ignore_index = ignore_index
self.top_k_percent_pixels = top_k_percent_pixels
self.EPS = 1e-8
self.data_format = data_format
if weight is not None:
self.weight = paddle.to_tensor(weight, dtype='float32')
else:
self.weight = None
self.loss_weight = loss_weight
def forward(self, logit, label, semantic_weights=None):
"""
Forward computation.
Args:
logit (Tensor): Logit tensor, the data type is float32, float64. Shape is
(N, C), where C is number of classes, and if shape is more than 2D, this
is (N, C, D1, D2,..., Dk), k >= 1.
label (Tensor): Label tensor, the data type is int64. Shape is (N), where each
value is 0 <= label[i] <= C-1, and if shape is more than 2D, this is
(N, D1, D2,..., Dk), k >= 1.
semantic_weights (Tensor, optional): Weights about loss for each pixels,
shape is the same as label. Default: None.
Returns:
(Tensor): The average loss.
"""
channel_axis = 1 if self.data_format == 'NCHW' else -1
if self.weight is not None and logit.shape[channel_axis] != len(
self.weight):
raise ValueError(
'The number of weights = {} must be the same as the number of classes = {}.'
.format(len(self.weight), logit.shape[channel_axis]))
if channel_axis == 1:
logit = paddle.transpose(logit, [0, 2, 3, 1])
label = label.astype('int64')
loss = F.cross_entropy(
logit,
label,
ignore_index=self.ignore_index,
reduction='none',
weight=self.weight)
return self._post_process_loss(logit, label, semantic_weights, loss)
def _post_process_loss(self, logit, label, semantic_weights, loss):
"""
Consider mask and top_k to calculate the final loss.
Args:
logit (Tensor): Logit tensor, the data type is float32, float64. Shape is
(N, C), where C is number of classes, and if shape is more than 2D, this
is (N, C, D1, D2,..., Dk), k >= 1.
label (Tensor): Label tensor, the data type is int64. Shape is (N), where each
value is 0 <= label[i] <= C-1, and if shape is more than 2D, this is
(N, D1, D2,..., Dk), k >= 1.
semantic_weights (Tensor, optional): Weights about loss for each pixels,
shape is the same as label.
loss (Tensor): Loss tensor which is the output of cross_entropy. If soft_label
is False in cross_entropy, the shape of loss should be the same as the label.
If soft_label is True in cross_entropy, the shape of loss should be
(N, D1, D2,..., Dk, 1).
Returns:
(Tensor): The average loss.
"""
mask = label != self.ignore_index
mask = paddle.cast(mask, 'float32')
label.stop_gradient = True
mask.stop_gradient = True
if loss.ndim > mask.ndim:
loss = paddle.squeeze(loss, axis=-1)
loss = loss * mask
if semantic_weights is not None:
loss = loss * semantic_weights
if self.weight is not None:
_one_hot = F.one_hot(label * mask, logit.shape[-1])
coef = paddle.sum(_one_hot * self.weight, axis=-1)
else:
coef = paddle.ones_like(label)
if self.top_k_percent_pixels == 1.0:
avg_loss = paddle.mean(loss) / (paddle.mean(mask * coef) + self.EPS)
else:
loss = loss.reshape((-1, ))
top_k_pixels = int(self.top_k_percent_pixels * loss.numel())
loss, indices = paddle.topk(loss, top_k_pixels)
coef = coef.reshape((-1, ))
coef = paddle.gather(coef, indices)
coef.stop_gradient = True
coef = coef.astype('float32')
avg_loss = loss.mean() / (paddle.mean(coef) + self.EPS)
return avg_loss * self.loss_weight