From 7a328477612cafe690a5479b1ce60ee9b87bd4c4 Mon Sep 17 00:00:00 2001 From: Guoyu Wang Date: Mon, 7 Feb 2022 00:41:30 -0800 Subject: [PATCH] move test case to util --- .../builders/op_support_checker.cc | 2 +- onnxruntime/test/optimizer/qdq_test_utils.h | 48 ++++++++++++++++++ .../test/optimizer/qdq_transformer_test.cc | 50 +------------------ 3 files changed, 51 insertions(+), 49 deletions(-) diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_support_checker.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_support_checker.cc index 37ea601ba9852..c1431c08875f7 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_support_checker.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_support_checker.cc @@ -210,7 +210,7 @@ static bool IsQuantizedIOSupported(const InitializedTensorSet& initializers, con const auto& op_type = node_unit.OpType(); auto quant_op_type = GetQuantizedOpType(node_unit); - ORT_ENFORCE(quant_op_type != QuantizedOpType::QLinearMatMul, "[", op_type, "] is not a quantized op"); + ORT_ENFORCE(quant_op_type != QuantizedOpType::Unknown, "[", op_type, "] is not a quantized op"); bool is_quant_conv = IsQuantizedConv(quant_op_type); bool is_quant_matmul = (quant_op_type == QuantizedOpType::QLinearMatMul); diff --git a/onnxruntime/test/optimizer/qdq_test_utils.h b/onnxruntime/test/optimizer/qdq_test_utils.h index 118ed8697dae0..bfbda03d53d94 100644 --- a/onnxruntime/test/optimizer/qdq_test_utils.h +++ b/onnxruntime/test/optimizer/qdq_test_utils.h @@ -83,5 +83,53 @@ GetQDQTestCaseFn BuildQDQResizeTestCase(const std::vector& input_shape, const std::string& mode = "nearest", const std::string& coordinate_transformation_mode = "half_pixel"); +template +GetQDQTestCaseFn BuildBinaryOpTestCase(const std::vector& input_shape, + const std::string& op_type) { + return [input_shape, op_type](ModelTestBuilder& builder) { + auto* input1_arg = builder.MakeInput(input_shape, -1.f, 1.f); + auto* input2_arg = builder.MakeInput(input_shape, -1.f, 1.f); + auto* output_arg = builder.MakeOutput(); + + // add QDQ 1 + auto* q1_output = builder.MakeIntermediate(); + auto* dq1_output = builder.MakeIntermediate(); + builder.AddQuantizeLinearNode(input1_arg, + .004f, + std::numeric_limits::max() / 2, + q1_output); + builder.AddDequantizeLinearNode(q1_output, + .0039f, + std::numeric_limits::max() / 2, + dq1_output); + + // add QDQ 2 + auto* q2_output = builder.MakeIntermediate(); + auto* dq2_output = builder.MakeIntermediate(); + builder.AddQuantizeLinearNode(input2_arg, + .004f, + std::numeric_limits::max() / 2, + q2_output); + builder.AddDequantizeLinearNode(q2_output, + .0039f, + std::numeric_limits::max() / 2, + dq2_output); + + // add binary operator + auto* binary_op_output = builder.MakeIntermediate(); + builder.AddNode(op_type, {dq1_output, dq2_output}, {binary_op_output}); + + // add QDQ output + auto* q3_output = builder.MakeIntermediate(); + builder.AddQuantizeLinearNode(binary_op_output, + .0038f, + std::numeric_limits::max() / 2, + q3_output); + builder.AddDequantizeLinearNode(q3_output, + .0039f, + std::numeric_limits::max() / 2, + output_arg); + }; +} } // namespace test } // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/test/optimizer/qdq_transformer_test.cc b/onnxruntime/test/optimizer/qdq_transformer_test.cc index c0fcac7d593f6..39e644029912f 100644 --- a/onnxruntime/test/optimizer/qdq_transformer_test.cc +++ b/onnxruntime/test/optimizer/qdq_transformer_test.cc @@ -290,51 +290,6 @@ TEST(QDQTransformerTests, AveragePool_U8S8) { template void QDQTransformerBinaryOpTests(const std::string& op_type) { auto test_case = [&](const std::vector& input_shape) { - auto build_test_case = [&](ModelTestBuilder& builder) { - auto* input1_arg = builder.MakeInput(input_shape, -1.f, 1.f); - auto* input2_arg = builder.MakeInput(input_shape, -1.f, 1.f); - auto* output_arg = builder.MakeOutput(); - - // add QDQ 1 - auto* q1_output = builder.MakeIntermediate(); - auto* dq1_output = builder.MakeIntermediate(); - builder.AddQuantizeLinearNode(input1_arg, - .004f, - std::numeric_limits::max() / 2, - q1_output); - builder.AddDequantizeLinearNode(q1_output, - .0039f, - std::numeric_limits::max() / 2, - dq1_output); - - // add QDQ 2 - auto* q2_output = builder.MakeIntermediate(); - auto* dq2_output = builder.MakeIntermediate(); - builder.AddQuantizeLinearNode(input2_arg, - .004f, - std::numeric_limits::max() / 2, - q2_output); - builder.AddDequantizeLinearNode(q2_output, - .0039f, - std::numeric_limits::max() / 2, - dq2_output); - - // add binary operator - auto* binary_op_output = builder.MakeIntermediate(); - builder.AddNode(op_type, {dq1_output, dq2_output}, {binary_op_output}); - - // add QDQ output - auto* q3_output = builder.MakeIntermediate(); - builder.AddQuantizeLinearNode(binary_op_output, - .0038f, - std::numeric_limits::max() / 2, - q3_output); - builder.AddDequantizeLinearNode(q3_output, - .0039f, - std::numeric_limits::max() / 2, - output_arg); - }; - auto check_graph = [&](InferenceSessionWrapper& session) { auto op_to_count = CountOpsInGraph(session.GetGraph()); if (std::is_same::value && @@ -351,7 +306,7 @@ void QDQTransformerBinaryOpTests(const std::string& op_type) { } }; - TransformerTester(build_test_case, + TransformerTester(BuildBinaryOpTestCase(input_shape, op_type), check_graph, TransformerLevel::Level1, TransformerLevel::Level2, @@ -614,8 +569,7 @@ void QDQTransformerGemmTests(bool has_output_q, bool has_bias, bool beta_not_one auto check_binary_op_graph = [&](InferenceSessionWrapper& session) { auto op_to_count = CountOpsInGraph(session.GetGraph()); - if ((!has_output_q || std::is_same_v)&& - (!has_bias || (std::is_same_v && !beta_not_one)) && + if ((!has_output_q || std::is_same_v)&&(!has_bias || (std::is_same_v && !beta_not_one)) && (std::is_same_v || std::is_same_v)) { EXPECT_EQ(op_to_count["com.microsoft.QGemm"], 1); EXPECT_EQ(op_to_count["Gemm"], 0);