From 916c88357fdbee5107574da156585addd17b31bb Mon Sep 17 00:00:00 2001 From: yang <7129+yang@users.noreply.github.com> Date: Sat, 4 May 2024 11:25:27 -0700 Subject: [PATCH] Add megablocks dropless MoE (#1192) * Add megablocks dropless MoE * pre-commit --------- Co-authored-by: Yang Zhang Co-authored-by: Quentin Anthony --- README.md | 75 +++++++++++ configs/125M-dmoe.yml | 101 ++++++++++++++ configs/125M-moe.yml | 16 +-- megatron/data/helpers.cpp | 12 +- megatron/model/megablocks_utils.py | 34 +++++ megatron/model/transformer.py | 169 +++++++++++++++++++++--- megatron/neox_arguments/arguments.py | 6 +- megatron/neox_arguments/neox_args.py | 30 ++++- megatron/training.py | 62 ++++++++- tools/ckpts/convert_hf_to_sequential.py | 2 +- 10 files changed, 464 insertions(+), 43 deletions(-) create mode 100644 configs/125M-dmoe.yml create mode 100644 megatron/model/megablocks_utils.py diff --git a/README.md b/README.md index 7b7bf1739..e7f61bf20 100644 --- a/README.md +++ b/README.md @@ -52,6 +52,7 @@ Prior to 3/9/2023, GPT-NeoX relied on [DeeperSpeed](https://github.com/EleutherA + [Containerized Setup](#containerized-setup) * [Usage](#usage) - [Configuration](#configuration) + * [Mixture of Experts](#mixture-of-experts) - [Datasets](#datasets) * [Preconfigured Datasets](#preconfigured-datasets) * [Using Custom Data](#using-custom-data) @@ -322,6 +323,80 @@ These files are generally complete, but non-optimal. For example, depending on y For a more detailed guide to the features available and how to configure them, see [the configuration README](configs/README.md), and for documentation of every possible argument, see [configs/neox_arguments.md](configs/neox_arguments.md). +## Mixture of Experts + +GPT-NeoX includes multiple expert implementations for MoE. To select between them, specify `moe_type` of `megablocks` (default) or `deepspeed`. + +Both are based on the DeepSpeed MoE parallelism framework, which supports tensor-expert-data parallelism. +Both allow you to toggle between token-dropping and dropless (default, and this is what Megablocks was designed for). +Sinkhorn routing to come soon! + +For an example of a basic complete configuration, see configs/125M-dmoe.yml (for Megablocks dropless) or configs/125M-moe.yml. + +Most MoE related configuration arguments are prefixed with `moe`. Some common configuration parameters and their defaults are as follows: + +``` +moe_type: megablocks +moe_num_experts: 1 # 1 disables MoE. 8 is a reasonable value. +moe_loss_coeff: 0.1 +expert_interval: 2 # See details below +enable_expert_tensor_parallelism: false # See details below +moe_expert_parallel_size: 1 # See details below +moe_token_dropping: false +``` + +DeepSpeed can be further configured with the following: + +``` +moe_top_k: 1 +moe_min_capacity: 4 +moe_train_capacity_factor: 1.0 # Setting to 1.0 +moe_eval_capacity_factor: 1.0 # Setting to 1.0 +``` + +One MoE layer is present every `expert_interval` transformer layers including the first, so with 12 layers total: + +``` +0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11 +``` + +Experts would be in these layers: + +``` +0, 2, 4, 6, 8, 10 +``` + +By default, we use expert-data parallelism, so any available tensor parallelism (`model_parallel_size`) will be used for expert routing. For instance, given the following: + +``` +expert_parallel_size: 4 +model_parallel_size: 2 # aka tensor parallelism +``` + +With 32 GPUs, the behavior will be look like: + +- In non-expert layers: + - Tensor parallelism is 2. (There are 32 / 2 = 16 such tensor parallel groups, each of size 2.) + - Data parallelism implicitly becomes 32 / 2 = 16. +- In expert layers: + - There is no tensor parallelism. + - Expert parallelism is 4. (There are 32 / 4 = 8 expert parallel groups, each of size 4.) + - Data parallelism implicitly becomes 32 / 4 = 8. Some cross-node token routing happens as a result of this redivision of data parallelism between 16 and 8. To avoid it, ensure that `expert_parallel_size == model_parallel_size`. + +Setting `enable_expert_tensor_parallelism` enables tensor-expert-data (TED) parallelism. The way to interpret the above would then be: + +- In non-expert layers: same as before. +- In expert layers: + - Tensor parallelism is 2. (There are 32 / 2 = 16 tensor parallel groups, each of size 2.) + - Expert parallelism is 4. (There are 32 / 4 = 8 expert parallel groups, each of size 4.) + - Data parallelism implicitly becomes 32 / (2 * 4) = 4. Again, cross-node token routing happens. To avoid, ensure `expert_parallel_size == 1` or `model_parallel_size == 1`. + +So note that DP must be divisible by (MP * EP). For more details, see the [TED paper]. + +Pipeline parallelism is not yet supported - coming soon! + +[TED paper]: https://arxiv.org/abs/2303.06318 + # Datasets ## Preconfigured Datasets diff --git a/configs/125M-dmoe.yml b/configs/125M-dmoe.yml new file mode 100644 index 000000000..229191b4d --- /dev/null +++ b/configs/125M-dmoe.yml @@ -0,0 +1,101 @@ +# GPT-2 pretraining setup +{ + # See README for MoE config docs! + "moe_type": "megablocks", + "moe_token_dropping": false, + # Have 4 experts per layer (every 2 layers by default) + "moe_num_experts": 4, + # parallelism settings + "enable_expert_tensor_parallelism": true, + "pipe_parallel_size": 1, # not yet supported for MoE + "model_parallel_size": 1, + "moe_expert_parallel_size": 1, + + # model settings + "num_layers": 12, + "hidden_size": 768, + "num_attention_heads": 12, + "seq_length": 2048, + "max_position_embeddings": 2048, + "norm": "layernorm", + "pos_emb": "rotary", + "no_weight_tying": true, + "gpt_j_residual": false, + "output_layer_parallelism": "column", + + # these should provide some speedup but takes a while to build, set to true if desired + "scaled_upper_triang_masked_softmax_fusion": false, + "bias_gelu_fusion": false, + "rope_fusion": false, + + # init methods + "init_method": "small_init", + "output_layer_init_method": "wang_init", + + + # optimizer settings + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.0006, + "betas": [0.9, 0.95], + "eps": 1.0e-8, + } + }, + "min_lr": 0.00006, + + # for all zero_optimization options, see https://www.deepspeed.ai/docs/config-json/#zero-optimizations-for-fp16-training + "zero_optimization": { + "stage": 0, + "allgather_partitions": True, + "allgather_bucket_size": 500000000, + "overlap_comm": True, + "reduce_scatter": True, + "reduce_bucket_size": 500000000, + "contiguous_gradients": True, + }, + + # batch / data settings + "train_micro_batch_size_per_gpu": 4, + "data_impl": "mmap", + + # activation checkpointing + "checkpoint_activations": true, + "checkpoint_num_layers": 1, + "partition_activations": true, + "synchronize_each_layer": true, + + # regularization + "gradient_clipping": 1.0, + "weight_decay": 0.1, + "hidden_dropout": 0.0, + "attention_dropout": 0.0, + + # precision settings + "fp16": { + "enabled": true, + "loss_scale": 0, + "loss_scale_window": 1000, + "hysteresis": 2, + "min_loss_scale": 1 + }, + + # misc. training settings + "train_iters": 320000, + "lr_decay_iters": 320000, + "distributed_backend": "nccl", + "lr_decay_style": "cosine", + "warmup": 0.01, + "checkpoint_factor": 10000, + "eval_interval": 1000, + "eval_iters": 10, + + # logging + "log_interval": 10, + "steps_per_print": 10, + "keep_last_n_checkpoints": 4, + "wall_clock_breakdown": true, + + # networking + "hostfile": "/mock_path" +} diff --git a/configs/125M-moe.yml b/configs/125M-moe.yml index 27ebf216b..1d08d78a4 100644 --- a/configs/125M-moe.yml +++ b/configs/125M-moe.yml @@ -1,15 +1,13 @@ # GPT-2 pretraining setup { + # See README for MoE config docs! + "moe_type": "deepspeed", + "moe_token_dropping": true, # Have 4 experts per layer (every 2 layers by default) - # So with 12 layers total: - # 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11 - # Experts would be in layers: - # 0, 2, 4, 6, 8, 10 - "num_experts": 4, - - # parallelism settings ( you will want to change these based on your cluster setup, ideally scheduling pipeline stages - # across the node boundaries ) - "pipe_parallel_size": 1, + "moe_num_experts": 4, + # parallelism settings + "enable_expert_tensor_parallelism": true, + "pipe_parallel_size": 1, # not yet supported for MoE "model_parallel_size": 1, "moe_expert_parallel_size": 1, diff --git a/megatron/data/helpers.cpp b/megatron/data/helpers.cpp index 9b062b050..aca290854 100644 --- a/megatron/data/helpers.cpp +++ b/megatron/data/helpers.cpp @@ -428,9 +428,9 @@ py::array build_mapping_impl(const py::array_t& docs_, } } // for (auto sent_index=sent_index_first; ... - } // if (num_remain_sent > 1) { - } // for (int doc=0; doc < num_docs; ++doc) { - } // for (int epoch=0; epoch < num_epochs; ++epoch) { + } // if (num_remain_sent > 1) { + } // for (int doc=0; doc < num_docs; ++doc) { + } // for (int epoch=0; epoch < num_epochs; ++epoch) { if (!second) { if (verbose) { @@ -660,9 +660,9 @@ py::array build_blocks_mapping_impl(const py::array_t& docs_, num_sent = 0; } } // for (auto sent_index=sent_index_first; ... - } // if (num_remain_sent > 1) { - } // for (int doc=0; doc < num_docs; ++doc) { - } // for (int epoch=0; epoch < num_epochs; ++epoch) { + } // if (num_remain_sent > 1) { + } // for (int doc=0; doc < num_docs; ++doc) { + } // for (int epoch=0; epoch < num_epochs; ++epoch) { if (!second) { if (verbose) { diff --git a/megatron/model/megablocks_utils.py b/megatron/model/megablocks_utils.py new file mode 100644 index 000000000..6f94b2b2c --- /dev/null +++ b/megatron/model/megablocks_utils.py @@ -0,0 +1,34 @@ +"""Adapter to expose MegaBlocks package, if available.""" + +try: + import megablocks +except ImportError: + megablocks = None + + +def megablocks_is_available(): + return megablocks is not None + + +def assert_megablocks_is_available(): + assert ( + megablocks_is_available() + ), "MegaBlocks not available. Please run `pip install megablocks`." + + +moe = megablocks.layers.moe if megablocks_is_available() else None +dmoe = megablocks.layers.dmoe if megablocks_is_available() else None +arguments = megablocks.layers.arguments if megablocks_is_available() else None + + +def as_megablocks_args(neox_args): + import copy + + tmp = copy.copy(neox_args) + delattr(tmp, "mlp_type") + tmp.mlp_type = "mlp" + args = arguments.from_megatron(tmp) + args.moe_lbl_in_fp32 = True + args.fp16 = neox_args.precision == "fp16" + args.moe_loss_weight = neox_args.moe_loss_coeff + return args diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index c6bf619f8..c154b09f4 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -26,6 +26,7 @@ from .norms import get_norm from megatron import mpu +from megatron.model import megablocks_utils from megatron.model.fused_softmax import FusedScaleMaskSoftmax from megatron.model.activations import get_activation from megatron.model.utils import exists, get_fusion_type @@ -276,6 +277,55 @@ def forward(self, hidden_states): return self.final_linear(hidden_states) +class _MegablocksAdapter(nn.Module): + def __init__( + self, neox_args, layer_cls, init_method, output_layer_init_method, ep_group + ): + super().__init__() + megablocks_utils.assert_megablocks_is_available() + args = megablocks_utils.as_megablocks_args(neox_args) + args.device = torch.cuda.current_device() + args.init_method = init_method + args.output_layer_init_method = output_layer_init_method + + # NOTE: Shard the MoE layers over the data parallel group. Expert + # parallel sharding and data parallel sharding could be decoupled + # by extending the optimizer to handle data parallel reductions for + # MoE and non-MoE parameters separately. + if args.moe_expert_model_parallelism: + args.expert_parallel_group = ep_group + + if neox_args.moe_glu: + args.mlp_type = "glu" + + self.moe = layer_cls(args) + + def forward(self, x): + return self.moe.forward(x) + + +class MbMoE(_MegablocksAdapter): + def __init__(self, neox_args, init_method, output_layer_init_method, ep_group): + super().__init__( + neox_args, + megablocks_utils.moe.MoE, + init_method, + output_layer_init_method, + ep_group, + ) + + +class dMoE(_MegablocksAdapter): + def __init__(self, neox_args, init_method, output_layer_init_method, ep_group): + super().__init__( + neox_args, + megablocks_utils.dmoe.dMoE, + init_method, + output_layer_init_method, + ep_group, + ) + + class ParallelSelfAttention(nn.Module): """Parallel self-attention layer abstract class. @@ -958,6 +1008,7 @@ def __init__( super().__init__() self.layer_number = layer_number + self.neox_args = neox_args norm, eps = get_norm(neox_args) @@ -970,6 +1021,7 @@ def __init__( self.gpt_j_residual = neox_args.gpt_j_residual self.gpt_j_tied = neox_args.gpt_j_tied self.mlp_type = neox_args.mlp_type + self.moe_type = neox_args.moe_type if self.gpt_j_residual: self.reduce = mpu.mappings.reduce_from_model_parallel_region @@ -1014,7 +1066,7 @@ def get_mlp(mlp_type, **kw): raise KeyError(mlp_type) self.num_experts = ( - neox_args.num_experts + neox_args.moe_num_experts if layer_number % neox_args.expert_interval == 0 else 1 ) @@ -1029,23 +1081,87 @@ def get_mlp(mlp_type, **kw): else: moe_mp_size = dist.get_world_size() // self.num_experts - self.mlp = MoE( - args.hidden_size, - get_mlp( - "regular", - MOE=True, - MoE_mp_size=moe_mp_size, - ), - num_experts=self.num_experts, - ep_size=args.moe_expert_parallel_size, - k=args.moe_top_k, - use_residual=args.moe_use_residual, - capacity_factor=args.moe_train_capacity_factor, - eval_capacity_factor=args.moe_eval_capacity_factor, - min_capacity=args.moe_min_capacity, - drop_tokens=args.moe_token_dropping, - use_tutel=args.use_tutel, - ) + if neox_args.moe_type == "deepspeed": + self.mlp = MoE( + args.hidden_size, + get_mlp( + "regular", + MOE=True, + MoE_mp_size=moe_mp_size, + ), + num_experts=self.num_experts, + ep_size=args.moe_expert_parallel_size, + k=args.moe_top_k, + use_residual=args.moe_use_residual, + capacity_factor=args.moe_train_capacity_factor, + eval_capacity_factor=args.moe_eval_capacity_factor, + min_capacity=args.moe_min_capacity, + drop_tokens=args.moe_token_dropping, + use_tutel=args.use_tutel, + enable_expert_tensor_parallelism=args.enable_expert_tensor_parallelism, + ) + elif neox_args.moe_type == "megablocks": + + def integrate_megablocks_with_ds_expert_parallelism(): + # We make megablocks work with DS parallelism. + # + # We fool DS into accepting these MoE parameters as its own DS MoE params, + # which makes things work with the underlying expert parallelism, + # including TED parallelism. + # + # Effectively, we want to: + # + # - Make DS's data parallel gradient all-reduction skip these params. + # - But make these params participate in the expert parallel all-reduction! + # + # Further background: + # + # Normally, with the original megablocks demo codebase, it + # only supports 1 copy of any expert throughout + # the network, since it uses EP group = DP group. + # + # First, we trigger DS initialization of the MoE expert parallel groups and internal state. + throwaway = MoE( + args.hidden_size, + get_mlp( + "regular", + MOE=True, + MoE_mp_size=moe_mp_size, + ), + num_experts=self.num_experts, + ep_size=args.moe_expert_parallel_size, + k=args.moe_top_k, + use_residual=args.moe_use_residual, + capacity_factor=args.moe_train_capacity_factor, + eval_capacity_factor=args.moe_eval_capacity_factor, + min_capacity=args.moe_min_capacity, + drop_tokens=args.moe_token_dropping, + use_tutel=args.use_tutel, + enable_expert_tensor_parallelism=args.enable_expert_tensor_parallelism, + ) + throwaway.set_deepspeed_parallelism() + + ep_group = throwaway.deepspeed_moe.ep_group + if args.moe_token_dropping: + self.mlp = MbMoE( + neox_args, init_method, output_layer_init_method, ep_group + ) + else: + self.mlp = dMoE( + neox_args, init_method, output_layer_init_method, ep_group + ) + + # Next, we trick DS into seeing these as its own MoE params. + for param in self.mlp.parameters(): + if getattr(param, "expert_model_parallel", None) is not None: + # is_moe_param looks for this attr. + param.allreduce = False + param.group_name = throwaway.expert_group_name + + integrate_megablocks_with_ds_expert_parallelism() + + else: + raise KeyError(neox_args.moe_type) self.layer_past = None # used to cache k/v pairs in inference @@ -1152,11 +1268,22 @@ def forward(self, x, attention_mask, layer_past=None): if self.num_experts == 1: mlp_output, mlp_bias = self.mlp(layernorm_output) else: - mlp_output, moe_loss, _ = self.mlp(layernorm_output) - mlp_bias = None # deepspeed.moe.layer.MoE.forward ignores the bias term + if self.moe_type == "deepspeed": + mlp_output, moe_loss, _ = self.mlp(layernorm_output) + mlp_bias = ( + None # deepspeed.moe.layer.MoE.forward ignores the bias term + ) + elif self.moe_type == "megablocks": + mlp_output, mlp_bias = self.mlp(layernorm_output) + else: + raise KeyError(self.moe_type) with torch.enable_grad(): - if self.mlp_type == "llama" or self.num_experts > 1: + if ( + self.mlp_type == "llama" + or self.num_experts > 1 + and self.moe_type == "deepspeed" + ): # No dropout either assert mlp_bias is None output = mlp_output + attention_output diff --git a/megatron/neox_arguments/arguments.py b/megatron/neox_arguments/arguments.py index bf6e3f3e8..d96c48af3 100644 --- a/megatron/neox_arguments/arguments.py +++ b/megatron/neox_arguments/arguments.py @@ -1034,13 +1034,15 @@ def calculate_derived(self): # if we set pipe_parallel_size to 0 or 1, GPT2ModelPipe.to_sequential() is called, and we run training with # the sequential model without the PipelineModule wrapper to avoid the overhead it incurs self.update_value( - "is_pipe_parallel", self.pipe_parallel_size > 1 and self.num_experts == 1 + "is_pipe_parallel", + self.pipe_parallel_size > 1 and self.moe_num_experts == 1, ) - if self.num_experts > 1: + if self.moe_num_experts > 1: assert not ( self.is_pipe_parallel or self.pipe_parallel_size > 1 ), "MoE not supported with pipeline parallelism" assert self.zero_optimization["stage"] != 3, "MoE not compatible with zero3" + assert self.mlp_type == "regular", "MoE not compatible with LLaMA" # Attention config if self.attention_config is None: diff --git a/megatron/neox_arguments/neox_args.py b/megatron/neox_arguments/neox_args.py index 16d6456b4..8a216a25b 100644 --- a/megatron/neox_arguments/neox_args.py +++ b/megatron/neox_arguments/neox_args.py @@ -1275,7 +1275,7 @@ class NeoXArgsTextgen(NeoXArgsTemplate): Use Tutel optimizations in MoE """ - num_experts: int = 1 + moe_num_experts: int = 1 """ Number of MoE experts """ @@ -1300,7 +1300,7 @@ class NeoXArgsTextgen(NeoXArgsTemplate): The minimum capacity per expert regardless of the capacity_factor """ - moe_token_dropping: bool = True + moe_token_dropping: bool = False """ Whether to drop tokens when exceeding capacity """ @@ -1319,3 +1319,29 @@ class NeoXArgsTextgen(NeoXArgsTemplate): """ Number of parallel experts in MoE """ + + moe_type: str = "megablocks" + """ + Either `deepspeed` or `megablocks` + """ + + moe_glu: bool = False + """ + Use gated linear units in MoE + """ + + moe_lbl_in_fp32: bool = False + """ + Whether to compute the load balancing loss in fp32. + """ + + moe_jitter_eps: float = None + """ + Coefficient for MoE routing jitter. Jitter is + not used if set to None + """ + + enable_expert_tensor_parallelism: bool = False + """ + Enable expert tensor parallelism + """ diff --git a/megatron/training.py b/megatron/training.py index 4ce5994a5..6a4e843ab 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -365,6 +365,56 @@ def get_batch_sequential(forward_input, neox_args): return (forward_input[0], forward_input[1], attention_mask) +def average_losses_across_data_parallel_group(losses): + """Reduce a tensor of losses across all GPUs.""" + averaged_losses = torch.cat([loss.clone().detach().view(1) for loss in losses]) + torch.distributed.all_reduce(averaged_losses, group=mpu.get_data_parallel_group()) + averaged_losses = averaged_losses / torch.distributed.get_world_size( + group=mpu.get_data_parallel_group() + ) + + return averaged_losses + + +def mb_moe_loss_func(args, loss_mask, output_tensor=None): + from megatron.model import megablocks_utils + from megatron.model.megablocks_utils import moe + + # NOTE: For pipeline parallelism this function will be run on the + # non-final stages to calculate load balancing loss contribution + # for the MoE layers within the stage. For these cases, output_tensor + # will be None. + loss, loss_dict = (None, {}) + if False: + assert output_tensor is not None + loss, loss_dict = loss_func(loss_mask, output_tensor) + assert loss.numel() == 1 + + # NOTE: If recompute is enabled we will collect duplicate load + # balancing loss contributions. Prune these before calculating + # the load balancing loss. + if args.checkpoint_activations: + # Ignore load balancing loss contributions compute during + # the forward pass if recompute is turned on. + load_balancing_loss_data = moe.get_load_balancing_loss() + if args.num_layers * 2 == len(load_balancing_loss_data): + load_balancing_loss_data = load_balancing_loss_data[args.num_layers :] + moe.clear_load_balancing_loss() + for x in load_balancing_loss_data: + moe.save_load_balancing_loss(x) + + # Compute the load balancing loss for all MoE layers. + megablocks_args = args = megablocks_utils.as_megablocks_args(args) + lbl = moe.batched_load_balancing_loss(megablocks_args) + moe.clear_load_balancing_loss() + + # Average the load balancing loss across data parallel + # replicas and save for logging. + averaged_lbl = average_losses_across_data_parallel_group([lbl]) + loss_dict["load balancing loss"] = averaged_lbl[0] + return averaged_lbl, loss_dict + + def forward_step( data_iterator, model, neox_args, timers, return_logits=False, is_train=False ): @@ -405,8 +455,13 @@ def forward_step( main_loss = cross_entropy( outputs, (labels, loss_mask), _fp16=neox_args.fp16_lm_cross_entropy ) - if neox_args.num_experts > 1: - moe_loss = neox_args.moe_loss_coeff * sum(m.item() for m in moe_losses) + if neox_args.moe_num_experts > 1: + if neox_args.moe_type == "deepspeed": + moe_loss = neox_args.moe_loss_coeff * sum(m.item() for m in moe_losses) + elif neox_args.moe_type == "megablocks": + moe_loss = mb_moe_loss_func(neox_args, loss_mask, outputs)[0] + else: + raise ValueError(f"Unsupported moe_type: {neox_args.moe_type}") else: moe_loss = 0.0 loss = main_loss + moe_loss @@ -710,6 +765,9 @@ def setup_model_and_optimizer(neox_args, use_cache=False, iteration=None): # config_params=neox_args.deepspeed_config, mpu=mpu if not neox_args.is_pipe_parallel else None, ) + if neox_args.moe_num_experts > 1 and neox_args.moe_type == "megablocks": + # We need to additionally set this flag to ensure DS parallelism properly handles this foreign MoE. + model.has_moe_layers = True model.total_params = get_total_params(model.module) print_rank_0(f' > total params: {"{:,}".format(model.total_params)}') diff --git a/tools/ckpts/convert_hf_to_sequential.py b/tools/ckpts/convert_hf_to_sequential.py index 5a5f3bbad..c53f28391 100644 --- a/tools/ckpts/convert_hf_to_sequential.py +++ b/tools/ckpts/convert_hf_to_sequential.py @@ -526,7 +526,7 @@ def get_non_existing_dir(tmp_dir): dist_init_required=False, model_parameters=None, config_params=neox_args.deepspeed_config, - mpu=mpu if not neox_args.is_pipe_parallel else None, + mpu=mpu, ) if os.environ.get("OMPI_COMM_WORLD_RANK", "1") == "0":