diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/helper.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/helper.cc index ee220b2a7ee45..b41d7dd6fddef 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/helper.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/helper.cc @@ -158,6 +158,201 @@ bool HasValidBinaryOpQuantizedInputs(const NodeUnit& node_unit) { return true; } +bool IsQuantizationScaleSupported(const InitializedTensorSet& initializers, + const NodeUnitIODef& io_def, + const OpSupportCheckParams& params, + const std::string& op_type, + bool is_quant_matmul, + bool is_conv_matmul_weight, + bool is_conv_matmul_u8s8_weight) { + const auto scale_name = io_def.quant_param->scale.Name(); + if (!Contains(initializers, scale_name)) { + LOGS_DEFAULT(VERBOSE) << "The scale of " << op_type << " must be an initializer tensor"; + return false; + } + + const auto& scale_tensor = *initializers.at(scale_name); + int64_t scales_dim = scale_tensor.dims().empty() ? 1 : scale_tensor.dims()[0]; + if (!is_conv_matmul_u8s8_weight) { + if (scales_dim != 1) { + LOGS_DEFAULT(VERBOSE) << op_type << " does not support per-channel quantization, " + << " for now, only u8s8 QlinearConv supports per-channel quantization on API 29+"; + return false; + } + } else if (scales_dim != 1) { + // For u8s8 Qlinear[Conv/MatMul], we support + // 1. Per-tensor, the weight will be transformed to uint8 later + // 2. Per-channel, only from Android API level 29 + if (is_quant_matmul) { + LOGS_DEFAULT(VERBOSE) << "QLinearMatMul does not support per-channel quantization"; + return false; + } + + if (params.android_feature_level < ANEURALNETWORKS_FEATURE_LEVEL_3) { + LOGS_DEFAULT(VERBOSE) << op_type << " only supports per-channel quantization on Android API 29+, " + << "system NNAPI feature level: " << params.android_feature_level; + return false; + } + + Shape weight_shape; + if (!GetShape(io_def.node_arg, weight_shape)) + return false; + + if (weight_shape[0] != scales_dim) { + LOGS_DEFAULT(VERBOSE) << op_type << " mismatch int8 per-channel quantization weight," + << " weight dimension[0] " << weight_shape[0] + << " scale dimension " << scales_dim; + return false; + } + } + + return true; +} + +bool IsQuantizationZeroPointSupported(const InitializedTensorSet& initializers, + const NodeUnitIODef& io_def, + const OpSupportCheckParams& params, + const std::string& op_type, + const Path& model_path, + bool is_quant_matmul, + bool is_conv_matmul_weight, + bool is_conv_matmul_u8s8_weight) { + // zero point is optional here + if (!io_def.quant_param->zero_point) + return true; + + const auto& zero_point_name = io_def.quant_param->zero_point->Name(); + if (!Contains(initializers, zero_point_name)) { + LOGS_DEFAULT(VERBOSE) << "The zero point of " << op_type << " must be an initializer tensor"; + return false; + } + + const auto& zero_tensor = *initializers.at(zero_point_name); + int64_t zero_dim = zero_tensor.dims().empty() ? 1 : zero_tensor.dims()[0]; + + if (!is_conv_matmul_u8s8_weight) { + if (zero_dim != 1) { + LOGS_DEFAULT(VERBOSE) << op_type << " does not support per-channel quantization, " + << " for now, only u8s8 QlinearConv supports per-channel quantization on API 29+"; + return false; + } + } else { + // For u8s8 Qlinear[Conv/MatMul], we support + // 1. Per-tensor, the weight will be transformed to uint8 later + // 2. Per-channel, only from Android API level 29 + if (zero_tensor.data_type() != ONNX_NAMESPACE::TensorProto_DataType_INT8) { + LOGS_DEFAULT(VERBOSE) << "u8s8 Qlinear[Conv/MatMul] only supports int8 zero point for weight, " + << "actual zero point type: [" << zero_tensor.data_type() << "]"; + return false; + } + + if (zero_dim != 1) { + if (is_quant_matmul) { + LOGS_DEFAULT(VERBOSE) << "QLinearMatMul does not support per-channel quantization"; + return false; + } + } + + // For onnx, u8s8 QlinearConv, the weight zero point can be a scalar, + // or a tensor with same channel as weight, for NNAPI we only support it be + // 0 (scalar) or all 0 (tensor), NNAPI will assume the zero point for per-channel + // quantization is 0 there is no input for it + Shape weight_shape; + if (!GetShape(io_def.node_arg, weight_shape)) + return false; + + if (weight_shape[0] != zero_dim && zero_dim != 1) { + LOGS_DEFAULT(VERBOSE) << op_type << " mismatch int8 per-channel quantization weight," + << " weight dimension[0] " << weight_shape[0] + << " zero point dimension " << zero_dim; + return false; + } + + std::vector unpacked_tensor; + auto status = onnxruntime::utils::UnpackInitializerData(zero_tensor, model_path, unpacked_tensor); + if (!status.IsOK()) { + LOGS_DEFAULT(ERROR) << "Qlinear[Conv/MatMul] error when unpack zero tensor: " << zero_point_name + << ", error msg: " << status.ErrorMessage(); + return false; + } + + // Verify all onnx weight zero point(s) are 0(s) + const int8_t* zero_points = reinterpret_cast(unpacked_tensor.data()); + for (size_t i = 0; i < unpacked_tensor.size(); i++) { + if (zero_points[i] != 0) { + LOGS_DEFAULT(VERBOSE) << "u8s8 Qlinear[Conv/MatMul] only support 0 as zero point, " + << "zero_points[" << i << "] has value: " << zero_points[i]; + return false; + } + } + } + + return true; +} + +bool IsQuantizedIOSupported(const InitializedTensorSet& initializers, const NodeUnit& node_unit, + const std::vector& indices, const OpSupportCheckParams& params, bool is_input) { + const auto& op_type = node_unit.OpType(); + auto quant_op_type = GetQuantizedOpType(node_unit); + + ORT_ENFORCE(quant_op_type != QuantizedOpType::QLinearMatMul, "[", op_type, "] is not a quantized op"); + + bool is_quant_conv = IsQuantizedConv(quant_op_type); + bool is_quant_matmul = (quant_op_type == QuantizedOpType::QLinearMatMul); + const auto& io_defs = is_input ? node_unit.Inputs() : node_unit.Outputs(); + + for (const auto idx : indices) { + if (idx >= io_defs.size()) { + LOGS_DEFAULT(VERBOSE) << (is_input ? "Input" : "Output") << " index, " << idx + << " >= size, " << io_defs.size() + << " of NodeUnit: " << node_unit.Name(); + return false; + } + + const auto& io_def = io_defs[idx]; + ORT_ENFORCE(io_def.quant_param.has_value(), "Input index, ", idx, " has no quant_param"); + + // If this op is Qlinear[Conv/MatMul], we want to check u8s8 support for weight tensor (or B tensor for QlinearMatMul) + bool is_conv_matmul_weight = is_input && (is_quant_conv || is_quant_matmul) && idx == 1; + bool is_conv_matmul_u8s8_weight = false; + + if (is_conv_matmul_weight) { + int32_t weight_type; + if (!GetType(io_def.node_arg, weight_type)) + return false; + is_conv_matmul_u8s8_weight = weight_type == ONNX_NAMESPACE::TensorProto_DataType_INT8; + } + + int32_t input_type; + if (!GetType(io_def.node_arg, input_type)) + return false; + + // We only support s8 for most of the inputs and all outputs, with the exception for Quantized MatMul and Conv, + // which allows s8 weight (u8s8) + // TODO, add support of s8s8 + if (input_type != ONNX_NAMESPACE::TensorProto_DataType_INT8 && !is_conv_matmul_u8s8_weight) { + LOGS_DEFAULT(VERBOSE) << op_type << "NodeUnit [" << node_unit.Name() + << "], type [" << op_type << "]'s " + << (is_input ? "Input" : "Output") << " index [" << idx + << "] has unspoorted type [" << input_type << "]"; + return false; + } + + // Check scale and zero point + if (!IsQuantizationScaleSupported(initializers, io_def, params, op_type, + is_quant_matmul, is_conv_matmul_weight, is_conv_matmul_u8s8_weight)) { + return false; + } + + if (!IsQuantizationZeroPointSupported(initializers, io_def, params, op_type, node_unit.ModelPath(), + is_quant_matmul, is_conv_matmul_weight, is_conv_matmul_u8s8_weight)) { + return false; + } + } + + return true; +} + bool HasValidQuantizationScales(const InitializedTensorSet& initializers, const NodeUnit& node_unit, const std::vector& indices, const OpSupportCheckParams& params, bool is_input) { const auto& op_type = node_unit.OpType(); @@ -192,8 +387,10 @@ bool HasValidQuantizationScales(const InitializedTensorSet& initializers, const bool is_conv_matmul_u8s8_weight = false; if (is_conv_matmul_weight) { - const auto& weight_tensor = *initializers.at(io_def.node_arg.Name()); - is_conv_matmul_u8s8_weight = weight_tensor.data_type() == ONNX_NAMESPACE::TensorProto_DataType_INT8; + int32_t weight_type; + if (!GetType(io_def.node_arg, weight_type)) + return false; + is_conv_matmul_u8s8_weight = weight_type == ONNX_NAMESPACE::TensorProto_DataType_INT8; } const auto& scale_tensor = *initializers.at(scale_name); @@ -219,10 +416,13 @@ bool HasValidQuantizationScales(const InitializedTensorSet& initializers, const return false; } - const auto& weight_tensor = *initializers.at(io_def.node_arg.Name()); - if (weight_tensor.dims()[0] != scales_dim) { + Shape weight_shape; + if (!GetShape(io_def.node_arg, weight_shape)) + return false; + + if (weight_shape[0] != scales_dim) { LOGS_DEFAULT(VERBOSE) << op_type << " mismatch int8 per-channel quantization weight," - << " weight dimension[0] " << weight_tensor.dims()[0] + << " weight dimension[0] " << weight_shape[0] << " scale dimension " << scales_dim; return false; } @@ -269,8 +469,10 @@ bool HasValidQuantizationZeroPoints(const InitializedTensorSet& initializers, co bool is_conv_matmul_u8s8_weight = false; if (is_conv_matmul_weight) { - const auto& weight_tensor = *initializers.at(io_def.node_arg.Name()); - is_conv_matmul_u8s8_weight = weight_tensor.data_type() == ONNX_NAMESPACE::TensorProto_DataType_INT8; + int32_t weight_type; + if (!GetType(io_def.node_arg, weight_type)) + return false; + is_conv_matmul_u8s8_weight = weight_type == ONNX_NAMESPACE::TensorProto_DataType_INT8; } const auto& zero_tensor = *initializers.at(zero_point_name); @@ -303,10 +505,13 @@ bool HasValidQuantizationZeroPoints(const InitializedTensorSet& initializers, co // or a tensor with same channel as weight, for NNAPI we only support it be // 0 (scalar) or all 0 (tensor), NNAPI will assume the zero point for per-channel // quantization is 0 there is no input for it - const auto& weight_tensor = *initializers.at(io_def.node_arg.Name()); - if (weight_tensor.dims()[0] != zero_dim && zero_dim != 1) { + Shape weight_shape; + if (!GetShape(io_def.node_arg, weight_shape)) + return false; + + if (weight_shape[0] != zero_dim && zero_dim != 1) { LOGS_DEFAULT(VERBOSE) << op_type << " mismatch int8 per-channel quantization weight," - << " weight dimension[0] " << weight_tensor.dims()[0] + << " weight dimension[0] " << weight_shape[0] << " zero point dimension " << zero_dim; return false; } diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/helper.h b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/helper.h index c5b3e1106d966..de32a895e1e18 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/helper.h +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/helper.h @@ -114,6 +114,10 @@ bool HasValidUnaryOpQuantizedInputs(const NodeUnit& node_unit); // Check if a qlinear binary op has valid inputs, Qlinear[Conv/MatMul/Add] bool HasValidBinaryOpQuantizedInputs(const NodeUnit& node_unit); +// Check if the given quantized input(s) or output(s) is supported +bool IsQuantizedIOSupported(const InitializedTensorSet& initializers, const NodeUnit& node_unit, + const std::vector& indices, const OpSupportCheckParams& params, bool is_input); + // Check if a qlinear op has valid scales for given indices bool HasValidQuantizationScales(const InitializedTensorSet& initializers, const NodeUnit& node_unit, const std::vector& indices, const OpSupportCheckParams& params, bool is_input); diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_support_checker.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_support_checker.cc index a8b299e6d1ba8..acebb29400f54 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_support_checker.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_support_checker.cc @@ -100,7 +100,7 @@ class BaseOpSupportChecker : public IOpSupportChecker { return ANEURALNETWORKS_FEATURE_LEVEL_1; } - virtual bool HasSupportedInputsImpl(const NodeUnit& node_unit) const; + virtual bool HasSupportedInputOutputsImpl(const NodeUnit& node_unit) const; virtual int GetMinSupportedOpSet(const NodeUnit& /* node_unit */) const { return 1; } virtual int GetMaxSupportedOpSet(const NodeUnit& /* node_unit */) const { return 15; } @@ -112,7 +112,7 @@ class BaseOpSupportChecker : public IOpSupportChecker { private: bool HasSupportedOpSet(const NodeUnit& node_unit) const; - bool HasSupportedInputs(const NodeUnit& node_unit) const; + bool HasSupportedInputOutputs(const NodeUnit& node_unit) const; }; /* static */ void BaseOpSupportChecker::CreateSharedOpSupportChecker( @@ -138,7 +138,7 @@ bool BaseOpSupportChecker::IsOpSupported(const InitializedTensorSet& initializer if (!IsNodeUnitTypeSupported(node_unit)) return false; - if (!HasSupportedInputs(node_unit)) + if (!HasSupportedInputOutputs(node_unit)) return false; // We do not support external initializers for now @@ -151,7 +151,7 @@ bool BaseOpSupportChecker::IsOpSupported(const InitializedTensorSet& initializer return IsOpSupportedImpl(initializers, node_unit, params); } -bool BaseOpSupportChecker::HasSupportedInputs(const NodeUnit& node_unit) const { +bool BaseOpSupportChecker::HasSupportedInputOutputs(const NodeUnit& node_unit) const { // We do not support unknown(null) input shape auto has_shape = [](const NodeArg& node_arg, const std::string& name, const std::string op_type) { if (!node_arg.Shape()) { @@ -176,10 +176,10 @@ bool BaseOpSupportChecker::HasSupportedInputs(const NodeUnit& node_unit) const { return false; } } - return HasSupportedInputsImpl(node_unit); + return HasSupportedInputOutputsImpl(node_unit); } -bool BaseOpSupportChecker::HasSupportedInputsImpl(const NodeUnit& node_unit) const { +bool BaseOpSupportChecker::HasSupportedInputOutputsImpl(const NodeUnit& node_unit) const { // We only check the type of input 0 by default // specific op builder can override this const auto& input = node_unit.Inputs()[0].node_arg; @@ -236,7 +236,7 @@ class BinaryOpSupportChecker : public BaseOpSupportChecker { const OpSupportCheckParams& params) const override; bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const NodeUnit& node_unit, const OpSupportCheckParams& params) const override; - bool HasSupportedInputsImpl(const NodeUnit& node_unit) const override; + bool HasSupportedInputOutputsImpl(const NodeUnit& node_unit) const override; int GetMinSupportedOpSet(const NodeUnit& node_unit) const override; }; @@ -278,11 +278,11 @@ int BinaryOpSupportChecker::GetMinSupportedOpSet(const NodeUnit& node_unit) cons return 1; } -bool BinaryOpSupportChecker::HasSupportedInputsImpl(const NodeUnit& node_unit) const { +bool BinaryOpSupportChecker::HasSupportedInputOutputsImpl(const NodeUnit& node_unit) const { bool is_qlinear_add = node_unit.OpType() == "QLinearAdd"; bool is_pow = node_unit.OpType() == "Pow"; if (!is_qlinear_add && !is_pow) - return BaseOpSupportChecker::HasSupportedInputsImpl(node_unit); + return BaseOpSupportChecker::HasSupportedInputOutputsImpl(node_unit); if (is_qlinear_add) { // QLinearAdd @@ -373,7 +373,7 @@ class TransposeOpSupportChecker : public BaseOpSupportChecker { return ANEURALNETWORKS_FEATURE_LEVEL_2; } - bool HasSupportedInputsImpl(const NodeUnit& node_unit) const override; + bool HasSupportedInputOutputsImpl(const NodeUnit& node_unit) const override; }; bool TransposeOpSupportChecker::IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const NodeUnit& node_unit, @@ -392,7 +392,7 @@ bool TransposeOpSupportChecker::IsOpSupportedImpl(const InitializedTensorSet& /* return true; } -bool TransposeOpSupportChecker::HasSupportedInputsImpl(const NodeUnit& node_unit) const { +bool TransposeOpSupportChecker::HasSupportedInputOutputsImpl(const NodeUnit& node_unit) const { int32_t input_type; if (!GetType(node_unit.Inputs()[0].node_arg, input_type)) return false; @@ -552,7 +552,7 @@ class PoolOpSupportChecker : public BaseOpSupportChecker { return params.use_nchw ? ANEURALNETWORKS_FEATURE_LEVEL_3 : ANEURALNETWORKS_FEATURE_LEVEL_2; } - bool HasSupportedInputsImpl(const NodeUnit& node_unit) const override; + bool HasSupportedInputOutputsImpl(const NodeUnit& node_unit) const override; }; /* static */ void PoolOpSupportChecker::CreateSharedOpSupportChecker( @@ -682,11 +682,11 @@ bool PoolOpSupportChecker::IsOpSupportedImpl(const InitializedTensorSet& initial return true; } -bool PoolOpSupportChecker::HasSupportedInputsImpl(const NodeUnit& node_unit) const { +bool PoolOpSupportChecker::HasSupportedInputOutputsImpl(const NodeUnit& node_unit) const { bool is_max_pool = node_unit.OpType() == "MaxPool"; bool is_qlinear_average_pool = node_unit.OpType() == "QLinearAveragePool"; if (!is_max_pool && !is_qlinear_average_pool) - return BaseOpSupportChecker::HasSupportedInputsImpl(node_unit); + return BaseOpSupportChecker::HasSupportedInputOutputsImpl(node_unit); if (is_qlinear_average_pool) { return HasValidUnaryOpQuantizedInputs(node_unit); @@ -727,7 +727,7 @@ class ConvOpSupportChecker : public BaseOpSupportChecker { return params.use_nchw ? ANEURALNETWORKS_FEATURE_LEVEL_3 : ANEURALNETWORKS_FEATURE_LEVEL_2; } - bool HasSupportedInputsImpl(const NodeUnit& node_unit) const override; + bool HasSupportedInputOutputsImpl(const NodeUnit& node_unit) const override; bool IsNodeUnitTypeSupported(const NodeUnit& /* node_unit */) const override { return true; } static bool IsQuantizedOp(const NodeUnit& node_unit); }; @@ -746,9 +746,9 @@ class ConvOpSupportChecker : public BaseOpSupportChecker { return IsQuantizedConv(GetQuantizedOpType(node_unit)); } -bool ConvOpSupportChecker::HasSupportedInputsImpl(const NodeUnit& node_unit) const { +bool ConvOpSupportChecker::HasSupportedInputOutputsImpl(const NodeUnit& node_unit) const { if (!IsQuantizedOp(node_unit)) - return BaseOpSupportChecker::HasSupportedInputsImpl(node_unit); + return BaseOpSupportChecker::HasSupportedInputOutputsImpl(node_unit); // QLinearConv only supports input of uint8 for now if (!HasValidBinaryOpQuantizedInputs(node_unit)) @@ -916,13 +916,13 @@ class GemmOpSupportChecker : public BaseOpSupportChecker { private: bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const NodeUnit& node_unit, const OpSupportCheckParams& params) const override; - bool HasSupportedInputsImpl(const NodeUnit& node_unit) const override; + bool HasSupportedInputOutputsImpl(const NodeUnit& node_unit) const override; int GetMinSupportedOpSet(const NodeUnit& node_unit) const override; }; -bool GemmOpSupportChecker::HasSupportedInputsImpl(const NodeUnit& node_unit) const { +bool GemmOpSupportChecker::HasSupportedInputOutputsImpl(const NodeUnit& node_unit) const { if (node_unit.OpType() != "QLinearMatMul") - return BaseOpSupportChecker::HasSupportedInputsImpl(node_unit); + return BaseOpSupportChecker::HasSupportedInputOutputsImpl(node_unit); // QLinearMatMul if (!HasValidBinaryOpQuantizedInputs(node_unit)) @@ -1112,7 +1112,7 @@ class UnaryOpSupportChecker : public BaseOpSupportChecker { int32_t GetMinSupportedNNAPIFeatureLevel(const NodeUnit& /* node_unit */, const OpSupportCheckParams& params) const override; - bool HasSupportedInputsImpl(const NodeUnit& node_unit) const override; + bool HasSupportedInputOutputsImpl(const NodeUnit& node_unit) const override; int GetMinSupportedOpSet(const NodeUnit& node_unit) const override; @@ -1161,10 +1161,10 @@ int32_t UnaryOpSupportChecker::GetMinSupportedNNAPIFeatureLevel(const NodeUnit& return ANEURALNETWORKS_FEATURE_LEVEL_1; } -bool UnaryOpSupportChecker::HasSupportedInputsImpl(const NodeUnit& node_unit) const { +bool UnaryOpSupportChecker::HasSupportedInputOutputsImpl(const NodeUnit& node_unit) const { // We only need to override input check for QLinearSigmoid if (node_unit.OpType() != "QLinearSigmoid") - return BaseOpSupportChecker::HasSupportedInputsImpl(node_unit); + return BaseOpSupportChecker::HasSupportedInputOutputsImpl(node_unit); return HasValidUnaryOpQuantizedInputs(node_unit); } @@ -1234,7 +1234,7 @@ class ConcatOpSupportChecker : public BaseOpSupportChecker { bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const NodeUnit& node_unit, const OpSupportCheckParams& params) const override; - bool HasSupportedInputsImpl(const NodeUnit& node_unit) const override; + bool HasSupportedInputOutputsImpl(const NodeUnit& node_unit) const override; }; bool ConcatOpSupportChecker::IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const NodeUnit& node_unit, @@ -1253,7 +1253,7 @@ bool ConcatOpSupportChecker::IsOpSupportedImpl(const InitializedTensorSet& /* in return true; } -bool ConcatOpSupportChecker::HasSupportedInputsImpl(const NodeUnit& node_unit) const { +bool ConcatOpSupportChecker::HasSupportedInputOutputsImpl(const NodeUnit& node_unit) const { int32_t input_type; if (!GetType(node_unit.Inputs()[0].node_arg, input_type)) return false; @@ -1361,7 +1361,7 @@ class DequantizeLinearOpSupportChecker : public BaseOpSupportChecker { const OpSupportCheckParams& /* params */) const override { return ANEURALNETWORKS_FEATURE_LEVEL_1; } - bool HasSupportedInputsImpl(const NodeUnit& node_unit) const override; + bool HasSupportedInputOutputsImpl(const NodeUnit& node_unit) const override; }; bool DequantizeLinearOpSupportChecker::IsOpSupportedImpl(const InitializedTensorSet& initializers, const NodeUnit& node_unit, @@ -1376,7 +1376,7 @@ bool DequantizeLinearOpSupportChecker::IsOpSupportedImpl(const InitializedTensor return true; } -bool DequantizeLinearOpSupportChecker::HasSupportedInputsImpl(const NodeUnit& node_unit) const { +bool DequantizeLinearOpSupportChecker::HasSupportedInputOutputsImpl(const NodeUnit& node_unit) const { int32_t input_type; if (!GetType(node_unit.Inputs()[0].node_arg, input_type)) return false; @@ -1465,7 +1465,7 @@ class ResizeOpSupportChecker : public BaseOpSupportChecker { // We only support Resize opset 11+ here int GetMinSupportedOpSet(const NodeUnit& /* node_unit */) const override { return 11; } - bool HasSupportedInputsImpl(const NodeUnit& node_unit) const override; + bool HasSupportedInputOutputsImpl(const NodeUnit& node_unit) const override; bool IsNodeUnitTypeSupported(const NodeUnit& /* node_unit */) const override { return true; } static bool IsQuantizedOp(const NodeUnit& node_unit) ORT_MUST_USE_RESULT; // TODO, see if we want to move this to BaseOpBuilder }; @@ -1638,7 +1638,7 @@ int32_t ResizeOpSupportChecker::GetMinSupportedNNAPIFeatureLevel(const NodeUnit& return ANEURALNETWORKS_FEATURE_LEVEL_2; } -bool ResizeOpSupportChecker::HasSupportedInputsImpl(const NodeUnit& node_unit) const { +bool ResizeOpSupportChecker::HasSupportedInputOutputsImpl(const NodeUnit& node_unit) const { int32_t input_type; if (!GetType(node_unit.Inputs()[0].node_arg, input_type)) return false; diff --git a/onnxruntime/test/optimizer/qdq_transformer_test.cc b/onnxruntime/test/optimizer/qdq_transformer_test.cc index bca4eb737c81c..26c2a8e1ccf8a 100644 --- a/onnxruntime/test/optimizer/qdq_transformer_test.cc +++ b/onnxruntime/test/optimizer/qdq_transformer_test.cc @@ -39,7 +39,7 @@ namespace test { template void QDQTransformerConvTests() { auto test_case = [&](const std::vector& input_shape, const std::vector& weights_shape) { - auto check_conv_graph = [&](InferenceSessionWrapper& session) { + auto check_graph = [&](InferenceSessionWrapper& session) { auto op_to_count = CountOpsInGraph(session.GetGraph()); if constexpr (std::is_same::value && std::is_same::value && @@ -57,7 +57,7 @@ void QDQTransformerConvTests() { }; TransformerTester(BuildQDQConvTestCase(input_shape, weights_shape), - check_conv_graph, + check_graph, TransformerLevel::Level1, TransformerLevel::Level2, 12 /*opset_version*/, @@ -136,7 +136,7 @@ TEST(QDQTransformerTests, ConvMaxPoolReshape_UInt8) { builder.AddQuantizeLinearNode(reshape_output, .0039f, 135, output_arg); }; - auto check_mp_reshape_graph = [&](InferenceSessionWrapper& session) { + auto check_graph = [&](InferenceSessionWrapper& session) { auto op_to_count = CountOpsInGraph(session.GetGraph()); EXPECT_EQ(op_to_count["QLinearConv"], 1); EXPECT_EQ(op_to_count["MaxPool"], 1); @@ -146,7 +146,7 @@ TEST(QDQTransformerTests, ConvMaxPoolReshape_UInt8) { }; TransformerTester(build_test_case, - check_mp_reshape_graph, + check_graph, TransformerLevel::Level1, TransformerLevel::Level2, opset_version); @@ -197,7 +197,7 @@ TEST(QDQTransformerTests, ConvMaxPoolReshape_Int8) { } }; - auto check_mp_reshape_graph = [&](InferenceSessionWrapper& session) { + auto check_graph = [&](InferenceSessionWrapper& session) { auto op_to_count = CountOpsInGraph(session.GetGraph()); EXPECT_EQ(op_to_count["QLinearConv"], 1); EXPECT_EQ(op_to_count["MaxPool"], 1); @@ -206,7 +206,7 @@ TEST(QDQTransformerTests, ConvMaxPoolReshape_Int8) { EXPECT_EQ(op_to_count["DequantizeLinear"], 0); }; - TransformerTester(build_test_case, check_mp_reshape_graph, TransformerLevel::Level1, TransformerLevel::Level2); + TransformerTester(build_test_case, check_graph, TransformerLevel::Level1, TransformerLevel::Level2); }; test_case({1, 12, 37}, {32, 12, 5}); @@ -241,7 +241,7 @@ void QDQTransformerAveragePoolTests() { output_arg); }; - auto check_binary_op_graph = [&](InferenceSessionWrapper& session) { + auto check_graph = [&](InferenceSessionWrapper& session) { auto op_to_count = CountOpsInGraph(session.GetGraph()); if constexpr (std::is_same::value) { EXPECT_EQ(op_to_count["com.microsoft.QLinearAveragePool"], 1); @@ -257,7 +257,7 @@ void QDQTransformerAveragePoolTests() { }; TransformerTester(build_test_case, - check_binary_op_graph, + check_graph, TransformerLevel::Level1, TransformerLevel::Level2, 12 /*opset_version*/, @@ -335,7 +335,7 @@ void QDQTransformerBinaryOpTests(const std::string& op_type) { output_arg); }; - auto check_binary_op_graph = [&](InferenceSessionWrapper& session) { + auto check_graph = [&](InferenceSessionWrapper& session) { auto op_to_count = CountOpsInGraph(session.GetGraph()); if (std::is_same::value && std::is_same::value) { @@ -352,7 +352,7 @@ void QDQTransformerBinaryOpTests(const std::string& op_type) { }; TransformerTester(build_test_case, - check_binary_op_graph, + check_graph, TransformerLevel::Level1, TransformerLevel::Level2, 12 /*opset_version*/, @@ -450,7 +450,7 @@ void QDQTransformerMatMulTests(bool has_output_q) { } }; - auto check_binary_op_graph = [&](InferenceSessionWrapper& session) { + auto check_graph = [&](InferenceSessionWrapper& session) { auto op_to_count = CountOpsInGraph(session.GetGraph()); if (has_output_q) { if constexpr (std::is_same::value && @@ -483,7 +483,7 @@ void QDQTransformerMatMulTests(bool has_output_q) { }; TransformerTester(build_test_case, - check_binary_op_graph, + check_graph, TransformerLevel::Level1, TransformerLevel::Level2, 12 /*opset_version*/, @@ -554,14 +554,14 @@ TEST(QDQTransformerTests, Gather) { builder.AddQuantizeLinearNode(gather_output, .003f, 1, output_arg); }; - auto check_matmul_graph = [&](InferenceSessionWrapper& session) { + auto check_graph = [&](InferenceSessionWrapper& session) { auto op_to_count = CountOpsInGraph(session.GetGraph()); EXPECT_EQ(op_to_count["Gather"], 1); EXPECT_EQ(op_to_count["QuantizeLinear"], 0); EXPECT_EQ(op_to_count["DequantizeLinear"], 0); }; - TransformerTester(build_test_case, check_matmul_graph, TransformerLevel::Level1, TransformerLevel::Level2); + TransformerTester(build_test_case, check_graph, TransformerLevel::Level1, TransformerLevel::Level2); }; test_case({12, 37}, {24, 12}); @@ -586,14 +586,14 @@ TEST(QDQTransformerTests, Transpose) { builder.AddQuantizeLinearNode(transpose_output, .003f, 1, output_arg); }; - auto check_matmul_graph = [&](InferenceSessionWrapper& session) { + auto check_graph = [&](InferenceSessionWrapper& session) { auto op_to_count = CountOpsInGraph(session.GetGraph()); EXPECT_EQ(op_to_count["Transpose"], 1); EXPECT_EQ(op_to_count["QuantizeLinear"], 0); EXPECT_EQ(op_to_count["DequantizeLinear"], 0); }; - TransformerTester(build_test_case, check_matmul_graph, TransformerLevel::Level1, TransformerLevel::Level2); + TransformerTester(build_test_case, check_graph, TransformerLevel::Level1, TransformerLevel::Level2); }; test_case({2, 13, 12, 37}, {0, 3, 1, 2}); @@ -618,13 +618,13 @@ TEST(QDQTransformerTests, Transpose_No_Fusion) { builder.AddQuantizeLinearNode(transpose_output, .003f, 1, output_arg); }; - auto check_matmul_graph = [&](InferenceSessionWrapper& session) { + auto check_graph = [&](InferenceSessionWrapper& session) { auto op_to_count = CountOpsInGraph(session.GetGraph()); EXPECT_EQ(op_to_count["QuantizeLinear"], 1); EXPECT_EQ(op_to_count["DequantizeLinear"], 1); }; - TransformerTester(build_test_case, check_matmul_graph, TransformerLevel::Level1, TransformerLevel::Level2); + TransformerTester(build_test_case, check_graph, TransformerLevel::Level1, TransformerLevel::Level2); }; test_case({2, 13, 12, 37}, {0, 3, 1, 2}); @@ -633,7 +633,7 @@ TEST(QDQTransformerTests, Transpose_No_Fusion) { TEST(QDQTransformerTests, Resize) { auto test_case = [&](const std::vector& input1_shape, const std::vector& sizes_shape) { - auto check_matmul_graph = [&](InferenceSessionWrapper& session) { + auto check_graph = [&](InferenceSessionWrapper& session) { auto op_to_count = CountOpsInGraph(session.GetGraph()); EXPECT_EQ(op_to_count["Resize"], 1); EXPECT_EQ(op_to_count["QuantizeLinear"], 0); @@ -641,7 +641,7 @@ TEST(QDQTransformerTests, Resize) { }; TransformerTester(BuildQDQResizeTestCase(input1_shape, sizes_shape), - check_matmul_graph, + check_graph, TransformerLevel::Level1, TransformerLevel::Level2); }; @@ -686,7 +686,7 @@ TEST(QDQTransformerTests, Resize_No_Fusion) { builder.AddQuantizeLinearNode(resize_output, .003f, 1, output_arg); }; - auto check_qdq_graph = [&](InferenceSessionWrapper& session) { + auto check_graph = [&](InferenceSessionWrapper& session) { auto op_to_count = CountOpsInGraph(session.GetGraph()); EXPECT_EQ(op_to_count["Resize"], 1); EXPECT_EQ(op_to_count["Concat"], 1); @@ -694,7 +694,7 @@ TEST(QDQTransformerTests, Resize_No_Fusion) { EXPECT_EQ(op_to_count["DequantizeLinear"], 1); }; - TransformerTester(build_test_case, check_qdq_graph, + TransformerTester(build_test_case, check_graph, TransformerLevel::Level1, TransformerLevel::Level2); }; @@ -725,7 +725,7 @@ TEST(QDQTransformerTests, ResizeReshape) { builder.AddNode("Reshape", {qdq_resize_output, reshape_shape}, {output_arg}); }; - auto check_qdq_graph = [&](InferenceSessionWrapper& session) { + auto check_graph = [&](InferenceSessionWrapper& session) { auto op_to_count = CountOpsInGraph(session.GetGraph()); EXPECT_EQ(op_to_count["Resize"], 1); EXPECT_EQ(op_to_count["Reshape"], 1); @@ -733,7 +733,7 @@ TEST(QDQTransformerTests, ResizeReshape) { EXPECT_EQ(op_to_count["DequantizeLinear"], 1); }; - TransformerTester(build_test_case, check_qdq_graph, + TransformerTester(build_test_case, check_graph, TransformerLevel::Level1, TransformerLevel::Level2); }; @@ -763,13 +763,13 @@ TEST(QDQTransformerTests, ArgMax) { argmax_node.AddAttribute("select_last_index", static_cast(select_last_index)); }; - auto check_argmax_graph = [&](InferenceSessionWrapper& session) { + auto check_graph = [&](InferenceSessionWrapper& session) { auto op_to_count = CountOpsInGraph(session.GetGraph()); EXPECT_EQ(op_to_count["ArgMax"], 1); EXPECT_EQ(op_to_count["DequantizeLinear"], 0); }; - TransformerTester(build_test_case, check_argmax_graph, + TransformerTester(build_test_case, check_graph, TransformerLevel::Level1, TransformerLevel::Level2, /* opset_version */ 13); @@ -797,14 +797,14 @@ TEST(QDQTransformerTests, QLinearMatMul) { builder.AddQuantizeLinearNode(matmul_output, .0039f, 135, output_arg); }; - auto check_matmul_graph = [&](InferenceSessionWrapper& session) { + auto check_graph = [&](InferenceSessionWrapper& session) { auto op_to_count = CountOpsInGraph(session.GetGraph()); EXPECT_EQ(op_to_count["QLinearMatMul"], 1); EXPECT_EQ(op_to_count["QuantizeLinear"], 2); EXPECT_EQ(op_to_count["DequantizeLinear"], 0); }; - TransformerTester(build_test_case, check_matmul_graph, TransformerLevel::Level1, TransformerLevel::Level2); + TransformerTester(build_test_case, check_graph, TransformerLevel::Level1, TransformerLevel::Level2); }; test_case({12, 37}, {37, 12}); @@ -828,7 +828,7 @@ TEST(QDQTransformerTests, MatMul_No_Fusion) { builder.AddQuantizeLinearNode(matmul_output, .0039f, 135, output_arg); }; - auto check_matmul_graph = [&](InferenceSessionWrapper& session) { + auto check_graph = [&](InferenceSessionWrapper& session) { auto op_to_count = CountOpsInGraph(session.GetGraph()); EXPECT_EQ(op_to_count["MatMul"], 1); EXPECT_EQ(op_to_count["QLinearMatMul"], 0); @@ -836,7 +836,7 @@ TEST(QDQTransformerTests, MatMul_No_Fusion) { EXPECT_EQ(op_to_count["DequantizeLinear"], 1); }; - TransformerTester(build_test_case, check_matmul_graph, TransformerLevel::Level1, TransformerLevel::Level2); + TransformerTester(build_test_case, check_graph, TransformerLevel::Level1, TransformerLevel::Level2); }; test_case({12, 37}, {37, 12}); @@ -864,7 +864,7 @@ TEST(QDQTransformerTests, MatMul_1st_Input_Int8) { builder.AddQuantizeLinearNode(matmul_output, .0039f, 135, output_arg); }; - auto check_matmul_graph = [&](InferenceSessionWrapper& session) { + auto check_graph = [&](InferenceSessionWrapper& session) { auto op_to_count = CountOpsInGraph(session.GetGraph()); EXPECT_EQ(op_to_count["MatMul"], 1); EXPECT_EQ(op_to_count["QLinearMatMul"], 0); @@ -872,7 +872,7 @@ TEST(QDQTransformerTests, MatMul_1st_Input_Int8) { EXPECT_EQ(op_to_count["DequantizeLinear"], 2); }; - TransformerTester(build_test_case, check_matmul_graph, TransformerLevel::Level1, TransformerLevel::Level2); + TransformerTester(build_test_case, check_graph, TransformerLevel::Level1, TransformerLevel::Level2); }; test_case({12, 37}, {37, 12}); @@ -901,7 +901,7 @@ TEST(QDQTransformerTests, MatMulIntegerToFloat) { builder.AddNode("MatMul", {dq_output_1, dq_output_2}, {output_arg}); }; - auto check_matmul_graph = [&](InferenceSessionWrapper& session) { + auto check_graph = [&](InferenceSessionWrapper& session) { auto op_to_count = CountOpsInGraph(session.GetGraph()); EXPECT_EQ(op_to_count["com.microsoft.MatMulIntegerToFloat"], 1); EXPECT_EQ(op_to_count["QuantizeLinear"], 0); @@ -909,7 +909,7 @@ TEST(QDQTransformerTests, MatMulIntegerToFloat) { }; TransformerTester(build_test_case, - check_matmul_graph, + check_graph, TransformerLevel::Level1, TransformerLevel::Level2, 12 /*opset_version*/, @@ -944,7 +944,7 @@ TEST(QDQTransformerTests, ConvRelu) { builder.AddQuantizeLinearNode(relu_output, .0039f, is_zp_zero ? 0 : 1, output_arg); }; - auto check_mp_reshape_graph = [&](InferenceSessionWrapper& session) { + auto check_graph = [&](InferenceSessionWrapper& session) { auto op_to_count = CountOpsInGraph(session.GetGraph()); if (is_zp_zero) { EXPECT_EQ(op_to_count["QLinearConv"], 1); @@ -962,7 +962,7 @@ TEST(QDQTransformerTests, ConvRelu) { } }; - TransformerTester(build_test_case, check_mp_reshape_graph, TransformerLevel::Level1, TransformerLevel::Level2); + TransformerTester(build_test_case, check_graph, TransformerLevel::Level1, TransformerLevel::Level2); }; test_case({1, 12, 37}, {32, 12, 5}, true); @@ -1008,7 +1008,7 @@ TEST(QDQTransformerTests, ConvAveragePoolReshape_UInt8) { builder.AddDequantizeLinearNode(q_output, .0035f, 135, output_arg); }; - auto check_mp_reshape_graph = [&](InferenceSessionWrapper& session) { + auto check_graph = [&](InferenceSessionWrapper& session) { auto op_to_count = CountOpsInGraph(session.GetGraph()); EXPECT_EQ(op_to_count["QLinearConv"], 1); EXPECT_EQ(op_to_count["com.microsoft.QLinearAveragePool"], 1); @@ -1018,7 +1018,7 @@ TEST(QDQTransformerTests, ConvAveragePoolReshape_UInt8) { }; TransformerTester(build_test_case, - check_mp_reshape_graph, + check_graph, TransformerLevel::Level1, TransformerLevel::Level2, 12 /*opset_version*/, @@ -1071,7 +1071,7 @@ TEST(QDQTransformerTests, ConvAveragePoolReshape_Int8) { } }; - auto check_mp_reshape_graph = [&](InferenceSessionWrapper& session) { + auto check_graph = [&](InferenceSessionWrapper& session) { auto op_to_count = CountOpsInGraph(session.GetGraph()); EXPECT_EQ(op_to_count["QLinearConv"], 1); EXPECT_EQ(op_to_count["com.microsoft.QLinearAveragePool"], 1); @@ -1081,7 +1081,7 @@ TEST(QDQTransformerTests, ConvAveragePoolReshape_Int8) { }; TransformerTester(build_test_case, - check_mp_reshape_graph, + check_graph, TransformerLevel::Level1, TransformerLevel::Level2, 12 /*opset_version*/, @@ -1135,7 +1135,7 @@ TEST(QDQTransformerTests, ConvAveragePoolReshape_Int8_Fail) { } }; - auto check_mp_reshape_graph = [&](InferenceSessionWrapper& session) { + auto check_graph = [&](InferenceSessionWrapper& session) { auto op_to_count = CountOpsInGraph(session.GetGraph()); EXPECT_EQ(op_to_count["Conv"], 1); EXPECT_EQ(op_to_count["QLinearConv"], 0); @@ -1146,7 +1146,7 @@ TEST(QDQTransformerTests, ConvAveragePoolReshape_Int8_Fail) { }; TransformerTester(build_test_case, - check_mp_reshape_graph, + check_graph, TransformerLevel::Level1, TransformerLevel::Level2, 12 /*opset_version*/, @@ -1183,7 +1183,7 @@ void QDQTransformerLeakyReluTests() { output_arg); }; - auto check_binary_op_graph = [&](InferenceSessionWrapper& session) { + auto check_graph = [&](InferenceSessionWrapper& session) { auto op_to_count = CountOpsInGraph(session.GetGraph()); if constexpr (std::is_same::value) { EXPECT_EQ(op_to_count["com.microsoft.QLinearLeakyRelu"], 1); @@ -1199,7 +1199,7 @@ void QDQTransformerLeakyReluTests() { }; TransformerTester(build_test_case, - check_binary_op_graph, + check_graph, TransformerLevel::Level1, TransformerLevel::Level2, 12 /*opset_version*/, @@ -1259,7 +1259,7 @@ TEST(QDQTransformerTests, ConvTranspose_QBackward) { } }; - auto check_mp_reshape_graph = [&](InferenceSessionWrapper& session) { + auto check_graph = [&](InferenceSessionWrapper& session) { auto op_to_count = CountOpsInGraph(session.GetGraph()); EXPECT_EQ(op_to_count["QLinearConv"], 1); EXPECT_EQ(op_to_count["Transpose"], 1); @@ -1268,7 +1268,7 @@ TEST(QDQTransformerTests, ConvTranspose_QBackward) { }; TransformerTester(build_test_case, - check_mp_reshape_graph, + check_graph, TransformerLevel::Level1, TransformerLevel::Level2); }; @@ -1319,7 +1319,7 @@ TEST(QDQTransformerTests, QBackward_MutilpleSteps) { } }; - auto check_mp_reshape_graph = [&](InferenceSessionWrapper& session) { + auto check_graph = [&](InferenceSessionWrapper& session) { auto op_to_count = CountOpsInGraph(session.GetGraph()); EXPECT_EQ(op_to_count["QLinearConv"], 1); EXPECT_EQ(op_to_count["MaxPool"], 1); @@ -1330,7 +1330,7 @@ TEST(QDQTransformerTests, QBackward_MutilpleSteps) { }; TransformerTester(build_test_case, - check_mp_reshape_graph, + check_graph, TransformerLevel::Level1, TransformerLevel::Level2); }; @@ -1370,7 +1370,7 @@ TEST(QDQTransformerTests, ConvTranspose_DQForward) { } }; - auto check_mp_reshape_graph = [&](InferenceSessionWrapper& session) { + auto check_graph = [&](InferenceSessionWrapper& session) { auto op_to_count = CountOpsInGraph(session.GetGraph()); EXPECT_EQ(op_to_count["QLinearConv"], 1); EXPECT_EQ(op_to_count["Transpose"], 1); @@ -1379,7 +1379,7 @@ TEST(QDQTransformerTests, ConvTranspose_DQForward) { }; TransformerTester(build_test_case, - check_mp_reshape_graph, + check_graph, TransformerLevel::Level1, TransformerLevel::Level2); }; @@ -1430,7 +1430,7 @@ TEST(QDQTransformerTests, DQForward_MutilpleSteps) { } }; - auto check_mp_reshape_graph = [&](InferenceSessionWrapper& session) { + auto check_graph = [&](InferenceSessionWrapper& session) { auto op_to_count = CountOpsInGraph(session.GetGraph()); EXPECT_EQ(op_to_count["QLinearConv"], 1); EXPECT_EQ(op_to_count["MaxPool"], 1); @@ -1441,7 +1441,7 @@ TEST(QDQTransformerTests, DQForward_MutilpleSteps) { }; TransformerTester(build_test_case, - check_mp_reshape_graph, + check_graph, TransformerLevel::Level1, TransformerLevel::Level2); }; @@ -1487,7 +1487,7 @@ TEST(QDQTransformerTests, Concat) { } }; - auto check_mp_reshape_graph = [&input_shapes, &has_input_float, &has_input_int8, &has_output_int8](InferenceSessionWrapper& session) { + auto check_graph = [&input_shapes, &has_input_float, &has_input_int8, &has_output_int8](InferenceSessionWrapper& session) { auto op_to_count = CountOpsInGraph(session.GetGraph()); if (has_input_float || has_input_int8 || has_output_int8) { EXPECT_EQ(op_to_count["com.microsoft.QLinearConcat"], 0); @@ -1499,7 +1499,7 @@ TEST(QDQTransformerTests, Concat) { }; TransformerTester(build_test_case, - check_mp_reshape_graph, + check_graph, TransformerLevel::Level1, TransformerLevel::Level2, 12 /*opset_version*/, @@ -1546,7 +1546,7 @@ TEST(QDQTransformerTests, QDQPropagation_QDQCancelOut) { builder.AddNode("Reshape", {maxpool_output, reshape_shape}, {output_arg}); }; - auto check_mp_reshape_graph = [&](InferenceSessionWrapper& session) { + auto check_graph = [&](InferenceSessionWrapper& session) { auto op_to_count = CountOpsInGraph(session.GetGraph()); EXPECT_EQ(op_to_count["MaxPool"], 1); EXPECT_EQ(op_to_count["Reshape"], 1); @@ -1556,7 +1556,7 @@ TEST(QDQTransformerTests, QDQPropagation_QDQCancelOut) { }; TransformerTester(build_test_case, - check_mp_reshape_graph, + check_graph, TransformerLevel::Level1, TransformerLevel::Level2); }; @@ -1582,7 +1582,7 @@ TEST(QDQTransformerTests, QDQPropagation_QDQ_CancelOut_More) { builder.AddQuantizeLinearNode(reshape_output, same_scale ? .004f : .0039f, same_zp ? 129 : 128, output_arg); }; - auto check_mp_reshape_graph = [&](InferenceSessionWrapper& session) { + auto check_graph = [&](InferenceSessionWrapper& session) { auto op_to_count = CountOpsInGraph(session.GetGraph()); EXPECT_EQ(op_to_count["Reshape"], 1); EXPECT_EQ(op_to_count["QuantizeLinear"], same_scale && same_zp ? 1 : 2); @@ -1590,7 +1590,7 @@ TEST(QDQTransformerTests, QDQPropagation_QDQ_CancelOut_More) { }; TransformerTester(build_test_case, - check_mp_reshape_graph, + check_graph, TransformerLevel::Level1, TransformerLevel::Level2); }; @@ -1616,7 +1616,7 @@ TEST(QDQTransformerTests, QDQPropagation_Q_No_Parent) { builder.AddQuantizeLinearNode(transpose_output, .0035f, 135, output_arg); }; - auto check_mp_reshape_graph = [&](InferenceSessionWrapper& session) { + auto check_graph = [&](InferenceSessionWrapper& session) { GraphViewer graph_viewer(session.GetGraph()); const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(); EXPECT_EQ(graph_viewer.GetNode(node_topology_list[0])->OpType(), "QuantizeLinear"); @@ -1624,7 +1624,7 @@ TEST(QDQTransformerTests, QDQPropagation_Q_No_Parent) { }; TransformerTester(build_test_case, - check_mp_reshape_graph, + check_graph, TransformerLevel::Level1, TransformerLevel::Level2); }; @@ -1649,7 +1649,7 @@ TEST(QDQTransformerTests, QDQPropagation_DQ_No_Children) { transpose_node.AddAttribute("perm", perms); }; - auto check_mp_reshape_graph = [&](InferenceSessionWrapper& session) { + auto check_graph = [&](InferenceSessionWrapper& session) { auto op_to_count = CountOpsInGraph(session.GetGraph()); GraphViewer graph_viewer(session.GetGraph()); const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(); @@ -1658,7 +1658,7 @@ TEST(QDQTransformerTests, QDQPropagation_DQ_No_Children) { }; TransformerTester(build_test_case, - check_mp_reshape_graph, + check_graph, TransformerLevel::Level1, TransformerLevel::Level2); }; @@ -1685,7 +1685,7 @@ TEST(QDQTransformerTests, QDQPropagation_Per_Layer_No_Propagation) { transpose_node.AddAttribute("perm", perms); }; - auto check_mp_reshape_graph = [&](InferenceSessionWrapper& session) { + auto check_graph = [&](InferenceSessionWrapper& session) { auto op_to_count = CountOpsInGraph(session.GetGraph()); GraphViewer graph_viewer(session.GetGraph()); const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(); @@ -1694,7 +1694,7 @@ TEST(QDQTransformerTests, QDQPropagation_Per_Layer_No_Propagation) { }; TransformerTester(build_test_case, - check_mp_reshape_graph, + check_graph, TransformerLevel::Level1, TransformerLevel::Level2); }; @@ -1718,14 +1718,14 @@ TEST(QDQTransformerTests, QDQPropagation_DQ_Q) { builder.AddQuantizeLinearNode(dq_output, .0035f, 135, output_arg); }; - auto check_mp_reshape_graph = [&](InferenceSessionWrapper& session) { + auto check_graph = [&](InferenceSessionWrapper& session) { auto op_to_count = CountOpsInGraph(session.GetGraph()); EXPECT_EQ(op_to_count["QuantizeLinear"], 1); EXPECT_EQ(op_to_count["DequantizeLinear"], 1); }; TransformerTester(build_test_case, - check_mp_reshape_graph, + check_graph, TransformerLevel::Level1, TransformerLevel::Level2); };