Skip to content

Commit

Permalink
add norm_by_times param to ctc_loss (#32490)
Browse files Browse the repository at this point in the history
* add norm_by_times param to ctc_loss

* fix doc,test=develop
  • Loading branch information
LDOUBLEV authored Apr 26, 2021
1 parent 1b9a3bf commit 6c03ea5
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 6 deletions.
6 changes: 4 additions & 2 deletions python/paddle/nn/functional/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -1023,7 +1023,8 @@ def ctc_loss(log_probs,
input_lengths,
label_lengths,
blank=0,
reduction='mean'):
reduction='mean',
norm_by_times=False):
"""
An operator integrating the open source Warp-CTC library (https://github.com/baidu-research/warp-ctc)
Expand All @@ -1038,6 +1039,7 @@ def ctc_loss(log_probs,
label_lengths (Tensor): The length for each label sequence, it should have shape [batch_size] and dtype int64.
blank (int, optional): The blank label index of Connectionist Temporal Classification (CTC) loss, which is in the half-opened interval [0, num_classes + 1). The data type must be int32. Default is 0.
reduction (string, optional): Indicate how to average the loss, the candicates are ``'none'`` | ``'mean'`` | ``'sum'``. If :attr:`reduction` is ``'mean'``, the output loss will be divided by the label_lengths, and then return the mean of quotient; If :attr:`reduction` is ``'sum'``, return the sum of loss; If :attr:`reduction` is ``'none'``, no reduction will be applied. Default is ``'mean'``.
norm_by_times (bool, default False) – Whether to normalize the gradients by the number of time-step, which is also the sequence’s length. There is no need to normalize the gradients if reduction mode is 'mean'.
Returns:
Tensor, The Connectionist Temporal Classification (CTC) loss between ``log_probs`` and ``labels``. If attr:`reduction` is ``'none'``, the shape of loss is [batch_size], otherwise, the shape of loss is [1]. Data type is the same as ``log_probs``.
Expand Down Expand Up @@ -1101,7 +1103,7 @@ def ctc_loss(log_probs,
"""

loss_out = fluid.layers.warpctc(log_probs, labels, blank, False,
loss_out = fluid.layers.warpctc(log_probs, labels, blank, norm_by_times,
input_lengths, label_lengths)

loss_out = fluid.layers.squeeze(loss_out, [-1])
Expand Down
19 changes: 15 additions & 4 deletions python/paddle/nn/layer/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -1060,6 +1060,7 @@ class CTCLoss(fluid.dygraph.Layer):
labels (Tensor): The ground truth sequence with padding, which must be a 3-D Tensor. The tensor shape is [batch_size, max_label_length], where max_label_length is the longest length of label sequence. The data type must be int32.
input_lengths (Tensor): The length for each input sequence, it should have shape [batch_size] and dtype int64.
label_lengths (Tensor): The length for each label sequence, it should have shape [batch_size] and dtype int64.
norm_by_times (bool, default false) – Whether to normalize the gradients by the number of time-step, which is also the sequence’s length. There is no need to normalize the gradients if reduction mode is 'mean'.
Returns:
Tensor, The Connectionist Temporal Classification (CTC) loss between ``log_probs`` and ``labels``. If attr:`reduction` is ``'none'``, the shape of loss is [batch_size], otherwise, the shape of loss is [1]. Data type is the same as ``log_probs``.
Expand Down Expand Up @@ -1122,10 +1123,20 @@ def __init__(self, blank=0, reduction='mean'):
self.blank = blank
self.reduction = reduction

def forward(self, log_probs, labels, input_lengths, label_lengths):
return paddle.nn.functional.ctc_loss(log_probs, labels, input_lengths,
label_lengths, self.blank,
self.reduction)
def forward(self,
log_probs,
labels,
input_lengths,
label_lengths,
norm_by_times=False):
return paddle.nn.functional.ctc_loss(
log_probs,
labels,
input_lengths,
label_lengths,
self.blank,
self.reduction,
norm_by_times=norm_by_times)


class SmoothL1Loss(fluid.dygraph.Layer):
Expand Down

0 comments on commit 6c03ea5

Please sign in to comment.