-
Notifications
You must be signed in to change notification settings - Fork 217
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Change the sign of the rnnt_loss and add reduction argument #911
Merged
Merged
Changes from 5 commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
f92711e
Add right boundary constrains for s_begin
pkufool e03a4c0
Minor fixes to the interface of rnnt_loss to make it return positive …
pkufool 1a3e29d
Fix comments
pkufool 3ca1886
Release a new version
pkufool 0ab2d1f
Minor fixes
pkufool bce3965
Minor fixes to the docs
pkufool File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -178,6 +178,7 @@ def rnnt_loss_simple( | |
symbols: Tensor, | ||
termination_symbol: int, | ||
boundary: Optional[Tensor] = None, | ||
reduction: Optional[str] = "mean", | ||
return_grad: bool = False, | ||
) -> Union[Tensor, Tuple[Tensor, Tuple[Tensor, Tensor]]]: | ||
"""A simple case of the RNN-T loss, where the 'joiner' network is just | ||
|
@@ -201,25 +202,51 @@ def rnnt_loss_simple( | |
[0, 0, S, T] | ||
if boundary is not supplied. | ||
Most likely you will want begin_symbol and begin_frame to be zero. | ||
reduction: | ||
Specifies the reduction to apply to the output: `none`, `mean` or `sum`. | ||
`none`: no reduction will be applied. | ||
`mean`: apply `torch.mean` over the batches. | ||
`sum`: the output will be summed. | ||
Default: `mean` | ||
return_grad: | ||
Whether to return grads of px and py, this grad standing for the | ||
occupation probability is the output of the backward with a | ||
`fake gradient` input (all ones) This is useful to implement the | ||
pruned version of rnnt loss. | ||
Returns: | ||
If return_grad is False, returns a Tensor of shape (B,), containing the | ||
NEGATED total RNN-T loss values for each element of the batch | ||
(like log-probs of sequences). | ||
If return_grad is False, returns a tensor of shape (B,), containing the | ||
total RNN-T loss values for each element of the batch if reduction equals | ||
to "none", otherwise a scalar with the reduction applied. | ||
If return_grad is True, the grads of px and py, which is the output of | ||
backward with a `fake gradient` input, will be returned too. And the | ||
returned value will be a tuple like (loss, (px_grad, py_grad)). | ||
""" | ||
px, py = get_rnnt_logprobs(lm, am, symbols, termination_symbol, boundary) | ||
return mutual_information_recursion(px, py, boundary, return_grad) | ||
px, py = get_rnnt_logprobs( | ||
lm=lm, | ||
am=am, | ||
symbols=symbols, | ||
termination_symbol=termination_symbol, | ||
boundary=boundary, | ||
) | ||
scores_and_grads = mutual_information_recursion( | ||
px=px, py=py, boundary=boundary, return_grad=return_grad | ||
) | ||
negated_loss = scores_and_grads[0] if return_grad else scores_and_grads | ||
if reduction == "none": | ||
loss = -negated_loss | ||
elif reduction == "mean": | ||
loss = -torch.mean(negated_loss) | ||
elif reduction == "sum": | ||
loss = -torch.sum(negated_loss) | ||
else: | ||
assert ( | ||
False | ||
), f"reduction should be ('none' | 'mean' | 'sum'), given {reduction}" | ||
return (loss, scores_and_grads[1]) if return_grad else loss | ||
|
||
|
||
def get_rnnt_logprobs_joint( | ||
joint: Tensor, | ||
logits: Tensor, | ||
symbols: Tensor, | ||
termination_symbol: int, | ||
boundary: Optional[Tensor] = None, | ||
|
@@ -229,7 +256,7 @@ def get_rnnt_logprobs_joint( | |
This function is called from rnnt_loss(). | ||
|
||
Args: | ||
joint: | ||
logits: | ||
The output of joiner network, with shape (B, T, S + 1, C), | ||
i.e. batch, time_seq_len, symbol_seq_len+1, num_classes | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we should remove "possibly including EOS", since RNN-T does not use EOS. |
||
symbols: | ||
|
@@ -262,16 +289,16 @@ def get_rnnt_logprobs_joint( | |
we cannot emit any symbols. This is simply a way of incorporating | ||
the probability of the termination symbol on the last frame. | ||
""" | ||
assert joint.ndim == 4 | ||
(B, T, S1, C) = joint.shape | ||
assert logits.ndim == 4 | ||
(B, T, S1, C) = logits.shape | ||
S = S1 - 1 | ||
assert symbols.shape == (B, S) | ||
|
||
normalizers = torch.logsumexp(joint, dim=3) | ||
normalizers = torch.logsumexp(logits, dim=3) | ||
normalizers = normalizers.permute((0, 2, 1)) | ||
|
||
px = torch.gather( | ||
joint, dim=3, index=symbols.reshape(B, 1, S, 1).expand(B, T, S, 1) | ||
logits, dim=3, index=symbols.reshape(B, 1, S, 1).expand(B, T, S, 1) | ||
).squeeze(-1) | ||
px = px.permute((0, 2, 1)) | ||
px = torch.cat( | ||
|
@@ -287,7 +314,7 @@ def get_rnnt_logprobs_joint( | |
px[:, :, :T] -= normalizers[:, :S, :] | ||
|
||
py = ( | ||
joint[:, :, :, termination_symbol].permute((0, 2, 1)).clone() | ||
logits[:, :, :, termination_symbol].permute((0, 2, 1)).clone() | ||
) # [B][S+1][T] | ||
py -= normalizers | ||
px = px.contiguous() | ||
|
@@ -298,16 +325,17 @@ def get_rnnt_logprobs_joint( | |
|
||
|
||
def rnnt_loss( | ||
joint: Tensor, | ||
logits: Tensor, | ||
symbols: Tensor, | ||
termination_symbol: int, | ||
boundary: Optional[Tensor] = None, | ||
reduction: Optional[str] = "mean", | ||
) -> Tensor: | ||
"""A normal RNN-T loss, which uses a 'joiner' network output as input, | ||
i.e. a 4 dimensions tensor. | ||
|
||
Args: | ||
joint: | ||
logits: | ||
The output of joiner network, with shape (B, T, S + 1, C), | ||
i.e. batch, time_seq_len, symbol_seq_len+1, num_classes | ||
symbols: | ||
|
@@ -320,15 +348,35 @@ def rnnt_loss( | |
[begin_symbol, begin_frame, end_symbol, end_frame] that is treated as | ||
[0, 0, S, T] if boundary is not supplied. | ||
Most likely you will want begin_symbol and begin_frame to be zero. | ||
reduction: | ||
Specifies the reduction to apply to the output: `none`, `mean` or `sum`. | ||
`none`: no reduction will be applied. | ||
`mean`: apply `torch.mean` over the batches. | ||
`sum`: the output will be summed. | ||
Default: `mean` | ||
|
||
Returns: | ||
A Tensor of shape (B,), containing the total RNN-T loss values for each | ||
element of the batch (like log-probs of sequences). | ||
If recursion is `none`, returns a tensor of shape (B,), containing the | ||
total RNN-T loss values for each element of the batch, otherwise a scalar | ||
with the reduction applied. | ||
""" | ||
px, py = get_rnnt_logprobs_joint( | ||
joint, symbols, termination_symbol, boundary | ||
logits=logits, | ||
symbols=symbols, | ||
termination_symbol=termination_symbol, | ||
boundary=boundary, | ||
) | ||
return mutual_information_recursion(px, py, boundary) | ||
negated_loss = mutual_information_recursion(px=px, py=py, boundary=boundary) | ||
if reduction == "none": | ||
return -negated_loss | ||
elif reduction == "mean": | ||
return -torch.mean(negated_loss) | ||
elif reduction == "sum": | ||
return -torch.sum(negated_loss) | ||
else: | ||
assert ( | ||
False | ||
), f"reduction should be ('none' | 'mean' | 'sum'), given {reduction}" | ||
|
||
|
||
def _adjust_pruning_lower_bound( | ||
|
@@ -464,10 +512,13 @@ def get_rnnt_prune_ranges( | |
s_begin = torch.argmax(diff_grad, dim=1) | ||
s_begin = s_begin[:, :T] | ||
|
||
# handle the values of s_begin in padding positions. | ||
# set the s_begin in paddding positions to `len(symbols) - s_range + 1` | ||
# Handle the values of s_begin in padding positions. | ||
# -1 here means we fill the position of the last frame of real data with | ||
# padding value which is `len(symbols) - s_range + 1`. | ||
# This is to guarantee that we reach the last symbol at last frame of real | ||
# data. | ||
mask = torch.arange(0, T, device=px_grad.device).reshape(1, T).expand(B, T) | ||
mask = mask < boundary[:, 3].reshape(B, 1) | ||
mask = mask < boundary[:, 3].reshape(B, 1) - 1 | ||
|
||
s_begin_padding = boundary[:, 2].reshape(B, 1) - s_range + 1 | ||
# handle the cases when `len(symbols) < s_range` | ||
|
@@ -561,7 +612,7 @@ def _roll_by_shifts(src: torch.Tensor, shifts: torch.LongTensor): | |
|
||
|
||
def get_rnnt_logprobs_pruned( | ||
joint: Tensor, | ||
logits: Tensor, | ||
symbols: Tensor, | ||
ranges: Tensor, | ||
termination_symbol: int, | ||
|
@@ -570,7 +621,7 @@ def get_rnnt_logprobs_pruned( | |
"""Construct px, py for mutual_information_recursion with pruned output. | ||
|
||
Args: | ||
joint: | ||
logits: | ||
The pruned output of joiner network, with shape (B, T, s_range, C) | ||
symbols: | ||
The symbol sequences, a LongTensor of shape [B][S], and elements in | ||
|
@@ -589,15 +640,15 @@ def get_rnnt_logprobs_pruned( | |
Return the px (B, S, T + 1) and py (B, S + 1, T) needed by | ||
mutual_information_recursion. | ||
""" | ||
# joint (B, T, s_range, C) | ||
# logits (B, T, s_range, C) | ||
# symbols (B, S) | ||
# ranges (B, T, s_range) | ||
assert joint.ndim == 4 | ||
(B, T, s_range, C) = joint.shape | ||
assert logits.ndim == 4 | ||
(B, T, s_range, C) = logits.shape | ||
assert ranges.shape == (B, T, s_range) | ||
(B, S) = symbols.shape | ||
|
||
normalizers = torch.logsumexp(joint, dim=3) | ||
normalizers = torch.logsumexp(logits, dim=3) | ||
|
||
symbols_with_terminal = torch.cat( | ||
( | ||
|
@@ -620,7 +671,7 @@ def get_rnnt_logprobs_pruned( | |
|
||
# (B, T, s_range) | ||
px = torch.gather( | ||
joint, dim=3, index=pruning_symbols.reshape(B, T, s_range, 1) | ||
logits, dim=3, index=pruning_symbols.reshape(B, T, s_range, 1) | ||
).squeeze(-1) | ||
px = px - normalizers | ||
|
||
|
@@ -652,7 +703,7 @@ def get_rnnt_logprobs_pruned( | |
dim=2, | ||
) # now: [B][S][T+1], index [:,:,T] has -inf.. | ||
|
||
py = joint[:, :, :, termination_symbol] # (B, T, s_range) | ||
py = logits[:, :, :, termination_symbol].clone() # (B, T, s_range) | ||
py = py - normalizers | ||
|
||
# (B, T, S + 1) with index larger than s_range in dim 2 filled with -inf | ||
|
@@ -682,11 +733,12 @@ def get_rnnt_logprobs_pruned( | |
|
||
|
||
def rnnt_loss_pruned( | ||
joint: Tensor, | ||
logits: Tensor, | ||
symbols: Tensor, | ||
ranges: Tensor, | ||
termination_symbol: int, | ||
boundary: Tensor = None, | ||
reduction: Optional[str] = "mean", | ||
) -> Tensor: | ||
"""A RNN-T loss with pruning, which uses a pruned 'joiner' network output | ||
as input, i.e. a 4 dimensions tensor with shape (B, T, s_range, C), | ||
|
@@ -708,14 +760,35 @@ def rnnt_loss_pruned( | |
[begin_symbol, begin_frame, end_symbol, end_frame] that is treated as | ||
[0, 0, S, T] if boundary is not supplied. | ||
Most likely you will want begin_symbol and begin_frame to be zero. | ||
reduction: | ||
Specifies the reduction to apply to the output: `none`, `mean` or `sum`. | ||
`none`: no reduction will be applied. | ||
`mean`: apply `torch.mean` over the batches. | ||
`sum`: the output will be summed. | ||
Default: `mean` | ||
Returns: | ||
A Tensor of shape (B,), containing the total RNN-T loss values for each | ||
element of the batch (like log-probs of sequences). | ||
If recursion is `none`, returns a tensor of shape (B,), containing the | ||
total RNN-T loss values for each element of the batch, otherwise a scalar | ||
with the reduction applied. | ||
""" | ||
px, py = get_rnnt_logprobs_pruned( | ||
joint, symbols, ranges, termination_symbol, boundary | ||
logits=logits, | ||
symbols=symbols, | ||
ranges=ranges, | ||
termination_symbol=termination_symbol, | ||
boundary=boundary, | ||
) | ||
return mutual_information_recursion(px, py, boundary) | ||
negated_loss = mutual_information_recursion(px=px, py=py, boundary=boundary) | ||
if reduction == "none": | ||
return -negated_loss | ||
elif reduction == "mean": | ||
return -torch.mean(negated_loss) | ||
elif reduction == "sum": | ||
return -torch.sum(negated_loss) | ||
else: | ||
assert ( | ||
False | ||
), f"reduction should be ('none' | 'mean' | 'sum'), given {reduction}" | ||
|
||
|
||
def get_rnnt_logprobs_smoothed( | ||
|
@@ -921,6 +994,7 @@ def rnnt_loss_smoothed( | |
lm_only_scale: float = 0.1, | ||
am_only_scale: float = 0.1, | ||
boundary: Optional[Tensor] = None, | ||
reduction: Optional[str] = "mean", | ||
return_grad: bool = False, | ||
) -> Tensor: | ||
"""A simple case of the RNN-T loss, where the 'joiner' network is just | ||
|
@@ -951,27 +1025,47 @@ def rnnt_loss_smoothed( | |
[0, 0, S, T] | ||
if boundary is not supplied. | ||
Most likely you will want begin_symbol and begin_frame to be zero. | ||
reduction: | ||
Specifies the reduction to apply to the output: `none`, `mean` or `sum`. | ||
`none`: no reduction will be applied. | ||
`mean`: apply `torch.mean` over the batches. | ||
`sum`: the output will be summed. | ||
Default: `mean` | ||
return_grad: | ||
Whether to return grads of px and py, this grad standing for the | ||
occupation probability is the output of the backward with a | ||
`fake gradient` input (all ones) This is useful to implement the | ||
pruned version of rnnt loss. | ||
|
||
Returns: | ||
If return_grad is False, returns a Tensor of shape (B,), containing the | ||
NEGATED total RNN-T loss values for each element of the batch | ||
(like log-probs of sequences). | ||
If return_grad is False, returns a tensor of shape (B,), containing the | ||
total RNN-T loss values for each element of the batch if reduction equals | ||
to "none", otherwise a scalar with the reduction applied. | ||
If return_grad is True, the grads of px and py, which is the output of | ||
backward with a `fake gradient` input, will be returned too. And the | ||
returned value will be a tuple like (loss, (px_grad, py_grad)). | ||
""" | ||
px, py = get_rnnt_logprobs_smoothed( | ||
lm, | ||
am, | ||
symbols, | ||
termination_symbol, | ||
lm_only_scale, | ||
am_only_scale, | ||
boundary, | ||
lm=lm, | ||
am=am, | ||
symbols=symbols, | ||
termination_symbol=termination_symbol, | ||
lm_only_scale=lm_only_scale, | ||
am_only_scale=am_only_scale, | ||
boundary=boundary, | ||
) | ||
scores_and_grads = mutual_information_recursion( | ||
px=px, py=py, boundary=boundary, return_grad=return_grad | ||
) | ||
return mutual_information_recursion(px, py, boundary, return_grad) | ||
negated_loss = scores_and_grads[0] if return_grad else scores_and_grads | ||
if reduction == "none": | ||
loss = -negated_loss | ||
elif reduction == "mean": | ||
loss = -torch.mean(negated_loss) | ||
elif reduction == "sum": | ||
loss = -torch.sum(negated_loss) | ||
else: | ||
assert ( | ||
False | ||
), f"reduction should be ('none' | 'mean' | 'sum'), given {reduction}" | ||
return (loss, scores_and_grads[1]) if return_grad else loss |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Around line 214 we should clarify that the "fake gradient" is the same as the gradient you'd
get if you did
torch.grad( (-loss.sum()), [px, py])
.Now there is a - sign.
When you refer to the fake gradient below, you can add "(see above)" to clarify that it's explained above.