diff --git a/CMakeLists.txt b/CMakeLists.txt index aa58d1b37..386e412d5 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -45,7 +45,7 @@ message(STATUS "Enabled languages: ${languages}") project(k2 ${languages}) -set(K2_VERSION "1.12") +set(K2_VERSION "1.13") # ----------------- Supported build types for K2 project ----------------- set(ALLOWABLE_BUILD_TYPES Debug Release RelWithDebInfo MinSizeRel) diff --git a/k2/python/k2/rnnt_loss.py b/k2/python/k2/rnnt_loss.py index ca4ffcf8f..2150d4ed2 100644 --- a/k2/python/k2/rnnt_loss.py +++ b/k2/python/k2/rnnt_loss.py @@ -26,15 +26,14 @@ def fix_for_boundary(px: Tensor, boundary: Optional[Tensor] = None) -> Tensor: """ Insert -inf's into `px` in appropriate places if `boundary` is not - None. If boundary == None and modified == False, px[:,:,-1] will - be -infinity, but if boundary is specified, we need px[b,:,boundary[b,3]] + None. If boundary == None, px[:,:,-1] will be -infinity, + but if boundary is specified, we need px[b,:,boundary[b,3]] to be -infinity. Args: - px: a Tensor of of shape [B][S][T+1] (this function is only - called if modified == False, see other docs for `modified`) - px is modified in-place and returned. - boundary: None, or a Tensor of shape [B][3] containing - [s_begin, t_begin, s_end, t_end]; we need only t_end. + px: a Tensor of of shape [B][S][T+1], px is modified in-place + and returned. + boundary: None, or a Tensor of shape [B][3] containing + [s_begin, t_begin, s_end, t_end]; we need only t_end. """ if boundary is None: return px @@ -82,7 +81,7 @@ def get_rnnt_logprobs( next on this frame. symbols: A LongTensor of shape [B][S], containing the symbols at each position - of the sequence, possibly including EOS + of the sequence. termination_symbol: The identity of the termination symbol, must be in {0..C-1} boundary: @@ -178,10 +177,11 @@ def rnnt_loss_simple( symbols: Tensor, termination_symbol: int, boundary: Optional[Tensor] = None, + reduction: Optional[str] = "mean", return_grad: bool = False, ) -> Union[Tensor, Tuple[Tensor, Tuple[Tensor, Tensor]]]: """A simple case of the RNN-T loss, where the 'joiner' network is just - addition. Returns negated total loss value. + addition. Args: lm: @@ -201,25 +201,53 @@ def rnnt_loss_simple( [0, 0, S, T] if boundary is not supplied. Most likely you will want begin_symbol and begin_frame to be zero. + reduction: + Specifies the reduction to apply to the output: `none`, `mean` or `sum`. + `none`: no reduction will be applied. + `mean`: apply `torch.mean` over the batches. + `sum`: the output will be summed. + Default: `mean` return_grad: Whether to return grads of px and py, this grad standing for the occupation probability is the output of the backward with a - `fake gradient` input (all ones) This is useful to implement the - pruned version of rnnt loss. + `fake gradient`, the `fake gradient` is the same as the gradient you'd + get if you did `torch.autograd.grad((-loss.sum()), [px, py])`, note, the + loss here is the loss with reduction "none". + This is useful to implement the pruned version of rnnt loss. Returns: - If return_grad is False, returns a Tensor of shape (B,), containing the - NEGATED total RNN-T loss values for each element of the batch - (like log-probs of sequences). + If return_grad is False, returns a tensor of shape (B,), containing the + total RNN-T loss values for each element of the batch if reduction equals + to "none", otherwise a scalar with the reduction applied. If return_grad is True, the grads of px and py, which is the output of - backward with a `fake gradient` input, will be returned too. And the + backward with a `fake gradient`(see above), will be returned too. And the returned value will be a tuple like (loss, (px_grad, py_grad)). """ - px, py = get_rnnt_logprobs(lm, am, symbols, termination_symbol, boundary) - return mutual_information_recursion(px, py, boundary, return_grad) + px, py = get_rnnt_logprobs( + lm=lm, + am=am, + symbols=symbols, + termination_symbol=termination_symbol, + boundary=boundary, + ) + scores_and_grads = mutual_information_recursion( + px=px, py=py, boundary=boundary, return_grad=return_grad + ) + negated_loss = scores_and_grads[0] if return_grad else scores_and_grads + if reduction == "none": + loss = -negated_loss + elif reduction == "mean": + loss = -torch.mean(negated_loss) + elif reduction == "sum": + loss = -torch.sum(negated_loss) + else: + assert ( + False + ), f"reduction should be ('none' | 'mean' | 'sum'), given {reduction}" + return (loss, scores_and_grads[1]) if return_grad else loss def get_rnnt_logprobs_joint( - joint: Tensor, + logits: Tensor, symbols: Tensor, termination_symbol: int, boundary: Optional[Tensor] = None, @@ -229,12 +257,12 @@ def get_rnnt_logprobs_joint( This function is called from rnnt_loss(). Args: - joint: + logits: The output of joiner network, with shape (B, T, S + 1, C), i.e. batch, time_seq_len, symbol_seq_len+1, num_classes symbols: A LongTensor of shape [B][S], containing the symbols at each position - of the sequence, possibly including EOS + of the sequence. termination_symbol: The identity of the termination symbol, must be in {0..C-1} boundary: @@ -262,16 +290,16 @@ def get_rnnt_logprobs_joint( we cannot emit any symbols. This is simply a way of incorporating the probability of the termination symbol on the last frame. """ - assert joint.ndim == 4 - (B, T, S1, C) = joint.shape + assert logits.ndim == 4 + (B, T, S1, C) = logits.shape S = S1 - 1 assert symbols.shape == (B, S) - normalizers = torch.logsumexp(joint, dim=3) + normalizers = torch.logsumexp(logits, dim=3) normalizers = normalizers.permute((0, 2, 1)) px = torch.gather( - joint, dim=3, index=symbols.reshape(B, 1, S, 1).expand(B, T, S, 1) + logits, dim=3, index=symbols.reshape(B, 1, S, 1).expand(B, T, S, 1) ).squeeze(-1) px = px.permute((0, 2, 1)) px = torch.cat( @@ -287,7 +315,7 @@ def get_rnnt_logprobs_joint( px[:, :, :T] -= normalizers[:, :S, :] py = ( - joint[:, :, :, termination_symbol].permute((0, 2, 1)).clone() + logits[:, :, :, termination_symbol].permute((0, 2, 1)).clone() ) # [B][S+1][T] py -= normalizers px = px.contiguous() @@ -298,16 +326,17 @@ def get_rnnt_logprobs_joint( def rnnt_loss( - joint: Tensor, + logits: Tensor, symbols: Tensor, termination_symbol: int, boundary: Optional[Tensor] = None, + reduction: Optional[str] = "mean", ) -> Tensor: """A normal RNN-T loss, which uses a 'joiner' network output as input, i.e. a 4 dimensions tensor. Args: - joint: + logits: The output of joiner network, with shape (B, T, S + 1, C), i.e. batch, time_seq_len, symbol_seq_len+1, num_classes symbols: @@ -320,15 +349,35 @@ def rnnt_loss( [begin_symbol, begin_frame, end_symbol, end_frame] that is treated as [0, 0, S, T] if boundary is not supplied. Most likely you will want begin_symbol and begin_frame to be zero. + reduction: + Specifies the reduction to apply to the output: `none`, `mean` or `sum`. + `none`: no reduction will be applied. + `mean`: apply `torch.mean` over the batches. + `sum`: the output will be summed. + Default: `mean` Returns: - A Tensor of shape (B,), containing the total RNN-T loss values for each - element of the batch (like log-probs of sequences). + If recursion is `none`, returns a tensor of shape (B,), containing the + total RNN-T loss values for each element of the batch, otherwise a scalar + with the reduction applied. """ px, py = get_rnnt_logprobs_joint( - joint, symbols, termination_symbol, boundary + logits=logits, + symbols=symbols, + termination_symbol=termination_symbol, + boundary=boundary, ) - return mutual_information_recursion(px, py, boundary) + negated_loss = mutual_information_recursion(px=px, py=py, boundary=boundary) + if reduction == "none": + return -negated_loss + elif reduction == "mean": + return -torch.mean(negated_loss) + elif reduction == "sum": + return -torch.sum(negated_loss) + else: + assert ( + False + ), f"reduction should be ('none' | 'mean' | 'sum'), given {reduction}" def _adjust_pruning_lower_bound( @@ -464,10 +513,13 @@ def get_rnnt_prune_ranges( s_begin = torch.argmax(diff_grad, dim=1) s_begin = s_begin[:, :T] - # handle the values of s_begin in padding positions. - # set the s_begin in paddding positions to `len(symbols) - s_range + 1` + # Handle the values of s_begin in padding positions. + # -1 here means we fill the position of the last frame of real data with + # padding value which is `len(symbols) - s_range + 1`. + # This is to guarantee that we reach the last symbol at last frame of real + # data. mask = torch.arange(0, T, device=px_grad.device).reshape(1, T).expand(B, T) - mask = mask < boundary[:, 3].reshape(B, 1) + mask = mask < boundary[:, 3].reshape(B, 1) - 1 s_begin_padding = boundary[:, 2].reshape(B, 1) - s_range + 1 # handle the cases when `len(symbols) < s_range` @@ -561,7 +613,7 @@ def _roll_by_shifts(src: torch.Tensor, shifts: torch.LongTensor): def get_rnnt_logprobs_pruned( - joint: Tensor, + logits: Tensor, symbols: Tensor, ranges: Tensor, termination_symbol: int, @@ -570,7 +622,7 @@ def get_rnnt_logprobs_pruned( """Construct px, py for mutual_information_recursion with pruned output. Args: - joint: + logits: The pruned output of joiner network, with shape (B, T, s_range, C) symbols: The symbol sequences, a LongTensor of shape [B][S], and elements in @@ -589,15 +641,15 @@ def get_rnnt_logprobs_pruned( Return the px (B, S, T + 1) and py (B, S + 1, T) needed by mutual_information_recursion. """ - # joint (B, T, s_range, C) + # logits (B, T, s_range, C) # symbols (B, S) # ranges (B, T, s_range) - assert joint.ndim == 4 - (B, T, s_range, C) = joint.shape + assert logits.ndim == 4 + (B, T, s_range, C) = logits.shape assert ranges.shape == (B, T, s_range) (B, S) = symbols.shape - normalizers = torch.logsumexp(joint, dim=3) + normalizers = torch.logsumexp(logits, dim=3) symbols_with_terminal = torch.cat( ( @@ -620,7 +672,7 @@ def get_rnnt_logprobs_pruned( # (B, T, s_range) px = torch.gather( - joint, dim=3, index=pruning_symbols.reshape(B, T, s_range, 1) + logits, dim=3, index=pruning_symbols.reshape(B, T, s_range, 1) ).squeeze(-1) px = px - normalizers @@ -652,7 +704,7 @@ def get_rnnt_logprobs_pruned( dim=2, ) # now: [B][S][T+1], index [:,:,T] has -inf.. - py = joint[:, :, :, termination_symbol] # (B, T, s_range) + py = logits[:, :, :, termination_symbol].clone() # (B, T, s_range) py = py - normalizers # (B, T, S + 1) with index larger than s_range in dim 2 filled with -inf @@ -682,23 +734,24 @@ def get_rnnt_logprobs_pruned( def rnnt_loss_pruned( - joint: Tensor, + logits: Tensor, symbols: Tensor, ranges: Tensor, termination_symbol: int, boundary: Tensor = None, + reduction: Optional[str] = "mean", ) -> Tensor: """A RNN-T loss with pruning, which uses a pruned 'joiner' network output as input, i.e. a 4 dimensions tensor with shape (B, T, s_range, C), s_range means the symbols number kept for each frame. Args: - joint: + logits: The pruned output of joiner network, with shape (B, T, s_range, C), i.e. batch, time_seq_len, prune_range, num_classes symbols: A LongTensor of shape [B][S], containing the symbols at each position - of the sequence, possibly including EOS + of the sequence. ranges: A tensor containing the symbol ids for each frame that we want to keep. termination_symbol: @@ -708,14 +761,35 @@ def rnnt_loss_pruned( [begin_symbol, begin_frame, end_symbol, end_frame] that is treated as [0, 0, S, T] if boundary is not supplied. Most likely you will want begin_symbol and begin_frame to be zero. + reduction: + Specifies the reduction to apply to the output: `none`, `mean` or `sum`. + `none`: no reduction will be applied. + `mean`: apply `torch.mean` over the batches. + `sum`: the output will be summed. + Default: `mean` Returns: - A Tensor of shape (B,), containing the total RNN-T loss values for each - element of the batch (like log-probs of sequences). + If recursion is `none`, returns a tensor of shape (B,), containing the + total RNN-T loss values for each element of the batch, otherwise a scalar + with the reduction applied. """ px, py = get_rnnt_logprobs_pruned( - joint, symbols, ranges, termination_symbol, boundary + logits=logits, + symbols=symbols, + ranges=ranges, + termination_symbol=termination_symbol, + boundary=boundary, ) - return mutual_information_recursion(px, py, boundary) + negated_loss = mutual_information_recursion(px=px, py=py, boundary=boundary) + if reduction == "none": + return -negated_loss + elif reduction == "mean": + return -torch.mean(negated_loss) + elif reduction == "sum": + return -torch.sum(negated_loss) + else: + assert ( + False + ), f"reduction should be ('none' | 'mean' | 'sum'), given {reduction}" def get_rnnt_logprobs_smoothed( @@ -765,7 +839,7 @@ def get_rnnt_logprobs_smoothed( next on this frame. symbols: A LongTensor of shape [B][S], containing the symbols at each position - of the sequence, possibly including EOS + of the sequence. termination_symbol: The identity of the termination symbol, must be in {0..C-1} lm_only_scale: @@ -921,10 +995,11 @@ def rnnt_loss_smoothed( lm_only_scale: float = 0.1, am_only_scale: float = 0.1, boundary: Optional[Tensor] = None, + reduction: Optional[str] = "mean", return_grad: bool = False, ) -> Tensor: """A simple case of the RNN-T loss, where the 'joiner' network is just - addition. Returns negated total loss value. + addition. Args: lm: @@ -951,27 +1026,49 @@ def rnnt_loss_smoothed( [0, 0, S, T] if boundary is not supplied. Most likely you will want begin_symbol and begin_frame to be zero. + reduction: + Specifies the reduction to apply to the output: `none`, `mean` or `sum`. + `none`: no reduction will be applied. + `mean`: apply `torch.mean` over the batches. + `sum`: the output will be summed. + Default: `mean` return_grad: Whether to return grads of px and py, this grad standing for the occupation probability is the output of the backward with a - `fake gradient` input (all ones) This is useful to implement the - pruned version of rnnt loss. + `fake gradient`, the `fake gradient` is the same as the gradient you'd + get if you did `torch.autograd.grad((-loss.sum()), [px, py])`, note, the + loss here is the loss with reduction "none". + This is useful to implement the pruned version of rnnt loss. Returns: - If return_grad is False, returns a Tensor of shape (B,), containing the - NEGATED total RNN-T loss values for each element of the batch - (like log-probs of sequences). + If return_grad is False, returns a tensor of shape (B,), containing the + total RNN-T loss values for each element of the batch if reduction equals + to "none", otherwise a scalar with the reduction applied. If return_grad is True, the grads of px and py, which is the output of - backward with a `fake gradient` input, will be returned too. And the + backward with a `fake gradient`(see above), will be returned too. And the returned value will be a tuple like (loss, (px_grad, py_grad)). """ px, py = get_rnnt_logprobs_smoothed( - lm, - am, - symbols, - termination_symbol, - lm_only_scale, - am_only_scale, - boundary, + lm=lm, + am=am, + symbols=symbols, + termination_symbol=termination_symbol, + lm_only_scale=lm_only_scale, + am_only_scale=am_only_scale, + boundary=boundary, + ) + scores_and_grads = mutual_information_recursion( + px=px, py=py, boundary=boundary, return_grad=return_grad ) - return mutual_information_recursion(px, py, boundary, return_grad) + negated_loss = scores_and_grads[0] if return_grad else scores_and_grads + if reduction == "none": + loss = -negated_loss + elif reduction == "mean": + loss = -torch.mean(negated_loss) + elif reduction == "sum": + loss = -torch.sum(negated_loss) + else: + assert ( + False + ), f"reduction should be ('none' | 'mean' | 'sum'), given {reduction}" + return (loss, scores_and_grads[1]) if return_grad else loss diff --git a/k2/python/tests/rnnt_loss_test.py b/k2/python/tests/rnnt_loss_test.py index 16cc7875f..d619591a8 100644 --- a/k2/python/tests/rnnt_loss_test.py +++ b/k2/python/tests/rnnt_loss_test.py @@ -81,36 +81,55 @@ def test_rnnt_loss_basic(self): termination_symbol = 2 symbols = torch.tensor([[0, 1, 0]], dtype=torch.long, device=device) - px, py = k2.get_rnnt_logprobs(lm, am, symbols, termination_symbol) + px, py = k2.get_rnnt_logprobs( + lm=lm, + am=am, + symbols=symbols, + termination_symbol=termination_symbol, + ) assert px.shape == (B, S, T + 1) assert py.shape == (B, S + 1, T) assert symbols.shape == (B, S) - m = k2.mutual_information_recursion(px, py) + m = k2.mutual_information_recursion(px=px, py=py, boundary=None) if device == torch.device("cpu"): - expected = m - assert torch.allclose(m, expected.to(device)) + expected = -m + assert torch.allclose(-m, expected.to(device)) # test rnnt_loss_simple - m = k2.rnnt_loss_simple(lm, am, symbols, termination_symbol, None) + m = k2.rnnt_loss_simple( + lm=lm, + am=am, + symbols=symbols, + termination_symbol=termination_symbol, + boundary=None, + reduction="none", + ) assert torch.allclose(m, expected.to(device)) # test rnnt_loss_smoothed m = k2.rnnt_loss_smoothed( - lm, - am, - symbols, - termination_symbol, + lm=lm, + am=am, + symbols=symbols, + termination_symbol=termination_symbol, lm_only_scale=0.0, am_only_scale=0.0, boundary=None, + reduction="none", ) assert torch.allclose(m, expected.to(device)) probs = am.unsqueeze(2) + lm.unsqueeze(1) # test rnnt_loss - m = k2.rnnt_loss(probs, symbols, termination_symbol, None) + m = k2.rnnt_loss( + logits=probs, + symbols=symbols, + termination_symbol=termination_symbol, + boundary=None, + reduction="none", + ) assert torch.allclose(m, expected.to(device)) # compare with torchaudio rnnt_loss @@ -129,28 +148,42 @@ def test_rnnt_loss_basic(self): blank=termination_symbol, reduction="none", ) - assert torch.allclose(-m, expected.to(device)) + assert torch.allclose(m, expected.to(device)) # should be invariant to adding a constant for any frame. lm += torch.randn(B, S + 1, 1, device=device) am += torch.randn(B, T, 1, device=device) - m = k2.rnnt_loss_simple(lm, am, symbols, termination_symbol, None) + m = k2.rnnt_loss_simple( + lm=lm, + am=am, + symbols=symbols, + termination_symbol=termination_symbol, + boundary=None, + reduction="none", + ) assert torch.allclose(m, expected.to(device)) m = k2.rnnt_loss_smoothed( - lm, - am, - symbols, - termination_symbol, + lm=lm, + am=am, + symbols=symbols, + termination_symbol=termination_symbol, lm_only_scale=0.0, am_only_scale=0.0, boundary=None, + reduction="none", ) assert torch.allclose(m, expected.to(device)) probs = am.unsqueeze(2) + lm.unsqueeze(1) - m = k2.rnnt_loss(probs, symbols, termination_symbol, None) + m = k2.rnnt_loss( + logits=probs, + symbols=symbols, + termination_symbol=termination_symbol, + boundary=None, + reduction="none", + ) assert torch.allclose(m, expected.to(device)) def test_rnnt_loss_random(self): @@ -182,27 +215,35 @@ def test_rnnt_loss_random(self): boundary = boundary_.to(device) px, py = k2.get_rnnt_logprobs( - lm, am, symbols, termination_symbol, boundary + lm=lm, + am=am, + symbols=symbols, + termination_symbol=termination_symbol, + boundary=boundary, ) assert px.shape == (B, S, T + 1) assert py.shape == (B, S + 1, T) assert symbols.shape == (B, S) - m = k2.mutual_information_recursion(px, py, boundary) + m = k2.mutual_information_recursion(px=px, py=py, boundary=boundary) if device == torch.device("cpu"): - expected = m - assert torch.allclose(m, expected.to(device)) + expected = -torch.mean(m) + assert torch.allclose(-torch.mean(m), expected.to(device)) m = k2.rnnt_loss_simple( - lm, am, symbols, termination_symbol, boundary + lm=lm, + am=am, + symbols=symbols, + termination_symbol=termination_symbol, + boundary=boundary, ) assert torch.allclose(m, expected.to(device)) m = k2.rnnt_loss_smoothed( - lm, - am, - symbols, - termination_symbol, + lm=lm, + am=am, + symbols=symbols, + termination_symbol=termination_symbol, lm_only_scale=0.0, am_only_scale=0.0, boundary=boundary, @@ -210,7 +251,12 @@ def test_rnnt_loss_random(self): assert torch.allclose(m, expected.to(device)) probs = am.unsqueeze(2) + lm.unsqueeze(1) - m = k2.rnnt_loss(probs, symbols, termination_symbol, boundary) + m = k2.rnnt_loss( + logits=probs, + symbols=symbols, + termination_symbol=termination_symbol, + boundary=boundary, + ) assert torch.allclose(m, expected.to(device)) # compare with torchaudio rnnt_loss @@ -223,28 +269,36 @@ def test_rnnt_loss_random(self): logit_lengths=boundary[:, 3].int(), target_lengths=boundary[:, 2].int(), blank=termination_symbol, - reduction="none", ) - assert torch.allclose(-m, expected.to(device)) + assert torch.allclose(m, expected.to(device)) # should be invariant to adding a constant for any frame. lm += torch.randn(B, S + 1, 1, device=device) am += torch.randn(B, T, 1, device=device) m = k2.rnnt_loss_simple( - lm, am, symbols, termination_symbol, boundary + lm=lm, + am=am, + symbols=symbols, + termination_symbol=termination_symbol, + boundary=boundary, ) assert torch.allclose(m, expected.to(device)) probs = am.unsqueeze(2) + lm.unsqueeze(1) - m = k2.rnnt_loss(probs, symbols, termination_symbol, boundary) + m = k2.rnnt_loss( + logits=probs, + symbols=symbols, + termination_symbol=termination_symbol, + boundary=boundary, + ) assert torch.allclose(m, expected.to(device)) m = k2.rnnt_loss_smoothed( - lm, - am, - symbols, - termination_symbol, + lm=lm, + am=am, + symbols=symbols, + termination_symbol=termination_symbol, lm_only_scale=0.0, am_only_scale=0.0, boundary=boundary, @@ -285,31 +339,27 @@ def test_rnnt_loss_gradient(self): logprobs = am.unsqueeze(2) + lm.unsqueeze(1) logprobs.requires_grad_() k2_loss = k2.rnnt_loss( - logprobs, symbols, termination_symbol, boundary - ) - k2_grad = torch.autograd.grad( - k2_loss, logprobs, -torch.ones_like(k2_loss) + logits=logprobs, + symbols=symbols, + termination_symbol=termination_symbol, + boundary=boundary, ) + k2_grad = torch.autograd.grad(k2_loss, logprobs) k2_grad = k2_grad[0] logprobs2 = logprobs.detach().clone().float() logprobs2.requires_grad_() torch_loss = torchaudio.functional.rnnt_loss( - logprobs2, - symbols.int(), - boundary[:, 3].int(), - boundary[:, 2].int(), + logits=logprobs2, + targets=symbols.int(), + logit_lengths=boundary[:, 3].int(), + target_lengths=boundary[:, 2].int(), blank=termination_symbol, - reduction="none", - ) - torch_grad = torch.autograd.grad( - torch_loss, logprobs2, torch.ones_like(torch_loss) ) + torch_grad = torch.autograd.grad(torch_loss, logprobs2) torch_grad = torch_grad[0] - assert torch.allclose( - -k2_loss, torch_loss, atol=1e-2, rtol=1e-2 - ) + assert torch.allclose(k2_loss, torch_loss, atol=1e-2, rtol=1e-2) assert torch.allclose(k2_grad, torch_grad, atol=1e-2, rtol=1e-2) @@ -336,10 +386,10 @@ def test_rnnt_loss_smoothed(self): symbols = torch.tensor([[0, 1, 0]], dtype=torch.long, device=device) m = k2.rnnt_loss_smoothed( - lm, - am, - symbols, - termination_symbol, + lm=lm, + am=am, + symbols=symbols, + termination_symbol=termination_symbol, lm_only_scale=0.0, am_only_scale=0.333, boundary=None, @@ -354,10 +404,10 @@ def test_rnnt_loss_smoothed(self): am += torch.randn(B, T, 1, device=device) m = k2.rnnt_loss_smoothed( - lm, - am, - symbols, - termination_symbol, + lm=lm, + am=am, + symbols=symbols, + termination_symbol=termination_symbol, lm_only_scale=0.0, am_only_scale=0.333, boundary=None, @@ -395,19 +445,36 @@ def test_rnnt_loss_pruned(self): t_prob = t_am + t_lm # nonlinear transform t_prob = torch.sigmoid(t_prob) - k2_loss = k2.rnnt_loss(t_prob, symbols, terminal_symbol, boundary) + k2_loss = k2.rnnt_loss( + logits=t_prob, + symbols=symbols, + termination_symbol=terminal_symbol, + boundary=boundary, + reduction="none", + ) print("unpruned rnnt loss: ", k2_loss) # pruning k2_simple_loss, (px_grad, py_grad) = k2.rnnt_loss_simple( - lm, am, symbols, terminal_symbol, boundary, True + lm=lm, + am=am, + symbols=symbols, + termination_symbol=terminal_symbol, + boundary=boundary, + return_grad=True, + reduction="none", ) for r in range(2, 50, 5): - ranges = k2.get_rnnt_prune_ranges(px_grad, py_grad, boundary, r) + ranges = k2.get_rnnt_prune_ranges( + px_grad=px_grad, + py_grad=py_grad, + boundary=boundary, + s_range=r, + ) # (B, T, r, C) - am_p, lm_p = k2.do_rnnt_pruning(am, lm, ranges) + am_p, lm_p = k2.do_rnnt_pruning(am=am, lm=lm, ranges=ranges) t_prob_p = am_p + lm_p @@ -415,9 +482,14 @@ def test_rnnt_loss_pruned(self): t_prob_p = torch.sigmoid(t_prob_p) pruning_loss = k2.rnnt_loss_pruned( - t_prob_p, symbols, ranges, terminal_symbol, boundary + logits=t_prob_p, + symbols=symbols, + ranges=ranges, + termination_symbol=terminal_symbol, + boundary=boundary, + reduction="none", ) - print(f"pruning loss with range {r} : ", pruning_loss) + print(f"pruned loss with range {r} : ", pruning_loss) if __name__ == "__main__":