diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index f769d31092d19..ff7da413e5765 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -134,12 +134,12 @@ InlinedVector> GenerateRewriteRules( rules.push_back(std::make_unique()); rules.push_back(std::make_unique()); rules.push_back(std::make_unique()); + rules.push_back(std::make_unique()); rules.push_back(std::make_unique()); break; case TransformerLevel::Level2: rules.push_back(std::make_unique()); - rules.push_back(std::make_unique()); rules.push_back(std::make_unique()); break; diff --git a/onnxruntime/core/optimizer/qdq_transformer/relu_quantizelinear.cc b/onnxruntime/core/optimizer/qdq_transformer/relu_quantizelinear.cc index e756ffe78a289..7417212c570c8 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/relu_quantizelinear.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/relu_quantizelinear.cc @@ -13,15 +13,13 @@ namespace onnxruntime { bool ReluQuantFusion::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& /*logger*/) const { if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Relu", {6, 13, 14}) || - !graph_utils::IsSupportedProvider(node, {kCpuExecutionProvider}) || !optimizer_utils::CheckOutputEdges(graph, node, 1)) { return false; } // if Relu is followed by QuantizeLinear, it can be fused into QuantizeLinear potentially const auto& next_node = *node.OutputNodesBegin(); - if (!graph_utils::IsSupportedProvider(next_node, {kCpuExecutionProvider}) || - !QDQ::MatchQNode(next_node)) { + if (!QDQ::MatchQNode(next_node)) { return false; } diff --git a/onnxruntime/test/optimizer/qdq_transformer_test.cc b/onnxruntime/test/optimizer/qdq_transformer_test.cc index d07977d4b97b8..792ea2793bb4f 100644 --- a/onnxruntime/test/optimizer/qdq_transformer_test.cc +++ b/onnxruntime/test/optimizer/qdq_transformer_test.cc @@ -3101,57 +3101,6 @@ TEST(QDQTransformerTests, Clip) { } } -// Test that the ReluQuantFusion transformer only runs for optimization level >= 2. -TEST(QDQTransformerTests, ReluQuantFusion_Level2Only) { - auto test_case = [&](TransformerLevel opt_level, int8_t zp) { - auto build_test_case = [&](ModelTestBuilder& builder) { - auto* input_arg = builder.MakeInput({1, 2, 2, 2}, - {-4, -3, -2, 0, 1, 2, 3, 4}); - auto* output_arg = builder.MakeOutput(); - - // add DQ - auto* dq_output = builder.MakeIntermediate(); - builder.AddDequantizeLinearNode(input_arg, 1.0f, zp, dq_output); - - // add Relu - auto* relu_output = builder.MakeIntermediate(); - builder.AddNode("Relu", {dq_output}, {relu_output}); - - // add Q + DQ - auto* q_output = builder.MakeIntermediate(); - builder.AddQuantizeLinearNode(relu_output, 1.0f, zp, q_output); - builder.AddDequantizeLinearNode(q_output, 1.0f, zp, output_arg); - }; - - auto check_relu_graph = [&](InferenceSessionWrapper& session) { - auto op_to_count = CountOpsInGraph(session.GetGraph()); - const QDQOpKeys qdq_keys = GetQDQOpKeys(false); - // Only fuse relu into Q if level >= 2 and zero_point == -128 for int8. - // Level1 graph: input -> DQ -> Relu -> Q -> DQ -> output - // Level2+ graph: input -> DQ -> output (QuantReluFusion + QDQFinalCleanupTransformer transformers) - const bool fuse_relu = (zp == -128) && - (opt_level == TransformerLevel::Level2 || opt_level == TransformerLevel::Level3); - EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], fuse_relu ? 0 : 1); - EXPECT_EQ(op_to_count["Relu"], fuse_relu ? 0 : 1); - EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], fuse_relu ? 1 : 2); - }; - - constexpr float epsilon = std::numeric_limits::epsilon(); - - TransformerTester(build_test_case, check_relu_graph, - TransformerLevel::Default, - opt_level, - 18, - epsilon, - epsilon); - }; - - test_case(TransformerLevel::Level1, -128); // Will not fuse Relu into QuantizeLinear due to level1 opt. - test_case(TransformerLevel::Level2, -128); // Will fuse Relu into QuantizeLinear. - test_case(TransformerLevel::Level3, -128); // Will fuse Relu into QuantizeLinear. - test_case(TransformerLevel::Level3, 0); // Will not fuse Relu into QuantizeLinear due to zero-point != -128 -} - TEST(QDQTransformerTests, Concat) { auto test_case = [&](const std::vector>& input_shapes, int64_t axis,