From de5bc5e04d15a19159ac67011bd2dff88c84f116 Mon Sep 17 00:00:00 2001 From: Sumit Agarwal Date: Thu, 12 Oct 2023 13:19:14 -0500 Subject: [PATCH 01/24] Add new fusion Matmul + BN --- .../core/optimizer/graph_transformer_utils.cc | 2 + onnxruntime/core/optimizer/initializer.cc | 75 +++++ onnxruntime/core/optimizer/initializer.h | 2 + .../core/optimizer/matmul_bn_fusion.cc | 318 ++++++++++++++++++ onnxruntime/core/optimizer/matmul_bn_fusion.h | 43 +++ 5 files changed, 440 insertions(+) create mode 100644 onnxruntime/core/optimizer/matmul_bn_fusion.cc create mode 100644 onnxruntime/core/optimizer/matmul_bn_fusion.h diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index 54511aa02a57c..13ffca70bb214 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -50,6 +50,7 @@ #include "core/optimizer/matmul_integer_to_float.h" #include "core/optimizer/matmul_scale_fusion.h" #include "core/optimizer/matmul_transpose_fusion.h" +#include "core/optimizer/matmul_bn_fusion.h" #include "core/optimizer/nchwc_transformer.h" #include "core/optimizer/noop_elimination.h" #include "core/optimizer/not_where_fusion.h" @@ -127,6 +128,7 @@ InlinedVector> GenerateRewriteRules( rules.push_back(std::make_unique()); rules.push_back(std::make_unique()); rules.push_back(std::make_unique()); + rules.push_back(std::make_unique()); rules.push_back(std::make_unique()); rules.push_back(std::make_unique()); break; diff --git a/onnxruntime/core/optimizer/initializer.cc b/onnxruntime/core/optimizer/initializer.cc index 9cdc0d9ef0473..4ce6a963c32fb 100644 --- a/onnxruntime/core/optimizer/initializer.cc +++ b/onnxruntime/core/optimizer/initializer.cc @@ -201,6 +201,30 @@ struct ScalarAdd { } }; +//template +//struct Broadcast { +// void operator()(Tensor& tensor, const onnxruntime::TensorShape& destShape) const { +// ToNumeric to_numeric; +// +// size_t newSize = Tensor::CalculateTensorStorageSize(tensor.DataType(), destShape); +// std::shared_ptr allocator = std::make_shared(); +// void* newData = nullptr; +// if (len > 0) { +// newData = allocator->Alloc(newSize); +// } +// Tensor newTensor(tensor.DataType(), destShape, newData, allocator); +// +// // because broadcasting only works for 1-D tensor +// const size_t block_size = tensor.Shape().GetDims().front(); +// const size_t num_blocks = destShape.Size() / block_size; +// +// auto span = tensor.MutableDataAsSpan(); +// for (auto& dst : span) { +// dst = T(to_numeric(dst) + v); +// } +// } +//}; + template struct Sqrt { void operator()(Tensor& tensor) const { @@ -280,6 +304,26 @@ Initializer& Initializer::div(const Initializer& other) { return *this; } +/* +* It only broadcast 1-D tensor if the dimension of that 1-d tensor either equals to +* 1st or last dimension of destShape. +*/ +//Initializer& Initializer::mulBy1dInitialer(const Initializer& other) { +// ORT_ENFORCE(other.size() == 1, "The multipier tensor should be 1-D tensor"); +// ORT_ENFORCE(other.dims().front() == dims().front() || other.dims().front() == dims().back(), +// "Dimension of the multiplier tensor should be equal to either 1st or last dimension of the multiplicand tensor."); +// +// const size_t block_size = narrow(data_.Shape().SizeFromDimension(gsl::narrow_cast(axis))); +// const size_t num_blocks = size() / block_size; +// ORT_ENFORCE(scalers.size() == 1 || scalers.size() == num_blocks, "Invalid other(scalers) size"); +// utils::MLTypeCallDispatcher t_disp(data_.GetElementType()); +// t_disp.Invoke(data_, scalers.data_, block_size, num_blocks); +// +// utils::MLTypeCallDispatcher t_disp(data_.GetElementType()); +// //data_ = t_disp.Invoke(data_, destShape); +// return *this; +//} + Initializer& Initializer::sqrt() { utils::MLTypeCallDispatcher t_disp(data_.GetElementType()); t_disp.Invoke(data_); @@ -310,6 +354,28 @@ struct ScaleByAxis { } }; +template +struct ScaleToAxis { + void operator()(Tensor& data, const Tensor& scalers, const size_t block_size, const size_t num_blocks) const { + ToNumeric to_numeric; + const auto scaler_size = scalers.Shape().Size(); + T* dst = data.MutableData(); + const T* scalers_data = scalers.Data(); + if (scaler_size == 1) { + const auto numeric_scaler = to_numeric(scalers_data[0]); + for (size_t block_offset = 0, limit = block_size * num_blocks; block_offset < limit; ++block_offset) { + dst[block_offset] = T(to_numeric(dst[block_offset]) * numeric_scaler); + } + } else { + for (size_t block_offset = 0, i = 0; i < num_blocks; i++) { + for (size_t j = 0; j < block_size; ++j, ++block_offset) { + const auto numeric_scaler = to_numeric(scalers_data[j]); + dst[block_offset] = T(to_numeric(dst[block_offset]) * numeric_scaler); + } + } + } + } +}; } // namespace void Initializer::scale_by_axis(const Initializer& scalers, int axis) { @@ -320,5 +386,14 @@ void Initializer::scale_by_axis(const Initializer& scalers, int axis) { utils::MLTypeCallDispatcher t_disp(data_.GetElementType()); t_disp.Invoke(data_, scalers.data_, block_size, num_blocks); } + +void Initializer::scale_to_axis(const Initializer& scalers, int axis) { + ORT_ENFORCE(axis >= 0, "Axis must be non-negative"); + const size_t block_size = narrow(data_.Shape().SizeFromDimension(gsl::narrow_cast(axis))); + const size_t num_blocks = size() / block_size; + ORT_ENFORCE(scalers.size() == 1 || scalers.size() == block_size, "Invalid other(scalers) size"); + utils::MLTypeCallDispatcher t_disp(data_.GetElementType()); + t_disp.Invoke(data_, scalers.data_, block_size, num_blocks); +} #endif // ORT_EXTENDED_MINIMAL_BUILD } // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/initializer.h b/onnxruntime/core/optimizer/initializer.h index dfe054ba1aced..cc118eee77c54 100644 --- a/onnxruntime/core/optimizer/initializer.h +++ b/onnxruntime/core/optimizer/initializer.h @@ -87,6 +87,8 @@ class Initializer final { Initializer& sqrt(); void scale_by_axis(const Initializer& other, int axis); + + void scale_to_axis(const Initializer& other, int axis); #endif // ORT_EXTENDED_MINIMAL_BUILD private: std::string name_; diff --git a/onnxruntime/core/optimizer/matmul_bn_fusion.cc b/onnxruntime/core/optimizer/matmul_bn_fusion.cc new file mode 100644 index 0000000000000..1eada34d30e74 --- /dev/null +++ b/onnxruntime/core/optimizer/matmul_bn_fusion.cc @@ -0,0 +1,318 @@ +#include "core/optimizer/matmul_bn_fusion.h" +#include "core/graph/graph_utils.h" +#include "core/optimizer/initializer.h" +#include "core/optimizer/utils.h" + + +namespace onnxruntime +{ + void AddNodesToRemove( + Node::NodeConstIterator currItr, + const NodeIndex& destNodeIndex, + std::vector& nodesToRemove) + { + while (currItr->Index() != destNodeIndex) { + nodesToRemove.push_back(currItr->Index()); + currItr = currItr->OutputNodesBegin(); + } + } + + NodeIndex GetOtherParentOfNode( + const Node& node, + NodeIndex firstParentIndex) + { + NodeIndex otherParentIndex = std::numeric_limits::max(); + if (node.GetInputEdgesCount() != 2) + { + return otherParentIndex; + } + + auto parentNodeItr = node.InputNodesBegin(); + if (parentNodeItr->Index() != firstParentIndex) + { + otherParentIndex = parentNodeItr->Index(); + } + ++parentNodeItr; + if (parentNodeItr->Index() != firstParentIndex) + { + otherParentIndex = parentNodeItr->Index(); + } + return otherParentIndex; + } + + bool MatmulBNFusion::MatchPath( + const Node& parentNode, + const gsl::span>>& path, + const Node& childNode) const + { + if (path.size() == 0) + { + return true; + } + + if (!graph_utils::IsSupportedOptypeVersionAndDomain(childNode, path[0].first, path[0].second) || + childNode.GetExecutionProviderType() != parentNode.GetExecutionProviderType()) + { + return false; + } + + // last node in the path can have more than one output + // because all those outputs will be preserved by the addition of new Gemm node + if (path.size() > 1 && childNode.GetOutputEdgesCount() != 1) + { + return false; + } + + return MatchPath(childNode, path.subspan(1), *childNode.OutputNodesBegin()); + } + + /* + * Given a MatMul node, it will verify the following pattern. + * MatMul + * | + * / \ + * / \ + * / \ + * Reshape Shape + * | | + * Transpose Cast + * | | + * BatchNormalization Cast + * | | + * Transpose | + * | / + * \ / + * \ / + * \ / + * | + * Reshape + * As of writing this fusion, we are being conversative in the pattern because the customer + * model we are targeting has this exact pattern. Above pattern will evolve in the future + * as we tend to add separate fusion to eliminate Transpose around the BatchNormalization, + * update the model optimizer script to eliminate adjacent Cast operator, etc. + * + * We have to match the path (MatMul->Shape->Cast->Cast->Reshape) because sub-merging the + * BatchNormalization into the MatMul will change MatMul's output and thus we have to make + * sure that MatMul's output is not used by any operator to which MatMul's output matters. + * Other Conditions: + * - B tensor of MatMul should be constant. + * - scale, B, mean, var tensors of BatchNormalization should be constant. + * - Every node in the path except first and last node, should have only 1 output edge. + * + */ + bool MatmulBNFusion::SatisfyCondition( + const Graph& graph, + const Node& node, + const logging::Logger&) const + { + if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "MatMul", { 1, 9, 13 }) || + node.GetOutputEdgesCount() != 2) + { + return false; + } + + auto childNodeIterator = node.OutputNodesBegin(); + const Node& firstChildNode = *childNodeIterator; + ++childNodeIterator; + const Node& secondChildNode = *childNodeIterator; + + std::vector>> firstPath = + {{"Reshape", {1, 5}}, + {"Transpose", {1}}, + {"BatchNormalization", {1, 6, 7}}, + {"Transpose", {1}}, + {"Reshape", {1, 5}}}; + + std::vector>> secondPath = + {{"Shape", {1}}, + {"Cast", {1, 6}}, + {"Cast", {1, 6}}, + {"Reshape", {1, 5}}}; + + if (!(MatchPath(node, firstPath, firstChildNode) ^ MatchPath(node, secondPath, firstChildNode))) + { + return false; + } + + if (!(MatchPath(node, firstPath, secondChildNode) ^ MatchPath(node, secondPath, secondChildNode))) { + return false; + } + + + const auto& batchNormNode = firstChildNode.OpType() == "Reshape" ? + *firstChildNode.OutputNodesBegin()->OutputNodesBegin() : + *secondChildNode.OutputNodesBegin()->OutputNodesBegin(); + + // Check that the appropriate inputs to the Matmul and BN nodes are constants. + if (!graph_utils::NodeArgIsConstant(graph, *node.InputDefs()[1]) || + !graph_utils::NodeArgIsConstant(graph, *batchNormNode.InputDefs()[1]) || + !graph_utils::NodeArgIsConstant(graph, *batchNormNode.InputDefs()[2]) || + !graph_utils::NodeArgIsConstant(graph, *batchNormNode.InputDefs()[3]) || + !graph_utils::NodeArgIsConstant(graph, *batchNormNode.InputDefs()[4])) + { + return false; + } + + // First output from BN is required. Others are optional. If any optional outputs exist we can't fuse. + const auto& output_defs = batchNormNode.OutputDefs(); + if (output_defs.size() > 1) { + for (size_t i = 1, end = output_defs.size(); i < end; ++i) { + if (output_defs[i] != nullptr && output_defs[i]->Exists()) + return false; + } + } + + if (graph.NodeProducesGraphOutput(node)) { + return false; + } + + return true; + } + + Status MatmulBNFusion::Apply( + Graph& graph, + Node& matmulNode, + RewriteRuleEffect& rule_effect, + const logging::Logger&) const + { + auto childNodeIterator = matmulNode.OutputNodesBegin(); + const Node& firstChildNode = *childNodeIterator; + ++childNodeIterator; + const Node& secondChildNode = *childNodeIterator; + + const Node& firstReshape = firstChildNode.OpType() == "Reshape" ? firstChildNode : secondChildNode; + + NodeIndex batchNormNodeIndex = firstReshape.OutputNodesBegin()->OutputNodesBegin()->Index(); + Node& batchNormNode = *graph.GetNode(batchNormNodeIndex); + + // only perform fusion if eplison is present and is of float_32 type + auto epsilonAttr = batchNormNode.GetAttributes().find("epsilon"); + if (epsilonAttr == batchNormNode.GetAttributes().end() || + epsilonAttr->second.type() != ONNX_NAMESPACE::AttributeProto_AttributeType_FLOAT) + { + return Status::OK(); + } + const float epsilon = epsilonAttr->second.f(); + + const onnx::TensorProto* scaleTensor = graph_utils::GetConstantInitializer(graph, batchNormNode.InputDefs()[1]->Name()); + ORT_ENFORCE(scaleTensor); + const onnx::TensorProto* biasTensor = graph_utils::GetConstantInitializer(graph, batchNormNode.InputDefs()[2]->Name()); + ORT_ENFORCE(biasTensor); + const onnx::TensorProto* meanTensor = graph_utils::GetConstantInitializer(graph, batchNormNode.InputDefs()[3]->Name()); + ORT_ENFORCE(meanTensor); + const onnx::TensorProto* varTensor = graph_utils::GetConstantInitializer(graph, batchNormNode.InputDefs()[4]->Name()); + ORT_ENFORCE(varTensor); + const onnx::TensorProto* matmulBTensor = graph_utils::GetConstantInitializer(graph, matmulNode.InputDefs()[1]->Name()); + ORT_ENFORCE(matmulBTensor); + + if (!optimizer_utils::IsFloatingPointDataType(*matmulBTensor) || + !optimizer_utils::IsFloatingPointDataType(*scaleTensor) || + !optimizer_utils::IsFloatingPointDataType(*biasTensor) || + !optimizer_utils::IsFloatingPointDataType(*meanTensor) || + !optimizer_utils::IsFloatingPointDataType(*varTensor) || + scaleTensor->dims_size() != 1 || + biasTensor->dims_size() != 1 || + meanTensor->dims_size() != 1 || + varTensor->dims_size() != 1 || + scaleTensor->dims(0) != matmulBTensor->dims(1) || + biasTensor->dims(0) != matmulBTensor->dims(1) || + meanTensor->dims(0) != matmulBTensor->dims(1) || + varTensor->dims(0) != matmulBTensor->dims(1)) + { + return Status::OK(); + } + + /* + * + * perform bn in terms of [N,H,W,C] + * temp = scale / sqrt(var + epsilon) + * output = (temp * Input) - ((temp * mean) + bias) + * Create a copy of the initializer to perform the above described calculation + * + * + * matmulB = [1792, 512] + * scalar of BN = [512] (temp) // 512 is my channel + * perform BN in terms of [N,H,W,C] + * + */ + // creates copy of the initializer + Initializer scale(*scaleTensor, graph.ModelPath()); + Initializer bias(*biasTensor, graph.ModelPath()); + Initializer mean(*meanTensor, graph.ModelPath()); + Initializer var(*varTensor, graph.ModelPath()); + Initializer matmulB(*matmulBTensor, graph.ModelPath()); + + var.add(epsilon); + var.sqrt(); + scale.div(var); // this is the temp + matmulB.scale_to_axis(scale, 1); + + mean.mul(scale); + bias.sub(mean); + + // remove redundant initializer ??? + + // create B tensorProto for new Gemm node from initializer. + ONNX_NAMESPACE::TensorProto newGemmBTensor(*matmulBTensor); + matmulB.ToProto(newGemmBTensor); + const std::string newGemmBName = graph.GenerateNodeArgName("MatMulBnFusion_GemmB_" + matmulBTensor->name()); + newGemmBTensor.set_name(newGemmBName); + NodeArg& newGemmBNodeArg = graph_utils::AddInitializer(graph, newGemmBTensor); + + // create bias tensorProto for new Gemm node from initializer. + ONNX_NAMESPACE::TensorProto newGemmBiasTensor(*biasTensor); + bias.ToProto(newGemmBiasTensor); + const std::string newGemmBiasName = graph.GenerateNodeArgName("MatMulBnFusion_GemmBias"); + newGemmBiasTensor.set_name(newGemmBiasName); + NodeArg& newGemmBiasNodeArg = graph_utils::AddInitializer(graph, newGemmBiasTensor); + + NodeIndex lastReshapeNodeIndex = firstReshape.OutputNodesBegin()->OutputNodesBegin()-> + OutputNodesBegin()->OutputNodesBegin()->Index(); + graph.AddNode( + graph.GenerateNodeArgName("MatMulBnFusion_Gemm"), + "Gemm", + "Generated from Matmul BatchNormalization fusion", + {matmulNode.MutableInputDefs()[0], &newGemmBNodeArg, &newGemmBiasNodeArg}, + graph.GetNode(lastReshapeNodeIndex)->MutableOutputDefs(), + nullptr, + kOnnxDomain); + + // Do we want to verify whether every node in the path should have only output + // because if any of the node's output is used by any other node, then + // we can't remove all of these nodes. + + std::vector nodesToRemove; + nodesToRemove.push_back(matmulNode.Index()); + + // Remove non-Matmul parent of Reshape if and only if + // that parent has only 1 output. + NodeIndex nonMatmulParentOfFirstReshape = GetOtherParentOfNode(firstReshape, matmulNode.Index()); + if (nonMatmulParentOfFirstReshape != std::numeric_limits::max() && + graph.GetNode(nonMatmulParentOfFirstReshape)->GetOutputEdgesCount() == 1) + { + nodesToRemove.push_back(nonMatmulParentOfFirstReshape); + } + + auto currItr = matmulNode.OutputNodesBegin(); + AddNodesToRemove(currItr, lastReshapeNodeIndex, nodesToRemove); + ++currItr; + AddNodesToRemove(currItr, lastReshapeNodeIndex, nodesToRemove); + nodesToRemove.push_back(lastReshapeNodeIndex); + + for (const auto& nodeIndex : nodesToRemove) { + Node* node = graph.GetNode(nodeIndex); + graph_utils::RemoveNodeOutputEdges(graph, *node); + graph.RemoveNode(nodeIndex); + } + + //scale.broadcast(utils::GetTensorShapeFromTensorProto(*matmulBTensor)); + //matmulB.mul(scale); //[1, 512] + //scale.scale_by_axis(matmulB, 0); + //matmulB.scale_by_axis(scale, 1); + //matmulB.mul(scale); + + rule_effect = RewriteRuleEffect::kRemovedCurrentNode; + + return Status::OK(); + } +} \ No newline at end of file diff --git a/onnxruntime/core/optimizer/matmul_bn_fusion.h b/onnxruntime/core/optimizer/matmul_bn_fusion.h new file mode 100644 index 0000000000000..ed60e3257400a --- /dev/null +++ b/onnxruntime/core/optimizer/matmul_bn_fusion.h @@ -0,0 +1,43 @@ +#pragma once + +#include "core/optimizer/rewrite_rule.h" + + +namespace onnxruntime +{ +/* +* This fusion submerges a BatchNormalization operator to it's super +* precedding MatMul operator, if and only if MatmulBNFusion::SatisfyCondition() +* is true. +*/ +class MatmulBNFusion : public RewriteRule +{ +public: + MatmulBNFusion() : RewriteRule("MatMul_BatchNormalization_Fusion") + { + + } + + std::vector TargetOpTypes() const noexcept + { + return {"MatMul"}; + } + +private: + bool SatisfyCondition( + const Graph& graph, + const Node& node, + const logging::Logger& logger) const override; + + Status Apply( + Graph& graph, + Node& matmulNode, + RewriteRuleEffect& rule_effect, + const logging::Logger& logger) const override; + + bool MatchPath( + const Node& parentNode, + const gsl::span>>& path, + const Node& childNode) const; +}; +} \ No newline at end of file From 4cb3d7e85389ee78df7d4d1dc572da88410517dc Mon Sep 17 00:00:00 2001 From: Sumit Agarwal Date: Thu, 12 Oct 2023 13:23:48 -0500 Subject: [PATCH 02/24] Update comments --- .../core/optimizer/matmul_bn_fusion.cc | 28 ++----------------- onnxruntime/core/optimizer/matmul_bn_fusion.h | 2 +- 2 files changed, 3 insertions(+), 27 deletions(-) diff --git a/onnxruntime/core/optimizer/matmul_bn_fusion.cc b/onnxruntime/core/optimizer/matmul_bn_fusion.cc index 1eada34d30e74..52b3228614c23 100644 --- a/onnxruntime/core/optimizer/matmul_bn_fusion.cc +++ b/onnxruntime/core/optimizer/matmul_bn_fusion.cc @@ -98,7 +98,6 @@ namespace onnxruntime * - B tensor of MatMul should be constant. * - scale, B, mean, var tensors of BatchNormalization should be constant. * - Every node in the path except first and last node, should have only 1 output edge. - * */ bool MatmulBNFusion::SatisfyCondition( const Graph& graph, @@ -172,7 +171,7 @@ namespace onnxruntime Status MatmulBNFusion::Apply( Graph& graph, Node& matmulNode, - RewriteRuleEffect& rule_effect, + RewriteRuleEffect& ruleEffect, const logging::Logger&) const { auto childNodeIterator = matmulNode.OutputNodesBegin(); @@ -223,19 +222,9 @@ namespace onnxruntime } /* - * - * perform bn in terms of [N,H,W,C] * temp = scale / sqrt(var + epsilon) * output = (temp * Input) - ((temp * mean) + bias) - * Create a copy of the initializer to perform the above described calculation - * - * - * matmulB = [1792, 512] - * scalar of BN = [512] (temp) // 512 is my channel - * perform BN in terms of [N,H,W,C] - * */ - // creates copy of the initializer Initializer scale(*scaleTensor, graph.ModelPath()); Initializer bias(*biasTensor, graph.ModelPath()); Initializer mean(*meanTensor, graph.ModelPath()); @@ -250,8 +239,6 @@ namespace onnxruntime mean.mul(scale); bias.sub(mean); - // remove redundant initializer ??? - // create B tensorProto for new Gemm node from initializer. ONNX_NAMESPACE::TensorProto newGemmBTensor(*matmulBTensor); matmulB.ToProto(newGemmBTensor); @@ -277,10 +264,6 @@ namespace onnxruntime nullptr, kOnnxDomain); - // Do we want to verify whether every node in the path should have only output - // because if any of the node's output is used by any other node, then - // we can't remove all of these nodes. - std::vector nodesToRemove; nodesToRemove.push_back(matmulNode.Index()); @@ -305,14 +288,7 @@ namespace onnxruntime graph.RemoveNode(nodeIndex); } - //scale.broadcast(utils::GetTensorShapeFromTensorProto(*matmulBTensor)); - //matmulB.mul(scale); //[1, 512] - //scale.scale_by_axis(matmulB, 0); - //matmulB.scale_by_axis(scale, 1); - //matmulB.mul(scale); - - rule_effect = RewriteRuleEffect::kRemovedCurrentNode; - + ruleEffect = RewriteRuleEffect::kRemovedCurrentNode; return Status::OK(); } } \ No newline at end of file diff --git a/onnxruntime/core/optimizer/matmul_bn_fusion.h b/onnxruntime/core/optimizer/matmul_bn_fusion.h index ed60e3257400a..d809b107352cb 100644 --- a/onnxruntime/core/optimizer/matmul_bn_fusion.h +++ b/onnxruntime/core/optimizer/matmul_bn_fusion.h @@ -32,7 +32,7 @@ class MatmulBNFusion : public RewriteRule Status Apply( Graph& graph, Node& matmulNode, - RewriteRuleEffect& rule_effect, + RewriteRuleEffect& ruleEffect, const logging::Logger& logger) const override; bool MatchPath( From c797f402cd6e659e416742baf62d53632a3b9178 Mon Sep 17 00:00:00 2001 From: Sumit Agarwal Date: Thu, 12 Oct 2023 13:43:38 -0500 Subject: [PATCH 03/24] Remove redundant code --- onnxruntime/core/optimizer/initializer.cc | 44 ----------------------- 1 file changed, 44 deletions(-) diff --git a/onnxruntime/core/optimizer/initializer.cc b/onnxruntime/core/optimizer/initializer.cc index 4ce6a963c32fb..f6b0307d13b33 100644 --- a/onnxruntime/core/optimizer/initializer.cc +++ b/onnxruntime/core/optimizer/initializer.cc @@ -201,30 +201,6 @@ struct ScalarAdd { } }; -//template -//struct Broadcast { -// void operator()(Tensor& tensor, const onnxruntime::TensorShape& destShape) const { -// ToNumeric to_numeric; -// -// size_t newSize = Tensor::CalculateTensorStorageSize(tensor.DataType(), destShape); -// std::shared_ptr allocator = std::make_shared(); -// void* newData = nullptr; -// if (len > 0) { -// newData = allocator->Alloc(newSize); -// } -// Tensor newTensor(tensor.DataType(), destShape, newData, allocator); -// -// // because broadcasting only works for 1-D tensor -// const size_t block_size = tensor.Shape().GetDims().front(); -// const size_t num_blocks = destShape.Size() / block_size; -// -// auto span = tensor.MutableDataAsSpan(); -// for (auto& dst : span) { -// dst = T(to_numeric(dst) + v); -// } -// } -//}; - template struct Sqrt { void operator()(Tensor& tensor) const { @@ -304,26 +280,6 @@ Initializer& Initializer::div(const Initializer& other) { return *this; } -/* -* It only broadcast 1-D tensor if the dimension of that 1-d tensor either equals to -* 1st or last dimension of destShape. -*/ -//Initializer& Initializer::mulBy1dInitialer(const Initializer& other) { -// ORT_ENFORCE(other.size() == 1, "The multipier tensor should be 1-D tensor"); -// ORT_ENFORCE(other.dims().front() == dims().front() || other.dims().front() == dims().back(), -// "Dimension of the multiplier tensor should be equal to either 1st or last dimension of the multiplicand tensor."); -// -// const size_t block_size = narrow(data_.Shape().SizeFromDimension(gsl::narrow_cast(axis))); -// const size_t num_blocks = size() / block_size; -// ORT_ENFORCE(scalers.size() == 1 || scalers.size() == num_blocks, "Invalid other(scalers) size"); -// utils::MLTypeCallDispatcher t_disp(data_.GetElementType()); -// t_disp.Invoke(data_, scalers.data_, block_size, num_blocks); -// -// utils::MLTypeCallDispatcher t_disp(data_.GetElementType()); -// //data_ = t_disp.Invoke(data_, destShape); -// return *this; -//} - Initializer& Initializer::sqrt() { utils::MLTypeCallDispatcher t_disp(data_.GetElementType()); t_disp.Invoke(data_); From 2024d647f062c67a55a01fca8d6e0d08e704280d Mon Sep 17 00:00:00 2001 From: Sumit Agarwal Date: Thu, 12 Oct 2023 15:18:35 -0500 Subject: [PATCH 04/24] Remove extra method scale_to_axis --- onnxruntime/core/optimizer/initializer.cc | 55 ++++++------------- onnxruntime/core/optimizer/initializer.h | 4 +- .../core/optimizer/matmul_bn_fusion.cc | 2 +- 3 files changed, 19 insertions(+), 42 deletions(-) diff --git a/onnxruntime/core/optimizer/initializer.cc b/onnxruntime/core/optimizer/initializer.cc index f6b0307d13b33..df4111182ff76 100644 --- a/onnxruntime/core/optimizer/initializer.cc +++ b/onnxruntime/core/optimizer/initializer.cc @@ -289,7 +289,7 @@ Initializer& Initializer::sqrt() { namespace { template struct ScaleByAxis { - void operator()(Tensor& data, const Tensor& scalers, const size_t block_size, const size_t num_blocks) const { + void operator()(Tensor& data, const Tensor& scalers, const size_t block_size, const size_t num_blocks, const bool columnMajor) const { ToNumeric to_numeric; const auto scaler_size = scalers.Shape().Size(); T* dst = data.MutableData(); @@ -301,32 +301,19 @@ struct ScaleByAxis { } } else { for (size_t block_offset = 0, i = 0; i < num_blocks; i++) { - const auto numeric_scaler = to_numeric(scalers_data[i]); - for (size_t j = 0; j < block_size; ++j, ++block_offset) { - dst[block_offset] = T(to_numeric(dst[block_offset]) * numeric_scaler); + if (columnMajor) + { + for (size_t j = 0; j < block_size; ++j, ++block_offset) { + const auto numeric_scaler = to_numeric(scalers_data[j]); + dst[block_offset] = T(to_numeric(dst[block_offset]) * numeric_scaler); + } } - } - } - } -}; - -template -struct ScaleToAxis { - void operator()(Tensor& data, const Tensor& scalers, const size_t block_size, const size_t num_blocks) const { - ToNumeric to_numeric; - const auto scaler_size = scalers.Shape().Size(); - T* dst = data.MutableData(); - const T* scalers_data = scalers.Data(); - if (scaler_size == 1) { - const auto numeric_scaler = to_numeric(scalers_data[0]); - for (size_t block_offset = 0, limit = block_size * num_blocks; block_offset < limit; ++block_offset) { - dst[block_offset] = T(to_numeric(dst[block_offset]) * numeric_scaler); - } - } else { - for (size_t block_offset = 0, i = 0; i < num_blocks; i++) { - for (size_t j = 0; j < block_size; ++j, ++block_offset) { - const auto numeric_scaler = to_numeric(scalers_data[j]); - dst[block_offset] = T(to_numeric(dst[block_offset]) * numeric_scaler); + else + { + const auto numeric_scaler = to_numeric(scalers_data[i]); + for (size_t j = 0; j < block_size; ++j, ++block_offset) { + dst[block_offset] = T(to_numeric(dst[block_offset]) * numeric_scaler); + } } } } @@ -334,22 +321,14 @@ struct ScaleToAxis { }; } // namespace -void Initializer::scale_by_axis(const Initializer& scalers, int axis) { - ORT_ENFORCE(axis >= 0, "Axis must be non-negative"); - const size_t block_size = narrow(data_.Shape().SizeFromDimension(gsl::narrow_cast(axis))); - const size_t num_blocks = size() / block_size; - ORT_ENFORCE(scalers.size() == 1 || scalers.size() == num_blocks, "Invalid other(scalers) size"); - utils::MLTypeCallDispatcher t_disp(data_.GetElementType()); - t_disp.Invoke(data_, scalers.data_, block_size, num_blocks); -} - -void Initializer::scale_to_axis(const Initializer& scalers, int axis) { +void Initializer::scale_by_axis(const Initializer& scalers, int axis, bool columnMajor) { ORT_ENFORCE(axis >= 0, "Axis must be non-negative"); const size_t block_size = narrow(data_.Shape().SizeFromDimension(gsl::narrow_cast(axis))); const size_t num_blocks = size() / block_size; - ORT_ENFORCE(scalers.size() == 1 || scalers.size() == block_size, "Invalid other(scalers) size"); + ORT_ENFORCE(scalers.size() == 1 || + (columnMajor ? scalers.size() == block_size : scalers.size() == num_blocks), "Invalid other(scalers) size"); utils::MLTypeCallDispatcher t_disp(data_.GetElementType()); - t_disp.Invoke(data_, scalers.data_, block_size, num_blocks); + t_disp.Invoke(data_, scalers.data_, block_size, num_blocks, columnMajor); } #endif // ORT_EXTENDED_MINIMAL_BUILD } // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/initializer.h b/onnxruntime/core/optimizer/initializer.h index cc118eee77c54..648fb88d158c2 100644 --- a/onnxruntime/core/optimizer/initializer.h +++ b/onnxruntime/core/optimizer/initializer.h @@ -86,9 +86,7 @@ class Initializer final { Initializer& sqrt(); - void scale_by_axis(const Initializer& other, int axis); - - void scale_to_axis(const Initializer& other, int axis); + void scale_by_axis(const Initializer& other, int axis, bool columnMajor = false); #endif // ORT_EXTENDED_MINIMAL_BUILD private: std::string name_; diff --git a/onnxruntime/core/optimizer/matmul_bn_fusion.cc b/onnxruntime/core/optimizer/matmul_bn_fusion.cc index 52b3228614c23..80ad424b05f0e 100644 --- a/onnxruntime/core/optimizer/matmul_bn_fusion.cc +++ b/onnxruntime/core/optimizer/matmul_bn_fusion.cc @@ -234,7 +234,7 @@ namespace onnxruntime var.add(epsilon); var.sqrt(); scale.div(var); // this is the temp - matmulB.scale_to_axis(scale, 1); + matmulB.scale_by_axis(scale, 1, true); mean.mul(scale); bias.sub(mean); From 6ea436f3b2adadf3f682ea13d05193f695feb428 Mon Sep 17 00:00:00 2001 From: Sumit Agarwal Date: Thu, 12 Oct 2023 17:53:39 -0500 Subject: [PATCH 05/24] Refactored the code as per ORT style --- onnxruntime/core/optimizer/initializer.cc | 10 +- onnxruntime/core/optimizer/initializer.h | 2 +- .../core/optimizer/matmul_bn_fusion.cc | 542 +++++++++--------- onnxruntime/core/optimizer/matmul_bn_fusion.h | 35 +- 4 files changed, 279 insertions(+), 310 deletions(-) diff --git a/onnxruntime/core/optimizer/initializer.cc b/onnxruntime/core/optimizer/initializer.cc index df4111182ff76..132510896e311 100644 --- a/onnxruntime/core/optimizer/initializer.cc +++ b/onnxruntime/core/optimizer/initializer.cc @@ -289,7 +289,7 @@ Initializer& Initializer::sqrt() { namespace { template struct ScaleByAxis { - void operator()(Tensor& data, const Tensor& scalers, const size_t block_size, const size_t num_blocks, const bool columnMajor) const { + void operator()(Tensor& data, const Tensor& scalers, const size_t block_size, const size_t num_blocks, const bool column_major) const { ToNumeric to_numeric; const auto scaler_size = scalers.Shape().Size(); T* dst = data.MutableData(); @@ -301,7 +301,7 @@ struct ScaleByAxis { } } else { for (size_t block_offset = 0, i = 0; i < num_blocks; i++) { - if (columnMajor) + if (column_major) { for (size_t j = 0; j < block_size; ++j, ++block_offset) { const auto numeric_scaler = to_numeric(scalers_data[j]); @@ -321,14 +321,14 @@ struct ScaleByAxis { }; } // namespace -void Initializer::scale_by_axis(const Initializer& scalers, int axis, bool columnMajor) { +void Initializer::scale_by_axis(const Initializer& scalers, int axis, bool column_major) { ORT_ENFORCE(axis >= 0, "Axis must be non-negative"); const size_t block_size = narrow(data_.Shape().SizeFromDimension(gsl::narrow_cast(axis))); const size_t num_blocks = size() / block_size; ORT_ENFORCE(scalers.size() == 1 || - (columnMajor ? scalers.size() == block_size : scalers.size() == num_blocks), "Invalid other(scalers) size"); + (column_major ? scalers.size() == block_size : scalers.size() == num_blocks), "Invalid other(scalers) size"); utils::MLTypeCallDispatcher t_disp(data_.GetElementType()); - t_disp.Invoke(data_, scalers.data_, block_size, num_blocks, columnMajor); + t_disp.Invoke(data_, scalers.data_, block_size, num_blocks, column_major); } #endif // ORT_EXTENDED_MINIMAL_BUILD } // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/initializer.h b/onnxruntime/core/optimizer/initializer.h index 648fb88d158c2..78e3fd6a3d24e 100644 --- a/onnxruntime/core/optimizer/initializer.h +++ b/onnxruntime/core/optimizer/initializer.h @@ -86,7 +86,7 @@ class Initializer final { Initializer& sqrt(); - void scale_by_axis(const Initializer& other, int axis, bool columnMajor = false); + void scale_by_axis(const Initializer& other, int axis, bool column_major = false); #endif // ORT_EXTENDED_MINIMAL_BUILD private: std::string name_; diff --git a/onnxruntime/core/optimizer/matmul_bn_fusion.cc b/onnxruntime/core/optimizer/matmul_bn_fusion.cc index 80ad424b05f0e..0bde61fa91bb9 100644 --- a/onnxruntime/core/optimizer/matmul_bn_fusion.cc +++ b/onnxruntime/core/optimizer/matmul_bn_fusion.cc @@ -6,289 +6,275 @@ namespace onnxruntime { - void AddNodesToRemove( - Node::NodeConstIterator currItr, - const NodeIndex& destNodeIndex, - std::vector& nodesToRemove) - { - while (currItr->Index() != destNodeIndex) { - nodesToRemove.push_back(currItr->Index()); - currItr = currItr->OutputNodesBegin(); - } - } - - NodeIndex GetOtherParentOfNode( - const Node& node, - NodeIndex firstParentIndex) - { - NodeIndex otherParentIndex = std::numeric_limits::max(); - if (node.GetInputEdgesCount() != 2) - { - return otherParentIndex; - } - - auto parentNodeItr = node.InputNodesBegin(); - if (parentNodeItr->Index() != firstParentIndex) - { - otherParentIndex = parentNodeItr->Index(); - } - ++parentNodeItr; - if (parentNodeItr->Index() != firstParentIndex) - { - otherParentIndex = parentNodeItr->Index(); - } - return otherParentIndex; - } - - bool MatmulBNFusion::MatchPath( - const Node& parentNode, - const gsl::span>>& path, - const Node& childNode) const - { - if (path.size() == 0) - { - return true; - } - - if (!graph_utils::IsSupportedOptypeVersionAndDomain(childNode, path[0].first, path[0].second) || - childNode.GetExecutionProviderType() != parentNode.GetExecutionProviderType()) - { - return false; - } - - // last node in the path can have more than one output - // because all those outputs will be preserved by the addition of new Gemm node - if (path.size() > 1 && childNode.GetOutputEdgesCount() != 1) - { - return false; - } - - return MatchPath(childNode, path.subspan(1), *childNode.OutputNodesBegin()); - } - - /* - * Given a MatMul node, it will verify the following pattern. - * MatMul - * | - * / \ - * / \ - * / \ - * Reshape Shape - * | | - * Transpose Cast - * | | - * BatchNormalization Cast - * | | - * Transpose | - * | / - * \ / - * \ / - * \ / - * | - * Reshape - * As of writing this fusion, we are being conversative in the pattern because the customer - * model we are targeting has this exact pattern. Above pattern will evolve in the future - * as we tend to add separate fusion to eliminate Transpose around the BatchNormalization, - * update the model optimizer script to eliminate adjacent Cast operator, etc. - * - * We have to match the path (MatMul->Shape->Cast->Cast->Reshape) because sub-merging the - * BatchNormalization into the MatMul will change MatMul's output and thus we have to make - * sure that MatMul's output is not used by any operator to which MatMul's output matters. - * Other Conditions: - * - B tensor of MatMul should be constant. - * - scale, B, mean, var tensors of BatchNormalization should be constant. - * - Every node in the path except first and last node, should have only 1 output edge. - */ - bool MatmulBNFusion::SatisfyCondition( - const Graph& graph, - const Node& node, - const logging::Logger&) const - { - if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "MatMul", { 1, 9, 13 }) || - node.GetOutputEdgesCount() != 2) - { - return false; - } - - auto childNodeIterator = node.OutputNodesBegin(); - const Node& firstChildNode = *childNodeIterator; - ++childNodeIterator; - const Node& secondChildNode = *childNodeIterator; - - std::vector>> firstPath = - {{"Reshape", {1, 5}}, - {"Transpose", {1}}, - {"BatchNormalization", {1, 6, 7}}, - {"Transpose", {1}}, - {"Reshape", {1, 5}}}; - - std::vector>> secondPath = - {{"Shape", {1}}, - {"Cast", {1, 6}}, - {"Cast", {1, 6}}, - {"Reshape", {1, 5}}}; - - if (!(MatchPath(node, firstPath, firstChildNode) ^ MatchPath(node, secondPath, firstChildNode))) - { - return false; - } - - if (!(MatchPath(node, firstPath, secondChildNode) ^ MatchPath(node, secondPath, secondChildNode))) { - return false; - } +void AddNodesToRemove(Node::NodeConstIterator curr_iterator, + const NodeIndex& dest_node_index, + std::vector& nodes_to_remove) { + while (curr_iterator->Index() != dest_node_index) { + nodes_to_remove.push_back(curr_iterator->Index()); + curr_iterator = curr_iterator->OutputNodesBegin(); + } +} + +NodeIndex GetOtherParentOfNode(const Node& node, NodeIndex first_parent_index) { + NodeIndex other_parent_index = std::numeric_limits::max(); + if (node.GetInputEdgesCount() != 2) { + return other_parent_index; + } + + auto parent_node_iterator = node.InputNodesBegin(); + if (parent_node_iterator->Index() != first_parent_index) { + other_parent_index = parent_node_iterator->Index(); + } + ++parent_node_iterator; + if (parent_node_iterator->Index() != first_parent_index) { + other_parent_index = parent_node_iterator->Index(); + } + return other_parent_index; +} + +bool MatchPath(const Node& parent_node, + const gsl::span>>& path, + const Node& child_node) { + if (path.size() == 0) { + return true; + } + + if (!graph_utils::IsSupportedOptypeVersionAndDomain(child_node, path[0].first, path[0].second) || + child_node.GetExecutionProviderType() != parent_node.GetExecutionProviderType()) { + return false; + } + + /* + * last node in the path can have more than one output + * because all those outputs will be preserved by the addition of new Gemm node + */ + if (path.size() > 1 && child_node.GetOutputEdgesCount() != 1) { + return false; + } + + return MatchPath(child_node, path.subspan(1), *child_node.OutputNodesBegin()); +} + +/* +* Given a MatMul node, it will verify the following pattern. +* MatMul +* | +* / \ +* / \ +* / \ +* Reshape Shape +* | | +* Transpose Cast +* | | +* BatchNormalization Cast +* | | +* Transpose | +* | / +* \ / +* \ / +* \ / +* | +* Reshape +* As of writing this fusion, we are being conversative in the pattern because the customer +* model we are targeting has this exact pattern. Above pattern will evolve in the future +* as we tend to add separate fusion to eliminate Transpose around the BatchNormalization, +* update the model optimizer script to eliminate adjacent Cast operator, etc. +* +* We have to match the path (MatMul->Shape->Cast->Cast->Reshape) because sub-merging the +* BatchNormalization into the MatMul will change MatMul's output and thus we have to make +* sure that MatMul's output is not used by any operator to which MatMul's output matters. +* Other Conditions: +* - B tensor of MatMul should be constant. +* - scale, B, mean, var tensors of BatchNormalization should be constant. +* - Every node in the path except first and last node, should have only 1 output edge. +*/ +bool MatmulBNFusion::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger&) const { + if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "MatMul", { 1, 9, 13 }) || + node.GetOutputEdgesCount() != 2) { + return false; + } + + auto child_node_iterator = node.OutputNodesBegin(); + const Node& first_child_node = *child_node_iterator; + ++child_node_iterator; + const Node& second_child_node = *child_node_iterator; + + std::vector>> first_path = + {{"Reshape", {1, 5}}, + {"Transpose", {1}}, + {"BatchNormalization", {1, 6, 7}}, + {"Transpose", {1}}, + {"Reshape", {1, 5}}}; + + std::vector>> second_path = + {{"Shape", {1}}, + {"Cast", {1, 6}}, + {"Cast", {1, 6}}, + {"Reshape", {1, 5}}}; + + if (!(MatchPath(node, first_path, first_child_node) ^ MatchPath(node, second_path, first_child_node))) { + return false; + } + + if (!(MatchPath(node, first_path, second_child_node) ^ MatchPath(node, second_path, second_child_node))) { + return false; + } - const auto& batchNormNode = firstChildNode.OpType() == "Reshape" ? - *firstChildNode.OutputNodesBegin()->OutputNodesBegin() : - *secondChildNode.OutputNodesBegin()->OutputNodesBegin(); + const auto& batch_norm_node = first_child_node.OpType() == "Reshape" ? + *first_child_node.OutputNodesBegin()->OutputNodesBegin() : + *second_child_node.OutputNodesBegin()->OutputNodesBegin(); - // Check that the appropriate inputs to the Matmul and BN nodes are constants. - if (!graph_utils::NodeArgIsConstant(graph, *node.InputDefs()[1]) || - !graph_utils::NodeArgIsConstant(graph, *batchNormNode.InputDefs()[1]) || - !graph_utils::NodeArgIsConstant(graph, *batchNormNode.InputDefs()[2]) || - !graph_utils::NodeArgIsConstant(graph, *batchNormNode.InputDefs()[3]) || - !graph_utils::NodeArgIsConstant(graph, *batchNormNode.InputDefs()[4])) - { - return false; - } - - // First output from BN is required. Others are optional. If any optional outputs exist we can't fuse. - const auto& output_defs = batchNormNode.OutputDefs(); - if (output_defs.size() > 1) { - for (size_t i = 1, end = output_defs.size(); i < end; ++i) { - if (output_defs[i] != nullptr && output_defs[i]->Exists()) - return false; - } - } - - if (graph.NodeProducesGraphOutput(node)) { - return false; - } - - return true; + // Check that the appropriate inputs to the Matmul and BN nodes are constants. + if (!graph_utils::NodeArgIsConstant(graph, *node.InputDefs()[1]) || + !graph_utils::NodeArgIsConstant(graph, *batch_norm_node.InputDefs()[1]) || + !graph_utils::NodeArgIsConstant(graph, *batch_norm_node.InputDefs()[2]) || + !graph_utils::NodeArgIsConstant(graph, *batch_norm_node.InputDefs()[3]) || + !graph_utils::NodeArgIsConstant(graph, *batch_norm_node.InputDefs()[4])) { + return false; + } + + // First output from BN is required. Others are optional. If any optional outputs exist we can't fuse. + const auto& output_defs = batch_norm_node.OutputDefs(); + if (output_defs.size() > 1) { + for (size_t i = 1, end = output_defs.size(); i < end; ++i) { + if (output_defs[i] != nullptr && output_defs[i]->Exists()) { + return false; + } } - - Status MatmulBNFusion::Apply( - Graph& graph, - Node& matmulNode, - RewriteRuleEffect& ruleEffect, - const logging::Logger&) const - { - auto childNodeIterator = matmulNode.OutputNodesBegin(); - const Node& firstChildNode = *childNodeIterator; - ++childNodeIterator; - const Node& secondChildNode = *childNodeIterator; - - const Node& firstReshape = firstChildNode.OpType() == "Reshape" ? firstChildNode : secondChildNode; - - NodeIndex batchNormNodeIndex = firstReshape.OutputNodesBegin()->OutputNodesBegin()->Index(); - Node& batchNormNode = *graph.GetNode(batchNormNodeIndex); - - // only perform fusion if eplison is present and is of float_32 type - auto epsilonAttr = batchNormNode.GetAttributes().find("epsilon"); - if (epsilonAttr == batchNormNode.GetAttributes().end() || - epsilonAttr->second.type() != ONNX_NAMESPACE::AttributeProto_AttributeType_FLOAT) - { - return Status::OK(); - } - const float epsilon = epsilonAttr->second.f(); - - const onnx::TensorProto* scaleTensor = graph_utils::GetConstantInitializer(graph, batchNormNode.InputDefs()[1]->Name()); - ORT_ENFORCE(scaleTensor); - const onnx::TensorProto* biasTensor = graph_utils::GetConstantInitializer(graph, batchNormNode.InputDefs()[2]->Name()); - ORT_ENFORCE(biasTensor); - const onnx::TensorProto* meanTensor = graph_utils::GetConstantInitializer(graph, batchNormNode.InputDefs()[3]->Name()); - ORT_ENFORCE(meanTensor); - const onnx::TensorProto* varTensor = graph_utils::GetConstantInitializer(graph, batchNormNode.InputDefs()[4]->Name()); - ORT_ENFORCE(varTensor); - const onnx::TensorProto* matmulBTensor = graph_utils::GetConstantInitializer(graph, matmulNode.InputDefs()[1]->Name()); - ORT_ENFORCE(matmulBTensor); - - if (!optimizer_utils::IsFloatingPointDataType(*matmulBTensor) || - !optimizer_utils::IsFloatingPointDataType(*scaleTensor) || - !optimizer_utils::IsFloatingPointDataType(*biasTensor) || - !optimizer_utils::IsFloatingPointDataType(*meanTensor) || - !optimizer_utils::IsFloatingPointDataType(*varTensor) || - scaleTensor->dims_size() != 1 || - biasTensor->dims_size() != 1 || - meanTensor->dims_size() != 1 || - varTensor->dims_size() != 1 || - scaleTensor->dims(0) != matmulBTensor->dims(1) || - biasTensor->dims(0) != matmulBTensor->dims(1) || - meanTensor->dims(0) != matmulBTensor->dims(1) || - varTensor->dims(0) != matmulBTensor->dims(1)) - { - return Status::OK(); - } + } + + if (graph.NodeProducesGraphOutput(node)) { + return false; + } + + return true; +} + +/* +* BatchNormalization: [https://learn.microsoft.com/en-us/windows/win32/api/directml/ns-directml-dml_batch_normalization_operator_desc] +* Scale * ((Input - Mean) / sqrt(Variance + Epsilon)) + Bias // ignore the FusedActivation in the above definition, that's very specific to DML +* Expanding out the terms: +* Output = (Scale / sqrt(Variance + Epsilon)) * Input + (Scale / sqrt(Variance + Epsilon)) * -Mean + Bias +* Here, +* Scale/sqrt(Variance + Epsilon) = alpha (constant) +* (Scale / sqrt(Variance + Epsilon)) * -Mean + Bias = beta (constant) +* Output = alpha * Input + beta, Input = B tensor of MatMul. +* +*/ +Status MatmulBNFusion::Apply(Graph& graph, Node& matmul_node, RewriteRuleEffect& rule_effect, const logging::Logger&) const { + auto child_node_iterator = matmul_node.OutputNodesBegin(); + const Node& first_child_node = *child_node_iterator; + ++child_node_iterator; + const Node& second_child_node = *child_node_iterator; + + const Node& first_reshape = first_child_node.OpType() == "Reshape" ? first_child_node : second_child_node; + + NodeIndex batch_norm_node_index = first_reshape.OutputNodesBegin()->OutputNodesBegin()->Index(); + Node& batch_norm_node = *graph.GetNode(batch_norm_node_index); + + // only perform fusion if eplison is present and is of float_32 type + auto epsilon_attribute = batch_norm_node.GetAttributes().find("epsilon"); + if (epsilon_attribute == batch_norm_node.GetAttributes().end() || + epsilon_attribute->second.type() != ONNX_NAMESPACE::AttributeProto_AttributeType_FLOAT) { + return Status::OK(); + } + const float epsilon = epsilon_attribute->second.f(); + + const onnx::TensorProto* scale_tensor = graph_utils::GetConstantInitializer(graph, batch_norm_node.InputDefs()[1]->Name()); + ORT_ENFORCE(scale_tensor); + const onnx::TensorProto* bias_tensor = graph_utils::GetConstantInitializer(graph, batch_norm_node.InputDefs()[2]->Name()); + ORT_ENFORCE(bias_tensor); + const onnx::TensorProto* mean_tensor = graph_utils::GetConstantInitializer(graph, batch_norm_node.InputDefs()[3]->Name()); + ORT_ENFORCE(mean_tensor); + const onnx::TensorProto* var_tensor = graph_utils::GetConstantInitializer(graph, batch_norm_node.InputDefs()[4]->Name()); + ORT_ENFORCE(var_tensor); + const onnx::TensorProto* matmul_b_tensor = graph_utils::GetConstantInitializer(graph, matmul_node.InputDefs()[1]->Name()); + ORT_ENFORCE(matmul_b_tensor); + + if (!optimizer_utils::IsFloatingPointDataType(*matmul_b_tensor) || + !optimizer_utils::IsFloatingPointDataType(*scale_tensor) || + !optimizer_utils::IsFloatingPointDataType(*bias_tensor) || + !optimizer_utils::IsFloatingPointDataType(*mean_tensor) || + !optimizer_utils::IsFloatingPointDataType(*var_tensor) || + scale_tensor->dims_size() != 1 || + bias_tensor->dims_size() != 1 || + mean_tensor->dims_size() != 1 || + var_tensor->dims_size() != 1 || + scale_tensor->dims(0) != matmul_b_tensor->dims(1) || + bias_tensor->dims(0) != matmul_b_tensor->dims(1) || + mean_tensor->dims(0) != matmul_b_tensor->dims(1) || + var_tensor->dims(0) != matmul_b_tensor->dims(1)) { + return Status::OK(); + } - /* - * temp = scale / sqrt(var + epsilon) - * output = (temp * Input) - ((temp * mean) + bias) - */ - Initializer scale(*scaleTensor, graph.ModelPath()); - Initializer bias(*biasTensor, graph.ModelPath()); - Initializer mean(*meanTensor, graph.ModelPath()); - Initializer var(*varTensor, graph.ModelPath()); - Initializer matmulB(*matmulBTensor, graph.ModelPath()); - - var.add(epsilon); - var.sqrt(); - scale.div(var); // this is the temp - matmulB.scale_by_axis(scale, 1, true); - - mean.mul(scale); - bias.sub(mean); + /* + * temp = scale / sqrt(var + epsilon) + * output = (temp * Input) - ((temp * mean) + bias) + */ + Initializer scale(*scale_tensor, graph.ModelPath()); + Initializer bias(*bias_tensor, graph.ModelPath()); + Initializer mean(*mean_tensor, graph.ModelPath()); + Initializer var(*var_tensor, graph.ModelPath()); + Initializer matmul_b(*matmul_b_tensor, graph.ModelPath()); + + var.add(epsilon); + var.sqrt(); + scale.div(var); // this is the temp + matmul_b.scale_by_axis(scale, 1, true); + + mean.mul(scale); + bias.sub(mean); - // create B tensorProto for new Gemm node from initializer. - ONNX_NAMESPACE::TensorProto newGemmBTensor(*matmulBTensor); - matmulB.ToProto(newGemmBTensor); - const std::string newGemmBName = graph.GenerateNodeArgName("MatMulBnFusion_GemmB_" + matmulBTensor->name()); - newGemmBTensor.set_name(newGemmBName); - NodeArg& newGemmBNodeArg = graph_utils::AddInitializer(graph, newGemmBTensor); - - // create bias tensorProto for new Gemm node from initializer. - ONNX_NAMESPACE::TensorProto newGemmBiasTensor(*biasTensor); - bias.ToProto(newGemmBiasTensor); - const std::string newGemmBiasName = graph.GenerateNodeArgName("MatMulBnFusion_GemmBias"); - newGemmBiasTensor.set_name(newGemmBiasName); - NodeArg& newGemmBiasNodeArg = graph_utils::AddInitializer(graph, newGemmBiasTensor); - - NodeIndex lastReshapeNodeIndex = firstReshape.OutputNodesBegin()->OutputNodesBegin()-> - OutputNodesBegin()->OutputNodesBegin()->Index(); - graph.AddNode( - graph.GenerateNodeArgName("MatMulBnFusion_Gemm"), - "Gemm", - "Generated from Matmul BatchNormalization fusion", - {matmulNode.MutableInputDefs()[0], &newGemmBNodeArg, &newGemmBiasNodeArg}, - graph.GetNode(lastReshapeNodeIndex)->MutableOutputDefs(), - nullptr, - kOnnxDomain); + // create B tensorProto for new Gemm node from initializer. + ONNX_NAMESPACE::TensorProto new_gemm_b_tensor(*matmul_b_tensor); + matmul_b.ToProto(new_gemm_b_tensor); + const std::string new_gemm_b_name = graph.GenerateNodeArgName("MatMulBnFusion_GemmB_" + matmul_b_tensor->name()); + new_gemm_b_tensor.set_name(new_gemm_b_name); + NodeArg& new_gemm_b_node_arg = graph_utils::AddInitializer(graph, new_gemm_b_tensor); + + // create bias tensorProto for new Gemm node from initializer. + ONNX_NAMESPACE::TensorProto new_gemm_bias_tensor(*bias_tensor); + bias.ToProto(new_gemm_bias_tensor); + const std::string new_gemm_bias_name = graph.GenerateNodeArgName("MatMulBnFusion_GemmBias"); + new_gemm_bias_tensor.set_name(new_gemm_bias_name); + NodeArg& new_gemm_bias_node_arg = graph_utils::AddInitializer(graph, new_gemm_bias_tensor); + + NodeIndex last_reshape_node_index = first_reshape.OutputNodesBegin()->OutputNodesBegin()-> + OutputNodesBegin()->OutputNodesBegin()->Index(); + graph.AddNode( + graph.GenerateNodeArgName("MatMulBnFusion_Gemm"), + "Gemm", + "Generated from Matmul BatchNormalization fusion", + {matmul_node.MutableInputDefs()[0], &new_gemm_b_node_arg, &new_gemm_bias_node_arg}, + graph.GetNode(last_reshape_node_index)->MutableOutputDefs(), + nullptr, + kOnnxDomain); - std::vector nodesToRemove; - nodesToRemove.push_back(matmulNode.Index()); - - // Remove non-Matmul parent of Reshape if and only if - // that parent has only 1 output. - NodeIndex nonMatmulParentOfFirstReshape = GetOtherParentOfNode(firstReshape, matmulNode.Index()); - if (nonMatmulParentOfFirstReshape != std::numeric_limits::max() && - graph.GetNode(nonMatmulParentOfFirstReshape)->GetOutputEdgesCount() == 1) - { - nodesToRemove.push_back(nonMatmulParentOfFirstReshape); - } - - auto currItr = matmulNode.OutputNodesBegin(); - AddNodesToRemove(currItr, lastReshapeNodeIndex, nodesToRemove); - ++currItr; - AddNodesToRemove(currItr, lastReshapeNodeIndex, nodesToRemove); - nodesToRemove.push_back(lastReshapeNodeIndex); - - for (const auto& nodeIndex : nodesToRemove) { - Node* node = graph.GetNode(nodeIndex); - graph_utils::RemoveNodeOutputEdges(graph, *node); - graph.RemoveNode(nodeIndex); - } - - ruleEffect = RewriteRuleEffect::kRemovedCurrentNode; - return Status::OK(); - } + std::vector nodes_to_remove; + nodes_to_remove.push_back(matmul_node.Index()); + + // Remove non-Matmul parent of Reshape if and only if + // that parent has only 1 output. + NodeIndex non_matmul_parent_of_first_reshape = GetOtherParentOfNode(first_reshape, matmul_node.Index()); + if (non_matmul_parent_of_first_reshape != std::numeric_limits::max() && + graph.GetNode(non_matmul_parent_of_first_reshape)->GetOutputEdgesCount() == 1) { + nodes_to_remove.push_back(non_matmul_parent_of_first_reshape); + } + + auto curr_iterator = matmul_node.OutputNodesBegin(); + AddNodesToRemove(curr_iterator, last_reshape_node_index, nodes_to_remove); + ++curr_iterator; + AddNodesToRemove(curr_iterator, last_reshape_node_index, nodes_to_remove); + nodes_to_remove.push_back(last_reshape_node_index); + + for (const auto& node_index : nodes_to_remove) { + Node* node = graph.GetNode(node_index); + graph_utils::RemoveNodeOutputEdges(graph, *node); + graph.RemoveNode(node_index); + } + + rule_effect = RewriteRuleEffect::kRemovedCurrentNode; + return Status::OK(); +} } \ No newline at end of file diff --git a/onnxruntime/core/optimizer/matmul_bn_fusion.h b/onnxruntime/core/optimizer/matmul_bn_fusion.h index d809b107352cb..4ef65341bad64 100644 --- a/onnxruntime/core/optimizer/matmul_bn_fusion.h +++ b/onnxruntime/core/optimizer/matmul_bn_fusion.h @@ -10,34 +10,17 @@ namespace onnxruntime * precedding MatMul operator, if and only if MatmulBNFusion::SatisfyCondition() * is true. */ -class MatmulBNFusion : public RewriteRule -{ -public: - MatmulBNFusion() : RewriteRule("MatMul_BatchNormalization_Fusion") - { - - } - - std::vector TargetOpTypes() const noexcept - { - return {"MatMul"}; - } +class MatmulBNFusion : public RewriteRule { + public: + MatmulBNFusion() : RewriteRule("MatMul_BatchNormalization_Fusion") {} -private: - bool SatisfyCondition( - const Graph& graph, - const Node& node, - const logging::Logger& logger) const override; + std::vector TargetOpTypes() const noexcept { + return {"MatMul"}; + } - Status Apply( - Graph& graph, - Node& matmulNode, - RewriteRuleEffect& ruleEffect, - const logging::Logger& logger) const override; + private: + bool SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const override; - bool MatchPath( - const Node& parentNode, - const gsl::span>>& path, - const Node& childNode) const; + Status Apply(Graph& graph, Node& matmul_node, RewriteRuleEffect& rule_effect, const logging::Logger& logger) const override; }; } \ No newline at end of file From f63bd114edd939e6df88988bfe8ead8eddb26a7a Mon Sep 17 00:00:00 2001 From: Sumit Agarwal Date: Fri, 13 Oct 2023 10:55:44 -0500 Subject: [PATCH 06/24] Added testcase --- .../core/optimizer/matmul_bn_fusion.cc | 4 +- .../test/optimizer/graph_transform_test.cc | 110 ++++++++++++++++++ 2 files changed, 112 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/optimizer/matmul_bn_fusion.cc b/onnxruntime/core/optimizer/matmul_bn_fusion.cc index 0bde61fa91bb9..43d268f75e9e2 100644 --- a/onnxruntime/core/optimizer/matmul_bn_fusion.cc +++ b/onnxruntime/core/optimizer/matmul_bn_fusion.cc @@ -157,8 +157,8 @@ bool MatmulBNFusion::SatisfyCondition(const Graph& graph, const Node& node, cons * Expanding out the terms: * Output = (Scale / sqrt(Variance + Epsilon)) * Input + (Scale / sqrt(Variance + Epsilon)) * -Mean + Bias * Here, -* Scale/sqrt(Variance + Epsilon) = alpha (constant) -* (Scale / sqrt(Variance + Epsilon)) * -Mean + Bias = beta (constant) +* [Scale/sqrt(Variance + Epsilon)] is constant, and let's call it `alpha` +* [(Scale / sqrt(Variance + Epsilon)) * -Mean + Bias] is also constant, and let's call it `beta` * Output = alpha * Input + beta, Input = B tensor of MatMul. * */ diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index dce1f2d40e8b9..1a9cc744ded50 100755 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -31,6 +31,7 @@ #include "core/optimizer/conv_add_act_fusion.h" #include "core/optimizer/conv_add_fusion.h" #include "core/optimizer/conv_bn_fusion.h" +#include "core/optimizer/matmul_bn_fusion.h" #include "core/optimizer/conv_mul_fusion.h" #include "core/optimizer/div_mul_fusion.h" #include "core/optimizer/dropout_elimination.h" @@ -1059,6 +1060,115 @@ TEST_F(GraphTransformationTests, FuseConvBNNoBias) { } } +TEST_F(GraphTransformationTests, FuseMatmulBN) { + constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/matmul_bn.onnx"; + + std::shared_ptr p_model; + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); + Graph& graph = p_model->MainGraph(); + + std::string expected_output_name; + GraphViewer graphViewer(graph); + for (auto& node_index : graphViewer.GetNodesInTopologicalOrder()) { + auto& node = *graph.GetNode(node_index); + if (node.OpType() == "Reshape") { + expected_output_name = node.OutputDefs()[0]->Name(); + } + } + + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + auto rule_transformer_L1 = std::make_unique("RuleTransformerL1"); + ASSERT_STATUS_OK(rule_transformer_L1->Register(std::make_unique())); + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1)); + + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); + + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_TRUE(op_to_count["BatchNormalization"] == 0); + ASSERT_TRUE(op_to_count["MatMul"] == 0); + + for (auto& node : graph.Nodes()) { + if (node.OpType() == "Gemm") { + ASSERT_EQ(node.OutputDefs()[0]->Name(), expected_output_name) + << "fusion should produce the same output name as the last node"; + } + } +} + +TEST_F(GraphTransformationTests, FuseMatmulBNWithEmptyOptionalOutput) { + constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/matmul_bn.onnx"; + + std::shared_ptr p_model; + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); + Graph& graph = p_model->MainGraph(); + + std::string expected_output_name; + GraphViewer graphViewer(graph); + for (auto& node_index : graphViewer.GetNodesInTopologicalOrder()) { + auto& node = *graph.GetNode(node_index); + if (node.OpType() == "Reshape") { + expected_output_name = node.OutputDefs()[0]->Name(); + } + else if (node.OpType() == "BatchNormalization") { + node.MutableOutputDefs().push_back(&graph.GetOrCreateNodeArg("", nullptr)); + } + } + + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + auto rule_transformer_L1 = std::make_unique("RuleTransformerL1"); + ASSERT_STATUS_OK(rule_transformer_L1->Register(std::make_unique())); + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1)); + + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); + + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_TRUE(op_to_count["BatchNormalization"] == 0); + ASSERT_TRUE(op_to_count["MatMul"] == 0); + + for (auto& node : graph.Nodes()) { + if (node.OpType() == "Gemm") { + ASSERT_EQ(node.OutputDefs()[0]->Name(), expected_output_name) + << "fusion should produce the same output name as the last node"; + } + } +} + +// should not fuse +TEST_F(GraphTransformationTests, FuseMatmulBNWithOptionalOutput) { + constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/matmul_bn.onnx"; + + std::shared_ptr p_model; + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); + Graph& graph = p_model->MainGraph(); + + std::string expected_output_name; + GraphViewer graphViewer(graph); + for (auto& node_index : graphViewer.GetNodesInTopologicalOrder()) { + auto& node = *graph.GetNode(node_index); + if (node.OpType() == "Reshape") { + expected_output_name = node.OutputDefs()[0]->Name(); + } else if (node.OpType() == "BatchNormalization") { + // additional additional non-empty output to batchNormalization + ONNX_NAMESPACE::TypeProto optional_output_tensor_type; + optional_output_tensor_type.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TypeProto::kTensorType); + auto& arg = graph.GetOrCreateNodeArg("bn_optional_output", &optional_output_tensor_type); + node.MutableOutputDefs().push_back(&arg); + } + } + + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + auto rule_transformer_L1 = std::make_unique("RuleTransformerL1"); + ASSERT_STATUS_OK(rule_transformer_L1->Register(std::make_unique())); + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1)); + + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); + + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_TRUE(op_to_count["BatchNormalization"] == 1); + ASSERT_TRUE(op_to_count["MatMul"] == 1); + ASSERT_TRUE(op_to_count["Gemm"] == 0); +} + TEST_F(GraphTransformationTests, DontFuseConvWithBNWithOptionalOutputs) { constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/fuse-conv-bn-no-bias.onnx"; From 7cc2013bf36157c2b459b25c9daeb7d5fcf56795 Mon Sep 17 00:00:00 2001 From: Sumit Agarwal Date: Fri, 13 Oct 2023 11:03:49 -0500 Subject: [PATCH 07/24] Added test file --- .../test/optimizer/graph_transform_test.cc | 6 +++--- .../testdata/transform/fusion/fuse-matmul-bn.onnx | Bin 0 -> 952 bytes 2 files changed, 3 insertions(+), 3 deletions(-) create mode 100644 onnxruntime/test/testdata/transform/fusion/fuse-matmul-bn.onnx diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index 1a9cc744ded50..fee74ec8d283a 100755 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -1061,7 +1061,7 @@ TEST_F(GraphTransformationTests, FuseConvBNNoBias) { } TEST_F(GraphTransformationTests, FuseMatmulBN) { - constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/matmul_bn.onnx"; + constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/fuse-matmul-bn.onnx"; std::shared_ptr p_model; ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); @@ -1096,7 +1096,7 @@ TEST_F(GraphTransformationTests, FuseMatmulBN) { } TEST_F(GraphTransformationTests, FuseMatmulBNWithEmptyOptionalOutput) { - constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/matmul_bn.onnx"; + constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/fuse-matmul-bn.onnx"; std::shared_ptr p_model; ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); @@ -1135,7 +1135,7 @@ TEST_F(GraphTransformationTests, FuseMatmulBNWithEmptyOptionalOutput) { // should not fuse TEST_F(GraphTransformationTests, FuseMatmulBNWithOptionalOutput) { - constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/matmul_bn.onnx"; + constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/fuse-matmul-bn.onnx"; std::shared_ptr p_model; ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); diff --git a/onnxruntime/test/testdata/transform/fusion/fuse-matmul-bn.onnx b/onnxruntime/test/testdata/transform/fusion/fuse-matmul-bn.onnx new file mode 100644 index 0000000000000000000000000000000000000000..6b765284540a62410ab7633aa9fdbb343018f9e0 GIT binary patch literal 952 zcmd zSDarY#0wS3FD(JeE3x?|miU(Da2au-N^r3kXCxM+#v2In7o|d&5FG|e>_HF#E+;N@ zIU&K4qQt!7g8bstc$jj|5SWmbAQwwPYEiBOg9EbzqXW|dMs}_+E=(<497%cc#mR{| zsa)(pR#IkSF_@8?nwZDM1{5hvEE3`b(ojDLIVF}PXZYn8pa^Dv>vtRv{)n1%m-TokNvAvOtoV{{uh<(A)NPCgr()PU1qim<_(6|2{#ba;o zxZ3u=euVwp?48!VpN`l$p{GkN84eZ!b}QchFu=&j$m#@3gjy4ylAZKZPkXUjiS|BBPi?`9 zK&ebi31X7@ Date: Fri, 13 Oct 2023 11:14:33 -0500 Subject: [PATCH 08/24] Added extra assertion --- onnxruntime/core/optimizer/matmul_bn_fusion.cc | 16 ++++++++-------- .../test/optimizer/graph_transform_test.cc | 2 ++ 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/onnxruntime/core/optimizer/matmul_bn_fusion.cc b/onnxruntime/core/optimizer/matmul_bn_fusion.cc index 43d268f75e9e2..b46b126e7fe67 100644 --- a/onnxruntime/core/optimizer/matmul_bn_fusion.cc +++ b/onnxruntime/core/optimizer/matmul_bn_fusion.cc @@ -6,12 +6,12 @@ namespace onnxruntime { -void AddNodesToRemove(Node::NodeConstIterator curr_iterator, +void AddNodesToRemove(Node::NodeConstIterator current_iterator, const NodeIndex& dest_node_index, std::vector& nodes_to_remove) { - while (curr_iterator->Index() != dest_node_index) { - nodes_to_remove.push_back(curr_iterator->Index()); - curr_iterator = curr_iterator->OutputNodesBegin(); + while (current_iterator->Index() != dest_node_index) { + nodes_to_remove.push_back(current_iterator->Index()); + current_iterator = current_iterator->OutputNodesBegin(); } } @@ -262,10 +262,10 @@ Status MatmulBNFusion::Apply(Graph& graph, Node& matmul_node, RewriteRuleEffect& nodes_to_remove.push_back(non_matmul_parent_of_first_reshape); } - auto curr_iterator = matmul_node.OutputNodesBegin(); - AddNodesToRemove(curr_iterator, last_reshape_node_index, nodes_to_remove); - ++curr_iterator; - AddNodesToRemove(curr_iterator, last_reshape_node_index, nodes_to_remove); + auto current_iterator = matmul_node.OutputNodesBegin(); + AddNodesToRemove(current_iterator, last_reshape_node_index, nodes_to_remove); + ++current_iterator; + AddNodesToRemove(current_iterator, last_reshape_node_index, nodes_to_remove); nodes_to_remove.push_back(last_reshape_node_index); for (const auto& node_index : nodes_to_remove) { diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index fee74ec8d283a..76797385d6f19 100755 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -1086,6 +1086,7 @@ TEST_F(GraphTransformationTests, FuseMatmulBN) { std::map op_to_count = CountOpsInGraph(graph); ASSERT_TRUE(op_to_count["BatchNormalization"] == 0); ASSERT_TRUE(op_to_count["MatMul"] == 0); + ASSERT_TRUE(op_to_count["Gemm"] == 1); for (auto& node : graph.Nodes()) { if (node.OpType() == "Gemm") { @@ -1124,6 +1125,7 @@ TEST_F(GraphTransformationTests, FuseMatmulBNWithEmptyOptionalOutput) { std::map op_to_count = CountOpsInGraph(graph); ASSERT_TRUE(op_to_count["BatchNormalization"] == 0); ASSERT_TRUE(op_to_count["MatMul"] == 0); + ASSERT_TRUE(op_to_count["Gemm"] == 1); for (auto& node : graph.Nodes()) { if (node.OpType() == "Gemm") { From 7ddeecf7aebc606c6341e06cc193f78502897cd3 Mon Sep 17 00:00:00 2001 From: Sumit Agarwal Date: Mon, 16 Oct 2023 15:40:18 -0500 Subject: [PATCH 09/24] Use inlinedVector instead of initializer_list --- onnxruntime/core/optimizer/matmul_bn_fusion.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/onnxruntime/core/optimizer/matmul_bn_fusion.cc b/onnxruntime/core/optimizer/matmul_bn_fusion.cc index b46b126e7fe67..b47f1462aff27 100644 --- a/onnxruntime/core/optimizer/matmul_bn_fusion.cc +++ b/onnxruntime/core/optimizer/matmul_bn_fusion.cc @@ -33,7 +33,7 @@ NodeIndex GetOtherParentOfNode(const Node& node, NodeIndex first_parent_index) { } bool MatchPath(const Node& parent_node, - const gsl::span>>& path, + const gsl::span>>& path, const Node& child_node) { if (path.size() == 0) { return true; @@ -99,14 +99,14 @@ bool MatmulBNFusion::SatisfyCondition(const Graph& graph, const Node& node, cons ++child_node_iterator; const Node& second_child_node = *child_node_iterator; - std::vector>> first_path = + std::vector>> first_path = {{"Reshape", {1, 5}}, {"Transpose", {1}}, {"BatchNormalization", {1, 6, 7}}, {"Transpose", {1}}, {"Reshape", {1, 5}}}; - std::vector>> second_path = + std::vector>> second_path = {{"Shape", {1}}, {"Cast", {1, 6}}, {"Cast", {1, 6}}, From d1842c9b9bf64673f152667cc3fa0dbc9938fb68 Mon Sep 17 00:00:00 2001 From: Sumit Agarwal Date: Mon, 16 Oct 2023 18:44:31 -0500 Subject: [PATCH 10/24] Add override specifier --- onnxruntime/core/optimizer/matmul_bn_fusion.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/optimizer/matmul_bn_fusion.h b/onnxruntime/core/optimizer/matmul_bn_fusion.h index 4ef65341bad64..58a574db04834 100644 --- a/onnxruntime/core/optimizer/matmul_bn_fusion.h +++ b/onnxruntime/core/optimizer/matmul_bn_fusion.h @@ -14,7 +14,7 @@ class MatmulBNFusion : public RewriteRule { public: MatmulBNFusion() : RewriteRule("MatMul_BatchNormalization_Fusion") {} - std::vector TargetOpTypes() const noexcept { + std::vector TargetOpTypes() const noexcept override { return {"MatMul"}; } From f367a3601b0669768b6c2730c8a81865f89410bd Mon Sep 17 00:00:00 2001 From: Sumit Agarwal Date: Tue, 17 Oct 2023 12:25:58 -0500 Subject: [PATCH 11/24] Addressed bot PR feedback --- onnxruntime/core/optimizer/initializer.cc | 30 ++++++++++--------- .../test/optimizer/graph_transform_test.cc | 18 +++++------ 2 files changed, 25 insertions(+), 23 deletions(-) diff --git a/onnxruntime/core/optimizer/initializer.cc b/onnxruntime/core/optimizer/initializer.cc index 9d4939076b6d1..4e03ef0963856 100644 --- a/onnxruntime/core/optimizer/initializer.cc +++ b/onnxruntime/core/optimizer/initializer.cc @@ -291,7 +291,11 @@ Initializer& Initializer::sqrt() { namespace { template struct ScaleByAxis { - void operator()(Tensor& data, const Tensor& scalers, const size_t block_size, const size_t num_blocks, const bool column_major) const { + void operator()(Tensor& data, + const Tensor& scalers, + const size_t block_size, + const size_t num_blocks, + const bool column_major) const { ToNumeric to_numeric; const auto scaler_size = scalers.Shape().Size(); T* dst = data.MutableData(); @@ -303,19 +307,17 @@ struct ScaleByAxis { } } else { for (size_t block_offset = 0, i = 0; i < num_blocks; i++) { - if (column_major) - { - for (size_t j = 0; j < block_size; ++j, ++block_offset) { - const auto numeric_scaler = to_numeric(scalers_data[j]); - dst[block_offset] = T(to_numeric(dst[block_offset]) * numeric_scaler); - } + if (column_major) { + for (size_t j = 0; j < block_size; ++j, ++block_offset) { + const auto numeric_scaler = to_numeric(scalers_data[j]); + dst[block_offset] = T(to_numeric(dst[block_offset]) * numeric_scaler); + } } - else - { - const auto numeric_scaler = to_numeric(scalers_data[i]); - for (size_t j = 0; j < block_size; ++j, ++block_offset) { - dst[block_offset] = T(to_numeric(dst[block_offset]) * numeric_scaler); - } + else { + const auto numeric_scaler = to_numeric(scalers_data[i]); + for (size_t j = 0; j < block_size; ++j, ++block_offset) { + dst[block_offset] = T(to_numeric(dst[block_offset]) * numeric_scaler); + } } } } @@ -327,7 +329,7 @@ void Initializer::scale_by_axis(const Initializer& scalers, int axis, bool colum ORT_ENFORCE(axis >= 0, "Axis must be non-negative"); const size_t block_size = narrow(data_.Shape().SizeFromDimension(gsl::narrow_cast(axis))); const size_t num_blocks = size() / block_size; - ORT_ENFORCE(scalers.size() == 1 || + ORT_ENFORCE(scalers.size() == 1 || (column_major ? scalers.size() == block_size : scalers.size() == num_blocks), "Invalid other(scalers) size"); utils::MLTypeCallDispatcher t_disp(data_.GetElementType()); t_disp.Invoke(data_, scalers.data_, block_size, num_blocks, column_major); diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index f68b185ac04e4..5a768bd7a0468 100755 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -1104,9 +1104,9 @@ TEST_F(GraphTransformationTests, FuseMatmulBN) { ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); std::map op_to_count = CountOpsInGraph(graph); - ASSERT_TRUE(op_to_count["BatchNormalization"] == 0); - ASSERT_TRUE(op_to_count["MatMul"] == 0); - ASSERT_TRUE(op_to_count["Gemm"] == 1); + ASSERT_EQ(op_to_count["BatchNormalization"], 0); + ASSERT_EQ(op_to_count["MatMul"], 0); + ASSERT_EQ(op_to_count["Gemm"], 1); for (auto& node : graph.Nodes()) { if (node.OpType() == "Gemm") { @@ -1143,9 +1143,9 @@ TEST_F(GraphTransformationTests, FuseMatmulBNWithEmptyOptionalOutput) { ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); std::map op_to_count = CountOpsInGraph(graph); - ASSERT_TRUE(op_to_count["BatchNormalization"] == 0); - ASSERT_TRUE(op_to_count["MatMul"] == 0); - ASSERT_TRUE(op_to_count["Gemm"] == 1); + ASSERT_EQ(op_to_count["BatchNormalization"], 0); + ASSERT_EQ(op_to_count["MatMul"], 0); + ASSERT_EQ(op_to_count["Gemm"], 1); for (auto& node : graph.Nodes()) { if (node.OpType() == "Gemm") { @@ -1186,9 +1186,9 @@ TEST_F(GraphTransformationTests, FuseMatmulBNWithOptionalOutput) { ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); std::map op_to_count = CountOpsInGraph(graph); - ASSERT_TRUE(op_to_count["BatchNormalization"] == 1); - ASSERT_TRUE(op_to_count["MatMul"] == 1); - ASSERT_TRUE(op_to_count["Gemm"] == 0); + ASSERT_EQ(op_to_count["BatchNormalization"], 1); + ASSERT_EQ(op_to_count["MatMul"], 1); + ASSERT_EQ(op_to_count["Gemm"], 0); } TEST_F(GraphTransformationTests, DontFuseConvWithBNWithOptionalOutputs) { From e604ea4997b42279c9340f70d7285cf6457cec50 Mon Sep 17 00:00:00 2001 From: Sumit Agarwal Date: Tue, 17 Oct 2023 22:14:11 -0500 Subject: [PATCH 12/24] Update the pattern as mentioned by Jeff --- onnxruntime/core/optimizer/initializer.cc | 3 +- .../core/optimizer/matmul_bn_fusion.cc | 146 ++++-------------- onnxruntime/core/optimizer/matmul_bn_fusion.h | 3 + .../test/optimizer/graph_transform_test.cc | 11 +- .../transform/fusion/fuse-matmul-bn.onnx | Bin 952 -> 779 bytes 5 files changed, 41 insertions(+), 122 deletions(-) diff --git a/onnxruntime/core/optimizer/initializer.cc b/onnxruntime/core/optimizer/initializer.cc index 4e03ef0963856..73466567ea1ab 100644 --- a/onnxruntime/core/optimizer/initializer.cc +++ b/onnxruntime/core/optimizer/initializer.cc @@ -312,8 +312,7 @@ struct ScaleByAxis { const auto numeric_scaler = to_numeric(scalers_data[j]); dst[block_offset] = T(to_numeric(dst[block_offset]) * numeric_scaler); } - } - else { + } else { const auto numeric_scaler = to_numeric(scalers_data[i]); for (size_t j = 0; j < block_size; ++j, ++block_offset) { dst[block_offset] = T(to_numeric(dst[block_offset]) * numeric_scaler); diff --git a/onnxruntime/core/optimizer/matmul_bn_fusion.cc b/onnxruntime/core/optimizer/matmul_bn_fusion.cc index b47f1462aff27..abcc7b9d1fd36 100644 --- a/onnxruntime/core/optimizer/matmul_bn_fusion.cc +++ b/onnxruntime/core/optimizer/matmul_bn_fusion.cc @@ -1,3 +1,6 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + #include "core/optimizer/matmul_bn_fusion.h" #include "core/graph/graph_utils.h" #include "core/optimizer/initializer.h" @@ -6,32 +9,6 @@ namespace onnxruntime { -void AddNodesToRemove(Node::NodeConstIterator current_iterator, - const NodeIndex& dest_node_index, - std::vector& nodes_to_remove) { - while (current_iterator->Index() != dest_node_index) { - nodes_to_remove.push_back(current_iterator->Index()); - current_iterator = current_iterator->OutputNodesBegin(); - } -} - -NodeIndex GetOtherParentOfNode(const Node& node, NodeIndex first_parent_index) { - NodeIndex other_parent_index = std::numeric_limits::max(); - if (node.GetInputEdgesCount() != 2) { - return other_parent_index; - } - - auto parent_node_iterator = node.InputNodesBegin(); - if (parent_node_iterator->Index() != first_parent_index) { - other_parent_index = parent_node_iterator->Index(); - } - ++parent_node_iterator; - if (parent_node_iterator->Index() != first_parent_index) { - other_parent_index = parent_node_iterator->Index(); - } - return other_parent_index; -} - bool MatchPath(const Node& parent_node, const gsl::span>>& path, const Node& child_node) { @@ -57,32 +34,13 @@ bool MatchPath(const Node& parent_node, /* * Given a MatMul node, it will verify the following pattern. -* MatMul -* | -* / \ -* / \ -* / \ -* Reshape Shape -* | | -* Transpose Cast -* | | -* BatchNormalization Cast -* | | -* Transpose | -* | / -* \ / -* \ / -* \ / -* | -* Reshape -* As of writing this fusion, we are being conversative in the pattern because the customer -* model we are targeting has this exact pattern. Above pattern will evolve in the future -* as we tend to add separate fusion to eliminate Transpose around the BatchNormalization, -* update the model optimizer script to eliminate adjacent Cast operator, etc. -* -* We have to match the path (MatMul->Shape->Cast->Cast->Reshape) because sub-merging the -* BatchNormalization into the MatMul will change MatMul's output and thus we have to make -* sure that MatMul's output is not used by any operator to which MatMul's output matters. +* MatMul +* | +* Reshape +* | +* Transpose +* | +* BatchNormalization * Other Conditions: * - B tensor of MatMul should be constant. * - scale, B, mean, var tensors of BatchNormalization should be constant. @@ -90,40 +48,23 @@ bool MatchPath(const Node& parent_node, */ bool MatmulBNFusion::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger&) const { if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "MatMul", { 1, 9, 13 }) || - node.GetOutputEdgesCount() != 2) { + node.GetOutputEdgesCount() != 1) { return false; } - auto child_node_iterator = node.OutputNodesBegin(); - const Node& first_child_node = *child_node_iterator; - ++child_node_iterator; - const Node& second_child_node = *child_node_iterator; + const Node& child_node = *node.OutputNodesBegin(); - std::vector>> first_path = - {{"Reshape", {1, 5}}, + std::vector>> path { + {"Reshape", {1, 5}}, {"Transpose", {1}}, - {"BatchNormalization", {1, 6, 7}}, - {"Transpose", {1}}, - {"Reshape", {1, 5}}}; - - std::vector>> second_path = - {{"Shape", {1}}, - {"Cast", {1, 6}}, - {"Cast", {1, 6}}, - {"Reshape", {1, 5}}}; + {"BatchNormalization", {1, 6, 7}} + }; - if (!(MatchPath(node, first_path, first_child_node) ^ MatchPath(node, second_path, first_child_node))) { + if (!MatchPath(node, path, child_node)) { return false; - } - - if (!(MatchPath(node, first_path, second_child_node) ^ MatchPath(node, second_path, second_child_node))) { - return false; - } - - - const auto& batch_norm_node = first_child_node.OpType() == "Reshape" ? - *first_child_node.OutputNodesBegin()->OutputNodesBegin() : - *second_child_node.OutputNodesBegin()->OutputNodesBegin(); + } + + const auto& batch_norm_node = *child_node.OutputNodesBegin()->OutputNodesBegin(); // Check that the appropriate inputs to the Matmul and BN nodes are constants. if (!graph_utils::NodeArgIsConstant(graph, *node.InputDefs()[1]) || @@ -163,17 +104,11 @@ bool MatmulBNFusion::SatisfyCondition(const Graph& graph, const Node& node, cons * */ Status MatmulBNFusion::Apply(Graph& graph, Node& matmul_node, RewriteRuleEffect& rule_effect, const logging::Logger&) const { - auto child_node_iterator = matmul_node.OutputNodesBegin(); - const Node& first_child_node = *child_node_iterator; - ++child_node_iterator; - const Node& second_child_node = *child_node_iterator; - - const Node& first_reshape = first_child_node.OpType() == "Reshape" ? first_child_node : second_child_node; - - NodeIndex batch_norm_node_index = first_reshape.OutputNodesBegin()->OutputNodesBegin()->Index(); + const Node& child_node = *matmul_node.OutputNodesBegin(); + NodeIndex batch_norm_node_index = child_node.OutputNodesBegin()->OutputNodesBegin()->Index(); Node& batch_norm_node = *graph.GetNode(batch_norm_node_index); - // only perform fusion if eplison is present and is of float_32 type + // only perform fusion if epsilon is present and is of float_32 type auto epsilon_attribute = batch_norm_node.GetAttributes().find("epsilon"); if (epsilon_attribute == batch_norm_node.GetAttributes().end() || epsilon_attribute->second.type() != ONNX_NAMESPACE::AttributeProto_AttributeType_FLOAT) { @@ -240,40 +175,23 @@ Status MatmulBNFusion::Apply(Graph& graph, Node& matmul_node, RewriteRuleEffect& new_gemm_bias_tensor.set_name(new_gemm_bias_name); NodeArg& new_gemm_bias_node_arg = graph_utils::AddInitializer(graph, new_gemm_bias_tensor); - NodeIndex last_reshape_node_index = first_reshape.OutputNodesBegin()->OutputNodesBegin()-> - OutputNodesBegin()->OutputNodesBegin()->Index(); graph.AddNode( graph.GenerateNodeArgName("MatMulBnFusion_Gemm"), "Gemm", "Generated from Matmul BatchNormalization fusion", {matmul_node.MutableInputDefs()[0], &new_gemm_b_node_arg, &new_gemm_bias_node_arg}, - graph.GetNode(last_reshape_node_index)->MutableOutputDefs(), + matmul_node.MutableOutputDefs(), nullptr, kOnnxDomain); - std::vector nodes_to_remove; - nodes_to_remove.push_back(matmul_node.Index()); - - // Remove non-Matmul parent of Reshape if and only if - // that parent has only 1 output. - NodeIndex non_matmul_parent_of_first_reshape = GetOtherParentOfNode(first_reshape, matmul_node.Index()); - if (non_matmul_parent_of_first_reshape != std::numeric_limits::max() && - graph.GetNode(non_matmul_parent_of_first_reshape)->GetOutputEdgesCount() == 1) { - nodes_to_remove.push_back(non_matmul_parent_of_first_reshape); - } - - auto current_iterator = matmul_node.OutputNodesBegin(); - AddNodesToRemove(current_iterator, last_reshape_node_index, nodes_to_remove); - ++current_iterator; - AddNodesToRemove(current_iterator, last_reshape_node_index, nodes_to_remove); - nodes_to_remove.push_back(last_reshape_node_index); - - for (const auto& node_index : nodes_to_remove) { - Node* node = graph.GetNode(node_index); - graph_utils::RemoveNodeOutputEdges(graph, *node); - graph.RemoveNode(node_index); - } - + // Remove MatMul node. + Node* node = graph.GetNode(matmul_node.Index()); + graph_utils::RemoveNodeOutputEdges(graph, *node); + graph.RemoveNode(matmul_node.Index()); + + // Delete BatchNormalization node and update the input of the child of BatchNormalization + graph_utils::FinalizeNodeFusion(graph, *graph.GetNode(child_node.OutputNodesBegin()->Index()), batch_norm_node); + rule_effect = RewriteRuleEffect::kRemovedCurrentNode; return Status::OK(); } diff --git a/onnxruntime/core/optimizer/matmul_bn_fusion.h b/onnxruntime/core/optimizer/matmul_bn_fusion.h index 58a574db04834..cf539cb8883b5 100644 --- a/onnxruntime/core/optimizer/matmul_bn_fusion.h +++ b/onnxruntime/core/optimizer/matmul_bn_fusion.h @@ -1,3 +1,6 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + #pragma once #include "core/optimizer/rewrite_rule.h" diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index 5a768bd7a0468..204eea824dd8f 100755 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -1091,7 +1091,7 @@ TEST_F(GraphTransformationTests, FuseMatmulBN) { GraphViewer graphViewer(graph); for (auto& node_index : graphViewer.GetNodesInTopologicalOrder()) { auto& node = *graph.GetNode(node_index); - if (node.OpType() == "Reshape") { + if (node.OpType() == "MatMul") { expected_output_name = node.OutputDefs()[0]->Name(); } } @@ -1127,10 +1127,9 @@ TEST_F(GraphTransformationTests, FuseMatmulBNWithEmptyOptionalOutput) { GraphViewer graphViewer(graph); for (auto& node_index : graphViewer.GetNodesInTopologicalOrder()) { auto& node = *graph.GetNode(node_index); - if (node.OpType() == "Reshape") { + if (node.OpType() == "MatMul") { expected_output_name = node.OutputDefs()[0]->Name(); - } - else if (node.OpType() == "BatchNormalization") { + } else if (node.OpType() == "BatchNormalization") { node.MutableOutputDefs().push_back(&graph.GetOrCreateNodeArg("", nullptr)); } } @@ -1167,10 +1166,10 @@ TEST_F(GraphTransformationTests, FuseMatmulBNWithOptionalOutput) { GraphViewer graphViewer(graph); for (auto& node_index : graphViewer.GetNodesInTopologicalOrder()) { auto& node = *graph.GetNode(node_index); - if (node.OpType() == "Reshape") { + if (node.OpType() == "MatMul") { expected_output_name = node.OutputDefs()[0]->Name(); } else if (node.OpType() == "BatchNormalization") { - // additional additional non-empty output to batchNormalization + // additional non-empty output to batchNormalization ONNX_NAMESPACE::TypeProto optional_output_tensor_type; optional_output_tensor_type.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TypeProto::kTensorType); auto& arg = graph.GetOrCreateNodeArg("bn_optional_output", &optional_output_tensor_type); diff --git a/onnxruntime/test/testdata/transform/fusion/fuse-matmul-bn.onnx b/onnxruntime/test/testdata/transform/fusion/fuse-matmul-bn.onnx index 6b765284540a62410ab7633aa9fdbb343018f9e0..386c1f729d6f0b2cb37357d8fb16270c1dd54e9d 100644 GIT binary patch delta 273 zcmdnN-pw{alIa!eMrj8|ZdES9lA^@C;)49*$x?B zm|bBqhy4!AF#FDRJM5P6CE9&f*kos!&}luzrPTi4YiIiv>ul`5zp=CTzm;!)b)%P^ zZ;q+Gv{91%$<4*~r)qrdCa-#8XIGVHzu?(D+v5ES_LKK8#m1ewoo`>qebsLFTu*zC z^MY1OSeopd*pu?&lQI*FwUjwn1sIh?miD^ZMLd0C*QYsmtCL$ delta 447 zcmeBX+rd6Tl4&~oMrjAedNVG;lA^@C;)49*)cE|;l7i9_LoW8>jKqS}ctaunqErYI zDx<_61QFm;;o{9rEXgg+foc-sK~|{58VoXlOP32pScoS%vA6`P!$66}InlAWM2nM) zsU%;5Z2=<_mj)M#3L%(7KuX~%*+D9mxu9;FoW^LI^0ZIHJ|}FgoyOfKHkk*_>_asx z?Ok0f>>tS$*q3_M+UTVk+P`bAu-`bj%zpJ(R(o-Nb^C+7#r8%na`wutA@&7FBke_g zOWX54kFuSzL*M>;6py{R<7(Ug`VsbXvv*qeemY_|nVl&%uJG(xyVR3=?36tW>_a(g z?WZTT*g3H$<;5puCKhWcbFd09Dv6|8TH5c)x@0H))YD$ Date: Wed, 18 Oct 2023 09:08:53 -0500 Subject: [PATCH 13/24] Apply LintRunner formatting changes --- onnxruntime/core/optimizer/initializer.cc | 3 +- .../core/optimizer/matmul_bn_fusion.cc | 91 +++++++++---------- onnxruntime/core/optimizer/matmul_bn_fusion.h | 14 ++- 3 files changed, 52 insertions(+), 56 deletions(-) diff --git a/onnxruntime/core/optimizer/initializer.cc b/onnxruntime/core/optimizer/initializer.cc index 73466567ea1ab..9e807ddc7be59 100644 --- a/onnxruntime/core/optimizer/initializer.cc +++ b/onnxruntime/core/optimizer/initializer.cc @@ -329,7 +329,8 @@ void Initializer::scale_by_axis(const Initializer& scalers, int axis, bool colum const size_t block_size = narrow(data_.Shape().SizeFromDimension(gsl::narrow_cast(axis))); const size_t num_blocks = size() / block_size; ORT_ENFORCE(scalers.size() == 1 || - (column_major ? scalers.size() == block_size : scalers.size() == num_blocks), "Invalid other(scalers) size"); + (column_major ? scalers.size() == block_size : scalers.size() == num_blocks), + "Invalid other(scalers) size"); utils::MLTypeCallDispatcher t_disp(data_.GetElementType()); t_disp.Invoke(data_, scalers.data_, block_size, num_blocks, column_major); } diff --git a/onnxruntime/core/optimizer/matmul_bn_fusion.cc b/onnxruntime/core/optimizer/matmul_bn_fusion.cc index abcc7b9d1fd36..f752187465e7a 100644 --- a/onnxruntime/core/optimizer/matmul_bn_fusion.cc +++ b/onnxruntime/core/optimizer/matmul_bn_fusion.cc @@ -6,9 +6,7 @@ #include "core/optimizer/initializer.h" #include "core/optimizer/utils.h" - -namespace onnxruntime -{ +namespace onnxruntime { bool MatchPath(const Node& parent_node, const gsl::span>>& path, const Node& child_node) { @@ -22,9 +20,9 @@ bool MatchPath(const Node& parent_node, } /* - * last node in the path can have more than one output - * because all those outputs will be preserved by the addition of new Gemm node - */ + * last node in the path can have more than one output + * because all those outputs will be preserved by the addition of new Gemm node + */ if (path.size() > 1 && child_node.GetOutputEdgesCount() != 1) { return false; } @@ -33,39 +31,38 @@ bool MatchPath(const Node& parent_node, } /* -* Given a MatMul node, it will verify the following pattern. -* MatMul -* | -* Reshape -* | -* Transpose -* | -* BatchNormalization -* Other Conditions: -* - B tensor of MatMul should be constant. -* - scale, B, mean, var tensors of BatchNormalization should be constant. -* - Every node in the path except first and last node, should have only 1 output edge. -*/ + * Given a MatMul node, it will verify the following pattern. + * MatMul + * | + * Reshape + * | + * Transpose + * | + * BatchNormalization + * Other Conditions: + * - B tensor of MatMul should be constant. + * - scale, B, mean, var tensors of BatchNormalization should be constant. + * - Every node in the path except first and last node, should have only 1 output edge. + */ bool MatmulBNFusion::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger&) const { - if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "MatMul", { 1, 9, 13 }) || + if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "MatMul", {1, 9, 13}) || node.GetOutputEdgesCount() != 1) { return false; } const Node& child_node = *node.OutputNodesBegin(); - std::vector>> path { - {"Reshape", {1, 5}}, - {"Transpose", {1}}, - {"BatchNormalization", {1, 6, 7}} - }; + std::vector>> path{ + {"Reshape", {1, 5}}, + {"Transpose", {1}}, + {"BatchNormalization", {1, 6, 7}}}; if (!MatchPath(node, path, child_node)) { return false; - } - + } + const auto& batch_norm_node = *child_node.OutputNodesBegin()->OutputNodesBegin(); - + // Check that the appropriate inputs to the Matmul and BN nodes are constants. if (!graph_utils::NodeArgIsConstant(graph, *node.InputDefs()[1]) || !graph_utils::NodeArgIsConstant(graph, *batch_norm_node.InputDefs()[1]) || @@ -93,16 +90,16 @@ bool MatmulBNFusion::SatisfyCondition(const Graph& graph, const Node& node, cons } /* -* BatchNormalization: [https://learn.microsoft.com/en-us/windows/win32/api/directml/ns-directml-dml_batch_normalization_operator_desc] -* Scale * ((Input - Mean) / sqrt(Variance + Epsilon)) + Bias // ignore the FusedActivation in the above definition, that's very specific to DML -* Expanding out the terms: -* Output = (Scale / sqrt(Variance + Epsilon)) * Input + (Scale / sqrt(Variance + Epsilon)) * -Mean + Bias -* Here, -* [Scale/sqrt(Variance + Epsilon)] is constant, and let's call it `alpha` -* [(Scale / sqrt(Variance + Epsilon)) * -Mean + Bias] is also constant, and let's call it `beta` -* Output = alpha * Input + beta, Input = B tensor of MatMul. -* -*/ + * BatchNormalization: [https://learn.microsoft.com/en-us/windows/win32/api/directml/ns-directml-dml_batch_normalization_operator_desc] + * Scale * ((Input - Mean) / sqrt(Variance + Epsilon)) + Bias // ignore the FusedActivation in the above definition, that's very specific to DML + * Expanding out the terms: + * Output = (Scale / sqrt(Variance + Epsilon)) * Input + (Scale / sqrt(Variance + Epsilon)) * -Mean + Bias + * Here, + * [Scale/sqrt(Variance + Epsilon)] is constant, and let's call it `alpha` + * [(Scale / sqrt(Variance + Epsilon)) * -Mean + Bias] is also constant, and let's call it `beta` + * Output = alpha * Input + beta, Input = B tensor of MatMul. + * + */ Status MatmulBNFusion::Apply(Graph& graph, Node& matmul_node, RewriteRuleEffect& rule_effect, const logging::Logger&) const { const Node& child_node = *matmul_node.OutputNodesBegin(); NodeIndex batch_norm_node_index = child_node.OutputNodesBegin()->OutputNodesBegin()->Index(); @@ -142,11 +139,11 @@ Status MatmulBNFusion::Apply(Graph& graph, Node& matmul_node, RewriteRuleEffect& var_tensor->dims(0) != matmul_b_tensor->dims(1)) { return Status::OK(); } - + /* - * temp = scale / sqrt(var + epsilon) - * output = (temp * Input) - ((temp * mean) + bias) - */ + * temp = scale / sqrt(var + epsilon) + * output = (temp * Input) - ((temp * mean) + bias) + */ Initializer scale(*scale_tensor, graph.ModelPath()); Initializer bias(*bias_tensor, graph.ModelPath()); Initializer mean(*mean_tensor, graph.ModelPath()); @@ -155,12 +152,12 @@ Status MatmulBNFusion::Apply(Graph& graph, Node& matmul_node, RewriteRuleEffect& var.add(epsilon); var.sqrt(); - scale.div(var); // this is the temp + scale.div(var); // this is the temp matmul_b.scale_by_axis(scale, 1, true); mean.mul(scale); bias.sub(mean); - + // create B tensorProto for new Gemm node from initializer. ONNX_NAMESPACE::TensorProto new_gemm_b_tensor(*matmul_b_tensor); matmul_b.ToProto(new_gemm_b_tensor); @@ -183,7 +180,7 @@ Status MatmulBNFusion::Apply(Graph& graph, Node& matmul_node, RewriteRuleEffect& matmul_node.MutableOutputDefs(), nullptr, kOnnxDomain); - + // Remove MatMul node. Node* node = graph.GetNode(matmul_node.Index()); graph_utils::RemoveNodeOutputEdges(graph, *node); @@ -191,8 +188,8 @@ Status MatmulBNFusion::Apply(Graph& graph, Node& matmul_node, RewriteRuleEffect& // Delete BatchNormalization node and update the input of the child of BatchNormalization graph_utils::FinalizeNodeFusion(graph, *graph.GetNode(child_node.OutputNodesBegin()->Index()), batch_norm_node); - + rule_effect = RewriteRuleEffect::kRemovedCurrentNode; return Status::OK(); } -} \ No newline at end of file +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/core/optimizer/matmul_bn_fusion.h b/onnxruntime/core/optimizer/matmul_bn_fusion.h index cf539cb8883b5..7a43483cf37d4 100644 --- a/onnxruntime/core/optimizer/matmul_bn_fusion.h +++ b/onnxruntime/core/optimizer/matmul_bn_fusion.h @@ -5,14 +5,12 @@ #include "core/optimizer/rewrite_rule.h" - -namespace onnxruntime -{ +namespace onnxruntime { /* -* This fusion submerges a BatchNormalization operator to it's super -* precedding MatMul operator, if and only if MatmulBNFusion::SatisfyCondition() -* is true. -*/ + * This fusion submerges a BatchNormalization operator to it's super + * precedding MatMul operator, if and only if MatmulBNFusion::SatisfyCondition() + * is true. + */ class MatmulBNFusion : public RewriteRule { public: MatmulBNFusion() : RewriteRule("MatMul_BatchNormalization_Fusion") {} @@ -26,4 +24,4 @@ class MatmulBNFusion : public RewriteRule { Status Apply(Graph& graph, Node& matmul_node, RewriteRuleEffect& rule_effect, const logging::Logger& logger) const override; }; -} \ No newline at end of file +} // namespace onnxruntime \ No newline at end of file From 79984f12acd5286b1ec367fd3a4d224f88234f9b Mon Sep 17 00:00:00 2001 From: Sumit Agarwal Date: Fri, 20 Oct 2023 13:17:28 -0500 Subject: [PATCH 14/24] Addressed PR comment --- .../core/optimizer/matmul_bn_fusion.cc | 34 +++-- .../test/optimizer/graph_transform_test.cc | 120 +++++++++++++++++- .../fusion/fuse-matmul-bn-with-reshape.onnx | Bin 0 -> 709 bytes .../fuse-matmul-bn-without-reshape.onnx | Bin 0 -> 547 bytes .../transform/fusion/fuse-matmul-bn.onnx | Bin 779 -> 0 bytes 5 files changed, 134 insertions(+), 20 deletions(-) create mode 100644 onnxruntime/test/testdata/transform/fusion/fuse-matmul-bn-with-reshape.onnx create mode 100644 onnxruntime/test/testdata/transform/fusion/fuse-matmul-bn-without-reshape.onnx delete mode 100644 onnxruntime/test/testdata/transform/fusion/fuse-matmul-bn.onnx diff --git a/onnxruntime/core/optimizer/matmul_bn_fusion.cc b/onnxruntime/core/optimizer/matmul_bn_fusion.cc index f752187465e7a..bc27c59a8ae02 100644 --- a/onnxruntime/core/optimizer/matmul_bn_fusion.cc +++ b/onnxruntime/core/optimizer/matmul_bn_fusion.cc @@ -32,9 +32,9 @@ bool MatchPath(const Node& parent_node, /* * Given a MatMul node, it will verify the following pattern. - * MatMul - * | - * Reshape + * MatMul MatMul + * | | + * Reshape OR BatchNormalization * | * Transpose * | @@ -42,7 +42,7 @@ bool MatchPath(const Node& parent_node, * Other Conditions: * - B tensor of MatMul should be constant. * - scale, B, mean, var tensors of BatchNormalization should be constant. - * - Every node in the path except first and last node, should have only 1 output edge. + * - Every node in the path, except the BatchNormalization, should have only 1 output edge. */ bool MatmulBNFusion::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger&) const { if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "MatMul", {1, 9, 13}) || @@ -50,18 +50,20 @@ bool MatmulBNFusion::SatisfyCondition(const Graph& graph, const Node& node, cons return false; } - const Node& child_node = *node.OutputNodesBegin(); + std::vector>> path_with_reshape{ + {"Reshape", {1, 5, 13, 14, 19}}, + {"Transpose", {1, 13}}, + {"BatchNormalization", {1, 6, 7, 9, 14, 15}}}; - std::vector>> path{ - {"Reshape", {1, 5}}, - {"Transpose", {1}}, - {"BatchNormalization", {1, 6, 7}}}; + std::vector>> path_without_reshape{ + {"BatchNormalization", {1, 6, 7, 9, 14, 15}}}; - if (!MatchPath(node, path, child_node)) { + const Node& child_node = *node.OutputNodesBegin(); + if (!(MatchPath(node, path_with_reshape, child_node) ^ MatchPath(node, path_without_reshape, child_node))) { return false; } - const auto& batch_norm_node = *child_node.OutputNodesBegin()->OutputNodesBegin(); + const auto& batch_norm_node = child_node.OpType() == "Reshape" ? *child_node.OutputNodesBegin()->OutputNodesBegin() : child_node; // Check that the appropriate inputs to the Matmul and BN nodes are constants. if (!graph_utils::NodeArgIsConstant(graph, *node.InputDefs()[1]) || @@ -102,7 +104,8 @@ bool MatmulBNFusion::SatisfyCondition(const Graph& graph, const Node& node, cons */ Status MatmulBNFusion::Apply(Graph& graph, Node& matmul_node, RewriteRuleEffect& rule_effect, const logging::Logger&) const { const Node& child_node = *matmul_node.OutputNodesBegin(); - NodeIndex batch_norm_node_index = child_node.OutputNodesBegin()->OutputNodesBegin()->Index(); + NodeIndex batch_norm_node_index = child_node.OpType() == "Reshape" ? + child_node.OutputNodesBegin()->OutputNodesBegin()->Index() : child_node.Index(); Node& batch_norm_node = *graph.GetNode(batch_norm_node_index); // only perform fusion if epsilon is present and is of float_32 type @@ -172,7 +175,7 @@ Status MatmulBNFusion::Apply(Graph& graph, Node& matmul_node, RewriteRuleEffect& new_gemm_bias_tensor.set_name(new_gemm_bias_name); NodeArg& new_gemm_bias_node_arg = graph_utils::AddInitializer(graph, new_gemm_bias_tensor); - graph.AddNode( + Node& gemm_node = graph.AddNode( graph.GenerateNodeArgName("MatMulBnFusion_Gemm"), "Gemm", "Generated from Matmul BatchNormalization fusion", @@ -186,8 +189,11 @@ Status MatmulBNFusion::Apply(Graph& graph, Node& matmul_node, RewriteRuleEffect& graph_utils::RemoveNodeOutputEdges(graph, *node); graph.RemoveNode(matmul_node.Index()); + // Delete optional empty output defs. // Delete BatchNormalization node and update the input of the child of BatchNormalization - graph_utils::FinalizeNodeFusion(graph, *graph.GetNode(child_node.OutputNodesBegin()->Index()), batch_norm_node); + batch_norm_node.MutableOutputDefs().resize(1); + NodeIndex batch_norm_parent_index = child_node.OpType() == "Reshape" ? child_node.OutputNodesBegin()->Index() : gemm_node.Index(); + graph_utils::FinalizeNodeFusion(graph, *graph.GetNode(batch_norm_parent_index), batch_norm_node); rule_effect = RewriteRuleEffect::kRemovedCurrentNode; return Status::OK(); diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index 204eea824dd8f..a767f0ed6d229 100755 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -1080,8 +1080,8 @@ TEST_F(GraphTransformationTests, FuseConvBNNoBias) { } } -TEST_F(GraphTransformationTests, FuseMatmulBN) { - constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/fuse-matmul-bn.onnx"; +TEST_F(GraphTransformationTests, FuseMatmulBNWithReshape) { + constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/fuse-matmul-bn-with-reshape.onnx"; std::shared_ptr p_model; ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); @@ -1116,8 +1116,8 @@ TEST_F(GraphTransformationTests, FuseMatmulBN) { } } -TEST_F(GraphTransformationTests, FuseMatmulBNWithEmptyOptionalOutput) { - constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/fuse-matmul-bn.onnx"; +TEST_F(GraphTransformationTests, FuseMatmulBNWithEmptyOptionalOutputWithReshape) { + constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/fuse-matmul-bn-with-reshape.onnx"; std::shared_ptr p_model; ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); @@ -1155,8 +1155,8 @@ TEST_F(GraphTransformationTests, FuseMatmulBNWithEmptyOptionalOutput) { } // should not fuse -TEST_F(GraphTransformationTests, FuseMatmulBNWithOptionalOutput) { - constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/fuse-matmul-bn.onnx"; +TEST_F(GraphTransformationTests, FuseMatmulBNWithOptionalOutputWithReshape) { + constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/fuse-matmul-bn-with-reshape.onnx"; std::shared_ptr p_model; ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); @@ -1190,6 +1190,114 @@ TEST_F(GraphTransformationTests, FuseMatmulBNWithOptionalOutput) { ASSERT_EQ(op_to_count["Gemm"], 0); } +TEST_F(GraphTransformationTests, FuseMatmulBNWithoutReshape) { + constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/fuse-matmul-bn-without-reshape.onnx"; + + std::shared_ptr p_model; + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); + Graph& graph = p_model->MainGraph(); + + std::string expected_output_name; + GraphViewer graphViewer(graph); + for (auto& node_index : graphViewer.GetNodesInTopologicalOrder()) { + auto& node = *graph.GetNode(node_index); + if (node.OpType() == "BatchNormalization") { + expected_output_name = node.OutputDefs()[0]->Name(); + } + } + + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + auto rule_transformer_L1 = std::make_unique("RuleTransformerL1"); + ASSERT_STATUS_OK(rule_transformer_L1->Register(std::make_unique())); + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1)); + + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); + + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_EQ(op_to_count["BatchNormalization"], 0); + ASSERT_EQ(op_to_count["MatMul"], 0); + ASSERT_EQ(op_to_count["Gemm"], 1); + + for (auto& node : graph.Nodes()) { + if (node.OpType() == "Gemm") { + ASSERT_EQ(node.OutputDefs()[0]->Name(), expected_output_name) + << "fusion should produce the same output name as the last node"; + } + } +} + +TEST_F(GraphTransformationTests, FuseMatmulBNWithEmptyOptionalOutputWithoutReshape) { + constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/fuse-matmul-bn-without-reshape.onnx"; + + std::shared_ptr p_model; + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); + Graph& graph = p_model->MainGraph(); + + std::string expected_output_name; + GraphViewer graphViewer(graph); + for (auto& node_index : graphViewer.GetNodesInTopologicalOrder()) { + auto& node = *graph.GetNode(node_index); + if (node.OpType() == "BatchNormalization") { + expected_output_name = node.OutputDefs()[0]->Name(); + node.MutableOutputDefs().push_back(&graph.GetOrCreateNodeArg("", nullptr)); + } + } + + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + auto rule_transformer_L1 = std::make_unique("RuleTransformerL1"); + ASSERT_STATUS_OK(rule_transformer_L1->Register(std::make_unique())); + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1)); + + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); + + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_EQ(op_to_count["BatchNormalization"], 0); + ASSERT_EQ(op_to_count["MatMul"], 0); + ASSERT_EQ(op_to_count["Gemm"], 1); + + for (auto& node : graph.Nodes()) { + if (node.OpType() == "Gemm") { + ASSERT_EQ(node.OutputDefs()[0]->Name(), expected_output_name) + << "fusion should produce the same output name as the last node"; + } + } +} + +// should not fuse +TEST_F(GraphTransformationTests, FuseMatmulBNWithOptionalOutputWithoutReshape) { + constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/fuse-matmul-bn-without-reshape.onnx"; + + std::shared_ptr p_model; + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); + Graph& graph = p_model->MainGraph(); + + std::string expected_output_name; + GraphViewer graphViewer(graph); + for (auto& node_index : graphViewer.GetNodesInTopologicalOrder()) { + auto& node = *graph.GetNode(node_index); + if (node.OpType() == "BatchNormalization") { + expected_output_name = node.OutputDefs()[0]->Name(); + // additional non-empty output to batchNormalization + ONNX_NAMESPACE::TypeProto optional_output_tensor_type; + optional_output_tensor_type.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TypeProto::kTensorType); + auto& arg = graph.GetOrCreateNodeArg("bn_optional_output", &optional_output_tensor_type); + node.MutableOutputDefs().push_back(&arg); + } + } + + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + auto rule_transformer_L1 = std::make_unique("RuleTransformerL1"); + ASSERT_STATUS_OK(rule_transformer_L1->Register(std::make_unique())); + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1)); + + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); + + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_EQ(op_to_count["BatchNormalization"], 1); + ASSERT_EQ(op_to_count["MatMul"], 1); + ASSERT_EQ(op_to_count["Gemm"], 0); +} + TEST_F(GraphTransformationTests, DontFuseConvWithBNWithOptionalOutputs) { constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/fuse-conv-bn-no-bias.onnx"; diff --git a/onnxruntime/test/testdata/transform/fusion/fuse-matmul-bn-with-reshape.onnx b/onnxruntime/test/testdata/transform/fusion/fuse-matmul-bn-with-reshape.onnx new file mode 100644 index 0000000000000000000000000000000000000000..8e4bc49514548c604b36029f780f7ff1a56db148 GIT binary patch literal 709 zcmd zSDarY#0wS3FD(JeE3x?|miU(Da2au-N^r3kXCxM+#v2In7o|d&5FG|e>_HF#E+;N@ zIU&K4qQt!7g8bstc$jj|5SWmbAQwwPYEiBOg9EbzqXW|dMs}_+E=(<497%cc#mR{| zsa)(pR#IkSF_@8?nwZDM1{5hvEE3`b(ojDLIVF}PXZYn8{d>;u)k9_-R68Ut6h7pqW!TWKOYnLXR9R84eZ!b}QchFu=&j$m#@( z2`v?{(?l2;8tl=r6FkVYlp(5YTTJaU71!F`j(KX=uE%OG<|u9N1Pe}>qNNu#?VDGY z+kf6~Y%f=BV{d;x(hjT$6x3Qu5R=rr>+N}uHrviU9%8RkaM$M1!XSGmXf#BLLlc@1 o9~Tb?qYwud69*#@vnDBUK|@_gj7tP4BLI_u(u__lTnqvn0In0|n*aa+ literal 0 HcmV?d00001 diff --git a/onnxruntime/test/testdata/transform/fusion/fuse-matmul-bn-without-reshape.onnx b/onnxruntime/test/testdata/transform/fusion/fuse-matmul-bn-without-reshape.onnx new file mode 100644 index 0000000000000000000000000000000000000000..187d03e3b49efc40a585e4fac05598c275307db8 GIT binary patch literal 547 zcmd zSDarY#0wS3FD(JeE3x?|miU(Da2au-N^r3kXCxM+#v2In7o|d&5FG|e>_HF#E+;N@ zIU&K4qQt!7g8bstc$jj|5SWmbAQwwPYEiBOg9EbzqXW|dMs}_+E=(<497%cc#mR{| zsa)(pR#IkSF_@8?nwZDM1{5hvEE3`b(ojDLIVF}PXZYn8RE@9QBo*HVV4I(0kWzK;8<-R`-b_8#X2 zt(LGf**U=?1*T|eud7|e(%C5TB|Y`yHi2Qu3irp4M? z9?i9T&EjY81dYrnacFWC;^X4sU=-ruV&Y%~V%8)DE@-F=iE)VlWdxuyAX%^qCl)RS G0S*BF=Knwd From b30662305479d8e8e04234d565d098b9d268a915 Mon Sep 17 00:00:00 2001 From: Sumit Agarwal Date: Fri, 20 Oct 2023 16:27:48 -0500 Subject: [PATCH 15/24] Modified pattern matching to incoroprate any combination --- .../core/optimizer/matmul_bn_fusion.cc | 99 +++++++++++------- .../test/optimizer/graph_transform_test.cc | 68 +++++++----- ...hape.onnx => fuse-matmul-bn-directly.onnx} | Bin .../fuse-matmul-bn-non-ignorable-node.onnx | Bin 0 -> 593 bytes .../fusion/fuse-matmul-bn-only-reshape.onnx | Bin 0 -> 639 bytes .../fusion/fuse-matmul-bn-only-transpose.onnx | Bin 0 -> 613 bytes 6 files changed, 103 insertions(+), 64 deletions(-) rename onnxruntime/test/testdata/transform/fusion/{fuse-matmul-bn-without-reshape.onnx => fuse-matmul-bn-directly.onnx} (100%) create mode 100644 onnxruntime/test/testdata/transform/fusion/fuse-matmul-bn-non-ignorable-node.onnx create mode 100644 onnxruntime/test/testdata/transform/fusion/fuse-matmul-bn-only-reshape.onnx create mode 100644 onnxruntime/test/testdata/transform/fusion/fuse-matmul-bn-only-transpose.onnx diff --git a/onnxruntime/core/optimizer/matmul_bn_fusion.cc b/onnxruntime/core/optimizer/matmul_bn_fusion.cc index bc27c59a8ae02..4b1ef894cc816 100644 --- a/onnxruntime/core/optimizer/matmul_bn_fusion.cc +++ b/onnxruntime/core/optimizer/matmul_bn_fusion.cc @@ -7,38 +7,55 @@ #include "core/optimizer/utils.h" namespace onnxruntime { -bool MatchPath(const Node& parent_node, - const gsl::span>>& path, - const Node& child_node) { - if (path.size() == 0) { - return true; + +namespace matmulbnfusion { + std::vector>> ignorable_nodes{ + {"Reshape", {1, 5, 13, 14, 19}}, + {"Transpose", {1, 13}}}; + std::pair> dest = {"BatchNormalization", {1, 6, 7, 9, 14, 15}}; +} + +std::optional> MatchPath( + const Node& parent_node, + const Node& curr_node, + const std::pair>& dest, + const gsl::span>>& ignorable_nodes, + std::vector& ignorable_nodes_visited) { + + // curr_node has different execution provider then it's parent or has > 1 output + if (curr_node.GetExecutionProviderType() != parent_node.GetExecutionProviderType() || + curr_node.GetOutputEdgesCount() != 1) { + return std::nullopt; } - if (!graph_utils::IsSupportedOptypeVersionAndDomain(child_node, path[0].first, path[0].second) || - child_node.GetExecutionProviderType() != parent_node.GetExecutionProviderType()) { - return false; + // curr_node == dest_node + if (graph_utils::IsSupportedOptypeVersionAndDomain(curr_node, dest.first, dest.second)) { + return curr_node; } - /* - * last node in the path can have more than one output - * because all those outputs will be preserved by the addition of new Gemm node - */ - if (path.size() > 1 && child_node.GetOutputEdgesCount() != 1) { - return false; + // curr_node can be any of the ignorable_nodes. + for (size_t index = 0; index < ignorable_nodes.size(); index++) { + if (!ignorable_nodes_visited[index] && + graph_utils::IsSupportedOptypeVersionAndDomain(curr_node, ignorable_nodes[index].first, ignorable_nodes[index].second)) { + ignorable_nodes_visited[index] = true; + return MatchPath(curr_node, *curr_node.OutputNodesBegin(), dest, ignorable_nodes, ignorable_nodes_visited); + } } - return MatchPath(child_node, path.subspan(1), *child_node.OutputNodesBegin()); + // curr_node neither a dest node nor any of the ignorable_nodes. + return std::nullopt; } /* * Given a MatMul node, it will verify the following pattern. - * MatMul MatMul - * | | - * Reshape OR BatchNormalization + * MatMul + * | + * Reshape ^ * | - * Transpose + * Transpose ^ * | * BatchNormalization + * Note: ^ means there can be 0 or 1 occurrences of that node. * Other Conditions: * - B tensor of MatMul should be constant. * - scale, B, mean, var tensors of BatchNormalization should be constant. @@ -50,32 +67,29 @@ bool MatmulBNFusion::SatisfyCondition(const Graph& graph, const Node& node, cons return false; } - std::vector>> path_with_reshape{ - {"Reshape", {1, 5, 13, 14, 19}}, - {"Transpose", {1, 13}}, - {"BatchNormalization", {1, 6, 7, 9, 14, 15}}}; - - std::vector>> path_without_reshape{ - {"BatchNormalization", {1, 6, 7, 9, 14, 15}}}; - const Node& child_node = *node.OutputNodesBegin(); - if (!(MatchPath(node, path_with_reshape, child_node) ^ MatchPath(node, path_without_reshape, child_node))) { + std::vector ignorable_nodes_visited(matmulbnfusion::ignorable_nodes.size(), false); + std::optional> batch_norm_node = MatchPath( + node, + child_node, + matmulbnfusion::dest, + matmulbnfusion::ignorable_nodes, + ignorable_nodes_visited); + if (!batch_norm_node.has_value()) { return false; } - const auto& batch_norm_node = child_node.OpType() == "Reshape" ? *child_node.OutputNodesBegin()->OutputNodesBegin() : child_node; - // Check that the appropriate inputs to the Matmul and BN nodes are constants. if (!graph_utils::NodeArgIsConstant(graph, *node.InputDefs()[1]) || - !graph_utils::NodeArgIsConstant(graph, *batch_norm_node.InputDefs()[1]) || - !graph_utils::NodeArgIsConstant(graph, *batch_norm_node.InputDefs()[2]) || - !graph_utils::NodeArgIsConstant(graph, *batch_norm_node.InputDefs()[3]) || - !graph_utils::NodeArgIsConstant(graph, *batch_norm_node.InputDefs()[4])) { + !graph_utils::NodeArgIsConstant(graph, *batch_norm_node->get().InputDefs()[1]) || + !graph_utils::NodeArgIsConstant(graph, *batch_norm_node->get().InputDefs()[2]) || + !graph_utils::NodeArgIsConstant(graph, *batch_norm_node->get().InputDefs()[3]) || + !graph_utils::NodeArgIsConstant(graph, *batch_norm_node->get().InputDefs()[4])) { return false; } // First output from BN is required. Others are optional. If any optional outputs exist we can't fuse. - const auto& output_defs = batch_norm_node.OutputDefs(); + const auto& output_defs = batch_norm_node->get().OutputDefs(); if (output_defs.size() > 1) { for (size_t i = 1, end = output_defs.size(); i < end; ++i) { if (output_defs[i] != nullptr && output_defs[i]->Exists()) { @@ -104,9 +118,15 @@ bool MatmulBNFusion::SatisfyCondition(const Graph& graph, const Node& node, cons */ Status MatmulBNFusion::Apply(Graph& graph, Node& matmul_node, RewriteRuleEffect& rule_effect, const logging::Logger&) const { const Node& child_node = *matmul_node.OutputNodesBegin(); - NodeIndex batch_norm_node_index = child_node.OpType() == "Reshape" ? - child_node.OutputNodesBegin()->OutputNodesBegin()->Index() : child_node.Index(); - Node& batch_norm_node = *graph.GetNode(batch_norm_node_index); + std::vector ignorable_nodes_visited(matmulbnfusion::ignorable_nodes.size(), false); + NodeIndex batch_norm_node_index = MatchPath( + matmul_node, + child_node, + matmulbnfusion::dest, + matmulbnfusion::ignorable_nodes, + ignorable_nodes_visited)->get().Index(); + + Node& batch_norm_node = *graph.GetNode(batch_norm_node_index); // need mutable node, that's why extracting node from graph // only perform fusion if epsilon is present and is of float_32 type auto epsilon_attribute = batch_norm_node.GetAttributes().find("epsilon"); @@ -192,7 +212,8 @@ Status MatmulBNFusion::Apply(Graph& graph, Node& matmul_node, RewriteRuleEffect& // Delete optional empty output defs. // Delete BatchNormalization node and update the input of the child of BatchNormalization batch_norm_node.MutableOutputDefs().resize(1); - NodeIndex batch_norm_parent_index = child_node.OpType() == "Reshape" ? child_node.OutputNodesBegin()->Index() : gemm_node.Index(); + NodeIndex batch_norm_parent_index = child_node.OpType() == "BatchNormalization" ? gemm_node.Index() : + batch_norm_node.InputNodesBegin()->Index(); graph_utils::FinalizeNodeFusion(graph, *graph.GetNode(batch_norm_parent_index), batch_norm_node); rule_effect = RewriteRuleEffect::kRemovedCurrentNode; diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index a767f0ed6d229..dd2d3313b9ca9 100755 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -1080,7 +1080,7 @@ TEST_F(GraphTransformationTests, FuseConvBNNoBias) { } } -TEST_F(GraphTransformationTests, FuseMatmulBNWithReshape) { +TEST_F(GraphTransformationTests, FuseMatmulBNWithInBetweenNodes) { constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/fuse-matmul-bn-with-reshape.onnx"; std::shared_ptr p_model; @@ -1111,12 +1111,12 @@ TEST_F(GraphTransformationTests, FuseMatmulBNWithReshape) { for (auto& node : graph.Nodes()) { if (node.OpType() == "Gemm") { ASSERT_EQ(node.OutputDefs()[0]->Name(), expected_output_name) - << "fusion should produce the same output name as the last node"; + << "fusion should produce the same output name as the MatMul node"; } } } -TEST_F(GraphTransformationTests, FuseMatmulBNWithEmptyOptionalOutputWithReshape) { +TEST_F(GraphTransformationTests, FuseMatmulBNWithEmptyOptionalOutputWithInBetweenNodes) { constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/fuse-matmul-bn-with-reshape.onnx"; std::shared_ptr p_model; @@ -1149,26 +1149,23 @@ TEST_F(GraphTransformationTests, FuseMatmulBNWithEmptyOptionalOutputWithReshape) for (auto& node : graph.Nodes()) { if (node.OpType() == "Gemm") { ASSERT_EQ(node.OutputDefs()[0]->Name(), expected_output_name) - << "fusion should produce the same output name as the last node"; + << "fusion should produce the same output name as the MatMul node"; } } } // should not fuse -TEST_F(GraphTransformationTests, FuseMatmulBNWithOptionalOutputWithReshape) { +TEST_F(GraphTransformationTests, FuseMatmulBNWithOptionalOutputWithInBetweenNodes) { constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/fuse-matmul-bn-with-reshape.onnx"; std::shared_ptr p_model; ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); Graph& graph = p_model->MainGraph(); - std::string expected_output_name; GraphViewer graphViewer(graph); for (auto& node_index : graphViewer.GetNodesInTopologicalOrder()) { auto& node = *graph.GetNode(node_index); - if (node.OpType() == "MatMul") { - expected_output_name = node.OutputDefs()[0]->Name(); - } else if (node.OpType() == "BatchNormalization") { + if (node.OpType() == "BatchNormalization") { // additional non-empty output to batchNormalization ONNX_NAMESPACE::TypeProto optional_output_tensor_type; optional_output_tensor_type.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TypeProto::kTensorType); @@ -1190,8 +1187,8 @@ TEST_F(GraphTransformationTests, FuseMatmulBNWithOptionalOutputWithReshape) { ASSERT_EQ(op_to_count["Gemm"], 0); } -TEST_F(GraphTransformationTests, FuseMatmulBNWithoutReshape) { - constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/fuse-matmul-bn-without-reshape.onnx"; +TEST_F(GraphTransformationTests, FuseMatmulBNDirectly) { + constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/fuse-matmul-bn-directly.onnx"; std::shared_ptr p_model; ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); @@ -1226,8 +1223,8 @@ TEST_F(GraphTransformationTests, FuseMatmulBNWithoutReshape) { } } -TEST_F(GraphTransformationTests, FuseMatmulBNWithEmptyOptionalOutputWithoutReshape) { - constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/fuse-matmul-bn-without-reshape.onnx"; +TEST_F(GraphTransformationTests, FuseMatmulBNWithOnlyReshape) { + constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/fuse-matmul-bn-only-reshape.onnx"; std::shared_ptr p_model; ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); @@ -1237,9 +1234,8 @@ TEST_F(GraphTransformationTests, FuseMatmulBNWithEmptyOptionalOutputWithoutResha GraphViewer graphViewer(graph); for (auto& node_index : graphViewer.GetNodesInTopologicalOrder()) { auto& node = *graph.GetNode(node_index); - if (node.OpType() == "BatchNormalization") { + if (node.OpType() == "MatMul") { expected_output_name = node.OutputDefs()[0]->Name(); - node.MutableOutputDefs().push_back(&graph.GetOrCreateNodeArg("", nullptr)); } } @@ -1258,14 +1254,13 @@ TEST_F(GraphTransformationTests, FuseMatmulBNWithEmptyOptionalOutputWithoutResha for (auto& node : graph.Nodes()) { if (node.OpType() == "Gemm") { ASSERT_EQ(node.OutputDefs()[0]->Name(), expected_output_name) - << "fusion should produce the same output name as the last node"; + << "fusion should produce the same output name as the MatMul node"; } } } -// should not fuse -TEST_F(GraphTransformationTests, FuseMatmulBNWithOptionalOutputWithoutReshape) { - constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/fuse-matmul-bn-without-reshape.onnx"; +TEST_F(GraphTransformationTests, FuseMatmulBNWithOnlyTranspose) { + constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/fuse-matmul-bn-only-transpose.onnx"; std::shared_ptr p_model; ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); @@ -1275,13 +1270,8 @@ TEST_F(GraphTransformationTests, FuseMatmulBNWithOptionalOutputWithoutReshape) { GraphViewer graphViewer(graph); for (auto& node_index : graphViewer.GetNodesInTopologicalOrder()) { auto& node = *graph.GetNode(node_index); - if (node.OpType() == "BatchNormalization") { + if (node.OpType() == "MatMul") { expected_output_name = node.OutputDefs()[0]->Name(); - // additional non-empty output to batchNormalization - ONNX_NAMESPACE::TypeProto optional_output_tensor_type; - optional_output_tensor_type.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TypeProto::kTensorType); - auto& arg = graph.GetOrCreateNodeArg("bn_optional_output", &optional_output_tensor_type); - node.MutableOutputDefs().push_back(&arg); } } @@ -1292,6 +1282,34 @@ TEST_F(GraphTransformationTests, FuseMatmulBNWithOptionalOutputWithoutReshape) { ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_EQ(op_to_count["BatchNormalization"], 0); + ASSERT_EQ(op_to_count["MatMul"], 0); + ASSERT_EQ(op_to_count["Gemm"], 1); + + for (auto& node : graph.Nodes()) { + if (node.OpType() == "Gemm") { + ASSERT_EQ(node.OutputDefs()[0]->Name(), expected_output_name) + << "fusion should produce the same output name as the MatMul node"; + } + } +} + +// should not fuse +TEST_F(GraphTransformationTests, FuseMatmulBNWithNonIgnorableNode) { + constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/fuse-matmul-bn-non-ignorable-node.onnx"; + + std::shared_ptr p_model; + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); + Graph& graph = p_model->MainGraph(); + + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + auto rule_transformer_L1 = std::make_unique("RuleTransformerL1"); + ASSERT_STATUS_OK(rule_transformer_L1->Register(std::make_unique())); + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1)); + + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); + std::map op_to_count = CountOpsInGraph(graph); ASSERT_EQ(op_to_count["BatchNormalization"], 1); ASSERT_EQ(op_to_count["MatMul"], 1); diff --git a/onnxruntime/test/testdata/transform/fusion/fuse-matmul-bn-without-reshape.onnx b/onnxruntime/test/testdata/transform/fusion/fuse-matmul-bn-directly.onnx similarity index 100% rename from onnxruntime/test/testdata/transform/fusion/fuse-matmul-bn-without-reshape.onnx rename to onnxruntime/test/testdata/transform/fusion/fuse-matmul-bn-directly.onnx diff --git a/onnxruntime/test/testdata/transform/fusion/fuse-matmul-bn-non-ignorable-node.onnx b/onnxruntime/test/testdata/transform/fusion/fuse-matmul-bn-non-ignorable-node.onnx new file mode 100644 index 0000000000000000000000000000000000000000..1050a7285b4a6e9e13d5be17f20316f9c57a2aac GIT binary patch literal 593 zcmd zSDarY#0wS3FD(JeE3x?|miU(DaA|R&N(k|1rljVTWR_IMLsfEkLIt=&xX>lJIFj<> zi<1*`Qn}cHtfb7uVlX2&H8GEi4JcBUSR}*=q@iXBIVF}PXZYn8u2svy%E2nYsFc!j%G0hwo57}ap`u+&w6Oi|3GViX_qf@! zybQG0I2>y4k;iL)_kNE3_wx$&tYe@Bh!we$Lz!TZ@_d?5;5D+V6{= zV3%pwVCNR)V}GFJiJcSoRJ-K@CU%z=KCp8_4^Ax=u;n5Q3=Q_^*a;peTFMYrb=SLX z+22jE+ccNa&c@Z%zU4-=y%Q|JV2aw9{@H~1McZ4hIcs}X#K69_o&}-^6qs5{5R;0# z3hgU&)9ra*iP{&LMB2BRci1^W13F3^8u>zeTs$0%LL6L79E?EBnk2ym4Oes-Cl)RS G0S*9lCcbw7 literal 0 HcmV?d00001 diff --git a/onnxruntime/test/testdata/transform/fusion/fuse-matmul-bn-only-reshape.onnx b/onnxruntime/test/testdata/transform/fusion/fuse-matmul-bn-only-reshape.onnx new file mode 100644 index 0000000000000000000000000000000000000000..c361a42700a30b12eeeacefe8da3825a1e9b3ab7 GIT binary patch literal 639 zcmd zSDarY#0wS3FD(JeE3x?|miU(Da2au-N^r3kXCxM+#v2In7o|d&5FG|e>_HF#t`IJC zIWCT*y!hhe#GF(vb|5P$GqD)V$W2Ym<6;AflqD7kaRO%ZT9OGc24NQswKn0BEW9N z`yU1v85vofU}31G0(P1R14Dy7I(C8wnU*p{)fxLfyYog2_W4&F?IlH6?J8~<+B?C5 z6Q;!T*3o{AR=O=~<8iytTpoxbP*7_rK}?#|t7^aKc8I;q6Gyw$c{gn>(pBu8 ypwSQ|4oxCLd|W&nj6xh-OdO0r%$lUY1r2o}F)k6Hi~v*yBnwvI#KOfOzySb5e#_?o literal 0 HcmV?d00001 diff --git a/onnxruntime/test/testdata/transform/fusion/fuse-matmul-bn-only-transpose.onnx b/onnxruntime/test/testdata/transform/fusion/fuse-matmul-bn-only-transpose.onnx new file mode 100644 index 0000000000000000000000000000000000000000..5af84634c03e739e48119a08fc6a266b5f2bca7d GIT binary patch literal 613 zcmd zSDarY#0wS3FD(JeE3x?|miU(DaM^I7N(c#-6eZ>r7vvYG#zT}EC~=0sgtU0MSPD{$ zavc~Q7#1+HbA@qXD(B)z%8M^fPRvQ=Vh6I4G82ozjNH`3JT5k%NLgZ$5GRm^x|NNE?F*w0wFf2LJrTA)Vz|+l1d>lnCJ3fK>!X9ty&IN4psq1 zrIZshtL$C2c3RkOs}8bfYgVv7 zR4ik^RXWk$XPTM)hKAd=P0VrjTWfgj7&hwKZL>UWH+^xS{jqxo?apX@wzJloV<&l` z$j%8ph_zI}mWwbjG}xnKCwQP}DMM7fbxq?*qOiTzy$-v6 Date: Fri, 20 Oct 2023 16:31:36 -0500 Subject: [PATCH 16/24] updated comment --- onnxruntime/core/optimizer/matmul_bn_fusion.cc | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/onnxruntime/core/optimizer/matmul_bn_fusion.cc b/onnxruntime/core/optimizer/matmul_bn_fusion.cc index 4b1ef894cc816..bc12815b6df6f 100644 --- a/onnxruntime/core/optimizer/matmul_bn_fusion.cc +++ b/onnxruntime/core/optimizer/matmul_bn_fusion.cc @@ -48,11 +48,11 @@ std::optional> MatchPath( /* * Given a MatMul node, it will verify the following pattern. - * MatMul - * | - * Reshape ^ - * | - * Transpose ^ + * MatMul GEMM + * | | + * Reshape ^ ---> Reshape ^ + * | | + * Transpose ^ Transpose ^ * | * BatchNormalization * Note: ^ means there can be 0 or 1 occurrences of that node. From 23c23daa2d4ee92520c5f4b00e0a35c0a10c97fa Mon Sep 17 00:00:00 2001 From: Sumit Agarwal Date: Fri, 20 Oct 2023 16:44:09 -0500 Subject: [PATCH 17/24] Apply lintrunner changes --- .../core/optimizer/matmul_bn_fusion.cc | 50 +++++++++---------- 1 file changed, 25 insertions(+), 25 deletions(-) diff --git a/onnxruntime/core/optimizer/matmul_bn_fusion.cc b/onnxruntime/core/optimizer/matmul_bn_fusion.cc index bc12815b6df6f..8c728c309ea06 100644 --- a/onnxruntime/core/optimizer/matmul_bn_fusion.cc +++ b/onnxruntime/core/optimizer/matmul_bn_fusion.cc @@ -9,19 +9,18 @@ namespace onnxruntime { namespace matmulbnfusion { - std::vector>> ignorable_nodes{ +std::vector>> ignorable_nodes{ {"Reshape", {1, 5, 13, 14, 19}}, {"Transpose", {1, 13}}}; - std::pair> dest = {"BatchNormalization", {1, 6, 7, 9, 14, 15}}; -} +std::pair> dest = {"BatchNormalization", {1, 6, 7, 9, 14, 15}}; +} // namespace matmulbnfusion std::optional> MatchPath( - const Node& parent_node, - const Node& curr_node, - const std::pair>& dest, - const gsl::span>>& ignorable_nodes, - std::vector& ignorable_nodes_visited) { - + const Node& parent_node, + const Node& curr_node, + const std::pair>& dest, + const gsl::span>>& ignorable_nodes, + std::vector& ignorable_nodes_visited) { // curr_node has different execution provider then it's parent or has > 1 output if (curr_node.GetExecutionProviderType() != parent_node.GetExecutionProviderType() || curr_node.GetOutputEdgesCount() != 1) { @@ -48,14 +47,14 @@ std::optional> MatchPath( /* * Given a MatMul node, it will verify the following pattern. - * MatMul GEMM - * | | + * MatMul GEMM + * | | * Reshape ^ ---> Reshape ^ * | | * Transpose ^ Transpose ^ * | * BatchNormalization - * Note: ^ means there can be 0 or 1 occurrences of that node. + * Note: ^ means there can be 0 or 1 occurrences of that node. * Other Conditions: * - B tensor of MatMul should be constant. * - scale, B, mean, var tensors of BatchNormalization should be constant. @@ -70,11 +69,11 @@ bool MatmulBNFusion::SatisfyCondition(const Graph& graph, const Node& node, cons const Node& child_node = *node.OutputNodesBegin(); std::vector ignorable_nodes_visited(matmulbnfusion::ignorable_nodes.size(), false); std::optional> batch_norm_node = MatchPath( - node, - child_node, - matmulbnfusion::dest, - matmulbnfusion::ignorable_nodes, - ignorable_nodes_visited); + node, + child_node, + matmulbnfusion::dest, + matmulbnfusion::ignorable_nodes, + ignorable_nodes_visited); if (!batch_norm_node.has_value()) { return false; } @@ -120,13 +119,15 @@ Status MatmulBNFusion::Apply(Graph& graph, Node& matmul_node, RewriteRuleEffect& const Node& child_node = *matmul_node.OutputNodesBegin(); std::vector ignorable_nodes_visited(matmulbnfusion::ignorable_nodes.size(), false); NodeIndex batch_norm_node_index = MatchPath( - matmul_node, - child_node, - matmulbnfusion::dest, - matmulbnfusion::ignorable_nodes, - ignorable_nodes_visited)->get().Index(); + matmul_node, + child_node, + matmulbnfusion::dest, + matmulbnfusion::ignorable_nodes, + ignorable_nodes_visited) + ->get() + .Index(); - Node& batch_norm_node = *graph.GetNode(batch_norm_node_index); // need mutable node, that's why extracting node from graph + Node& batch_norm_node = *graph.GetNode(batch_norm_node_index); // need mutable node, that's why extracting node from graph // only perform fusion if epsilon is present and is of float_32 type auto epsilon_attribute = batch_norm_node.GetAttributes().find("epsilon"); @@ -212,8 +213,7 @@ Status MatmulBNFusion::Apply(Graph& graph, Node& matmul_node, RewriteRuleEffect& // Delete optional empty output defs. // Delete BatchNormalization node and update the input of the child of BatchNormalization batch_norm_node.MutableOutputDefs().resize(1); - NodeIndex batch_norm_parent_index = child_node.OpType() == "BatchNormalization" ? gemm_node.Index() : - batch_norm_node.InputNodesBegin()->Index(); + NodeIndex batch_norm_parent_index = child_node.OpType() == "BatchNormalization" ? gemm_node.Index() : batch_norm_node.InputNodesBegin()->Index(); graph_utils::FinalizeNodeFusion(graph, *graph.GetNode(batch_norm_parent_index), batch_norm_node); rule_effect = RewriteRuleEffect::kRemovedCurrentNode; From 1a26722fa87c917589ea0c4fead633749081fadd Mon Sep 17 00:00:00 2001 From: Sumit Agarwal Date: Fri, 20 Oct 2023 18:37:21 -0500 Subject: [PATCH 18/24] Replaced recursion with iteration --- .../core/optimizer/matmul_bn_fusion.cc | 104 +++++++++--------- 1 file changed, 52 insertions(+), 52 deletions(-) diff --git a/onnxruntime/core/optimizer/matmul_bn_fusion.cc b/onnxruntime/core/optimizer/matmul_bn_fusion.cc index 8c728c309ea06..bbf7a55986ada 100644 --- a/onnxruntime/core/optimizer/matmul_bn_fusion.cc +++ b/onnxruntime/core/optimizer/matmul_bn_fusion.cc @@ -9,39 +9,50 @@ namespace onnxruntime { namespace matmulbnfusion { -std::vector>> ignorable_nodes{ +const std::vector>> ignorable_nodes{ {"Reshape", {1, 5, 13, 14, 19}}, {"Transpose", {1, 13}}}; -std::pair> dest = {"BatchNormalization", {1, 6, 7, 9, 14, 15}}; +const std::pair> dest = {"BatchNormalization", {1, 6, 7, 9, 14, 15}}; } // namespace matmulbnfusion -std::optional> MatchPath( - const Node& parent_node, - const Node& curr_node, - const std::pair>& dest, - const gsl::span>>& ignorable_nodes, - std::vector& ignorable_nodes_visited) { - // curr_node has different execution provider then it's parent or has > 1 output - if (curr_node.GetExecutionProviderType() != parent_node.GetExecutionProviderType() || - curr_node.GetOutputEdgesCount() != 1) { - return std::nullopt; - } +bool NodeIsIgnorable(const Graph& graph, const Node& root_node, NodeIndex curr_node_index) { + const Node* curr_node = graph.GetNode(curr_node_index); - // curr_node == dest_node - if (graph_utils::IsSupportedOptypeVersionAndDomain(curr_node, dest.first, dest.second)) { - return curr_node; + // curr_node has different execution provider then it's parent or has > 1 output + if (curr_node->GetExecutionProviderType() != root_node.GetExecutionProviderType() || + curr_node->GetOutputEdgesCount() != 1) { + return false; } // curr_node can be any of the ignorable_nodes. - for (size_t index = 0; index < ignorable_nodes.size(); index++) { - if (!ignorable_nodes_visited[index] && - graph_utils::IsSupportedOptypeVersionAndDomain(curr_node, ignorable_nodes[index].first, ignorable_nodes[index].second)) { - ignorable_nodes_visited[index] = true; - return MatchPath(curr_node, *curr_node.OutputNodesBegin(), dest, ignorable_nodes, ignorable_nodes_visited); + for (size_t index = 0; index < matmulbnfusion::ignorable_nodes.size(); index++) { + if (graph_utils::IsSupportedOptypeVersionAndDomain(*curr_node, matmulbnfusion::ignorable_nodes[index].first, + matmulbnfusion::ignorable_nodes[index].second)) { + return true; } } - // curr_node neither a dest node nor any of the ignorable_nodes. + return false; +} + +std::optional MatchPath(const Graph& graph, const Node& root_node, NodeIndex curr_node_index) { + while (NodeIsIgnorable(graph, root_node, curr_node_index)) { + curr_node_index = graph.GetNode(curr_node_index)->OutputNodesBegin()->Index(); + } + + // curr_node is neither ignorable nor dest + const Node* curr_node = graph.GetNode(curr_node_index); + if (curr_node->OpType() != matmulbnfusion::dest.first) { + return std::nullopt; + } + + if (curr_node->GetExecutionProviderType() == root_node.GetExecutionProviderType() && + graph_utils::IsSupportedOptypeVersionAndDomain(*curr_node, matmulbnfusion::dest.first, matmulbnfusion::dest.second)) { + return curr_node_index; + } + + // either curr_node has different execution provider or + // has invalid opset. return std::nullopt; } @@ -54,7 +65,7 @@ std::optional> MatchPath( * Transpose ^ Transpose ^ * | * BatchNormalization - * Note: ^ means there can be 0 or 1 occurrences of that node. + * Note: ^ means there can be 0 or any occurrences of that node. * Other Conditions: * - B tensor of MatMul should be constant. * - scale, B, mean, var tensors of BatchNormalization should be constant. @@ -66,29 +77,30 @@ bool MatmulBNFusion::SatisfyCondition(const Graph& graph, const Node& node, cons return false; } - const Node& child_node = *node.OutputNodesBegin(); - std::vector ignorable_nodes_visited(matmulbnfusion::ignorable_nodes.size(), false); - std::optional> batch_norm_node = MatchPath( - node, - child_node, - matmulbnfusion::dest, - matmulbnfusion::ignorable_nodes, - ignorable_nodes_visited); - if (!batch_norm_node.has_value()) { + if (graph.NodeProducesGraphOutput(node)) { + return false; + } + + // because is not producing graph output, it means it will have a child node + NodeIndex child_node_index = node.OutputNodesBegin()->Index(); + std::optional batch_norm_index = MatchPath(graph, node, child_node_index); + if (!batch_norm_index.has_value()) { return false; } + const Node* batch_norm_node = graph.GetNode(*batch_norm_index); + // Check that the appropriate inputs to the Matmul and BN nodes are constants. if (!graph_utils::NodeArgIsConstant(graph, *node.InputDefs()[1]) || - !graph_utils::NodeArgIsConstant(graph, *batch_norm_node->get().InputDefs()[1]) || - !graph_utils::NodeArgIsConstant(graph, *batch_norm_node->get().InputDefs()[2]) || - !graph_utils::NodeArgIsConstant(graph, *batch_norm_node->get().InputDefs()[3]) || - !graph_utils::NodeArgIsConstant(graph, *batch_norm_node->get().InputDefs()[4])) { + !graph_utils::NodeArgIsConstant(graph, *batch_norm_node->InputDefs()[1]) || + !graph_utils::NodeArgIsConstant(graph, *batch_norm_node->InputDefs()[2]) || + !graph_utils::NodeArgIsConstant(graph, *batch_norm_node->InputDefs()[3]) || + !graph_utils::NodeArgIsConstant(graph, *batch_norm_node->InputDefs()[4])) { return false; } // First output from BN is required. Others are optional. If any optional outputs exist we can't fuse. - const auto& output_defs = batch_norm_node->get().OutputDefs(); + const auto& output_defs = batch_norm_node->OutputDefs(); if (output_defs.size() > 1) { for (size_t i = 1, end = output_defs.size(); i < end; ++i) { if (output_defs[i] != nullptr && output_defs[i]->Exists()) { @@ -97,10 +109,6 @@ bool MatmulBNFusion::SatisfyCondition(const Graph& graph, const Node& node, cons } } - if (graph.NodeProducesGraphOutput(node)) { - return false; - } - return true; } @@ -116,16 +124,8 @@ bool MatmulBNFusion::SatisfyCondition(const Graph& graph, const Node& node, cons * */ Status MatmulBNFusion::Apply(Graph& graph, Node& matmul_node, RewriteRuleEffect& rule_effect, const logging::Logger&) const { - const Node& child_node = *matmul_node.OutputNodesBegin(); - std::vector ignorable_nodes_visited(matmulbnfusion::ignorable_nodes.size(), false); - NodeIndex batch_norm_node_index = MatchPath( - matmul_node, - child_node, - matmulbnfusion::dest, - matmulbnfusion::ignorable_nodes, - ignorable_nodes_visited) - ->get() - .Index(); + NodeIndex child_node_index = matmul_node.OutputNodesBegin()->Index(); + NodeIndex batch_norm_node_index = MatchPath(graph, matmul_node, child_node_index).value(); Node& batch_norm_node = *graph.GetNode(batch_norm_node_index); // need mutable node, that's why extracting node from graph @@ -213,7 +213,7 @@ Status MatmulBNFusion::Apply(Graph& graph, Node& matmul_node, RewriteRuleEffect& // Delete optional empty output defs. // Delete BatchNormalization node and update the input of the child of BatchNormalization batch_norm_node.MutableOutputDefs().resize(1); - NodeIndex batch_norm_parent_index = child_node.OpType() == "BatchNormalization" ? gemm_node.Index() : batch_norm_node.InputNodesBegin()->Index(); + NodeIndex batch_norm_parent_index = graph.GetNode(child_node_index)->OpType() == "BatchNormalization" ? gemm_node.Index() : batch_norm_node.InputNodesBegin()->Index(); graph_utils::FinalizeNodeFusion(graph, *graph.GetNode(batch_norm_parent_index), batch_norm_node); rule_effect = RewriteRuleEffect::kRemovedCurrentNode; From 95e3efb8277e59dd05a0f9d79881ec5a670875fe Mon Sep 17 00:00:00 2001 From: Sumit Agarwal Date: Fri, 20 Oct 2023 18:45:51 -0500 Subject: [PATCH 19/24] updated test model --- .../fusion/fuse-matmul-bn-directly.onnx | Bin 547 -> 513 bytes 1 file changed, 0 insertions(+), 0 deletions(-) diff --git a/onnxruntime/test/testdata/transform/fusion/fuse-matmul-bn-directly.onnx b/onnxruntime/test/testdata/transform/fusion/fuse-matmul-bn-directly.onnx index 187d03e3b49efc40a585e4fac05598c275307db8..fa11adaac8d95db4772990bac6f5b0b072f4d5c1 100644 GIT binary patch delta 279 zcmZ3?(#SGFlIa2SL}?>N|A~Gk(ro#qB?YA=N?ZH-O_wcdo`B5cFOwC?H0UQVf*jGE4%q; z)$QJ;8QPt@5^n!qdYetyVp)5QkGJh6l?U6OHlA&}Pi(ENe?hq2Wzjy{XJ?%3XG{{a zcjC`YEXgg+iBF0zNzE(HFVa%sU=?6g5@BGNoWQ6OSMcYzZHU@gJByTXdq=A}`}Olr z+BvZ&<;5puCKhWcL)1#|oMz|D@yIUn&kH-1T{-s3G6wcwMY*Yod0I*kML*UT+HXr$ WwO0(QvS+xl*|xyoh~4BzjE(>+)n}9d delta 313 zcmZoN_qt*{c@E*`ND2)wYE@)85&I-OBTQn0@BzWILW` ztoDbN_}G5V46^^Z;GmuDZ3+9pPCxrMQWNb;SdHu_Z!@;%%zI?}N>|+ei)x|$ByMs0 z>OWle!9FW(OY@lQ*W0AoC$VPOJMrfxmgJV^#3#jb<%1uqo(^7&cTDa5Peol42U1M*Veffhd`>PwHZ6_aKbOZps>~F*X From 009b86c8210a9db2b607472b0c57d5b009af8a51 Mon Sep 17 00:00:00 2001 From: Sumit Agarwal Date: Fri, 20 Oct 2023 19:16:04 -0500 Subject: [PATCH 20/24] Addressed PR comment --- onnxruntime/core/optimizer/matmul_bn_fusion.cc | 13 ++++++------- .../fusion/fuse-matmul-bn-only-transpose.onnx | Bin 613 -> 579 bytes 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/onnxruntime/core/optimizer/matmul_bn_fusion.cc b/onnxruntime/core/optimizer/matmul_bn_fusion.cc index bbf7a55986ada..78ed6bf1a1fed 100644 --- a/onnxruntime/core/optimizer/matmul_bn_fusion.cc +++ b/onnxruntime/core/optimizer/matmul_bn_fusion.cc @@ -8,12 +8,12 @@ namespace onnxruntime { -namespace matmulbnfusion { +namespace { const std::vector>> ignorable_nodes{ {"Reshape", {1, 5, 13, 14, 19}}, {"Transpose", {1, 13}}}; const std::pair> dest = {"BatchNormalization", {1, 6, 7, 9, 14, 15}}; -} // namespace matmulbnfusion +} // namespace bool NodeIsIgnorable(const Graph& graph, const Node& root_node, NodeIndex curr_node_index) { const Node* curr_node = graph.GetNode(curr_node_index); @@ -25,9 +25,8 @@ bool NodeIsIgnorable(const Graph& graph, const Node& root_node, NodeIndex curr_n } // curr_node can be any of the ignorable_nodes. - for (size_t index = 0; index < matmulbnfusion::ignorable_nodes.size(); index++) { - if (graph_utils::IsSupportedOptypeVersionAndDomain(*curr_node, matmulbnfusion::ignorable_nodes[index].first, - matmulbnfusion::ignorable_nodes[index].second)) { + for (size_t index = 0; index < ignorable_nodes.size(); index++) { + if (graph_utils::IsSupportedOptypeVersionAndDomain(*curr_node, ignorable_nodes[index].first, ignorable_nodes[index].second)) { return true; } } @@ -42,12 +41,12 @@ std::optional MatchPath(const Graph& graph, const Node& root_node, No // curr_node is neither ignorable nor dest const Node* curr_node = graph.GetNode(curr_node_index); - if (curr_node->OpType() != matmulbnfusion::dest.first) { + if (curr_node->OpType() != dest.first) { return std::nullopt; } if (curr_node->GetExecutionProviderType() == root_node.GetExecutionProviderType() && - graph_utils::IsSupportedOptypeVersionAndDomain(*curr_node, matmulbnfusion::dest.first, matmulbnfusion::dest.second)) { + graph_utils::IsSupportedOptypeVersionAndDomain(*curr_node, dest.first, dest.second)) { return curr_node_index; } diff --git a/onnxruntime/test/testdata/transform/fusion/fuse-matmul-bn-only-transpose.onnx b/onnxruntime/test/testdata/transform/fusion/fuse-matmul-bn-only-transpose.onnx index 5af84634c03e739e48119a08fc6a266b5f2bca7d..f70ae2e6229e7a89eb90cb8d68360e35e4c32291 100644 GIT binary patch delta 252 zcmaFLa+qa;B;%rqQiY7c6MHPA+44(E3Q9|qgq#vfk~94Bi*gflGOH3xGV}8$Gcqcr zh^~KWH(Oi9e#q#E#?Skf~+Vgh_+uQChw3oiCX&b#)%iiUYu)XfY19l;SJ$5^< z+S{CAH?-f9(qkubIMUAjT9*Bn_a650O ybv(A4{xQXVKf6CfQEqBto|Y0sQImwVeUeR;ozJOs`@c+zc9IKy>?f-*IRXI3XEefaAnd$~7P>?}{D+S@u; z+y9+(!cJxNY0S9B03^hS!c^qpsaH%hPt# z7YEuOyLZs;jMis6YrQ#kk{62XocMDSOL9wd;*;V_QuB)Qi?mcYSOplBL>L$*H!!Ni zy>-s9Pv2-_@9y4U_tBNh{{Pu9dnb0F&yq3|i?x&?YTxHp+lPHGwtIMm-+mfzqg|B# oUpug(+| Date: Fri, 20 Oct 2023 19:18:54 -0500 Subject: [PATCH 21/24] Added comments --- onnxruntime/core/optimizer/matmul_bn_fusion.cc | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/onnxruntime/core/optimizer/matmul_bn_fusion.cc b/onnxruntime/core/optimizer/matmul_bn_fusion.cc index 78ed6bf1a1fed..9e32348e7b0f7 100644 --- a/onnxruntime/core/optimizer/matmul_bn_fusion.cc +++ b/onnxruntime/core/optimizer/matmul_bn_fusion.cc @@ -65,6 +65,13 @@ std::optional MatchPath(const Graph& graph, const Node& root_node, No * | * BatchNormalization * Note: ^ means there can be 0 or any occurrences of that node. + * Few example fusable pattern: + * - MatMul -> Reshape -> Transpose -> BatchNormalization + * - MatMul -> Reshape -> BatchNormalization + * - MatMul -> Transpose -> BatchNormalization + * - MatMul -> Reshape -> Reshape -> BatchNormalization + * - MatMul -> Reshape -> Transpose -> Reshape -> BatchNormalization + * - MatMul -> BatchNormalization * Other Conditions: * - B tensor of MatMul should be constant. * - scale, B, mean, var tensors of BatchNormalization should be constant. From 65e067d88afec411ed7f144e2ec6bebd242bd301 Mon Sep 17 00:00:00 2001 From: Sumit Agarwal Date: Fri, 20 Oct 2023 19:23:13 -0500 Subject: [PATCH 22/24] Updated comment --- onnxruntime/core/optimizer/matmul_bn_fusion.cc | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/onnxruntime/core/optimizer/matmul_bn_fusion.cc b/onnxruntime/core/optimizer/matmul_bn_fusion.cc index 9e32348e7b0f7..c67421caef028 100644 --- a/onnxruntime/core/optimizer/matmul_bn_fusion.cc +++ b/onnxruntime/core/optimizer/matmul_bn_fusion.cc @@ -66,12 +66,12 @@ std::optional MatchPath(const Graph& graph, const Node& root_node, No * BatchNormalization * Note: ^ means there can be 0 or any occurrences of that node. * Few example fusable pattern: - * - MatMul -> Reshape -> Transpose -> BatchNormalization - * - MatMul -> Reshape -> BatchNormalization - * - MatMul -> Transpose -> BatchNormalization - * - MatMul -> Reshape -> Reshape -> BatchNormalization - * - MatMul -> Reshape -> Transpose -> Reshape -> BatchNormalization - * - MatMul -> BatchNormalization + * - MatMul -> Reshape -> Transpose -> BatchNormalization ---> GEMM -> Reshape -> Transpose + * - MatMul -> Reshape -> BatchNormalization ---> GEMM -> Reshape + * - MatMul -> Transpose -> BatchNormalization ---> GEMM -> Transpose + * - MatMul -> Reshape -> Reshape -> BatchNormalization ---> GEMM -> Reshape -> Reshape + * - MatMul -> Reshape -> Transpose -> Reshape -> BatchNormalization ---> GEMM -> Reshape -> Transpose -> Reshape + * - MatMul -> BatchNormalization ---> GEMM * Other Conditions: * - B tensor of MatMul should be constant. * - scale, B, mean, var tensors of BatchNormalization should be constant. From 018cdfb5fdac25af9d649b04093487526c545b8e Mon Sep 17 00:00:00 2001 From: Sumit Agarwal Date: Sun, 22 Oct 2023 20:21:00 -0500 Subject: [PATCH 23/24] Add test case without batchnormalization --- .../core/optimizer/matmul_bn_fusion.cc | 4 ++- .../test/optimizer/graph_transform_test.cc | 28 +++++++++++++++++++ 2 files changed, 31 insertions(+), 1 deletion(-) diff --git a/onnxruntime/core/optimizer/matmul_bn_fusion.cc b/onnxruntime/core/optimizer/matmul_bn_fusion.cc index c67421caef028..477ab7fa3ca9a 100644 --- a/onnxruntime/core/optimizer/matmul_bn_fusion.cc +++ b/onnxruntime/core/optimizer/matmul_bn_fusion.cc @@ -18,7 +18,9 @@ const std::pair> bool NodeIsIgnorable(const Graph& graph, const Node& root_node, NodeIndex curr_node_index) { const Node* curr_node = graph.GetNode(curr_node_index); - // curr_node has different execution provider then it's parent or has > 1 output + // curr_node has different execution provider then it's parent or + // has output edge != 1 (this condition will handle the case when ignorable node + // is graph output i.e. a graph like this "MatMul->Transpose") if (curr_node->GetExecutionProviderType() != root_node.GetExecutionProviderType() || curr_node->GetOutputEdgesCount() != 1) { return false; diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index dd2d3313b9ca9..6215d03cf235d 100755 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -1295,6 +1295,34 @@ TEST_F(GraphTransformationTests, FuseMatmulBNWithOnlyTranspose) { } } +TEST_F(GraphTransformationTests, FuseMatmulBNWithoutBatchNormalization) { + constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/fuse-matmul-bn-only-transpose.onnx"; + + std::shared_ptr p_model; + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); + Graph& graph = p_model->MainGraph(); + + GraphViewer graphViewer(graph); + for (auto& node_index : graphViewer.GetNodesInTopologicalOrder()) { + auto& node = *graph.GetNode(node_index); + if (node.OpType() == "BatchNormalization") { + graph_utils::RemoveNode(graph, node); + } + } + + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + auto rule_transformer_L1 = std::make_unique("RuleTransformerL1"); + ASSERT_STATUS_OK(rule_transformer_L1->Register(std::make_unique())); + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1)); + + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); + + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_EQ(op_to_count["MatMul"], 1); + + +} + // should not fuse TEST_F(GraphTransformationTests, FuseMatmulBNWithNonIgnorableNode) { constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/fuse-matmul-bn-non-ignorable-node.onnx"; From d79a6074e80a80c503673203f56b70e4cb6de432 Mon Sep 17 00:00:00 2001 From: Sumit Agarwal Date: Sun, 22 Oct 2023 20:23:03 -0500 Subject: [PATCH 24/24] Apply lintrunner --- onnxruntime/core/optimizer/matmul_bn_fusion.cc | 2 +- onnxruntime/test/optimizer/graph_transform_test.cc | 2 -- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/onnxruntime/core/optimizer/matmul_bn_fusion.cc b/onnxruntime/core/optimizer/matmul_bn_fusion.cc index 477ab7fa3ca9a..e944522c9c338 100644 --- a/onnxruntime/core/optimizer/matmul_bn_fusion.cc +++ b/onnxruntime/core/optimizer/matmul_bn_fusion.cc @@ -18,7 +18,7 @@ const std::pair> bool NodeIsIgnorable(const Graph& graph, const Node& root_node, NodeIndex curr_node_index) { const Node* curr_node = graph.GetNode(curr_node_index); - // curr_node has different execution provider then it's parent or + // curr_node has different execution provider then it's parent or // has output edge != 1 (this condition will handle the case when ignorable node // is graph output i.e. a graph like this "MatMul->Transpose") if (curr_node->GetExecutionProviderType() != root_node.GetExecutionProviderType() || diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index 6215d03cf235d..46b95a127b75c 100755 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -1319,8 +1319,6 @@ TEST_F(GraphTransformationTests, FuseMatmulBNWithoutBatchNormalization) { std::map op_to_count = CountOpsInGraph(graph); ASSERT_EQ(op_to_count["MatMul"], 1); - - } // should not fuse