Skip to content

Commit

Permalink
Add support for Gradient Clipping (clamp) in RNNT Numba loss (#3550)
Browse files Browse the repository at this point in the history
* Implement LogSoftmaxGradModification function

Signed-off-by: smajumdar <[email protected]>

* Implement clamp in cuda kernels

Signed-off-by: smajumdar <[email protected]>

* Update code for clamp support in numpy and pytorch

Signed-off-by: smajumdar <[email protected]>

* Add basic gradient test

Signed-off-by: smajumdar <[email protected]>

* Add basic gradient test

Signed-off-by: smajumdar <[email protected]>

* Add gradient test for fastemit and clamp

Signed-off-by: smajumdar <[email protected]>

* Correct test name for big tensor

Signed-off-by: smajumdar <[email protected]>

* Fix installation links to Numba to be for conda-forge

Signed-off-by: smajumdar <[email protected]>

* Increase GPU grad kernel thread size

Signed-off-by: smajumdar <[email protected]>

* Add newline

Signed-off-by: smajumdar <[email protected]>

* Add support for clamp in loss numba

Signed-off-by: smajumdar <[email protected]>

* Update configs

Signed-off-by: smajumdar <[email protected]>

* Update test name

Signed-off-by: smajumdar <[email protected]>

* Fix issue with conda forge install numba

Signed-off-by: smajumdar <[email protected]>

* Revert inplace mul, we dont need double backprop yet (todo: investigate memory cost with and without it for future)

Signed-off-by: smajumdar <[email protected]>

* Address comment about docstring

Signed-off-by: smajumdar <[email protected]>
  • Loading branch information
titu1994 authored and fayejf committed Mar 2, 2022
1 parent aceb2f1 commit 7152327
Show file tree
Hide file tree
Showing 20 changed files with 548 additions and 26 deletions.
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ RUN --mount=from=nemo-src,target=/tmp/nemo cd /tmp/nemo && pip install ".[all]"

# TODO: Try to remove once 21.07 container is the base container
# install pinned numba version
RUN conda install -c numba numba=0.54.1
RUN conda install -c conda-forge numba=0.54.1

# copy scripts/examples/tests into container for end user
WORKDIR /workspace/nemo
Expand Down
2 changes: 1 addition & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ Note that RNNT requires numba to be installed from conda.
conda remove numba
pip uninstall numba
conda install -c numba numba
conda install -c conda-forge numba
Megatron GPT
~~~~~~~~~~~~
Expand Down
1 change: 1 addition & 0 deletions examples/asr/conf/conformer/conformer_transducer_bpe.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ model:
# FastEmit regularization: https://arxiv.org/abs/2010.11148
# You may enable FastEmit to reduce the latency of the model for streaming
fastemit_lambda: 0.0 # Recommended values to be in range [1e-4, 1e-2], 0.001 is a good start.
clamp: -1.0 # if > 0, applies gradient clamping in range [-clamp, clamp] for the joint tensor only.

# Adds Gaussian noise to the gradients of the decoder to avoid overfitting
variational_noise:
Expand Down
1 change: 1 addition & 0 deletions examples/asr/conf/conformer/conformer_transducer_char.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ model:
# FastEmit regularization: https://arxiv.org/abs/2010.11148
# You may enable FastEmit to reduce the latency of the model for streaming
fastemit_lambda: 0.0 # Recommended values to be in range [1e-4, 1e-2], 0.001 is a good start.
clamp: -1.0 # if > 0, applies gradient clamping in range [-clamp, clamp] for the joint tensor only.

# Adds Gaussian noise to the gradients of the decoder to avoid overfitting
variational_noise:
Expand Down
3 changes: 2 additions & 1 deletion examples/asr/conf/contextnet_rnnt/config_rnnt.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,8 @@ model:
loss:
loss_name: "default"
warprnnt_numba_kwargs:
fastemit_lambda: 0.0
fastemit_lambda: 0.0 # Recommended values to be in range [1e-4, 1e-2], 0.001 is a good start.
clamp: -1.0 # if > 0, applies gradient clamping in range [-clamp, clamp] for the joint tensor only.

optim:
name: adam
Expand Down
3 changes: 2 additions & 1 deletion examples/asr/conf/contextnet_rnnt/config_rnnt_bpe.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,8 @@ model:
loss:
loss_name: "default"
warprnnt_numba_kwargs:
fastemit_lambda: 0.0
fastemit_lambda: 0.0 # Recommended values to be in range [1e-4, 1e-2], 0.001 is a good start.
clamp: -1.0 # if > 0, applies gradient clamping in range [-clamp, clamp] for the joint tensor only.

optim:
name: adam
Expand Down
1 change: 1 addition & 0 deletions examples/asr/conf/contextnet_rnnt/contextnet_rnnt.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,7 @@ model:
warprnnt_numba_kwargs:
# FastEmit regularization: https://arxiv.org/abs/2010.11148
fastemit_lambda: 0.001 # Values can be in range [1e-4, 1e-2]. Generally, 0.001 is good start.
clamp: -1.0 # if > 0, applies gradient clamping in range [-clamp, clamp] for the joint tensor only.

optim:
name: novograd
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,7 @@ model:
warprnnt_numba_kwargs:
# FastEmit regularization: https://arxiv.org/abs/2010.11148
fastemit_lambda: 0.001 # Values can be in range [1e-4, 1e-2]. Generally, 0.001 is good start.
clamp: -1.0 # if > 0, applies gradient clamping in range [-clamp, clamp] for the joint tensor only.

optim:
name: novograd
Expand Down
3 changes: 2 additions & 1 deletion nemo/collections/asr/losses/rnnt.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,8 @@ def resolve_rnnt_loss(loss_name: str, blank_idx: int, loss_kwargs: dict = None)

elif loss_name == 'warprnnt_numba':
fastemit_lambda = loss_kwargs.pop('fastemit_lambda', 0.0)
loss_func = RNNTLossNumba(blank=blank_idx, reduction='none', fastemit_lambda=fastemit_lambda)
clamp = loss_kwargs.pop('clamp', -1.0)
loss_func = RNNTLossNumba(blank=blank_idx, reduction='none', fastemit_lambda=fastemit_lambda, clamp=clamp)
_warn_unused_additional_kwargs(loss_name, loss_kwargs)

else:
Expand Down
6 changes: 6 additions & 0 deletions nemo/collections/asr/parts/numba/rnnt_loss/rnnt.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def rnnt_loss_cpu(
grads: torch.Tensor,
blank_label: int,
fastemit_lambda: float,
clamp: float,
num_threads: int,
):
"""
Expand All @@ -62,6 +63,7 @@ def rnnt_loss_cpu(
blank_label: Index of the blank token in the vocabulary.
fastemit_lambda: Float scaling factor for FastEmit regularization. Refer to
FastEmit: Low-latency Streaming ASR with Sequence-level Emission Regularization.
clamp: Float value. When set to value >= 0.0, will clamp the gradient to [-clamp, clamp].
num_threads: Number of threads for OpenMP.
"""
# aliases
Expand Down Expand Up @@ -96,6 +98,7 @@ def rnnt_loss_cpu(
workspace=cpu_workspace,
blank=blank_label,
fastemit_lambda=fastemit_lambda,
clamp=clamp,
num_threads=num_threads,
batch_first=True,
)
Expand Down Expand Up @@ -141,6 +144,7 @@ def rnnt_loss_gpu(
grads: torch.Tensor,
blank_label: int,
fastemit_lambda: float,
clamp: float,
num_threads: int,
):
"""
Expand All @@ -158,6 +162,7 @@ def rnnt_loss_gpu(
blank_label: Index of the blank token in the vocabulary.
fastemit_lambda: Float scaling factor for FastEmit regularization. Refer to
FastEmit: Low-latency Streaming ASR with Sequence-level Emission Regularization.
clamp: Float value. When set to value >= 0.0, will clamp the gradient to [-clamp, clamp].
num_threads: Number of threads for OpenMP.
"""
minibatch_size = acts.shape[0]
Expand Down Expand Up @@ -197,6 +202,7 @@ def rnnt_loss_gpu(
workspace=gpu_workspace,
blank=blank_label,
fastemit_lambda=fastemit_lambda,
clamp=clamp,
num_threads=num_threads,
stream=stream,
)
Expand Down
28 changes: 26 additions & 2 deletions nemo/collections/asr/parts/numba/rnnt_loss/rnnt_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,25 @@ def _assert_no_grad(tensor):
)


class LogSoftmaxGradModification(Function):
@staticmethod
def forward(ctx, acts, clamp):
if clamp < 0:
raise ValueError("`clamp` must be 0.0 or positive float.")

res = acts.new(acts)
ctx.clamp = clamp
return res

@staticmethod
def backward(ctx, grad_output):
grad_output = torch.clamp(grad_output, -ctx.clamp, ctx.clamp)
return (
grad_output,
None,
)


def forward_pass(log_probs, labels, blank):
"""
Computes probability of the forward variable alpha.
Expand Down Expand Up @@ -300,7 +319,8 @@ def forward(ctx, acts, labels, act_lens, label_lens, blank, fastemit_lambda):

@staticmethod
def backward(ctx, grad_output):
return ctx.grads, None, None, None, None, None
grad_output = grad_output.view(-1, 1, 1, 1).to(ctx.grads)
return ctx.grads.mul(grad_output), None, None, None, None, None


class RNNTLoss(Module):
Expand All @@ -310,10 +330,11 @@ class RNNTLoss(Module):
fastemit_lambda: Float scaling factor for FastEmit regularization.
"""

def __init__(self, blank: int = 0, fastemit_lambda: float = 0.0):
def __init__(self, blank: int = 0, fastemit_lambda: float = 0.0, clamp: float = -1.0):
super(RNNTLoss, self).__init__()
self.blank = blank
self.fastemit_lambda = fastemit_lambda
self.clamp = float(clamp) if clamp > 0 else 0.0
self.rnnt = _RNNT.apply

def forward(self, acts, labels, act_lens, label_lens):
Expand All @@ -323,6 +344,9 @@ def forward(self, acts, labels, act_lens, label_lens):
_assert_no_grad(label_lens)
certify_inputs(acts, labels, act_lens, label_lens)

if self.clamp > 0.0:
acts = LogSoftmaxGradModification.apply(acts, self.clamp)

acts = torch.nn.functional.log_softmax(acts, -1)
return self.rnnt(acts, labels, act_lens, label_lens, self.blank, self.fastemit_lambda)

Expand Down
44 changes: 37 additions & 7 deletions nemo/collections/asr/parts/numba/rnnt_loss/rnnt_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,14 @@
from torch.nn import Module

from nemo.collections.asr.parts.numba.rnnt_loss import rnnt
from nemo.collections.asr.parts.numba.rnnt_loss.utils.cpu_utils import cpu_rnnt

__all__ = ['rnnt_loss', 'RNNTLossNumba']


class _RNNTNumba(Function):
@staticmethod
def forward(ctx, acts, labels, act_lens, label_lens, blank, reduction, fastemit_lambda):
def forward(ctx, acts, labels, act_lens, label_lens, blank, reduction, fastemit_lambda, clamp):
"""
log_probs: Tensor of (batch x seqLength x labelLength x outputDim) containing output from network
labels: 2 dimensional Tensor containing all the targets of the batch with zero padded
Expand All @@ -50,6 +51,8 @@ def forward(ctx, acts, labels, act_lens, label_lens, blank, reduction, fastemit_
is_cuda = acts.is_cuda

certify_inputs(acts, labels, act_lens, label_lens)
if clamp < 0:
raise ValueError("`clamp` must be 0.0 or positive float value.")

loss_func = rnnt.rnnt_loss_gpu if is_cuda else rnnt.rnnt_loss_cpu
grads = torch.zeros_like(acts) if acts.requires_grad else None
Expand All @@ -65,6 +68,7 @@ def forward(ctx, acts, labels, act_lens, label_lens, blank, reduction, fastemit_
grads=grads,
blank_label=blank,
fastemit_lambda=fastemit_lambda,
clamp=clamp,
num_threads=0,
)

Expand All @@ -84,11 +88,13 @@ def forward(ctx, acts, labels, act_lens, label_lens, blank, reduction, fastemit_
def backward(ctx, grad_output):
if grad_output is not None and ctx.grads is not None:
grad_output = grad_output.view(-1, 1, 1, 1).to(ctx.grads)
return ctx.grads.mul_(grad_output), None, None, None, None, None, None
return ctx.grads.mul_(grad_output), None, None, None, None, None, None, None


def rnnt_loss(acts, labels, act_lens, label_lens, blank=0, reduction='mean'):
"""RNN Transducer Loss
def rnnt_loss(
acts, labels, act_lens, label_lens, blank=0, reduction='mean', fastemit_lambda: float = 0.0, clamp: float = 0.0
):
"""RNN Transducer Loss (functional form)
Args:
acts: Tensor of (batch x seqLength x labelLength x outputDim) containing output from network
labels: 2 dimensional Tensor containing all the targets of the batch with zero padded
Expand All @@ -101,9 +107,19 @@ def rnnt_loss(acts, labels, act_lens, label_lens, blank=0, reduction='mean'):
then the mean over the batch is taken. Default: 'mean'
"""
if not acts.is_cuda:
# Since CPU requires log_softmax to be computed explicitly, we need to perform grad clipping
# *after* we have obtained the gradients of loss(logsoftmax()).
# This is highly wasteful since it requires a copy of the entire joint tensor which is expensive.
# CUDA version is much more efficient since it performs an inplace logsoftmax, and therefore
# can inplace clamp the gradient.
if clamp > 0.0:
acts = cpu_rnnt.LogSoftmaxGradModification.apply(acts, clamp)

# NOTE manually done log_softmax for CPU version,
# log_softmax is computed within GPU version.
acts = torch.nn.functional.log_softmax(acts, -1)

return _RNNTNumba.apply(acts, labels, act_lens, label_lens, blank, reduction)
return _RNNTNumba.apply(acts, labels, act_lens, label_lens, blank, reduction, fastemit_lambda, clamp)


class RNNTLossNumba(Module):
Expand All @@ -114,12 +130,16 @@ class RNNTLossNumba(Module):
'none' | 'mean' | 'sum'. 'none': no reduction will be applied,
'mean': the output losses will be divided by the target lengths and
then the mean over the batch is taken. Default: 'mean'
fastemit_lambda: Float scaling factor for FastEmit regularization. Refer to
FastEmit: Low-latency Streaming ASR with Sequence-level Emission Regularization.
clamp: Float value. When set to value >= 0.0, will clamp the gradient to [-clamp, clamp].
"""

def __init__(self, blank=0, reduction='mean', fastemit_lambda: float = 0.0):
def __init__(self, blank=0, reduction='mean', fastemit_lambda: float = 0.0, clamp: float = -1):
super(RNNTLossNumba, self).__init__()
self.blank = blank
self.fastemit_lambda = fastemit_lambda
self.clamp = float(clamp) if clamp > 0 else 0.0
self.reduction = reduction
self.loss = _RNNTNumba.apply

Expand All @@ -131,11 +151,21 @@ def forward(self, acts, labels, act_lens, label_lens):
label_lens: Tensor of (batch) containing label length of each example
"""
if not acts.is_cuda:
# Since CPU requires log_softmax to be computed explicitly, we need to perform grad clipping
# *after* we have obtained the gradients of loss(logsoftmax()).
# This is highly wasteful since it requires a copy of the entire joint tensor which is expensive.
# CUDA version is much more efficient since it performs an inplace logsoftmax, and therefore
# can inplace clamp the gradient.
if self.clamp > 0.0:
acts = cpu_rnnt.LogSoftmaxGradModification.apply(acts, self.clamp)

# NOTE manually done log_softmax for CPU version,
# log_softmax is computed within GPU version.
acts = torch.nn.functional.log_softmax(acts, -1)

return self.loss(acts, labels, act_lens, label_lens, self.blank, self.reduction, self.fastemit_lambda)
return self.loss(
acts, labels, act_lens, label_lens, self.blank, self.reduction, self.fastemit_lambda, self.clamp
)


def check_type(var, t, name):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@

import numba
import torch
from torch.autograd import Function

from nemo.collections.asr.parts.numba.rnnt_loss.utils import global_constants

Expand Down Expand Up @@ -137,6 +138,29 @@ def setup_probs(
self.log_probs2[offset + 1] = log_probs[idx(t, u, labels[u])]


class LogSoftmaxGradModification(Function):
@staticmethod
def forward(ctx, acts, clamp):
if clamp < 0:
raise ValueError("`clamp` must be 0.0 or positive float.")

# This is needed for correctness (inplace is problematic),
# but it wastes a log of memory.
res = acts.new(acts)
ctx.clamp = clamp
return res

@staticmethod
def backward(ctx, grad_output):
# Clamp the gradients of loss(logsoftmax(...))
# CPU computes logsoftmax explicitly, so we need to override t
grad_output = torch.clamp(grad_output, -ctx.clamp, ctx.clamp)
return (
grad_output,
None,
)


class CPURNNT:
def __init__(
self,
Expand All @@ -147,6 +171,7 @@ def __init__(
workspace: torch.Tensor,
blank: int,
fastemit_lambda: float,
clamp: float,
num_threads: int,
batch_first: bool,
):
Expand All @@ -163,6 +188,7 @@ def __init__(
blank: Index of the RNNT blank token in the vocabulary. Generally the first or last token in the vocab.
fastemit_lambda: Float scaling factor for FastEmit regularization. Refer to
FastEmit: Low-latency Streaming ASR with Sequence-level Emission Regularization.
clamp: Float value. When set to value >= 0.0, will clamp the gradient to [-clamp, clamp].
num_threads: Number of OMP threads to launch.
batch_first: Bool that decides if batch dimension is first or third.
"""
Expand All @@ -173,6 +199,7 @@ def __init__(
self.workspace = workspace # a flat vector of floatX numbers that represents allocated memory slices
self.blank_ = blank
self.fastemit_lambda_ = fastemit_lambda
self.clamp_ = abs(clamp)
self.num_threads_ = num_threads
self.batch_first = batch_first

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ def __init__(
alphabet_size: int,
workspace,
blank: int,
fastemit_lambda,
fastemit_lambda: float,
clamp: float,
num_threads: int,
stream,
):
Expand All @@ -63,6 +64,7 @@ def __init__(
blank: Index of the RNNT blank token in the vocabulary. Generally the first or last token in the vocab.
fastemit_lambda: Float scaling factor for FastEmit regularization. Refer to
FastEmit: Low-latency Streaming ASR with Sequence-level Emission Regularization.
clamp: Float value. When set to value >= 0.0, will clamp the gradient to [-clamp, clamp].
num_threads: Number of OMP threads to launch.
stream: Numba Cuda Stream.
"""
Expand All @@ -75,6 +77,7 @@ def __init__(
) # a flat vector of floatX numbers that represents allocated memory slices
self.blank_ = blank
self.fastemit_lambda_ = fastemit_lambda
self.clamp_ = abs(clamp)
self.num_threads_ = num_threads
self.stream_ = stream # type: cuda.cudadrv.driver.Stream

Expand Down Expand Up @@ -220,6 +223,7 @@ def compute_cost_and_score(
self.alphabet_size_,
self.blank_,
self.fastemit_lambda_,
self.clamp_,
)

# // cost
Expand Down
Loading

0 comments on commit 7152327

Please sign in to comment.