Skip to content

Commit

Permalink
Add simplified layernorm fusion for Gemma (#20572)
Browse files Browse the repository at this point in the history
Gemma has a `Mul` node right after the `Gather` and before the first
layer norm.
  • Loading branch information
PatriceVignola authored May 7, 2024
1 parent 05b4ad2 commit 478d3e0
Showing 1 changed file with 18 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,20 @@ def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict):
[0, 1, 1, 0, 0, 0],
)

# For Gemma from Microsoft custom export, which has a Multiply after the Gather:
#
# SimplifiedLayerNorm
# +-------------------------------------------------------+
# | |
# Mul --> Pow --> ReduceMean --> Add --> Sqrt --> Div --> Mul --> Mul
# |
# node
sim_ln_nodes_5 = self.model.match_parent_path(
node,
["Mul", "Div", "Sqrt", "Add", "ReduceMean", "Pow", "Mul"],
[1, 1, 1, 0, 0, 0, 0],
)

add_node, pow_node = None, None
if sim_ln_nodes_1 is not None:
sim_ln_nodes = sim_ln_nodes_1
Expand All @@ -99,6 +113,10 @@ def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict):
# Verify that parent input to Pow node is graph_input
if pow_node.input[0] not in self.model.get_graphs_input_names():
return
elif sim_ln_nodes_5 is not None:
sim_ln_nodes = sim_ln_nodes_5
add_node = sim_ln_nodes[3]
pow_node = sim_ln_nodes[-2]
else:
return

Expand Down

0 comments on commit 478d3e0

Please sign in to comment.