Skip to content

Commit

Permalink
Fix RNN-T loss memory usage (NVIDIA#11144)
Browse files Browse the repository at this point in the history
* Fix RNN-T memory usage

Signed-off-by: artbataev <[email protected]>

---------

Signed-off-by: Vladimir Bataev <[email protected]>
  • Loading branch information
artbataev authored and XuesongYang committed Jan 18, 2025
1 parent cd52873 commit 9450f55
Showing 1 changed file with 16 additions and 14 deletions.
30 changes: 16 additions & 14 deletions nemo/collections/asr/parts/numba/rnnt_loss/rnnt_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,15 +80,16 @@ def forward(ctx, acts, labels, act_lens, label_lens, blank, reduction, fastemit_
if grads is not None:
grads /= minibatch_size

ctx.grads = grads
ctx.save_for_backward(grads)

return costs

@staticmethod
def backward(ctx, grad_output):
if grad_output is not None and ctx.grads is not None:
grad_output = grad_output.view(-1, 1, 1, 1).to(ctx.grads)
return ctx.grads.mul_(grad_output), None, None, None, None, None, None, None
(grads,) = ctx.saved_tensors
if grad_output is not None and grads is not None:
grad_output = grad_output.view(-1, 1, 1, 1).to(grads)
return grads.mul_(grad_output), None, None, None, None, None, None, None


class _TDTNumba(Function):
Expand Down Expand Up @@ -170,18 +171,18 @@ def forward(
label_grads /= minibatch_size
duration_grads /= minibatch_size

ctx.label_grads = label_grads
ctx.duration_grads = duration_grads
ctx.save_for_backward(label_grads, duration_grads)

return costs

@staticmethod
def backward(ctx, grad_output):
if grad_output is not None and ctx.label_grads is not None:
grad_output = grad_output.view(-1, 1, 1, 1).to(ctx.label_grads)
label_grads, duration_grads = ctx.saved_tensors
if grad_output is not None and label_grads is not None:
grad_output = grad_output.view(-1, 1, 1, 1).to(label_grads)
return (
ctx.label_grads.mul_(grad_output),
ctx.duration_grads.mul_(grad_output),
label_grads.mul_(grad_output),
duration_grads.mul_(grad_output),
None,
None,
None,
Expand Down Expand Up @@ -251,15 +252,16 @@ def forward(
if grads is not None:
grads /= minibatch_size

ctx.grads = grads
ctx.save_for_backward(grads)

return costs

@staticmethod
def backward(ctx, grad_output):
if grad_output is not None and ctx.grads is not None:
grad_output = grad_output.view(-1, 1, 1, 1).to(ctx.grads)
return ctx.grads.mul_(grad_output), None, None, None, None, None, None, None, None, None, None
(grads,) = ctx.saved_tensors
if grad_output is not None and grads is not None:
grad_output = grad_output.view(-1, 1, 1, 1).to(grads)
return grads.mul_(grad_output), None, None, None, None, None, None, None, None, None, None


def rnnt_loss(
Expand Down

0 comments on commit 9450f55

Please sign in to comment.