From 3d594ba23e8aeceb1f880128153dade95391c6ff Mon Sep 17 00:00:00 2001 From: Patrice Vignola Date: Mon, 6 May 2024 20:07:14 -0700 Subject: [PATCH] Add simplified layernorm fusion for Gemma (#20572) Gemma has a `Mul` node right after the `Gather` and before the first layer norm. --- .../fusion_simplified_layernorm.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/onnxruntime/python/tools/transformers/fusion_simplified_layernorm.py b/onnxruntime/python/tools/transformers/fusion_simplified_layernorm.py index 6f35fa5617a39..a872b8c2075bc 100644 --- a/onnxruntime/python/tools/transformers/fusion_simplified_layernorm.py +++ b/onnxruntime/python/tools/transformers/fusion_simplified_layernorm.py @@ -79,6 +79,20 @@ def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict): [0, 1, 1, 0, 0, 0], ) + # For Gemma from Microsoft custom export, which has a Multiply after the Gather: + # + # SimplifiedLayerNorm + # +-------------------------------------------------------+ + # | | + # Mul --> Pow --> ReduceMean --> Add --> Sqrt --> Div --> Mul --> Mul + # | + # node + sim_ln_nodes_5 = self.model.match_parent_path( + node, + ["Mul", "Div", "Sqrt", "Add", "ReduceMean", "Pow", "Mul"], + [1, 1, 1, 0, 0, 0, 0], + ) + add_node, pow_node = None, None if sim_ln_nodes_1 is not None: sim_ln_nodes = sim_ln_nodes_1 @@ -99,6 +113,10 @@ def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict): # Verify that parent input to Pow node is graph_input if pow_node.input[0] not in self.model.get_graphs_input_names(): return + elif sim_ln_nodes_5 is not None: + sim_ln_nodes = sim_ln_nodes_5 + add_node = sim_ln_nodes[3] + pow_node = sim_ln_nodes[-2] else: return