diff --git a/examples/nlp/language_modeling/megatron_gpt_pretraining.py b/examples/nlp/language_modeling/megatron_gpt_pretraining.py index 1ac12b422108..596e3c59e753 100644 --- a/examples/nlp/language_modeling/megatron_gpt_pretraining.py +++ b/examples/nlp/language_modeling/megatron_gpt_pretraining.py @@ -13,6 +13,7 @@ # limitations under the License. +import torch.multiprocessing as mp from omegaconf.omegaconf import OmegaConf, open_dict from pytorch_lightning import Trainer from pytorch_lightning.plugins.environments import TorchElasticEnvironment @@ -29,6 +30,8 @@ from nemo.utils import logging from nemo.utils.exp_manager import exp_manager +mp.set_start_method("spawn", force=True) + @hydra_runner(config_path="conf", config_name="megatron_gpt_config") def main(cfg) -> None: diff --git a/nemo/collections/nlp/data/language_modeling/megatron/indexed_dataset.py b/nemo/collections/nlp/data/language_modeling/megatron/indexed_dataset.py index f45d9b7155a8..b2c5f4976f32 100644 --- a/nemo/collections/nlp/data/language_modeling/megatron/indexed_dataset.py +++ b/nemo/collections/nlp/data/language_modeling/megatron/indexed_dataset.py @@ -500,10 +500,10 @@ def __init__(self, path, skip_warmup=False): def __getstate__(self): return self._path - # def __setstate__(self, state): - # self._do_init(state) + def __setstate__(self, state): + self._do_init(state) - def _do_init(self, path, skip_warmup): + def _do_init(self, path, skip_warmup=True): self._path = path self._index = self.Index(index_file_path(self._path), skip_warmup) diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index db3f37ae3c90..2b1edc954972 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -13,7 +13,7 @@ # limitations under the License. import itertools -from typing import Any, List, Optional, Union +from typing import Any, Dict, List, Optional, Union import numpy as np import torch @@ -149,8 +149,6 @@ def __init__(self, cfg: DictConfig, trainer: Trainer): self._nsys_profile_start_step *= grad_accum_steps self._nsys_profile_end_step *= grad_accum_steps - self.get_attention_mask_from_fusion = self.cfg.get('get_attention_mask_from_fusion', False) - def set_inference_config(self, inference_config): self._inference_config = inference_config @@ -231,6 +229,18 @@ def setup_optimizer_param_groups(self): else: self._optimizer_param_groups = get_params_for_weight_decay_optimization(self.model) + def setup_optimization( + self, optim_config: Optional[Union[DictConfig, Dict]] = None, optim_kwargs: Optional[Dict[str, Any]] = None, + ): + optim_kwargs = {} if optim_kwargs is None else optim_kwargs.copy() + if self.with_distributed_adam: + + # Enable overlapped param sync by default + if 'overlap_param_sync' not in optim_kwargs: + optim_kwargs['overlap_param_sync'] = True + + return super().setup_optimization(optim_config=optim_config, optim_kwargs=optim_kwargs) + def configure_optimizers(self): if self.with_distributed_adam: @@ -522,43 +532,25 @@ def allreduce_first_last_embeddings(self): def get_forward_output_and_loss_func(self, validation_step=False): def fwd_output_and_loss_func(dataloader_iter, model, checkpoint_activations_all_layers=None): - batch = next(dataloader_iter) + # GPT3 uses only causal mask, which doesn't need attention mask if parallel_state.get_pipeline_model_parallel_world_size() == 1: + batch = next(dataloader_iter) for k in batch.keys(): - if self.get_attention_mask_from_fusion: - batch[k] = batch[k].cuda(non_blocking=True) if k not in ['attention_mask'] else None - else: - batch[k] = batch[k].cuda(non_blocking=True) + batch[k] = batch[k].cuda(non_blocking=True) if k not in ['attention_mask'] else None else: if parallel_state.is_pipeline_first_stage(): - # First pipeline stage needs tokens, position_ids, and attention_mask + batch = next(dataloader_iter) + # First pipeline stage needs only the tokens and position_ids for k in batch.keys(): - if self.get_attention_mask_from_fusion: - batch[k] = batch[k].cuda(non_blocking=True) if k in ['tokens', 'position_ids'] else None - else: - batch[k] = ( - batch[k].cuda(non_blocking=True) - if k in ['tokens', 'position_ids', 'attention_mask'] - else None - ) + batch[k] = batch[k].cuda(non_blocking=True) if k in ['tokens', 'position_ids'] else None elif parallel_state.is_pipeline_last_stage(): - # Last pipeline stage needs the labels, loss_mask, and attention_mask + batch = next(dataloader_iter) + # Last pipeline stage needs only the labels and loss_mask for k in batch.keys(): - if self.get_attention_mask_from_fusion: - batch[k] = batch[k].cuda(non_blocking=True) if k in ['labels', 'loss_mask'] else None - else: - batch[k] = ( - batch[k].cuda(non_blocking=True) - if k in ['labels', 'loss_mask', 'attention_mask'] - else None - ) + batch[k] = batch[k].cuda(non_blocking=True) if k in ['labels', 'loss_mask'] else None else: - # Intermediate pipeline stage only needs attention_mask - if self.get_attention_mask_from_fusion: - batch = {k: None for k in ['tokens', 'position_ids', 'attention_mask', 'labels']} - else: - for k in batch.keys(): - batch[k] = batch[k].cuda(non_blocking=True) if k in ['attention_mask'] else None + # Intermediate pipeline stage doesn't need any inputs + batch = {k: None for k in ['tokens', 'position_ids', 'attention_mask', 'labels']} output_tensor = model( batch['tokens'],