diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index 5a4586309..c670fd4bf 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -1248,9 +1248,8 @@ def forward(self, x, attention_mask, layer_past=None): raise KeyError(self.moe_type) with torch.enable_grad() if not self.eval else nullcontext(): - if ( - mlp_bias == None, - self.num_experts > 1 and self.moe_type == "deepspeed", + if mlp_bias == None or ( + self.num_experts > 1 and self.moe_type == "deepspeed" ): # No dropout either assert mlp_bias is None diff --git a/megatron/neox_arguments/arguments.py b/megatron/neox_arguments/arguments.py index 9735a58be..1e5567c80 100644 --- a/megatron/neox_arguments/arguments.py +++ b/megatron/neox_arguments/arguments.py @@ -962,7 +962,7 @@ def calculate_derived(self): else: fp16_conflict = "DeepSpeed fp16 field was set but precision conflicts" assert self.precision == "fp16", fp16_conflict - + if self.bf16 and self.bf16.get("enabled", False): if self.precision is None: self.update_value("precision", "bfloat16")