Skip to content

Commit

Permalink
Fix MatMulBnFusion to exclude cases when tensors are not 2D tensors (m…
Browse files Browse the repository at this point in the history
…icrosoft#22762)

### Description
Fixes microsoft#22512, MatMul, Add can be fused into a single Gemm even if
tensors dimensions are > 2. The PR excludes that cases.



### Motivation and Context
ORT crashes on valid models due to that unexpected fusion.
  • Loading branch information
xadupre authored and Ishwar Raut committed Nov 19, 2024
1 parent 210eae2 commit 7aaa136
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 0 deletions.
17 changes: 17 additions & 0 deletions onnxruntime/core/optimizer/matmul_bn_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,22 @@ bool MatmulBNFusion::SatisfyCondition(const Graph& graph, const Node& node, cons
return false;
}

// Checks the first input of MatMul has 2 dimensions.
// The test for the second input is done in method Apply as it accesses the constant.
if (node.InputDefs()[0] == nullptr) {
// This should never happen but just in case.
return false;
}
auto shape_a = node.InputDefs()[0]->Shape();
if (shape_a == nullptr) {
// We cannot shape the rank. It is better to avoid fusing.
return false;
}
if (shape_a->dim_size() != 2) {
// Gemm only supports 2D tensors.
return false;
}

// First output from BN is required. Others are optional. If any optional outputs exist we can't fuse.
const auto& output_defs = batch_norm_node->OutputDefs();
if (output_defs.size() > 1) {
Expand Down Expand Up @@ -165,6 +181,7 @@ Status MatmulBNFusion::Apply(Graph& graph, Node& matmul_node, RewriteRuleEffect&
bias_tensor->dims_size() != 1 ||
mean_tensor->dims_size() != 1 ||
var_tensor->dims_size() != 1 ||
matmul_b_tensor->dims_size() != 2 ||
scale_tensor->dims(0) != matmul_b_tensor->dims(1) ||
bias_tensor->dims(0) != matmul_b_tensor->dims(1) ||
mean_tensor->dims(0) != matmul_b_tensor->dims(1) ||
Expand Down
29 changes: 29 additions & 0 deletions onnxruntime/test/optimizer/graph_transform_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1764,6 +1764,35 @@ TEST_F(GraphTransformationTests, FuseMatmulBNDirectly) {
}
}

TEST_F(GraphTransformationTests, DoNotApplyFuseMatmulBNDirectly) {
constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/fuse-matmul-bn-directly-dont-fuse.onnx";

std::shared_ptr<Model> p_model;
ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_));
Graph& graph = p_model->MainGraph();

std::string expected_output_name;
GraphViewer graphViewer(graph);
for (auto& node_index : graphViewer.GetNodesInTopologicalOrder()) {
auto& node = *graph.GetNode(node_index);
if (node.OpType() == "BatchNormalization") {
expected_output_name = node.OutputDefs()[0]->Name();
}
}

onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
auto rule_transformer_L1 = std::make_unique<RuleBasedGraphTransformer>("RuleTransformerL1");
ASSERT_STATUS_OK(rule_transformer_L1->Register(std::make_unique<MatmulBNFusion>()));
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1));

ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));

std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
ASSERT_EQ(op_to_count["BatchNormalization"], 1);
ASSERT_EQ(op_to_count["MatMul"], 1);
ASSERT_EQ(op_to_count["Gemm"], 0);
}

TEST_F(GraphTransformationTests, FuseMatmulBNWithOnlyReshape) {
constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/fuse-matmul-bn-only-reshape.onnx";

Expand Down
Binary file not shown.

0 comments on commit 7aaa136

Please sign in to comment.