From 2cf5ce460e98e12ad97d2272293ae9c60ea9d3f4 Mon Sep 17 00:00:00 2001 From: Somshubra Majumdar Date: Tue, 8 Jun 2021 15:39:08 -0700 Subject: [PATCH 01/11] Temp commit Signed-off-by: smajumdar --- .../asr/parts/numba/rnnt_loss/rnnt_numpy.py | 136 ++++++++++++++++-- 1 file changed, 125 insertions(+), 11 deletions(-) diff --git a/nemo/collections/asr/parts/numba/rnnt_loss/rnnt_numpy.py b/nemo/collections/asr/parts/numba/rnnt_loss/rnnt_numpy.py index aa814718cd9b..f04c02158a39 100644 --- a/nemo/collections/asr/parts/numba/rnnt_loss/rnnt_numpy.py +++ b/nemo/collections/asr/parts/numba/rnnt_loss/rnnt_numpy.py @@ -33,6 +33,57 @@ from torch.nn import Module +def check_type(var, t, name): + if var.dtype is not t: + raise TypeError("{} must be {}".format(name, t)) + + +def check_contiguous(var, name): + if not var.is_contiguous(): + raise ValueError("{} must be contiguous".format(name)) + + +def check_dim(var, dim, name): + if len(var.shape) != dim: + raise ValueError("{} must be {}D".format(name, dim)) + + +def certify_inputs(log_probs, labels, lengths, label_lengths): + # check_type(log_probs, torch.float32, "log_probs") + check_type(labels, torch.int32, "labels") + check_type(label_lengths, torch.int32, "label_lengths") + check_type(lengths, torch.int32, "lengths") + check_contiguous(log_probs, "log_probs") + check_contiguous(labels, "labels") + check_contiguous(label_lengths, "label_lengths") + check_contiguous(lengths, "lengths") + + if lengths.shape[0] != log_probs.shape[0]: + raise ValueError( + f"Must have a length per example. " + f"Given lengths dim: {lengths.shape[0]}, " + f"Log probs dim : {log_probs.shape[0]}" + ) + if label_lengths.shape[0] != log_probs.shape[0]: + raise ValueError( + "Must have a label length per example. " + f"Given label lengths dim : {label_lengths.shape[0]}, " + f"Log probs dim : {log_probs.shape[0]}" + ) + + check_dim(log_probs, 4, "log_probs") + check_dim(labels, 2, "labels") + check_dim(lengths, 1, "lenghts") + check_dim(label_lengths, 1, "label_lenghts") + max_T = torch.max(lengths) + max_U = torch.max(label_lengths) + T, U = log_probs.shape[1:3] + if T != max_T: + raise ValueError(f"Input length mismatch! Given T: {T}, Expected max T from input lengths: {max_T}") + if U != max_U + 1: + raise ValueError(f"Output length mismatch! Given U: {U}, Expected max U from target lengths: {max_U} + 1") + + def _assert_no_grad(tensor): assert not tensor.requires_grad, ( "gradients only computed for log_probs - please " "mark other tensors as not requiring gradients" @@ -102,7 +153,7 @@ def backward_pass(log_probs, labels, blank): return betas, betas[0, 0] -def compute_gradient(log_probs, alphas, betas, labels, blank): +def compute_gradient(log_probs, alphas, betas, labels, blank, fastemit_lambda): """ Computes the gradients of the log_probs with respect to the log probability of this step occuring. @@ -129,10 +180,52 @@ def compute_gradient(log_probs, alphas, betas, labels, blank): grads[:, u, l] = alphas[:, u] + betas[:, u + 1] grads = -np.exp(grads + log_probs - log_like) + + if fastemit_lambda > 0.0: + for u, l in enumerate(labels): + grads[:, u, l] = (1.0 + fastemit_lambda) * grads[:, u, l] + return grads -def transduce(log_probs, labels, blank=0): +def fastemit_regularization(log_probs, labels, alphas, betas, blank, fastemit_lambda): + """ + Computes probability of the forward variable alpha. + + Args: + log_probs: Tensor of shape [T, U, V+1] + labels: Labels of shape [B, U] + blank: Index of the blank token. + + Returns: + A tuple of the forward variable probabilities - alpha of shape [T, U] + and the log likelihood of this forward step. + """ + T, U, _ = log_probs.shape + alignment = np.zeros((T, U), dtype='f') + + for u in range(U): + for n in range(0, T + U): + t = n - u + + + + # log_like = betas[0, 0] + + # // grad to last blank transition + # grads[T - 1, U - 1, blank] = alphas[T - 1, U - 1] + # + # grads[: T - 1, :, blank] = alphas[: T - 1, :] + betas[1:, :] + # for u, l in enumerate(labels): + # grads[:, u, l] = alphas[:, u] + betas[:, u + 1] + # + # grads = -np.exp(grads + log_probs - log_like) + + loglike = alphas[T - 1, U - 1] + log_probs[T - 1, U - 1, blank] + return alphas, loglike + + +def transduce(log_probs, labels, blank=0, fastemit_lambda=0.0): """ Args: log_probs: 3D array with shape @@ -145,11 +238,11 @@ def transduce(log_probs, labels, blank=0): """ alphas, ll_forward = forward_pass(log_probs, labels, blank) betas, ll_backward = backward_pass(log_probs, labels, blank) - grads = compute_gradient(log_probs, alphas, betas, labels, blank) - return -ll_forward, grads + grads = compute_gradient(log_probs, alphas, betas, labels, blank, fastemit_lambda) + return -ll_forward, grads, alphas, betas -def transduce_batch(log_probs, labels, flen, glen, blank=0): +def transduce_batch(log_probs, labels, flen, glen, blank=0, fastemit_lambda=0.0): """ Compute the transducer loss of the batch. @@ -168,17 +261,25 @@ def transduce_batch(log_probs, labels, flen, glen, blank=0): for b in range(log_probs.shape[0]): t = int(flen[b]) u = int(glen[b]) + 1 - ll, g = transduce(log_probs[b, :t, :u, :], labels[b, : u - 1], blank) + ll, g, alphas, betas = transduce(log_probs[b, :t, :u, :], labels[b, : u - 1], blank, fastemit_lambda) grads[b, :t, :u, :] = g + + _ = fastemit_regularization(log_probs[b, :t, :u, :], labels[b, : u - 1], alphas, betas, blank, fastemit_lambda) + costs.append(ll) return costs, grads class _RNNT(Function): @staticmethod - def forward(ctx, acts, labels, act_lens, label_lens, blank): + def forward(ctx, acts, labels, act_lens, label_lens, blank, fastemit_lambda): costs, grads = transduce_batch( - acts.detach().cpu().numpy(), labels.cpu().numpy(), act_lens.cpu().numpy(), label_lens.cpu().numpy(), blank, + acts.detach().cpu().numpy(), + labels.cpu().numpy(), + act_lens.cpu().numpy(), + label_lens.cpu().numpy(), + blank, + fastemit_lambda, ) costs = torch.FloatTensor([sum(costs)]) @@ -189,7 +290,7 @@ def forward(ctx, acts, labels, act_lens, label_lens, blank): @staticmethod def backward(ctx, grad_output): - return ctx.grads, None, None, None, None + return ctx.grads, None, None, None, None, None class RNNTLoss(Module): @@ -198,9 +299,10 @@ class RNNTLoss(Module): `blank_label` (int): default 0 - label index of blank token """ - def __init__(self, blank: int = 0): + def __init__(self, blank: int = 0, fastemit_lambda: float = 0.0): super(RNNTLoss, self).__init__() self.blank = blank + self.fastemit_lambda = fastemit_lambda self.rnnt = _RNNT.apply def forward(self, acts, labels, act_lens, label_lens): @@ -208,6 +310,18 @@ def forward(self, acts, labels, act_lens, label_lens): _assert_no_grad(labels) _assert_no_grad(act_lens) _assert_no_grad(label_lens) + certify_inputs(acts, labels, act_lens, label_lens) acts = torch.nn.functional.log_softmax(acts, -1) - return self.rnnt(acts, labels, act_lens, label_lens, self.blank) + return self.rnnt(acts, labels, act_lens, label_lens, self.blank, self.fastemit_lambda) + + +if __name__ == '__main__': + loss = RNNTLoss(fastemit_lambda=0.01) + + acts = torch.randn(1, 10, 11, 3) + labels = torch.tensor([[0, 1, 1, 2, 1, 1, 2, 1, 1, 2]], dtype=torch.int32) + act_lens = torch.tensor([10], dtype=torch.int32) + label_lens = torch.tensor([len(labels[0])], dtype=torch.int32) + + loss_val = loss(acts, labels, act_lens, label_lens) From 9a832a61407f3d9929001a9fa77cd8135d2c8caa Mon Sep 17 00:00:00 2001 From: Somshubra Majumdar Date: Sat, 12 Jun 2021 02:34:03 -0700 Subject: [PATCH 02/11] Initial code for fastemit forward pass Signed-off-by: smajumdar --- .../asr/parts/numba/rnnt_loss/rnnt_numpy.py | 41 +++++++++---------- 1 file changed, 20 insertions(+), 21 deletions(-) diff --git a/nemo/collections/asr/parts/numba/rnnt_loss/rnnt_numpy.py b/nemo/collections/asr/parts/numba/rnnt_loss/rnnt_numpy.py index f04c02158a39..9a7d0a551fbb 100644 --- a/nemo/collections/asr/parts/numba/rnnt_loss/rnnt_numpy.py +++ b/nemo/collections/asr/parts/numba/rnnt_loss/rnnt_numpy.py @@ -170,12 +170,13 @@ def compute_gradient(log_probs, alphas, betas, labels, blank, fastemit_lambda): """ T, U, _ = log_probs.shape grads = np.full(log_probs.shape, -float("inf")) - log_like = betas[0, 0] + log_like = betas[0, 0] # == alphas[T - 1, U - 1] + betas[T - 1, U - 1] # // grad to last blank transition grads[T - 1, U - 1, blank] = alphas[T - 1, U - 1] - grads[: T - 1, :, blank] = alphas[: T - 1, :] + betas[1:, :] + + # // grad to label transition for u, l in enumerate(labels): grads[:, u, l] = alphas[:, u] + betas[:, u + 1] @@ -202,27 +203,23 @@ def fastemit_regularization(log_probs, labels, alphas, betas, blank, fastemit_la and the log likelihood of this forward step. """ T, U, _ = log_probs.shape - alignment = np.zeros((T, U), dtype='f') - - for u in range(U): - for n in range(0, T + U): - t = n - u - + alignment = np.zeros((T, U), dtype='float32') + for t in range(0, T): + alignment[t, U - 1] = alphas[t, U - 1] + betas[t, U - 1] - # log_like = betas[0, 0] + for t in range(0, T): + for u in range(0, U - 1): + emit = alphas[t, u] + log_probs[t, u, labels[u]] + betas[t, u + 1] + alignment[t, u] = emit - # // grad to last blank transition - # grads[T - 1, U - 1, blank] = alphas[T - 1, U - 1] - # - # grads[: T - 1, :, blank] = alphas[: T - 1, :] + betas[1:, :] - # for u, l in enumerate(labels): - # grads[:, u, l] = alphas[:, u] + betas[:, u + 1] - # - # grads = -np.exp(grads + log_probs - log_like) + # to compute likelihood. + # loglike = betas[0, 0] + # a = np.exp(alignment - loglike) + # print(a) - loglike = alphas[T - 1, U - 1] + log_probs[T - 1, U - 1, blank] - return alphas, loglike + reg = fastemit_lambda * (alignment[T - 1, U - 1]) + return alignment, reg def transduce(log_probs, labels, blank=0, fastemit_lambda=0.0): @@ -319,9 +316,11 @@ def forward(self, acts, labels, act_lens, label_lens): if __name__ == '__main__': loss = RNNTLoss(fastemit_lambda=0.01) - acts = torch.randn(1, 10, 11, 3) + torch.manual_seed(0) + + acts = torch.randn(1, 3, 11, 3) labels = torch.tensor([[0, 1, 1, 2, 1, 1, 2, 1, 1, 2]], dtype=torch.int32) - act_lens = torch.tensor([10], dtype=torch.int32) + act_lens = torch.tensor([3], dtype=torch.int32) label_lens = torch.tensor([len(labels[0])], dtype=torch.int32) loss_val = loss(acts, labels, act_lens, label_lens) From 1661c66ec4c9a345d08d15e4570d4fe11ef01cdd Mon Sep 17 00:00:00 2001 From: Somshubra Majumdar Date: Sun, 13 Jun 2021 01:03:39 -0700 Subject: [PATCH 03/11] Correct return reg value Signed-off-by: smajumdar --- .../asr/parts/numba/rnnt_loss/rnnt_numpy.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/nemo/collections/asr/parts/numba/rnnt_loss/rnnt_numpy.py b/nemo/collections/asr/parts/numba/rnnt_loss/rnnt_numpy.py index 9a7d0a551fbb..a450aa5a0773 100644 --- a/nemo/collections/asr/parts/numba/rnnt_loss/rnnt_numpy.py +++ b/nemo/collections/asr/parts/numba/rnnt_loss/rnnt_numpy.py @@ -219,7 +219,8 @@ def fastemit_regularization(log_probs, labels, alphas, betas, blank, fastemit_la # print(a) reg = fastemit_lambda * (alignment[T - 1, U - 1]) - return alignment, reg + # reg = fastemit_lambda * (alphas[T - 1, U - 1] + betas[T - 1, U - 1]) + return -reg def transduce(log_probs, labels, blank=0, fastemit_lambda=0.0): @@ -258,11 +259,12 @@ def transduce_batch(log_probs, labels, flen, glen, blank=0, fastemit_lambda=0.0) for b in range(log_probs.shape[0]): t = int(flen[b]) u = int(glen[b]) + 1 + ll, g, alphas, betas = transduce(log_probs[b, :t, :u, :], labels[b, : u - 1], blank, fastemit_lambda) grads[b, :t, :u, :] = g - _ = fastemit_regularization(log_probs[b, :t, :u, :], labels[b, : u - 1], alphas, betas, blank, fastemit_lambda) - + reg = fastemit_regularization(log_probs[b, :t, :u, :], labels[b, : u - 1], alphas, betas, blank, fastemit_lambda) + ll += reg costs.append(ll) return costs, grads @@ -318,9 +320,9 @@ def forward(self, acts, labels, act_lens, label_lens): torch.manual_seed(0) - acts = torch.randn(1, 3, 11, 3) - labels = torch.tensor([[0, 1, 1, 2, 1, 1, 2, 1, 1, 2]], dtype=torch.int32) - act_lens = torch.tensor([3], dtype=torch.int32) + acts = torch.randn(1, 2, 5, 3) + labels = torch.tensor([[0, 2, 1, 2]], dtype=torch.int32) + act_lens = torch.tensor([2], dtype=torch.int32) label_lens = torch.tensor([len(labels[0])], dtype=torch.int32) loss_val = loss(acts, labels, act_lens, label_lens) From 7f67eeb6eab4791eb883524bcfdf43c89157d9c2 Mon Sep 17 00:00:00 2001 From: Samuel Kriman Date: Thu, 17 Jun 2021 21:03:31 -0700 Subject: [PATCH 04/11] Initial cpu impl Signed-off-by: smajumdar --- .../asr/parts/numba/rnnt_loss/rnnt.py | 8 +++ .../asr/parts/numba/rnnt_loss/rnnt_numpy.py | 50 +++++++++++-------- .../asr/parts/numba/rnnt_loss/rnnt_pytorch.py | 12 +++-- .../rnnt_loss/utils/cpu_utils/cpu_rnnt.py | 12 ++++- .../rnnt_loss/utils/cuda_utils/gpu_rnnt.py | 18 ++++++- .../utils/cuda_utils/gpu_rnnt_kernel.py | 21 +++++++- .../asr/numba/rnnt_loss/test_rnnt_pytorch.py | 29 +++++++++++ 7 files changed, 122 insertions(+), 28 deletions(-) diff --git a/nemo/collections/asr/parts/numba/rnnt_loss/rnnt.py b/nemo/collections/asr/parts/numba/rnnt_loss/rnnt.py index 5caf2fea2fc2..4acdb680cd13 100644 --- a/nemo/collections/asr/parts/numba/rnnt_loss/rnnt.py +++ b/nemo/collections/asr/parts/numba/rnnt_loss/rnnt.py @@ -44,6 +44,7 @@ def rnnt_loss_cpu( costs: torch.Tensor, grads: torch.Tensor, blank_label: int, + fastemit_lambda: float, num_threads: int, ): """ @@ -59,6 +60,8 @@ def rnnt_loss_cpu( costs: Zero vector of length [B] in which costs will be set. grads: Zero tensor of shape [B, T, U, V+1] where the gradient will be set. blank_label: Index of the blank token in the vocabulary. + fastemit_lambda: Float scaling factor for FastEmit regularization. Refer to + FastEmit: Low-latency Streaming ASR with Sequence-level Emission Regularization. num_threads: Number of threads for OpenMP. """ # aliases @@ -92,6 +95,7 @@ def rnnt_loss_cpu( alphabet_size=alphabet_size, workspace=cpu_workspace, blank=blank_label, + fastemit_lambda=fastemit_lambda, num_threads=num_threads, batch_first=True, ) @@ -136,6 +140,7 @@ def rnnt_loss_gpu( costs: torch.Tensor, grads: torch.Tensor, blank_label: int, + fastemit_lambda: float, num_threads: int, ): """ @@ -151,6 +156,8 @@ def rnnt_loss_gpu( costs: Zero vector of length [B] in which costs will be set. grads: Zero tensor of shape [B, T, U, V+1] where the gradient will be set. blank_label: Index of the blank token in the vocabulary. + fastemit_lambda: Float scaling factor for FastEmit regularization. Refer to + FastEmit: Low-latency Streaming ASR with Sequence-level Emission Regularization. num_threads: Number of threads for OpenMP. """ minibatch_size = acts.shape[0] @@ -189,6 +196,7 @@ def rnnt_loss_gpu( alphabet_size=alphabet_size, workspace=gpu_workspace, blank=blank_label, + fastemit_lambda=fastemit_lambda, num_threads=num_threads, stream=stream, ) diff --git a/nemo/collections/asr/parts/numba/rnnt_loss/rnnt_numpy.py b/nemo/collections/asr/parts/numba/rnnt_loss/rnnt_numpy.py index a450aa5a0773..689da70e743d 100644 --- a/nemo/collections/asr/parts/numba/rnnt_loss/rnnt_numpy.py +++ b/nemo/collections/asr/parts/numba/rnnt_loss/rnnt_numpy.py @@ -191,35 +191,38 @@ def compute_gradient(log_probs, alphas, betas, labels, blank, fastemit_lambda): def fastemit_regularization(log_probs, labels, alphas, betas, blank, fastemit_lambda): """ - Computes probability of the forward variable alpha. + Describes the computation of FastEmit regularization from the paper - + [FastEmit: Low-latency Streaming ASR with Sequence-level Emission Regularization](https://arxiv.org/abs/2010.11148) Args: log_probs: Tensor of shape [T, U, V+1] - labels: Labels of shape [B, U] + labels: Unused. Labels of shape [B, U] + alphas: Tensor of shape [T, U] which represents the forward variable. + betas: Unused. Tensor of shape [T, U] which represents the backward variable. blank: Index of the blank token. + fastemit_lambda: Float scaling factor for FastEmit regularization. Returns: - A tuple of the forward variable probabilities - alpha of shape [T, U] - and the log likelihood of this forward step. + The regularized negative log likelihood - lambda * P˜(At, u|x) """ + # General calculation of the fastemit regularization alignments T, U, _ = log_probs.shape - alignment = np.zeros((T, U), dtype='float32') - - for t in range(0, T): - alignment[t, U - 1] = alphas[t, U - 1] + betas[t, U - 1] - - for t in range(0, T): - for u in range(0, U - 1): - emit = alphas[t, u] + log_probs[t, u, labels[u]] + betas[t, u + 1] - alignment[t, u] = emit - - # to compute likelihood. - # loglike = betas[0, 0] - # a = np.exp(alignment - loglike) - # print(a) - - reg = fastemit_lambda * (alignment[T - 1, U - 1]) + # alignment = np.zeros((T, U), dtype='float32') + # + # for t in range(0, T): + # alignment[t, U - 1] = alphas[t, U - 1] + betas[t, U - 1] + # + # for t in range(0, T): + # for u in range(0, U - 1): + # emit = alphas[t, u] + log_probs[t, u, labels[u]] + betas[t, u + 1] + # alignment[t, u] = emit + # reg = fastemit_lambda * (alignment[T - 1, U - 1]) + + # The above is equivalent to below, without need of computing above # reg = fastemit_lambda * (alphas[T - 1, U - 1] + betas[T - 1, U - 1]) + + # The above is also equivalent to below, without need of computing the betas alignment matrix + reg = fastemit_lambda * (alphas[T - 1, U - 1] + log_probs[T - 1, U - 1, blank]) return -reg @@ -229,10 +232,15 @@ def transduce(log_probs, labels, blank=0, fastemit_lambda=0.0): log_probs: 3D array with shape [input len, output len + 1, vocab size] labels: 1D array with shape [output time steps] + blank: Index of the blank token. + fastemit_lambda: Float scaling factor for FastEmit regularization. + Returns: float: The negative log-likelihood 3D array: Gradients with respect to the unnormalized input actications + 2d arrays: Alphas matrix (TxU) + 2d array: Betas matrix (TxU) """ alphas, ll_forward = forward_pass(log_probs, labels, blank) betas, ll_backward = backward_pass(log_probs, labels, blank) @@ -250,6 +258,7 @@ def transduce_batch(log_probs, labels, flen, glen, blank=0, fastemit_lambda=0.0) flen: Length vector of the acoustic sequence. glen: Length vector of the target sequence. blank: Id of the blank token. + fastemit_lambda: Float scaling factor for FastEmit regularization. Returns: Batch of transducer forward log probabilities (loss) and the gradients of the activation matrix. @@ -296,6 +305,7 @@ class RNNTLoss(Module): """ Parameters: `blank_label` (int): default 0 - label index of blank token + fastemit_lambda: Float scaling factor for FastEmit regularization. """ def __init__(self, blank: int = 0, fastemit_lambda: float = 0.0): diff --git a/nemo/collections/asr/parts/numba/rnnt_loss/rnnt_pytorch.py b/nemo/collections/asr/parts/numba/rnnt_loss/rnnt_pytorch.py index 016f2dce9d8a..8623ebbc1ccb 100644 --- a/nemo/collections/asr/parts/numba/rnnt_loss/rnnt_pytorch.py +++ b/nemo/collections/asr/parts/numba/rnnt_loss/rnnt_pytorch.py @@ -38,12 +38,14 @@ class _RNNTNumba(Function): @staticmethod - def forward(ctx, acts, labels, act_lens, label_lens, blank, reduction): + def forward(ctx, acts, labels, act_lens, label_lens, blank, reduction, fastemit_lambda): """ log_probs: Tensor of (batch x seqLength x labelLength x outputDim) containing output from network labels: 2 dimensional Tensor containing all the targets of the batch with zero padded act_lens: Tensor of size (batch) containing size of each output sequence from the network label_lens: Tensor of (batch) containing label length of each example + fastemit_lambda: Float scaling factor for FastEmit regularization. Refer to + FastEmit: Low-latency Streaming ASR with Sequence-level Emission Regularization. """ is_cuda = acts.is_cuda @@ -62,6 +64,7 @@ def forward(ctx, acts, labels, act_lens, label_lens, blank, reduction): costs=costs, grads=grads, blank_label=blank, + fastemit_lambda=fastemit_lambda, num_threads=0, ) @@ -82,7 +85,7 @@ def forward(ctx, acts, labels, act_lens, label_lens, blank, reduction): def backward(ctx, grad_output): if grad_output is not None and ctx.grads is not None: grad_output = grad_output.view(-1, 1, 1, 1).to(ctx.grads) - return ctx.grads.mul_(grad_output), None, None, None, None, None + return ctx.grads.mul_(grad_output), None, None, None, None, None, None def rnnt_loss(acts, labels, act_lens, label_lens, blank=0, reduction='mean'): @@ -114,9 +117,10 @@ class RNNTLossNumba(Module): then the mean over the batch is taken. Default: 'mean' """ - def __init__(self, blank=0, reduction='mean'): + def __init__(self, blank=0, reduction='mean', fastemit_lambda: float = 0.0): super(RNNTLossNumba, self).__init__() self.blank = blank + self.fastemit_lambda = fastemit_lambda self.reduction = reduction self.loss = _RNNTNumba.apply @@ -132,7 +136,7 @@ def forward(self, acts, labels, act_lens, label_lens): # log_softmax is computed within GPU version. acts = torch.nn.functional.log_softmax(acts, -1) - return self.loss(acts, labels, act_lens, label_lens, self.blank, self.reduction) + return self.loss(acts, labels, act_lens, label_lens, self.blank, self.reduction, self.fastemit_lambda) def check_type(var, t, name): diff --git a/nemo/collections/asr/parts/numba/rnnt_loss/utils/cpu_utils/cpu_rnnt.py b/nemo/collections/asr/parts/numba/rnnt_loss/utils/cpu_utils/cpu_rnnt.py index 3e4cac7a0969..5e6a0247d449 100644 --- a/nemo/collections/asr/parts/numba/rnnt_loss/utils/cpu_utils/cpu_rnnt.py +++ b/nemo/collections/asr/parts/numba/rnnt_loss/utils/cpu_utils/cpu_rnnt.py @@ -146,6 +146,7 @@ def __init__( alphabet_size: int, workspace: torch.Tensor, blank: int, + fastemit_lambda: float, num_threads: int, batch_first: bool, ): @@ -160,6 +161,8 @@ def __init__( workspace: An allocated chunk of memory that will be sliced off and reshaped into required blocks used as working memory. blank: Index of the RNNT blank token in the vocabulary. Generally the first or last token in the vocab. + fastemit_lambda: Float scaling factor for FastEmit regularization. Refer to + FastEmit: Low-latency Streaming ASR with Sequence-level Emission Regularization. num_threads: Number of OMP threads to launch. batch_first: Bool that decides if batch dimension is first or third. """ @@ -169,6 +172,7 @@ def __init__( self.alphabet_size_ = alphabet_size self.workspace = workspace # a flat vector of floatX numbers that represents allocated memory slices self.blank_ = blank + self.fastemit_lambda_ = fastemit_lambda self.num_threads_ = num_threads self.batch_first = batch_first @@ -199,6 +203,10 @@ def cost_and_grad_kernel( grad, rnntm.log_probs2, T, U, rnntm.alphas, rnntm.betas, labels, llForward ) + # Scale llForward by FastEmit lambda + llForward *= (1.0 + self.fastemit_lambda_) + llBackward *= (1.0 + self.fastemit_lambda_) + diff = (llForward - llBackward).abs() if diff > 0.1: print(f"WARNING: Forward backward likelihood mismatch : {diff}") @@ -291,7 +299,9 @@ def compute_betas_and_grads( if u < U - 1: g = alphas[idx(t, u)] + betas[idx(t, u + 1)] - grad[idx(t, u, labels[u])] = -torch.exp(log_probs[idx(t, u) * 2 + 1] + g - loglike) + grad[idx(t, u, labels[u])] = -torch.exp( + math.log1p(self.fastemit_lambda_) + log_probs[idx(t, u) * 2 + 1] + g - loglike + ) # // gradient to the last blank transition grad[idx(T - 1, U - 1, self.blank_)] = -torch.exp( diff --git a/nemo/collections/asr/parts/numba/rnnt_loss/utils/cuda_utils/gpu_rnnt.py b/nemo/collections/asr/parts/numba/rnnt_loss/utils/cuda_utils/gpu_rnnt.py index ce940c3b3925..801458ea3301 100644 --- a/nemo/collections/asr/parts/numba/rnnt_loss/utils/cuda_utils/gpu_rnnt.py +++ b/nemo/collections/asr/parts/numba/rnnt_loss/utils/cuda_utils/gpu_rnnt.py @@ -39,7 +39,16 @@ class GPURNNT: def __init__( - self, minibatch: int, maxT: int, maxU: int, alphabet_size: int, workspace, blank: int, num_threads: int, stream + self, + minibatch: int, + maxT: int, + maxU: int, + alphabet_size: int, + workspace, + blank: int, + fastemit_lambda, + num_threads: int, + stream, ): """ Helper class to launch the CUDA Kernels to compute the Transducer Loss. @@ -52,6 +61,8 @@ def __init__( workspace: An allocated chunk of memory that will be sliced off and reshaped into required blocks used as working memory. blank: Index of the RNNT blank token in the vocabulary. Generally the first or last token in the vocab. + fastemit_lambda: Float scaling factor for FastEmit regularization. Refer to + FastEmit: Low-latency Streaming ASR with Sequence-level Emission Regularization. num_threads: Number of OMP threads to launch. stream: Numba Cuda Stream. """ @@ -63,6 +74,7 @@ def __init__( workspace ) # a flat vector of floatX numbers that represents allocated memory slices self.blank_ = blank + self.fastemit_lambda_ = fastemit_lambda self.num_threads_ = num_threads self.stream_ = stream # type: cuda.cudadrv.driver.Stream @@ -207,6 +219,7 @@ def compute_cost_and_score( self.maxU_, self.alphabet_size_, self.blank_, + self.fastemit_lambda_, ) # // cost @@ -215,7 +228,8 @@ def compute_cost_and_score( # compute negative log likelihood. for mb in range(self.minibatch_): - costs[mb] = -costs[mb] + # Scale llForward by FastEmit lambda + costs[mb] = (1.0 + self.fastemit_lambda_) * -costs[mb] return global_constants.RNNTStatus.RNNT_STATUS_SUCCESS diff --git a/nemo/collections/asr/parts/numba/rnnt_loss/utils/cuda_utils/gpu_rnnt_kernel.py b/nemo/collections/asr/parts/numba/rnnt_loss/utils/cuda_utils/gpu_rnnt_kernel.py index 7f7d0227aaf6..8ce94e6fd1ec 100644 --- a/nemo/collections/asr/parts/numba/rnnt_loss/utils/cuda_utils/gpu_rnnt_kernel.py +++ b/nemo/collections/asr/parts/numba/rnnt_loss/utils/cuda_utils/gpu_rnnt_kernel.py @@ -99,6 +99,8 @@ def compute_alphas_kernel( maxU: The maximum possible target sequence length. Represents U in the logprobs tensor. alphabet_size: The vocabulary dimension V+1 (inclusive of RNNT blank). blank_: Index of the RNNT blank token in the vocabulary. Generally the first or last token in the vocab. + fastemit_lambda: Float scaling factor for FastEmit regularization. Refer to + FastEmit: Low-latency Streaming ASR with Sequence-level Emission Regularization. Updates: Kernel inplace updates the following inputs: @@ -277,6 +279,7 @@ def compute_grad_kernel( maxU: int, alphabet_size: int, blank_: int, + fastemit_lambda: float, ): """ Compute gradients over the transduction step. @@ -302,6 +305,8 @@ def compute_grad_kernel( maxU: The maximum possible target sequence length. Represents U in the logprobs tensor. alphabet_size: The vocabulary dimension V+1 (inclusive of RNNT blank). blank_: Index of the RNNT blank token in the vocabulary. Generally the first or last token in the vocab. + fastemit_lambda: Float scaling factor for FastEmit regularization. Refer to + FastEmit: Low-latency Streaming ASR with Sequence-level Emission Regularization. Updates: Kernel inplace updates the following inputs: @@ -350,12 +355,26 @@ def compute_grad_kernel( # grad of blank across t < T; # grad[b, t Date: Thu, 17 Jun 2021 21:56:48 -0700 Subject: [PATCH 05/11] Try gpu impl Signed-off-by: smajumdar --- .../numba/rnnt_loss/utils/cuda_utils/gpu_rnnt_kernel.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/nemo/collections/asr/parts/numba/rnnt_loss/utils/cuda_utils/gpu_rnnt_kernel.py b/nemo/collections/asr/parts/numba/rnnt_loss/utils/cuda_utils/gpu_rnnt_kernel.py index 8ce94e6fd1ec..b7e74aab4375 100644 --- a/nemo/collections/asr/parts/numba/rnnt_loss/utils/cuda_utils/gpu_rnnt_kernel.py +++ b/nemo/collections/asr/parts/numba/rnnt_loss/utils/cuda_utils/gpu_rnnt_kernel.py @@ -99,8 +99,6 @@ def compute_alphas_kernel( maxU: The maximum possible target sequence length. Represents U in the logprobs tensor. alphabet_size: The vocabulary dimension V+1 (inclusive of RNNT blank). blank_: Index of the RNNT blank token in the vocabulary. Generally the first or last token in the vocab. - fastemit_lambda: Float scaling factor for FastEmit regularization. Refer to - FastEmit: Low-latency Streaming ASR with Sequence-level Emission Regularization. Updates: Kernel inplace updates the following inputs: @@ -371,7 +369,9 @@ def compute_grad_kernel( # print(mb, t, u, idx, "init u grad", grad) # math.log1p(fastemit_lambda) + - grad -= math.exp(math.log1p(fastemit_lambda) + alphas[col] + logpk - logll[mb] + betas[col + 1]) + grad -= math.exp(alphas[col] + logpk - logll[mb] + betas[col + 1]) + + # print("LABEL", mb, t, u, idx, "label", labels[u], "grad", grad) # if mb == 0 and t == 0 and u == 0: # print(mb, t, u, idx, "final u grad", grad) From 6bf087b2078bade5ac4ce83910ebb20400f9f4ee Mon Sep 17 00:00:00 2001 From: Somshubra Majumdar Date: Fri, 18 Jun 2021 01:53:06 -0700 Subject: [PATCH 06/11] Try gpu impl Signed-off-by: smajumdar --- nemo/collections/asr/parts/numba/rnnt_loss/rnnt_numpy.py | 2 ++ .../collections/asr/parts/numba/rnnt_loss/rnnt_pytorch.py | 1 + .../parts/numba/rnnt_loss/utils/cuda_utils/gpu_rnnt.py | 3 ++- .../numba/rnnt_loss/utils/cuda_utils/gpu_rnnt_kernel.py | 8 +++++++- .../collections/asr/numba/rnnt_loss/test_rnnt_pytorch.py | 3 ++- 5 files changed, 14 insertions(+), 3 deletions(-) diff --git a/nemo/collections/asr/parts/numba/rnnt_loss/rnnt_numpy.py b/nemo/collections/asr/parts/numba/rnnt_loss/rnnt_numpy.py index 689da70e743d..09b60da2f7c4 100644 --- a/nemo/collections/asr/parts/numba/rnnt_loss/rnnt_numpy.py +++ b/nemo/collections/asr/parts/numba/rnnt_loss/rnnt_numpy.py @@ -186,6 +186,8 @@ def compute_gradient(log_probs, alphas, betas, labels, blank, fastemit_lambda): for u, l in enumerate(labels): grads[:, u, l] = (1.0 + fastemit_lambda) * grads[:, u, l] + print("numpy", alphas[0, 0], betas[1, 0], log_probs[0, 0, 0], log_like, "final ", grads[0, 0, 0]) + return grads diff --git a/nemo/collections/asr/parts/numba/rnnt_loss/rnnt_pytorch.py b/nemo/collections/asr/parts/numba/rnnt_loss/rnnt_pytorch.py index 8623ebbc1ccb..aa9204a6db43 100644 --- a/nemo/collections/asr/parts/numba/rnnt_loss/rnnt_pytorch.py +++ b/nemo/collections/asr/parts/numba/rnnt_loss/rnnt_pytorch.py @@ -76,6 +76,7 @@ def forward(ctx, acts, labels, act_lens, label_lens, blank, reduction, fastemit_ if grads is not None: grads /= minibatch_size + print("final cuda grad", grads[0, 0, 0, 0]) # costs = costs.to(log_probs.device) ctx.grads = grads diff --git a/nemo/collections/asr/parts/numba/rnnt_loss/utils/cuda_utils/gpu_rnnt.py b/nemo/collections/asr/parts/numba/rnnt_loss/utils/cuda_utils/gpu_rnnt.py index 801458ea3301..9277c530056c 100644 --- a/nemo/collections/asr/parts/numba/rnnt_loss/utils/cuda_utils/gpu_rnnt.py +++ b/nemo/collections/asr/parts/numba/rnnt_loss/utils/cuda_utils/gpu_rnnt.py @@ -229,7 +229,8 @@ def compute_cost_and_score( # compute negative log likelihood. for mb in range(self.minibatch_): # Scale llForward by FastEmit lambda - costs[mb] = (1.0 + self.fastemit_lambda_) * -costs[mb] + costs[mb] = -costs[mb] + costs[mb] = (1.0 + self.fastemit_lambda_) * costs[mb] return global_constants.RNNTStatus.RNNT_STATUS_SUCCESS diff --git a/nemo/collections/asr/parts/numba/rnnt_loss/utils/cuda_utils/gpu_rnnt_kernel.py b/nemo/collections/asr/parts/numba/rnnt_loss/utils/cuda_utils/gpu_rnnt_kernel.py index b7e74aab4375..aeac7150fe85 100644 --- a/nemo/collections/asr/parts/numba/rnnt_loss/utils/cuda_utils/gpu_rnnt_kernel.py +++ b/nemo/collections/asr/parts/numba/rnnt_loss/utils/cuda_utils/gpu_rnnt_kernel.py @@ -349,6 +349,7 @@ def compute_grad_kernel( # grad[b, T-1, U-1, v=blank] -= exp(alphas[b, t, u) + logpk - logll[b]) if (idx == blank_) and (t == T - 1) and (u == U - 1): grad -= math.exp(alphas[col] + logpk - logll[mb]) + pass # grad of blank across t < T; # grad[b, t Date: Fri, 18 Jun 2021 16:00:46 -0700 Subject: [PATCH 07/11] Correct few impl Signed-off-by: smajumdar --- .../asr/parts/numba/rnnt_loss/rnnt_numpy.py | 2 +- .../utils/cuda_utils/gpu_rnnt_kernel.py | 24 +++++++++++++------ 2 files changed, 18 insertions(+), 8 deletions(-) diff --git a/nemo/collections/asr/parts/numba/rnnt_loss/rnnt_numpy.py b/nemo/collections/asr/parts/numba/rnnt_loss/rnnt_numpy.py index 09b60da2f7c4..6b720ab28aee 100644 --- a/nemo/collections/asr/parts/numba/rnnt_loss/rnnt_numpy.py +++ b/nemo/collections/asr/parts/numba/rnnt_loss/rnnt_numpy.py @@ -295,7 +295,7 @@ def forward(ctx, acts, labels, act_lens, label_lens, blank, fastemit_lambda): costs = torch.FloatTensor([sum(costs)]) grads = torch.Tensor(grads).to(acts) - ctx.grads = Variable(grads) + ctx.grads = grads return costs @staticmethod diff --git a/nemo/collections/asr/parts/numba/rnnt_loss/utils/cuda_utils/gpu_rnnt_kernel.py b/nemo/collections/asr/parts/numba/rnnt_loss/utils/cuda_utils/gpu_rnnt_kernel.py index aeac7150fe85..324129e4adda 100644 --- a/nemo/collections/asr/parts/numba/rnnt_loss/utils/cuda_utils/gpu_rnnt_kernel.py +++ b/nemo/collections/asr/parts/numba/rnnt_loss/utils/cuda_utils/gpu_rnnt_kernel.py @@ -355,14 +355,20 @@ def compute_grad_kernel( # grad[b, t 0.0: + grad += math.exp(alphas[col] + betas[col + 1] + logpk - logll[mb] + + logp(denom, acts, maxT, maxU, alphabet_size, mb, t, u, u + 1)) - grad_t = math.exp(alphas[col] + logpk - logll[mb] + betas[col + maxU]) - grad += -grad_t + grad -= math.exp(alphas[col] + logpk - logll[mb] + betas[col + maxU]) - if mb == 0 and t == 0 and u == 0 and idx == 0: - print("cuda", alphas[col], betas[col + maxU], logpk, logll[mb], "final ", grad) - # print("cuda check", -math.exp(alphas[col] + logpk - logll[mb] + betas[col + maxU])) + # if mb == 0 and t == 0 and u == 0 and idx == 0: + # print("cuda", alphas[col], betas[col + maxU], logpk, logll[mb], "final ", grad) + # print("cuda check", -math.exp(alphas[col] + logpk - logll[mb] + betas[col + maxU])) # if mb == 0 and t == 0 and u == 0 and idx == 0: # print(mb, t, u, idx, "init t grad", grad) @@ -375,7 +381,11 @@ def compute_grad_kernel( # print(mb, t, u, idx, "init u grad", grad) # math.log1p(fastemit_lambda) + - grad -= math.exp(alphas[col] + logpk - logll[mb] + betas[col + 1]) + if fastemit_lambda > 0.0: + grad += math.exp(alphas[col] + betas[col + 1] + logpk - logll[mb] + + logp(denom, acts, maxT, maxU, alphabet_size, mb, t, u, u + 1)) + + grad -= math.exp(math.log1p(fastemit_lambda) + alphas[col] + logpk - logll[mb] + betas[col + 1]) # print("LABEL", mb, t, u, idx, "label", labels[u], "grad", grad) From 25a105dc93bf8c9520e40291578e38f089817a03 Mon Sep 17 00:00:00 2001 From: Somshubra Majumdar Date: Fri, 18 Jun 2021 16:10:54 -0700 Subject: [PATCH 08/11] Update fastemit scaling Signed-off-by: smajumdar --- .../utils/cuda_utils/gpu_rnnt_kernel.py | 25 +++++++++++++------ 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/nemo/collections/asr/parts/numba/rnnt_loss/utils/cuda_utils/gpu_rnnt_kernel.py b/nemo/collections/asr/parts/numba/rnnt_loss/utils/cuda_utils/gpu_rnnt_kernel.py index 324129e4adda..82ba926381d7 100644 --- a/nemo/collections/asr/parts/numba/rnnt_loss/utils/cuda_utils/gpu_rnnt_kernel.py +++ b/nemo/collections/asr/parts/numba/rnnt_loss/utils/cuda_utils/gpu_rnnt_kernel.py @@ -345,9 +345,24 @@ def compute_grad_kernel( # initialize the grad of the sample acts[b, t, u, v] grad = math.exp(alphas[col] + betas[col] + logpk - logll[mb]) + if fastemit_lambda > 0.0: + if u < U - 1: + fastemit_grad = fastemit_lambda * math.exp( + alphas[col] + + betas[col + 1] + + logpk + - logll[mb] + + logp(denom, acts, maxT, maxU, alphabet_size, mb, t, u, u + 1) + ) + else: + fastemit_grad = 0.0 + else: + fastemit_grad = 0.0 + # // grad to last blank transition # grad[b, T-1, U-1, v=blank] -= exp(alphas[b, t, u) + logpk - logll[b]) if (idx == blank_) and (t == T - 1) and (u == U - 1): + grad -= math.exp(alphas[col] + logpk - logll[mb]) pass @@ -360,10 +375,7 @@ def compute_grad_kernel( # math.exp(logll[mb]), "final ", grad) # grad -= fastemit_lambda * (math.exp(alphas[col] + betas[col + 1] + logpk - logll[mb])) - if fastemit_lambda > 0.0: - grad += math.exp(alphas[col] + betas[col + 1] + logpk - logll[mb] + - logp(denom, acts, maxT, maxU, alphabet_size, mb, t, u, u + 1)) - + grad += fastemit_grad grad -= math.exp(alphas[col] + logpk - logll[mb] + betas[col + maxU]) # if mb == 0 and t == 0 and u == 0 and idx == 0: @@ -381,10 +393,7 @@ def compute_grad_kernel( # print(mb, t, u, idx, "init u grad", grad) # math.log1p(fastemit_lambda) + - if fastemit_lambda > 0.0: - grad += math.exp(alphas[col] + betas[col + 1] + logpk - logll[mb] + - logp(denom, acts, maxT, maxU, alphabet_size, mb, t, u, u + 1)) - + grad += fastemit_grad grad -= math.exp(math.log1p(fastemit_lambda) + alphas[col] + logpk - logll[mb] + betas[col + 1]) # print("LABEL", mb, t, u, idx, "label", labels[u], "grad", grad) From 96f27d9924d5f7a56559001a4f0f481439f19cbd Mon Sep 17 00:00:00 2001 From: Somshubra Majumdar Date: Fri, 18 Jun 2021 16:54:06 -0700 Subject: [PATCH 09/11] Cleanup fastemit Signed-off-by: smajumdar --- .../rnnt_loss/utils/cuda_utils/gpu_rnnt_kernel.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/nemo/collections/asr/parts/numba/rnnt_loss/utils/cuda_utils/gpu_rnnt_kernel.py b/nemo/collections/asr/parts/numba/rnnt_loss/utils/cuda_utils/gpu_rnnt_kernel.py index 82ba926381d7..ab16b7f53840 100644 --- a/nemo/collections/asr/parts/numba/rnnt_loss/utils/cuda_utils/gpu_rnnt_kernel.py +++ b/nemo/collections/asr/parts/numba/rnnt_loss/utils/cuda_utils/gpu_rnnt_kernel.py @@ -349,22 +349,23 @@ def compute_grad_kernel( if u < U - 1: fastemit_grad = fastemit_lambda * math.exp( alphas[col] + + (denom[col] + acts[col * alphabet_size + labels[u]]) + betas[col + 1] + logpk - logll[mb] - + logp(denom, acts, maxT, maxU, alphabet_size, mb, t, u, u + 1) ) else: fastemit_grad = 0.0 else: fastemit_grad = 0.0 + if mb == 0 and t == 0 and u == 1: + print(mb, t, u, idx, "grad", grad, "fastemit grad", fastemit_grad) + # // grad to last blank transition # grad[b, T-1, U-1, v=blank] -= exp(alphas[b, t, u) + logpk - logll[b]) if (idx == blank_) and (t == T - 1) and (u == U - 1): - grad -= math.exp(alphas[col] + logpk - logll[mb]) - pass # grad of blank across t < T; # grad[b, t Date: Fri, 18 Jun 2021 17:21:40 -0700 Subject: [PATCH 10/11] Finalize FastEmit regularization PR Signed-off-by: smajumdar --- .../asr/parts/numba/rnnt_loss/rnnt_numpy.py | 6 +- .../asr/parts/numba/rnnt_loss/rnnt_pytorch.py | 2 - .../rnnt_loss/utils/cpu_utils/cpu_rnnt.py | 4 +- .../utils/cuda_utils/gpu_rnnt_kernel.py | 62 ++++++------------- .../asr/numba/rnnt_loss/test_rnnt_pytorch.py | 10 +-- 5 files changed, 25 insertions(+), 59 deletions(-) diff --git a/nemo/collections/asr/parts/numba/rnnt_loss/rnnt_numpy.py b/nemo/collections/asr/parts/numba/rnnt_loss/rnnt_numpy.py index 6b720ab28aee..8a47b1a4041d 100644 --- a/nemo/collections/asr/parts/numba/rnnt_loss/rnnt_numpy.py +++ b/nemo/collections/asr/parts/numba/rnnt_loss/rnnt_numpy.py @@ -186,8 +186,6 @@ def compute_gradient(log_probs, alphas, betas, labels, blank, fastemit_lambda): for u, l in enumerate(labels): grads[:, u, l] = (1.0 + fastemit_lambda) * grads[:, u, l] - print("numpy", alphas[0, 0], betas[1, 0], log_probs[0, 0, 0], log_like, "final ", grads[0, 0, 0]) - return grads @@ -274,7 +272,9 @@ def transduce_batch(log_probs, labels, flen, glen, blank=0, fastemit_lambda=0.0) ll, g, alphas, betas = transduce(log_probs[b, :t, :u, :], labels[b, : u - 1], blank, fastemit_lambda) grads[b, :t, :u, :] = g - reg = fastemit_regularization(log_probs[b, :t, :u, :], labels[b, : u - 1], alphas, betas, blank, fastemit_lambda) + reg = fastemit_regularization( + log_probs[b, :t, :u, :], labels[b, : u - 1], alphas, betas, blank, fastemit_lambda + ) ll += reg costs.append(ll) return costs, grads diff --git a/nemo/collections/asr/parts/numba/rnnt_loss/rnnt_pytorch.py b/nemo/collections/asr/parts/numba/rnnt_loss/rnnt_pytorch.py index aa9204a6db43..10d9073e7c81 100644 --- a/nemo/collections/asr/parts/numba/rnnt_loss/rnnt_pytorch.py +++ b/nemo/collections/asr/parts/numba/rnnt_loss/rnnt_pytorch.py @@ -76,8 +76,6 @@ def forward(ctx, acts, labels, act_lens, label_lens, blank, reduction, fastemit_ if grads is not None: grads /= minibatch_size - print("final cuda grad", grads[0, 0, 0, 0]) - # costs = costs.to(log_probs.device) ctx.grads = grads return costs diff --git a/nemo/collections/asr/parts/numba/rnnt_loss/utils/cpu_utils/cpu_rnnt.py b/nemo/collections/asr/parts/numba/rnnt_loss/utils/cpu_utils/cpu_rnnt.py index 5e6a0247d449..6bf148148ac1 100644 --- a/nemo/collections/asr/parts/numba/rnnt_loss/utils/cpu_utils/cpu_rnnt.py +++ b/nemo/collections/asr/parts/numba/rnnt_loss/utils/cpu_utils/cpu_rnnt.py @@ -204,8 +204,8 @@ def cost_and_grad_kernel( ) # Scale llForward by FastEmit lambda - llForward *= (1.0 + self.fastemit_lambda_) - llBackward *= (1.0 + self.fastemit_lambda_) + llForward *= 1.0 + self.fastemit_lambda_ + llBackward *= 1.0 + self.fastemit_lambda_ diff = (llForward - llBackward).abs() if diff > 0.1: diff --git a/nemo/collections/asr/parts/numba/rnnt_loss/utils/cuda_utils/gpu_rnnt_kernel.py b/nemo/collections/asr/parts/numba/rnnt_loss/utils/cuda_utils/gpu_rnnt_kernel.py index ab16b7f53840..bcca5bf33b8a 100644 --- a/nemo/collections/asr/parts/numba/rnnt_loss/utils/cuda_utils/gpu_rnnt_kernel.py +++ b/nemo/collections/asr/parts/numba/rnnt_loss/utils/cuda_utils/gpu_rnnt_kernel.py @@ -345,22 +345,24 @@ def compute_grad_kernel( # initialize the grad of the sample acts[b, t, u, v] grad = math.exp(alphas[col] + betas[col] + logpk - logll[mb]) - if fastemit_lambda > 0.0: - if u < U - 1: - fastemit_grad = fastemit_lambda * math.exp( - alphas[col] - + (denom[col] + acts[col * alphabet_size + labels[u]]) - + betas[col + 1] - + logpk - - logll[mb] - ) - else: - fastemit_grad = 0.0 + # If FastEmit regularization is enabled, calculate the gradeint of probability of predicting the next label + # at the current timestep. + # The formula for this is Equation 9 in https://arxiv.org/abs/2010.11148, multiplied by the log probability + # of the current step (t, u), normalized by the total log likelihood. + # Once the gradient has been calculated, scale it by `fastemit_lambda`, as in Equation 10. + if fastemit_lambda > 0.0 and u < U - 1: + fastemit_grad = fastemit_lambda * math.exp( + alphas[col] # alphas(t, u) + + (denom[col] + acts[col * alphabet_size + labels[u]]) # y_hat(t, u) + + betas[col + 1] # betas(t, u+1) + + logpk # log Pr(k|t, u) + - logll[mb] # total log likelihood for normalization + ) else: fastemit_grad = 0.0 - if mb == 0 and t == 0 and u == 1: - print(mb, t, u, idx, "grad", grad, "fastemit grad", fastemit_grad) + # Update the gradient of act[b, t, u, v] with the gradient from FastEmit regularization + grad = grad + fastemit_grad # // grad to last blank transition # grad[b, T-1, U-1, v=blank] -= exp(alphas[b, t, u) + logpk - logll[b]) @@ -370,44 +372,16 @@ def compute_grad_kernel( # grad of blank across t < T; # grad[b, t Date: Fri, 18 Jun 2021 17:44:09 -0700 Subject: [PATCH 11/11] Refactor code to support fastemit regularization Signed-off-by: smajumdar --- .../contextnet_rnnt/config_rnnt.yaml | 5 +++++ .../contextnet_rnnt/config_rnnt_bpe.yaml | 5 +++++ nemo/collections/asr/losses/rnnt.py | 21 +++++++++++++++++-- .../asr/test_asr_rnnt_encdec_model.py | 3 +++ .../asr/test_asr_rnnt_encoder_model_bpe.py | 3 +++ 5 files changed, 35 insertions(+), 2 deletions(-) diff --git a/examples/asr/experimental/contextnet_rnnt/config_rnnt.yaml b/examples/asr/experimental/contextnet_rnnt/config_rnnt.yaml index 306ffe094c3e..6360104423cf 100644 --- a/examples/asr/experimental/contextnet_rnnt/config_rnnt.yaml +++ b/examples/asr/experimental/contextnet_rnnt/config_rnnt.yaml @@ -191,6 +191,11 @@ model: tsd_max_sym_exp: 50 # for Time Synchronous Decoding alsd_max_target_len: 2.0 # for Alignment-Length Synchronous Decoding + loss: + loss_name: "default" + warprnnt_numba_kwargs: + fastemit_lambda: 0.0 + optim: name: adam # _target_: nemo.core.optim.optimizers.Adam diff --git a/examples/asr/experimental/contextnet_rnnt/config_rnnt_bpe.yaml b/examples/asr/experimental/contextnet_rnnt/config_rnnt_bpe.yaml index c259ba581c8c..777f9d484438 100644 --- a/examples/asr/experimental/contextnet_rnnt/config_rnnt_bpe.yaml +++ b/examples/asr/experimental/contextnet_rnnt/config_rnnt_bpe.yaml @@ -192,6 +192,11 @@ model: tsd_max_sym_exp: 50 # for Time Synchronous Decoding alsd_max_target_len: 2.0 # for Alignment-Length Synchronous Decoding + loss: + loss_name: "default" + warprnnt_numba_kwargs: + fastemit_lambda: 0.0 + optim: name: adam # _target_: nemo.core.optim.optimizers.Adam diff --git a/nemo/collections/asr/losses/rnnt.py b/nemo/collections/asr/losses/rnnt.py index 24496565ef3b..ee5512149089 100644 --- a/nemo/collections/asr/losses/rnnt.py +++ b/nemo/collections/asr/losses/rnnt.py @@ -32,6 +32,7 @@ from typing import Optional import torch +from omegaconf import DictConfig, OmegaConf from nemo.core.classes import Loss, typecheck from nemo.core.neural_types import LabelsType, LengthsType, LogprobsType, LossType, NeuralType @@ -144,6 +145,9 @@ def resolve_rnnt_loss(loss_name: str, blank_idx: int, loss_kwargs: dict = None) # Resolve loss functions sequentially loss_kwargs = {} if loss_kwargs is None else loss_kwargs + if isinstance(loss_kwargs, DictConfig): + loss_kwargs = OmegaConf.to_container(loss_kwargs, resolve=True) + # Get actual loss name for `default` if loss_name == 'default': loss_name = loss_config.loss_name @@ -156,7 +160,8 @@ def resolve_rnnt_loss(loss_name: str, blank_idx: int, loss_kwargs: dict = None) _warn_unused_additional_kwargs(loss_name, loss_kwargs) elif loss_name == 'warprnnt_numba': - loss_func = RNNTLossNumba(blank=blank_idx, reduction='none') + fastemit_lambda = loss_kwargs.pop('fastemit_lambda', 0.0) + loss_func = RNNTLossNumba(blank=blank_idx, reduction='none', fastemit_lambda=fastemit_lambda) _warn_unused_additional_kwargs(loss_name, loss_kwargs) else: @@ -194,7 +199,19 @@ def __init__(self, num_classes, reduction: str = 'mean_batch', loss_name: str = albiet there is a small speed penalty for JIT numba compile. Note: - Requires the pytorch bindings to be installed prior to calling this class. + Requires Numba 0.53.0 or later to be installed to use this loss function. + + Losses can be selected via the config, and optionally be passed keyword arguments as follows. + + Examples: + .. code-block:: yaml + + model: # RNNT Model config + ... + loss: + loss_name: "warprnnt_numba" + warprnnt_numba_kwargs: + fastemit_lambda: 0.0 Warning: In the case that GPU memory is exhausted in order to compute RNNTLoss, it might cause diff --git a/tests/collections/asr/test_asr_rnnt_encdec_model.py b/tests/collections/asr/test_asr_rnnt_encdec_model.py index 15124157a6a7..16663a036d42 100644 --- a/tests/collections/asr/test_asr_rnnt_encdec_model.py +++ b/tests/collections/asr/test_asr_rnnt_encdec_model.py @@ -76,6 +76,8 @@ def asr_model(): decoding = {'strategy': 'greedy_batch', 'greedy': {'max_symbols': 30}} + loss = {'loss_name': 'default', 'warprnnt_numba_kwargs': {'fastemit_lambda': 0.001}} + modelConfig = DictConfig( { 'labels': ListConfig(labels), @@ -85,6 +87,7 @@ def asr_model(): 'decoder': DictConfig(decoder), 'joint': DictConfig(joint), 'decoding': DictConfig(decoding), + 'loss': DictConfig(loss), } ) diff --git a/tests/collections/asr/test_asr_rnnt_encoder_model_bpe.py b/tests/collections/asr/test_asr_rnnt_encoder_model_bpe.py index 294b03dd7d25..ea08642f0eb6 100644 --- a/tests/collections/asr/test_asr_rnnt_encoder_model_bpe.py +++ b/tests/collections/asr/test_asr_rnnt_encoder_model_bpe.py @@ -73,6 +73,8 @@ def asr_model(test_data_dir): tokenizer = {'dir': os.path.join(test_data_dir, "asr", "tokenizers", "an4_wpe_128"), 'type': 'wpe'} + loss = {'loss_name': 'default', 'warprnnt_numba_kwargs': {'fastemit_lambda': 0.001}} + modelConfig = DictConfig( { 'preprocessor': DictConfig(preprocessor), @@ -82,6 +84,7 @@ def asr_model(test_data_dir): 'joint': DictConfig(joint), 'tokenizer': DictConfig(tokenizer), 'decoding': DictConfig(decoding), + 'loss': DictConfig(loss), } )