Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add FastEmit support for RNNT Losses #2374

Merged
merged 11 commits into from
Jun 22, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions examples/asr/experimental/contextnet_rnnt/config_rnnt.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,11 @@ model:
tsd_max_sym_exp: 50 # for Time Synchronous Decoding
alsd_max_target_len: 2.0 # for Alignment-Length Synchronous Decoding

loss:
loss_name: "default"
warprnnt_numba_kwargs:
fastemit_lambda: 0.0

optim:
name: adam
# _target_: nemo.core.optim.optimizers.Adam
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,11 @@ model:
tsd_max_sym_exp: 50 # for Time Synchronous Decoding
alsd_max_target_len: 2.0 # for Alignment-Length Synchronous Decoding

loss:
loss_name: "default"
warprnnt_numba_kwargs:
fastemit_lambda: 0.0

optim:
name: adam
# _target_: nemo.core.optim.optimizers.Adam
Expand Down
21 changes: 19 additions & 2 deletions nemo/collections/asr/losses/rnnt.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from typing import Optional

import torch
from omegaconf import DictConfig, OmegaConf

from nemo.core.classes import Loss, typecheck
from nemo.core.neural_types import LabelsType, LengthsType, LogprobsType, LossType, NeuralType
Expand Down Expand Up @@ -144,6 +145,9 @@ def resolve_rnnt_loss(loss_name: str, blank_idx: int, loss_kwargs: dict = None)
# Resolve loss functions sequentially
loss_kwargs = {} if loss_kwargs is None else loss_kwargs

if isinstance(loss_kwargs, DictConfig):
loss_kwargs = OmegaConf.to_container(loss_kwargs, resolve=True)

# Get actual loss name for `default`
if loss_name == 'default':
loss_name = loss_config.loss_name
Expand All @@ -156,7 +160,8 @@ def resolve_rnnt_loss(loss_name: str, blank_idx: int, loss_kwargs: dict = None)
_warn_unused_additional_kwargs(loss_name, loss_kwargs)

elif loss_name == 'warprnnt_numba':
loss_func = RNNTLossNumba(blank=blank_idx, reduction='none')
fastemit_lambda = loss_kwargs.pop('fastemit_lambda', 0.0)
loss_func = RNNTLossNumba(blank=blank_idx, reduction='none', fastemit_lambda=fastemit_lambda)
_warn_unused_additional_kwargs(loss_name, loss_kwargs)

else:
Expand Down Expand Up @@ -194,7 +199,19 @@ def __init__(self, num_classes, reduction: str = 'mean_batch', loss_name: str =
albiet there is a small speed penalty for JIT numba compile.

Note:
Requires the pytorch bindings to be installed prior to calling this class.
Requires Numba 0.53.0 or later to be installed to use this loss function.

Losses can be selected via the config, and optionally be passed keyword arguments as follows.

Examples:
.. code-block:: yaml

model: # RNNT Model config
...
loss:
loss_name: "warprnnt_numba"
warprnnt_numba_kwargs:
fastemit_lambda: 0.0

Warning:
In the case that GPU memory is exhausted in order to compute RNNTLoss, it might cause
Expand Down
8 changes: 8 additions & 0 deletions nemo/collections/asr/parts/numba/rnnt_loss/rnnt.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def rnnt_loss_cpu(
costs: torch.Tensor,
grads: torch.Tensor,
blank_label: int,
fastemit_lambda: float,
num_threads: int,
):
"""
Expand All @@ -59,6 +60,8 @@ def rnnt_loss_cpu(
costs: Zero vector of length [B] in which costs will be set.
grads: Zero tensor of shape [B, T, U, V+1] where the gradient will be set.
blank_label: Index of the blank token in the vocabulary.
fastemit_lambda: Float scaling factor for FastEmit regularization. Refer to
FastEmit: Low-latency Streaming ASR with Sequence-level Emission Regularization.
num_threads: Number of threads for OpenMP.
"""
# aliases
Expand Down Expand Up @@ -92,6 +95,7 @@ def rnnt_loss_cpu(
alphabet_size=alphabet_size,
workspace=cpu_workspace,
blank=blank_label,
fastemit_lambda=fastemit_lambda,
num_threads=num_threads,
batch_first=True,
)
Expand Down Expand Up @@ -136,6 +140,7 @@ def rnnt_loss_gpu(
costs: torch.Tensor,
grads: torch.Tensor,
blank_label: int,
fastemit_lambda: float,
num_threads: int,
):
"""
Expand All @@ -151,6 +156,8 @@ def rnnt_loss_gpu(
costs: Zero vector of length [B] in which costs will be set.
grads: Zero tensor of shape [B, T, U, V+1] where the gradient will be set.
blank_label: Index of the blank token in the vocabulary.
fastemit_lambda: Float scaling factor for FastEmit regularization. Refer to
FastEmit: Low-latency Streaming ASR with Sequence-level Emission Regularization.
num_threads: Number of threads for OpenMP.
"""
minibatch_size = acts.shape[0]
Expand Down Expand Up @@ -189,6 +196,7 @@ def rnnt_loss_gpu(
alphabet_size=alphabet_size,
workspace=gpu_workspace,
blank=blank_label,
fastemit_lambda=fastemit_lambda,
num_threads=num_threads,
stream=stream,
)
Expand Down
155 changes: 141 additions & 14 deletions nemo/collections/asr/parts/numba/rnnt_loss/rnnt_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,57 @@
from torch.nn import Module


def check_type(var, t, name):
if var.dtype is not t:
raise TypeError("{} must be {}".format(name, t))


def check_contiguous(var, name):
if not var.is_contiguous():
raise ValueError("{} must be contiguous".format(name))


def check_dim(var, dim, name):
if len(var.shape) != dim:
raise ValueError("{} must be {}D".format(name, dim))


def certify_inputs(log_probs, labels, lengths, label_lengths):
# check_type(log_probs, torch.float32, "log_probs")
check_type(labels, torch.int32, "labels")
check_type(label_lengths, torch.int32, "label_lengths")
check_type(lengths, torch.int32, "lengths")
check_contiguous(log_probs, "log_probs")
check_contiguous(labels, "labels")
check_contiguous(label_lengths, "label_lengths")
check_contiguous(lengths, "lengths")

if lengths.shape[0] != log_probs.shape[0]:
raise ValueError(
f"Must have a length per example. "
f"Given lengths dim: {lengths.shape[0]}, "
f"Log probs dim : {log_probs.shape[0]}"
)
if label_lengths.shape[0] != log_probs.shape[0]:
raise ValueError(
"Must have a label length per example. "
f"Given label lengths dim : {label_lengths.shape[0]}, "
f"Log probs dim : {log_probs.shape[0]}"
)

check_dim(log_probs, 4, "log_probs")
check_dim(labels, 2, "labels")
check_dim(lengths, 1, "lenghts")
check_dim(label_lengths, 1, "label_lenghts")
max_T = torch.max(lengths)
max_U = torch.max(label_lengths)
T, U = log_probs.shape[1:3]
if T != max_T:
raise ValueError(f"Input length mismatch! Given T: {T}, Expected max T from input lengths: {max_T}")
if U != max_U + 1:
raise ValueError(f"Output length mismatch! Given U: {U}, Expected max U from target lengths: {max_U} + 1")


def _assert_no_grad(tensor):
assert not tensor.requires_grad, (
"gradients only computed for log_probs - please " "mark other tensors as not requiring gradients"
Expand Down Expand Up @@ -102,7 +153,7 @@ def backward_pass(log_probs, labels, blank):
return betas, betas[0, 0]


def compute_gradient(log_probs, alphas, betas, labels, blank):
def compute_gradient(log_probs, alphas, betas, labels, blank, fastemit_lambda):
"""
Computes the gradients of the log_probs with respect to the log probability of this step occuring.

Expand All @@ -119,37 +170,85 @@ def compute_gradient(log_probs, alphas, betas, labels, blank):
"""
T, U, _ = log_probs.shape
grads = np.full(log_probs.shape, -float("inf"))
log_like = betas[0, 0]
log_like = betas[0, 0] # == alphas[T - 1, U - 1] + betas[T - 1, U - 1]

# // grad to last blank transition
grads[T - 1, U - 1, blank] = alphas[T - 1, U - 1]

grads[: T - 1, :, blank] = alphas[: T - 1, :] + betas[1:, :]

# // grad to label transition
for u, l in enumerate(labels):
grads[:, u, l] = alphas[:, u] + betas[:, u + 1]

grads = -np.exp(grads + log_probs - log_like)

if fastemit_lambda > 0.0:
for u, l in enumerate(labels):
grads[:, u, l] = (1.0 + fastemit_lambda) * grads[:, u, l]

return grads


def transduce(log_probs, labels, blank=0):
def fastemit_regularization(log_probs, labels, alphas, betas, blank, fastemit_lambda):
"""
Describes the computation of FastEmit regularization from the paper -
[FastEmit: Low-latency Streaming ASR with Sequence-level Emission Regularization](https://arxiv.org/abs/2010.11148)

Args:
log_probs: Tensor of shape [T, U, V+1]
labels: Unused. Labels of shape [B, U]
alphas: Tensor of shape [T, U] which represents the forward variable.
betas: Unused. Tensor of shape [T, U] which represents the backward variable.
blank: Index of the blank token.
fastemit_lambda: Float scaling factor for FastEmit regularization.

Returns:
The regularized negative log likelihood - lambda * P˜(At, u|x)
"""
# General calculation of the fastemit regularization alignments
T, U, _ = log_probs.shape
# alignment = np.zeros((T, U), dtype='float32')
#
# for t in range(0, T):
# alignment[t, U - 1] = alphas[t, U - 1] + betas[t, U - 1]
#
# for t in range(0, T):
# for u in range(0, U - 1):
# emit = alphas[t, u] + log_probs[t, u, labels[u]] + betas[t, u + 1]
# alignment[t, u] = emit
# reg = fastemit_lambda * (alignment[T - 1, U - 1])

# The above is equivalent to below, without need of computing above
# reg = fastemit_lambda * (alphas[T - 1, U - 1] + betas[T - 1, U - 1])

# The above is also equivalent to below, without need of computing the betas alignment matrix
reg = fastemit_lambda * (alphas[T - 1, U - 1] + log_probs[T - 1, U - 1, blank])
return -reg


def transduce(log_probs, labels, blank=0, fastemit_lambda=0.0):
"""
Args:
log_probs: 3D array with shape
[input len, output len + 1, vocab size]
labels: 1D array with shape [output time steps]
blank: Index of the blank token.
fastemit_lambda: Float scaling factor for FastEmit regularization.

Returns:
float: The negative log-likelihood
3D array: Gradients with respect to the
unnormalized input actications
2d arrays: Alphas matrix (TxU)
2d array: Betas matrix (TxU)
"""
alphas, ll_forward = forward_pass(log_probs, labels, blank)
betas, ll_backward = backward_pass(log_probs, labels, blank)
grads = compute_gradient(log_probs, alphas, betas, labels, blank)
return -ll_forward, grads
grads = compute_gradient(log_probs, alphas, betas, labels, blank, fastemit_lambda)
return -ll_forward, grads, alphas, betas


def transduce_batch(log_probs, labels, flen, glen, blank=0):
def transduce_batch(log_probs, labels, flen, glen, blank=0, fastemit_lambda=0.0):
"""
Compute the transducer loss of the batch.

Expand All @@ -159,6 +258,7 @@ def transduce_batch(log_probs, labels, flen, glen, blank=0):
flen: Length vector of the acoustic sequence.
glen: Length vector of the target sequence.
blank: Id of the blank token.
fastemit_lambda: Float scaling factor for FastEmit regularization.

Returns:
Batch of transducer forward log probabilities (loss) and the gradients of the activation matrix.
Expand All @@ -168,46 +268,73 @@ def transduce_batch(log_probs, labels, flen, glen, blank=0):
for b in range(log_probs.shape[0]):
t = int(flen[b])
u = int(glen[b]) + 1
ll, g = transduce(log_probs[b, :t, :u, :], labels[b, : u - 1], blank)

ll, g, alphas, betas = transduce(log_probs[b, :t, :u, :], labels[b, : u - 1], blank, fastemit_lambda)
grads[b, :t, :u, :] = g

reg = fastemit_regularization(
log_probs[b, :t, :u, :], labels[b, : u - 1], alphas, betas, blank, fastemit_lambda
)
ll += reg
costs.append(ll)
return costs, grads


class _RNNT(Function):
@staticmethod
def forward(ctx, acts, labels, act_lens, label_lens, blank):
def forward(ctx, acts, labels, act_lens, label_lens, blank, fastemit_lambda):
costs, grads = transduce_batch(
acts.detach().cpu().numpy(), labels.cpu().numpy(), act_lens.cpu().numpy(), label_lens.cpu().numpy(), blank,
acts.detach().cpu().numpy(),
labels.cpu().numpy(),
act_lens.cpu().numpy(),
label_lens.cpu().numpy(),
blank,
fastemit_lambda,
)

costs = torch.FloatTensor([sum(costs)])
grads = torch.Tensor(grads).to(acts)

ctx.grads = Variable(grads)
ctx.grads = grads
return costs

@staticmethod
def backward(ctx, grad_output):
return ctx.grads, None, None, None, None
return ctx.grads, None, None, None, None, None


class RNNTLoss(Module):
"""
Parameters:
`blank_label` (int): default 0 - label index of blank token
fastemit_lambda: Float scaling factor for FastEmit regularization.
"""

def __init__(self, blank: int = 0):
def __init__(self, blank: int = 0, fastemit_lambda: float = 0.0):
super(RNNTLoss, self).__init__()
self.blank = blank
self.fastemit_lambda = fastemit_lambda
self.rnnt = _RNNT.apply

def forward(self, acts, labels, act_lens, label_lens):
assert len(labels.size()) == 2
_assert_no_grad(labels)
_assert_no_grad(act_lens)
_assert_no_grad(label_lens)
certify_inputs(acts, labels, act_lens, label_lens)

acts = torch.nn.functional.log_softmax(acts, -1)
return self.rnnt(acts, labels, act_lens, label_lens, self.blank)
return self.rnnt(acts, labels, act_lens, label_lens, self.blank, self.fastemit_lambda)


if __name__ == '__main__':
loss = RNNTLoss(fastemit_lambda=0.01)

torch.manual_seed(0)

acts = torch.randn(1, 2, 5, 3)
labels = torch.tensor([[0, 2, 1, 2]], dtype=torch.int32)
act_lens = torch.tensor([2], dtype=torch.int32)
label_lens = torch.tensor([len(labels[0])], dtype=torch.int32)

loss_val = loss(acts, labels, act_lens, label_lens)
Loading