diff --git a/Dockerfile b/Dockerfile index 1bfac37ab35cb..85ba2353d5b5b 100644 --- a/Dockerfile +++ b/Dockerfile @@ -90,7 +90,7 @@ RUN --mount=from=nemo-src,target=/tmp/nemo cd /tmp/nemo && pip install ".[all]" # TODO: Try to remove once 21.07 container is the base container # install pinned numba version -RUN conda install -c numba numba=0.54.1 +RUN conda install -c conda-forge numba=0.54.1 # copy scripts/examples/tests into container for end user WORKDIR /workspace/nemo diff --git a/README.rst b/README.rst index 15f5ec8e6e6b9..417517ffae6ab 100644 --- a/README.rst +++ b/README.rst @@ -164,7 +164,7 @@ Note that RNNT requires numba to be installed from conda. conda remove numba pip uninstall numba - conda install -c numba numba + conda install -c conda-forge numba Megatron GPT ~~~~~~~~~~~~ diff --git a/examples/asr/conf/conformer/conformer_transducer_bpe.yaml b/examples/asr/conf/conformer/conformer_transducer_bpe.yaml index a145202e2b7f4..7e17566d0443a 100644 --- a/examples/asr/conf/conformer/conformer_transducer_bpe.yaml +++ b/examples/asr/conf/conformer/conformer_transducer_bpe.yaml @@ -182,6 +182,7 @@ model: # FastEmit regularization: https://arxiv.org/abs/2010.11148 # You may enable FastEmit to reduce the latency of the model for streaming fastemit_lambda: 0.0 # Recommended values to be in range [1e-4, 1e-2], 0.001 is a good start. + clamp: -1.0 # if > 0, applies gradient clamping in range [-clamp, clamp] for the joint tensor only. # Adds Gaussian noise to the gradients of the decoder to avoid overfitting variational_noise: diff --git a/examples/asr/conf/conformer/conformer_transducer_char.yaml b/examples/asr/conf/conformer/conformer_transducer_char.yaml index 36c6e9fe545c9..e7378338e2e77 100644 --- a/examples/asr/conf/conformer/conformer_transducer_char.yaml +++ b/examples/asr/conf/conformer/conformer_transducer_char.yaml @@ -177,6 +177,7 @@ model: # FastEmit regularization: https://arxiv.org/abs/2010.11148 # You may enable FastEmit to reduce the latency of the model for streaming fastemit_lambda: 0.0 # Recommended values to be in range [1e-4, 1e-2], 0.001 is a good start. + clamp: -1.0 # if > 0, applies gradient clamping in range [-clamp, clamp] for the joint tensor only. # Adds Gaussian noise to the gradients of the decoder to avoid overfitting variational_noise: diff --git a/examples/asr/conf/contextnet_rnnt/config_rnnt.yaml b/examples/asr/conf/contextnet_rnnt/config_rnnt.yaml index c3337fd60b037..ddebcbce9f538 100644 --- a/examples/asr/conf/contextnet_rnnt/config_rnnt.yaml +++ b/examples/asr/conf/contextnet_rnnt/config_rnnt.yaml @@ -204,7 +204,8 @@ model: loss: loss_name: "default" warprnnt_numba_kwargs: - fastemit_lambda: 0.0 + fastemit_lambda: 0.0 # Recommended values to be in range [1e-4, 1e-2], 0.001 is a good start. + clamp: -1.0 # if > 0, applies gradient clamping in range [-clamp, clamp] for the joint tensor only. optim: name: adam diff --git a/examples/asr/conf/contextnet_rnnt/config_rnnt_bpe.yaml b/examples/asr/conf/contextnet_rnnt/config_rnnt_bpe.yaml index 942257e62aaac..c621dbb21e6c9 100644 --- a/examples/asr/conf/contextnet_rnnt/config_rnnt_bpe.yaml +++ b/examples/asr/conf/contextnet_rnnt/config_rnnt_bpe.yaml @@ -204,7 +204,8 @@ model: loss: loss_name: "default" warprnnt_numba_kwargs: - fastemit_lambda: 0.0 + fastemit_lambda: 0.0 # Recommended values to be in range [1e-4, 1e-2], 0.001 is a good start. + clamp: -1.0 # if > 0, applies gradient clamping in range [-clamp, clamp] for the joint tensor only. optim: name: adam diff --git a/examples/asr/conf/contextnet_rnnt/contextnet_rnnt.yaml b/examples/asr/conf/contextnet_rnnt/contextnet_rnnt.yaml index 9d99dc193fc10..44eb15ab55ac3 100644 --- a/examples/asr/conf/contextnet_rnnt/contextnet_rnnt.yaml +++ b/examples/asr/conf/contextnet_rnnt/contextnet_rnnt.yaml @@ -451,6 +451,7 @@ model: warprnnt_numba_kwargs: # FastEmit regularization: https://arxiv.org/abs/2010.11148 fastemit_lambda: 0.001 # Values can be in range [1e-4, 1e-2]. Generally, 0.001 is good start. + clamp: -1.0 # if > 0, applies gradient clamping in range [-clamp, clamp] for the joint tensor only. optim: name: novograd diff --git a/examples/asr/conf/contextnet_rnnt/contextnet_rnnt_char.yaml b/examples/asr/conf/contextnet_rnnt/contextnet_rnnt_char.yaml index 805989c82e27f..4de58c7f3c06e 100644 --- a/examples/asr/conf/contextnet_rnnt/contextnet_rnnt_char.yaml +++ b/examples/asr/conf/contextnet_rnnt/contextnet_rnnt_char.yaml @@ -453,6 +453,7 @@ model: warprnnt_numba_kwargs: # FastEmit regularization: https://arxiv.org/abs/2010.11148 fastemit_lambda: 0.001 # Values can be in range [1e-4, 1e-2]. Generally, 0.001 is good start. + clamp: -1.0 # if > 0, applies gradient clamping in range [-clamp, clamp] for the joint tensor only. optim: name: novograd diff --git a/nemo/collections/asr/losses/rnnt.py b/nemo/collections/asr/losses/rnnt.py index c8b5d23da6240..dc8e1e6f3175f 100644 --- a/nemo/collections/asr/losses/rnnt.py +++ b/nemo/collections/asr/losses/rnnt.py @@ -161,7 +161,8 @@ def resolve_rnnt_loss(loss_name: str, blank_idx: int, loss_kwargs: dict = None) elif loss_name == 'warprnnt_numba': fastemit_lambda = loss_kwargs.pop('fastemit_lambda', 0.0) - loss_func = RNNTLossNumba(blank=blank_idx, reduction='none', fastemit_lambda=fastemit_lambda) + clamp = loss_kwargs.pop('clamp', -1.0) + loss_func = RNNTLossNumba(blank=blank_idx, reduction='none', fastemit_lambda=fastemit_lambda, clamp=clamp) _warn_unused_additional_kwargs(loss_name, loss_kwargs) else: diff --git a/nemo/collections/asr/parts/numba/rnnt_loss/rnnt.py b/nemo/collections/asr/parts/numba/rnnt_loss/rnnt.py index 4acdb680cd136..769f1f907c2a9 100644 --- a/nemo/collections/asr/parts/numba/rnnt_loss/rnnt.py +++ b/nemo/collections/asr/parts/numba/rnnt_loss/rnnt.py @@ -45,6 +45,7 @@ def rnnt_loss_cpu( grads: torch.Tensor, blank_label: int, fastemit_lambda: float, + clamp: float, num_threads: int, ): """ @@ -62,6 +63,7 @@ def rnnt_loss_cpu( blank_label: Index of the blank token in the vocabulary. fastemit_lambda: Float scaling factor for FastEmit regularization. Refer to FastEmit: Low-latency Streaming ASR with Sequence-level Emission Regularization. + clamp: Float value. When set to value >= 0.0, will clamp the gradient to [-clamp, clamp]. num_threads: Number of threads for OpenMP. """ # aliases @@ -96,6 +98,7 @@ def rnnt_loss_cpu( workspace=cpu_workspace, blank=blank_label, fastemit_lambda=fastemit_lambda, + clamp=clamp, num_threads=num_threads, batch_first=True, ) @@ -141,6 +144,7 @@ def rnnt_loss_gpu( grads: torch.Tensor, blank_label: int, fastemit_lambda: float, + clamp: float, num_threads: int, ): """ @@ -158,6 +162,7 @@ def rnnt_loss_gpu( blank_label: Index of the blank token in the vocabulary. fastemit_lambda: Float scaling factor for FastEmit regularization. Refer to FastEmit: Low-latency Streaming ASR with Sequence-level Emission Regularization. + clamp: Float value. When set to value >= 0.0, will clamp the gradient to [-clamp, clamp]. num_threads: Number of threads for OpenMP. """ minibatch_size = acts.shape[0] @@ -197,6 +202,7 @@ def rnnt_loss_gpu( workspace=gpu_workspace, blank=blank_label, fastemit_lambda=fastemit_lambda, + clamp=clamp, num_threads=num_threads, stream=stream, ) diff --git a/nemo/collections/asr/parts/numba/rnnt_loss/rnnt_numpy.py b/nemo/collections/asr/parts/numba/rnnt_loss/rnnt_numpy.py index 8a47b1a4041d7..3d418d36988fd 100644 --- a/nemo/collections/asr/parts/numba/rnnt_loss/rnnt_numpy.py +++ b/nemo/collections/asr/parts/numba/rnnt_loss/rnnt_numpy.py @@ -90,6 +90,25 @@ def _assert_no_grad(tensor): ) +class LogSoftmaxGradModification(Function): + @staticmethod + def forward(ctx, acts, clamp): + if clamp < 0: + raise ValueError("`clamp` must be 0.0 or positive float.") + + res = acts.new(acts) + ctx.clamp = clamp + return res + + @staticmethod + def backward(ctx, grad_output): + grad_output = torch.clamp(grad_output, -ctx.clamp, ctx.clamp) + return ( + grad_output, + None, + ) + + def forward_pass(log_probs, labels, blank): """ Computes probability of the forward variable alpha. @@ -300,7 +319,8 @@ def forward(ctx, acts, labels, act_lens, label_lens, blank, fastemit_lambda): @staticmethod def backward(ctx, grad_output): - return ctx.grads, None, None, None, None, None + grad_output = grad_output.view(-1, 1, 1, 1).to(ctx.grads) + return ctx.grads.mul(grad_output), None, None, None, None, None class RNNTLoss(Module): @@ -310,10 +330,11 @@ class RNNTLoss(Module): fastemit_lambda: Float scaling factor for FastEmit regularization. """ - def __init__(self, blank: int = 0, fastemit_lambda: float = 0.0): + def __init__(self, blank: int = 0, fastemit_lambda: float = 0.0, clamp: float = -1.0): super(RNNTLoss, self).__init__() self.blank = blank self.fastemit_lambda = fastemit_lambda + self.clamp = float(clamp) if clamp > 0 else 0.0 self.rnnt = _RNNT.apply def forward(self, acts, labels, act_lens, label_lens): @@ -323,6 +344,9 @@ def forward(self, acts, labels, act_lens, label_lens): _assert_no_grad(label_lens) certify_inputs(acts, labels, act_lens, label_lens) + if self.clamp > 0.0: + acts = LogSoftmaxGradModification.apply(acts, self.clamp) + acts = torch.nn.functional.log_softmax(acts, -1) return self.rnnt(acts, labels, act_lens, label_lens, self.blank, self.fastemit_lambda) diff --git a/nemo/collections/asr/parts/numba/rnnt_loss/rnnt_pytorch.py b/nemo/collections/asr/parts/numba/rnnt_loss/rnnt_pytorch.py index 10d9073e7c819..f28bbdb720aef 100644 --- a/nemo/collections/asr/parts/numba/rnnt_loss/rnnt_pytorch.py +++ b/nemo/collections/asr/parts/numba/rnnt_loss/rnnt_pytorch.py @@ -32,13 +32,14 @@ from torch.nn import Module from nemo.collections.asr.parts.numba.rnnt_loss import rnnt +from nemo.collections.asr.parts.numba.rnnt_loss.utils.cpu_utils import cpu_rnnt __all__ = ['rnnt_loss', 'RNNTLossNumba'] class _RNNTNumba(Function): @staticmethod - def forward(ctx, acts, labels, act_lens, label_lens, blank, reduction, fastemit_lambda): + def forward(ctx, acts, labels, act_lens, label_lens, blank, reduction, fastemit_lambda, clamp): """ log_probs: Tensor of (batch x seqLength x labelLength x outputDim) containing output from network labels: 2 dimensional Tensor containing all the targets of the batch with zero padded @@ -50,6 +51,8 @@ def forward(ctx, acts, labels, act_lens, label_lens, blank, reduction, fastemit_ is_cuda = acts.is_cuda certify_inputs(acts, labels, act_lens, label_lens) + if clamp < 0: + raise ValueError("`clamp` must be 0.0 or positive float value.") loss_func = rnnt.rnnt_loss_gpu if is_cuda else rnnt.rnnt_loss_cpu grads = torch.zeros_like(acts) if acts.requires_grad else None @@ -65,6 +68,7 @@ def forward(ctx, acts, labels, act_lens, label_lens, blank, reduction, fastemit_ grads=grads, blank_label=blank, fastemit_lambda=fastemit_lambda, + clamp=clamp, num_threads=0, ) @@ -84,11 +88,13 @@ def forward(ctx, acts, labels, act_lens, label_lens, blank, reduction, fastemit_ def backward(ctx, grad_output): if grad_output is not None and ctx.grads is not None: grad_output = grad_output.view(-1, 1, 1, 1).to(ctx.grads) - return ctx.grads.mul_(grad_output), None, None, None, None, None, None + return ctx.grads.mul_(grad_output), None, None, None, None, None, None, None -def rnnt_loss(acts, labels, act_lens, label_lens, blank=0, reduction='mean'): - """RNN Transducer Loss +def rnnt_loss( + acts, labels, act_lens, label_lens, blank=0, reduction='mean', fastemit_lambda: float = 0.0, clamp: float = 0.0 +): + """RNN Transducer Loss (functional form) Args: acts: Tensor of (batch x seqLength x labelLength x outputDim) containing output from network labels: 2 dimensional Tensor containing all the targets of the batch with zero padded @@ -101,9 +107,19 @@ def rnnt_loss(acts, labels, act_lens, label_lens, blank=0, reduction='mean'): then the mean over the batch is taken. Default: 'mean' """ if not acts.is_cuda: + # Since CPU requires log_softmax to be computed explicitly, we need to perform grad clipping + # *after* we have obtained the gradients of loss(logsoftmax()). + # This is highly wasteful since it requires a copy of the entire joint tensor which is expensive. + # CUDA version is much more efficient since it performs an inplace logsoftmax, and therefore + # can inplace clamp the gradient. + if clamp > 0.0: + acts = cpu_rnnt.LogSoftmaxGradModification.apply(acts, clamp) + + # NOTE manually done log_softmax for CPU version, + # log_softmax is computed within GPU version. acts = torch.nn.functional.log_softmax(acts, -1) - return _RNNTNumba.apply(acts, labels, act_lens, label_lens, blank, reduction) + return _RNNTNumba.apply(acts, labels, act_lens, label_lens, blank, reduction, fastemit_lambda, clamp) class RNNTLossNumba(Module): @@ -114,12 +130,16 @@ class RNNTLossNumba(Module): 'none' | 'mean' | 'sum'. 'none': no reduction will be applied, 'mean': the output losses will be divided by the target lengths and then the mean over the batch is taken. Default: 'mean' + fastemit_lambda: Float scaling factor for FastEmit regularization. Refer to + FastEmit: Low-latency Streaming ASR with Sequence-level Emission Regularization. + clamp: Float value. When set to value >= 0.0, will clamp the gradient to [-clamp, clamp]. """ - def __init__(self, blank=0, reduction='mean', fastemit_lambda: float = 0.0): + def __init__(self, blank=0, reduction='mean', fastemit_lambda: float = 0.0, clamp: float = -1): super(RNNTLossNumba, self).__init__() self.blank = blank self.fastemit_lambda = fastemit_lambda + self.clamp = float(clamp) if clamp > 0 else 0.0 self.reduction = reduction self.loss = _RNNTNumba.apply @@ -131,11 +151,21 @@ def forward(self, acts, labels, act_lens, label_lens): label_lens: Tensor of (batch) containing label length of each example """ if not acts.is_cuda: + # Since CPU requires log_softmax to be computed explicitly, we need to perform grad clipping + # *after* we have obtained the gradients of loss(logsoftmax()). + # This is highly wasteful since it requires a copy of the entire joint tensor which is expensive. + # CUDA version is much more efficient since it performs an inplace logsoftmax, and therefore + # can inplace clamp the gradient. + if self.clamp > 0.0: + acts = cpu_rnnt.LogSoftmaxGradModification.apply(acts, self.clamp) + # NOTE manually done log_softmax for CPU version, # log_softmax is computed within GPU version. acts = torch.nn.functional.log_softmax(acts, -1) - return self.loss(acts, labels, act_lens, label_lens, self.blank, self.reduction, self.fastemit_lambda) + return self.loss( + acts, labels, act_lens, label_lens, self.blank, self.reduction, self.fastemit_lambda, self.clamp + ) def check_type(var, t, name): diff --git a/nemo/collections/asr/parts/numba/rnnt_loss/utils/cpu_utils/cpu_rnnt.py b/nemo/collections/asr/parts/numba/rnnt_loss/utils/cpu_utils/cpu_rnnt.py index 6bf148148ac12..1528606716e10 100644 --- a/nemo/collections/asr/parts/numba/rnnt_loss/utils/cpu_utils/cpu_rnnt.py +++ b/nemo/collections/asr/parts/numba/rnnt_loss/utils/cpu_utils/cpu_rnnt.py @@ -32,6 +32,7 @@ import numba import torch +from torch.autograd import Function from nemo.collections.asr.parts.numba.rnnt_loss.utils import global_constants @@ -137,6 +138,29 @@ def setup_probs( self.log_probs2[offset + 1] = log_probs[idx(t, u, labels[u])] +class LogSoftmaxGradModification(Function): + @staticmethod + def forward(ctx, acts, clamp): + if clamp < 0: + raise ValueError("`clamp` must be 0.0 or positive float.") + + # This is needed for correctness (inplace is problematic), + # but it wastes a log of memory. + res = acts.new(acts) + ctx.clamp = clamp + return res + + @staticmethod + def backward(ctx, grad_output): + # Clamp the gradients of loss(logsoftmax(...)) + # CPU computes logsoftmax explicitly, so we need to override t + grad_output = torch.clamp(grad_output, -ctx.clamp, ctx.clamp) + return ( + grad_output, + None, + ) + + class CPURNNT: def __init__( self, @@ -147,6 +171,7 @@ def __init__( workspace: torch.Tensor, blank: int, fastemit_lambda: float, + clamp: float, num_threads: int, batch_first: bool, ): @@ -163,6 +188,7 @@ def __init__( blank: Index of the RNNT blank token in the vocabulary. Generally the first or last token in the vocab. fastemit_lambda: Float scaling factor for FastEmit regularization. Refer to FastEmit: Low-latency Streaming ASR with Sequence-level Emission Regularization. + clamp: Float value. When set to value >= 0.0, will clamp the gradient to [-clamp, clamp]. num_threads: Number of OMP threads to launch. batch_first: Bool that decides if batch dimension is first or third. """ @@ -173,6 +199,7 @@ def __init__( self.workspace = workspace # a flat vector of floatX numbers that represents allocated memory slices self.blank_ = blank self.fastemit_lambda_ = fastemit_lambda + self.clamp_ = abs(clamp) self.num_threads_ = num_threads self.batch_first = batch_first diff --git a/nemo/collections/asr/parts/numba/rnnt_loss/utils/cuda_utils/gpu_rnnt.py b/nemo/collections/asr/parts/numba/rnnt_loss/utils/cuda_utils/gpu_rnnt.py index 9277c530056cf..42890821405e1 100644 --- a/nemo/collections/asr/parts/numba/rnnt_loss/utils/cuda_utils/gpu_rnnt.py +++ b/nemo/collections/asr/parts/numba/rnnt_loss/utils/cuda_utils/gpu_rnnt.py @@ -46,7 +46,8 @@ def __init__( alphabet_size: int, workspace, blank: int, - fastemit_lambda, + fastemit_lambda: float, + clamp: float, num_threads: int, stream, ): @@ -63,6 +64,7 @@ def __init__( blank: Index of the RNNT blank token in the vocabulary. Generally the first or last token in the vocab. fastemit_lambda: Float scaling factor for FastEmit regularization. Refer to FastEmit: Low-latency Streaming ASR with Sequence-level Emission Regularization. + clamp: Float value. When set to value >= 0.0, will clamp the gradient to [-clamp, clamp]. num_threads: Number of OMP threads to launch. stream: Numba Cuda Stream. """ @@ -75,6 +77,7 @@ def __init__( ) # a flat vector of floatX numbers that represents allocated memory slices self.blank_ = blank self.fastemit_lambda_ = fastemit_lambda + self.clamp_ = abs(clamp) self.num_threads_ = num_threads self.stream_ = stream # type: cuda.cudadrv.driver.Stream @@ -220,6 +223,7 @@ def compute_cost_and_score( self.alphabet_size_, self.blank_, self.fastemit_lambda_, + self.clamp_, ) # // cost diff --git a/nemo/collections/asr/parts/numba/rnnt_loss/utils/cuda_utils/gpu_rnnt_kernel.py b/nemo/collections/asr/parts/numba/rnnt_loss/utils/cuda_utils/gpu_rnnt_kernel.py index bcca5bf33b8aa..db92147b3362c 100644 --- a/nemo/collections/asr/parts/numba/rnnt_loss/utils/cuda_utils/gpu_rnnt_kernel.py +++ b/nemo/collections/asr/parts/numba/rnnt_loss/utils/cuda_utils/gpu_rnnt_kernel.py @@ -33,7 +33,7 @@ from nemo.collections.asr.parts.numba.rnnt_loss.utils import rnnt_helper -GPU_RNNT_THREAD_SIZE = 128 +GPU_RNNT_THREAD_SIZE = 256 @cuda.jit(device=True, inline=True) @@ -278,6 +278,7 @@ def compute_grad_kernel( alphabet_size: int, blank_: int, fastemit_lambda: float, + clamp: float, ): """ Compute gradients over the transduction step. @@ -305,6 +306,7 @@ def compute_grad_kernel( blank_: Index of the RNNT blank token in the vocabulary. Generally the first or last token in the vocab. fastemit_lambda: Float scaling factor for FastEmit regularization. Refer to FastEmit: Low-latency Streaming ASR with Sequence-level Emission Regularization. + clamp: Float value. When set to value >= 0.0, will clamp the gradient to [-clamp, clamp]. Updates: Kernel inplace updates the following inputs: @@ -385,6 +387,13 @@ def compute_grad_kernel( # update grads[b, t, u, v] = grad grads[col * alphabet_size + idx] = grad + # clamp gradient (if needed) + if clamp > 0.0: + g = grads[col * alphabet_size + idx] + g = min(g, clamp) + g = max(g, -clamp) + grads[col * alphabet_size + idx] = g + # update internal index through the thread_buffer; # until idx < V + 1, such that entire vocabulary has been updated. idx += GPU_RNNT_THREAD_SIZE diff --git a/reinstall.sh b/reinstall.sh index 76ede0acc5c85..808f7685bc1f1 100755 --- a/reinstall.sh +++ b/reinstall.sh @@ -33,11 +33,11 @@ echo 'Installing additional nemo_text_processing conda dependency' bash nemo_text_processing/setup.sh > /dev/null 2>&1 && echo "nemo_text_processing installed!" || echo "nemo_text_processing could not be installed!" if [ -x "$(command -v conda)" ]; then - # we need at least numba .53, and .54 breaks the PyTorch 21.06 container - echo 'Installing numba=0.53.1' - conda install -y -c numba numba=0.53. + + echo 'Installing numba=0.55.0' + conda install -y -c conda-forge numba==0.55 # echo 'Attempting update to numba installation via conda' - # conda update -c numba numba -y > /dev/null 2>&1 && echo "Numba updated!" || echo "Numba could not be updated!" + # conda update -c conda-forge numba -y > /dev/null 2>&1 && echo "Numba updated!" || echo "Numba could not be updated!" fi echo 'All done!' diff --git a/tests/collections/asr/numba/rnnt_loss/test_rnnt_pytorch.py b/tests/collections/asr/numba/rnnt_loss/test_rnnt_pytorch.py index e959bfb8ee6b0..daceaf460a1ed 100644 --- a/tests/collections/asr/numba/rnnt_loss/test_rnnt_pytorch.py +++ b/tests/collections/asr/numba/rnnt_loss/test_rnnt_pytorch.py @@ -141,7 +141,7 @@ def test_case_small_random_fastemit_reg(self, device, fastemit_lambda): @pytest.mark.unit @pytest.mark.parametrize('device', DEVICES) - def big_test(self, device): + def test_case_big_tensor(self, device): if device == 'cuda': numba_utils.skip_numba_cuda_test_if_unsupported(__NUMBA_MINIMUM_VERSION__) @@ -280,6 +280,84 @@ def test_case_large_random(self, device): assert np.allclose(pt_cost, np_cost, atol=1e-5, rtol=1e-3), "large_random_test costs mismatch." assert np.allclose(pt_grads, np_grads, atol=1e-5, rtol=1e-3), "large_random_test gradient mismatch." + @pytest.mark.unit + @pytest.mark.parametrize('device', DEVICES) + def test_case_small_clamp(self, device): + if device == 'cuda': + numba_utils.skip_numba_cuda_test_if_unsupported(__NUMBA_MINIMUM_VERSION__) + + GRAD_CLAMP = 0.1 + acts = np.array( + [ + [ + [[0.1, 0.6, 0.1, 0.1, 0.1], [0.1, 0.1, 0.6, 0.1, 0.1], [0.1, 0.1, 0.2, 0.8, 0.1]], + [[0.1, 0.6, 0.1, 0.1, 0.1], [0.1, 0.1, 0.2, 0.1, 0.1], [0.7, 0.1, 0.2, 0.1, 0.1]], + ] + ] + ) + labels = [[1, 2]] + + fn_pt = RNNTLossNumba(blank=0, reduction='sum', clamp=GRAD_CLAMP) + pt_cost, pt_grads = wrap_and_call(fn_pt, acts, labels, device) + + fn_np = RNNTLoss_Numpy(blank=0, clamp=GRAD_CLAMP) + np_cost, np_grads = wrap_and_call(fn_np, acts, labels, device) + + expected_cost = 4.495666 + expected_grads = np.array( + [ + [ + [ + [-0.1, -0.1, 0.1, 0.1, 0.1], + [-0.1, 0.1, -0.1, 0.1, 0.1], + [-0.1, 0.06269141, 0.06928472, 0.1, 0.06269141], + ], + [ + [0.05456069, -0.1, 0.05456069, 0.05456069, 0.05456069], + [0.1, 0.1, -0.1, 0.1, 0.1], + [-0.1, 0.1, 0.1, 0.1, 0.1], + ], + ] + ] + ) + + assert np.allclose(pt_cost, expected_cost, rtol=1e-6), "small_test costs mismatch." + assert np.allclose(pt_grads, expected_grads), "small_test gradient mismatch." + + assert np.allclose(pt_cost, np_cost, rtol=1e-6), "small_test costs mismatch." + assert np.allclose(pt_grads, np_grads), "small_test gradient mismatch." + + @pytest.mark.unit + @pytest.mark.parametrize('device', DEVICES) + @pytest.mark.parametrize('fastemit_lambda', [1.0, 0.01, 0.00001]) + def test_case_small_fastemit_clamp(self, device, fastemit_lambda): + if device == 'cuda': + numba_utils.skip_numba_cuda_test_if_unsupported(__NUMBA_MINIMUM_VERSION__) + + GRAD_CLAMP = 0.1 + acts = np.array( + [ + [ + [[0.1, 0.6, 0.1, 0.1, 0.1], [0.1, 0.1, 0.6, 0.1, 0.1], [0.1, 0.1, 0.2, 0.8, 0.1]], + [[0.1, 0.6, 0.1, 0.1, 0.1], [0.1, 0.1, 0.2, 0.1, 0.1], [0.7, 0.1, 0.2, 0.1, 0.1]], + ] + ] + ) + labels = [[1, 2]] + + fn_pt = RNNTLossNumba(blank=0, reduction='sum', fastemit_lambda=fastemit_lambda, clamp=GRAD_CLAMP) + pt_cost, pt_grads = wrap_and_call(fn_pt, acts, labels, device) + + fn_np = RNNTLoss_Numpy(blank=0, fastemit_lambda=fastemit_lambda, clamp=GRAD_CLAMP) + np_cost, np_grads = wrap_and_call(fn_np, acts, labels, device) + + expected_cost = 4.495666 + expected_cost += expected_cost * fastemit_lambda + + assert np.allclose(pt_cost, expected_cost, rtol=1e-6), "small_test costs mismatch." + assert np.allclose(pt_cost, np_cost, rtol=1e-6), "small_test costs mismatch." + assert np.allclose(pt_grads, np_grads), "small_test gradient mismatch." + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/collections/asr/numba/rnnt_loss/utils/test_gpu_rnnt_kernel.py b/tests/collections/asr/numba/rnnt_loss/utils/test_gpu_rnnt_kernel.py index a0cb84cbb77ae..a45c7e8929c1c 100644 --- a/tests/collections/asr/numba/rnnt_loss/utils/test_gpu_rnnt_kernel.py +++ b/tests/collections/asr/numba/rnnt_loss/utils/test_gpu_rnnt_kernel.py @@ -26,11 +26,18 @@ def log_softmax(x, axis=-1): x = torch.from_numpy(x) # zero-copy - x = torch.log_softmax(x, axis) + x = torch.log_softmax(x, dim=axis) x = x.numpy() return x +def log_softmax_grad(x, axis=-1): + x = torch.tensor(x, requires_grad=True) # alloc memory + y = torch.log_softmax(x, dim=axis) + y.sum().backward() + return x.grad.numpy() + + class TestRNNTCUDAKernels: @pytest.mark.skipif(not cuda.is_available(), reason="CUDA Reductions can only be run when CUDA is available") @pytest.mark.unit @@ -167,3 +174,333 @@ def test_compute_betas_kernel(self): assert np.abs(ll_diff).mean() <= 1e-5 assert np.square(ll_diff).mean() <= 1e-10 + + @pytest.mark.skipif(not cuda.is_available(), reason="CUDA Reductions can only be run when CUDA is available") + @pytest.mark.unit + def test_compute_grads_kernel(self): + numba_utils.skip_numba_cuda_test_if_unsupported(__NUMBA_MINIMUM_VERSION__) + + fastemit_lambda = 0.0 + clamp = 0.0 + + random = np.random.RandomState(0) + original_shape = [1, 5, 11, 3] + B, T, U, V = original_shape + + # Numpy kernel + x = random.randn(*original_shape) + labels = torch.from_numpy(np.array([[1, 1, 1, 2, 2, 2, 1, 2, 2, 1]], dtype=np.int32)) # [1, 10] + audio_len = torch.from_numpy(np.array([T], dtype=np.int32)) + label_len = torch.from_numpy(np.array([U - 1], dtype=np.int32)) + blank_idx = 0 + + x_np = torch.from_numpy(x) + x_np.requires_grad_(True) + + """ + Here we will directly utilize the numpy variant of the loss without explicitly calling + the numpy functions for alpha, beta and grads. + + This is because the grads returned by the rnnt_numpy.transduce_batch() are : + d/dx (alpha + beta alignment)(log_softmax(x)). + But according to the chain rule, we'd still need to compute the gradient of log_softmax(x) + and update the alignments by hand. Instead, we will rely on pytorch to compute the gradient + of the log_softmax(x) step and propagate it backwards. + """ + loss_func = rnnt_numpy.RNNTLoss(blank_idx, fastemit_lambda=fastemit_lambda, clamp=clamp) + loss_val = loss_func(x_np, labels, audio_len, label_len) + loss_val.sum().backward() + true_grads = x_np.grad + + # Pytorch kernel + device = torch.device('cuda') + if hasattr(cuda, 'external_stream'): + stream = cuda.external_stream(torch.cuda.current_stream(device).cuda_stream) + else: + stream = cuda.default_stream() + + x_c = torch.tensor(x, device=device, dtype=torch.float32) + labels_c = torch.tensor(labels, device=device, dtype=torch.int32) + + # Allocate workspace memory + denom = torch.zeros(B * T * U, device=device, dtype=x_c.dtype) + alphas = torch.zeros(B * T * U, device=device, dtype=x_c.dtype) + betas = torch.zeros(B * T * U, device=device, dtype=x_c.dtype) + llForward = torch.zeros(B, device=device, dtype=x_c.dtype) + llBackward = torch.zeros(B, device=device, dtype=x_c.dtype) + input_lengths = torch.tensor([T], dtype=torch.int32, device=device) + label_lengths = torch.tensor([len(labels[0])], dtype=torch.int32, device=device) + + # certify input data + certify_inputs(x_c, labels_c, input_lengths, label_lengths) + + # flatten activation tensor (for pointer based indexing) + x_c = x_c.view([-1]) + grads = torch.zeros_like(x_c, requires_grad=False) + + # call kernel + # log softmax reduction + reduce.reduce_max(x_c, denom, rows=V, cols=B * T * U, minus=False, stream=stream) + reduce.reduce_exp(x_c, denom, rows=V, cols=B * T * U, minus=True, stream=stream) + + # alpha kernel + gpu_rnnt_kernel.compute_alphas_kernel[B, U, stream, 0]( + x_c, denom, alphas, llForward, input_lengths, label_lengths, labels_c, B, T, U, V, blank_idx, + ) + + # beta kernel + gpu_rnnt_kernel.compute_betas_kernel[B, U, stream, 0]( + x_c, denom, betas, llBackward, input_lengths, label_lengths, labels_c, B, T, U, V, blank_idx, + ) + + # gamma kernel + grad_blocks_per_grid = B * T * U + grad_threads_per_block = gpu_rnnt_kernel.GPU_RNNT_THREAD_SIZE + gpu_rnnt_kernel.compute_grad_kernel[grad_blocks_per_grid, grad_threads_per_block, stream, 0]( + grads, + x_c, + denom, + alphas, + betas, + llForward, + input_lengths, + label_lengths, + labels_c, + B, + T, + U, + V, + blank_idx, + fastemit_lambda, + clamp, + ) + + # sync kernel + stream.synchronize() + + # reshape grads + grads = grads.view([B, T, U, V]) + diff = true_grads - grads[0].cpu().numpy() + + assert np.abs(diff).mean() <= 1e-5 + assert np.square(diff).mean() <= 1e-10 + + @pytest.mark.skipif(not cuda.is_available(), reason="CUDA Reductions can only be run when CUDA is available") + @pytest.mark.unit + def test_compute_grads_kernel_fastemit(self): + numba_utils.skip_numba_cuda_test_if_unsupported(__NUMBA_MINIMUM_VERSION__) + + fastemit_lambda = 0.001 + clamp = 0.0 + + random = np.random.RandomState(0) + original_shape = [1, 5, 11, 3] + B, T, U, V = original_shape + + # Numpy kernel + x = random.randn(*original_shape) + labels = torch.from_numpy(np.array([[1, 1, 1, 2, 2, 2, 1, 2, 2, 1]], dtype=np.int32)) # [1, 10] + audio_len = torch.from_numpy(np.array([T], dtype=np.int32)) + label_len = torch.from_numpy(np.array([U - 1], dtype=np.int32)) + blank_idx = 0 + + x_np = torch.from_numpy(x) + x_np.requires_grad_(True) + + """ + Here we will directly utilize the numpy variant of the loss without explicitly calling + the numpy functions for alpha, beta and grads. + + This is because the grads returned by the rnnt_numpy.transduce_batch() are : + d/dx (alpha + beta alignment)(log_softmax(x)). + But according to the chain rule, we'd still need to compute the gradient of log_softmax(x) + and update the alignments by hand. Instead, we will rely on pytorch to compute the gradient + of the log_softmax(x) step and propagate it backwards. + """ + loss_func = rnnt_numpy.RNNTLoss(blank_idx, fastemit_lambda=fastemit_lambda, clamp=clamp) + loss_val = loss_func(x_np, labels, audio_len, label_len) + loss_val.sum().backward() + true_grads = x_np.grad + + # Pytorch kernel + device = torch.device('cuda') + if hasattr(cuda, 'external_stream'): + stream = cuda.external_stream(torch.cuda.current_stream(device).cuda_stream) + else: + stream = cuda.default_stream() + + x_c = torch.tensor(x, device=device, dtype=torch.float32) + labels_c = torch.tensor(labels, device=device, dtype=torch.int32) + + # Allocate workspace memory + denom = torch.zeros(B * T * U, device=device, dtype=x_c.dtype) + alphas = torch.zeros(B * T * U, device=device, dtype=x_c.dtype) + betas = torch.zeros(B * T * U, device=device, dtype=x_c.dtype) + llForward = torch.zeros(B, device=device, dtype=x_c.dtype) + llBackward = torch.zeros(B, device=device, dtype=x_c.dtype) + input_lengths = torch.tensor([T], dtype=torch.int32, device=device) + label_lengths = torch.tensor([len(labels[0])], dtype=torch.int32, device=device) + + # certify input data + certify_inputs(x_c, labels_c, input_lengths, label_lengths) + + # flatten activation tensor (for pointer based indexing) + x_c = x_c.view([-1]) + grads = torch.zeros_like(x_c, requires_grad=False) + + # call kernel + # log softmax reduction + reduce.reduce_max(x_c, denom, rows=V, cols=B * T * U, minus=False, stream=stream) + reduce.reduce_exp(x_c, denom, rows=V, cols=B * T * U, minus=True, stream=stream) + + # alpha kernel + gpu_rnnt_kernel.compute_alphas_kernel[B, U, stream, 0]( + x_c, denom, alphas, llForward, input_lengths, label_lengths, labels_c, B, T, U, V, blank_idx, + ) + + # beta kernel + gpu_rnnt_kernel.compute_betas_kernel[B, U, stream, 0]( + x_c, denom, betas, llBackward, input_lengths, label_lengths, labels_c, B, T, U, V, blank_idx, + ) + + # gamma kernel + grad_blocks_per_grid = B * T * U + grad_threads_per_block = gpu_rnnt_kernel.GPU_RNNT_THREAD_SIZE + gpu_rnnt_kernel.compute_grad_kernel[grad_blocks_per_grid, grad_threads_per_block, stream, 0]( + grads, + x_c, + denom, + alphas, + betas, + llForward, + input_lengths, + label_lengths, + labels_c, + B, + T, + U, + V, + blank_idx, + fastemit_lambda, + clamp, + ) + + # sync kernel + stream.synchronize() + + # reshape grads + grads = grads.view([B, T, U, V]) + diff = true_grads - grads[0].cpu().numpy() + + assert np.abs(diff).mean() <= 1e-5 + assert np.square(diff).mean() <= 1e-10 + + @pytest.mark.skipif(not cuda.is_available(), reason="CUDA Reductions can only be run when CUDA is available") + @pytest.mark.unit + def test_compute_grads_kernel_clamp(self): + numba_utils.skip_numba_cuda_test_if_unsupported(__NUMBA_MINIMUM_VERSION__) + + fastemit_lambda = 0.0 + clamp = 0.1 + + random = np.random.RandomState(0) + original_shape = [1, 5, 11, 3] + B, T, U, V = original_shape + + # Numpy kernel + x = random.randn(*original_shape) + labels = torch.from_numpy(np.array([[1, 1, 1, 2, 2, 2, 1, 2, 2, 1]], dtype=np.int32)) # [1, 10] + audio_len = torch.from_numpy(np.array([T], dtype=np.int32)) + label_len = torch.from_numpy(np.array([U - 1], dtype=np.int32)) + blank_idx = 0 + + x_np = torch.from_numpy(x) + x_np.requires_grad_(True) + + """ + Here we will directly utilize the numpy variant of the loss without explicitly calling + the numpy functions for alpha, beta and grads. + + This is because the grads returned by the rnnt_numpy.transduce_batch() are : + d/dx (alpha + beta alignment)(log_softmax(x)). + But according to the chain rule, we'd still need to compute the gradient of log_softmax(x) + and update the alignments by hand. Instead, we will rely on pytorch to compute the gradient + of the log_softmax(x) step and propagate it backwards. + """ + loss_func = rnnt_numpy.RNNTLoss(blank_idx, fastemit_lambda=fastemit_lambda, clamp=clamp) + loss_val = loss_func(x_np, labels, audio_len, label_len) + loss_val.sum().backward() + true_grads = x_np.grad + + # Pytorch kernel + device = torch.device('cuda') + if hasattr(cuda, 'external_stream'): + stream = cuda.external_stream(torch.cuda.current_stream(device).cuda_stream) + else: + stream = cuda.default_stream() + + x_c = torch.tensor(x, device=device, dtype=torch.float32) + labels_c = torch.tensor(labels, device=device, dtype=torch.int32) + + # Allocate workspace memory + denom = torch.zeros(B * T * U, device=device, dtype=x_c.dtype) + alphas = torch.zeros(B * T * U, device=device, dtype=x_c.dtype) + betas = torch.zeros(B * T * U, device=device, dtype=x_c.dtype) + llForward = torch.zeros(B, device=device, dtype=x_c.dtype) + llBackward = torch.zeros(B, device=device, dtype=x_c.dtype) + input_lengths = torch.tensor([T], dtype=torch.int32, device=device) + label_lengths = torch.tensor([len(labels[0])], dtype=torch.int32, device=device) + + # certify input data + certify_inputs(x_c, labels_c, input_lengths, label_lengths) + + # flatten activation tensor (for pointer based indexing) + x_c = x_c.view([-1]) + grads = torch.zeros_like(x_c, requires_grad=False) + + # call kernel + # log softmax reduction + reduce.reduce_max(x_c, denom, rows=V, cols=B * T * U, minus=False, stream=stream) + reduce.reduce_exp(x_c, denom, rows=V, cols=B * T * U, minus=True, stream=stream) + + # alpha kernel + gpu_rnnt_kernel.compute_alphas_kernel[B, U, stream, 0]( + x_c, denom, alphas, llForward, input_lengths, label_lengths, labels_c, B, T, U, V, blank_idx, + ) + + # beta kernel + gpu_rnnt_kernel.compute_betas_kernel[B, U, stream, 0]( + x_c, denom, betas, llBackward, input_lengths, label_lengths, labels_c, B, T, U, V, blank_idx, + ) + + # gamma kernel + grad_blocks_per_grid = B * T * U + grad_threads_per_block = gpu_rnnt_kernel.GPU_RNNT_THREAD_SIZE + gpu_rnnt_kernel.compute_grad_kernel[grad_blocks_per_grid, grad_threads_per_block, stream, 0]( + grads, + x_c, + denom, + alphas, + betas, + llForward, + input_lengths, + label_lengths, + labels_c, + B, + T, + U, + V, + blank_idx, + fastemit_lambda, + clamp, + ) + + # sync kernel + stream.synchronize() + + # reshape grads + grads = grads.view([B, T, U, V]) + diff = true_grads - grads[0].cpu().numpy() + + assert np.abs(diff).mean() <= 1e-5 + assert np.square(diff).mean() <= 1e-10 diff --git a/tutorials/asr/ASR_with_Transducers.ipynb b/tutorials/asr/ASR_with_Transducers.ipynb index c14ea61dfb2f1..e6f60a943dffa 100644 --- a/tutorials/asr/ASR_with_Transducers.ipynb +++ b/tutorials/asr/ASR_with_Transducers.ipynb @@ -64,9 +64,9 @@ "source": [ "# In a conda environment, you would use the following command\n", "# Update Numba to > 0.53\n", - "# conda install -c numba numba\n", + "# conda install -c conda-forge numba\n", "# or\n", - "# conda update -c numba numba\n", + "# conda update -c conda-forge numba\n", "\n", "# For pip based environments,\n", "# Update Numba to > 0.53\n", diff --git a/tutorials/asr/Intro_to_Transducers.ipynb b/tutorials/asr/Intro_to_Transducers.ipynb index ee506b9a793f3..81788d71049be 100644 --- a/tutorials/asr/Intro_to_Transducers.ipynb +++ b/tutorials/asr/Intro_to_Transducers.ipynb @@ -57,9 +57,9 @@ "source": [ "# In a conda environment, you would use the following command\n", "# Update Numba to > 0.53\n", - "# conda install -c numba numba\n", + "# conda install -c conda-forge numba\n", "# or\n", - "# conda update -c numba numba\n", + "# conda update -c conda-forge numba\n", "\n", "# For pip based environments,\n", "# Update Numba to > 0.54\n",