From 90a99d3602b432eefef5624dd6af5f15ca4ce417 Mon Sep 17 00:00:00 2001 From: Guyue Huang Date: Fri, 3 Jan 2025 17:21:09 -0800 Subject: [PATCH] Fix nemo2 interface for ucc Signed-off-by: Guyue Huang --- nemo/lightning/_strategy_lib.py | 2 +- nemo/lightning/pytorch/strategies/megatron_strategy.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/nemo/lightning/_strategy_lib.py b/nemo/lightning/_strategy_lib.py index 4dec5200450e..6834e5133584 100644 --- a/nemo/lightning/_strategy_lib.py +++ b/nemo/lightning/_strategy_lib.py @@ -82,7 +82,7 @@ def init_parallel_ranks( tensor_model_parallel_size=parallel_config.tensor_model_parallel_size, expert_model_parallel_size=parallel_config.expert_model_parallel_size, pipeline_model_parallel_size=parallel_config.pipeline_model_parallel_size, - pipeline_model_parallel_comm_backend=parallel_config.get('pipeline_model_parallel_comm_backend', 'nccl'), + pipeline_model_parallel_comm_backend=parallel_config.pipeline_model_parallel_comm_backend, virtual_pipeline_model_parallel_size=parallel_config.virtual_pipeline_model_parallel_size, context_parallel_size=parallel_config.context_parallel_size, encoder_tensor_model_parallel_size=getattr(parallel_config, "encoder_tensor_model_parallel_size", 0), diff --git a/nemo/lightning/pytorch/strategies/megatron_strategy.py b/nemo/lightning/pytorch/strategies/megatron_strategy.py index 9ddf8241fab2..f3434e06bd48 100644 --- a/nemo/lightning/pytorch/strategies/megatron_strategy.py +++ b/nemo/lightning/pytorch/strategies/megatron_strategy.py @@ -905,6 +905,7 @@ def parallelism(self) -> ParallelismConfig: return ParallelismConfig( tensor_model_parallel_size=self.tensor_model_parallel_size, pipeline_model_parallel_size=self.pipeline_model_parallel_size, + pipeline_model_parallel_comm_backend=self.pipeline_model_parallel_comm_backend, virtual_pipeline_model_parallel_size=self.virtual_pipeline_model_parallel_size, microbatch_group_size_per_vp_stage=self.microbatch_group_size_per_vp_stage, context_parallel_size=self.context_parallel_size,