From 595447cda3e110af052dc79ad171522258e375f5 Mon Sep 17 00:00:00 2001 From: kaixih Date: Tue, 9 Jul 2024 21:42:46 +0000 Subject: [PATCH] Use separate einsum instances in MoE --- praxis/layers/grok.py | 3 --- praxis/layers/transformers.py | 21 +++++++++++++++------ 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/praxis/layers/grok.py b/praxis/layers/grok.py index 238b01b2..95507b11 100644 --- a/praxis/layers/grok.py +++ b/praxis/layers/grok.py @@ -185,9 +185,6 @@ def GrokStackedTransformerHParams( fp8_ops.Fp8EinsumOp ) p.moe_layer_tpl.einsum_tpl = pax_fiddle.Config(fp8_ops.Fp8EinsumOp) - # p.moe_layer_tpl.einsum_tpl = ( - # pax_fiddle.Config(fp8_ops.Fp8EinsumOp) - # ) return p diff --git a/praxis/layers/transformers.py b/praxis/layers/transformers.py index 79b8ebb1..099d4510 100644 --- a/praxis/layers/transformers.py +++ b/praxis/layers/transformers.py @@ -825,7 +825,12 @@ def setup(self) -> None: ) logging.debug('moe wo WeightHParams %s', wo_pc) self.create_variable('wo_0', wo_pc) - self.create_child('einsum', self.einsum_tpl.clone()) + self.create_child('dispatch_einsum', self.einsum_tpl.clone()) + if self._is_ffn1_gated: + self.create_child('gated_ffn1_hidden0_einsum', self.einsum_tpl.clone()) + self.create_child('gated_ffn1_hidden1_einsum', self.einsum_tpl.clone()) + self.create_child('ffn2_einsum', self.einsum_tpl.clone()) + self.create_child('combine_einsum', self.einsum_tpl.clone()) def _split(self, t_in, sharding): return base_layer.maybe_shard(t_in, sharding, self.mesh_axis_names) @@ -1031,7 +1036,7 @@ def _dispatch_and_combine_expert_outputs(self, inputs, paddings, segment_ids): if self.gating_func in ['top2', 'expert_choice_v2']: combine_tensor = self._split(combine_tensor, ap.gsec) dispatch_tensor = self._split(dispatch_tensor, ap.gsec) - expert_inputs = self.einsum( + expert_inputs = self.dispatch_einsum( 'gsec,gsm->egcm', dispatch_tensor, reshaped_inputs ) elif self.gating_func == 'expert_choice': @@ -1045,8 +1050,12 @@ def _dispatch_and_combine_expert_outputs(self, inputs, paddings, segment_ids): expert_inputs = self._split(expert_inputs, ap.egcm) if self._is_ffn1_gated: - hidden0 = self.einsum('egcm,emh->egch', expert_inputs, theta_wi) - hidden1 = self.einsum('egcm,emh->egch', expert_inputs, theta_wi_gated) + hidden0 = self.gated_ffn1_hidden0_einsum( + 'egcm,emh->egch', expert_inputs, theta_wi + ) + hidden1 = self.gated_ffn1_hidden1_einsum( + 'egcm,emh->egch', expert_inputs, theta_wi_gated + ) if self.gating_func in ['top2', 'expert_choice_v2']: self._count_dead_neurons(hidden1, dispatch_tensor) hidden1 = self.activation(hidden1) @@ -1061,13 +1070,13 @@ def _dispatch_and_combine_expert_outputs(self, inputs, paddings, segment_ids): # Dropout. hidden = self.relu_dropout(hidden) # Output. - expert_output = self.einsum('egch,ehm->egcm', hidden, theta_wo) + expert_output = self.ffn2_einsum('egch,ehm->egcm', hidden, theta_wo) expert_output = self._split(expert_output, ap.egcm) # Now transpose and reshard. transposed_expert_output = jnp.einsum('egcm->gecm', expert_output) transposed_expert_output = self._split(transposed_expert_output, ap.gecm) if self.gating_func in ['top2', 'expert_choice_v2']: - combined_output = self.einsum( + combined_output = self.combine_einsum( 'gecm,gsec->gsm', transposed_expert_output, combine_tensor ) elif self.gating_func == 'expert_choice':