diff --git a/nemo/backends/pytorch/actions.py b/nemo/backends/pytorch/actions.py index 7115eebf41af..faad7a28436f 100644 --- a/nemo/backends/pytorch/actions.py +++ b/nemo/backends/pytorch/actions.py @@ -1239,7 +1239,9 @@ def train( # By default, disable broadcast_buffers. This disables batch norm synchronization on forward # pass - pmodule = DDP(pmodule, device_ids=[self.local_rank], broadcast_buffers=False) + pmodule = DDP( + pmodule, device_ids=[self.local_rank], broadcast_buffers=False, find_unused_parameters=True + ) # # Convert batchnorm modules to synced if applicable # if synced_batchnorm and isinstance(pmodule, torch.nn.Module):