From 43f45fa16d3170d2ccc559e368e69c34ac5bf98e Mon Sep 17 00:00:00 2001 From: Guyue Huang Date: Wed, 8 Jan 2025 10:56:21 -0800 Subject: [PATCH 1/7] Revert mixtral config change Signed-off-by: Guyue Huang --- nemo/collections/llm/gpt/model/mixtral.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/nemo/collections/llm/gpt/model/mixtral.py b/nemo/collections/llm/gpt/model/mixtral.py index 8b1c8d62ebd6..919705fc02f1 100644 --- a/nemo/collections/llm/gpt/model/mixtral.py +++ b/nemo/collections/llm/gpt/model/mixtral.py @@ -70,8 +70,7 @@ class MixtralConfig(GPTConfig): rotary_base: float = 1000000.0 bf16: bool = True params_dtype: torch.dtype = torch.bfloat16 - apply_rope_fusion: bool = True - bias_activation_fusion: bool = True + @dataclass class MixtralConfig8x3B(MixtralConfig): From 8420b2232a556851ff56464f075bbafb33cc7084 Mon Sep 17 00:00:00 2001 From: Guyue Huang Date: Wed, 8 Jan 2025 11:04:02 -0800 Subject: [PATCH 2/7] Decide cuda device max connections based on torch.cuda.get_device_capability Signed-off-by: Guyue Huang --- nemo/lightning/run/plugins.py | 23 ++++++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/nemo/lightning/run/plugins.py b/nemo/lightning/run/plugins.py index e4b07e5acb35..42a4eb296af8 100644 --- a/nemo/lightning/run/plugins.py +++ b/nemo/lightning/run/plugins.py @@ -30,6 +30,8 @@ from nemo.utils.import_utils import safe_import +import torch + res_module, HAVE_RES = safe_import('nvidia_resiliency_ext.ptl_resiliency') # This file contains plugins based on NeMo-Run's run.Plugin API. @@ -342,9 +344,24 @@ def setup(self, task: run.Partial | run.Script, executor: run.Executor): """Enable the performance environment settings""" if task.trainer.strategy.__fn_or_cls__ == MegatronStrategy: - # Force program order kernel launch for TP, CP overlap - if self.custom_cuda_device_max_connections: - executor.env_vars["CUDA_DEVICE_MAX_CONNECTIONS"] = str(self.custom_cuda_device_max_connections) + if torch.cuda.is_available(): + major, _ = torch.cuda.get_device_capability() + if major > 9: + if self.custom_cuda_device_max_connections is not None: + executor.env_vars["CUDA_DEVICE_MAX_CONNECTIONS"] = str(self.custom_cuda_device_max_connections) + else: + # When TP or CP size is larger than 1, need to use a single cuda device connection to enforce + # the kernel queuing order of the host to GPU for their execution. This is needed for the optimal + # overlap between communication and computation kernels. + tp_size = task.trainer.strategy.tensor_model_parallel_size + cp_size = task.trainer.strategy.context_parallel_size + if tp_size > 1 or cp_size > 1: + executor.env_vars["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" + else: + if self.custom_cuda_device_max_connections is not None: + executor.env_vars["CUDA_DEVICE_MAX_CONNECTIONS"] = str( + self.custom_cuda_device_max_connections + ) # Set LayerNorm SM margin to support the overlap with LayerNorm kernel if self.enable_layernorm_sm_margin: From 633b9032c9a06fe6de4c88fee21ecc2a80dfab77 Mon Sep 17 00:00:00 2001 From: Guyue Huang Date: Wed, 8 Jan 2025 11:05:09 -0800 Subject: [PATCH 3/7] Rename custom_cuda_device_max_connections to num_cuda_device_max_connections Signed-off-by: Guyue Huang --- nemo/lightning/run/plugins.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/nemo/lightning/run/plugins.py b/nemo/lightning/run/plugins.py index 42a4eb296af8..d8458692c2e8 100644 --- a/nemo/lightning/run/plugins.py +++ b/nemo/lightning/run/plugins.py @@ -317,7 +317,7 @@ class PerfEnvPlugin(run.Plugin): layernorm_sm_margin: int = 16 enable_vboost: bool = False nccl_pp_comm_chunksize: Optional[int] = None - custom_cuda_device_max_connections: int = None + num_cuda_device_max_connections: int = None def get_vboost_srun_cmd(self, nodes, job_dir): "Create the vboost `sudo nvidia-smi boost-slider --vboost 1` command" @@ -347,8 +347,8 @@ def setup(self, task: run.Partial | run.Script, executor: run.Executor): if torch.cuda.is_available(): major, _ = torch.cuda.get_device_capability() if major > 9: - if self.custom_cuda_device_max_connections is not None: - executor.env_vars["CUDA_DEVICE_MAX_CONNECTIONS"] = str(self.custom_cuda_device_max_connections) + if self.num_cuda_device_max_connections is not None: + executor.env_vars["CUDA_DEVICE_MAX_CONNECTIONS"] = str(self.num_cuda_device_max_connections) else: # When TP or CP size is larger than 1, need to use a single cuda device connection to enforce # the kernel queuing order of the host to GPU for their execution. This is needed for the optimal @@ -358,9 +358,9 @@ def setup(self, task: run.Partial | run.Script, executor: run.Executor): if tp_size > 1 or cp_size > 1: executor.env_vars["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" else: - if self.custom_cuda_device_max_connections is not None: + if self.num_cuda_device_max_connections is not None: executor.env_vars["CUDA_DEVICE_MAX_CONNECTIONS"] = str( - self.custom_cuda_device_max_connections + self.num_cuda_device_max_connections ) # Set LayerNorm SM margin to support the overlap with LayerNorm kernel From 5ca96dbed7540c550321debbba4b224fa953f6e0 Mon Sep 17 00:00:00 2001 From: guyueh1 Date: Wed, 8 Jan 2025 19:08:45 +0000 Subject: [PATCH 4/7] Apply isort and black reformatting Signed-off-by: guyueh1 --- nemo/collections/llm/recipes/nemotron4_15b.py | 4 +--- nemo/collections/llm/recipes/nemotron4_340b.py | 2 +- nemo/lightning/run/plugins.py | 4 +--- 3 files changed, 3 insertions(+), 7 deletions(-) diff --git a/nemo/collections/llm/recipes/nemotron4_15b.py b/nemo/collections/llm/recipes/nemotron4_15b.py index e4765de7fb96..90baeec836f5 100644 --- a/nemo/collections/llm/recipes/nemotron4_15b.py +++ b/nemo/collections/llm/recipes/nemotron4_15b.py @@ -25,9 +25,7 @@ from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger from nemo.collections.llm.recipes.nemotron import nemotron_model, nemotron_trainer from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing -from nemo.collections.llm.recipes.tp_overlap_configs.userbuffers import ( - userbuffers_bf16_h100_h8192_tp2_mbs1_seqlen8192 -) +from nemo.collections.llm.recipes.tp_overlap_configs.userbuffers import userbuffers_bf16_h100_h8192_tp2_mbs1_seqlen8192 from nemo.lightning.pytorch.callbacks.megatron_comm_overlap import MegatronCommOverlapCallback from nemo.utils.exp_manager import TimingCallback diff --git a/nemo/collections/llm/recipes/nemotron4_340b.py b/nemo/collections/llm/recipes/nemotron4_340b.py index e397c7ed75bd..3813caf221f7 100644 --- a/nemo/collections/llm/recipes/nemotron4_340b.py +++ b/nemo/collections/llm/recipes/nemotron4_340b.py @@ -26,7 +26,7 @@ from nemo.collections.llm.recipes.nemotron import nemotron_model, nemotron_trainer from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing from nemo.collections.llm.recipes.tp_overlap_configs.userbuffers import ( - userbuffers_bf16_h100_h18432_tp8_mbs1_seqlen4096 + userbuffers_bf16_h100_h18432_tp8_mbs1_seqlen4096, ) from nemo.lightning.pytorch.callbacks.megatron_comm_overlap import MegatronCommOverlapCallback from nemo.utils.exp_manager import TimingCallback diff --git a/nemo/lightning/run/plugins.py b/nemo/lightning/run/plugins.py index 5d33248384a8..f163b8f37ab2 100644 --- a/nemo/lightning/run/plugins.py +++ b/nemo/lightning/run/plugins.py @@ -19,6 +19,7 @@ from typing import Callable, Optional import nemo_run as run +import torch import yaml from lightning.pytorch import Callback from lightning.pytorch.loggers import WandbLogger @@ -27,11 +28,8 @@ from nemo.lightning.pytorch.callbacks import NsysCallback, PreemptionCallback from nemo.lightning.pytorch.strategies.megatron_strategy import MegatronStrategy from nemo.utils import logging - from nemo.utils.import_utils import safe_import -import torch - res_module, HAVE_RES = safe_import('nvidia_resiliency_ext.ptl_resiliency') # This file contains plugins based on NeMo-Run's run.Plugin API. From 2b8114bd788eeece306b69b10f699e89d2801727 Mon Sep 17 00:00:00 2001 From: Guyue Huang Date: Wed, 8 Jan 2025 16:52:14 -0800 Subject: [PATCH 5/7] Remove explicit config of align_param_gather in mixtral recipe and use default --- nemo/collections/llm/recipes/mixtral_8x7b.py | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/nemo/collections/llm/recipes/mixtral_8x7b.py b/nemo/collections/llm/recipes/mixtral_8x7b.py index d06e22fc2180..c101656ce68f 100644 --- a/nemo/collections/llm/recipes/mixtral_8x7b.py +++ b/nemo/collections/llm/recipes/mixtral_8x7b.py @@ -210,20 +210,10 @@ def pretrain_performance_optimizations(recipe: run.Partial) -> run.Partial: It may not be suitable for all hardware configurations or use cases. """ - # 'overlap_param_gather_with_optimizer_step' and 'align_param_gather' params are set automatically - # by MegatronCommOverlapCallback. They are added here for user's knowledge. - # overlap_param_gather_with_optimizer_step- Overlap param all-gather of first bucket with optimizer step. - # align_param_gather- If true, all PP stages launch param all-gathers simultaneously, else - # each PP stage launches independently as needed. - recipe.trainer.callbacks.extend( [ run.Config(MegatronTokenDropCallback), - run.Config( - MegatronCommOverlapCallback, - overlap_param_gather_with_optimizer_step=False, # Currently disabled due to issue with checkpointing. - align_param_gather=True, - ), + run.Config(MegatronCommOverlapCallback), ] ) recipe.trainer.strategy.expert_model_parallel_size = 1 From 7c5530bee7ded21e97dff74e40262947bdf82fb9 Mon Sep 17 00:00:00 2001 From: Guyue Huang Date: Wed, 8 Jan 2025 19:27:54 -0800 Subject: [PATCH 6/7] Revert "Remove explicit config of align_param_gather in mixtral recipe and use default" This reverts commit 2b8114bd788eeece306b69b10f699e89d2801727. --- nemo/collections/llm/recipes/mixtral_8x7b.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/nemo/collections/llm/recipes/mixtral_8x7b.py b/nemo/collections/llm/recipes/mixtral_8x7b.py index c101656ce68f..d06e22fc2180 100644 --- a/nemo/collections/llm/recipes/mixtral_8x7b.py +++ b/nemo/collections/llm/recipes/mixtral_8x7b.py @@ -210,10 +210,20 @@ def pretrain_performance_optimizations(recipe: run.Partial) -> run.Partial: It may not be suitable for all hardware configurations or use cases. """ + # 'overlap_param_gather_with_optimizer_step' and 'align_param_gather' params are set automatically + # by MegatronCommOverlapCallback. They are added here for user's knowledge. + # overlap_param_gather_with_optimizer_step- Overlap param all-gather of first bucket with optimizer step. + # align_param_gather- If true, all PP stages launch param all-gathers simultaneously, else + # each PP stage launches independently as needed. + recipe.trainer.callbacks.extend( [ run.Config(MegatronTokenDropCallback), - run.Config(MegatronCommOverlapCallback), + run.Config( + MegatronCommOverlapCallback, + overlap_param_gather_with_optimizer_step=False, # Currently disabled due to issue with checkpointing. + align_param_gather=True, + ), ] ) recipe.trainer.strategy.expert_model_parallel_size = 1 From e234588a7e413445737245ad9123101cc6022560 Mon Sep 17 00:00:00 2001 From: Guyue Huang Date: Thu, 9 Jan 2025 14:40:22 -0800 Subject: [PATCH 7/7] Rename ub config; change proj to ring exchange for nemotron 340b Signed-off-by: Guyue Huang --- nemo/collections/llm/recipes/nemotron4_15b.py | 4 +-- .../collections/llm/recipes/nemotron4_340b.py | 4 +-- .../recipes/tp_overlap_configs/userbuffers.py | 34 +++++++++---------- 3 files changed, 21 insertions(+), 21 deletions(-) diff --git a/nemo/collections/llm/recipes/nemotron4_15b.py b/nemo/collections/llm/recipes/nemotron4_15b.py index 90baeec836f5..8eaf7ddc6dbb 100644 --- a/nemo/collections/llm/recipes/nemotron4_15b.py +++ b/nemo/collections/llm/recipes/nemotron4_15b.py @@ -25,7 +25,7 @@ from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger from nemo.collections.llm.recipes.nemotron import nemotron_model, nemotron_trainer from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing -from nemo.collections.llm.recipes.tp_overlap_configs.userbuffers import userbuffers_bf16_h100_h8192_tp2_mbs1_seqlen8192 +from nemo.collections.llm.recipes.tp_overlap_configs.userbuffers import userbuffers_bf16_b200_h6144_tp2_mbs1_seqlen4096 from nemo.lightning.pytorch.callbacks.megatron_comm_overlap import MegatronCommOverlapCallback from nemo.utils.exp_manager import TimingCallback @@ -203,7 +203,7 @@ def pretrain_performance_optimizations(recipe: run.Partial) -> run.Partial: run.Config( MegatronCommOverlapCallback, tp_comm_overlap=True, - tp_comm_overlap_cfg=userbuffers_bf16_h100_h8192_tp2_mbs1_seqlen8192, + tp_comm_overlap_cfg=userbuffers_bf16_b200_h6144_tp2_mbs1_seqlen4096, ) ) return recipe diff --git a/nemo/collections/llm/recipes/nemotron4_340b.py b/nemo/collections/llm/recipes/nemotron4_340b.py index 3813caf221f7..eb0983088f49 100644 --- a/nemo/collections/llm/recipes/nemotron4_340b.py +++ b/nemo/collections/llm/recipes/nemotron4_340b.py @@ -26,7 +26,7 @@ from nemo.collections.llm.recipes.nemotron import nemotron_model, nemotron_trainer from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing from nemo.collections.llm.recipes.tp_overlap_configs.userbuffers import ( - userbuffers_bf16_h100_h18432_tp8_mbs1_seqlen4096, + userbuffers_bf16_b200_h18432_tp8_mbs1_seqlen4096, ) from nemo.lightning.pytorch.callbacks.megatron_comm_overlap import MegatronCommOverlapCallback from nemo.utils.exp_manager import TimingCallback @@ -212,7 +212,7 @@ def pretrain_performance_optimizations(recipe: run.Partial) -> run.Partial: run.Config( MegatronCommOverlapCallback, tp_comm_overlap=True, - tp_comm_overlap_cfg=userbuffers_bf16_h100_h18432_tp8_mbs1_seqlen4096, + tp_comm_overlap_cfg=userbuffers_bf16_b200_h18432_tp8_mbs1_seqlen4096, defer_embedding_wgrad_compute=True, wgrad_deferral_limit=22, overlap_param_gather_with_optimizer_step=False, # Currently disabled due to an issue with checkpointing diff --git a/nemo/collections/llm/recipes/tp_overlap_configs/userbuffers.py b/nemo/collections/llm/recipes/tp_overlap_configs/userbuffers.py index de7ee4c9c889..820e8561cf20 100644 --- a/nemo/collections/llm/recipes/tp_overlap_configs/userbuffers.py +++ b/nemo/collections/llm/recipes/tp_overlap_configs/userbuffers.py @@ -88,20 +88,6 @@ class TransformerLayerTPOverlapCfg: fc2_fprop=PipelineOverlapCfg(num_sm=16, cga_size=2, num_splits=4, set_sm_margin=True, fp8_buf=True), ) -# llama3 70b -userbuffers_bf16_h100_h8192_tp2_mbs1_seqlen8192 = TransformerLayerTPOverlapCfg( - qkv_dgrad=BulkOverlapCfg(num_sm=8, cga_size=2, set_sm_margin=False), - qkv_wgrad=BulkOverlapCfg(num_sm=32, cga_size=2, set_sm_margin=False), - fc1_dgrad=BulkOverlapCfg(num_sm=2, cga_size=2, set_sm_margin=False), - fc1_wgrad=BulkOverlapCfg(num_sm=8, cga_size=2, set_sm_margin=False), - qkv_fprop=RingExchangeOverlapCfg(aggregate=False), - proj_dgrad=RingExchangeOverlapCfg(aggregate=False), - fc1_fprop=RingExchangeOverlapCfg(aggregate=False), - fc2_dgrad=RingExchangeOverlapCfg(aggregate=False), - proj_fprop=RingExchangeOverlapCfg(), - fc2_fprop=RingExchangeOverlapCfg(), -) - # llama3.1 405b userbuffers_bf16_h100_h16384_tp8_cp2_mbs1_seqlen8192 = TransformerLayerTPOverlapCfg( qkv_dgrad=BulkOverlapCfg(num_sm=2, cga_size=2, set_sm_margin=False), @@ -183,8 +169,22 @@ class TransformerLayerTPOverlapCfg: fc2_fprop=RingExchangeOverlapCfg(num_sm=1, set_sm_margin=True), ) +# Nemotron 15B +userbuffers_bf16_b200_h6144_tp2_mbs1_seqlen4096 = TransformerLayerTPOverlapCfg( + qkv_dgrad=BulkOverlapCfg(num_sm=8, cga_size=2, set_sm_margin=False), + qkv_wgrad=BulkOverlapCfg(num_sm=32, cga_size=2, set_sm_margin=False), + fc1_dgrad=BulkOverlapCfg(num_sm=2, cga_size=2, set_sm_margin=False), + fc1_wgrad=BulkOverlapCfg(num_sm=8, cga_size=2, set_sm_margin=False), + qkv_fprop=RingExchangeOverlapCfg(aggregate=False), + proj_dgrad=RingExchangeOverlapCfg(aggregate=False), + fc1_fprop=RingExchangeOverlapCfg(aggregate=False), + fc2_dgrad=RingExchangeOverlapCfg(aggregate=False), + proj_fprop=RingExchangeOverlapCfg(num_sm=1, set_sm_margin=True), + fc2_fprop=RingExchangeOverlapCfg(num_sm=1, set_sm_margin=True), +) + # Nemotron 340B -userbuffers_bf16_h100_h18432_tp8_mbs1_seqlen4096 = TransformerLayerTPOverlapCfg( +userbuffers_bf16_b200_h18432_tp8_mbs1_seqlen4096 = TransformerLayerTPOverlapCfg( qkv_dgrad=BulkOverlapCfg(num_sm=8, cga_size=2, set_sm_margin=False), qkv_wgrad=BulkOverlapCfg(num_sm=32, cga_size=2, set_sm_margin=False), fc1_dgrad=BulkOverlapCfg(num_sm=2, cga_size=2, set_sm_margin=False), @@ -193,6 +193,6 @@ class TransformerLayerTPOverlapCfg: proj_dgrad=RingExchangeOverlapCfg(aggregate=False), fc1_fprop=RingExchangeOverlapCfg(aggregate=False), fc2_dgrad=RingExchangeOverlapCfg(aggregate=False), - proj_fprop=PipelineOverlapCfg(num_sm=32, cga_size=2, num_splits=2, set_sm_margin=True, fp8_buf=True), - fc2_fprop=PipelineOverlapCfg(num_sm=24, cga_size=2, num_splits=4, set_sm_margin=True, fp8_buf=True), + proj_fprop=RingExchangeOverlapCfg(num_sm=1, set_sm_margin=True), + fc2_fprop=PipelineOverlapCfg(num_sm=24, cga_size=2, num_splits=4, set_sm_margin=True), )