Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change the sign of the rnnt_loss and add reduction argument #911

Merged
merged 6 commits into from
Jan 29, 2022
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ message(STATUS "Enabled languages: ${languages}")

project(k2 ${languages})

set(K2_VERSION "1.12")
set(K2_VERSION "1.13")

# ----------------- Supported build types for K2 project -----------------
set(ALLOWABLE_BUILD_TYPES Debug Release RelWithDebInfo MinSizeRel)
Expand Down
184 changes: 139 additions & 45 deletions k2/python/k2/rnnt_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ def rnnt_loss_simple(
symbols: Tensor,
termination_symbol: int,
boundary: Optional[Tensor] = None,
reduction: Optional[str] = "mean",
return_grad: bool = False,
) -> Union[Tensor, Tuple[Tensor, Tuple[Tensor, Tensor]]]:
"""A simple case of the RNN-T loss, where the 'joiner' network is just
Expand All @@ -201,25 +202,51 @@ def rnnt_loss_simple(
[0, 0, S, T]
if boundary is not supplied.
Most likely you will want begin_symbol and begin_frame to be zero.
reduction:
Specifies the reduction to apply to the output: `none`, `mean` or `sum`.
`none`: no reduction will be applied.
`mean`: apply `torch.mean` over the batches.
`sum`: the output will be summed.
Default: `mean`
return_grad:
Whether to return grads of px and py, this grad standing for the
occupation probability is the output of the backward with a
`fake gradient` input (all ones) This is useful to implement the
pruned version of rnnt loss.
Returns:
If return_grad is False, returns a Tensor of shape (B,), containing the
NEGATED total RNN-T loss values for each element of the batch
(like log-probs of sequences).
If return_grad is False, returns a tensor of shape (B,), containing the
total RNN-T loss values for each element of the batch if reduction equals
to "none", otherwise a scalar with the reduction applied.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Around line 214 we should clarify that the "fake gradient" is the same as the gradient you'd
get if you did torch.grad( (-loss.sum()), [px, py]).
Now there is a - sign.
When you refer to the fake gradient below, you can add "(see above)" to clarify that it's explained above.

If return_grad is True, the grads of px and py, which is the output of
backward with a `fake gradient` input, will be returned too. And the
returned value will be a tuple like (loss, (px_grad, py_grad)).
"""
px, py = get_rnnt_logprobs(lm, am, symbols, termination_symbol, boundary)
return mutual_information_recursion(px, py, boundary, return_grad)
px, py = get_rnnt_logprobs(
lm=lm,
am=am,
symbols=symbols,
termination_symbol=termination_symbol,
boundary=boundary,
)
scores_and_grads = mutual_information_recursion(
px=px, py=py, boundary=boundary, return_grad=return_grad
)
negated_loss = scores_and_grads[0] if return_grad else scores_and_grads
if reduction == "none":
loss = -negated_loss
elif reduction == "mean":
loss = -torch.mean(negated_loss)
elif reduction == "sum":
loss = -torch.sum(negated_loss)
else:
assert (
False
), f"reduction should be ('none' | 'mean' | 'sum'), given {reduction}"
return (loss, scores_and_grads[1]) if return_grad else loss


def get_rnnt_logprobs_joint(
joint: Tensor,
logits: Tensor,
symbols: Tensor,
termination_symbol: int,
boundary: Optional[Tensor] = None,
Expand All @@ -229,7 +256,7 @@ def get_rnnt_logprobs_joint(
This function is called from rnnt_loss().

Args:
joint:
logits:
The output of joiner network, with shape (B, T, S + 1, C),
i.e. batch, time_seq_len, symbol_seq_len+1, num_classes
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should remove "possibly including EOS", since RNN-T does not use EOS.

symbols:
Expand Down Expand Up @@ -262,16 +289,16 @@ def get_rnnt_logprobs_joint(
we cannot emit any symbols. This is simply a way of incorporating
the probability of the termination symbol on the last frame.
"""
assert joint.ndim == 4
(B, T, S1, C) = joint.shape
assert logits.ndim == 4
(B, T, S1, C) = logits.shape
S = S1 - 1
assert symbols.shape == (B, S)

normalizers = torch.logsumexp(joint, dim=3)
normalizers = torch.logsumexp(logits, dim=3)
normalizers = normalizers.permute((0, 2, 1))

px = torch.gather(
joint, dim=3, index=symbols.reshape(B, 1, S, 1).expand(B, T, S, 1)
logits, dim=3, index=symbols.reshape(B, 1, S, 1).expand(B, T, S, 1)
).squeeze(-1)
px = px.permute((0, 2, 1))
px = torch.cat(
Expand All @@ -287,7 +314,7 @@ def get_rnnt_logprobs_joint(
px[:, :, :T] -= normalizers[:, :S, :]

py = (
joint[:, :, :, termination_symbol].permute((0, 2, 1)).clone()
logits[:, :, :, termination_symbol].permute((0, 2, 1)).clone()
) # [B][S+1][T]
py -= normalizers
px = px.contiguous()
Expand All @@ -298,16 +325,17 @@ def get_rnnt_logprobs_joint(


def rnnt_loss(
joint: Tensor,
logits: Tensor,
symbols: Tensor,
termination_symbol: int,
boundary: Optional[Tensor] = None,
reduction: Optional[str] = "mean",
) -> Tensor:
"""A normal RNN-T loss, which uses a 'joiner' network output as input,
i.e. a 4 dimensions tensor.

Args:
joint:
logits:
The output of joiner network, with shape (B, T, S + 1, C),
i.e. batch, time_seq_len, symbol_seq_len+1, num_classes
symbols:
Expand All @@ -320,15 +348,35 @@ def rnnt_loss(
[begin_symbol, begin_frame, end_symbol, end_frame] that is treated as
[0, 0, S, T] if boundary is not supplied.
Most likely you will want begin_symbol and begin_frame to be zero.
reduction:
Specifies the reduction to apply to the output: `none`, `mean` or `sum`.
`none`: no reduction will be applied.
`mean`: apply `torch.mean` over the batches.
`sum`: the output will be summed.
Default: `mean`

Returns:
A Tensor of shape (B,), containing the total RNN-T loss values for each
element of the batch (like log-probs of sequences).
If recursion is `none`, returns a tensor of shape (B,), containing the
total RNN-T loss values for each element of the batch, otherwise a scalar
with the reduction applied.
"""
px, py = get_rnnt_logprobs_joint(
joint, symbols, termination_symbol, boundary
logits=logits,
symbols=symbols,
termination_symbol=termination_symbol,
boundary=boundary,
)
return mutual_information_recursion(px, py, boundary)
negated_loss = mutual_information_recursion(px=px, py=py, boundary=boundary)
if reduction == "none":
return -negated_loss
elif reduction == "mean":
return -torch.mean(negated_loss)
elif reduction == "sum":
return -torch.sum(negated_loss)
else:
assert (
False
), f"reduction should be ('none' | 'mean' | 'sum'), given {reduction}"


def _adjust_pruning_lower_bound(
Expand Down Expand Up @@ -464,10 +512,13 @@ def get_rnnt_prune_ranges(
s_begin = torch.argmax(diff_grad, dim=1)
s_begin = s_begin[:, :T]

# handle the values of s_begin in padding positions.
# set the s_begin in paddding positions to `len(symbols) - s_range + 1`
# Handle the values of s_begin in padding positions.
# -1 here means we fill the position of the last frame of real data with
# padding value which is `len(symbols) - s_range + 1`.
# This is to guarantee that we reach the last symbol at last frame of real
# data.
mask = torch.arange(0, T, device=px_grad.device).reshape(1, T).expand(B, T)
mask = mask < boundary[:, 3].reshape(B, 1)
mask = mask < boundary[:, 3].reshape(B, 1) - 1

s_begin_padding = boundary[:, 2].reshape(B, 1) - s_range + 1
# handle the cases when `len(symbols) < s_range`
Expand Down Expand Up @@ -561,7 +612,7 @@ def _roll_by_shifts(src: torch.Tensor, shifts: torch.LongTensor):


def get_rnnt_logprobs_pruned(
joint: Tensor,
logits: Tensor,
symbols: Tensor,
ranges: Tensor,
termination_symbol: int,
Expand All @@ -570,7 +621,7 @@ def get_rnnt_logprobs_pruned(
"""Construct px, py for mutual_information_recursion with pruned output.

Args:
joint:
logits:
The pruned output of joiner network, with shape (B, T, s_range, C)
symbols:
The symbol sequences, a LongTensor of shape [B][S], and elements in
Expand All @@ -589,15 +640,15 @@ def get_rnnt_logprobs_pruned(
Return the px (B, S, T + 1) and py (B, S + 1, T) needed by
mutual_information_recursion.
"""
# joint (B, T, s_range, C)
# logits (B, T, s_range, C)
# symbols (B, S)
# ranges (B, T, s_range)
assert joint.ndim == 4
(B, T, s_range, C) = joint.shape
assert logits.ndim == 4
(B, T, s_range, C) = logits.shape
assert ranges.shape == (B, T, s_range)
(B, S) = symbols.shape

normalizers = torch.logsumexp(joint, dim=3)
normalizers = torch.logsumexp(logits, dim=3)

symbols_with_terminal = torch.cat(
(
Expand All @@ -620,7 +671,7 @@ def get_rnnt_logprobs_pruned(

# (B, T, s_range)
px = torch.gather(
joint, dim=3, index=pruning_symbols.reshape(B, T, s_range, 1)
logits, dim=3, index=pruning_symbols.reshape(B, T, s_range, 1)
).squeeze(-1)
px = px - normalizers

Expand Down Expand Up @@ -652,7 +703,7 @@ def get_rnnt_logprobs_pruned(
dim=2,
) # now: [B][S][T+1], index [:,:,T] has -inf..

py = joint[:, :, :, termination_symbol] # (B, T, s_range)
py = logits[:, :, :, termination_symbol].clone() # (B, T, s_range)
py = py - normalizers

# (B, T, S + 1) with index larger than s_range in dim 2 filled with -inf
Expand Down Expand Up @@ -682,11 +733,12 @@ def get_rnnt_logprobs_pruned(


def rnnt_loss_pruned(
joint: Tensor,
logits: Tensor,
symbols: Tensor,
ranges: Tensor,
termination_symbol: int,
boundary: Tensor = None,
reduction: Optional[str] = "mean",
) -> Tensor:
"""A RNN-T loss with pruning, which uses a pruned 'joiner' network output
as input, i.e. a 4 dimensions tensor with shape (B, T, s_range, C),
Expand All @@ -708,14 +760,35 @@ def rnnt_loss_pruned(
[begin_symbol, begin_frame, end_symbol, end_frame] that is treated as
[0, 0, S, T] if boundary is not supplied.
Most likely you will want begin_symbol and begin_frame to be zero.
reduction:
Specifies the reduction to apply to the output: `none`, `mean` or `sum`.
`none`: no reduction will be applied.
`mean`: apply `torch.mean` over the batches.
`sum`: the output will be summed.
Default: `mean`
Returns:
A Tensor of shape (B,), containing the total RNN-T loss values for each
element of the batch (like log-probs of sequences).
If recursion is `none`, returns a tensor of shape (B,), containing the
total RNN-T loss values for each element of the batch, otherwise a scalar
with the reduction applied.
"""
px, py = get_rnnt_logprobs_pruned(
joint, symbols, ranges, termination_symbol, boundary
logits=logits,
symbols=symbols,
ranges=ranges,
termination_symbol=termination_symbol,
boundary=boundary,
)
return mutual_information_recursion(px, py, boundary)
negated_loss = mutual_information_recursion(px=px, py=py, boundary=boundary)
if reduction == "none":
return -negated_loss
elif reduction == "mean":
return -torch.mean(negated_loss)
elif reduction == "sum":
return -torch.sum(negated_loss)
else:
assert (
False
), f"reduction should be ('none' | 'mean' | 'sum'), given {reduction}"


def get_rnnt_logprobs_smoothed(
Expand Down Expand Up @@ -921,6 +994,7 @@ def rnnt_loss_smoothed(
lm_only_scale: float = 0.1,
am_only_scale: float = 0.1,
boundary: Optional[Tensor] = None,
reduction: Optional[str] = "mean",
return_grad: bool = False,
) -> Tensor:
"""A simple case of the RNN-T loss, where the 'joiner' network is just
Expand Down Expand Up @@ -951,27 +1025,47 @@ def rnnt_loss_smoothed(
[0, 0, S, T]
if boundary is not supplied.
Most likely you will want begin_symbol and begin_frame to be zero.
reduction:
Specifies the reduction to apply to the output: `none`, `mean` or `sum`.
`none`: no reduction will be applied.
`mean`: apply `torch.mean` over the batches.
`sum`: the output will be summed.
Default: `mean`
return_grad:
Whether to return grads of px and py, this grad standing for the
occupation probability is the output of the backward with a
`fake gradient` input (all ones) This is useful to implement the
pruned version of rnnt loss.

Returns:
If return_grad is False, returns a Tensor of shape (B,), containing the
NEGATED total RNN-T loss values for each element of the batch
(like log-probs of sequences).
If return_grad is False, returns a tensor of shape (B,), containing the
total RNN-T loss values for each element of the batch if reduction equals
to "none", otherwise a scalar with the reduction applied.
If return_grad is True, the grads of px and py, which is the output of
backward with a `fake gradient` input, will be returned too. And the
returned value will be a tuple like (loss, (px_grad, py_grad)).
"""
px, py = get_rnnt_logprobs_smoothed(
lm,
am,
symbols,
termination_symbol,
lm_only_scale,
am_only_scale,
boundary,
lm=lm,
am=am,
symbols=symbols,
termination_symbol=termination_symbol,
lm_only_scale=lm_only_scale,
am_only_scale=am_only_scale,
boundary=boundary,
)
scores_and_grads = mutual_information_recursion(
px=px, py=py, boundary=boundary, return_grad=return_grad
)
return mutual_information_recursion(px, py, boundary, return_grad)
negated_loss = scores_and_grads[0] if return_grad else scores_and_grads
if reduction == "none":
loss = -negated_loss
elif reduction == "mean":
loss = -torch.mean(negated_loss)
elif reduction == "sum":
loss = -torch.sum(negated_loss)
else:
assert (
False
), f"reduction should be ('none' | 'mean' | 'sum'), given {reduction}"
return (loss, scores_and_grads[1]) if return_grad else loss
Loading