-
Notifications
You must be signed in to change notification settings - Fork 4
/
loss.py
43 lines (33 loc) · 1.42 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import torch
import torch.nn.functional as F
def _get_losses_lengths(logits, labels):
probs = F.softmax(logits.transpose(1, 2), dim=1)
losses = F.nll_loss(probs, labels, ignore_index=-100, reduction="none")
lengths = (labels >= 0).sum(1)
return losses, lengths
def samplewise_average_loss(logits, labels):
losses, lengths = _get_losses_lengths(logits, labels)
weights = (1/lengths)[:, None].expand(labels.shape) / len(labels)
return (losses * weights).sum()
def phonewise_average_loss(logits, labels):
losses, lengths = _get_losses_lengths(logits, labels)
weights = 1 / lengths.sum()
return (losses * weights).sum()
def ctc_like_loss(logits, labels):
losses, lengths = _get_losses_lengths(logits, labels)
weights = []
for label, length in zip(labels, lengths):
label = label[:length]
_, indices, counts = torch.unique_consecutive(
label, return_inverse=True, return_counts=True)
weight = torch.index_select(1 / counts, 0, indices) / counts.shape[0]
weight = torch.cat((weight, torch.zeros(labels.shape[1] - length, device=weight.device)))
weights.append(weight)
weights = torch.stack(weights) / len(labels)
return (losses * weights).sum()
def get_loss(name):
return {
"samplewise_average": samplewise_average_loss,
"phonewise_average": phonewise_average_loss,
"ctc_like": ctc_like_loss,
}[name]