From 0d09dd5d20c3a5133e6e6ae1218cdc45ccaf0358 Mon Sep 17 00:00:00 2001 From: Ye Wang <52801275+wangyems@users.noreply.github.com> Date: Thu, 3 Feb 2022 23:59:05 -0800 Subject: [PATCH] Support fusion for TNLR based model (#10432) * support tnlr based offensive V4 model * Update onnx_model_tnlr.py Co-authored-by: Ubuntu --- .../tools/transformers/onnx_model_tnlr.py | 180 ++++++++++++++++++ .../python/tools/transformers/optimizer.py | 18 +- 2 files changed, 190 insertions(+), 8 deletions(-) create mode 100644 onnxruntime/python/tools/transformers/onnx_model_tnlr.py diff --git a/onnxruntime/python/tools/transformers/onnx_model_tnlr.py b/onnxruntime/python/tools/transformers/onnx_model_tnlr.py new file mode 100644 index 0000000000000..c99817c410c3e --- /dev/null +++ b/onnxruntime/python/tools/transformers/onnx_model_tnlr.py @@ -0,0 +1,180 @@ +#------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +#-------------------------------------------------------------------------- +import logging +from fusion_attention import FusionAttention, AttentionMask +from fusion_utils import NumpyHelper +from onnx import helper, numpy_helper, TensorProto, NodeProto +from onnx_model import OnnxModel +from onnx_model_bert import BertOnnxModel +from typing import Union + +logger = logging.getLogger(__name__) + + +class FusionTnlrAttention(FusionAttention): + """ + Fuse TNLR Attention subgraph into one Attention node. + TNLR Attention has extra addtion after qk nodes and adopts [S, B, NH] as I/O shape. + """ + def __init__(self, model: OnnxModel, hidden_size: int, num_heads: int, attention_mask: AttentionMask): + super().__init__(model, hidden_size, num_heads, attention_mask) + + def create_attention_node(self, mask_index: str, matmul: NodeProto, add: NodeProto, num_heads: int, + hidden_size: int, input: str, output: str, add_qk_str: str) -> Union[NodeProto, None]: + + assert num_heads > 0 + if hidden_size > 0 and (hidden_size % num_heads) != 0: + logger.debug(f"input hidden size {hidden_size} is not a multiple of num of heads {num_heads}") + return None + + weight = self.model.get_initializer(matmul.input[1]) + bias = self.model.get_initializer(add.input[1]) or self.model.get_initializer(add.input[0]) + + if weight is None or bias is None: + return None + + qkv_weight = NumpyHelper.to_array(weight) + qkv_bias = NumpyHelper.to_array(bias) + + attention_node_name = self.model.create_node_name('Attention') + + weight = helper.make_tensor(name=attention_node_name + '_qkv_weight', + data_type=TensorProto.FLOAT, + dims=[hidden_size, 3 * hidden_size], + vals=qkv_weight.flatten().tolist()) + + # Sometimes weights and bias are stored in fp16 + if weight.data_type == 10: + weight.CopyFrom(numpy_helper.from_array(NumpyHelper.to_array(weight).astype(np.float16), weight.name)) + self.model.add_initializer(weight, self.this_graph_name) + + bias = helper.make_tensor(name=attention_node_name + '_qkv_bias', + data_type=TensorProto.FLOAT, + dims=[3 * hidden_size], + vals=qkv_bias.flatten().tolist()) + if bias.data_type == 10: + bias.CopyFrom(numpy_helper.from_array(NumpyHelper.to_array(bias).astype(np.float16), bias.name)) + self.model.add_initializer(bias, self.this_graph_name) + + attention_inputs = [input, attention_node_name + '_qkv_weight', attention_node_name + '_qkv_bias'] + if mask_index is not None: + attention_inputs.append(mask_index) + else: + attention_inputs.append("") + + if add_qk_str is not None: + attention_inputs.append("") + attention_inputs.append(add_qk_str) + + attention_node = helper.make_node('Attention', + inputs=attention_inputs, + outputs=[output], + name=attention_node_name) + attention_node.domain = "com.microsoft" + attention_node.attribute.extend([helper.make_attribute("num_heads", num_heads)]) + + return attention_node + + def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): + # Sometimes we can not fuse skiplayernormalization since the add before layernorm has an output that used by nodes outside skiplayernorm + # Conceptually we treat add before layernorm as skiplayernorm node since they share the same pattern + start_node = normalize_node + if normalize_node.op_type != 'SkipLayerNormalization': + return + + # SkipLayerNormalization has two inputs, and one of them is the root input for attention. + qkv_nodes = self.model.match_parent_path(start_node, + ['Where', 'Add', 'MatMul', 'Reshape', 'Transpose', 'MatMul'], + [1, 1, 1, 0, 0, 0]) + if qkv_nodes is not None: + (_, _, matmul_below, reshape_qkv, transpose_qkv, matmul_qkv) = qkv_nodes + else: + return + + other_inputs = [] + for i, input in enumerate(start_node.input): + if input not in output_name_to_node: + continue + + if input == qkv_nodes[0].output[0]: + continue + other_inputs.append(input) + if len(other_inputs) != 1: + return + + root_input = other_inputs[0] + + v_nodes = self.model.match_parent_path(matmul_qkv, ['Transpose', 'Reshape', 'Slice', 'Add', 'MatMul'], + [1, 0, 0, 0, 1]) + if v_nodes is None: + return + (_, _, _, add, matmul) = v_nodes + + upper_nodes = self.model.match_parent_path(matmul, ['Transpose'], [0]) + transpose = upper_nodes[0] + + qk_nodes = self.model.match_parent_path(matmul_qkv, ['Softmax', 'Add', 'MatMul'], [0, 0, 0]) + if qk_nodes is None: + return + (_, add_qk, matmul_qk) = qk_nodes + + q_nodes = self.model.match_parent_path(matmul_qk, ['Mul', 'Transpose', 'Reshape', 'Slice', 'Add', 'MatMul'], + [0, 0, 0, 0, 0, 1]) + if q_nodes is None: + return + add = q_nodes[-2] + matmul = q_nodes[-1] + + k_nodes = self.model.match_parent_path(matmul_qk, ['Transpose', 'Reshape', 'Slice', 'Add', 'MatMul'], + [1, 0, 0, 0, 1]) + if k_nodes is None: + return + add = k_nodes[-2] + matmul = k_nodes[-1] + + extra_add_qk_nodes = self.model.match_parent_path(add_qk, ['Reshape', 'Where'], [1, 0]) + if extra_add_qk_nodes is None: + return + + if matmul.input[0] == root_input: + mask_index = None + attention_last_node = reshape_qkv + # number of heads are same for all the paths, hence to create attention node, we pass the q_num_heads + # the input_hidden_size represents the input hidden size, this is used as needed but hidden sizes for Q, K are extracted appropriately + new_node = self.create_attention_node(mask_index, matmul, add, self.num_heads, self.hidden_size, root_input, + attention_last_node.output[0], extra_add_qk_nodes[0].input[0]) + if new_node is None: + return + + self.nodes_to_add.append(new_node) + self.node_name_to_graph_name[new_node.name] = self.this_graph_name + + # Add a transpose node after the attention node + back_transpose = helper.make_node("Transpose", ["back_transpose_in_" + new_node.name], [new_node.output[0]], + "back_transpose_" + new_node.name, + perm=[1, 0, 2]) + self.model.add_node(back_transpose, self.this_graph_name) + new_node.input[0] = transpose.input[0] + new_node.output[0] = "back_transpose_in_" + new_node.name + + self.nodes_to_remove.extend([attention_last_node, transpose_qkv, matmul_qkv]) + self.nodes_to_remove.extend(qk_nodes) + self.nodes_to_remove.extend(q_nodes) + self.nodes_to_remove.extend(k_nodes) + self.nodes_to_remove.extend(v_nodes) + + # Use prune graph to remove mask nodes since they are shared by all attention nodes. + #self.nodes_to_remove.extend(mask_nodes) + self.prune_graph = True + + +class TnlrOnnxModel(BertOnnxModel): + def __init__(self, model, num_heads, hidden_size): + super().__init__(model, num_heads, hidden_size) + self.attention_mask = AttentionMask(self) + self.attention_fusion = FusionTnlrAttention(self, self.hidden_size, self.num_heads, self.attention_mask) + + def fuse_attention(self): + self.attention_fusion.apply() diff --git a/onnxruntime/python/tools/transformers/optimizer.py b/onnxruntime/python/tools/transformers/optimizer.py index d7a25c829d916..55699ae7b0544 100644 --- a/onnxruntime/python/tools/transformers/optimizer.py +++ b/onnxruntime/python/tools/transformers/optimizer.py @@ -28,6 +28,7 @@ from onnx_model_bert_tf import BertOnnxModelTF from onnx_model_bert_keras import BertOnnxModelKeras from onnx_model_gpt2 import Gpt2OnnxModel +from onnx_model_tnlr import TnlrOnnxModel from fusion_options import FusionOptions logger = logging.getLogger(__name__) @@ -39,7 +40,8 @@ "bert_tf": (BertOnnxModelTF, "tf2onnx", 0), "bert_keras": (BertOnnxModelKeras, "keras2onnx", 0), "gpt2": (Gpt2OnnxModel, "pytorch", 1), - "gpt2_tf": (Gpt2OnnxModel, 'tf2onnx', 0) # might add a class for GPT2OnnxModel for TF later. + "gpt2_tf": (Gpt2OnnxModel, 'tf2onnx', 0), # might add a class for GPT2OnnxModel for TF later. + "tnlr": (TnlrOnnxModel, "pytorch", 1), } @@ -115,9 +117,9 @@ def optimize_by_fusion(model: ModelProto, model (ModelProto): model object model_type (str, optional): model type - like bert, bert_tf, bert_keras or gpt2. Defaults to 'bert'. num_heads (int, optional): number of attention heads. Defaults to 0. - 0 allows detect the parameter from graph automatically (for model_type "bert" only). + 0 allows detect the parameter from graph automatically (for model_type "bert" only). hidden_size (int, optional): hidden size. Defaults to 0. - 0 allows detect the parameter from graph automatically (for model_type "bert" only). + 0 allows detect the parameter from graph automatically (for model_type "bert" only). optimization_options (FusionOptions, optional): optimization options that turn on/off some fusions. Defaults to None. Returns: @@ -159,7 +161,7 @@ def optimize_model(input: str, only_onnxruntime: bool = False): """ Optimize Model by OnnxRuntime and/or python fusion logic. - ONNX Runtime has graph optimizations (https://onnxruntime.ai/docs/resources/graph-optimizations.html). + ONNX Runtime has graph optimizations (https://onnxruntime.ai/docs/resources/graph-optimizations.html). However, the coverage is limited. We also have graph fusions that implemented in Python to improve the coverage. They can combined: ONNX Runtime will run first when opt_level > 0, then graph fusions in Python will be applied. @@ -170,8 +172,8 @@ def optimize_model(input: str, When opt_level is 0 and only_onnxruntime is False, only python fusion logic is used and onnxruntime is disabled. - When opt_level > 1, use_gpu shall set properly since the optimized graph might contain operators for GPU or CPU only. - If your model is intended for GPU inference only (especially float16 or mixed precision model), it is recommended to + When opt_level > 1, use_gpu shall set properly since the optimized graph might contain operators for GPU or CPU only. + If your model is intended for GPU inference only (especially float16 or mixed precision model), it is recommended to set use_gpu to be True, otherwise the model is not optimized for GPU inference. For BERT model, num_heads and hidden_size are optional. For other model types, you need specify these parameters. @@ -180,9 +182,9 @@ def optimize_model(input: str, input (str): input model path. model_type (str, optional): model type - like bert, bert_tf, bert_keras or gpt2. Defaults to 'bert'. num_heads (int, optional): number of attention heads. Defaults to 0. - 0 allows detect the parameter from graph automatically (for model_type "bert" only). + 0 allows detect the parameter from graph automatically (for model_type "bert" only). hidden_size (int, optional): hidden size. Defaults to 0. - 0 allows detect the parameter from graph automatically (for model_type "bert" only). + 0 allows detect the parameter from graph automatically (for model_type "bert" only). optimization_options (FusionOptions, optional): optimization options that turn on/off some fusions. Defaults to None. opt_level (int, optional): onnxruntime graph optimization level (0, 1, 2 or 99) or None. Defaults to None. When the value is None, default value (1 for bert and gpt2, 0 for other model types) will be used.