From 8e47bb9a4a3a630a3731ce19edbde6a5546b2a50 Mon Sep 17 00:00:00 2001 From: Rachel Guo <35738743+YUNQIUGUO@users.noreply.github.com> Date: Tue, 15 Feb 2022 12:46:05 -0800 Subject: [PATCH] [NNAPI QDQ] Add QDQReshape op support (#10533) * wip * wip * save * address partial pr comments * update * minor change * move isquantizedop to baseopbuilderorchecker * update * format * update * update * address pr comments * update Co-authored-by: rachguo --- .../nnapi/nnapi_builtin/builders/helper.cc | 2 + .../nnapi/nnapi_builtin/builders/helper.h | 1 + .../nnapi_builtin/builders/op_builder.cc | 54 +++++++++++++------ .../builders/op_support_checker.cc | 51 ++++++++++++++---- onnxruntime/test/optimizer/qdq_test_utils.cc | 22 ++++++++ onnxruntime/test/optimizer/qdq_test_utils.h | 3 ++ .../test/providers/nnapi/nnapi_basic_test.cc | 9 ++++ 7 files changed, 114 insertions(+), 28 deletions(-) diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/helper.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/helper.cc index 7ae031d45f7c8..fe193a3a630af 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/helper.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/helper.cc @@ -80,6 +80,8 @@ QuantizedOpType GetQuantizedOpType(const NodeUnit& node_unit) { return QuantizedOpType::QDQMul; else if (op_type == "Transpose") return QuantizedOpType::QDQTranspose; + else if (op_type == "Reshape") + return QuantizedOpType::QDQReshape; } else { // throw? } diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/helper.h b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/helper.h index 859a92b0bf982..b8f754a925c50 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/helper.h +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/helper.h @@ -91,6 +91,7 @@ enum class QuantizedOpType : uint8_t { QDQAdd, QDQMul, QDQTranspose, + QDQReshape, // TODO, add other QDQ NodeUnit types }; diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder.cc index cd467017c3ab4..07d310ac2b1ab 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder.cc @@ -622,6 +622,7 @@ class BaseOpBuilder : public IOpBuilder { protected: virtual Status AddToModelBuilderImpl(ModelBuilder& model_builder, const NodeUnit& node_unit) const = 0; static bool IsOpSupported(const ModelBuilder& model_builder, const NodeUnit& node_unit) ORT_MUST_USE_RESULT; + virtual bool IsQuantizedOp(const NodeUnit& /* node_unit */) const { return false; } }; /* static */ bool BaseOpBuilder::IsOpSupported(const ModelBuilder& model_builder, const NodeUnit& node_unit) { @@ -651,11 +652,11 @@ class BinaryOpBuilder : public BaseOpBuilder { static void CreateSharedOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); private: - static bool IsQuantizedOp(const NodeUnit& node_unit) ORT_MUST_USE_RESULT; // TODO, see if we want to move this to BaseOpBuilder + bool IsQuantizedOp(const NodeUnit& node_unit) const override; Status AddToModelBuilderImpl(ModelBuilder& model_builder, const NodeUnit& node_unit) const override; }; -/* static */ bool BinaryOpBuilder::IsQuantizedOp(const NodeUnit& node_unit) { +bool BinaryOpBuilder::IsQuantizedOp(const NodeUnit& node_unit) const { const auto quant_type = GetQuantizedOpType(node_unit); return quant_type == QuantizedOpType::QLinearAdd || quant_type == QuantizedOpType::QLinearMul || @@ -793,7 +794,7 @@ class TransposeOpBuilder : public BaseOpBuilder { private: Status AddToModelBuilderImpl(ModelBuilder& model_builder, const NodeUnit& node_unit) const override; - static bool IsQuantizedOp(const NodeUnit& node_unit) ORT_MUST_USE_RESULT; // TODO, see if we want to move this to BaseOpBuilder + bool IsQuantizedOp(const NodeUnit& node_unit) const override; }; void TransposeOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const NodeUnit& node_unit) const { @@ -804,7 +805,7 @@ void TransposeOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, cons AddQuantizationScaleAndZeroPointToSkip(model_builder, *node_unit.Outputs()[0].quant_param); // y_scale, y_zp } -/* static */ bool TransposeOpBuilder::IsQuantizedOp(const NodeUnit& node_unit) { +bool TransposeOpBuilder::IsQuantizedOp(const NodeUnit& node_unit) const { return GetQuantizedOpType(node_unit) == QuantizedOpType::QDQTranspose; } @@ -867,12 +868,21 @@ class ReshapeOpBuilder : public BaseOpBuilder { Status AddToModelBuilderImpl(ModelBuilder& model_builder, const NodeUnit& node_unit) const override; static bool CanSkipReshape(const ModelBuilder& model_builder, const NodeUnit& node_unit, size_t input_rank, size_t output_rank); + bool IsQuantizedOp(const NodeUnit& node_unit) const override; }; void ReshapeOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const NodeUnit& node_unit) const { + if (IsQuantizedOp(node_unit)) { + AddQuantizationScaleAndZeroPointToSkip(model_builder, *node_unit.Inputs()[0].quant_param); // x_scale, x_zp + AddQuantizationScaleAndZeroPointToSkip(model_builder, *node_unit.Outputs()[0].quant_param); // y_scale, y_zp + } model_builder.AddInitializerToSkip(node_unit.Inputs()[1].node_arg.Name()); } +bool ReshapeOpBuilder::IsQuantizedOp(const NodeUnit& node_unit) const { + return GetQuantizedOpType(node_unit) == QuantizedOpType::QDQReshape; +} + // We can skip the Reshape if all the output edges satisfies both the following conditions // 1. The output the reshape/flatten is not an output of the graph // 2. The output of the reshape/flatten is the input 0 of one or more GEMM/Matmul operators, @@ -956,12 +966,15 @@ void ReshapeOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const auto input_rank = shaper[input].size(); auto output_rank = shaper[output].size(); + // For reshape, the output type should be the same as the input type except the shape is different + auto output_operand_type = operand_types.at(input); + output_operand_type.SetDimensions(shaper[output]); + // Since Reshape is not running using hardware in NNAPI for some CPU (e.g. Qualcomm SD for now) // We will try to see if we the skip the Reshape to prevent context switching between // NNAPI CPU impl and NNAPI hardware accelerator impl if (CanSkipReshape(model_builder, node_unit, input_rank, output_rank)) { // Since reshape can be skipped, only register the dimension and type, with same index and new name - const OperandType output_operand_type(operand_types.at(input).type, shaper[output]); model_builder.RegisterOperand(output, operand_indices.at(input), output_operand_type, false); } else { // We still need to perform a reshape here @@ -974,8 +987,6 @@ void ReshapeOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const OperandType shape_operand_type(Type::TENSOR_INT32, shape_dimen); ORT_RETURN_IF_ERROR(model_builder.AddOperandFromPersistMemoryBuffer(shape_name, shape.data(), shape_operand_type)); input_indices.push_back(operand_indices.at(shape_name)); - - const OperandType output_operand_type(operand_types.at(input).type, shaper[output]); ORT_RETURN_IF_ERROR(model_builder.AddOperation(ANEURALNETWORKS_RESHAPE, input_indices, {output}, {output_operand_type}, {false})); } @@ -1006,6 +1017,15 @@ Status ReshapeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, cons shape[i] = dim == 0 ? input_shape[i] : dim; } + // Check if the quantization scale and ZP are correct + float x_scale = 0.0f; + int32_t x_zero_point = 0; + if (IsQuantizedOp(node_unit)) { + ORT_RETURN_IF_ERROR(GetQuantizationScaleAndZeroPoint( + initializers, node_unit.Inputs()[0], node_unit.ModelPath(), x_scale, x_zero_point)); + ORT_RETURN_IF_ERROR(IsValidInputQuantizedType(model_builder, input, x_scale, x_zero_point)); + } + return AddReshapeOperator(model_builder, node_unit, input, shape); } @@ -1131,11 +1151,11 @@ class PoolOpBuilder : public BaseOpBuilder { static void CreateSharedOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); private: - static bool IsQuantizedOp(const NodeUnit& node_unit) ORT_MUST_USE_RESULT; // TODO, see if we want to move this to BaseOpBuilder + bool IsQuantizedOp(const NodeUnit& node_unit) const override; Status AddToModelBuilderImpl(ModelBuilder& model_builder, const NodeUnit& node_unit) const override; }; -/* static */ bool PoolOpBuilder::IsQuantizedOp(const NodeUnit& node_unit) { +bool PoolOpBuilder::IsQuantizedOp(const NodeUnit& node_unit) const { return IsQuantizedPool(GetQuantizedOpType(node_unit)); } @@ -1284,11 +1304,11 @@ class ConvOpBuilder : public BaseOpBuilder { static void CreateSharedOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); private: - static bool IsQuantizedOp(const NodeUnit& node_unit) ORT_MUST_USE_RESULT; // TODO, see if we want to move this to BaseOpBuilder + bool IsQuantizedOp(const NodeUnit& node_unit) const override; Status AddToModelBuilderImpl(ModelBuilder& model_builder, const NodeUnit& node_unit) const override; }; -/* static */ bool ConvOpBuilder::IsQuantizedOp(const NodeUnit& node_unit) { +bool ConvOpBuilder::IsQuantizedOp(const NodeUnit& node_unit) const { return IsQuantizedConv(GetQuantizedOpType(node_unit)); } @@ -1673,11 +1693,11 @@ class GemmOpBuilder : public BaseOpBuilder { static void CreateSharedOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); private: - static bool IsQuantizedOp(const NodeUnit& node_unit) ORT_MUST_USE_RESULT; // TODO, see if we want to move this to BaseOpBuilder + bool IsQuantizedOp(const NodeUnit& node_unit) const override; Status AddToModelBuilderImpl(ModelBuilder& model_builder, const NodeUnit& node_unit) const override; }; -/* static */ bool GemmOpBuilder::IsQuantizedOp(const NodeUnit& node_unit) { +bool GemmOpBuilder::IsQuantizedOp(const NodeUnit& node_unit) const { // TODO, add support for QDQ NodeUnit return node_unit.OpType() == "QLinearMatMul"; } @@ -1833,11 +1853,11 @@ class UnaryOpBuilder : public BaseOpBuilder { static void CreateSharedOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); private: - static bool IsQuantizedOp(const NodeUnit& node_unit) ORT_MUST_USE_RESULT; // TODO, see if we want to move this to BaseOpBuilder + bool IsQuantizedOp(const NodeUnit& node_unit) const override; Status AddToModelBuilderImpl(ModelBuilder& model_builder, const NodeUnit& node_unit) const override; }; -/* static */ bool UnaryOpBuilder::IsQuantizedOp(const NodeUnit& node_unit) { +bool UnaryOpBuilder::IsQuantizedOp(const NodeUnit& node_unit) const { // TODO, add support for QDQ NodeUnit return node_unit.OpType() == "QLinearSigmoid"; } @@ -2287,10 +2307,10 @@ class ResizeOpBuilder : public BaseOpBuilder { private: Status AddToModelBuilderImpl(ModelBuilder& model_builder, const NodeUnit& node_unit) const override; - static bool IsQuantizedOp(const NodeUnit& node_unit) ORT_MUST_USE_RESULT; // TODO, see if we want to move this to BaseOpBuilder + bool IsQuantizedOp(const NodeUnit& node_unit) const override; }; -/* static */ bool ResizeOpBuilder::IsQuantizedOp(const NodeUnit& node_unit) { +bool ResizeOpBuilder::IsQuantizedOp(const NodeUnit& node_unit) const { return GetQuantizedOpType(node_unit) == QuantizedOpType::QDQResize; } diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_support_checker.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_support_checker.cc index 5f0b4d840a67a..5150d7ae37e65 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_support_checker.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_support_checker.cc @@ -289,6 +289,8 @@ class BaseOpSupportChecker : public IOpSupportChecker { return true; } + virtual bool IsQuantizedOp(const NodeUnit& /* node_unit */) const { return false; } + virtual int32_t GetMinSupportedNNAPIFeatureLevel(const NodeUnit& /* node_unit */, const OpSupportCheckParams& /* params */) const { // ANEURALNETWORKS_FEATURE_LEVEL_1 is the baseline version of NNAPI, @@ -453,7 +455,7 @@ class BinaryOpSupportChecker : public BaseOpSupportChecker { int GetMinSupportedOpSet(const NodeUnit& node_unit) const override; bool IsNodeUnitTypeSupported(const NodeUnit& node_unit) const override; - static bool IsQuantizedOp(const NodeUnit& node_unit); + bool IsQuantizedOp(const NodeUnit& node_unit) const override; }; /* static */ void BinaryOpSupportChecker::CreateSharedOpSupportChecker( @@ -481,7 +483,7 @@ bool BinaryOpSupportChecker::IsNodeUnitTypeSupported(const NodeUnit& node_unit) return true; } -/* static */ bool BinaryOpSupportChecker::IsQuantizedOp(const NodeUnit& node_unit) { +bool BinaryOpSupportChecker::IsQuantizedOp(const NodeUnit& node_unit) const { const auto quant_type = GetQuantizedOpType(node_unit); return quant_type == QuantizedOpType::QLinearAdd || quant_type == QuantizedOpType::QLinearMul || @@ -593,10 +595,10 @@ class TransposeOpSupportChecker : public BaseOpSupportChecker { const InitializedTensorSet& initializers, const NodeUnit& node_unit, const OpSupportCheckParams& params) const override; bool IsNodeUnitTypeSupported(const NodeUnit& /* node_unit */) const override { return true; } - static bool IsQuantizedOp(const NodeUnit& node_unit) ORT_MUST_USE_RESULT; // TODO, see if we want to move this to BaseOpBuilder + bool IsQuantizedOp(const NodeUnit& node_unit) const override; }; -/* static */ bool TransposeOpSupportChecker::IsQuantizedOp(const NodeUnit& node_unit) { +bool TransposeOpSupportChecker::IsQuantizedOp(const NodeUnit& node_unit) const { return GetQuantizedOpType(node_unit) == QuantizedOpType::QDQTranspose; } @@ -653,8 +655,17 @@ class ReshapeOpSupportChecker : public BaseOpSupportChecker { // Reshape opset 4- uses attributes for new shape which we do not support for now int GetMinSupportedOpSet(const NodeUnit& /* node_unit */) const override { return 5; } + bool HasSupportedInputOutputsImpl( + const InitializedTensorSet& /* initializers */, const NodeUnit& node_unit, + const OpSupportCheckParams& /* params */) const override; + bool IsNodeUnitTypeSupported(const NodeUnit& /* node_unit */) const override { return true; } + bool IsQuantizedOp(const NodeUnit& node_unit) const override; }; +bool ReshapeOpSupportChecker::IsQuantizedOp(const NodeUnit& node_unit) const { + return GetQuantizedOpType(node_unit) == QuantizedOpType::QDQReshape; +} + bool ReshapeOpSupportChecker::IsOpSupportedImpl(const InitializedTensorSet& initializers, const NodeUnit& node_unit, const OpSupportCheckParams& /* params */) const { const auto& inputs = node_unit.Inputs(); @@ -685,7 +696,7 @@ bool ReshapeOpSupportChecker::IsOpSupportedImpl(const InitializedTensorSet& init const auto perm_size = SafeInt(perm_tensor.dims()[0]); NodeAttrHelper helper(node_unit); - const bool allow_zero = helper.Get("allowzero ", 0) == 1; + const bool allow_zero = helper.Get("allowzero", 0) == 1; for (uint32_t i = 0; i < perm_size; i++) { // NNAPI reshape does not support 0 as dimension if (raw_perm[i] == 0) { @@ -704,6 +715,24 @@ bool ReshapeOpSupportChecker::IsOpSupportedImpl(const InitializedTensorSet& init return true; } +bool ReshapeOpSupportChecker::HasSupportedInputOutputsImpl( + const InitializedTensorSet& initializers, const NodeUnit& node_unit, + const OpSupportCheckParams& params) const { + if (!IsQuantizedOp(node_unit)) { + return BaseOpSupportChecker::HasSupportedInputOutputsImpl(initializers, node_unit, params); + } + + if (!IsQuantizedIOSupported(initializers, node_unit, {0}, params, IOKind::Input)) { + return false; + } + + if (!IsQuantizedIOSupported(initializers, node_unit, {0}, params, IOKind::Output)) { + return false; + } + + return true; +} + #pragma endregion #pragma region op_batchnormalization @@ -790,7 +819,7 @@ class PoolOpSupportChecker : public BaseOpSupportChecker { const InitializedTensorSet& initializers, const NodeUnit& node_unit, const OpSupportCheckParams& params) const override; bool IsNodeUnitTypeSupported(const NodeUnit& /* node_unit */) const override; - static bool IsQuantizedOp(const NodeUnit& node_unit); + bool IsQuantizedOp(const NodeUnit& node_unit) const override; }; /* static */ void PoolOpSupportChecker::CreateSharedOpSupportChecker( @@ -815,7 +844,7 @@ bool PoolOpSupportChecker::IsNodeUnitTypeSupported(const NodeUnit& node_unit) co return true; } -/* static */ bool PoolOpSupportChecker::IsQuantizedOp(const NodeUnit& node_unit) { +bool PoolOpSupportChecker::IsQuantizedOp(const NodeUnit& node_unit) const { return IsQuantizedPool(GetQuantizedOpType(node_unit)); } @@ -979,7 +1008,7 @@ class ConvOpSupportChecker : public BaseOpSupportChecker { const InitializedTensorSet& /* initializers */, const NodeUnit& node_unit, const OpSupportCheckParams& /* params */) const override; bool IsNodeUnitTypeSupported(const NodeUnit& /* node_unit */) const override { return true; } - static bool IsQuantizedOp(const NodeUnit& node_unit); + bool IsQuantizedOp(const NodeUnit& node_unit) const override; }; /* static */ void ConvOpSupportChecker::CreateSharedOpSupportChecker( @@ -992,7 +1021,7 @@ class ConvOpSupportChecker : public BaseOpSupportChecker { }); } -/* static */ bool ConvOpSupportChecker::IsQuantizedOp(const NodeUnit& node_unit) { +bool ConvOpSupportChecker::IsQuantizedOp(const NodeUnit& node_unit) const { return IsQuantizedConv(GetQuantizedOpType(node_unit)); } @@ -1642,10 +1671,10 @@ class ResizeOpSupportChecker : public BaseOpSupportChecker { const InitializedTensorSet& /* initializers */, const NodeUnit& node_unit, const OpSupportCheckParams& /* params */) const override; bool IsNodeUnitTypeSupported(const NodeUnit& /* node_unit */) const override { return true; } - static bool IsQuantizedOp(const NodeUnit& node_unit) ORT_MUST_USE_RESULT; // TODO, see if we want to move this to BaseOpBuilder + bool IsQuantizedOp(const NodeUnit& node_unit) const override; }; -/* static */ bool ResizeOpSupportChecker::IsQuantizedOp(const NodeUnit& node_unit) { +bool ResizeOpSupportChecker::IsQuantizedOp(const NodeUnit& node_unit) const { return GetQuantizedOpType(node_unit) == QuantizedOpType::QDQResize; } diff --git a/onnxruntime/test/optimizer/qdq_test_utils.cc b/onnxruntime/test/optimizer/qdq_test_utils.cc index 08c1c751991b4..d40889306bede 100644 --- a/onnxruntime/test/optimizer/qdq_test_utils.cc +++ b/onnxruntime/test/optimizer/qdq_test_utils.cc @@ -36,5 +36,27 @@ GetQDQTestCaseFn BuildQDQResizeTestCase( }; } +GetQDQTestCaseFn BuildQDQReshapeTestCase(const std::vector& input_shape, + const std::vector& reshape_shape) { + return [input_shape, reshape_shape](ModelTestBuilder& builder) { + auto* input_arg = builder.MakeInput(input_shape, + std::numeric_limits::min(), + std::numeric_limits::max()); + auto* output_arg = builder.MakeOutput(); + + // add DQ + auto* dq_output = builder.MakeIntermediate(); + builder.AddDequantizeLinearNode(input_arg, .003f, 1, dq_output); + + // add Reshape + auto* reshape_output = builder.MakeIntermediate(); + auto* shape = builder.Make1DInitializer(reshape_shape); + builder.AddNode("Reshape", {dq_output, shape}, {reshape_output}); + + // add Q + builder.AddQuantizeLinearNode(reshape_output, .003f, 1, output_arg); + }; +} + } // namespace test } // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/test/optimizer/qdq_test_utils.h b/onnxruntime/test/optimizer/qdq_test_utils.h index 1a0c0f044cb34..b7f8dee81b99a 100644 --- a/onnxruntime/test/optimizer/qdq_test_utils.h +++ b/onnxruntime/test/optimizer/qdq_test_utils.h @@ -212,5 +212,8 @@ GetQDQTestCaseFn BuildQDQTransposeTestCase( builder.AddQuantizeLinearNode(transpose_output, .003f, q_zp, output_arg); }; } + +GetQDQTestCaseFn BuildQDQReshapeTestCase(const std::vector& input_shape, + const std::vector& reshape_shape); } // namespace test } // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/test/providers/nnapi/nnapi_basic_test.cc b/onnxruntime/test/providers/nnapi/nnapi_basic_test.cc index a136e0b22c11a..bdbde8c8b00a3 100644 --- a/onnxruntime/test/providers/nnapi/nnapi_basic_test.cc +++ b/onnxruntime/test/providers/nnapi/nnapi_basic_test.cc @@ -370,6 +370,15 @@ TEST(NnapiExecutionProviderTest, TestQDQTranspose) { }); } +TEST(NnapiExecutionProviderTest, TestQDQReshape) { + RunQDQModelTest(BuildQDQReshapeTestCase({1, 3, 64, 64} /* input_shape */, + {1, 64, 64, 3} /* reshape_shape */), + "nnapi_qdq_test_graph_reshape", + { + true /* verify_entire_graph_use_ep */ + }); +} + #endif // !(ORT_MINIMAL_BUILD) TEST(NnapiExecutionProviderTest, NNAPIFlagsTest) {