diff --git a/nemo/collections/llm/gpt/model/mixtral.py b/nemo/collections/llm/gpt/model/mixtral.py index 919705fc02f1..8b1c8d62ebd6 100644 --- a/nemo/collections/llm/gpt/model/mixtral.py +++ b/nemo/collections/llm/gpt/model/mixtral.py @@ -70,7 +70,8 @@ class MixtralConfig(GPTConfig): rotary_base: float = 1000000.0 bf16: bool = True params_dtype: torch.dtype = torch.bfloat16 - + apply_rope_fusion: bool = True + bias_activation_fusion: bool = True @dataclass class MixtralConfig8x3B(MixtralConfig): diff --git a/nemo/collections/llm/recipes/nemotron4_15b.py b/nemo/collections/llm/recipes/nemotron4_15b.py index 49f92fcc1616..1db257ece889 100644 --- a/nemo/collections/llm/recipes/nemotron4_15b.py +++ b/nemo/collections/llm/recipes/nemotron4_15b.py @@ -25,6 +25,9 @@ from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger from nemo.collections.llm.recipes.nemotron import nemotron_model, nemotron_trainer from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing +from nemo.collections.llm.recipes.tp_overlap_configs.userbuffers import ( + userbuffers_bf16_h100_h8192_tp2_mbs1_seqlen8192 +) from nemo.lightning.pytorch.callbacks.megatron_comm_overlap import MegatronCommOverlapCallback from nemo.utils.exp_manager import TimingCallback @@ -202,6 +205,7 @@ def pretrain_performance_optimizations(recipe: run.Partial) -> run.Partial: run.Config( MegatronCommOverlapCallback, tp_comm_overlap=True, + tp_comm_overlap_cfg=userbuffers_bf16_h100_h8192_tp2_mbs1_seqlen8192, ) ) return recipe diff --git a/nemo/collections/llm/recipes/nemotron4_340b.py b/nemo/collections/llm/recipes/nemotron4_340b.py index 14d4c0f32d11..e397c7ed75bd 100644 --- a/nemo/collections/llm/recipes/nemotron4_340b.py +++ b/nemo/collections/llm/recipes/nemotron4_340b.py @@ -25,6 +25,9 @@ from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger from nemo.collections.llm.recipes.nemotron import nemotron_model, nemotron_trainer from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing +from nemo.collections.llm.recipes.tp_overlap_configs.userbuffers import ( + userbuffers_bf16_h100_h18432_tp8_mbs1_seqlen4096 +) from nemo.lightning.pytorch.callbacks.megatron_comm_overlap import MegatronCommOverlapCallback from nemo.utils.exp_manager import TimingCallback @@ -209,6 +212,7 @@ def pretrain_performance_optimizations(recipe: run.Partial) -> run.Partial: run.Config( MegatronCommOverlapCallback, tp_comm_overlap=True, + tp_comm_overlap_cfg=userbuffers_bf16_h100_h18432_tp8_mbs1_seqlen4096, defer_embedding_wgrad_compute=True, wgrad_deferral_limit=22, overlap_param_gather_with_optimizer_step=False, # Currently disabled due to an issue with checkpointing diff --git a/nemo/collections/llm/recipes/tp_overlap_configs/userbuffers.py b/nemo/collections/llm/recipes/tp_overlap_configs/userbuffers.py index e9cc0b5825c7..de7ee4c9c889 100644 --- a/nemo/collections/llm/recipes/tp_overlap_configs/userbuffers.py +++ b/nemo/collections/llm/recipes/tp_overlap_configs/userbuffers.py @@ -182,3 +182,17 @@ class TransformerLayerTPOverlapCfg: proj_fprop=PipelineOverlapCfg(num_sm=24, cga_size=2, num_splits=4, set_sm_margin=True, fp8_buf=True), fc2_fprop=RingExchangeOverlapCfg(num_sm=1, set_sm_margin=True), ) + +# Nemotron 340B +userbuffers_bf16_h100_h18432_tp8_mbs1_seqlen4096 = TransformerLayerTPOverlapCfg( + qkv_dgrad=BulkOverlapCfg(num_sm=8, cga_size=2, set_sm_margin=False), + qkv_wgrad=BulkOverlapCfg(num_sm=32, cga_size=2, set_sm_margin=False), + fc1_dgrad=BulkOverlapCfg(num_sm=2, cga_size=2, set_sm_margin=False), + fc1_wgrad=BulkOverlapCfg(num_sm=8, cga_size=2, set_sm_margin=False), + qkv_fprop=RingExchangeOverlapCfg(aggregate=False), + proj_dgrad=RingExchangeOverlapCfg(aggregate=False), + fc1_fprop=RingExchangeOverlapCfg(aggregate=False), + fc2_dgrad=RingExchangeOverlapCfg(aggregate=False), + proj_fprop=PipelineOverlapCfg(num_sm=32, cga_size=2, num_splits=2, set_sm_margin=True, fp8_buf=True), + fc2_fprop=PipelineOverlapCfg(num_sm=24, cga_size=2, num_splits=4, set_sm_margin=True, fp8_buf=True), +)