Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NeMo-UX] Adding context- & expert-parallelism to MegatronStrategy #9525

Merged
merged 1 commit into from
Jun 24, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 42 additions & 3 deletions nemo/lightning/pytorch/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,20 +47,53 @@
class MegatronStrategy(DDPStrategy, io.IOMixin):
"""Megatron plugin for Pytorch Lightning.

This strategy implements model parallelism using NVIDIA's Megatron-LM framework. It supports
various forms of parallelism including tensor model parallelism, pipeline model parallelism,
sequence parallelism, and expert parallelism for efficient training of large language models.

Args:
no_ddp_communication_hook: Disable DDP communication hook when using AMP-O2
with FP32 gradient accumulation.
tensor_model_parallel_size (int): Intra-layer model parallelism. Splits tensors across GPU ranks.
Defaults to 1.
pipeline_model_parallel_size (int): Inter-layer model parallelism. Splits transformer layers
across GPU ranks. Defaults to 1.
virtual_pipeline_model_parallel_size (Optional[int]): Interleaved pipeline parallelism used to
improve performance by reducing the pipeline bubble. Defaults to None.
context_parallel_size (int): Splits network input along sequence dimension across GPU ranks.
Defaults to 1.
sequence_parallel (bool): Makes tensor parallelism more memory efficient for LLMs (20B+) by
parallelizing layer norms and dropout sequentially. Defaults to False.
expert_model_parallel_size (int): Distributes MoE Experts across sub data parallel dimension.
Defaults to 1.
moe_extended_tp (bool): Alternative parallelization strategy for expert parallelism. Defaults to False.
data_sampler (Optional['DataSampler']): Custom data sampler for distributed training. Defaults to None.
parallel_devices (Optional[List[torch.device]]): List of devices to use for parallelism. Defaults to None.
cluster_environment: Cluster environment for distributed training. Defaults to None.
checkpoint_io: Checkpoint I/O handler. Defaults to None.
find_unused_parameters (bool): Find unused parameters in DDP. Defaults to False.
enable_nemo_ckpt_io (bool): Enable NeMo checkpoint I/O. Defaults to True.
ckpt_type (TrainerCkptProtocol): Checkpoint type. Defaults to TrainerCheckpoint.
ckpt_include_optimizer (bool): Include optimizer state in checkpoint. Defaults to False.
ddp (Union[DDPLiteral, DistributedDataParallelConfig]): DDP configuration. Defaults to "megatron".
lazy_init (bool): Use lazy initialization for model parallel parameters. Defaults to False.
pipeline_dtype (Optional[torch.dtype]): Data type for pipeline parallelism. Defaults to None.
**kwargs: Additional keyword arguments.

Note:
This strategy is designed to work with NVIDIA's Megatron-LM framework and requires
specific model implementations that are compatible with Megatron's parallelism techniques.
"""

trainer: pl.Trainer

## TODO: support context parallel
def __init__(
self,
tensor_model_parallel_size: int = 1,
pipeline_model_parallel_size: int = 1,
virtual_pipeline_model_parallel_size: Optional[int] = None,
context_parallel_size: int = 1,
sequence_parallel: bool = False,
expert_model_parallel_size: int = 1,
moe_extended_tp: bool = False,
data_sampler: Optional['DataSampler'] = None,
parallel_devices: Optional[List[torch.device]] = None,
cluster_environment=None, # TODO: Add type-hint
Expand All @@ -86,6 +119,9 @@ def __init__(
self.data_sampler: Optional['DataSampler'] = data_sampler
self.tensor_model_parallel_size = tensor_model_parallel_size
self.pipeline_model_parallel_size = pipeline_model_parallel_size
self.context_parallel_size = context_parallel_size
self.expert_model_parallel_size = expert_model_parallel_size
self.moe_extended_tp = moe_extended_tp
self.virtual_pipeline_model_parallel_size = virtual_pipeline_model_parallel_size
self.sequence_parallel = sequence_parallel
self.enable_nemo_ckpt_io = enable_nemo_ckpt_io
Expand Down Expand Up @@ -125,6 +161,9 @@ def connect(self, model: pl.LightningModule) -> None:
config.tensor_model_parallel_size = self.tensor_model_parallel_size
config.pipeline_model_parallel_size = self.pipeline_model_parallel_size
config.virtual_pipeline_model_parallel_size = self.virtual_pipeline_model_parallel_size
config.context_parallel_size = self.context_parallel_size
config.expert_model_parallel_size = self.expert_model_parallel_size
config.moe_extended_tp = self.moe_extended_tp
config.sequence_parallel = self.sequence_parallel
self._mcore_config = config

Expand Down
Loading