Skip to content

Commit

Permalink
Use separate einsum instances in MoE
Browse files Browse the repository at this point in the history
  • Loading branch information
kaixih committed Jul 9, 2024
1 parent c41477c commit 595447c
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 9 deletions.
3 changes: 0 additions & 3 deletions praxis/layers/grok.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,9 +185,6 @@ def GrokStackedTransformerHParams(
fp8_ops.Fp8EinsumOp
)
p.moe_layer_tpl.einsum_tpl = pax_fiddle.Config(fp8_ops.Fp8EinsumOp)
# p.moe_layer_tpl.einsum_tpl = (
# pax_fiddle.Config(fp8_ops.Fp8EinsumOp)
# )
return p


Expand Down
21 changes: 15 additions & 6 deletions praxis/layers/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -825,7 +825,12 @@ def setup(self) -> None:
)
logging.debug('moe wo WeightHParams %s', wo_pc)
self.create_variable('wo_0', wo_pc)
self.create_child('einsum', self.einsum_tpl.clone())
self.create_child('dispatch_einsum', self.einsum_tpl.clone())
if self._is_ffn1_gated:
self.create_child('gated_ffn1_hidden0_einsum', self.einsum_tpl.clone())
self.create_child('gated_ffn1_hidden1_einsum', self.einsum_tpl.clone())
self.create_child('ffn2_einsum', self.einsum_tpl.clone())
self.create_child('combine_einsum', self.einsum_tpl.clone())

def _split(self, t_in, sharding):
return base_layer.maybe_shard(t_in, sharding, self.mesh_axis_names)
Expand Down Expand Up @@ -1031,7 +1036,7 @@ def _dispatch_and_combine_expert_outputs(self, inputs, paddings, segment_ids):
if self.gating_func in ['top2', 'expert_choice_v2']:
combine_tensor = self._split(combine_tensor, ap.gsec)
dispatch_tensor = self._split(dispatch_tensor, ap.gsec)
expert_inputs = self.einsum(
expert_inputs = self.dispatch_einsum(
'gsec,gsm->egcm', dispatch_tensor, reshaped_inputs
)
elif self.gating_func == 'expert_choice':
Expand All @@ -1045,8 +1050,12 @@ def _dispatch_and_combine_expert_outputs(self, inputs, paddings, segment_ids):
expert_inputs = self._split(expert_inputs, ap.egcm)

if self._is_ffn1_gated:
hidden0 = self.einsum('egcm,emh->egch', expert_inputs, theta_wi)
hidden1 = self.einsum('egcm,emh->egch', expert_inputs, theta_wi_gated)
hidden0 = self.gated_ffn1_hidden0_einsum(
'egcm,emh->egch', expert_inputs, theta_wi
)
hidden1 = self.gated_ffn1_hidden1_einsum(
'egcm,emh->egch', expert_inputs, theta_wi_gated
)
if self.gating_func in ['top2', 'expert_choice_v2']:
self._count_dead_neurons(hidden1, dispatch_tensor)
hidden1 = self.activation(hidden1)
Expand All @@ -1061,13 +1070,13 @@ def _dispatch_and_combine_expert_outputs(self, inputs, paddings, segment_ids):
# Dropout.
hidden = self.relu_dropout(hidden)
# Output.
expert_output = self.einsum('egch,ehm->egcm', hidden, theta_wo)
expert_output = self.ffn2_einsum('egch,ehm->egcm', hidden, theta_wo)
expert_output = self._split(expert_output, ap.egcm)
# Now transpose and reshard.
transposed_expert_output = jnp.einsum('egcm->gecm', expert_output)
transposed_expert_output = self._split(transposed_expert_output, ap.gecm)
if self.gating_func in ['top2', 'expert_choice_v2']:
combined_output = self.einsum(
combined_output = self.combine_einsum(
'gecm,gsec->gsm', transposed_expert_output, combine_tensor
)
elif self.gating_func == 'expert_choice':
Expand Down

0 comments on commit 595447c

Please sign in to comment.