Skip to content

Commit

Permalink
Do not fuse DQ+Node+Q if DQ produces graph output (#14509)
Browse files Browse the repository at this point in the history
Fix issue #14501
  • Loading branch information
yufenglee authored Feb 1, 2023
1 parent 3d388a1 commit d9e675a
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,14 @@ bool NodeGroupSelector::CheckQDQNodes(const GraphViewer& graph_viewer, const Nod
return false;
}

auto does_node_produce_graph_output = [&graph_viewer](const Node* node_ptr) {
return graph_viewer.NodeProducesGraphOutput(*node_ptr);
};

if (std::any_of(dq_nodes.begin(), dq_nodes.end(), does_node_produce_graph_output)) {
return false;
}

if (q_nodes.empty()) {
return is_empty_q_nodes_allowed;
}
Expand Down
56 changes: 56 additions & 0 deletions onnxruntime/test/optimizer/qdq_transformer_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2760,5 +2760,61 @@ TEST(QDQTransformerTests, QDQFinalCleanupTransformer_GraphInputToOutput) {
test_case(false);
}

// Not fuse if DQ produces graph output
TEST(QDQTransformerTests, DQ_Produce_Graph_Output) {
auto test_case = [&](const std::vector<int64_t>& input_shape, int64_t axis) {
auto build_test_case = [&](ModelTestBuilder& builder) {
auto* input_arg = builder.MakeInput<float>(input_shape, -5.f, 5.f);
auto* dq_output_arg = builder.MakeOutput();
auto* output_arg = builder.MakeOutput();
// add input QDQ
auto* input_q_output = builder.MakeIntermediate();
builder.AddQuantizeLinearNode<uint8_t>(input_arg,
.105f,
127,
input_q_output);
builder.AddDequantizeLinearNode<uint8_t>(input_q_output,
.105f,
127,
dq_output_arg);

// add Softmax
auto* softmax_output = builder.MakeIntermediate();
auto& softmax_node = builder.AddNode("Softmax", {dq_output_arg}, {softmax_output});
softmax_node.AddAttribute("axis", axis);

// add output QDQ
auto* q_output = builder.MakeIntermediate();
builder.AddQuantizeLinearNode<uint8_t>(softmax_output,
1.0f / (std::numeric_limits<uint8_t>::max() + 1),
0,
q_output);
builder.AddDequantizeLinearNode<uint8_t>(q_output,
1.0f / (std::numeric_limits<uint8_t>::max() + 1),
0,
output_arg);
};

auto check_graph = [&](InferenceSessionWrapper& session) {
auto op_to_count = CountOpsInGraph(session.GetGraph());
EXPECT_EQ(op_to_count["com.microsoft.QLinearSoftmax"], 0);
EXPECT_EQ(op_to_count["Softmax"], 1);
EXPECT_EQ(op_to_count["QuantizeLinear"], 2);
EXPECT_EQ(op_to_count["DequantizeLinear"], 2);
};

TransformerTester(build_test_case,
check_graph,
TransformerLevel::Level1,
TransformerLevel::Level2,
12 /*opset_version*/,
0.01 /*per_sample_tolerance*/,
0.01 /*relative_per_sample_tolerance*/,
std::make_unique<QDQSelectorActionTransformer>(QDQIsInt8Allowed()));
};

test_case({1, 12, 37}, -1);
}

} // namespace test
} // namespace onnxruntime

0 comments on commit d9e675a

Please sign in to comment.