Skip to content

Commit

Permalink
Revert "Move ReluQuantFusion to Level2 for CPU EP only (#21329)"
Browse files Browse the repository at this point in the history
This reverts commit 22d4d82.
  • Loading branch information
cloudhan committed Oct 29, 2024
1 parent 7f1dd50 commit afe2e76
Show file tree
Hide file tree
Showing 3 changed files with 2 additions and 55 deletions.
2 changes: 1 addition & 1 deletion onnxruntime/core/optimizer/graph_transformer_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -134,12 +134,12 @@ InlinedVector<std::unique_ptr<RewriteRule>> GenerateRewriteRules(
rules.push_back(std::make_unique<ConvBNFusion>());
rules.push_back(std::make_unique<PadFusion>());
rules.push_back(std::make_unique<MatmulBNFusion>());
rules.push_back(std::make_unique<ReluQuantFusion>());
rules.push_back(std::make_unique<LabelEncoderFusion>());
break;

case TransformerLevel::Level2:
rules.push_back(std::make_unique<ClipQuantFusion>());
rules.push_back(std::make_unique<ReluQuantFusion>());
rules.push_back(std::make_unique<GemmTransposeFusion>());
break;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,13 @@ namespace onnxruntime {

bool ReluQuantFusion::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& /*logger*/) const {
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Relu", {6, 13, 14}) ||
!graph_utils::IsSupportedProvider(node, {kCpuExecutionProvider}) ||
!optimizer_utils::CheckOutputEdges(graph, node, 1)) {
return false;
}

// if Relu is followed by QuantizeLinear, it can be fused into QuantizeLinear potentially
const auto& next_node = *node.OutputNodesBegin();
if (!graph_utils::IsSupportedProvider(next_node, {kCpuExecutionProvider}) ||
!QDQ::MatchQNode(next_node)) {
if (!QDQ::MatchQNode(next_node)) {
return false;
}

Expand Down
51 changes: 0 additions & 51 deletions onnxruntime/test/optimizer/qdq_transformer_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3101,57 +3101,6 @@ TEST(QDQTransformerTests, Clip) {
}
}

// Test that the ReluQuantFusion transformer only runs for optimization level >= 2.
TEST(QDQTransformerTests, ReluQuantFusion_Level2Only) {
auto test_case = [&](TransformerLevel opt_level, int8_t zp) {
auto build_test_case = [&](ModelTestBuilder& builder) {
auto* input_arg = builder.MakeInput<int8_t>({1, 2, 2, 2},
{-4, -3, -2, 0, 1, 2, 3, 4});
auto* output_arg = builder.MakeOutput();

// add DQ
auto* dq_output = builder.MakeIntermediate();
builder.AddDequantizeLinearNode<int8_t>(input_arg, 1.0f, zp, dq_output);

// add Relu
auto* relu_output = builder.MakeIntermediate();
builder.AddNode("Relu", {dq_output}, {relu_output});

// add Q + DQ
auto* q_output = builder.MakeIntermediate();
builder.AddQuantizeLinearNode<int8_t>(relu_output, 1.0f, zp, q_output);
builder.AddDequantizeLinearNode<int8_t>(q_output, 1.0f, zp, output_arg);
};

auto check_relu_graph = [&](InferenceSessionWrapper& session) {
auto op_to_count = CountOpsInGraph(session.GetGraph());
const QDQOpKeys qdq_keys = GetQDQOpKeys(false);
// Only fuse relu into Q if level >= 2 and zero_point == -128 for int8.
// Level1 graph: input -> DQ -> Relu -> Q -> DQ -> output
// Level2+ graph: input -> DQ -> output (QuantReluFusion + QDQFinalCleanupTransformer transformers)
const bool fuse_relu = (zp == -128) &&
(opt_level == TransformerLevel::Level2 || opt_level == TransformerLevel::Level3);
EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], fuse_relu ? 0 : 1);
EXPECT_EQ(op_to_count["Relu"], fuse_relu ? 0 : 1);
EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], fuse_relu ? 1 : 2);
};

constexpr float epsilon = std::numeric_limits<float>::epsilon();

TransformerTester(build_test_case, check_relu_graph,
TransformerLevel::Default,
opt_level,
18,
epsilon,
epsilon);
};

test_case(TransformerLevel::Level1, -128); // Will not fuse Relu into QuantizeLinear due to level1 opt.
test_case(TransformerLevel::Level2, -128); // Will fuse Relu into QuantizeLinear.
test_case(TransformerLevel::Level3, -128); // Will fuse Relu into QuantizeLinear.
test_case(TransformerLevel::Level3, 0); // Will not fuse Relu into QuantizeLinear due to zero-point != -128
}

TEST(QDQTransformerTests, Concat) {
auto test_case = [&](const std::vector<std::vector<int64_t>>& input_shapes,
int64_t axis,
Expand Down

0 comments on commit afe2e76

Please sign in to comment.