diff --git a/nemo/backends/pytorch/nm.py b/nemo/backends/pytorch/nm.py index c7d568ed1b30..2b1c33b48e43 100644 --- a/nemo/backends/pytorch/nm.py +++ b/nemo/backends/pytorch/nm.py @@ -142,8 +142,7 @@ def __init__(self, name=None): def __call__(self, force_pt=False, *input, **kwargs): pt_call = len(input) > 0 or force_pt if pt_call: - with t.no_grad(): - return self.forward(*input, **kwargs) + return self.forward(*input, **kwargs) else: return NeuralModule.__call__(self, **kwargs)