From f5325805678c2b9e35aae4528283e0132c5f5bbc Mon Sep 17 00:00:00 2001 From: Quentin Anthony Date: Wed, 27 Nov 2024 11:04:07 -0800 Subject: [PATCH] undo merge error (#1325) --- megatron/model/transformer.py | 5 ++--- megatron/neox_arguments/arguments.py | 2 +- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index 5a4586309..c670fd4bf 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -1248,9 +1248,8 @@ def forward(self, x, attention_mask, layer_past=None): raise KeyError(self.moe_type) with torch.enable_grad() if not self.eval else nullcontext(): - if ( - mlp_bias == None, - self.num_experts > 1 and self.moe_type == "deepspeed", + if mlp_bias == None or ( + self.num_experts > 1 and self.moe_type == "deepspeed" ): # No dropout either assert mlp_bias is None diff --git a/megatron/neox_arguments/arguments.py b/megatron/neox_arguments/arguments.py index 9735a58be..1e5567c80 100644 --- a/megatron/neox_arguments/arguments.py +++ b/megatron/neox_arguments/arguments.py @@ -962,7 +962,7 @@ def calculate_derived(self): else: fp16_conflict = "DeepSpeed fp16 field was set but precision conflicts" assert self.precision == "fp16", fp16_conflict - + if self.bf16 and self.bf16.get("enabled", False): if self.precision is None: self.update_value("precision", "bfloat16")