Skip to content

Commit

Permalink
Add distributed Adam support for BERT (NVIDIA#5305)
Browse files Browse the repository at this point in the history
Signed-off-by: Tim Moon <[email protected]>

Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: andrusenkoau <[email protected]>
  • Loading branch information
timmoon10 authored and andrusenkoau committed Jan 5, 2023
1 parent 6f55999 commit 686676e
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 7 deletions.
3 changes: 2 additions & 1 deletion examples/nlp/language_modeling/megatron_bert_pretraining.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def main(cfg) -> None:
logging.info(f'\n{OmegaConf.to_yaml(cfg)}')

megatron_amp_o2 = cfg.model.get('megatron_amp_O2', False)
with_distributed_adam = cfg.model.optim.get('name') == 'distributed_fused_adam'

plugins = []
strategy = NLPDDPStrategy(
Expand All @@ -51,7 +52,7 @@ def main(cfg) -> None:
init_scale=cfg.model.get('native_amp_init_scale', 2 ** 32),
growth_interval=cfg.model.get('native_amp_growth_interval', 1000),
)
if megatron_amp_o2:
if megatron_amp_o2 and not with_distributed_adam:
plugins.append(MegatronHalfPrecisionPlugin(precision=cfg.trainer.precision, device='cuda', scaler=scaler))
else:
plugins.append(PipelineMixedPrecisionPlugin(precision=cfg.trainer.precision, device='cuda', scaler=scaler))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,8 @@ def __init__(self, cfg: DictConfig, trainer: Trainer):
self.model = build_model(model_provider_func=self.model_provider_func, wrap_with_ddp=False,)[0]

if self.megatron_amp_o2:
self.model.cuda(torch.cuda.current_device())
if not self.with_distributed_adam:
self.model.cuda(torch.cuda.current_device())
self.model = Float16Module(module=self.model, precision=cfg.precision)

def model_provider_func(self, pre_process, post_process):
Expand Down Expand Up @@ -210,11 +211,23 @@ def training_step(self, batch, batch_idx):
batch_for_pipeline = self.process_batch(batch)
tensor_shape = [self.cfg.encoder_seq_length, self.cfg.micro_batch_size, self.cfg.hidden_size]

if self.megatron_amp_o2:
custom_sync_context_handler = self._optimizer.no_sync
# handle asynchronous grad reduction
custom_sync_context_handler = None
custom_grad_sync_func = None
if self.with_distributed_adam:
if self.megatron_amp_o2:
# copy grads to main grad
custom_sync_context_handler = lambda: self._optimizer.no_sync(greedy_grad_copy=True)
else:
# keep grad tensors around
custom_sync_context_handler = lambda: self._optimizer.no_sync(greedy_grad_copy=False)
custom_grad_sync_func = self.reduce_overlap_gradients
else:
# TODO: enable async grad all reduce for O1/autocast mixed precision training
custom_sync_context_handler = None
if self.megatron_amp_o2:
custom_sync_context_handler = self._optimizer.no_sync
else:
# TODO: enable async grad all reduce for O1/autocast mixed precision training
custom_sync_context_handler = None

# run forward and backwards passes for an entire global batch
# we do this inside training_step to support pipeline parallelism
Expand All @@ -229,6 +242,7 @@ def training_step(self, batch, batch_idx):
dtype=self.autocast_dtype,
grad_scaler=self.trainer.precision_plugin.scaler if self.cfg.precision == 16 else None,
custom_sync_context_handler=custom_sync_context_handler,
custom_grad_sync_func=custom_grad_sync_func,
)

if losses_reduced_per_micro_batch:
Expand All @@ -238,7 +252,10 @@ def training_step(self, batch, batch_idx):
else:
loss_mean = torch.tensor([0.0, 0.0]).cuda()

if self.megatron_amp_o2 and self.cfg.get('pipeline_model_parallel_size', 1) > 1:
if self.with_distributed_adam:
# gradients are reduced internally in distributed optimizer
pass
elif self.megatron_amp_o2 and self.cfg.get('pipeline_model_parallel_size', 1) > 1:
# when using pipeline parallelism grads must be all-reduced after the pipeline (not asynchronously)
self._optimizer.allreduce_main_grads()
else:
Expand Down Expand Up @@ -588,6 +605,25 @@ def setup_optimizer_param_groups(self):
"""ModelPT override. Optimizer will get self._optimizer_param_groups"""
self._optimizer_param_groups = get_params_for_weight_decay_optimization(self.model)

def configure_optimizers(self):

if self.with_distributed_adam:

# Disable overlapped grad sync for embedding grad when
# pipeline parallelism is enabled
# See: allreduce_first_last_embeddings
if parallel_state.get_pipeline_model_parallel_world_size() > 1 and (
parallel_state.is_pipeline_first_stage(ignore_virtual=True)
or parallel_state.is_pipeline_last_stage(ignore_virtual=True)
):
module = self.model
if module.share_token_embeddings:
word_embeddings_weight = module.word_embeddings_weight()
word_embeddings_weight._disable_greedy_grad_copy = not self.megatron_amp_o2
word_embeddings_weight._disable_overlap_grad_sync = True

return super().configure_optimizers()

# Required for ONNX export
@property
def input_types(self) -> Optional[Dict[str, NeuralType]]:
Expand Down

0 comments on commit 686676e

Please sign in to comment.