diff --git a/nemo/collections/nlp/modules/common/megatron/adapters/parallel_adapters.py b/nemo/collections/nlp/modules/common/megatron/adapters/parallel_adapters.py index 8c34f528f2d9..953edef89820 100644 --- a/nemo/collections/nlp/modules/common/megatron/adapters/parallel_adapters.py +++ b/nemo/collections/nlp/modules/common/megatron/adapters/parallel_adapters.py @@ -187,6 +187,7 @@ def __init__( bias=False, gather_output=True, init_method=self._get_init_fn(column_init_method), + disable_grad_reduce=self._sequence_parallel, ) if gather_output: self.linear_out = RowParallelLinear(