diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/helper.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/helper.cc index 2fc1afd435279..af93017649e78 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/helper.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/helper.cc @@ -5,22 +5,18 @@ #include #include -#include -#include -#include -#include -#include -#include +#include "helper.h" +#include "core/common/safeint.h" +#include "core/common/logging/logging.h" +#include "core/framework/tensorprotoutils.h" +#include "core/graph/graph.h" +#include "core/graph/graph_viewer.h" +#include "core/providers/common.h" #include "core/providers/shared/node_unit/node_unit.h" #include "core/providers/shared/utils/utils.h" -#include "helper.h" #include "op_support_checker.h" -using onnxruntime::NodeUnit; -using std::string; -using std::vector; - namespace onnxruntime { namespace nnapi { @@ -72,16 +68,11 @@ QLinearOpType GetQLinearOpType(const onnxruntime::Node& node) { return QLinearOpType::Unknown; } -ConvType GetConvType(const onnxruntime::Node& node, const InitializedTensorSet& initializers) { - const auto& op_type = node.OpType(); - bool is_qlinear_conv = (op_type == "QLinearConv"); - ORT_ENFORCE(op_type == "Conv" || is_qlinear_conv); - - NodeAttrHelper helper(node); +ConvType GetConvType(const NodeUnit& node_unit, const InitializedTensorSet& initializers) { + NodeAttrHelper helper(node_unit); const auto group = helper.Get("group", 1); - size_t w_idx = is_qlinear_conv ? 3 : 1; - const auto& weight = node.InputDefs()[w_idx]->Name(); + const auto& weight = node_unit.Inputs()[1].node_arg.Name(); const auto& weight_tensor = *initializers.at(weight); // For ONNX we only have 1 conv ops @@ -104,13 +95,13 @@ bool IsQLinearBinaryOp(QLinearOpType qlinear_op_type) { qlinear_op_type == QLinearOpType::QLinearAdd; } -bool HasValidUnaryOpQuantizedInputs(const Node& node) { +bool HasValidUnaryOpQuantizedInputs(const NodeUnit& node_unit) { int32_t input_type; - if (!GetType(*node.InputDefs()[0], input_type)) + if (!GetType(node_unit.Inputs()[0].node_arg, input_type)) return false; if (input_type != ONNX_NAMESPACE::TensorProto_DataType_UINT8) { - LOGS_DEFAULT(VERBOSE) << "[" << node.OpType() + LOGS_DEFAULT(VERBOSE) << "[" << node_unit.OpType() << "] Input type: [" << input_type << "] is not supported for now"; return false; @@ -119,18 +110,18 @@ bool HasValidUnaryOpQuantizedInputs(const Node& node) { return true; } -bool HasValidBinaryOpQuantizedInputs(const Node& node) { - auto op_type = GetQLinearOpType(node); +bool HasValidBinaryOpQuantizedInputs(const NodeUnit& node_unit) { + auto op_type = GetQLinearOpType(node_unit.GetNode()); int32_t a_input_type, b_input_type; if (!IsQLinearBinaryOp(op_type)) { - LOGS_DEFAULT(VERBOSE) << "[" << node.OpType() << "] is not a binary qlinear op"; + LOGS_DEFAULT(VERBOSE) << "[" << node_unit.OpType() << "] is not a binary qlinear op"; return false; } - const auto input_defs(node.InputDefs()); - if (!GetType(*input_defs[0], a_input_type)) + const auto& inputs = node_unit.Inputs(); + if (!GetType(inputs[0].node_arg, a_input_type)) return false; - if (!GetType(*input_defs[3], b_input_type)) + if (!GetType(inputs[1].node_arg, b_input_type)) return false; // QlinearConv supports u8u8 or u8s8 @@ -143,7 +134,7 @@ bool HasValidBinaryOpQuantizedInputs(const Node& node) { if (a_input_type != ONNX_NAMESPACE::TensorProto_DataType_UINT8 || (!is_qlinear_conv && a_input_type != b_input_type) || (is_qlinear_conv && !has_valid_qlinear_conv_weight)) { - LOGS_DEFAULT(VERBOSE) << "[" << node.OpType() + LOGS_DEFAULT(VERBOSE) << "[" << node_unit.OpType() << "] A Input type: [" << a_input_type << "] B Input type: [" << b_input_type << "] is not supported for now"; @@ -153,32 +144,41 @@ bool HasValidBinaryOpQuantizedInputs(const Node& node) { return true; } -bool HasValidQuantizationScales(const InitializedTensorSet& initializers, const Node& node, - const std::vector& indices, const OpSupportCheckParams& params) { - const auto& op_type = node.OpType(); - auto qlinear_op_type = GetQLinearOpType(node); +bool HasValidQuantizationScales(const InitializedTensorSet& initializers, const NodeUnit& node_unit, + const std::vector& indices, const OpSupportCheckParams& params, bool is_input) { + const auto& op_type = node_unit.OpType(); + auto qlinear_op_type = GetQLinearOpType(node_unit.GetNode()); bool is_qlinear_conv = (qlinear_op_type == QLinearOpType::QLinearConv); bool is_qlinear_matmul = (qlinear_op_type == QLinearOpType::QLinearMatMul); - const auto input_defs(node.InputDefs()); + const auto& io_defs = is_input ? node_unit.Inputs() : node_unit.Outputs(); for (const auto idx : indices) { - if (idx >= input_defs.size()) { - LOGS_DEFAULT(VERBOSE) << "HasValidQuantizationScales, Input index, " << idx - << " >= input number, " << input_defs.size(); + if (idx >= io_defs.size()) { + LOGS_DEFAULT(VERBOSE) << (is_input ? "Input" : "Output") << " index, " << idx + << " >= size, " << io_defs.size() + << " of NodeUnit: " << node_unit.Name(); + return false; + } + + const auto& io_def = io_defs[idx]; + if (!io_def.quant_param.has_value()) { + LOGS_DEFAULT(VERBOSE) << "HasValidQuantizationZeroPoints, Input index, " << idx + << " has no quant_param"; return false; } - const auto scale_name = input_defs[idx]->Name(); + const auto scale_name = io_def.quant_param->scale.Name(); + if (!Contains(initializers, scale_name)) { LOGS_DEFAULT(VERBOSE) << "The scale of " << op_type << " must be an initializer tensor"; return false; } // If this op is Qlinear[Conv/MatMul], we want to check u8s8 support for weight tensor (or B tensor for QlinearMatMul) - bool is_conv_matmul_weight = (is_qlinear_conv || is_qlinear_matmul) && idx == 4; + bool is_conv_matmul_weight = is_input && (is_qlinear_conv || is_qlinear_matmul) && idx == 1; bool is_conv_matmul_u8s8_weight = false; if (is_conv_matmul_weight) { - const auto& weight_tensor = *initializers.at(node.InputDefs()[3]->Name()); + const auto& weight_tensor = *initializers.at(io_def.node_arg.Name()); is_conv_matmul_u8s8_weight = weight_tensor.data_type() == ONNX_NAMESPACE::TensorProto_DataType_INT8; } @@ -205,7 +205,7 @@ bool HasValidQuantizationScales(const InitializedTensorSet& initializers, const return false; } - const auto& weight_tensor = *initializers.at(node.InputDefs()[3]->Name()); + const auto& weight_tensor = *initializers.at(io_def.node_arg.Name()); if (weight_tensor.dims()[0] != scales_dim) { LOGS_DEFAULT(VERBOSE) << op_type << " mismatch int8 per-channel quantization weight," << " weight dimension[0] " << weight_tensor.dims()[0] @@ -218,30 +218,44 @@ bool HasValidQuantizationScales(const InitializedTensorSet& initializers, const return true; } -bool HasValidQuantizationZeroPoints(const InitializedTensorSet& initializers, const Node& node, - const std::vector& indices) { - const auto& op_type = node.OpType(); - auto qlinear_op_type = GetQLinearOpType(node); +bool HasValidQuantizationZeroPoints(const InitializedTensorSet& initializers, const NodeUnit& node_unit, + const std::vector& indices, bool is_input) { + const auto& op_type = node_unit.OpType(); + auto qlinear_op_type = GetQLinearOpType(node_unit.GetNode()); bool is_qlinear_conv = (qlinear_op_type == QLinearOpType::QLinearConv); bool is_qlinear_matmul = (qlinear_op_type == QLinearOpType::QLinearMatMul); - const auto input_defs(node.InputDefs()); + + const auto& io_defs = is_input ? node_unit.Inputs() : node_unit.Outputs(); for (const auto idx : indices) { - if (idx >= input_defs.size()) { + if (idx >= io_defs.size()) { + LOGS_DEFAULT(VERBOSE) << "HasValidQuantizationZeroPoints, " + << (is_input ? "Input" : "Output") << " index, " << idx + << " >= size, " << io_defs.size(); + return false; + } + + const auto& io_def = io_defs[idx]; + if (!io_def.quant_param.has_value()) { LOGS_DEFAULT(VERBOSE) << "HasValidQuantizationZeroPoints, Input index, " << idx - << " >= input number, " << input_defs.size(); + << " has no quant_param"; return false; } - const auto zero_point_name = input_defs[idx]->Name(); + // zero point is optional here + if (!io_def.quant_param->zero_point) + return true; + + const auto& zero_point_name = io_def.quant_param->zero_point->Name(); if (!Contains(initializers, zero_point_name)) { LOGS_DEFAULT(VERBOSE) << "The zero point of " << op_type << " must be an initializer tensor"; return false; } - bool is_conv_matmul_weight = is_qlinear_conv && idx == 5; + bool is_conv_matmul_weight = is_input && (is_qlinear_conv || is_qlinear_matmul) && idx == 1; bool is_conv_matmul_u8s8_weight = false; + if (is_conv_matmul_weight) { - const auto& weight_tensor = *initializers.at(node.InputDefs()[3]->Name()); + const auto& weight_tensor = *initializers.at(io_def.node_arg.Name()); is_conv_matmul_u8s8_weight = weight_tensor.data_type() == ONNX_NAMESPACE::TensorProto_DataType_INT8; } @@ -275,7 +289,7 @@ bool HasValidQuantizationZeroPoints(const InitializedTensorSet& initializers, co // or a tensor with same channel as weight, for NNAPI we only support it be // 0 (scalar) or all 0 (tensor), NNAPI will assume the zero point for per-channel // quantization is 0 there is no input for it - const auto& weight_tensor = *initializers.at(node.InputDefs()[3]->Name()); + const auto& weight_tensor = *initializers.at(io_def.node_arg.Name()); if (weight_tensor.dims()[0] != zero_dim && zero_dim != 1) { LOGS_DEFAULT(VERBOSE) << op_type << " mismatch int8 per-channel quantization weight," << " weight dimension[0] " << weight_tensor.dims()[0] @@ -284,7 +298,7 @@ bool HasValidQuantizationZeroPoints(const InitializedTensorSet& initializers, co } std::vector unpacked_tensor; - auto status = onnxruntime::utils::UnpackInitializerData(zero_tensor, node.ModelPath(), unpacked_tensor); + auto status = onnxruntime::utils::UnpackInitializerData(zero_tensor, node_unit.ModelPath(), unpacked_tensor); if (!status.IsOK()) { LOGS_DEFAULT(ERROR) << "Qlinear[Conv/MatMul] error when unpack zero tensor: " << zero_point_name << ", error msg: " << status.ErrorMessage(); @@ -306,33 +320,61 @@ bool HasValidQuantizationZeroPoints(const InitializedTensorSet& initializers, co return true; } -common::Status GetQuantizationScale(const InitializedTensorSet& initializers, const Node& node, - size_t idx, float& scale) { - std::vector unpacked_tensor; - const auto& name = node.InputDefs()[idx]->Name(); - const auto& scale_tensor = *initializers.at(name); - ORT_RETURN_IF_ERROR( - onnxruntime::utils::UnpackInitializerData(scale_tensor, node.ModelPath(), unpacked_tensor)); - - // The scale should be one or more floats - ORT_RETURN_IF(unpacked_tensor.size() < 4, "The initializer [", name, "] should have one or more floats ", - "with size no less than 4, actual size: ", unpacked_tensor.size()); - scale = reinterpret_cast(unpacked_tensor.data())[0]; +common::Status GetQuantizationScaleAndZeroPoint( + const InitializedTensorSet& initializers, const NodeUnitIODef& io_def, const Path& model_path, + float& scale, int32_t& zero_point) { + scale = 0.0f; + zero_point = 0; + + if (!io_def.quant_param) { // Not a quantized IO + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "NodeArg: ", io_def.node_arg.Name(), " is not quantized"); + } + + const auto unpack_tensor = [&model_path](const InitializedTensorSet& initializers, + const std::string& name, std::vector& unpacked_tensor) { + const auto& tensor = *initializers.at(name); + ORT_RETURN_IF_ERROR( + onnxruntime::utils::UnpackInitializerData(tensor, model_path, unpacked_tensor)); + return Status::OK(); + }; + + const auto& quant_param = *io_def.quant_param; + { // get the scale + std::vector unpacked_tensor; + const auto& name = quant_param.scale.Name(); + ORT_RETURN_IF_ERROR(unpack_tensor(initializers, name, unpacked_tensor)); + // The scale should be one or more floats + ORT_RETURN_IF(unpacked_tensor.size() < 4, + "The initializer [", name, "] should have one or more floats ", + "with size no less than 4, actual size: ", unpacked_tensor.size()); + scale = reinterpret_cast(unpacked_tensor.data())[0]; + } + + if (quant_param.zero_point) { // get the zero point if it's there + std::vector unpacked_tensor; + const auto& name = quant_param.zero_point->Name(); + ORT_RETURN_IF_ERROR(unpack_tensor(initializers, name, unpacked_tensor)); + ORT_RETURN_IF(unpacked_tensor.empty(), "The initializer [", name, "] is empty"); + // Onnx quantization uses uint8 [int8 not yet supported], need to cast to int32_t used by NNAPI + zero_point = static_cast(unpacked_tensor[0]); + } + return Status::OK(); } -common::Status GetQuantizationZeroPoint(const InitializedTensorSet& initializers, - const Node& node, size_t idx, int32_t& zero_point) { - std::vector unpacked_tensor; - const auto& name = node.InputDefs()[idx]->Name(); - const auto& zero_point_tensor = *initializers.at(name); - ORT_RETURN_IF_ERROR( - onnxruntime::utils::UnpackInitializerData(zero_point_tensor, node.ModelPath(), unpacked_tensor)); - - ORT_RETURN_IF(unpacked_tensor.empty(), "The initializer [", name, "] is empty"); - // Onnx quantization uses uint8 [int8 not yet supported], need to cast to int32_t used by NNAPI - zero_point = static_cast(unpacked_tensor[0]); - return Status::OK(); +common::Status GetQuantizationScaleAndZeroPoint( + const InitializedTensorSet& initializers, const NodeUnit& node_unit, const std::string& name, + float& scale, int32_t& zero_point, bool is_input) { + const auto& io_defs = is_input ? node_unit.Inputs() : node_unit.Outputs(); + for (const auto& io_def : io_defs) { + if (io_def.node_arg.Name() == name) + return GetQuantizationScaleAndZeroPoint(initializers, io_def, node_unit.ModelPath(), + scale, zero_point); + } + + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Unknown input: ", name, ", for NodeUnit with node index: ", node_unit.Index()); } bool GetShape(const NodeArg& node_arg, Shape& shape) { @@ -363,9 +405,9 @@ bool GetType(const NodeArg& node_arg, int32_t& type) { return true; } -void GetFlattenOutputShape(const Node& node, const Shape& input_shape, int32_t& dim_1, int32_t& dim_2) { +void GetFlattenOutputShape(const NodeUnit& node_unit, const Shape& input_shape, int32_t& dim_1, int32_t& dim_2) { int32_t rank = static_cast(input_shape.size()); - NodeAttrHelper helper(node); + NodeAttrHelper helper(node_unit); int32_t axis = helper.Get("axis", 1); // axis == rank is a valid input, but invalid for HandleNegativeAxis // Skip non-negative axis here @@ -491,10 +533,11 @@ std::string Shape2String(const std::vector& shape) { return os.str(); } -bool CheckIsInitializer(const InitializedTensorSet& initializers, const Node& node, - size_t input_idx, const char* input_name) { - if (!Contains(initializers, node.InputDefs()[input_idx]->Name())) { - LOGS_DEFAULT(VERBOSE) << input_name << " of " << node.OpType() << " must be an initializer tensor"; +bool CheckIsInitializer(const InitializedTensorSet& initializers, const NodeUnit& node_unit, + const std::string& input_name, const char* input_description) { + if (!Contains(initializers, input_name)) { + LOGS_DEFAULT(VERBOSE) << input_description << " of " << node_unit.Name() << "of type [" + << node_unit.OpType() << "] must be an initializer tensor"; return false; } diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/helper.h b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/helper.h index d8d89269c9f55..c3729fb1c8f10 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/helper.h +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/helper.h @@ -26,10 +26,13 @@ namespace onnxruntime { using Shape = std::vector; using InitializerMap = std::unordered_map; +class GraphViewer; class Node; class NodeArg; class NodeUnit; -class GraphViewer; +class Path; + +struct NodeUnitIODef; namespace nnapi { @@ -94,28 +97,32 @@ QLinearOpType GetQLinearOpType(const onnxruntime::Node& node); // Return the type of the conv ops, // This function assumes the input is a 2d conv node -ConvType GetConvType(const onnxruntime::Node& node, const InitializedTensorSet& initializers); +ConvType GetConvType(const NodeUnit& node_unit, const InitializedTensorSet& initializers); // This qlinear op is an operator takes 2 inputs and produces 1 output // Such as QLinearConv, QLinearMatMul, QLinearAdd, ... bool IsQLinearBinaryOp(QLinearOpType qlinear_op_type); // Check if a qlinear unary op has valid inputs, Qlinear[Sigmoid/AveragePool] -bool HasValidUnaryOpQuantizedInputs(const Node& node); +bool HasValidUnaryOpQuantizedInputs(const NodeUnit& node_unit); // Check if a qlinear binary op has valid inputs, Qlinear[Conv/MatMul/Add] -bool HasValidBinaryOpQuantizedInputs(const Node& node); +bool HasValidBinaryOpQuantizedInputs(const NodeUnit& node_unit); + // Check if a qlinear op has valid scales for given indices -bool HasValidQuantizationScales(const InitializedTensorSet& initializers, const Node& node, - const std::vector& indices, const OpSupportCheckParams& params); +bool HasValidQuantizationScales(const InitializedTensorSet& initializers, const NodeUnit& node_unit, + const std::vector& indices, const OpSupportCheckParams& params, bool is_input); + // Check if a qlinear op has valid zero points for given indices -bool HasValidQuantizationZeroPoints(const InitializedTensorSet& initializers, const Node& node, - const std::vector& indices); +bool HasValidQuantizationZeroPoints(const InitializedTensorSet& initializers, const NodeUnit& node_unit, + const std::vector& indices, bool is_input); -common::Status GetQuantizationScale(const InitializedTensorSet& initializers, const Node& node, - size_t idx, float& scale); +common::Status GetQuantizationScaleAndZeroPoint( + const InitializedTensorSet& initializers, const NodeUnitIODef& io_def, const Path& model_path, + float& scale, int32_t& zero_point); -common::Status GetQuantizationZeroPoint(const InitializedTensorSet& initializers, - const Node& node, size_t idx, int32_t& zero_point) ORT_MUST_USE_RESULT; +common::Status GetQuantizationScaleAndZeroPoint( + const InitializedTensorSet& initializers, const NodeUnit& node_unit, const std::string& name, + float& scale, int32_t& zero_point, bool is_input = true); // Get Shape/Type of a NodeArg // TODO, move to shared_utils @@ -123,7 +130,7 @@ bool GetShape(const NodeArg& node_arg, Shape& shape); bool GetType(const NodeArg& node_arg, int32_t& type); // Get the output shape of Flatten Op -void GetFlattenOutputShape(const Node& node, const Shape& input_shape, int32_t& dim_1, int32_t& dim_2); +void GetFlattenOutputShape(const NodeUnit& node_unit, const Shape& input_shape, int32_t& dim_1, int32_t& dim_2); // If a node is supported by NNAPI bool IsNodeSupported(const NodeUnit& node_unit, const GraphViewer& graph_viewer, const OpSupportCheckParams& params); @@ -144,8 +151,10 @@ bool IsValidSupportedNodeGroup(const std::vector& supported_node_gr std::string Shape2String(const std::vector& shape); // Check the given input is an initializer tensor -bool CheckIsInitializer(const InitializedTensorSet& initializers, const Node& node, - size_t index, const char* input_name) ORT_MUST_USE_RESULT; +// input_name is the name of the initializer +// input_description is the string describing the input in the output message (if any) +bool CheckIsInitializer(const InitializedTensorSet& initializers, const NodeUnit& node_unit, + const std::string& input_name, const char* input_description); } // namespace nnapi } // namespace onnxruntime diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/model_builder.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/model_builder.cc index 645ab23a85109..fe6eade431770 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/model_builder.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/model_builder.cc @@ -1,22 +1,23 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include -#include -#include +#include "model_builder.h" +#include "core/common/logging/logging.h" +#include "core/common/safeint.h" +#include "core/common/status.h" +#include "core/framework/tensorprotoutils.h" +#include "core/graph/graph_viewer.h" #include "core/providers/common.h" #include "core/providers/shared/node_unit/node_unit.h" #include "core/providers/shared/utils/utils.h" #include "core/providers/nnapi/nnapi_builtin/nnapi_lib/nnapi_implementation.h" + #include "helper.h" -#include "model_builder.h" #include "op_builder.h" #include "op_support_checker.h" -using onnxruntime::NodeUnit; using namespace android::nn::wrapper; -using std::vector; namespace onnxruntime { namespace nnapi { @@ -31,7 +32,7 @@ int32_t ModelBuilder::GetNNAPIFeatureLevel() const { // Scalar operand is copied into the model, no need to persist #define DEFINE_ADD_OPERAND_FROM_SCALAR(scalar_type, op_type) \ Status ModelBuilder::AddOperandFromScalar(scalar_type value, uint32_t& index) { \ - OperandType operandType(Type::op_type, vector{}); \ + OperandType operandType(Type::op_type, std::vector{}); \ ORT_RETURN_IF_ERROR(AddNewNNAPIOperand(operandType, index)); \ RETURN_STATUS_ON_ERROR_WITH_NOTE( \ nnapi_->ANeuralNetworksModel_setOperandValue( \ @@ -50,13 +51,12 @@ void ModelBuilder::AddInitializerToSkip(const std::string& tensor_name) { skipped_initializers_.insert(tensor_name); } -static std::unordered_map> GetAllQuantizedOpInputs(const GraphViewer& graph_viewer); - Status ModelBuilder::Prepare() { nnapi_model_ = std::unique_ptr(new Model()); RETURN_STATUS_ON_ERROR(nnapi_->ANeuralNetworksModel_create(&nnapi_model_->model_)); ORT_RETURN_IF_ERROR(GetTargetDevices()); - all_quantized_op_inputs_ = GetAllQuantizedOpInputs(graph_viewer_); + PreprocessNodeUnits(); + GetAllQuantizedOpInputs(); PreprocessInitializers(); PreprocessActivations(); ORT_RETURN_IF_ERROR(RegisterInitializers()); @@ -118,74 +118,87 @@ Status ModelBuilder::GetTargetDevices() { } void ModelBuilder::PreprocessInitializers() { - const auto& node_indices = graph_viewer_.GetNodesInTopologicalOrder(); - for (size_t i = 0; i < node_indices.size(); i++) { - const auto* node(graph_viewer_.GetNode(node_indices[i])); - if (const auto* op_builder = GetOpBuilder(*node)) { - const NodeUnit node_unit(*node); - op_builder->AddInitializersToSkip(*this, node_unit); + for (const auto& node_unit : node_unit_holder_) { + if (const auto* op_builder = GetOpBuilder(*node_unit)) { + op_builder->AddInitializersToSkip(*this, *node_unit); } } } void ModelBuilder::PreprocessActivations() { - const auto& node_indices = graph_viewer_.GetNodesInTopologicalOrder(); - for (size_t i = 0; i < node_indices.size(); i++) { - const auto* node(graph_viewer_.GetNode(node_indices[i])); - const auto& op_type(node->OpType()); - + for (const auto& node_unit : node_unit_holder_) { + const auto& node = node_unit->GetNode(); + const auto& op_type(node.OpType()); if (op_type == "Relu") { - activation_nodes_.emplace(node->Index(), ANEURALNETWORKS_FUSED_RELU); + activation_node_units_.emplace(node_unit.get(), ANEURALNETWORKS_FUSED_RELU); } else if (op_type == "Clip") { // Relu1 or Relu6 float min, max; - if (!GetClipMinMax(GetInitializerTensors(), *node, min, max, logging::LoggingManager::DefaultLogger())) + if (!GetClipMinMax(GetInitializerTensors(), node, min, max, logging::LoggingManager::DefaultLogger())) continue; if (min == -1.0f && max == 1.0f) { - activation_nodes_.emplace(node->Index(), ANEURALNETWORKS_FUSED_RELU1); + activation_node_units_.emplace(node_unit.get(), ANEURALNETWORKS_FUSED_RELU1); } else if (min == 0.0f && max == 6.0f) { - activation_nodes_.emplace(node->Index(), ANEURALNETWORKS_FUSED_RELU6); + activation_node_units_.emplace(node_unit.get(), ANEURALNETWORKS_FUSED_RELU6); } } } } -// Help to get all quantized operators' input and the node(s) using the input -static std::unordered_map> GetAllQuantizedOpInputs(const GraphViewer& graph_viewer) { - std::unordered_map> all_quantized_op_inputs; - const auto& node_indices = graph_viewer.GetNodesInTopologicalOrder(); - for (const auto& node_idx : node_indices) { - const auto* node(graph_viewer.GetNode(node_idx)); - auto qlinear_op_type = GetQLinearOpType(*node); +const NodeUnit& ModelBuilder::GetNodeUnit(const Node* node) const { + // In theory, if node_unit_map_ is generated correctly, see PreprocessNodeUnits(), a NodeUnit can be + // found for any single node in the graph_viewer_, unless the given node is not from graph_viewer_ + return *node_unit_map_.at(node); +} + +void ModelBuilder::PreprocessNodeUnits() { + // TODO, hookup shared QDQ selectors here to identify all the qdq NodeUnit in the graph + const auto& node_indices = graph_viewer_.GetNodesInTopologicalOrder(); + for (size_t i = 0; i < node_indices.size(); i++) { + const auto node_idx = node_indices[i]; + // TODO, check if the node is already part of a qdq group + const auto* node(graph_viewer_.GetNode(node_idx)); + auto node_unit = std::make_unique(*node); + node_unit_map_.insert({node, node_unit.get()}); + node_unit_holder_.push_back(std::move(node_unit)); + } +} + +// Help to get all quantized operators' input and the NodeUnit(s) using the input +void ModelBuilder::GetAllQuantizedOpInputs() { + for (const auto& node_unit : node_unit_holder_) { + // TODO, hookup getting quantized inputs with QDQ NodeUnits and remove the ORT_ENFORCE + ORT_ENFORCE(node_unit->UnitType() == NodeUnit::Type::SingleNode, "QDQ NodeUnit is not yet implemented"); + + auto qlinear_op_type = GetQLinearOpType(node_unit->GetNode()); // Not a qlinear op + // TODO, add handling for QDQ NodeUnit if (qlinear_op_type == QLinearOpType::Unknown) continue; + const auto add_quantized_input = + [&all_quantized_op_inputs = all_quantized_op_inputs_](const NodeUnit& node_unit, size_t input_idx) { + const auto& input_name = node_unit.Inputs()[input_idx].node_arg.Name(); + all_quantized_op_inputs[input_name].push_back(&node_unit); + }; + // All qlinear ops EXCEPT QuantizeLinear has quantized input if (qlinear_op_type != QLinearOpType::QuantizeLinear) { - const auto& input_name = node->InputDefs()[0]->Name(); - if (Contains(all_quantized_op_inputs, input_name)) - all_quantized_op_inputs.at(input_name).push_back(node); - else - all_quantized_op_inputs.emplace(input_name, vector{node}); + add_quantized_input(*node_unit, 0); } if (IsQLinearBinaryOp(qlinear_op_type)) { - const auto& input_name = node->InputDefs()[3]->Name(); - if (Contains(all_quantized_op_inputs, input_name)) - all_quantized_op_inputs.at(input_name).push_back(node); - else - all_quantized_op_inputs.emplace(input_name, vector{node}); + add_quantized_input(*node_unit, 1); } - } - return all_quantized_op_inputs; + // TODO, add handling for varidiac nodes such as QLinearConcat + } } static Status GetInputDataType( const InitializedTensorSet& initializers, - const std::unordered_map>& all_quantized_op_inputs, + const std::unordered_map>& all_quantized_op_inputs, const std::string& name, int32_t data_type, const Shape& shape, OperandType& operand_type) { Type type = Type::TENSOR_FLOAT32; @@ -208,10 +221,9 @@ static Status GetInputDataType( } // TODO, verify the scale and zero point match if there are multiple op using same input - const auto* node = all_quantized_op_inputs.at(name)[0]; - const NodeUnit node_unit(*node); - ORT_RETURN_IF_ERROR(GetQuantizedInputScaleAndZeroPoint( - initializers, node_unit, name, scale, zero_point)); + const auto* node_unit = all_quantized_op_inputs.at(name)[0]; + ORT_RETURN_IF_ERROR(GetQuantizationScaleAndZeroPoint( + initializers, *node_unit, name, scale, zero_point, true /* is_input */)); break; } // case ONNX_NAMESPACE::TensorProto_DataType_INT8: @@ -491,15 +503,23 @@ Status ModelBuilder::AddOperandFromPersistMemoryBuffer( Status ModelBuilder::AddOperations() { const auto& node_indices = graph_viewer_.GetNodesInTopologicalOrder(); + std::unordered_set processed_node_units; for (size_t i = 0; i < node_indices.size(); i++) { const auto* node(graph_viewer_.GetNode(node_indices[i])); - if (const auto* op_builder = GetOpBuilder(*node)) { - const NodeUnit node_unit(*node); + const NodeUnit& node_unit = GetNodeUnit(node); + + // Since a NodeUnit may contain multiple nodes, avoid processing the same NodeUnit multiple times + if (Contains(processed_node_units, &node_unit)) + continue; + + if (const auto* op_builder = GetOpBuilder(node_unit)) { ORT_RETURN_IF_ERROR(op_builder->AddToModelBuilder(*this, node_unit)); } else { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Node [", node->Name(), "], type [", node->OpType(), "] is not supported"); + "Node [", node_unit.Name(), "], type [", node_unit.OpType(), "] is not supported"); } + + processed_node_units.insert(&node_unit); } return Status::OK(); @@ -605,20 +625,40 @@ Status ModelBuilder::Compile(std::unique_ptr& model) { return Status::OK(); } -int32_t ModelBuilder::FindActivation(const Node& node, const NodeArg& output) { +int32_t ModelBuilder::FindActivation(const NodeUnit& node_unit) { int32_t fuse_code = ANEURALNETWORKS_FUSED_NONE; + const auto& output_nodes = node_unit.GetOutputNodes(); + if (node_unit.GetOutputNodes().size() != 1) { + LOGS_DEFAULT(VERBOSE) << "FindActivation does not support, NodeUnit [" << node_unit.Name() + << "] type [" << node_unit.OpType() + << "], with " << output_nodes.size() << " output nodes"; + return fuse_code; + } + + const auto& outputs = node_unit.Outputs(); + if (outputs.size() != 1) { + LOGS_DEFAULT(VERBOSE) << "FindActivation does not support, NodeUnit [" << node_unit.Name() + << "] type [" << node_unit.OpType() + << "], with " << outputs.size() << " outputs"; + return fuse_code; + } + const NodeArg& output = outputs[0].node_arg; + const auto& output_node = *output_nodes[0]; + + // TODO, add support of activation fusion for quantized node group (qdq or qlinear) // We do not support activation fusion for quantized operators for now - auto qlinear_op_type = GetQLinearOpType(node); + auto qlinear_op_type = GetQLinearOpType(node_unit.GetNode()); if (qlinear_op_type != QLinearOpType::Unknown) return fuse_code; - for (auto it = node.OutputEdgesBegin(), end = node.OutputEdgesEnd(); it != end; ++it) { + for (auto it = output_node.OutputEdgesBegin(), end = output_node.OutputEdgesEnd(); it != end; ++it) { const auto& dst_node = it->GetNode(); const auto* dst_input = dst_node.InputDefs()[it->GetDstArgIndex()]; - if (Contains(activation_nodes_, dst_node.Index())) { + const auto& dst_node_unit = GetNodeUnit(&dst_node); + if (Contains(activation_node_units_, &dst_node_unit)) { if (&output == dst_input) { - fuse_code = activation_nodes_.at(dst_node.Index()); + fuse_code = activation_node_units_.at(&dst_node_unit); } } else { // if there is any other non-relu node using the output @@ -628,14 +668,14 @@ int32_t ModelBuilder::FindActivation(const Node& node, const NodeArg& output) { } } - // if output is a graph output, will add relu separately + // if output is a graph output, will add activation separately if (fuse_code != ANEURALNETWORKS_FUSED_NONE) { - for (const auto* graph_output : graph_viewer_.GetOutputs()) { - if (&output == graph_output) - return ANEURALNETWORKS_FUSED_NONE; + const auto& graph_outputs = graph_viewer_.GetOutputs(); + if (std::find(graph_outputs.cbegin(), graph_outputs.cend(), &output) != graph_outputs.cend()) { + return ANEURALNETWORKS_FUSED_NONE; } - LOGS_DEFAULT(VERBOSE) << "Node [" << node.Name() << "] type [" << node.OpType() + LOGS_DEFAULT(VERBOSE) << "Node [" << node_unit.Name() << "] type [" << node_unit.OpType() << "], fused the output [" << output.Name() << "]"; fused_activations_.insert(output.Name()); @@ -644,12 +684,13 @@ int32_t ModelBuilder::FindActivation(const Node& node, const NodeArg& output) { return fuse_code; } -/* static */ const IOpBuilder* ModelBuilder::GetOpBuilder(const Node& node) { +/* static */ const IOpBuilder* ModelBuilder::GetOpBuilder(const NodeUnit& node_unit) { const auto& op_builders = GetOpBuilders(); - if (!Contains(op_builders, node.OpType())) + const auto& op_type = node_unit.GetNode().OpType(); + if (!Contains(op_builders, op_type)) return nullptr; - return op_builders.at(node.OpType()); + return op_builders.at(op_type); } std::string ModelBuilder::GetUniqueName(const std::string& base_name) { @@ -663,6 +704,10 @@ std::string ModelBuilder::GetUniqueName(const std::string& base_name) { return unique_name; } +const InitializedTensorSet& ModelBuilder::GetInitializerTensors() const { + return graph_viewer_.GetAllInitializedTensors(); +} + void ModelBuilder::RegisterNHWCOperand(const std::string& name) { nhwc_operands_.insert(name); } diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/model_builder.h b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/model_builder.h index d7dfd78ac0a43..2269c986f60ea 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/model_builder.h +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/model_builder.h @@ -5,16 +5,22 @@ #include #include -#include +#include "core/graph/basic_types.h" #include "core/providers/nnapi/nnapi_builtin/model.h" #include "core/providers/nnapi/nnapi_builtin/nnapi_lib/NeuralNetworksWrapper.h" -#include "op_support_checker.h" #include "shaper.h" namespace onnxruntime { + +class GraphViewer; +class NodeUnit; +class Node; +class NodeArg; + namespace nnapi { class IOpBuilder; +class IOpSupportChecker; class ModelBuilder { public: @@ -33,30 +39,30 @@ class ModelBuilder { }; ModelBuilder(const GraphViewer& graph_viewer); - ~ModelBuilder() = default; - Status Compile(std::unique_ptr& model) ORT_MUST_USE_RESULT; + common::Status Compile(std::unique_ptr& model); int32_t GetNNAPIFeatureLevel() const; // Add an NNAPI operation (operator) - Status AddOperation(int op, const std::vector& input_indices, - const std::vector& output_names, - const std::vector& types, - const std::vector& is_nhwc_vec) ORT_MUST_USE_RESULT; + common::Status AddOperation(int op, const std::vector& input_indices, + const std::vector& output_names, + const std::vector& types, + const std::vector& is_nhwc_vec); - // Find if an output has a fuseable activation (Relu) - int32_t FindActivation(const Node& node, const NodeArg& output); + // Find if the given node_unit has a fuseable activation (Relu/Relu1/Relu6) + // For now we only support node_unit with a single output + int32_t FindActivation(const NodeUnit& node_unit); // Add an NNAPI scalar operand - Status AddOperandFromScalar(bool value, uint32_t& index) ORT_MUST_USE_RESULT; - Status AddOperandFromScalar(float value, uint32_t& index) ORT_MUST_USE_RESULT; - Status AddOperandFromScalar(int32_t value, uint32_t& index) ORT_MUST_USE_RESULT; + common::Status AddOperandFromScalar(bool value, uint32_t& index); + common::Status AddOperandFromScalar(float value, uint32_t& index); + common::Status AddOperandFromScalar(int32_t value, uint32_t& index); // Add an NNAPI tensor operand (and allocate persist buffer) - Status AddOperandFromPersistMemoryBuffer( + common::Status AddOperandFromPersistMemoryBuffer( const std::string& name, const void* buffer, - const android::nn::wrapper::OperandType& operand_type) ORT_MUST_USE_RESULT; + const android::nn::wrapper::OperandType& operand_type); // The initializer will be processed separately, skip it as an initializer void AddInitializerToSkip(const std::string& tensor_name); @@ -96,7 +102,7 @@ class ModelBuilder { const std::unordered_set& GetFusedActivations() const { return fused_activations_; } - const InitializedTensorSet& GetInitializerTensors() const { return graph_viewer_.GetAllInitializedTensors(); } + const InitializedTensorSet& GetInitializerTensors() const; const GraphViewer& GetGraphViewer() const { return graph_viewer_; } @@ -107,10 +113,13 @@ class ModelBuilder { bool GetNCHWOperand(const std::string& nhwc_name, std::string& nchw_name); bool GetNHWCOperand(const std::string& nchw_name, std::string& nhwc_name); - Status SetNHWCToNCHWOperandMap(const std::string& nhwc_name, - const std::string& nchw_name) ORT_MUST_USE_RESULT; - Status SetNCHWToNHWCOperandMap(const std::string& nchw_name, - const std::string& nhwc_name) ORT_MUST_USE_RESULT; + // Get the NodeUnit which contains the given node + const NodeUnit& GetNodeUnit(const Node* node) const; + + common::Status SetNHWCToNCHWOperandMap(const std::string& nhwc_name, + const std::string& nchw_name); + common::Status SetNCHWToNHWCOperandMap(const std::string& nchw_name, + const std::string& nhwc_name); private: const NnApi* nnapi_{nullptr}; @@ -134,8 +143,8 @@ class ModelBuilder { std::unordered_set skipped_initializers_; - // All activation nodes (Relu, Relu1, Relu6) as a map - std::unordered_map activation_nodes_; + // All activation nodes (Relu, Relu1, Relu6) as a map + std::unordered_map activation_node_units_; std::unordered_map> op_support_checkers_; @@ -149,9 +158,14 @@ class ModelBuilder { std::vector input_index_vec_; std::vector output_index_vec_; - // Contains all quantized operators' input and the node(s) using the input - // In the form of {input_name, [node(s) using the input]} - std::unordered_map> all_quantized_op_inputs_; + // Contains all quantized operators' input and the NodeUnit(s) using the input + // In the form of {input_name, [NodeUnit(s) using the input]} + std::unordered_map> all_quantized_op_inputs_; + + // Holder for the NodeUnits in the graph, this will guarantee the NodeUnits is + // valid throughout the lifetime of the ModelBuilder + std::vector> node_unit_holder_; + std::unordered_map node_unit_map_; std::unordered_set unique_names_; @@ -164,32 +178,38 @@ class ModelBuilder { uint32_t next_index_ = 0; // Convert the onnx model to ANeuralNetworksModel - Status Prepare() ORT_MUST_USE_RESULT; + common::Status Prepare(); - Status GetTargetDevices() ORT_MUST_USE_RESULT; + common::Status GetTargetDevices(); // If a NNAPI operation will use initializers directly, we will add the initializers to the skip list void PreprocessInitializers(); // Preprocess all the activation nodes (Relu/Relu1/Relu6) for easy query later void PreprocessActivations(); // Copy and process all the initializers to NNAPI model - Status RegisterInitializers() ORT_MUST_USE_RESULT; - Status RegisterModelInputs() ORT_MUST_USE_RESULT; - Status AddOperations() ORT_MUST_USE_RESULT; - Status RegisterModelOutputs() ORT_MUST_USE_RESULT; + common::Status RegisterInitializers(); + common::Status RegisterModelInputs(); + common::Status AddOperations(); + common::Status RegisterModelOutputs(); // After constructing the NNAPI model, will set the shape inferencing record to the Model void RegisterModelShaper(); - Status SetOperandValue(uint32_t index, Model::NNMemory* memory, - size_t size, size_t offset) ORT_MUST_USE_RESULT; + // Get all quantized inputs in the underlying graph_viewer + void GetAllQuantizedOpInputs(); - Status AddNewNNAPIOperand(const android::nn::wrapper::OperandType& type, uint32_t& index) ORT_MUST_USE_RESULT; - Status AddNewOperand(const std::string& name, - const android::nn::wrapper::OperandType& operand_type, - bool is_nhwc, - uint32_t& index) ORT_MUST_USE_RESULT; + // Go through the underlying graph_viewer, and generate NodeUnits, Many initializing functions are + // using the result of PreprocessNodeUnits, this need to run early in the Prepare() + void PreprocessNodeUnits(); + + common::Status SetOperandValue(uint32_t index, Model::NNMemory* memory, size_t size, size_t offset); + + common::Status AddNewNNAPIOperand(const android::nn::wrapper::OperandType& type, uint32_t& index); + common::Status AddNewOperand(const std::string& name, + const android::nn::wrapper::OperandType& operand_type, + bool is_nhwc, + uint32_t& index); - static const IOpBuilder* GetOpBuilder(const Node& node); + static const IOpBuilder* GetOpBuilder(const NodeUnit& node_unit); }; } // namespace nnapi diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder.cc index 66f870df15074..bff260cb6741c 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder.cc @@ -3,12 +3,13 @@ #include "op_builder.h" -#include -#include -#include -#include #include +#include "core/common/logging/logging.h" +#include "core/common/safeint.h" +#include "core/framework/tensorprotoutils.h" +#include "core/graph/graph_viewer.h" +#include "core/providers/common.h" #include "core/providers/shared/utils/utils.h" #include "core/providers/shared/node_unit/node_unit.h" #include "core/providers/cpu/tensor/slice_helper.h" @@ -16,9 +17,7 @@ #include "model_builder.h" #include "op_support_checker.h" -using onnxruntime::NodeUnit; using namespace android::nn::wrapper; -using std::vector; namespace onnxruntime { namespace nnapi { @@ -40,13 +39,7 @@ struct OpBuilderRegistrations { Status AddTransposeOperator(ModelBuilder& model_builder, const std::string& input, const std::string& perm_name, - vector perm, - const std::string& output, - bool output_is_nhwc) ORT_MUST_USE_RESULT; -Status AddTransposeOperator(ModelBuilder& model_builder, - const std::string& input, - const std::string& perm_name, - vector perm, + std::vector perm, const std::string& output, bool output_is_nhwc) { auto& shaper(model_builder.GetShaper()); @@ -69,10 +62,6 @@ Status AddTransposeOperator(ModelBuilder& model_builder, {output_operand_type}, {output_is_nhwc}); } -Status TransposeBetweenNCHWAndNHWC(ModelBuilder& model_builder, - const std::string& input, - const std::string& output, - bool nchw_to_nhwc) ORT_MUST_USE_RESULT; Status TransposeBetweenNCHWAndNHWC(ModelBuilder& model_builder, const std::string& input, const std::string& output, @@ -83,7 +72,7 @@ Status TransposeBetweenNCHWAndNHWC(ModelBuilder& model_builder, "TransposeBetweenNCHWAndNHWC input has to be a 4d tensor, actual dimensions: ", shaper[input].size()); std::string perm_name; - vector perm; + std::vector perm; if (nchw_to_nhwc) { perm_name = model_builder.GetUniqueName(input + "nchw_to_nhwc_perm"); perm = {0, 2, 3, 1}; @@ -110,18 +99,12 @@ Status TransposeBetweenNCHWAndNHWC(ModelBuilder& model_builder, return Status::OK(); } -Status TransposeNHWCToNCHW(ModelBuilder& model_builder, - const std::string& input, - const std::string& output) ORT_MUST_USE_RESULT; Status TransposeNHWCToNCHW(ModelBuilder& model_builder, const std::string& input, const std::string& output) { return TransposeBetweenNCHWAndNHWC(model_builder, input, output, false /* nchw_to_nhwc */); } -Status TransposeNCHWToNHWC(ModelBuilder& model_builder, - const std::string& input, - const std::string& output) ORT_MUST_USE_RESULT; Status TransposeNCHWToNHWC(ModelBuilder& model_builder, const std::string& input, const std::string& output) { @@ -130,22 +113,22 @@ Status TransposeNCHWToNHWC(ModelBuilder& model_builder, // Convert the input from nchw to nhwc // Caller should ensure input is currently in nchw format using ModelBuilder::IsOperandNHWC -Status GetNHWCInput(ModelBuilder& model_builder, const Node& node, size_t input_index, std::string& input) { - const auto& nchw_input = node.InputDefs()[input_index]->Name(); - if (!model_builder.GetNHWCOperand(nchw_input, input)) { - input = model_builder.GetUniqueName(nchw_input + "_nchw_to_nhwc"); - ORT_RETURN_IF_ERROR(TransposeNCHWToNHWC(model_builder, nchw_input, input)); +Status GetNHWCInput(ModelBuilder& model_builder, const NodeUnit& node_unit, size_t input_index, std::string& nhwc_input) { + const auto& nchw_input = node_unit.Inputs()[input_index].node_arg.Name(); + if (!model_builder.GetNHWCOperand(nchw_input, nhwc_input)) { + nhwc_input = model_builder.GetUniqueName(nchw_input + "_nchw_to_nhwc"); + ORT_RETURN_IF_ERROR(TransposeNCHWToNHWC(model_builder, nchw_input, nhwc_input)); } return Status::OK(); } // Convert the input from nhwc to nchw // Caller should ensure input is currently in nhwc format using ModelBuilder::IsOperandNHWC -Status GetNCHWInput(ModelBuilder& model_builder, const Node& node, size_t input_index, std::string& input) { - const auto& nhwc_input = node.InputDefs()[input_index]->Name(); - if (!model_builder.GetNCHWOperand(nhwc_input, input)) { - input = model_builder.GetUniqueName(nhwc_input + "_nhwc_to_nchw"); - ORT_RETURN_IF_ERROR(TransposeNHWCToNCHW(model_builder, nhwc_input, input)); +Status GetNCHWInput(ModelBuilder& model_builder, const NodeUnit& node_unit, size_t input_index, std::string& nchw_input) { + const auto& nhwc_input = node_unit.Inputs()[input_index].node_arg.Name(); + if (!model_builder.GetNCHWOperand(nhwc_input, nchw_input)) { + nchw_input = model_builder.GetUniqueName(nhwc_input + "_nhwc_to_nchw"); + ORT_RETURN_IF_ERROR(TransposeNHWCToNCHW(model_builder, nhwc_input, nchw_input)); } return Status::OK(); } @@ -154,12 +137,7 @@ Status GetNCHWInput(ModelBuilder& model_builder, const Node& node, size_t input_ // and return the layout type of output tensor // If both inputs have same layout, the output will have the same layout // Otherwise we will need transpose the nhwc input back to nchw, and output will be nchw -Status TransposeBinaryOpInputLayout(ModelBuilder& model_builder, const Node& node, - size_t input1_idx, size_t input2_idx, - std::string& input1, std::string& input2, - bool& output_is_nhwc) ORT_MUST_USE_RESULT; -Status TransposeBinaryOpInputLayout(ModelBuilder& model_builder, const Node& node, - size_t input1_idx, size_t input2_idx, +Status TransposeBinaryOpInputLayout(ModelBuilder& model_builder, const NodeUnit& node_unit, std::string& input1, std::string& input2, bool& output_is_nhwc) { bool input1_is_nhwc = model_builder.IsOperandNHWC(input1); @@ -170,10 +148,10 @@ Status TransposeBinaryOpInputLayout(ModelBuilder& model_builder, const Node& nod output_is_nhwc = input1_is_nhwc; } else if (input1_is_nhwc) { // need transpose input1 back to nchw - ORT_RETURN_IF_ERROR(GetNCHWInput(model_builder, node, input1_idx, input1)); + ORT_RETURN_IF_ERROR(GetNCHWInput(model_builder, node_unit, 0, input1)); } else { // input2_is_nhwc // need transpose input2 back to nchw - ORT_RETURN_IF_ERROR(GetNCHWInput(model_builder, node, input2_idx, input2)); + ORT_RETURN_IF_ERROR(GetNCHWInput(model_builder, node_unit, 1, input2)); } return Status::OK(); @@ -188,17 +166,7 @@ static Status AddBinaryOperator(int32_t op_type, const std::string& output, bool output_is_nhwc, float output_scale = 0.0f, - int32_t output_zero_point = 0) ORT_MUST_USE_RESULT; -static Status AddBinaryOperator(int32_t op_type, - ModelBuilder& model_builder, - const std::string& input1, - const std::string& input2, - bool add_activation, - int32_t fuse_code, - const std::string& output, - bool output_is_nhwc, - float output_scale, - int32_t output_zero_point) { + int32_t output_zero_point = 0) { auto& shaper(model_builder.GetShaper()); const auto& operand_indices(model_builder.GetOperandIndices()); const auto& operand_types(model_builder.GetOperandTypes()); @@ -222,11 +190,7 @@ static Status AddBinaryOperator(int32_t op_type, static Status AddSqueezeOp(ModelBuilder& model_builder, const std::string& node_name, const std::string& input, const std::string& output, - vector axes) ORT_MUST_USE_RESULT; -static Status AddSqueezeOp(ModelBuilder& model_builder, - const std::string& node_name, - const std::string& input, const std::string& output, - vector axes) { + std::vector axes) { if (model_builder.GetNNAPIFeatureLevel() < ANEURALNETWORKS_FEATURE_LEVEL_2) { return ORT_MAKE_STATUS( ONNXRUNTIME, FAIL, "Squeeze is not supported on API level ", model_builder.GetNNAPIFeatureLevel()); @@ -283,11 +247,6 @@ enum DataLayout { // since NNAPI requires X and W to be same type for per-tensor quantization, // the initializer tensor W will be converted from int8 to uint8 by flip each byte by XOR 0x80 // byte ^ 0x80 == byte + 128 -static Status AddInitializerInNewLayout(ModelBuilder& model_builder, - const std::string& name, - const OperandType& source_operand_type, - DataLayout new_layout, - bool is_per_tensor_u8s8) ORT_MUST_USE_RESULT; static Status AddInitializerInNewLayout(ModelBuilder& model_builder, const std::string& name, const OperandType& source_operand_type, @@ -373,10 +332,6 @@ static Status AddInitializerInNewLayout(ModelBuilder& model_builder, // and input B is signed int8), in this case, since NNAPI requires A and B to be same type, // the initializer tensor B will be converted from int8 to uint8 by flip each byte by XOR 0x80 // byte ^ 0x80 == byte + 128 -static Status AddInitializerTransposed(ModelBuilder& model_builder, - const OperandType& source_operand_type, - const std::string& name, - bool is_per_tensor_u8s8) ORT_MUST_USE_RESULT; static Status AddInitializerTransposed(ModelBuilder& model_builder, const OperandType& source_operand_type, const std::string& name, @@ -430,13 +385,7 @@ static Status ComputeConvPads( const uint32_t weight_size_y, const uint32_t weight_size_x, const std::vector& onnx_pads, const std::vector& onnx_strides, const std::vector& onnx_dilations, AutoPadType auto_pad_type, bool nchw, - vector& pads_out) ORT_MUST_USE_RESULT; -static Status ComputeConvPads( - const Shape& input_dimen, - const uint32_t weight_size_y, const uint32_t weight_size_x, - const std::vector& onnx_pads, const std::vector& onnx_strides, const std::vector& onnx_dilations, - AutoPadType auto_pad_type, bool nchw, - vector& pads_out) { + std::vector& pads_out) { const int32_t input_size_y = nchw ? input_dimen[2] : input_dimen[1]; const int32_t input_size_x = nchw ? input_dimen[3] : input_dimen[2]; const int32_t stride_y = onnx_strides[0]; @@ -467,21 +416,11 @@ static Status ComputeConvPads( static Status HandleAutoPad(const Shape& input_shape, const uint32_t weight_size_y, const uint32_t weight_size_x, - const vector& onnx_strides, - const vector& onnx_dilations, - AutoPadType auto_pad_type, - bool use_nchw, - vector& onnx_pads, - int32_t& nnapi_padding_code, - bool& use_auto_pad) ORT_MUST_USE_RESULT; -static Status HandleAutoPad(const Shape& input_shape, - const uint32_t weight_size_y, - const uint32_t weight_size_x, - const vector& onnx_strides, - const vector& onnx_dilations, + const std::vector& onnx_strides, + const std::vector& onnx_dilations, AutoPadType auto_pad_type, bool use_nchw, - vector& onnx_pads, + std::vector& onnx_pads, int32_t& nnapi_padding_code, bool& use_auto_pad) { use_auto_pad = false; @@ -498,7 +437,7 @@ static Status HandleAutoPad(const Shape& input_shape, } } else if (onnx_dilations == std::vector{1, 1}) { // Since NNAPI runs more efficiently using auto_pad, we try to map the NOTSET padding to auto_pad - vector same_upper_pads; + std::vector same_upper_pads; ORT_RETURN_IF_ERROR(ComputeConvPads(input_shape, weight_size_y, weight_size_x, onnx_pads, onnx_strides, onnx_dilations, AutoPadType::SAME_UPPER, use_nchw, @@ -516,20 +455,15 @@ static Status HandleAutoPad(const Shape& input_shape, // QLinearConv, QLinearMatmul, QLinearAdd // a, b are inputs, and y is output static Status GetBinaryOpQuantizationScaleAndZeroPoint( - const ModelBuilder& model_builder, const Node& node, - float& a_scale, float& b_scale, float& y_scale, - int32_t& a_zero_point, int32_t& b_zero_point, int32_t& y_zero_point) ORT_MUST_USE_RESULT; -static Status GetBinaryOpQuantizationScaleAndZeroPoint( - const ModelBuilder& model_builder, const Node& node, + const InitializedTensorSet& initializers, const NodeUnit& node_unit, float& a_scale, float& b_scale, float& y_scale, int32_t& a_zero_point, int32_t& b_zero_point, int32_t& y_zero_point) { - const auto& initializers = model_builder.GetInitializerTensors(); - ORT_RETURN_IF_ERROR(GetQuantizationScale(initializers, node, 1, a_scale)); - ORT_RETURN_IF_ERROR(GetQuantizationScale(initializers, node, 4, b_scale)); - ORT_RETURN_IF_ERROR(GetQuantizationScale(initializers, node, 6, y_scale)); - ORT_RETURN_IF_ERROR(GetQuantizationZeroPoint(initializers, node, 2, a_zero_point)); - ORT_RETURN_IF_ERROR(GetQuantizationZeroPoint(initializers, node, 5, b_zero_point)); - ORT_RETURN_IF_ERROR(GetQuantizationZeroPoint(initializers, node, 7, y_zero_point)); + ORT_RETURN_IF_ERROR(GetQuantizationScaleAndZeroPoint( + initializers, node_unit.Inputs()[0], node_unit.ModelPath(), a_scale, a_zero_point)); + ORT_RETURN_IF_ERROR(GetQuantizationScaleAndZeroPoint( + initializers, node_unit.Inputs()[1], node_unit.ModelPath(), b_scale, b_zero_point)); + ORT_RETURN_IF_ERROR(GetQuantizationScaleAndZeroPoint( + initializers, node_unit.Outputs()[0], node_unit.ModelPath(), y_scale, y_zero_point)); return Status::OK(); } @@ -544,26 +478,21 @@ static Status GetBinaryOpQuantizationScaleAndZeroPoint( // will be convert to uint8 later, will return the same scale and 128 as zero point // Also will set is_per_tensor_u8s8 to true to be used later static Status GetConvMatMulOpQuantizationScaleAndZeroPoint( - const ModelBuilder& model_builder, const Node& node, + const ModelBuilder& model_builder, const NodeUnit& node_unit, float& a_scale, float& w_scale, float& y_scale, int32_t& a_zero_point, int32_t& w_zero_point, int32_t& y_zero_point, - optional>& w_scales, bool& is_per_tensor_u8s8) ORT_MUST_USE_RESULT; -static Status GetConvMatMulOpQuantizationScaleAndZeroPoint( - const ModelBuilder& model_builder, const Node& node, - float& a_scale, float& w_scale, float& y_scale, - int32_t& a_zero_point, int32_t& w_zero_point, int32_t& y_zero_point, - optional>& w_scales, bool& is_per_tensor_u8s8) { + optional>& w_scales, bool& is_per_tensor_u8s8) { is_per_tensor_u8s8 = false; + const auto& initializers(model_builder.GetInitializerTensors()); // Get scale and zero points // We will handle per-channel weight scale and zero point later ORT_RETURN_IF_ERROR( - GetBinaryOpQuantizationScaleAndZeroPoint(model_builder, node, + GetBinaryOpQuantizationScaleAndZeroPoint(initializers, node_unit, a_scale, w_scale, y_scale, a_zero_point, w_zero_point, y_zero_point)); - const auto input_defs = node.InputDefs(); - const auto& initializers(model_builder.GetInitializerTensors()); - const auto& weight_tensor = *initializers.at(input_defs[3]->Name()); + const auto& inputs = node_unit.Inputs(); + const auto& weight_tensor = *initializers.at(inputs[1].node_arg.Name()); // We are done here is this is u8u8 QLinearConv if (weight_tensor.data_type() == ONNX_NAMESPACE::TensorProto_DataType_UINT8) @@ -574,7 +503,7 @@ static Status GetConvMatMulOpQuantizationScaleAndZeroPoint( // For this case we will need to convert the int8 weight tensor to uint8 // And have same scale and 128 as zero point // The conversion of the weight tensor itself will be done in the OpBuilder - const auto& scale_tensor = *initializers.at(input_defs[4]->Name()); + const auto& scale_tensor = *initializers.at(inputs[1].quant_param->scale.Name()); int64_t scale_dim = scale_tensor.dims().empty() ? 1 : scale_tensor.dims()[0]; if (scale_dim == 1) { w_zero_point = 128; @@ -593,7 +522,7 @@ static Status GetConvMatMulOpQuantizationScaleAndZeroPoint( ORT_RETURN_IF_ERROR(onnxruntime::utils::UnpackInitializerData(scale_tensor, unpacked_tensor)); const float* scales = reinterpret_cast(unpacked_tensor.data()); const size_t scales_size = scale_tensor.dims().empty() ? 1 : scale_tensor.dims()[0]; - vector scales_vec(scales, scales + scales_size); + std::vector scales_vec(scales, scales + scales_size); w_scales = onnxruntime::make_optional(std::move(scales_vec)); return Status::OK(); } @@ -601,10 +530,6 @@ static Status GetConvMatMulOpQuantizationScaleAndZeroPoint( // NNAPI has the quantization scale and zero point embedded in the ANeuralNetworksOperandType // ONNX has the quantization scale and zero point as the inputs of the qlinear operators // We want to verify the scale and zeropoint of the ONNX inputs matches the values embedded in the NNAPI inputs -static Status IsValidInputQuantizedType(const ModelBuilder& model_builder, - const std::string& input_name, - float scale, - int32_t zero_point) ORT_MUST_USE_RESULT; static Status IsValidInputQuantizedType(const ModelBuilder& model_builder, const std::string& input_name, float scale, @@ -631,12 +556,7 @@ static Status IsValidConvWeightQuantizedType(const ModelBuilder& model_builder, const std::string& input_name, float scale, int32_t zero_point, - const optional>& scales) ORT_MUST_USE_RESULT; -static Status IsValidConvWeightQuantizedType(const ModelBuilder& model_builder, - const std::string& input_name, - float scale, - int32_t zero_point, - const optional>& scales) { + const optional>& scales) { // first verify as the weight has no per-channel quantization ORT_RETURN_IF_ERROR(IsValidInputQuantizedType(model_builder, input_name, scale, zero_point)); @@ -656,57 +576,23 @@ static Status IsValidConvWeightQuantizedType(const ModelBuilder& model_builder, return Status::OK(); } -static void AddBinaryOpQuantizationScaleAndZeroPointToSkip(ModelBuilder& model_builder, const NodeUnit& node_unit) { - const auto& node = node_unit.GetNode(); - const auto input_defs(node.InputDefs()); - model_builder.AddInitializerToSkip(input_defs[1]->Name()); // a_scale - model_builder.AddInitializerToSkip(input_defs[2]->Name()); // a_zero_point - model_builder.AddInitializerToSkip(input_defs[4]->Name()); // b_scale - model_builder.AddInitializerToSkip(input_defs[5]->Name()); // b_zero_point - model_builder.AddInitializerToSkip(input_defs[6]->Name()); // y_scale - model_builder.AddInitializerToSkip(input_defs[7]->Name()); // y_zero_point -} - -Status GetQuantizedInputScaleAndZeroPoint(const InitializedTensorSet& initializers, - const NodeUnit& node_unit, - const std::string& input_name, - float& scale, - int32_t& zero_point) { - const auto& node = node_unit.GetNode(); - const auto& op_type = node.OpType(); - auto qlinear_op_type = GetQLinearOpType(node); - assert(qlinear_op_type != QLinearOpType::Unknown && - qlinear_op_type != QLinearOpType::QuantizeLinear); - - size_t scale_idx, zero_point_idx; - if (qlinear_op_type == QLinearOpType::DequantizeLinear || - qlinear_op_type == QLinearOpType::QLinearSigmoid || - qlinear_op_type == QLinearOpType::QLinearAveragePool) { - scale_idx = 1; - zero_point_idx = 2; - } else if (IsQLinearBinaryOp(qlinear_op_type)) { - const auto input_defs(node.InputDefs()); - if (input_name == input_defs[0]->Name()) { - scale_idx = 1; - zero_point_idx = 2; - } else if (input_name == input_defs[3]->Name()) { - scale_idx = 4; - zero_point_idx = 5; - } else { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Unknown input: ", input_name, ", for op: ", op_type); - } - } else { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Unsupported op: ", op_type); - } - - ORT_RETURN_IF_ERROR(GetQuantizationScale(initializers, node, scale_idx, scale)); - zero_point = 0; - if (node.InputDefs().size() > zero_point_idx) { - ORT_RETURN_IF_ERROR(GetQuantizationZeroPoint(initializers, node, zero_point_idx, zero_point)); +static void AddQuantizationScaleAndZeroPointToSkip(ModelBuilder& model_builder, + const NodeUnitIODef::QuantParam& quant_param) { + // If we reach here, we assume the io_def has quant_param + model_builder.AddInitializerToSkip(quant_param.scale.Name()); // scale + LOGS_DEFAULT(VERBOSE) << quant_param.scale.Name() << "is skipped"; + if (quant_param.zero_point) { + model_builder.AddInitializerToSkip(quant_param.zero_point->Name()); // zero_point + LOGS_DEFAULT(VERBOSE) << quant_param.zero_point->Name() << "is skipped"; } +} - return Status::OK(); +// Ignore the input (with quantization scale and ZP if available) +// The input (usually weight) is already embedded in the NNAPI model +static void AddInputToSkip(ModelBuilder& model_builder, const NodeUnitIODef& io_def) { + model_builder.AddInitializerToSkip(io_def.node_arg.Name()); // main input + if (io_def.quant_param) + AddQuantizationScaleAndZeroPointToSkip(model_builder, *io_def.quant_param); } template @@ -731,20 +617,24 @@ class BaseOpBuilder : public IOpBuilder { public: virtual ~BaseOpBuilder() = default; virtual void AddInitializersToSkip(ModelBuilder& /* model_builder */, const NodeUnit& /* node_unit */) const override {} - Status AddToModelBuilder(ModelBuilder& model_builder, const NodeUnit& node_unit) const override final ORT_MUST_USE_RESULT; + Status AddToModelBuilder(ModelBuilder& model_builder, const NodeUnit& node_unit) const override final; protected: - virtual Status AddToModelBuilderImpl(ModelBuilder& model_builder, const NodeUnit& node_unit) const ORT_MUST_USE_RESULT = 0; + virtual Status AddToModelBuilderImpl(ModelBuilder& model_builder, const NodeUnit& node_unit) const = 0; + static bool IsOpSupported(const ModelBuilder& model_builder, const NodeUnit& node_unit) ORT_MUST_USE_RESULT; }; -Status BaseOpBuilder::AddToModelBuilder(ModelBuilder& model_builder, const NodeUnit& node_unit) const { +/* static */ bool BaseOpBuilder::IsOpSupported(const ModelBuilder& model_builder, const NodeUnit& node_unit) { OpSupportCheckParams params{ model_builder.GetNNAPIFeatureLevel(), model_builder.UseNCHW(), }; - ORT_RETURN_IF_NOT(IsNodeSupported(node_unit, model_builder.GetGraphViewer(), params), - "Unsupported operator ", node_unit.OpType()); + return IsNodeSupported(node_unit, model_builder.GetGraphViewer(), params); +} + +Status BaseOpBuilder::AddToModelBuilder(ModelBuilder& model_builder, const NodeUnit& node_unit) const { + ORT_RETURN_IF_NOT(IsOpSupported(model_builder, node_unit), "Unsupported operator ", node_unit.OpType()); ORT_RETURN_IF_ERROR(AddToModelBuilderImpl(model_builder, node_unit)); LOGS_DEFAULT(VERBOSE) << "Operator name: [" << node_unit.Name() << "] type: [" << node_unit.OpType() << "] was added"; @@ -761,14 +651,23 @@ class BinaryOpBuilder : public BaseOpBuilder { static void CreateSharedOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); private: - Status AddToModelBuilderImpl(ModelBuilder& model_builder, const NodeUnit& node_unit) const override ORT_MUST_USE_RESULT; + static bool IsQuantizedOp(const NodeUnit& node_unit) ORT_MUST_USE_RESULT; // TODO, see if we want to move this to BaseOpBuilder + Status AddToModelBuilderImpl(ModelBuilder& model_builder, const NodeUnit& node_unit) const override; }; +/* static */ bool BinaryOpBuilder::IsQuantizedOp(const NodeUnit& node_unit) { + // TODO, add support for QDQ NodeUnit + return node_unit.OpType() == "QLinearAdd"; +} + void BinaryOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const NodeUnit& node_unit) const { - const auto& op = node_unit.OpType(); - if (op == "QLinearAdd") { - AddBinaryOpQuantizationScaleAndZeroPointToSkip(model_builder, node_unit); - } + if (!IsQuantizedOp(node_unit)) + return; + + const auto& inputs = node_unit.Inputs(); + AddQuantizationScaleAndZeroPointToSkip(model_builder, *inputs[0].quant_param); // a_scale, a_zp + AddQuantizationScaleAndZeroPointToSkip(model_builder, *inputs[1].quant_param); // b_scale, b_zp + AddQuantizationScaleAndZeroPointToSkip(model_builder, *node_unit.Outputs()[0].quant_param); // y_scale, y_zp } /* static */ void BinaryOpBuilder::CreateSharedOpBuilder( @@ -786,9 +685,8 @@ void BinaryOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const N } Status BinaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const NodeUnit& node_unit) const { - const auto& node = node_unit.GetNode(); - const auto& op_type(node.OpType()); - const auto input_defs(node.InputDefs()); + const auto& op_type(node_unit.OpType()); + const auto& inputs = node_unit.Inputs(); int32_t op_code; bool add_activation = true; @@ -808,18 +706,13 @@ Status BinaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "UnaryOpBuilder, unknown op: ", op_type); } - size_t a_idx = 0, b_idx = 1; - if (op_is_qlinear) { - b_idx = 3; - } - - std::string input1 = input_defs[a_idx]->Name(); - std::string input2 = input_defs[b_idx]->Name(); - const auto& output = node.OutputDefs()[0]->Name(); + std::string input1 = inputs[0].node_arg.Name(); + std::string input2 = inputs[1].node_arg.Name(); + const auto& output = node_unit.Outputs()[0].node_arg.Name(); bool output_is_nhwc = false; ORT_RETURN_IF_ERROR( - TransposeBinaryOpInputLayout(model_builder, node, a_idx, b_idx, input1, input2, output_is_nhwc)); + TransposeBinaryOpInputLayout(model_builder, node_unit, input1, input2, output_is_nhwc)); float a_scale = 0.0f, b_scale = 0.0f, @@ -829,9 +722,10 @@ Status BinaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const y_zero_point = 0; if (op_is_qlinear) { - ORT_RETURN_IF_ERROR(GetBinaryOpQuantizationScaleAndZeroPoint(model_builder, node, - a_scale, b_scale, y_scale, - a_zero_point, b_zero_point, y_zero_point)); + ORT_RETURN_IF_ERROR(GetBinaryOpQuantizationScaleAndZeroPoint( + model_builder.GetInitializerTensors(), node_unit, + a_scale, b_scale, y_scale, + a_zero_point, b_zero_point, y_zero_point)); } // Verify if the scale and zero point matchs from onnx input and nnapi input match @@ -842,7 +736,7 @@ Status BinaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const int32_t fuse_code = ANEURALNETWORKS_FUSED_NONE; if (add_activation) { - fuse_code = model_builder.FindActivation(node, *node.OutputDefs()[0]); + fuse_code = model_builder.FindActivation(node_unit); } return AddBinaryOperator(op_code, model_builder, @@ -857,24 +751,23 @@ Status BinaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const class ReluOpBuilder : public BaseOpBuilder { private: - Status AddToModelBuilderImpl(ModelBuilder& model_builder, const NodeUnit& node_unit) const override ORT_MUST_USE_RESULT; + Status AddToModelBuilderImpl(ModelBuilder& model_builder, const NodeUnit& node_unit) const override; }; Status ReluOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const NodeUnit& node_unit) const { - const auto& node = node_unit.GetNode(); auto& shaper(model_builder.GetShaper()); const auto& operand_indices(model_builder.GetOperandIndices()); const auto& operand_types(model_builder.GetOperandTypes()); - const auto& input = node.InputDefs()[0]->Name(); - const auto& output = node.OutputDefs()[0]->Name(); + const auto& input = node_unit.Inputs()[0].node_arg.Name(); + const auto& output = node_unit.Outputs()[0].node_arg.Name(); bool output_is_nhwc = model_builder.IsOperandNHWC(input); ORT_RETURN_IF_ERROR(shaper.Identity(input, output)); const OperandType output_operand_type(operand_types.at(input).type, shaper[output]); // skip this relu if it is some op's fuse output if (Contains(model_builder.GetFusedActivations(), input)) { - LOGS_DEFAULT(VERBOSE) << "Relu Node [" << node.Name() << "] fused"; + LOGS_DEFAULT(VERBOSE) << "Relu Node [" << node_unit.Name() << "] fused"; model_builder.RegisterOperand(output, operand_indices.at(input), output_operand_type, output_is_nhwc); } else { std::vector input_indices; @@ -892,17 +785,16 @@ Status ReluOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N class TransposeOpBuilder : public BaseOpBuilder { private: - Status AddToModelBuilderImpl(ModelBuilder& model_builder, const NodeUnit& node_unit) const override ORT_MUST_USE_RESULT; + Status AddToModelBuilderImpl(ModelBuilder& model_builder, const NodeUnit& node_unit) const override; }; Status TransposeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const NodeUnit& node_unit) const { - const auto& node = node_unit.GetNode(); auto& shaper(model_builder.GetShaper()); - auto input = node.InputDefs()[0]->Name(); - const auto& output = node.OutputDefs()[0]->Name(); - NodeAttrHelper helper(node); - vector perm = helper.Get("perm", vector()); + const auto& input = node_unit.Inputs()[0].node_arg.Name(); + const auto& output = node_unit.Outputs()[0].node_arg.Name(); + NodeAttrHelper helper(node_unit); + std::vector perm = helper.Get("perm", std::vector()); auto input_dims = shaper[input].size(); if (perm.empty()) { for (int32_t i = input_dims - 1; i >= 0; i--) @@ -920,7 +812,7 @@ Status TransposeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, co perm[i] = axis_nchw_to_nhwc[perm[i]]; } - std::string perm_name = model_builder.GetUniqueName(node.Name() + input + "perm"); + std::string perm_name = model_builder.GetUniqueName(node_unit.Name() + input + "perm"); // It is possible this onnx transpose operator can be nchw->nhwc, but so far I don't see // any scenario will do this since onnx is nchw only, assume the output is always not nhwc @@ -938,17 +830,17 @@ Status TransposeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, co class ReshapeOpBuilder : public BaseOpBuilder { public: void AddInitializersToSkip(ModelBuilder& model_builder, const NodeUnit& node_unit) const override; - static Status AddReshapeOperator(ModelBuilder& model_builder, const Node& node, - const std::string& input, const std::vector& shape) ORT_MUST_USE_RESULT; + static Status AddReshapeOperator(ModelBuilder& model_builder, const NodeUnit& node_unit, + const std::string& input, const std::vector& shape); private: - Status AddToModelBuilderImpl(ModelBuilder& model_builder, const NodeUnit& node_unit) const override ORT_MUST_USE_RESULT; - static bool CanSkipReshape(const ModelBuilder& model_builder, const Node& node, size_t input_rank, size_t output_rank); + Status AddToModelBuilderImpl(ModelBuilder& model_builder, const NodeUnit& node_unit) const override; + static bool CanSkipReshape(const ModelBuilder& model_builder, const NodeUnit& node_unit, + size_t input_rank, size_t output_rank); }; void ReshapeOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const NodeUnit& node_unit) const { - const auto& node = node_unit.GetNode(); - model_builder.AddInitializerToSkip(node.InputDefs()[1]->Name()); + model_builder.AddInitializerToSkip(node_unit.Inputs()[1].node_arg.Name()); } // We can skip the Reshape if all the output edges satisfies both the following conditions @@ -963,25 +855,34 @@ void ReshapeOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const // between NNAPI CPU impl and Hardware Accelerator impl and will speed up the execution // If we are going to skip the reshape, we will still add correct shape and operand type for the output in // onnxruntime::nnapi::Model. -/* static */ bool ReshapeOpBuilder::CanSkipReshape(const ModelBuilder& model_builder, const Node& node, +/* static */ bool ReshapeOpBuilder::CanSkipReshape(const ModelBuilder& model_builder, const NodeUnit& node_unit, size_t input_rank, size_t output_rank) { - const auto& output = node.OutputDefs()[0]->Name(); + const auto& output_node_arg = node_unit.Outputs()[0].node_arg; + const auto& output_name = output_node_arg.Name(); + const auto& output_node = *node_unit.GetOutputNodes()[0]; + // We will go through all the output edges - for (auto it = node.OutputEdgesBegin(), end = node.OutputEdgesEnd(); it != end; ++it) { - const auto& op_type = it->GetNode().OpType(); + for (auto it = output_node.OutputEdgesBegin(), end = output_node.OutputEdgesEnd(); it != end; ++it) { + const auto& dest_node_unit = model_builder.GetNodeUnit(&it->GetNode()); + const auto& op_type = dest_node_unit.OpType(); // TODO add quantized matmul when reshape support quantized input if (op_type != "Gemm" && op_type != "MatMul") { LOGS_DEFAULT(VERBOSE) << "Reshape/Flatten can only be skipped when the output is Gemm/Matmul" << " or no op is using the output (output is graph output)" - << ", output name, " << output + << ", output name, " << output_name << " is used by " << op_type; return false; } + // Now the dest node is Gemm/Matmul, we want to make sure it is supported + if (!BaseOpBuilder::IsOpSupported(model_builder, node_unit)) { + return false; + } + // NNAPI ANEURALNETWORKS_FULLY_CONNECTED will only flatten the input 0 - if (it->GetDstArgIndex() != 0) { + if (&output_node_arg != &dest_node_unit.Inputs()[0].node_arg) { LOGS_DEFAULT(VERBOSE) << "Reshape/Flatten can only be skipped when the output is input 0 of Gemm/Matmul" - << ", output name, " << output; + << ", output name, " << output_name; return false; } @@ -989,7 +890,7 @@ void ReshapeOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const // And NNAPI ANEURALNETWORKS_FULLY_CONNECTED will only flatten input rank >= 2 if (input_rank < 2 || output_rank != 2) { LOGS_DEFAULT(VERBOSE) << "Reshape/Flatten can only be skipped when input_rank >= 2 and output_rank == 2" - << ", output name, " << output + << ", output name, " << output_name << ", the actual input_rank, " << input_rank << ", the actual output_rank, " << output_rank; return false; @@ -1000,26 +901,26 @@ void ReshapeOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const // Check if the Reshape output is a graph output, if so we cannot skip the Reshape // We do not care the case where the Reshape output is a dead end for (const auto* node_arg : model_builder.GetGraphViewer().GetOutputs()) { - if (node_arg->Name() == output) { + if (node_arg == &output_node_arg) { LOGS_DEFAULT(VERBOSE) << "Reshape/Flatten can not be skipped when the output is a graph output" - << ", output name, " << output; + << ", output name, " << output_name; return false; } } LOGS_DEFAULT(VERBOSE) << "Skipping Reshape/Flatten node [" - << node.Name() << "] with output, " << output; + << node_unit.Name() << "] with output, " << output_name; return true; } /* static */ Status ReshapeOpBuilder::AddReshapeOperator(ModelBuilder& model_builder, - const Node& node, + const NodeUnit& node_unit, const std::string& input, const std::vector& shape) { auto& shaper(model_builder.GetShaper()); const auto& operand_indices(model_builder.GetOperandIndices()); const auto& operand_types(model_builder.GetOperandTypes()); - const auto& output = node.OutputDefs()[0]->Name(); + const auto& output = node_unit.Outputs()[0].node_arg.Name(); ORT_RETURN_IF_ERROR(shaper.Reshape(input, shape, output)); auto input_rank = shaper[input].size(); auto output_rank = shaper[output].size(); @@ -1027,7 +928,7 @@ void ReshapeOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const // Since Reshape is not running using hardware in NNAPI for some CPU (e.g. Qualcomm SD for now) // We will try to see if we the skip the Reshape to prevent context switching between // NNAPI CPU impl and NNAPI hardware accelerator impl - if (CanSkipReshape(model_builder, node, input_rank, output_rank)) { + if (CanSkipReshape(model_builder, node_unit, input_rank, output_rank)) { // Since reshape can be skipped, only register the dimension and type, with same index and new name const OperandType output_operand_type(operand_types.at(input).type, shaper[output]); model_builder.RegisterOperand(output, operand_indices.at(input), output_operand_type, false); @@ -1038,7 +939,7 @@ void ReshapeOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const input_indices.push_back(operand_indices.at(input)); // Add new shape Shape shape_dimen = {static_cast(shape.size())}; - std::string shape_name = model_builder.GetUniqueName(node.Name() + input + "newshape"); + std::string shape_name = model_builder.GetUniqueName(node_unit.Name() + input + "newshape"); OperandType shape_operand_type(Type::TENSOR_INT32, shape_dimen); ORT_RETURN_IF_ERROR(model_builder.AddOperandFromPersistMemoryBuffer(shape_name, shape.data(), shape_operand_type)); input_indices.push_back(operand_indices.at(shape_name)); @@ -1051,17 +952,16 @@ void ReshapeOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const } Status ReshapeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const NodeUnit& node_unit) const { - const auto& node = node_unit.GetNode(); auto& shaper(model_builder.GetShaper()); const auto& initializers(model_builder.GetInitializerTensors()); - auto input = node.InputDefs()[0]->Name(); + auto input = node_unit.Inputs()[0].node_arg.Name(); if (model_builder.IsOperandNHWC(input)) { // We want to transpose nhwc operand back to nchw before reshape - ORT_RETURN_IF_ERROR(GetNCHWInput(model_builder, node, 0, input)); + ORT_RETURN_IF_ERROR(GetNCHWInput(model_builder, node_unit, 0, input)); } - const auto& shape_tensor = *initializers.at(node.InputDefs()[1]->Name()); + const auto& shape_tensor = *initializers.at(node_unit.Inputs()[1].node_arg.Name()); std::vector unpacked_tensor; ORT_RETURN_IF_ERROR(onnxruntime::utils::UnpackInitializerData(shape_tensor, unpacked_tensor)); const int64_t* raw_shape = reinterpret_cast(unpacked_tensor.data()); @@ -1075,7 +975,7 @@ Status ReshapeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, cons shape[i] = dim == 0 ? input_shape[i] : dim; } - return AddReshapeOperator(model_builder, node, input, shape); + return AddReshapeOperator(model_builder, node_unit, input, shape); } #pragma endregion op_reshape @@ -1087,38 +987,37 @@ class BatchNormalizationOpBuilder : public BaseOpBuilder { void AddInitializersToSkip(ModelBuilder& model_builder, const NodeUnit& node_unit) const override; private: - Status AddToModelBuilderImpl(ModelBuilder& model_builder, const NodeUnit& node_unit) const override ORT_MUST_USE_RESULT; + Status AddToModelBuilderImpl(ModelBuilder& model_builder, const NodeUnit& node_unit) const override; }; void BatchNormalizationOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const NodeUnit& node_unit) const { - const auto& node = node_unit.GetNode(); // skip everything except input0 for BatchNormalization - model_builder.AddInitializerToSkip(node.InputDefs()[1]->Name()); // scale - model_builder.AddInitializerToSkip(node.InputDefs()[2]->Name()); // B - model_builder.AddInitializerToSkip(node.InputDefs()[3]->Name()); // mean - model_builder.AddInitializerToSkip(node.InputDefs()[4]->Name()); //var + model_builder.AddInitializerToSkip(node_unit.Inputs()[1].node_arg.Name()); // scale + model_builder.AddInitializerToSkip(node_unit.Inputs()[2].node_arg.Name()); // B + model_builder.AddInitializerToSkip(node_unit.Inputs()[3].node_arg.Name()); // mean + model_builder.AddInitializerToSkip(node_unit.Inputs()[4].node_arg.Name()); //var } Status BatchNormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const NodeUnit& node_unit) const { - const auto& node = node_unit.GetNode(); auto& shaper(model_builder.GetShaper()); const auto& operand_types(model_builder.GetOperandTypes()); const auto& initializers(model_builder.GetInitializerTensors()); - NodeAttrHelper helper(node); + NodeAttrHelper helper(node_unit); + const auto& inputs = node_unit.Inputs(); // For reshape we are not really doing anything but // register a new operand with new shape - const auto& input = node.InputDefs()[0]->Name(); - const auto& output = node.OutputDefs()[0]->Name(); + const auto& input = inputs[0].node_arg.Name(); + const auto& output = node_unit.Outputs()[0].node_arg.Name(); - const auto& scale_tensor = *initializers.at(node.InputDefs()[1]->Name()); - const auto& bias_tensor = *initializers.at(node.InputDefs()[2]->Name()); - const auto& mean_tensor = *initializers.at(node.InputDefs()[3]->Name()); - const auto& var_tensor = *initializers.at(node.InputDefs()[4]->Name()); + const auto& scale_tensor = *initializers.at(inputs[1].node_arg.Name()); + const auto& bias_tensor = *initializers.at(inputs[2].node_arg.Name()); + const auto& mean_tensor = *initializers.at(inputs[3].node_arg.Name()); + const auto& var_tensor = *initializers.at(inputs[4].node_arg.Name()); const auto eps = helper.Get("epsilon", 1e-5f); const auto size = SafeInt(scale_tensor.dims()[0]); - vector a, b; + std::vector a, b; a.reserve(size); b.reserve(size); @@ -1144,9 +1043,9 @@ Status BatchNormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_bu bias_data[i]); } - const auto tensor_a_name = model_builder.GetUniqueName(node.Name() + input + "_imm_a"); - const auto tensor_b_name = model_builder.GetUniqueName(node.Name() + input + "_imm_b"); - const auto tensor_imm_product_name = model_builder.GetUniqueName(node.Name() + input + "_imm_mul"); + const auto tensor_a_name = model_builder.GetUniqueName(node_unit.Name() + input + "_imm_a"); + const auto tensor_b_name = model_builder.GetUniqueName(node_unit.Name() + input + "_imm_b"); + const auto tensor_imm_product_name = model_builder.GetUniqueName(node_unit.Name() + input + "_imm_mul"); Shape tensor_a_dimen = {size}; bool input_is_nhwc = model_builder.IsOperandNHWC(input); @@ -1180,7 +1079,7 @@ Status BatchNormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_bu output_is_nhwc)); // Add - int32_t fuse_code = model_builder.FindActivation(node, *node.OutputDefs()[0]); + int32_t fuse_code = model_builder.FindActivation(node_unit); ORT_RETURN_IF_ERROR(AddBinaryOperator(ANEURALNETWORKS_ADD, model_builder, tensor_imm_product_name, tensor_b_name, @@ -1201,24 +1100,22 @@ class PoolOpBuilder : public BaseOpBuilder { static void CreateSharedOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); private: - Status AddToModelBuilderImpl(ModelBuilder& model_builder, const NodeUnit& node_unit) const override ORT_MUST_USE_RESULT; + static bool IsQuantizedOp(const NodeUnit& node_unit) ORT_MUST_USE_RESULT; // TODO, see if we want to move this to BaseOpBuilder + Status AddToModelBuilderImpl(ModelBuilder& model_builder, const NodeUnit& node_unit) const override; }; +/* static */ bool PoolOpBuilder::IsQuantizedOp(const NodeUnit& node_unit) { + // TODO, add support for QDQ NodeUnit + return node_unit.OpType() == "QLinearAveragePool"; +} + void PoolOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const NodeUnit& node_unit) const { - const auto& node = node_unit.GetNode(); - const auto& op = node.OpType(); - if (op != "QLinearAveragePool") + if (!IsQuantizedOp(node_unit)) return; - const auto input_defs = node.InputDefs(); - // skip input/output scales and zeropoints - model_builder.AddInitializerToSkip(input_defs[1]->Name()); // X_scale - model_builder.AddInitializerToSkip(input_defs[2]->Name()); // X_zero_point - model_builder.AddInitializerToSkip(input_defs[3]->Name()); // Y_scale - - if (input_defs.size() == 5) // has Y_zero_point input - model_builder.AddInitializerToSkip(input_defs[4]->Name()); // Y_zero_point + AddQuantizationScaleAndZeroPointToSkip(model_builder, *node_unit.Inputs()[0].quant_param); // x_scale, x_zp + AddQuantizationScaleAndZeroPointToSkip(model_builder, *node_unit.Outputs()[0].quant_param); // y_scale, y_zp } /* static */ void PoolOpBuilder::CreateSharedOpBuilder( @@ -1235,15 +1132,13 @@ void PoolOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Nod } Status PoolOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const NodeUnit& node_unit) const { - const auto& node = node_unit.GetNode(); - auto& shaper(model_builder.GetShaper()); const auto& operand_indices(model_builder.GetOperandIndices()); const auto& operand_types(model_builder.GetOperandTypes()); - NodeAttrHelper helper(node); + NodeAttrHelper helper(node_unit); - auto input = node.InputDefs()[0]->Name(); + auto input = node_unit.Inputs()[0].node_arg.Name(); bool use_nchw = model_builder.UseNCHW(); bool input_is_nhwc = model_builder.IsOperandNHWC(input); bool output_is_nhwc = false; @@ -1252,12 +1147,12 @@ Status PoolOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N } else { output_is_nhwc = true; if (!input_is_nhwc) { - ORT_RETURN_IF_ERROR(GetNHWCInput(model_builder, node, 0, input)); + ORT_RETURN_IF_ERROR(GetNHWCInput(model_builder, node_unit, 0, input)); } } - const auto& output = node.OutputDefs()[0]->Name(); - const auto& op_type = node.OpType(); + const auto& output = node_unit.Outputs()[0].node_arg.Name(); + const auto& op_type = node_unit.OpType(); int32_t op_code; bool is_qlinear_average_pool = op_type == "QLinearAveragePool"; @@ -1267,15 +1162,15 @@ Status PoolOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N else // (op_type == "MaxPool" || op_type == "GlobalMaxPool") op_code = ANEURALNETWORKS_MAX_POOL_2D; - vector onnx_pads, onnx_strides, kernel_shape; + std::vector onnx_pads, onnx_strides, kernel_shape; bool use_auto_pad = false; int32_t nnapi_padding_code = ANEURALNETWORKS_PADDING_VALID; const auto& input_shape = shaper[input]; if (is_average_pool || op_type == "MaxPool") { const auto auto_pad_type = StringToAutoPadType(helper.Get("auto_pad", "NOTSET")); - kernel_shape = helper.Get("kernel_shape", vector{0, 0}); - onnx_strides = helper.Get("strides", vector{1, 1}); - onnx_pads = helper.Get("pads", vector{0, 0, 0, 0}); + kernel_shape = helper.Get("kernel_shape", std::vector{0, 0}); + onnx_strides = helper.Get("strides", std::vector{1, 1}); + onnx_pads = helper.Get("pads", std::vector{0, 0, 0, 0}); const auto weight_size_y = static_cast(kernel_shape[0]); const auto weight_size_x = static_cast(kernel_shape[1]); ORT_RETURN_IF_ERROR( @@ -1286,18 +1181,18 @@ Status PoolOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N } else { // (op_type == "GlobalAveragePool" || op_type == "GlobalMaxPool") use_auto_pad = true; nnapi_padding_code = ANEURALNETWORKS_PADDING_VALID; - onnx_strides = vector{1, 1}; - onnx_pads = vector{0, 0, 0, 0}; + onnx_strides = std::vector{1, 1}; + onnx_pads = std::vector{0, 0, 0, 0}; if (use_nchw) { - kernel_shape = vector{static_cast(input_shape[2]), - static_cast(input_shape[3])}; + kernel_shape = std::vector{static_cast(input_shape[2]), + static_cast(input_shape[3])}; } else { - kernel_shape = vector{static_cast(input_shape[1]), - static_cast(input_shape[2])}; + kernel_shape = std::vector{static_cast(input_shape[1]), + static_cast(input_shape[2])}; } } - int32_t fuse_code = model_builder.FindActivation(node, *node.OutputDefs()[0]); + int32_t fuse_code = model_builder.FindActivation(node_unit); // Get output scale and zero point if this is QLinearAveragePool // Otherwise we will use the scale and zero point of the input @@ -1307,16 +1202,14 @@ Status PoolOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N if (is_qlinear_average_pool) { const auto& initializers = model_builder.GetInitializerTensors(); float x_scale = 0.0f; - ORT_RETURN_IF_ERROR(GetQuantizationScale(initializers, node, 1 /* idx */, x_scale)); int32_t x_zero_point = 0; - ORT_RETURN_IF_ERROR(GetQuantizationZeroPoint(initializers, node, 2 /* idx */, x_zero_point)); + ORT_RETURN_IF_ERROR(GetQuantizationScaleAndZeroPoint( + initializers, node_unit.Inputs()[0], node_unit.ModelPath(), x_scale, x_zero_point)); // Verify if the scale and zero point values from onnx input and nnapi input match ORT_RETURN_IF_ERROR(IsValidInputQuantizedType(model_builder, input, x_scale, x_zero_point)); - - ORT_RETURN_IF_ERROR(GetQuantizationScale(initializers, node, 3 /* idx */, y_scale)); - if (node.InputDefs().size() > 4) - ORT_RETURN_IF_ERROR(GetQuantizationZeroPoint(initializers, node, 4 /* idx */, y_zero_point)); + ORT_RETURN_IF_ERROR(GetQuantizationScaleAndZeroPoint( + initializers, node_unit.Outputs()[0], node_unit.ModelPath(), y_scale, y_zero_point)); } std::vector input_indices; @@ -1361,10 +1254,17 @@ class ConvOpBuilder : public BaseOpBuilder { static void CreateSharedOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); private: - Status AddToModelBuilderImpl(ModelBuilder& model_builder, const NodeUnit& node_unit) const override ORT_MUST_USE_RESULT; + static bool IsQuantizedOp(const NodeUnit& node_unit) ORT_MUST_USE_RESULT; // TODO, see if we want to move this to BaseOpBuilder + Status AddToModelBuilderImpl(ModelBuilder& model_builder, const NodeUnit& node_unit) const override; }; -/* static */ void ConvOpBuilder::CreateSharedOpBuilder( +/* static */ bool ConvOpBuilder::IsQuantizedOp(const NodeUnit& node_unit) { + // TODO, add support for QDQ NodeUnit + return node_unit.OpType() == "QLinearConv"; +} + +/* static */ void +ConvOpBuilder::CreateSharedOpBuilder( const std::string& op_type, OpBuilderRegistrations& op_registrations) { CreateSharedOpBuilderImpl( op_type, op_registrations, @@ -1375,50 +1275,42 @@ class ConvOpBuilder : public BaseOpBuilder { } void ConvOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const NodeUnit& node_unit) const { - const auto& node = node_unit.GetNode(); - const auto& op = node.OpType(); - const auto input_defs = node.InputDefs(); - + const auto& inputs = node_unit.Inputs(); // skip the weight for conv as we need to transpose - if (op == "QLinearConv") { - AddBinaryOpQuantizationScaleAndZeroPointToSkip(model_builder, node_unit); - model_builder.AddInitializerToSkip(input_defs[3]->Name()); // w - if (input_defs.size() > 8) - model_builder.AddInitializerToSkip(input_defs[8]->Name()); // B + if (IsQuantizedOp(node_unit)) { + AddQuantizationScaleAndZeroPointToSkip(model_builder, *inputs[0].quant_param); // x_scale, x_zp + AddInputToSkip(model_builder, inputs[1]); // w, w_scale, w_zp + AddQuantizationScaleAndZeroPointToSkip(model_builder, *node_unit.Outputs()[0].quant_param); // y_scale, y_zp + if (inputs.size() > 2) + AddInputToSkip(model_builder, inputs[2]); // B, B_scale, B_zp } else { - model_builder.AddInitializerToSkip(input_defs[1]->Name()); // w + model_builder.AddInitializerToSkip(inputs[1].node_arg.Name()); // w } } Status ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const NodeUnit& node_unit) const { - const auto& node = node_unit.GetNode(); auto& shaper(model_builder.GetShaper()); const auto& operand_indices(model_builder.GetOperandIndices()); const auto& operand_types(model_builder.GetOperandTypes()); const auto& initializers(model_builder.GetInitializerTensors()); - NodeAttrHelper helper(node); - const auto input_defs = node.InputDefs(); - const auto& op_type = node.OpType(); - bool is_qlinear_conv = (op_type == "QLinearConv"); + NodeAttrHelper helper(node_unit); + const auto inputs = node_unit.Inputs(); + bool is_qlinear_conv = IsQuantizedOp(node_unit); // onnx strides are in the order height, width // while nnapi strides are in the order width, height - const auto onnx_strides = helper.Get("strides", vector{1, 1}); + const auto onnx_strides = helper.Get("strides", std::vector{1, 1}); // onnx pads are in the order top, left, bottom, right // while nnapi pads is in the order left, right, top, bottom - auto onnx_pads = helper.Get("pads", vector{0, 0, 0, 0}); + auto onnx_pads = helper.Get("pads", std::vector{0, 0, 0, 0}); // onnx dilations is in the order height, width // while nnapi dilations are in the order width, height - const auto onnx_dilations = helper.Get("dilations", vector{1, 1}); + const auto onnx_dilations = helper.Get("dilations", std::vector{1, 1}); const auto group = helper.Get("group", 1); - size_t x_idx = 0, - w_idx = is_qlinear_conv ? 3 : 1, - b_idx = is_qlinear_conv ? 8 : 2; - - auto input = input_defs[x_idx]->Name(); + auto input = inputs[0].node_arg.Name(); bool use_nchw = model_builder.UseNCHW(); bool input_is_nhwc = model_builder.IsOperandNHWC(input); bool output_is_nhwc = false; @@ -1427,13 +1319,13 @@ Status ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N } else { output_is_nhwc = true; if (!input_is_nhwc) { - ORT_RETURN_IF_ERROR(GetNHWCInput(model_builder, node, x_idx, input)); + ORT_RETURN_IF_ERROR(GetNHWCInput(model_builder, node_unit, 0, input)); } } - const auto& weight = input_defs[w_idx]->Name(); + const auto& weight = inputs[1].node_arg.Name(); const auto& weight_tensor = *initializers.at(weight); - auto conv_type = GetConvType(node, model_builder.GetGraphViewer().GetAllInitializedTensors()); + auto conv_type = GetConvType(node_unit, model_builder.GetInitializerTensors()); bool conv_2d = (conv_type == ConvType::Regular), depthwise_conv_2d = (conv_type == ConvType::Depthwise), grouped_conv_2d = (conv_type == ConvType::Grouped); @@ -1446,10 +1338,10 @@ Status ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N y_zero_point = 0; // this is for per-channel quantization weights - optional> w_scales; + optional> w_scales; bool is_per_tensor_u8s8 = false; if (is_qlinear_conv) { - ORT_RETURN_IF_ERROR(GetConvMatMulOpQuantizationScaleAndZeroPoint(model_builder, node, + ORT_RETURN_IF_ERROR(GetConvMatMulOpQuantizationScaleAndZeroPoint(model_builder, node_unit, x_scale, w_scale, y_scale, x_zero_point, w_zero_point, y_zero_point, w_scales, is_per_tensor_u8s8)); @@ -1505,8 +1397,8 @@ Status ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N ORT_RETURN_IF_ERROR(IsValidConvWeightQuantizedType(model_builder, weight, w_scale, w_zero_point, w_scales)); } - bool hasBias = (input_defs.size() > b_idx); - std::string bias = hasBias ? input_defs[b_idx]->Name() : weight + "_bias"; + bool hasBias = (inputs.size() > 2); + std::string bias = hasBias ? inputs[2].node_arg.Name() : weight + "_bias"; if (!hasBias) { const auto weight_dimen = shaper[weight]; Shape bias_dimen; @@ -1517,11 +1409,11 @@ Status ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N const auto& weight_type = operand_types.at(weight).type; if (weight_type == Type::TENSOR_FLOAT32) { - vector buffer(bias_dimen[0], 0.0f); + std::vector buffer(bias_dimen[0], 0.0f); OperandType bias_operand_type(Type::TENSOR_FLOAT32, bias_dimen, x_scale * w_scale); ORT_RETURN_IF_ERROR(model_builder.AddOperandFromPersistMemoryBuffer(bias, buffer.data(), bias_operand_type)); } else if (weight_type == Type::TENSOR_QUANT8_ASYMM || weight_type == Type::TENSOR_QUANT8_SYMM_PER_CHANNEL) { - vector buffer(bias_dimen[0], 0); + std::vector buffer(bias_dimen[0], 0); OperandType bias_operand_type(Type::TENSOR_INT32, bias_dimen, x_scale * w_scale); ORT_RETURN_IF_ERROR(model_builder.AddOperandFromPersistMemoryBuffer(bias, buffer.data(), bias_operand_type)); } else { @@ -1582,7 +1474,7 @@ Status ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N } } - int32_t fuse_code = model_builder.FindActivation(node, *node.OutputDefs()[0]); + int32_t fuse_code = model_builder.FindActivation(node_unit); ADD_SCALAR_OPERAND(model_builder, input_indices, fuse_code); if (model_builder.GetNNAPIFeatureLevel() > ANEURALNETWORKS_FEATURE_LEVEL_2) { @@ -1600,7 +1492,7 @@ Status ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N } int32_t operationCode; - const auto& output = node.OutputDefs()[0]->Name(); + const auto& output = node_unit.Outputs()[0].node_arg.Name(); if (conv_2d || grouped_conv_2d) { operationCode = conv_2d ? ANEURALNETWORKS_CONV_2D : ANEURALNETWORKS_GROUPED_CONV_2D; @@ -1628,17 +1520,16 @@ Status ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N class CastOpBuilder : public BaseOpBuilder { private: - Status AddToModelBuilderImpl(ModelBuilder& model_builder, const NodeUnit& node_unit) const override ORT_MUST_USE_RESULT; + Status AddToModelBuilderImpl(ModelBuilder& model_builder, const NodeUnit& node_unit) const override; }; Status CastOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const NodeUnit& node_unit) const { - const auto& node = node_unit.GetNode(); auto& shaper(model_builder.GetShaper()); const auto& operand_indices(model_builder.GetOperandIndices()); - NodeAttrHelper helper(node); + NodeAttrHelper helper(node_unit); - const auto& input = node.InputDefs()[0]->Name(); - const auto& output = node.OutputDefs()[0]->Name(); + const auto& input = node_unit.Inputs()[0].node_arg.Name(); + const auto& output = node_unit.Outputs()[0].node_arg.Name(); bool output_is_nhwc = model_builder.IsOperandNHWC(input); auto to = helper.Get("to", 0); @@ -1669,25 +1560,24 @@ Status CastOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N class SoftMaxOpBuilder : public BaseOpBuilder { private: - Status AddToModelBuilderImpl(ModelBuilder& model_builder, const NodeUnit& node_unit) const override ORT_MUST_USE_RESULT; + Status AddToModelBuilderImpl(ModelBuilder& model_builder, const NodeUnit& node_unit) const override; }; Status SoftMaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const NodeUnit& node_unit) const { - const auto& node = node_unit.GetNode(); auto& shaper(model_builder.GetShaper()); const auto& operand_indices(model_builder.GetOperandIndices()); const auto& operand_types(model_builder.GetOperandTypes()); const auto android_feature_level = model_builder.GetNNAPIFeatureLevel(); - NodeAttrHelper helper(node); + NodeAttrHelper helper(node_unit); - auto input = node.InputDefs()[0]->Name(); + auto input = node_unit.Inputs()[0].node_arg.Name(); bool input_is_nhwc = model_builder.IsOperandNHWC(input); bool output_is_nhwc = input_is_nhwc; if (android_feature_level < ANEURALNETWORKS_FEATURE_LEVEL_3) { if (model_builder.IsOperandNHWC(input)) { output_is_nhwc = false; // We want to transpose nhwc operand back to nchw before softmax - ORT_RETURN_IF_ERROR(GetNCHWInput(model_builder, node, 0, input)); + ORT_RETURN_IF_ERROR(GetNCHWInput(model_builder, node_unit, 0, input)); } } @@ -1697,7 +1587,7 @@ Status SoftMaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, cons axis = axis_nchw_to_nhwc[axis]; } - const auto& output = node.OutputDefs()[0]->Name(); + const auto& output = node_unit.Outputs()[0].node_arg.Name(); float beta = 1.f; std::vector input_indices; input_indices.push_back(operand_indices.at(input)); @@ -1721,20 +1611,18 @@ Status SoftMaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, cons class IdentityOpBuilder : public BaseOpBuilder { private: - Status AddToModelBuilderImpl(ModelBuilder& model_builder, const NodeUnit& node_unit) const override ORT_MUST_USE_RESULT; + Status AddToModelBuilderImpl(ModelBuilder& model_builder, const NodeUnit& node_unit) const override; }; Status IdentityOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const NodeUnit& node_unit) const { - const auto& node = node_unit.GetNode(); - // Identity is not really going to do anything // Just register the dimension and type, with same index and new name auto& shaper(model_builder.GetShaper()); const auto& operand_indices(model_builder.GetOperandIndices()); const auto& operand_types(model_builder.GetOperandTypes()); - const auto& input = node.InputDefs()[0]->Name(); - const auto& output = node.OutputDefs()[0]->Name(); + const auto& input = node_unit.Inputs()[0].node_arg.Name(); + const auto& output = node_unit.Outputs()[0].node_arg.Name(); bool output_is_nhwc = model_builder.IsOperandNHWC(input); std::vector input_indices; @@ -1756,9 +1644,15 @@ class GemmOpBuilder : public BaseOpBuilder { static void CreateSharedOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); private: - Status AddToModelBuilderImpl(ModelBuilder& model_builder, const NodeUnit& node_unit) const override ORT_MUST_USE_RESULT; + static bool IsQuantizedOp(const NodeUnit& node_unit) ORT_MUST_USE_RESULT; // TODO, see if we want to move this to BaseOpBuilder + Status AddToModelBuilderImpl(ModelBuilder& model_builder, const NodeUnit& node_unit) const override; }; +/* static */ bool GemmOpBuilder::IsQuantizedOp(const NodeUnit& node_unit) { + // TODO, add support for QDQ NodeUnit + return node_unit.OpType() == "QLinearMatMul"; +} + /* static */ void GemmOpBuilder::CreateSharedOpBuilder( const std::string& op_type, OpBuilderRegistrations& op_registrations) { CreateSharedOpBuilderImpl( @@ -1771,43 +1665,38 @@ class GemmOpBuilder : public BaseOpBuilder { } void GemmOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const NodeUnit& node_unit) const { - const auto& node = node_unit.GetNode(); - - const auto& op = node.OpType(); - const auto input_defs(node.InputDefs()); - if (op == "MatMul") { - model_builder.AddInitializerToSkip(input_defs[1]->Name()); - } else if (op == "Gemm") { - NodeAttrHelper helper(node); - const auto transB = helper.Get("transB", 0); - if (transB == 0) - model_builder.AddInitializerToSkip(input_defs[1]->Name()); - } else if (op == "QLinearMatMul") { - AddBinaryOpQuantizationScaleAndZeroPointToSkip(model_builder, node_unit); - model_builder.AddInitializerToSkip(input_defs[3]->Name()); // b + const auto& inputs = node_unit.Inputs(); + if (IsQuantizedOp(node_unit)) { + AddQuantizationScaleAndZeroPointToSkip(model_builder, *inputs[0].quant_param); // b_scale, b_zp + AddInputToSkip(model_builder, inputs[1]); // b, b_scale, b_zp + AddQuantizationScaleAndZeroPointToSkip(model_builder, *node_unit.Outputs()[0].quant_param); // y_scale, y_zp + } else { + const auto& op = node_unit.OpType(); + if (op == "MatMul") { + model_builder.AddInitializerToSkip(inputs[1].node_arg.Name()); + } else if (op == "Gemm") { + NodeAttrHelper helper(node_unit); + const auto transB = helper.Get("transB", 0); + if (transB == 0) + model_builder.AddInitializerToSkip(inputs[1].node_arg.Name()); + } } } Status GemmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const NodeUnit& node_unit) const { - const auto& node = node_unit.GetNode(); - auto& shaper(model_builder.GetShaper()); const auto& operand_indices(model_builder.GetOperandIndices()); const auto& operand_types(model_builder.GetOperandTypes()); const auto& initializers(model_builder.GetInitializerTensors()); - const auto& op = node.OpType(); - const auto input_defs(node.InputDefs()); - NodeAttrHelper helper(node); + const auto& op = node_unit.OpType(); + const auto& inputs = node_unit.Inputs(); + NodeAttrHelper helper(node_unit); bool is_qlinear_matmul = op == "QLinearMatMul"; - size_t a_idx = 0, - b_idx = is_qlinear_matmul ? 3 : 1, - c_idx = 2; // QLinearMatMul has no bias - - const auto& input1 = input_defs[a_idx]->Name(); - const auto& input2 = input_defs[b_idx]->Name(); - const auto& output = node.OutputDefs()[0]->Name(); + const auto& input1 = inputs[0].node_arg.Name(); + const auto& input2 = inputs[1].node_arg.Name(); + const auto& output = node_unit.Outputs()[0].node_arg.Name(); const auto transB = helper.Get("transB", 0); float a_scale = 0.0f, @@ -1819,9 +1708,9 @@ Status GemmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N bool is_per_tensor_u8s8 = false; if (is_qlinear_matmul) { - optional> w_scales; + optional> w_scales; ORT_RETURN_IF_ERROR( - GetConvMatMulOpQuantizationScaleAndZeroPoint(model_builder, node, + GetConvMatMulOpQuantizationScaleAndZeroPoint(model_builder, node_unit, a_scale, b_scale, y_scale, a_zero_point, b_zero_point, y_zero_point, w_scales, is_per_tensor_u8s8)); @@ -1853,14 +1742,14 @@ Status GemmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N } uint32_t bias_idx; - bool has_bias = (op == "Gemm") && (input_defs.size() > 2); + bool has_bias = inputs.size() > 2; if (has_bias) { - const auto& bias = input_defs[c_idx]->Name(); + const auto& bias = inputs[2].node_arg.Name(); // We need squeeze the input tensor to 1d if necessary if (shaper[bias].size() > 1) { - std::string bias_squeezed = model_builder.GetUniqueName(node.Name() + op + "_bias_squeezed"); + std::string bias_squeezed = model_builder.GetUniqueName(node_unit.Name() + op + "_bias_squeezed"); // We will use squeeze all here - ORT_RETURN_IF_ERROR(AddSqueezeOp(model_builder, node.Name(), + ORT_RETURN_IF_ERROR(AddSqueezeOp(model_builder, node_unit.Name(), bias, bias_squeezed, {} /* axes */)); bias_idx = operand_indices.at(bias_squeezed); @@ -1873,7 +1762,7 @@ Status GemmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N } } else { // No C supplied, we need a vector of 0 - std::string bias = model_builder.GetUniqueName(node.Name() + op + "_bias"); + std::string bias = model_builder.GetUniqueName(node_unit.Name() + op + "_bias"); const auto& bias_type = operand_types.at(input2).type; const Shape& bias_dimen = {shaper[input2][0]}; if (bias_type == Type::TENSOR_FLOAT32) { @@ -1895,7 +1784,7 @@ Status GemmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N input_indices.push_back(operand_indices.at(input1)); // A input_indices.push_back(input_2_idx); // B input_indices.push_back(bias_idx); // C - int32_t fuse_code = model_builder.FindActivation(node, *node.OutputDefs()[0]); + int32_t fuse_code = model_builder.FindActivation(node_unit); ADD_SCALAR_OPERAND(model_builder, input_indices, fuse_code); ORT_RETURN_IF_ERROR(shaper.FC(input1, input2, output)); @@ -1915,24 +1804,21 @@ class UnaryOpBuilder : public BaseOpBuilder { static void CreateSharedOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); private: - Status AddToModelBuilderImpl(ModelBuilder& model_builder, const NodeUnit& node_unit) const override ORT_MUST_USE_RESULT; + static bool IsQuantizedOp(const NodeUnit& node_unit) ORT_MUST_USE_RESULT; // TODO, see if we want to move this to BaseOpBuilder + Status AddToModelBuilderImpl(ModelBuilder& model_builder, const NodeUnit& node_unit) const override; }; +/* static */ bool UnaryOpBuilder::IsQuantizedOp(const NodeUnit& node_unit) { + // TODO, add support for QDQ NodeUnit + return node_unit.OpType() == "QLinearSigmoid"; +} + void UnaryOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const NodeUnit& node_unit) const { - const auto& node = node_unit.GetNode(); - const auto& op = node.OpType(); - if (op != "QLinearSigmoid") + if (!IsQuantizedOp(node_unit)) return; - const auto input_defs = node.InputDefs(); - - // skip input/output scales and zeropoints - model_builder.AddInitializerToSkip(input_defs[1]->Name()); // X_scale - model_builder.AddInitializerToSkip(input_defs[2]->Name()); // X_zero_point - model_builder.AddInitializerToSkip(input_defs[3]->Name()); // Y_scale - - if (input_defs.size() == 5) // has Y_zero_point input - model_builder.AddInitializerToSkip(input_defs[4]->Name()); // Y_zero_point + AddQuantizationScaleAndZeroPointToSkip(model_builder, *node_unit.Inputs()[0].quant_param); // x_scale, x_zp + AddQuantizationScaleAndZeroPointToSkip(model_builder, *node_unit.Outputs()[0].quant_param); // y_scale, y_zp } /* static */ void UnaryOpBuilder::CreateSharedOpBuilder( @@ -1954,14 +1840,13 @@ void UnaryOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const No } Status UnaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const NodeUnit& node_unit) const { - const auto& node = node_unit.GetNode(); auto& shaper(model_builder.GetShaper()); const auto& operand_indices(model_builder.GetOperandIndices()); const auto& operand_types(model_builder.GetOperandTypes()); - const auto& op_type(node.OpType()); + const auto& op_type(node_unit.OpType()); - const auto& input = node.InputDefs()[0]->Name(); - const auto& output = node.OutputDefs()[0]->Name(); + const auto& input = node_unit.Inputs()[0].node_arg.Name(); + const auto& output = node_unit.Outputs()[0].node_arg.Name(); bool output_is_nhwc = model_builder.IsOperandNHWC(input); ORT_RETURN_IF_ERROR(shaper.Identity(input, output)); @@ -1995,9 +1880,9 @@ Status UnaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const if (is_qlinear_sigmoid) { const auto& initializers = model_builder.GetInitializerTensors(); float x_scale = 0.0f; - ORT_RETURN_IF_ERROR(GetQuantizationScale(initializers, node, 1, x_scale)); int32_t x_zero_point = 0; - ORT_RETURN_IF_ERROR(GetQuantizationZeroPoint(initializers, node, 2, x_zero_point)); + ORT_RETURN_IF_ERROR(GetQuantizationScaleAndZeroPoint( + initializers, node_unit.Inputs()[0], node_unit.ModelPath(), x_scale, x_zero_point)); // Verify if the scale and zero point values from onnx input and nnapi input match ORT_RETURN_IF_ERROR(IsValidInputQuantizedType(model_builder, input, x_scale, x_zero_point)); @@ -2021,21 +1906,21 @@ Status UnaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const class ConcatOpBuilder : public BaseOpBuilder { private: - Status AddToModelBuilderImpl(ModelBuilder& model_builder, const NodeUnit& node_unit) const override ORT_MUST_USE_RESULT; + Status AddToModelBuilderImpl(ModelBuilder& model_builder, const NodeUnit& node_unit) const override; }; Status ConcatOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const NodeUnit& node_unit) const { - const auto& node = node_unit.GetNode(); auto& shaper(model_builder.GetShaper()); const auto& operand_indices(model_builder.GetOperandIndices()); const auto& operand_types(model_builder.GetOperandTypes()); - NodeAttrHelper helper(node); + NodeAttrHelper helper(node_unit); + const auto& inputs = node_unit.Inputs(); std::vector input_indices; - const auto& input0 = node.InputDefs()[0]->Name(); + const auto& input0 = inputs[0].node_arg.Name(); bool all_input_have_same_layout = true; bool output_is_nhwc = false; - const auto node_input_size = node.InputDefs().size(); + const auto node_input_size = inputs.size(); // First if the inputs are uint8, we need verify all the inputs have same scale and zero points if (operand_types.at(input0).type == android::nn::wrapper::Type::TENSOR_QUANT8_ASYMM) { @@ -2044,7 +1929,7 @@ Status ConcatOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const // Compare scale and zp of input0 to input1~n for (size_t i = 1; i < node_input_size; i++) { - const auto& type = operand_types.at(node.InputDefs()[i]->Name()); + const auto& type = operand_types.at(inputs[i].node_arg.Name()); ORT_RETURN_IF_NOT(scale == type.operandType.scale, "Input[", i, "]'s scale: ", type.operandType.scale, " is different than input[0]'s scale: ", scale); @@ -2059,31 +1944,31 @@ Status ConcatOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const for (size_t i = 0; i < node_input_size - 1; i++) { all_input_have_same_layout = all_input_have_same_layout && - model_builder.IsOperandNHWC(node.InputDefs()[i]->Name()) == - model_builder.IsOperandNHWC(node.InputDefs()[i + 1]->Name()); + model_builder.IsOperandNHWC(inputs[i].node_arg.Name()) == + model_builder.IsOperandNHWC(inputs[i + 1].node_arg.Name()); } - std::vector inputs; - inputs.reserve(node_input_size); + std::vector input_names; + input_names.reserve(node_input_size); if (all_input_have_same_layout) { // if all the inputs are of same layout, output will be the same layout output_is_nhwc = model_builder.IsOperandNHWC(input0); for (size_t i = 0; i < node_input_size; i++) { - auto input = node.InputDefs()[i]->Name(); + const auto& input = inputs[i].node_arg.Name(); input_indices.push_back(operand_indices.at(input)); - inputs.push_back(input); + input_names.push_back(input); } } else { // if all the inputs are not same layout, // will need transpos those nhwc tensors back to nchw for (size_t i = 0; i < node_input_size; i++) { - auto input = node.InputDefs()[i]->Name(); + auto input = inputs[i].node_arg.Name(); if (model_builder.IsOperandNHWC(input)) { - ORT_RETURN_IF_ERROR(GetNCHWInput(model_builder, node, i, input)); + ORT_RETURN_IF_ERROR(GetNCHWInput(model_builder, node_unit, i, input)); } input_indices.push_back(operand_indices.at(input)); - inputs.push_back(input); + input_names.push_back(input); } } @@ -2099,8 +1984,8 @@ Status ConcatOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const } ADD_SCALAR_OPERAND(model_builder, input_indices, axis); - const auto& output = node.OutputDefs()[0]->Name(); - ORT_RETURN_IF_ERROR(shaper.Concat(inputs, axis, output)); + const auto& output = node_unit.Outputs()[0].node_arg.Name(); + ORT_RETURN_IF_ERROR(shaper.Concat(input_names, axis, output)); OperandType output_operand_type = operand_types.at(input0); output_operand_type.SetDimensions(shaper[output]); ORT_RETURN_IF_ERROR(model_builder.AddOperation(ANEURALNETWORKS_CONCATENATION, input_indices, @@ -2117,25 +2002,24 @@ class SqueezeOpBuilder : public BaseOpBuilder { void AddInitializersToSkip(ModelBuilder& model_builder, const NodeUnit& node_unit) const override; private: - Status AddToModelBuilderImpl(ModelBuilder& model_builder, const NodeUnit& node_unit) const override ORT_MUST_USE_RESULT; - static Status GetAxes(ModelBuilder& model_builder, const Node& node, vector& axes); + Status AddToModelBuilderImpl(ModelBuilder& model_builder, const NodeUnit& node_unit) const override; + static Status GetAxes(ModelBuilder& model_builder, const NodeUnit& node_unit, std::vector& axes); }; void SqueezeOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const NodeUnit& node_unit) const { - const auto& node = node_unit.GetNode(); - if (node.SinceVersion() > 12 && node.InputDefs().size() > 1) { - model_builder.AddInitializerToSkip(node.InputDefs()[1]->Name()); + if (node_unit.SinceVersion() > 12 && node_unit.Inputs().size() > 1) { + model_builder.AddInitializerToSkip(node_unit.Inputs()[1].node_arg.Name()); } } /* static */ Status SqueezeOpBuilder::GetAxes(ModelBuilder& model_builder, - const Node& node, vector& axes) { + const NodeUnit& node_unit, std::vector& axes) { // Squeeze opset 13 use input as axes - if (node.SinceVersion() > 12) { + if (node_unit.SinceVersion() > 12) { // If axes is not supplied, return an empty axes as default to squeeze all - if (node.InputDefs().size() > 1) { + if (node_unit.Inputs().size() > 1) { const auto& initializers(model_builder.GetInitializerTensors()); - const auto& axes_tensor = *initializers.at(node.InputDefs()[1]->Name()); + const auto& axes_tensor = *initializers.at(node_unit.Inputs()[1].node_arg.Name()); std::vector unpacked_tensor; ORT_RETURN_IF_ERROR(onnxruntime::utils::UnpackInitializerData(axes_tensor, unpacked_tensor)); const int64_t* raw_axes = reinterpret_cast(unpacked_tensor.data()); @@ -2147,25 +2031,23 @@ void SqueezeOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const } } } else { - NodeAttrHelper helper(node); - axes = helper.Get("axes", vector()); + NodeAttrHelper helper(node_unit); + axes = helper.Get("axes", std::vector()); } return Status::OK(); } Status SqueezeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const NodeUnit& node_unit) const { - const auto& node = node_unit.GetNode(); - - auto input = node.InputDefs()[0]->Name(); + auto input = node_unit.Inputs()[0].node_arg.Name(); if (model_builder.IsOperandNHWC(input)) { // We want to transpose nhwc operand back to nchw before squeeze - ORT_RETURN_IF_ERROR(GetNCHWInput(model_builder, node, 0, input)); + ORT_RETURN_IF_ERROR(GetNCHWInput(model_builder, node_unit, 0, input)); } - vector axes; - ORT_RETURN_IF_ERROR(GetAxes(model_builder, node, axes)); - return AddSqueezeOp(model_builder, node.Name(), input, node.OutputDefs()[0]->Name(), axes); + std::vector axes; + ORT_RETURN_IF_ERROR(GetAxes(model_builder, node_unit, axes)); + return AddSqueezeOp(model_builder, node_unit.Name(), input, node_unit.Outputs()[0].node_arg.Name(), axes); } #pragma endregion @@ -2177,38 +2059,27 @@ class QuantizeLinearOpBuilder : public BaseOpBuilder { void AddInitializersToSkip(ModelBuilder& model_builder, const NodeUnit& node_unit) const override; private: - Status AddToModelBuilderImpl(ModelBuilder& model_builder, const NodeUnit& node_unit) const override ORT_MUST_USE_RESULT; + Status AddToModelBuilderImpl(ModelBuilder& model_builder, const NodeUnit& node_unit) const override; }; void QuantizeLinearOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const NodeUnit& node_unit) const { - const auto& node = node_unit.GetNode(); - const auto input_defs(node.InputDefs()); - - model_builder.AddInitializerToSkip(input_defs[1]->Name()); - - if (input_defs.size() == 3) // has zero_point input - model_builder.AddInitializerToSkip(input_defs[2]->Name()); + AddQuantizationScaleAndZeroPointToSkip(model_builder, *node_unit.Outputs()[0].quant_param); // y_scale, y_zp } Status QuantizeLinearOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const NodeUnit& node_unit) const { - const auto& node = node_unit.GetNode(); auto& shaper(model_builder.GetShaper()); const auto& operand_indices(model_builder.GetOperandIndices()); - const auto input_defs(node.InputDefs()); - const auto& input = input_defs[0]->Name(); - const auto& output = node.OutputDefs()[0]->Name(); + const auto& input = node_unit.Inputs()[0].node_arg.Name(); + const auto& output = node_unit.Outputs()[0].node_arg.Name(); bool output_is_nhwc = model_builder.IsOperandNHWC(input); float scale = 0.0f; - ORT_RETURN_IF_ERROR(GetQuantizationScale(model_builder.GetInitializerTensors(), node, 1, scale)); int32_t zero_point = 0; - Type output_type = Type::TENSOR_QUANT8_ASYMM; - - if (input_defs.size() == 3) { // Get zero point - ORT_RETURN_IF_ERROR(GetQuantizationZeroPoint(model_builder.GetInitializerTensors(), node, 2, zero_point)); - } + ORT_RETURN_IF_ERROR(GetQuantizationScaleAndZeroPoint( + model_builder.GetInitializerTensors(), node_unit.Outputs()[0], node_unit.ModelPath(), scale, zero_point)); + Type output_type = Type::TENSOR_QUANT8_ASYMM; ORT_RETURN_IF_ERROR(shaper.Identity(input, output)); const OperandType output_operand_type(output_type, shaper[output], scale, zero_point); std::vector input_indices; @@ -2227,36 +2098,26 @@ class DequantizeLinearOpBuilder : public BaseOpBuilder { void AddInitializersToSkip(ModelBuilder& model_builder, const NodeUnit& node_unit) const override; private: - Status AddToModelBuilderImpl(ModelBuilder& model_builder, const NodeUnit& node_unit) const override ORT_MUST_USE_RESULT; + Status AddToModelBuilderImpl(ModelBuilder& model_builder, const NodeUnit& node_unit) const override; }; void DequantizeLinearOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const NodeUnit& node_unit) const { - const auto& node = node_unit.GetNode(); - const auto input_defs(node.InputDefs()); - - model_builder.AddInitializerToSkip(input_defs[1]->Name()); - - if (input_defs.size() == 3) // has zero_point input - model_builder.AddInitializerToSkip(input_defs[2]->Name()); + AddQuantizationScaleAndZeroPointToSkip(model_builder, *node_unit.Inputs()[0].quant_param); // x_scale, x_zp } Status DequantizeLinearOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const NodeUnit& node_unit) const { - const auto& node = node_unit.GetNode(); - auto& shaper(model_builder.GetShaper()); const auto& operand_indices(model_builder.GetOperandIndices()); - const auto input_defs(node.InputDefs()); + const auto& inputs = node_unit.Inputs(); - const auto& input = input_defs[0]->Name(); - const auto& output = node.OutputDefs()[0]->Name(); + const auto& input = inputs[0].node_arg.Name(); + const auto& output = node_unit.Outputs()[0].node_arg.Name(); bool output_is_nhwc = model_builder.IsOperandNHWC(input); float scale = 0.0; - ORT_RETURN_IF_ERROR(GetQuantizationScale(model_builder.GetInitializerTensors(), node, 1, scale)); int32_t zero_point = 0; - if (input_defs.size() == 3) { // Get zero point - ORT_RETURN_IF_ERROR(GetQuantizationZeroPoint(model_builder.GetInitializerTensors(), node, 2, zero_point)); - } + ORT_RETURN_IF_ERROR(GetQuantizationScaleAndZeroPoint( + model_builder.GetInitializerTensors(), node_unit.Inputs()[0], node_unit.ModelPath(), scale, zero_point)); ORT_RETURN_IF_ERROR(IsValidInputQuantizedType(model_builder, input, scale, zero_point)); @@ -2276,25 +2137,24 @@ Status DequantizeLinearOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_buil class LRNOpBuilder : public BaseOpBuilder { private: - Status AddToModelBuilderImpl(ModelBuilder& model_builder, const NodeUnit& node_unit) const override ORT_MUST_USE_RESULT; + Status AddToModelBuilderImpl(ModelBuilder& model_builder, const NodeUnit& node_unit) const override; }; Status LRNOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const NodeUnit& node_unit) const { - const auto& node = node_unit.GetNode(); auto& shaper(model_builder.GetShaper()); const auto& operand_indices(model_builder.GetOperandIndices()); const auto& operand_types(model_builder.GetOperandTypes()); - NodeAttrHelper helper(node); + NodeAttrHelper helper(node_unit); const auto android_feature_level = model_builder.GetNNAPIFeatureLevel(); - auto input = node.InputDefs()[0]->Name(); - const auto& output = node.OutputDefs()[0]->Name(); + auto input = node_unit.Inputs()[0].node_arg.Name(); + const auto& output = node_unit.Outputs()[0].node_arg.Name(); bool output_is_nhwc = model_builder.IsOperandNHWC(input); if (android_feature_level < ANEURALNETWORKS_FEATURE_LEVEL_3) { // on android api level 28, we need to transpose the nchw input to nhwc output_is_nhwc = true; if (!model_builder.IsOperandNHWC(input)) { - ORT_RETURN_IF_ERROR(GetNHWCInput(model_builder, node, 0, input)); + ORT_RETURN_IF_ERROR(GetNHWCInput(model_builder, node_unit, 0, input)); } } @@ -2338,40 +2198,39 @@ class ClipOpBuilder : public BaseOpBuilder { void AddInitializersToSkip(ModelBuilder& model_builder, const NodeUnit& node_unit) const override; private: - Status AddToModelBuilderImpl(ModelBuilder& model_builder, const NodeUnit& node_unit) const override ORT_MUST_USE_RESULT; + Status AddToModelBuilderImpl(ModelBuilder& model_builder, const NodeUnit& node_unit) const override; }; void ClipOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const NodeUnit& node_unit) const { - const auto& node = node_unit.GetNode(); - if (node.InputDefs().size() > 1) - model_builder.AddInitializerToSkip(node.InputDefs()[1]->Name()); // min + const auto& inputs = node_unit.Inputs(); + if (inputs.size() > 1) + model_builder.AddInitializerToSkip(inputs[1].node_arg.Name()); // min - if (node.InputDefs().size() > 2) - model_builder.AddInitializerToSkip(node.InputDefs()[2]->Name()); // max + if (inputs.size() > 2) + model_builder.AddInitializerToSkip(inputs[2].node_arg.Name()); // max } Status ClipOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const NodeUnit& node_unit) const { - const auto& node = node_unit.GetNode(); - auto& shaper(model_builder.GetShaper()); const auto& operand_indices(model_builder.GetOperandIndices()); const auto& operand_types(model_builder.GetOperandTypes()); - const auto& input = node.InputDefs()[0]->Name(); - const auto& output = node.OutputDefs()[0]->Name(); + const auto& input = node_unit.Inputs()[0].node_arg.Name(); + const auto& output = node_unit.Outputs()[0].node_arg.Name(); bool output_is_nhwc = model_builder.IsOperandNHWC(input); ORT_RETURN_IF_ERROR(shaper.Identity(input, output)); const OperandType output_operand_type(operand_types.at(input).type, shaper[output]); if (Contains(model_builder.GetFusedActivations(), input)) { - LOGS_DEFAULT(VERBOSE) << "Clip Node [" << node.Name() << "] fused"; + LOGS_DEFAULT(VERBOSE) << "Clip Node [" << node_unit.Name() << "] fused"; model_builder.RegisterOperand(output, operand_indices.at(input), output_operand_type, output_is_nhwc); return Status::OK(); } float min, max; - GetClipMinMax(model_builder.GetInitializerTensors(), node, min, max, logging::LoggingManager::DefaultLogger()); + GetClipMinMax(model_builder.GetInitializerTensors(), node_unit.GetNode(), min, max, + logging::LoggingManager::DefaultLogger()); int32_t op_code; if (min == 0.0f && max == 6.0f) @@ -2398,34 +2257,33 @@ class ResizeOpBuilder : public BaseOpBuilder { void AddInitializersToSkip(ModelBuilder& model_builder, const NodeUnit& node_unit) const override; private: - Status AddToModelBuilderImpl(ModelBuilder& model_builder, const NodeUnit& node_unit) const override ORT_MUST_USE_RESULT; + Status AddToModelBuilderImpl(ModelBuilder& model_builder, const NodeUnit& node_unit) const override; }; void ResizeOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const NodeUnit& node_unit) const { - const auto& node = node_unit.GetNode(); + const auto& inputs = node_unit.Inputs(); // We don't really use ROI here, so add them to skipped list - model_builder.AddInitializerToSkip(node.InputDefs()[1]->Name()); // ROI + model_builder.AddInitializerToSkip(inputs[1].node_arg.Name()); // ROI // We will still add scales to the skipped list even sizes are present // since there is no use of it, we will not process it later - model_builder.AddInitializerToSkip(node.InputDefs()[2]->Name()); // scales + model_builder.AddInitializerToSkip(inputs[2].node_arg.Name()); // scales - if (node.InputDefs().size() > 3) - model_builder.AddInitializerToSkip(node.InputDefs()[3]->Name()); // sizes + if (inputs.size() > 3) + model_builder.AddInitializerToSkip(inputs[3].node_arg.Name()); // sizes } Status ResizeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const NodeUnit& node_unit) const { - const auto& node = node_unit.GetNode(); auto& shaper(model_builder.GetShaper()); const auto& operand_indices(model_builder.GetOperandIndices()); const auto& operand_types(model_builder.GetOperandTypes()); const auto& initializers(model_builder.GetInitializerTensors()); - NodeAttrHelper helper(node); - const auto input_defs = node.InputDefs(); + NodeAttrHelper helper(node_unit); + const auto& inputs = node_unit.Inputs(); const auto android_feature_level = model_builder.GetNNAPIFeatureLevel(); - const auto& output = node.OutputDefs()[0]->Name(); + const auto& output = node_unit.Outputs()[0].node_arg.Name(); - auto input = input_defs[0]->Name(); + auto input = inputs[0].node_arg.Name(); bool use_nchw = model_builder.UseNCHW(); bool input_is_nhwc = model_builder.IsOperandNHWC(input); bool output_is_nhwc = false; @@ -2434,7 +2292,7 @@ Status ResizeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const } else { output_is_nhwc = true; if (!input_is_nhwc) { - ORT_RETURN_IF_ERROR(GetNHWCInput(model_builder, node, 0, input)); + ORT_RETURN_IF_ERROR(GetNHWCInput(model_builder, node_unit, 0, input)); } } @@ -2447,8 +2305,8 @@ Status ResizeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const bool using_half_pixel = coord_trans_mode == "half_pixel"; bool using_align_corners = coord_trans_mode == "align_corners"; - if (input_defs.size() == 3) { // we are using scales - const auto& scales_name = input_defs[2]->Name(); + if (inputs.size() == 3) { // we are using scales + const auto& scales_name = inputs[2].node_arg.Name(); const auto& scales_tensor = *initializers.at(scales_name); std::vector unpacked_tensor; ORT_RETURN_IF_ERROR(onnxruntime::utils::UnpackInitializerData(scales_tensor, unpacked_tensor)); @@ -2458,7 +2316,7 @@ Status ResizeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const ORT_RETURN_IF_ERROR( shaper.ResizeUsingScales(input, scale_h, scale_w, use_nchw, output)); } else { // we are using sizes - const auto& sizes_name = input_defs[3]->Name(); + const auto& sizes_name = inputs[3].node_arg.Name(); const auto& sizes_tensor = *initializers.at(sizes_name); std::vector unpacked_tensor; ORT_RETURN_IF_ERROR(onnxruntime::utils::UnpackInitializerData(sizes_tensor, unpacked_tensor)); @@ -2505,31 +2363,29 @@ Status ResizeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const class FlattenOpBuilder : public BaseOpBuilder { private: - Status AddToModelBuilderImpl(ModelBuilder& model_builder, const NodeUnit& node_unit) const override ORT_MUST_USE_RESULT; + Status AddToModelBuilderImpl(ModelBuilder& model_builder, const NodeUnit& node_unit) const override; }; Status FlattenOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const NodeUnit& node_unit) const { - const auto& node = node_unit.GetNode(); - - auto input = node.InputDefs()[0]->Name(); + auto input = node_unit.Inputs()[0].node_arg.Name(); if (model_builder.IsOperandNHWC(input)) { // We want to transpose nhwc operand back to nchw before reshape - ORT_RETURN_IF_ERROR(GetNCHWInput(model_builder, node, 0, input)); + ORT_RETURN_IF_ERROR(GetNCHWInput(model_builder, node_unit, 0, input)); } // Flatten is basically a reshape to 2d tensor // Get the shape for Reshape here Shape input_shape; - GetShape(*node.InputDefs()[0], input_shape); + GetShape(node_unit.Inputs()[0].node_arg, input_shape); int32_t dim_1 = 1; int32_t dim_2 = 1; - GetFlattenOutputShape(node, input_shape, dim_1, dim_2); + GetFlattenOutputShape(node_unit, input_shape, dim_1, dim_2); // If the input is of dynamic shape, replace 0 (dynamic) dimension with -1 // We cannot have dim_1 and dim_2 both be 0 here, it was checked in IsOpSupportedImpl dim_1 = dim_1 == 0 ? -1 : dim_1; dim_2 = dim_2 == 0 ? -1 : dim_2; std::vector shape{dim_1, dim_2}; - return ReshapeOpBuilder::AddReshapeOperator(model_builder, node, input, shape); + return ReshapeOpBuilder::AddReshapeOperator(model_builder, node_unit, input, shape); } #pragma endregion @@ -2539,12 +2395,12 @@ Status FlattenOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, cons class MinMaxOpBuilder : public BaseOpBuilder { public: static void CreateSharedOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); - static Status AddMinMaxOperator(ModelBuilder& model_builder, const Node& node, - const std::string& input1, const std::string& input2, - bool output_is_nhwc) ORT_MUST_USE_RESULT; private: - Status AddToModelBuilderImpl(ModelBuilder& model_builder, const NodeUnit& node_unit) const override ORT_MUST_USE_RESULT; + Status AddToModelBuilderImpl(ModelBuilder& model_builder, const NodeUnit& node_unit) const override; + static Status AddMinMaxOperator(ModelBuilder& model_builder, const NodeUnit& node_unit, + const std::string& input1, const std::string& input2, + bool output_is_nhwc); }; /* static */ void MinMaxOpBuilder::CreateSharedOpBuilder( @@ -2557,16 +2413,16 @@ class MinMaxOpBuilder : public BaseOpBuilder { }); } -/* static */ Status MinMaxOpBuilder::AddMinMaxOperator(ModelBuilder& model_builder, const Node& node, +/* static */ Status MinMaxOpBuilder::AddMinMaxOperator(ModelBuilder& model_builder, const NodeUnit& node_unit, const std::string& input1, const std::string& input2, bool output_is_nhwc) { auto& shaper(model_builder.GetShaper()); const auto& operand_indices(model_builder.GetOperandIndices()); const auto& operand_types(model_builder.GetOperandTypes()); - const auto& output = node.OutputDefs()[0]->Name(); + const auto& output = node_unit.Outputs()[0].node_arg.Name(); - const auto& op_type(node.OpType()); + const auto& op_type(node_unit.OpType()); int32_t op_code; if (op_type == "Min") op_code = ANEURALNETWORKS_MINIMUM; @@ -2588,17 +2444,14 @@ class MinMaxOpBuilder : public BaseOpBuilder { } Status MinMaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const NodeUnit& node_unit) const { - const auto& node = node_unit.GetNode(); - const auto input_defs(node.InputDefs()); - std::string input1 = input_defs[0]->Name(); - std::string input2 = input_defs[1]->Name(); + const auto& inputs = node_unit.Inputs(); + std::string input1 = inputs[0].node_arg.Name(); + std::string input2 = inputs[1].node_arg.Name(); bool output_is_nhwc = false; - ORT_RETURN_IF_ERROR(TransposeBinaryOpInputLayout(model_builder, node, - 0 /* input1_idx */, - 1 /* input2_idx */, + ORT_RETURN_IF_ERROR(TransposeBinaryOpInputLayout(model_builder, node_unit, input1, input2, output_is_nhwc)); - return AddMinMaxOperator(model_builder, node, input1, input2, output_is_nhwc); + return AddMinMaxOperator(model_builder, node_unit, input1, input2, output_is_nhwc); } #pragma endregion @@ -2607,21 +2460,19 @@ Status MinMaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const class EluOpBuilder : public BaseOpBuilder { private: - Status AddToModelBuilderImpl(ModelBuilder& model_builder, const NodeUnit& node_unit) const override ORT_MUST_USE_RESULT; + Status AddToModelBuilderImpl(ModelBuilder& model_builder, const NodeUnit& node_unit) const override; }; Status EluOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const NodeUnit& node_unit) const { - const auto& node = node_unit.GetNode(); - auto& shaper(model_builder.GetShaper()); const auto& operand_indices(model_builder.GetOperandIndices()); const auto& operand_types(model_builder.GetOperandTypes()); - const auto& input = node.InputDefs()[0]->Name(); - const auto& output = node.OutputDefs()[0]->Name(); + const auto& input = node_unit.Inputs()[0].node_arg.Name(); + const auto& output = node_unit.Outputs()[0].node_arg.Name(); bool output_is_nhwc = model_builder.IsOperandNHWC(input); ORT_RETURN_IF_ERROR(shaper.Identity(input, output)); const OperandType output_operand_type(operand_types.at(input).type, shaper[output]); - NodeAttrHelper helper(node); + NodeAttrHelper helper(node_unit); const auto alpha = helper.Get("alpha", 1.0f); std::vector input_indices; input_indices.push_back(operand_indices.at(input)); @@ -2639,32 +2490,28 @@ class SliceOpBuilder : public BaseOpBuilder { void AddInitializersToSkip(ModelBuilder& model_builder, const NodeUnit& node_unit) const override; private: - Status AddToModelBuilderImpl(ModelBuilder& model_builder, const NodeUnit& node_unit) const override ORT_MUST_USE_RESULT; + Status AddToModelBuilderImpl(ModelBuilder& model_builder, const NodeUnit& node_unit) const override; }; void SliceOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const NodeUnit& node_unit) const { - const auto& node = node_unit.GetNode(); - // Skip everything except input0 for Slice - const auto input_defs = node.InputDefs(); - model_builder.AddInitializerToSkip(input_defs[1]->Name()); // starts - model_builder.AddInitializerToSkip(input_defs[2]->Name()); // ends - if (input_defs.size() > 3) { - model_builder.AddInitializerToSkip(input_defs[3]->Name()); // axes - if (input_defs.size() > 4) { - model_builder.AddInitializerToSkip(input_defs[4]->Name()); // steps + const auto& inputs = node_unit.Inputs(); + model_builder.AddInitializerToSkip(inputs[1].node_arg.Name()); // starts + model_builder.AddInitializerToSkip(inputs[2].node_arg.Name()); // ends + if (inputs.size() > 3) { + model_builder.AddInitializerToSkip(inputs[3].node_arg.Name()); // axes + if (inputs.size() > 4) { + model_builder.AddInitializerToSkip(inputs[4].node_arg.Name()); // steps } } } Status SliceOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const NodeUnit& node_unit) const { - const auto& node = node_unit.GetNode(); - auto& shaper(model_builder.GetShaper()); const auto& operand_indices(model_builder.GetOperandIndices()); const auto& operand_types(model_builder.GetOperandTypes()); - const auto input_defs = node.InputDefs(); - const auto& input_shape = shaper[input_defs[0]->Name()]; + const auto& inputs = node_unit.Inputs(); + const auto& input_shape = shaper[inputs[0].node_arg.Name()]; std::vector input_shape_64(input_shape.cbegin(), input_shape.cend()); SliceOp::PrepareForComputeMetadata compute_metadata(input_shape_64); @@ -2678,15 +2525,14 @@ Status SliceOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const std::vector input_axes; std::vector input_steps; - const auto CopyInputData = [&node, &model_builder](size_t input_idx, std::vector& data) { + const auto CopyInputData = [&inputs, &model_builder](size_t input_idx, std::vector& data) { data.clear(); - const auto input_defs = node.InputDefs(); // This is an optional input, return empty vector - if (input_defs.size() <= input_idx) + if (inputs.size() <= input_idx) return Status::OK(); - const auto& input_name = input_defs[input_idx]->Name(); + const auto& input_name = inputs[input_idx].node_arg.Name(); const auto& initializers(model_builder.GetInitializerTensors()); const auto& tensor = *initializers.at(input_name); @@ -2728,8 +2574,8 @@ Status SliceOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const std::back_inserter(nnapi_output_shape), [](int64_t i) { return SafeInt(i); }); - const auto& input = node.InputDefs()[0]->Name(); - const auto& output = node.OutputDefs()[0]->Name(); + const auto& input = inputs[0].node_arg.Name(); + const auto& output = node_unit.Outputs()[0].node_arg.Name(); bool output_is_nhwc = model_builder.IsOperandNHWC(input); // No shape inference for Slice, everything is calculated here, we only need to add the output shape @@ -2744,14 +2590,14 @@ Status SliceOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Shape param_dimen = {static_cast(input_shape.size())}; // helper function to add begin/end/strides of ANEURALNETWORKS_STRIDED_SLICE - const auto AddOperand = [&model_builder, &node, &input_indices, &operand_indices]( + const auto AddOperand = [&model_builder, &node_unit, &input_indices, &operand_indices]( const char* name, const Shape& shape, const std::vector& param_raw_data) { std::vector param_data; param_data.reserve(param_raw_data.size()); std::transform(param_raw_data.cbegin(), param_raw_data.cend(), std::back_inserter(param_data), [](int64_t i) { return SafeInt(i); }); - std::string param_name = model_builder.GetUniqueName(node.Name() + name); + std::string param_name = model_builder.GetUniqueName(node_unit.Name() + name); OperandType param_operand_type(Type::TENSOR_INT32, shape); ORT_RETURN_IF_ERROR( model_builder.AddOperandFromPersistMemoryBuffer(param_name, param_data.data(), param_operand_type)); diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder.h b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder.h index 46acbc4eff4b9..6483c432f1442 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder.h +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder.h @@ -7,7 +7,6 @@ #include #include "core/graph/basic_types.h" -#include "core/session/onnxruntime_c_api.h" namespace onnxruntime { @@ -31,7 +30,7 @@ class IOpBuilder { virtual void AddInitializersToSkip(ModelBuilder& model_builder, const NodeUnit& node_unit) const = 0; // Add the operator to NNAPI model - virtual common::Status AddToModelBuilder(ModelBuilder& model_builder, const NodeUnit& node_unit) const ORT_MUST_USE_RESULT = 0; + virtual common::Status AddToModelBuilder(ModelBuilder& model_builder, const NodeUnit& node_unit) const = 0; }; // Get the lookup table with IOpBuilder delegates for different onnx operators @@ -40,13 +39,7 @@ class IOpBuilder { const std::unordered_map& GetOpBuilders(); // Transpose the NHWC input to NCHW output -common::Status TransposeNHWCToNCHW(ModelBuilder& model_builder, const std::string& input, const std::string& output) - ORT_MUST_USE_RESULT; - -// Get the quantized input's scale and zero point for the given input -common::Status GetQuantizedInputScaleAndZeroPoint(const InitializedTensorSet& initializers, - const NodeUnit& node_unit, const std::string& input_name, - float& scale, int32_t& zero_point) ORT_MUST_USE_RESULT; +common::Status TransposeNHWCToNCHW(ModelBuilder& model_builder, const std::string& input, const std::string& output); } // namespace nnapi } // namespace onnxruntime diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_support_checker.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_support_checker.cc index ca2ba5e90fb93..75eab4c837d00 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_support_checker.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_support_checker.cc @@ -1,19 +1,16 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include -#include -#include -#include +#include "op_support_checker.h" +#include "core/common/logging/logging.h" +#include "core/common/safeint.h" +#include "core/framework/tensorprotoutils.h" +#include "core/graph/graph.h" #include "core/providers/common.h" #include "core/providers/shared/node_unit/node_unit.h" #include "core/providers/shared/utils/utils.h" #include "helper.h" -#include "op_support_checker.h" - -using onnxruntime::NodeUnit; -using std::vector; namespace onnxruntime { namespace nnapi { @@ -25,19 +22,37 @@ struct OpSupportCheckerRegistrations { std::unordered_map op_support_checker_map; }; -bool HasExternalInitializer(const InitializedTensorSet& initializers, const Node& node) { - for (const auto* node_arg : node.InputDefs()) { - const auto& input_name(node_arg->Name()); - if (!Contains(initializers, input_name)) +bool HasExternalInitializer(const InitializedTensorSet& initializers, const NodeUnit& node_unit) { + const auto is_ext_initializer = + [&](const NodeArg& node_arg) { + const auto& input_name(node_arg.Name()); + if (!Contains(initializers, input_name)) + return false; + + const auto& tensor = *initializers.at(input_name); + if (tensor.has_data_location() && + tensor.data_location() == ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL) { + LOGS_DEFAULT(VERBOSE) << "Initializer [" << input_name + << "] with external data location are not currently supported"; + return true; + } + + return false; + }; + + const auto& inputs = node_unit.Inputs(); + for (const auto& input : inputs) { + if (is_ext_initializer(input.node_arg)) + return true; + + if (!input.quant_param) continue; - const auto& tensor = *initializers.at(input_name); - if (tensor.has_data_location() && - tensor.data_location() == ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL) { - LOGS_DEFAULT(VERBOSE) << "Initializer [" << input_name - << "] with external data location are not currently supported"; + if (is_ext_initializer(input.quant_param->scale)) + return true; + + if (input.quant_param->zero_point && is_ext_initializer(*input.quant_param->zero_point)) return true; - } } return false; @@ -118,10 +133,8 @@ bool BaseOpSupportChecker::IsOpSupported(const InitializedTensorSet& initializer if (!HasSupportedInputs(node_unit)) return false; - const auto& node = node_unit.GetNode(); - // We do not support external initializers for now - if (HasExternalInitializer(initializers, node)) + if (HasExternalInitializer(initializers, node_unit)) return false; if (!HasSupportedOpSet(node_unit)) @@ -247,30 +260,25 @@ int BinaryOpSupportChecker::GetMinSupportedOpSet(const NodeUnit& node_unit) cons } bool BinaryOpSupportChecker::HasSupportedInputsImpl(const NodeUnit& node_unit) const { - // TODO, change to use node unit and quant_param of IODef - const auto& node = node_unit.GetNode(); - bool is_qlinear_add = node.OpType() == "QLinearAdd"; - bool is_pow = node.OpType() == "Pow"; + bool is_qlinear_add = node_unit.OpType() == "QLinearAdd"; + bool is_pow = node_unit.OpType() == "Pow"; if (!is_qlinear_add && !is_pow) return BaseOpSupportChecker::HasSupportedInputsImpl(node_unit); if (is_qlinear_add) { // QLinearAdd - if (!HasValidBinaryOpQuantizedInputs(node)) + if (!HasValidBinaryOpQuantizedInputs(node_unit)) return false; } // Pow we only support both input as fp32 now if (is_pow) { - const auto& input1 = *node.InputDefs()[0]; - const auto& input2 = *node.InputDefs()[1]; - int32_t input_type_1; - if (!GetType(input1, input_type_1)) + if (!GetType(node_unit.Inputs()[0].node_arg, input_type_1)) return false; int32_t input_type_2; - if (!GetType(input2, input_type_2)) + if (!GetType(node_unit.Inputs()[1].node_arg, input_type_2)) return false; if (input_type_1 != ONNX_NAMESPACE::TensorProto_DataType_FLOAT || input_type_1 != input_type_2) { @@ -286,24 +294,18 @@ bool BinaryOpSupportChecker::HasSupportedInputsImpl(const NodeUnit& node_unit) c bool BinaryOpSupportChecker::IsOpSupportedImpl(const InitializedTensorSet& initializers, const NodeUnit& node_unit, const OpSupportCheckParams& params) const { - const auto& node = node_unit.GetNode(); - - const auto& op_type(node.OpType()); - const auto input_defs(node.InputDefs()); + const auto& op_type(node_unit.OpType()); + const auto& inputs = node_unit.Inputs(); bool op_is_qlinear = op_type == "QLinearAdd"; - size_t a_idx = 0, b_idx = 1; - if (op_is_qlinear) { - b_idx = 3; - } Shape input1_shape, input2_shape; - if (!GetShape(*input_defs[a_idx], input1_shape) || - !GetShape(*input_defs[b_idx], input2_shape)) + if (!GetShape(inputs[0].node_arg, input1_shape) || + !GetShape(inputs[1].node_arg, input2_shape)) return false; const auto input1_size = input1_shape.size(); const auto input2_size = input2_shape.size(); if (input1_size > 4 || input2_size > 4) { - LOGS_DEFAULT(VERBOSE) << node.OpType() << " only support up to 4d shape, input1 is " + LOGS_DEFAULT(VERBOSE) << op_type << " only support up to 4d shape, input1 is " << input1_size << "d shape, input 2 is " << input2_size << "d shape"; return false; @@ -312,7 +314,7 @@ bool BinaryOpSupportChecker::IsOpSupportedImpl(const InitializedTensorSet& initi if (op_is_qlinear) { // For QLinearAdd, we only support uint8 output now int32_t output_type; - if (!GetType(*node.OutputDefs()[0], output_type)) + if (!GetType(node_unit.Outputs()[0].node_arg, output_type)) return false; if (output_type != ONNX_NAMESPACE::TensorProto_DataType_UINT8) { @@ -322,13 +324,16 @@ bool BinaryOpSupportChecker::IsOpSupportedImpl(const InitializedTensorSet& initi return false; } - // All scale/zero points are initializer scalars - // a/b/y_scale - if (!HasValidQuantizationScales(initializers, node, {1, 4, 6}, params)) + // Check input scales and ZPs + if (!HasValidQuantizationScales(initializers, node_unit, {0, 1}, params, true /* is_input */)) + return false; + if (!HasValidQuantizationZeroPoints(initializers, node_unit, {0, 1}, true /* is_input */)) return false; - // a/b/y_zero_point - if (!HasValidQuantizationZeroPoints(initializers, node, {2, 5, 7})) + // Check output scale and ZP + if (!HasValidQuantizationScales(initializers, node_unit, {0}, params, false /* is_input */)) + return false; + if (!HasValidQuantizationZeroPoints(initializers, node_unit, {0}, false /* is_input */)) return false; } @@ -354,9 +359,8 @@ class TransposeOpSupportChecker : public BaseOpSupportChecker { bool TransposeOpSupportChecker::IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const NodeUnit& node_unit, const OpSupportCheckParams& /* params */) const { - const auto& node = node_unit.GetNode(); Shape input_shape; - if (!GetShape(*node.InputDefs()[0], input_shape)) + if (!GetShape(node_unit.Inputs()[0].node_arg, input_shape)) return false; const auto input_size = input_shape.size(); @@ -400,15 +404,15 @@ class ReshapeOpSupportChecker : public BaseOpSupportChecker { bool ReshapeOpSupportChecker::IsOpSupportedImpl(const InitializedTensorSet& initializers, const NodeUnit& node_unit, const OpSupportCheckParams& /* params */) const { - const auto& node = node_unit.GetNode(); - const auto& perm_name = node.InputDefs()[1]->Name(); + const auto& inputs = node_unit.Inputs(); + const auto& perm_name = inputs[1].node_arg.Name(); if (!Contains(initializers, perm_name)) { LOGS_DEFAULT(VERBOSE) << "New shape of reshape must be known"; return false; } Shape input_shape; - if (!GetShape(*node.InputDefs()[0], input_shape)) + if (!GetShape(inputs[0].node_arg, input_shape)) return false; if (input_shape.size() > 4 || input_shape.empty()) { @@ -427,7 +431,7 @@ bool ReshapeOpSupportChecker::IsOpSupportedImpl(const InitializedTensorSet& init const int64_t* raw_perm = reinterpret_cast(unpacked_tensor.data()); const auto perm_size = SafeInt(perm_tensor.dims()[0]); - NodeAttrHelper helper(node); + NodeAttrHelper helper(node_unit); const bool allow_zero = helper.Get("allowzero ", 0) == 1; for (uint32_t i = 0; i < perm_size; i++) { // NNAPI reshape does not support 0 as dimension @@ -462,16 +466,15 @@ class BatchNormalizationOpSupportChecker : public BaseOpSupportChecker { bool BatchNormalizationOpSupportChecker::IsOpSupportedImpl(const InitializedTensorSet& initializers, const NodeUnit& node_unit, const OpSupportCheckParams& /* params */) const { - const auto& node = node_unit.GetNode(); - if (node.OutputDefs().size() != 1) { + if (node_unit.Outputs().size() != 1) { LOGS_DEFAULT(VERBOSE) << "Your onnx model may be in training mode, please export " "it in test mode."; return false; } - const auto& input_defs = node.InputDefs(); + const auto& inputs = node_unit.Inputs(); Shape input_shape; - if (!GetShape(*input_defs[0], input_shape)) + if (!GetShape(inputs[0].node_arg, input_shape)) return false; const auto input_size = input_shape.size(); @@ -481,17 +484,17 @@ bool BatchNormalizationOpSupportChecker::IsOpSupportedImpl(const InitializedTens return false; } - NodeAttrHelper helper(node); + NodeAttrHelper helper(node_unit); const auto spatial = helper.Get("spatial", 1); if (spatial != 1) { LOGS_DEFAULT(VERBOSE) << "Non-spatial BN is not supported"; return false; } - const auto& scale_name = input_defs[1]->Name(); - const auto& b_name = input_defs[2]->Name(); - const auto& mean_name = input_defs[3]->Name(); - const auto& var_name = input_defs[4]->Name(); + const auto& scale_name = inputs[1].node_arg.Name(); + const auto& b_name = inputs[2].node_arg.Name(); + const auto& mean_name = inputs[3].node_arg.Name(); + const auto& var_name = inputs[4].node_arg.Name(); if (!Contains(initializers, scale_name)) { LOGS_DEFAULT(VERBOSE) << "Scale of BN must be known"; return false; @@ -548,25 +551,24 @@ class PoolOpSupportChecker : public BaseOpSupportChecker { bool PoolOpSupportChecker::IsOpSupportedImpl(const InitializedTensorSet& initializers, const NodeUnit& node_unit, const OpSupportCheckParams& params) const { - const auto& node = node_unit.GetNode(); - const auto& op_name = node.Name(); - const auto& op_type = node.OpType(); - const auto& input_defs = node.InputDefs(); + const auto& op_name = node_unit.Name(); + const auto& op_type = node_unit.OpType(); + const auto& inputs = node_unit.Inputs(); Shape input_shape; - if (!GetShape(*input_defs[0], input_shape)) + if (!GetShape(inputs[0].node_arg, input_shape)) return false; const auto input_size = input_shape.size(); if (input_size != 4) { LOGS_DEFAULT(VERBOSE) << op_type << " only supports rank-4 tensor, input [" - << input_defs[0]->Name() << "] has actual dim count " << input_size; + << inputs[0].node_arg.Name() << "] has actual dim count " << input_size; return false; } bool is_qlinear_average_pool = op_type == "QLinearAveragePool"; if (op_type == "AveragePool" || op_type == "MaxPool" || is_qlinear_average_pool) { - NodeAttrHelper helper(node); + NodeAttrHelper helper(node_unit); const auto count_include_pad = helper.Get("count_include_pad", 0); if (count_include_pad == 1) { @@ -596,7 +598,7 @@ bool PoolOpSupportChecker::IsOpSupportedImpl(const InitializedTensorSet& initial return false; } - if (node.OutputDefs().size() != 1) { + if (node_unit.Outputs().size() != 1) { LOGS_DEFAULT(VERBOSE) << "Argmax in maxpooling is not supported"; return false; } @@ -607,37 +609,38 @@ bool PoolOpSupportChecker::IsOpSupportedImpl(const InitializedTensorSet& initial // We need to check if we have valid scales and zero points for QLinearAveragePool if (is_qlinear_average_pool) { - if (input_defs.size() < 4) + // Check input scales and ZPs + if (!HasValidQuantizationScales(initializers, node_unit, {0}, params, true /* is_input */)) + return false; + if (!HasValidQuantizationZeroPoints(initializers, node_unit, {0}, true /* is_input */)) return false; - // the output zero point can be optional - bool has_output_zp = input_defs.size() == 5; + // Check output scale and ZP - if (!HasValidQuantizationScales(initializers, node, {1, 3}, params)) + if (!HasValidQuantizationScales(initializers, node_unit, {0}, params, false /* is_input */)) return false; - - if (!HasValidQuantizationZeroPoints(initializers, node, - has_output_zp - ? std::vector{2} - : std::vector{2, 4})) { + if (!HasValidQuantizationZeroPoints(initializers, node_unit, {0}, false /* is_input */)) return false; - } // NNAPI requires Quantized Average Pool has same scale and zero point for both input and output float input_scale = 0.0f; - auto status = GetQuantizationScale(initializers, node, 1, input_scale); + int32_t input_zp = 0; + auto status = GetQuantizationScaleAndZeroPoint( + initializers, node_unit.Inputs()[0], node_unit.ModelPath(), input_scale, input_zp); if (!status.IsOK()) { LOGS_DEFAULT(ERROR) << "Op [" << op_type << "] name [" << op_name - << "] GetQuantizationScale for input_scale failed, message: " + << "] GetQuantizationScaleAndZeroPoint for input_scale/zp failed, message: " << status.ErrorMessage(); return false; } float output_scale = 0.0f; - status = GetQuantizationScale(initializers, node, 3, output_scale); + int32_t output_zp = 0; + status = GetQuantizationScaleAndZeroPoint( + initializers, node_unit.Outputs()[0], node_unit.ModelPath(), output_scale, output_zp); if (!status.IsOK()) { LOGS_DEFAULT(ERROR) << "Op [" << op_type << "] name [" << op_name - << "] GetQuantizationScale for output_scale failed, message: " + << "] GetQuantizationScaleAndZeroPoint for output_scale/zp failed, message: " << status.ErrorMessage(); return false; } @@ -649,26 +652,6 @@ bool PoolOpSupportChecker::IsOpSupportedImpl(const InitializedTensorSet& initial return false; } - int32_t input_zp = 0; - int32_t output_zp = 0; - status = GetQuantizationZeroPoint(initializers, node, 2, input_zp); - if (!status.IsOK()) { - LOGS_DEFAULT(ERROR) << "Op [" << op_type << "] name [" << op_name - << "] GetQuantizationZeroPoint for input_zp failed, message: " - << status.ErrorMessage(); - return false; - } - - if (has_output_zp) { - status = GetQuantizationZeroPoint(initializers, node, 4, output_zp); - if (!status.IsOK()) { - LOGS_DEFAULT(ERROR) << "Op [" << op_type << "] name [" << op_name - << "] GetQuantizationZeroPoint for output_zp failed, message: " - << status.ErrorMessage(); - return false; - } - } - if (input_zp != output_zp) { LOGS_DEFAULT(VERBOSE) << "Op [" << op_type << "] name [" << op_name << "] has different input_zp: " << input_zp @@ -681,26 +664,24 @@ bool PoolOpSupportChecker::IsOpSupportedImpl(const InitializedTensorSet& initial } bool PoolOpSupportChecker::HasSupportedInputsImpl(const NodeUnit& node_unit) const { - // TODO, change to use node unit and quant_param of IODef - const auto& node = node_unit.GetNode(); - bool is_max_pool = node.OpType() == "MaxPool"; - bool is_qlinear_average_pool = node.OpType() == "QLinearAveragePool"; + bool is_max_pool = node_unit.OpType() == "MaxPool"; + bool is_qlinear_average_pool = node_unit.OpType() == "QLinearAveragePool"; if (!is_max_pool && !is_qlinear_average_pool) return BaseOpSupportChecker::HasSupportedInputsImpl(node_unit); if (is_qlinear_average_pool) { - return HasValidUnaryOpQuantizedInputs(node); + return HasValidUnaryOpQuantizedInputs(node_unit); } // is_max_pool // For max pool, we can support both float and uint8 input int32_t input_type; - if (!GetType(*node.InputDefs()[0], input_type)) + if (!GetType(node_unit.Inputs()[0].node_arg, input_type)) return false; if (input_type != ONNX_NAMESPACE::TensorProto_DataType_FLOAT && input_type != ONNX_NAMESPACE::TensorProto_DataType_UINT8) { - LOGS_DEFAULT(VERBOSE) << "[" << node.OpType() + LOGS_DEFAULT(VERBOSE) << "[" << node_unit.OpType() << "] Input type: [" << input_type << "] is not supported for now"; return false; @@ -741,13 +722,11 @@ class ConvOpSupportChecker : public BaseOpSupportChecker { } bool ConvOpSupportChecker::HasSupportedInputsImpl(const NodeUnit& node_unit) const { - // TODO, change to use node unit and quant_param of IODef - const auto& node = node_unit.GetNode(); - if (node.OpType() != "QLinearConv") + if (node_unit.OpType() != "QLinearConv") return BaseOpSupportChecker::HasSupportedInputsImpl(node_unit); // QLinearConv only supports input of uint8 for now - if (!HasValidBinaryOpQuantizedInputs(node)) + if (!HasValidBinaryOpQuantizedInputs(node_unit)) return false; return true; @@ -755,21 +734,19 @@ bool ConvOpSupportChecker::HasSupportedInputsImpl(const NodeUnit& node_unit) con bool ConvOpSupportChecker::IsOpSupportedImpl(const InitializedTensorSet& initializers, const NodeUnit& node_unit, const OpSupportCheckParams& params) const { - const auto& node = node_unit.GetNode(); - const auto& op_type = node.OpType(); + const auto& op_type = node_unit.OpType(); const bool is_qlinear_conv = (op_type == "QLinearConv"); // We don't support nhwc com.microsoft.QLinearConv for now - if (is_qlinear_conv && node.Domain() == kMSDomain) { + if (is_qlinear_conv && node_unit.Domain() == kMSDomain) { LOGS_DEFAULT(VERBOSE) << "com.microsoft.QLinearConv is not supported"; return false; } - const auto input_defs = node.InputDefs(); - NodeAttrHelper helper(node); - size_t w_idx = is_qlinear_conv ? 3 : 1; + const auto& inputs = node_unit.Inputs(); + NodeAttrHelper helper(node_unit); const auto group = helper.Get("group", 1); - const auto weight_name = input_defs[w_idx]->Name(); + const auto weight_name = inputs[1].node_arg.Name(); if (Contains(initializers, weight_name)) { const auto& tensor = *initializers.at(weight_name); if (tensor.dims().size() != 4) { @@ -777,8 +754,8 @@ bool ConvOpSupportChecker::IsOpSupportedImpl(const InitializedTensorSet& initial return false; } - const auto onnx_dilations = helper.Get("dilations", vector{1, 1}); - if (onnx_dilations != vector{1, 1}) { + const auto onnx_dilations = helper.Get("dilations", std::vector{1, 1}); + if (onnx_dilations != std::vector{1, 1}) { if (group != 1 && tensor.dims()[1] != 1) { LOGS_DEFAULT(VERBOSE) << "dilation is not supported on grouped conv"; return false; @@ -798,7 +775,7 @@ bool ConvOpSupportChecker::IsOpSupportedImpl(const InitializedTensorSet& initial if (is_qlinear_conv) { // For QLinearConv, we only support uint8 output now int32_t output_type; - if (!GetType(*node.OutputDefs()[0], output_type)) + if (!GetType(node_unit.Outputs()[0].node_arg, output_type)) return false; if (output_type != ONNX_NAMESPACE::TensorProto_DataType_UINT8) { @@ -808,17 +785,21 @@ bool ConvOpSupportChecker::IsOpSupportedImpl(const InitializedTensorSet& initial return false; } - if (input_defs.size() > 8 && !Contains(initializers, input_defs[8]->Name())) { + if (inputs.size() > 2 && !Contains(initializers, inputs[2].node_arg.Name())) { LOGS_DEFAULT(VERBOSE) << "Bias of QLinearConv must be known"; return false; } - // a/b/y_scale - if (!HasValidQuantizationScales(initializers, node, {1, 4, 6}, params)) + // Check input scales and ZPs + if (!HasValidQuantizationScales(initializers, node_unit, {0, 1}, params, true /* is_input */)) + return false; + if (!HasValidQuantizationZeroPoints(initializers, node_unit, {0, 1}, true /* is_input */)) return false; - // a/b/y_zero_point - if (!HasValidQuantizationZeroPoints(initializers, node, {2, 5, 7})) + // Check output scale and ZP + if (!HasValidQuantizationScales(initializers, node_unit, {0}, params, false /* is_input */)) + return false; + if (!HasValidQuantizationZeroPoints(initializers, node_unit, {0}, false /* is_input */)) return false; } @@ -845,8 +826,7 @@ class CastOpSupportChecker : public BaseOpSupportChecker { bool CastOpSupportChecker::IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const NodeUnit& node_unit, const OpSupportCheckParams& /* params */) const { - const auto& node = node_unit.GetNode(); - NodeAttrHelper helper(node); + NodeAttrHelper helper(node_unit); const auto to = helper.Get("to", 0); if (to != ONNX_NAMESPACE::TensorProto::FLOAT && to != ONNX_NAMESPACE::TensorProto::INT32) { @@ -874,9 +854,8 @@ class SoftMaxOpSupportChecker : public BaseOpSupportChecker { bool SoftMaxOpSupportChecker::IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const NodeUnit& node_unit, const OpSupportCheckParams& params) const { - const auto& node = node_unit.GetNode(); Shape input_shape; - if (!GetShape(*node.InputDefs()[0], input_shape)) + if (!GetShape(node_unit.Inputs()[0].node_arg, input_shape)) return false; const auto input_size = input_shape.size(); @@ -887,7 +866,7 @@ bool SoftMaxOpSupportChecker::IsOpSupportedImpl(const InitializedTensorSet& /* i } if (params.android_feature_level < ANEURALNETWORKS_FEATURE_LEVEL_3) { - NodeAttrHelper helper(node); + NodeAttrHelper helper(node_unit); int32_t axis = helper.Get("axis", 1); if (axis != 1) { LOGS_DEFAULT(VERBOSE) @@ -917,13 +896,11 @@ class GemmOpSupportChecker : public BaseOpSupportChecker { }; bool GemmOpSupportChecker::HasSupportedInputsImpl(const NodeUnit& node_unit) const { - // TODO, change to use node unit and quant_param of IODef - const auto& node = node_unit.GetNode(); - if (node.OpType() != "QLinearMatMul") + if (node_unit.OpType() != "QLinearMatMul") return BaseOpSupportChecker::HasSupportedInputsImpl(node_unit); // QLinearMatMul - if (!HasValidBinaryOpQuantizedInputs(node)) + if (!HasValidBinaryOpQuantizedInputs(node_unit)) return false; return true; @@ -985,19 +962,13 @@ int GemmOpSupportChecker::GetMinSupportedOpSet(const NodeUnit& node_unit) const bool GemmOpSupportChecker::IsOpSupportedImpl(const InitializedTensorSet& initializers, const NodeUnit& node_unit, const OpSupportCheckParams& params) const { - const auto& node = node_unit.GetNode(); - const auto& op_type = node.OpType(); - const auto input_defs(node.InputDefs()); - size_t a_idx = 0, b_idx = 1, c_idx = 2; // A*B+C + const auto& op_type = node_unit.OpType(); + const auto& inputs = node_unit.Inputs(); bool is_qlinear_matmul = op_type == "QLinearMatMul"; - if (is_qlinear_matmul) { - a_idx = 0; - b_idx = 3; - } Shape a_shape; { - if (!GetShape(*input_defs[a_idx], a_shape)) + if (!GetShape(inputs[0].node_arg, a_shape)) return false; if (a_shape.size() != 2) { @@ -1008,7 +979,7 @@ bool GemmOpSupportChecker::IsOpSupportedImpl(const InitializedTensorSet& initial Shape b_shape; { - if (!GetShape(*input_defs[b_idx], b_shape)) + if (!GetShape(inputs[1].node_arg, b_shape)) return false; if (b_shape.size() != 2) { @@ -1021,7 +992,7 @@ bool GemmOpSupportChecker::IsOpSupportedImpl(const InitializedTensorSet& initial // Only support // 1. A*B'+C // 2. A*B+C and B is an initializer - NodeAttrHelper helper(node); + NodeAttrHelper helper(node_unit); const auto transA = helper.Get("transA", 0); const auto transB = helper.Get("transB", 0); const auto alpha = helper.Get("alpha", 1.0f); @@ -1037,14 +1008,14 @@ bool GemmOpSupportChecker::IsOpSupportedImpl(const InitializedTensorSet& initial return false; } - if (transB == 0 && !Contains(initializers, input_defs[b_idx]->Name())) { + if (transB == 0 && !Contains(initializers, inputs[1].node_arg.Name())) { LOGS_DEFAULT(VERBOSE) << "B of Gemm must be known if transB != 1"; return false; } - if (input_defs.size() == 3) { + if (inputs.size() == 3) { Shape c_shape; - if (!GetShape(*input_defs[c_idx], c_shape)) + if (!GetShape(inputs[2].node_arg, c_shape)) return false; uint32_t c_size; @@ -1062,7 +1033,7 @@ bool GemmOpSupportChecker::IsOpSupportedImpl(const InitializedTensorSet& initial } } else if (op_type == "MatMul" || is_qlinear_matmul) { // Only support A*B B is an initializer - if (!Contains(initializers, input_defs[b_idx]->Name())) { + if (!Contains(initializers, inputs[1].node_arg.Name())) { LOGS_DEFAULT(VERBOSE) << "B of MatMul must be known"; return false; } @@ -1070,7 +1041,7 @@ bool GemmOpSupportChecker::IsOpSupportedImpl(const InitializedTensorSet& initial if (is_qlinear_matmul) { // For QLinearMatMul, we only support uint8 output now int32_t output_type; - if (!GetType(*node.OutputDefs()[0], output_type)) + if (!GetType(node_unit.Outputs()[0].node_arg, output_type)) return false; if (output_type != ONNX_NAMESPACE::TensorProto_DataType_UINT8) { @@ -1081,12 +1052,16 @@ bool GemmOpSupportChecker::IsOpSupportedImpl(const InitializedTensorSet& initial } // All scale/zero points are initializer scalars - // a/b/y_scale - if (!HasValidQuantizationScales(initializers, node, {1, 4, 6}, params)) + // Check input scales and ZPs + if (!HasValidQuantizationScales(initializers, node_unit, {0, 1}, params, true /* is_input */)) + return false; + if (!HasValidQuantizationZeroPoints(initializers, node_unit, {0, 1}, true /* is_input */)) return false; - // a/b/y_zero_point - if (!HasValidQuantizationZeroPoints(initializers, node, {2, 5, 7})) + // Check output scale and ZP + if (!HasValidQuantizationScales(initializers, node_unit, {0}, params, false /* is_input */)) + return false; + if (!HasValidQuantizationZeroPoints(initializers, node_unit, {0}, false /* is_input */)) return false; } } else { @@ -1116,7 +1091,7 @@ class UnaryOpSupportChecker : public BaseOpSupportChecker { int GetMinSupportedOpSet(const NodeUnit& node_unit) const override; - static bool IsQuantizedOpSupported(const InitializedTensorSet& initializers, const Node& node, + static bool IsQuantizedOpSupported(const InitializedTensorSet& initializers, const NodeUnit& node_unit, const OpSupportCheckParams& params); }; @@ -1140,9 +1115,8 @@ class UnaryOpSupportChecker : public BaseOpSupportChecker { bool UnaryOpSupportChecker::IsOpSupportedImpl(const InitializedTensorSet& initializers, const NodeUnit& node_unit, const OpSupportCheckParams& params) const { - const auto& node = node_unit.GetNode(); - if (node.OpType() == "QLinearSigmoid") - return IsQuantizedOpSupported(initializers, node, params); + if (node_unit.OpType() == "QLinearSigmoid") + return IsQuantizedOpSupported(initializers, node_unit, params); else // Everything except "QLinearSigmoid" are by default supported return true; } @@ -1163,13 +1137,11 @@ int32_t UnaryOpSupportChecker::GetMinSupportedNNAPIFeatureLevel(const NodeUnit& } bool UnaryOpSupportChecker::HasSupportedInputsImpl(const NodeUnit& node_unit) const { - // TODO, change to use node unit and quant_param of IODef - const auto& node = node_unit.GetNode(); // We only need to override input check for QLinearSigmoid - if (node.OpType() != "QLinearSigmoid") + if (node_unit.OpType() != "QLinearSigmoid") return BaseOpSupportChecker::HasSupportedInputsImpl(node_unit); - return HasValidUnaryOpQuantizedInputs(node); + return HasValidUnaryOpQuantizedInputs(node_unit); } // All ops except "Sin" opset 5- uses consumed_inputs attribute which is not supported for now @@ -1183,35 +1155,35 @@ int UnaryOpSupportChecker::GetMinSupportedOpSet(const NodeUnit& node_unit) const } /* static */ bool UnaryOpSupportChecker::IsQuantizedOpSupported( - const InitializedTensorSet& initializers, const Node& node, const OpSupportCheckParams& params) { - const auto& op_type = node.OpType(); + const InitializedTensorSet& initializers, const NodeUnit& node_unit, const OpSupportCheckParams& params) { + const auto& op_type = node_unit.OpType(); ORT_ENFORCE(op_type == "QLinearSigmoid"); - const auto& op_name = node.Name(); - const auto input_defs(node.InputDefs()); - // const auto output_defs(node.OutputDefs()); + const auto& op_name = node_unit.Name(); - if (input_defs.size() < 4) + // Check input scales and ZPs + if (!HasValidQuantizationScales(initializers, node_unit, {0}, params, true /* is_input */)) return false; - - bool has_output_zp = input_defs.size() == 5; - - if (!HasValidQuantizationScales(initializers, node, {1, 3}, params)) + if (!HasValidQuantizationZeroPoints(initializers, node_unit, {0}, true /* is_input */)) return false; - if (!HasValidQuantizationZeroPoints(initializers, node, - has_output_zp - ? std::vector{2} - : std::vector{2, 4})) + // Check output scale and ZP + if (!HasValidQuantizationScales(initializers, node_unit, {0}, params, false /* is_input */)) + return false; + if (!HasValidQuantizationZeroPoints(initializers, node_unit, {0}, false /* is_input */)) return false; + return false; + // NNAPI requires the scale be 1.f/256 and zero point to be 0 // See https://android.googlesource.com/platform/frameworks/ml/+/refs/heads/android10-c2f2-release/nn/common/operations/Activation.cpp#180 float output_scale = 0.0f; - auto status = GetQuantizationScale(initializers, node, 3, output_scale); + int32_t output_zp = 0; + auto status = GetQuantizationScaleAndZeroPoint(initializers, node_unit.Outputs()[0], node_unit.ModelPath(), + output_scale, output_zp); if (!status.IsOK()) { LOGS_DEFAULT(ERROR) << "Op [" << op_type << "] name [" << op_name - << "] GetQuantizationScale failed, message: " << status.ErrorMessage(); + << "] GetQuantizationScaleAndZeroPoint failed, message: " << status.ErrorMessage(); return false; } @@ -1221,20 +1193,10 @@ int UnaryOpSupportChecker::GetMinSupportedOpSet(const NodeUnit& node_unit) const return false; } - int32_t output_zp; - if (has_output_zp) { - status = GetQuantizationZeroPoint(initializers, node, 4, output_zp); - if (!status.IsOK()) { - LOGS_DEFAULT(ERROR) << "Op [" << op_type << "] name [" << op_name - << "] GetQuantizationZeroPoint failed, message: " << status.ErrorMessage(); - return false; - } - - if (output_zp != 0) { - LOGS_DEFAULT(VERBOSE) << "Op [" << op_type << "] name [" << op_name - << "] output zero point can only be 0, actual zero point: " << output_scale; - return false; - } + if (output_zp != 0) { + LOGS_DEFAULT(VERBOSE) << "Op [" << op_type << "] name [" << op_name + << "] output zero point can only be 0, actual zero point: " << output_scale; + return false; } return true; @@ -1254,9 +1216,8 @@ class ConcatOpSupportChecker : public BaseOpSupportChecker { bool ConcatOpSupportChecker::IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const NodeUnit& node_unit, const OpSupportCheckParams& /* params */) const { - const auto& node = node_unit.GetNode(); Shape input_shape; - if (!GetShape(*node.InputDefs()[0], input_shape)) + if (GetShape(node_unit.Inputs()[0].node_arg, input_shape)) return false; const auto input_size = input_shape.size(); @@ -1302,21 +1263,21 @@ class SqueezeOpSupportChecker : public BaseOpSupportChecker { bool SqueezeOpSupportChecker::IsOpSupportedImpl(const InitializedTensorSet& initializers, const NodeUnit& node_unit, const OpSupportCheckParams& /* params */) const { - const auto& node = node_unit.GetNode(); + const auto& inputs = node_unit.Inputs(); Shape input_shape; - if (!GetShape(*node.InputDefs()[0], input_shape)) + if (!GetShape(inputs[0].node_arg, input_shape)) return false; - const auto input_size = input_shape.size(); - if (input_size > 4 || input_size == 0) { + const auto input_rank = input_shape.size(); + if (input_rank > 4 || input_rank == 0) { LOGS_DEFAULT(VERBOSE) << "Squeeze only supports 1-4d shape, input is " - << input_size << "d shape"; + << input_rank << "d shape"; return false; } // Squeeze opset 13 use input 1 as axes, if we have input 1 then it need to be an initializer - if (node.SinceVersion() > 12 && node.InputDefs().size() > 1) { - const auto& axes_name = node.InputDefs()[1]->Name(); + if (node_unit.SinceVersion() > 12 && inputs.size() > 1) { + const auto& axes_name = inputs[1].node_arg.Name(); if (!Contains(initializers, axes_name)) { LOGS_DEFAULT(VERBOSE) << "Input axes of Squeeze must be known"; return false; @@ -1343,28 +1304,23 @@ class QuantizeLinearOpSupportChecker : public BaseOpSupportChecker { bool QuantizeLinearOpSupportChecker::IsOpSupportedImpl(const InitializedTensorSet& initializers, const NodeUnit& node_unit, const OpSupportCheckParams& params) const { - const auto& node = node_unit.GetNode(); - const auto input_defs(node.InputDefs()); - const auto output_defs(node.OutputDefs()); - int32_t output_type; - if (!GetType(*output_defs[0], output_type)) + if (!GetType(node_unit.Outputs()[0].node_arg, output_type)) return false; if (output_type != ONNX_NAMESPACE::TensorProto_DataType_UINT8) { - LOGS_DEFAULT(VERBOSE) << "[" << node.OpType() + LOGS_DEFAULT(VERBOSE) << "[" << node_unit.OpType() << "] output type: [" << output_type << "] is not supported for now"; return false; } - if (!HasValidQuantizationScales(initializers, node, {1}, params)) + // For QuantizeLinear only output is quantized + // Check output scale and ZP + if (!HasValidQuantizationScales(initializers, node_unit, {0}, params, false /* is_input */)) + return false; + if (!HasValidQuantizationZeroPoints(initializers, node_unit, {0}, false /* is_input */)) return false; - - if (input_defs.size() == 3) { // has zero_point input - if (!HasValidQuantizationZeroPoints(initializers, node, {2})) - return false; - } return true; } @@ -1387,15 +1343,12 @@ class DequantizeLinearOpSupportChecker : public BaseOpSupportChecker { bool DequantizeLinearOpSupportChecker::IsOpSupportedImpl(const InitializedTensorSet& initializers, const NodeUnit& node_unit, const OpSupportCheckParams& params) const { - const auto& node = node_unit.GetNode(); - const auto input_defs(node.InputDefs()); - if (!HasValidQuantizationScales(initializers, node, {1}, params)) + // For DequantizeLinear only input is quantized + // Check input scale and ZP + if (!HasValidQuantizationScales(initializers, node_unit, {0}, params, true /* is_input */)) + return false; + if (!HasValidQuantizationZeroPoints(initializers, node_unit, {0}, true /* is_input */)) return false; - - if (input_defs.size() == 3) { // has zero_point input - if (!HasValidQuantizationZeroPoints(initializers, node, {2})) - return false; - } return true; } @@ -1432,9 +1385,8 @@ class LRNOpSupportChecker : public BaseOpSupportChecker { bool LRNOpSupportChecker::IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const NodeUnit& node_unit, const OpSupportCheckParams& /* params */) const { - const auto& node = node_unit.GetNode(); Shape input_shape; - if (!GetShape(*node.InputDefs()[0], input_shape)) + if (!GetShape(node_unit.Inputs()[0].node_arg, input_shape)) return false; const auto input_size = input_shape.size(); @@ -1459,9 +1411,8 @@ class ClipOpSupportChecker : public BaseOpSupportChecker { bool ClipOpSupportChecker::IsOpSupportedImpl(const InitializedTensorSet& initializers, const NodeUnit& node_unit, const OpSupportCheckParams& /* params */) const { - const auto& node = node_unit.GetNode(); float min, max; - if (!GetClipMinMax(initializers, node, min, max, logging::LoggingManager::DefaultLogger())) + if (!GetClipMinMax(initializers, node_unit.GetNode(), min, max, logging::LoggingManager::DefaultLogger())) return false; // We only supoort relu6 or relu1 @@ -1496,9 +1447,8 @@ class ResizeOpSupportChecker : public BaseOpSupportChecker { bool ResizeOpSupportChecker::IsOpSupportedImpl(const InitializedTensorSet& initializers, const NodeUnit& node_unit, const OpSupportCheckParams& params) const { - const auto& node = node_unit.GetNode(); Shape input_shape; - if (!GetShape(*node.InputDefs()[0], input_shape)) + if (!GetShape(node_unit.Inputs()[0].node_arg, input_shape)) return false; const auto input_size = input_shape.size(); @@ -1509,7 +1459,7 @@ bool ResizeOpSupportChecker::IsOpSupportedImpl(const InitializedTensorSet& initi } { // check attributes - NodeAttrHelper helper(node); + NodeAttrHelper helper(node_unit); const auto mode = helper.Get("mode", "nearest"); bool is_linear_resize = mode == "linear"; bool is_nearest_resize = mode == "nearest"; @@ -1556,27 +1506,27 @@ bool ResizeOpSupportChecker::IsOpSupportedImpl(const InitializedTensorSet& initi } { // scales and sizes (if present) must be initializers - const auto input_defs = node.InputDefs(); - if (input_defs.size() < 3) { + const auto inputs = node_unit.Inputs(); + if (inputs.size() < 3) { LOGS_DEFAULT(VERBOSE) << "Input scales or sizes of Resize must be known"; return false; } // scales - if (input_defs.size() == 3 && !Contains(initializers, input_defs[2]->Name())) { + if (inputs.size() == 3 && !Contains(initializers, inputs[2].node_arg.Name())) { LOGS_DEFAULT(VERBOSE) << "Input scales of Resize must be known"; return false; } // sizes - if (input_defs.size() > 3 && !Contains(initializers, input_defs[3]->Name())) { + if (inputs.size() > 3 && !Contains(initializers, inputs[3].node_arg.Name())) { LOGS_DEFAULT(VERBOSE) << "Input sizes of Resize must be known"; return false; } // We want to check if the scales or sizes are not trying to resize on N/C channels here - if (input_defs.size() == 3) { // we are using scales - const auto& scales_tensor = *initializers.at(input_defs[2]->Name()); + if (inputs.size() == 3) { // we are using scales + const auto& scales_tensor = *initializers.at(inputs[2].node_arg.Name()); std::vector unpacked_tensor; auto status = onnxruntime::utils::UnpackInitializerData(scales_tensor, unpacked_tensor); if (!status.IsOK()) { @@ -1594,7 +1544,7 @@ bool ResizeOpSupportChecker::IsOpSupportedImpl(const InitializedTensorSet& initi } } else { // we are using sizes - const auto& sizes_name = input_defs[3]->Name(); + const auto& sizes_name = inputs[3].node_arg.Name(); const auto& sizes_tensor = *initializers.at(sizes_name); std::vector unpacked_tensor; auto status = onnxruntime::utils::UnpackInitializerData(sizes_tensor, unpacked_tensor); @@ -1659,9 +1609,8 @@ class FlattenOpSupportChecker : public BaseOpSupportChecker { bool FlattenOpSupportChecker::IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const NodeUnit& node_unit, const OpSupportCheckParams& /* params */) const { - const auto& node = node_unit.GetNode(); Shape input_shape; - if (!GetShape(*node.InputDefs()[0], input_shape)) + if (!GetShape(node_unit.Inputs()[0].node_arg, input_shape)) return false; if (input_shape.size() > 4 || input_shape.empty()) { @@ -1672,7 +1621,7 @@ bool FlattenOpSupportChecker::IsOpSupportedImpl(const InitializedTensorSet& /* i int32_t dim_1 = 1; int32_t dim_2 = 1; - GetFlattenOutputShape(node, input_shape, dim_1, dim_2); + GetFlattenOutputShape(node_unit, input_shape, dim_1, dim_2); if (dim_1 == 0 && dim_2 == 0) { LOGS_DEFAULT(VERBOSE) << "The dynamical input shape " << Shape2String(input_shape) @@ -1717,11 +1666,10 @@ class MinMaxOpSupportChecker : public BaseOpSupportChecker { bool MinMaxOpSupportChecker::IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const NodeUnit& node_unit, const OpSupportCheckParams& /* params */) const { - const auto& node = node_unit.GetNode(); // TODO support 2+ inputs for Min/Max op - if (node.InputDefs().size() != 2) { - LOGS_DEFAULT(VERBOSE) << "[" << node.OpType() << "] only supports 2 inputs, " - << "actual input number, " << node.InputDefs().size(); + if (node_unit.Inputs().size() != 2) { + LOGS_DEFAULT(VERBOSE) << "[" << node_unit.OpType() << "] only supports 2 inputs, " + << "actual input number, " << node_unit.Inputs().size(); return false; } @@ -1763,9 +1711,8 @@ class SliceOpSupportChecker : public BaseOpSupportChecker { bool SliceOpSupportChecker::IsOpSupportedImpl(const InitializedTensorSet& initializers, const NodeUnit& node_unit, const OpSupportCheckParams& /* params */) const { - const auto& node = node_unit.GetNode(); Shape input_shape; - if (!GetShape(*node.InputDefs()[0], input_shape)) + if (!GetShape(node_unit.Inputs()[0].node_arg, input_shape)) return false; if (input_shape.size() > 4) { @@ -1780,19 +1727,19 @@ bool SliceOpSupportChecker::IsOpSupportedImpl(const InitializedTensorSet& initia return false; } - if (!CheckIsInitializer(initializers, node, 1, "starts")) { + if (!CheckIsInitializer(initializers, node_unit, node_unit.Inputs()[1].node_arg.Name(), "starts")) { return false; } - if (!CheckIsInitializer(initializers, node, 2, "ends")) { + if (!CheckIsInitializer(initializers, node_unit, node_unit.Inputs()[2].node_arg.Name(), "ends")) { return false; } - const auto& input_defs = node.InputDefs(); - if (input_defs.size() > 3) { - if (!CheckIsInitializer(initializers, node, 3, "axes")) { + const auto& inputs = node_unit.Inputs(); + if (inputs.size() > 3) { + if (!CheckIsInitializer(initializers, node_unit, node_unit.Inputs()[3].node_arg.Name(), "axes")) { return false; } - if (input_defs.size() > 4) { - if (!CheckIsInitializer(initializers, node, 4, "steps")) { + if (inputs.size() > 4) { + if (!CheckIsInitializer(initializers, node_unit, node_unit.Inputs()[4].node_arg.Name(), "steps")) { return false; } } diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/shaper.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/shaper.cc index 9d675ecfa84c0..1652237622e71 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/shaper.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/shaper.cc @@ -3,20 +3,17 @@ #include "core/providers/common.h" -#include "helper.h" #include "shaper.h" +#include "helper.h" namespace onnxruntime { namespace nnapi { -using std::string; -using std::vector; - std::pair ComputeConvOutputShape(const uint32_t input_size_y, const uint32_t input_size_x, const uint32_t weight_size_y, const uint32_t weight_size_x, - const vector& onnx_pads, - const vector& onnx_strides, - const vector& onnx_dilations) { + const std::vector& onnx_pads, + const std::vector& onnx_strides, + const std::vector& onnx_dilations) { int32_t padding_top = onnx_pads[0]; int32_t padding_bottom = onnx_pads[2]; int32_t padding_left = onnx_pads[1]; @@ -53,9 +50,9 @@ std::pair ComputeConvOutputShape(const uint32_t input_size_y Status Shaper::Conv(const std::string& input_name, const std::string& weight_name, - const vector& onnx_pads, - const vector& onnx_strides, - const vector& onnx_dilations, + const std::vector& onnx_pads, + const std::vector& onnx_strides, + const std::vector& onnx_dilations, bool nchw, const std::string& output_name) { SHAPER_FUNC(Conv, @@ -150,9 +147,9 @@ Status Shaper::ResizeUsingOutputSizes(const std::string& input_name, Status Shaper::ConvImpl(const std::string& input_name, const std::string& weight_name, - const vector& onnx_pads, - const vector& onnx_strides, - const vector& onnx_dilations, + const std::vector& onnx_pads, + const std::vector& onnx_strides, + const std::vector& onnx_dilations, bool nchw, const std::string& output_name) { const Shape& input_dimen = shape_map_.at(input_name); diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/shaper.h b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/shaper.h index b9299454dce44..8656328804f46 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/shaper.h +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/shaper.h @@ -6,7 +6,8 @@ #include #include #include -#include + +#include "core/common/status.h" namespace onnxruntime { namespace nnapi { @@ -20,115 +21,103 @@ class Shaper { return shape_map_.at(key); } - Status Conv(const std::string& input_name, - const std::string& weight_name, - const std::vector& onnx_pads, - const std::vector& onnx_strides, - const std::vector& onnx_dilations, - bool nchw, - const std::string& output_name) ORT_MUST_USE_RESULT; - - Status DepthwiseConv(const std::string& input_name, - const std::string& weight_name, - const std::vector& onnx_pads, - const std::vector& onnx_strides, - const std::vector& onnx_dilations, - bool nchw, - const std::string& output_name) ORT_MUST_USE_RESULT; - - Status Pool(const std::string& input_name, - const std::vector& onnx_pads, - const std::vector& onnx_strides, - const std::vector& kernel_shape, - bool nchw, - const std::string& output_name) ORT_MUST_USE_RESULT; - - Status Reshape(const std::string& input_name, const std::vector& shape, const std::string& output_name) - ORT_MUST_USE_RESULT; - - Status Transpose(const std::string& input_name, const std::vector& perm, const std::string& output_name) - ORT_MUST_USE_RESULT; - - Status Eltwise(const std::string& input1_name, const std::string& input2_name, const std::string& output_name) - ORT_MUST_USE_RESULT; - - Status Identity(const std::string& input_name, const std::string& output_name) ORT_MUST_USE_RESULT; - - Status FC(const std::string& input1_name, const std::string& input2_name, const std::string& output_name) - ORT_MUST_USE_RESULT; - - Status Concat(const std::vector& input_names, const int32_t axis, const std::string& output_name) - ORT_MUST_USE_RESULT; - - Status Squeeze(const std::string& input_name, const std::vector& axes, const std::string& output_name) - ORT_MUST_USE_RESULT; - - Status ResizeUsingScales(const std::string& input_name, - const float scale_h, const float scale_w, - bool nchw, - const std::string& output_name) ORT_MUST_USE_RESULT; - Status ResizeUsingOutputSizes(const std::string& input_name, - const uint32_t output_h, const uint32_t output_w, - bool nchw, - const std::string& output_name) ORT_MUST_USE_RESULT; + common::Status Conv(const std::string& input_name, + const std::string& weight_name, + const std::vector& onnx_pads, + const std::vector& onnx_strides, + const std::vector& onnx_dilations, + bool nchw, + const std::string& output_name); + + common::Status DepthwiseConv(const std::string& input_name, + const std::string& weight_name, + const std::vector& onnx_pads, + const std::vector& onnx_strides, + const std::vector& onnx_dilations, + bool nchw, + const std::string& output_name); + + common::Status Pool(const std::string& input_name, + const std::vector& onnx_pads, + const std::vector& onnx_strides, + const std::vector& kernel_shape, + bool nchw, + const std::string& output_name); + + common::Status Reshape(const std::string& input_name, const std::vector& shape, const std::string& output_name); + + common::Status Transpose(const std::string& input_name, const std::vector& perm, const std::string& output_name); + + common::Status Eltwise(const std::string& input1_name, const std::string& input2_name, const std::string& output_name); + + common::Status Identity(const std::string& input_name, const std::string& output_name); + + common::Status FC(const std::string& input1_name, const std::string& input2_name, const std::string& output_name); + + common::Status Concat(const std::vector& input_names, const int32_t axis, const std::string& output_name); + + common::Status Squeeze(const std::string& input_name, const std::vector& axes, const std::string& output_name); + + common::Status ResizeUsingScales(const std::string& input_name, + const float scale_h, const float scale_w, + bool nchw, + const std::string& output_name); + common::Status ResizeUsingOutputSizes(const std::string& input_name, + const uint32_t output_h, const uint32_t output_w, + bool nchw, + const std::string& output_name); // If the shape of certain input is dynamic // Use the following 2 functions to update the particular shape // and calculate the new output shape // Only perform this when the NNAPI model is finalized! - Status UpdateShape(const std::string& name, const Shape& new_shape) ORT_MUST_USE_RESULT; - Status UpdateDynamicDimensions() ORT_MUST_USE_RESULT; + common::Status UpdateShape(const std::string& name, const Shape& new_shape); + common::Status UpdateDynamicDimensions(); void Clear(); private: - Status ConvImpl(const std::string& input_name, - const std::string& weight_name, - const std::vector& onnx_pads, - const std::vector& onnx_strides, - const std::vector& onnx_dilations, - bool nchw, - const std::string& output_name) ORT_MUST_USE_RESULT; - - Status DepthwiseConvImpl(const std::string& input_name, - const std::string& weight_name, - const std::vector& onnx_pads, - const std::vector& onnx_strides, - const std::vector& onnx_dilations, - bool nchw, - const std::string& output_name) ORT_MUST_USE_RESULT; - - Status PoolImpl(const std::string& input_name, - const std::vector& onnx_pads, - const std::vector& onnx_strides, - const std::vector& kernel_shape, - bool nchw, - const std::string& output_name) ORT_MUST_USE_RESULT; - - Status ReshapeImpl(const std::string& input_name, const std::vector& shape, const std::string& output_name) - ORT_MUST_USE_RESULT; - Status TransposeImpl(const std::string& input_name, const std::vector& perm, const std::string& output_name) - ORT_MUST_USE_RESULT; - Status EltwiseImpl(const std::string& input1_name, const std::string& input2_name, const std::string& output_name) - ORT_MUST_USE_RESULT; - Status IdentityImpl(const std::string& input_name, const std::string& output_name) ORT_MUST_USE_RESULT; - Status FCImpl(const std::string& input1_name, const std::string& input2_name, const std::string& output_name) - ORT_MUST_USE_RESULT; - Status ConcatImpl(const std::vector& input_names, const int32_t axis, const std::string& output_name) - ORT_MUST_USE_RESULT; - Status SqueezeImpl(const std::string& input_names, const std::vector& axes, const std::string& output_name) - ORT_MUST_USE_RESULT; - Status ResizeUsingScalesImpl(const std::string& input_name, - const float scale_h, const float scale_w, - bool nchw, - const std::string& output_name) ORT_MUST_USE_RESULT; - Status ResizeUsingOutputSizesImpl(const std::string& input_name, - const uint32_t output_h, const uint32_t output_w, - bool nchw, - const std::string& output_name) ORT_MUST_USE_RESULT; + common::Status ConvImpl(const std::string& input_name, + const std::string& weight_name, + const std::vector& onnx_pads, + const std::vector& onnx_strides, + const std::vector& onnx_dilations, + bool nchw, + const std::string& output_name); + + common::Status DepthwiseConvImpl(const std::string& input_name, + const std::string& weight_name, + const std::vector& onnx_pads, + const std::vector& onnx_strides, + const std::vector& onnx_dilations, + bool nchw, + const std::string& output_name); + + common::Status PoolImpl(const std::string& input_name, + const std::vector& onnx_pads, + const std::vector& onnx_strides, + const std::vector& kernel_shape, + bool nchw, + const std::string& output_name); + + common::Status ReshapeImpl(const std::string& input_name, const std::vector& shape, const std::string& output_name); + common::Status TransposeImpl(const std::string& input_name, const std::vector& perm, const std::string& output_name); + common::Status EltwiseImpl(const std::string& input1_name, const std::string& input2_name, const std::string& output_name); + common::Status IdentityImpl(const std::string& input_name, const std::string& output_name); + common::Status FCImpl(const std::string& input1_name, const std::string& input2_name, const std::string& output_name); + common::Status ConcatImpl(const std::vector& input_names, const int32_t axis, const std::string& output_name); + common::Status SqueezeImpl(const std::string& input_names, const std::vector& axes, const std::string& output_name); + common::Status ResizeUsingScalesImpl(const std::string& input_name, + const float scale_h, const float scale_w, + bool nchw, + const std::string& output_name); + common::Status ResizeUsingOutputSizesImpl(const std::string& input_name, + const uint32_t output_h, const uint32_t output_w, + bool nchw, + const std::string& output_name); std::unordered_map shape_map_; - std::vector> shape_ops_; + std::vector> shape_ops_; }; } // namespace nnapi diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/model.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/model.cc index 7a2036252ae8f..887384e6bd3bc 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/model.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/model.cc @@ -1,9 +1,9 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include - #include "model.h" + +#include "core/common/logging/logging.h" #include "core/providers/common.h" #include "core/providers/nnapi/nnapi_builtin/builders/helper.h" #include "core/providers/nnapi/nnapi_builtin/nnapi_lib/nnapi_implementation.h" diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/model.h b/onnxruntime/core/providers/nnapi/nnapi_builtin/model.h index 8ce72538affc4..6326e60cf9797 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/model.h +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/model.h @@ -103,7 +103,7 @@ class Model { // this output may need special handling bool IsScalarOutput(const std::string& output_name) const; - Status PrepareForExecution(std::unique_ptr& execution) ORT_MUST_USE_RESULT; + common::Status PrepareForExecution(std::unique_ptr& execution); private: const NnApi* nnapi_{nullptr}; @@ -143,7 +143,7 @@ class Model { void AddScalarOutput(const std::string& output_name); - void SetShaper(const Shaper shaper) { shaper_ = shaper; } + void SetShaper(const Shaper& shaper) { shaper_ = shaper; } int32_t GetNNAPIFeatureLevel() const; }; @@ -172,17 +172,16 @@ class Execution { // Set the input/output data buffers // These need to be called before calling Predict() - Status SetInputBuffers(const std::vector& inputs) ORT_MUST_USE_RESULT; - Status SetOutputBuffers(const std::vector& outputs) ORT_MUST_USE_RESULT; + common::Status SetInputBuffers(const std::vector& inputs); + common::Status SetOutputBuffers(const std::vector& outputs); // Execute the NNAPI model // if there is dynamic output shape, will output the actual output shapes - Status Predict(const std::vector& dynamic_outputs, std::vector& dynamic_output_shapes) - ORT_MUST_USE_RESULT; + common::Status Predict(const std::vector& dynamic_outputs, std::vector& dynamic_output_shapes); private: - Status SetInputBuffer(const int32_t index, const InputBuffer& input) ORT_MUST_USE_RESULT; - Status SetOutputBuffer(const int32_t index, const OutputBuffer& output) ORT_MUST_USE_RESULT; + common::Status SetInputBuffer(const int32_t index, const InputBuffer& input); + common::Status SetOutputBuffer(const int32_t index, const OutputBuffer& output); const NnApi* nnapi_{nullptr}; ANeuralNetworksExecution* execution_; diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.cc index fa876a7ef6bca..85a0cf3ad13fc 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.cc @@ -21,8 +21,6 @@ #include "core/providers/nnapi/nnapi_builtin/model.h" #endif -using onnxruntime::NodeUnit; - namespace onnxruntime { namespace { @@ -189,14 +187,6 @@ NnapiExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_view } #ifdef __ANDROID__ -static Status GetOutputBuffer(Ort::CustomOpApi& ort, - OrtKernelContext* context, - const nnapi::Model& model, - const std::string& output_name, - const std::vector& output_shape, - const android::nn::wrapper::Type output_type, - void** output_buffer) ORT_MUST_USE_RESULT; - static Status GetOutputBuffer(Ort::CustomOpApi& ort, OrtKernelContext* context, const nnapi::Model& model, diff --git a/onnxruntime/core/providers/shared/node_unit/node_unit.cc b/onnxruntime/core/providers/shared/node_unit/node_unit.cc index 80492ce701eff..d443fe858f36b 100644 --- a/onnxruntime/core/providers/shared/node_unit/node_unit.cc +++ b/onnxruntime/core/providers/shared/node_unit/node_unit.cc @@ -6,40 +6,170 @@ namespace onnxruntime { +namespace { + +// The QLinearOpType GetQLinearOpType, is very similar to the one in NNAPI +// However, the NNAPI ones are only the subset of the ones here, +// TODO, make these shared +enum class QLinearOpType : uint8_t { + Unknown, // Unknown or not a linear quantized op + DequantizeLinear, + QuantizeLinear, + QLinearConv, + QLinearMatMul, + QLinearAdd, + QLinearSigmoid, + QLinearAveragePool, + QLinearMul, + QLinearReduceMean, + QLinearConcat, + QLinearGlobalAveragePool, + QLinearLeakyRelu, +}; + +QLinearOpType GetQLinearOpType(const onnxruntime::Node& node) { + const auto& op_type = node.OpType(); + if (op_type == "DequantizeLinear") + return QLinearOpType::DequantizeLinear; + else if (op_type == "QuantizeLinear") + return QLinearOpType::QuantizeLinear; + else if (op_type == "QLinearConv") + return QLinearOpType::QLinearConv; + else if (op_type == "QLinearMatMul") + return QLinearOpType::QLinearMatMul; + else if (op_type == "QLinearAdd") + return QLinearOpType::QLinearAdd; + else if (op_type == "QLinearSigmoid") + return QLinearOpType::QLinearSigmoid; + else if (op_type == "QLinearAveragePool") + return QLinearOpType::QLinearAveragePool; + else if (op_type == "QLinearMul") + return QLinearOpType::QLinearMul; + else if (op_type == "QLinearReduceMean") + return QLinearOpType::QLinearReduceMean; + else if (op_type == "QLinearConcat") + return QLinearOpType::QLinearConcat; + else if (op_type == "QLinearGlobalAveragePool") + return QLinearOpType::QLinearGlobalAveragePool; + else if (op_type == "QLinearLeakyRelu") + return QLinearOpType::QLinearLeakyRelu; + + return QLinearOpType::Unknown; +} + +// Ops have 1 input +bool IsUnaryQLinearOp(QLinearOpType type) { + return type == QLinearOpType::QLinearSigmoid || + type == QLinearOpType::QLinearAveragePool || + type == QLinearOpType::QLinearGlobalAveragePool || + type == QLinearOpType::QLinearLeakyRelu || + type == QLinearOpType::QLinearReduceMean; +} + +// Ops have 2 inputs +bool IsBinaryQLinearOp(QLinearOpType type) { + return type == QLinearOpType::QLinearConv || + type == QLinearOpType::QLinearMatMul || + type == QLinearOpType::QLinearAdd || + type == QLinearOpType::QLinearMul; +} + +// Ops have 1 or more inputs +bool IsVariadicQLinearOp(QLinearOpType type) { + return type == QLinearOpType::QLinearConcat; +} + +} // namespace + NodeUnit::NodeUnit(const Node& node) - : nodes_{&node}, - node_(node), + : output_nodes_{&node}, + target_node_(node), type_(Type::SingleNode) { InitForNode(); } -const std::string& NodeUnit::Domain() const noexcept { return node_.Domain(); } -const std::string& NodeUnit::OpType() const noexcept { return node_.OpType(); } -const std::string& NodeUnit::Name() const noexcept { return node_.Name(); } -int NodeUnit::SinceVersion() const noexcept { return node_.SinceVersion(); } -NodeIndex NodeUnit::Index() const noexcept { return node_.Index(); } -const Path& NodeUnit::ModelPath() const noexcept { return node_.ModelPath(); } -ProviderType NodeUnit::GetExecutionProviderType() const noexcept { return node_.GetExecutionProviderType(); } +const std::string& NodeUnit::Domain() const noexcept { return target_node_.Domain(); } +const std::string& NodeUnit::OpType() const noexcept { return target_node_.OpType(); } +const std::string& NodeUnit::Name() const noexcept { return target_node_.Name(); } +int NodeUnit::SinceVersion() const noexcept { return target_node_.SinceVersion(); } +NodeIndex NodeUnit::Index() const noexcept { return target_node_.Index(); } +const Path& NodeUnit::ModelPath() const noexcept { return target_node_.ModelPath(); } +ProviderType NodeUnit::GetExecutionProviderType() const noexcept { return target_node_.GetExecutionProviderType(); } void NodeUnit::InitForNode() { - const auto& input_defs = node_.InputDefs(); - const auto& output_defs = node_.OutputDefs(); - // The 1st step is to hookup the NodeUnit with the NNAPI builder interface - // So we are not handling quantization here now - // TODO, enable quantization - // auto qlinear_type = GetQLinearOpType(node_); - // if (qlinear_type == QLinearOpType::Unknown) { - // Not a Qlinear op, add all inputs/outputs - auto add_all_io = [](std::vector& defs, - const ConstPointerContainer>& node_defs) { - defs.reserve(node_defs.size()); - - for (const auto def : node_defs) { - defs.push_back(NodeUnit::IODef{*def, std::nullopt}); + const auto& input_defs = target_node_.InputDefs(); + const auto& output_defs = target_node_.OutputDefs(); + auto qlinear_type = GetQLinearOpType(target_node_); + if (qlinear_type == QLinearOpType::Unknown || + IsVariadicQLinearOp(qlinear_type)) { // TODO, add variadic support + // Not a Qlinear op, add all inputs / outputs + auto add_all_io = [](std::vector& defs, + const ConstPointerContainer>& node_defs) { + defs.reserve(node_defs.size()); + + for (const auto def : node_defs) { + defs.push_back(NodeUnitIODef{*def, std::nullopt}); + } + }; + add_all_io(inputs_, input_defs); + add_all_io(outputs_, output_defs); + } else if (IsUnaryQLinearOp(qlinear_type)) { + // Unary QLinear Op has 5 inputs + // x, x_scale, x_zp, y_scale, y_zp (optional) + inputs_.push_back(NodeUnitIODef{ + *input_defs[0], + NodeUnitIODef::QuantParam{*input_defs[1], input_defs[2]}}); + + outputs_.push_back(NodeUnitIODef{ + *output_defs[0], + NodeUnitIODef::QuantParam{*input_defs[3], + input_defs.size() > 4 + ? input_defs[4] + : nullptr}}); + } else if (IsBinaryQLinearOp(qlinear_type)) { + // Binary QLinear Op has 9 inputs + // x1, x1_scale, x1_zp, x2/w, x2_scale, x2_zp, y_scale , y_zp, B + inputs_.push_back(NodeUnitIODef{ + *input_defs[0], + NodeUnitIODef::QuantParam{*input_defs[1], input_defs[2]}}); + inputs_.push_back(NodeUnitIODef{ + *input_defs[3], + NodeUnitIODef::QuantParam{*input_defs[4], input_defs[5]}}); + + if (input_defs.size() == 9) { // has Bias + inputs_.push_back(NodeUnitIODef{ + *input_defs[8], + std::nullopt}); // for Bias the scale and zp are optional } - }; - add_all_io(input_defs_, input_defs); - add_all_io(output_defs_, output_defs); + + outputs_.push_back(NodeUnitIODef{ + *output_defs[0], + NodeUnitIODef::QuantParam{*input_defs[6], input_defs[7]}}); + } else if (qlinear_type == QLinearOpType::DequantizeLinear) { + // DequantizeLinear has 3 inputs + // x, x_scale, x_zp + // output is not quantized + inputs_.push_back(NodeUnitIODef{ + *input_defs[0], + NodeUnitIODef::QuantParam{*input_defs[1], + input_defs.size() == 3 + ? input_defs[2] + : nullptr}}); + outputs_.push_back(NodeUnitIODef{*output_defs[0], std::nullopt}); + } else if (qlinear_type == QLinearOpType::QuantizeLinear) { + // QuantizeLinear the input is not quantized and has 3 inputs + // x, y_scale, y_zp (optional) + // The output is quantized + inputs_.push_back(NodeUnitIODef{*input_defs[0], std::nullopt}); + outputs_.push_back(NodeUnitIODef{ + *output_defs[0], + NodeUnitIODef::QuantParam{*input_defs[1], + input_defs.size() == 3 + ? input_defs[2] + : nullptr}}); + } else { + ORT_THROW("The QLinear op [", static_cast(qlinear_type), "] is not supported"); + } } } // namespace onnxruntime diff --git a/onnxruntime/core/providers/shared/node_unit/node_unit.h b/onnxruntime/core/providers/shared/node_unit/node_unit.h index d94b0e2fc55e7..e109703c9316f 100644 --- a/onnxruntime/core/providers/shared/node_unit/node_unit.h +++ b/onnxruntime/core/providers/shared/node_unit/node_unit.h @@ -21,6 +21,20 @@ namespace QDQ { struct NodeGroup; } +// Definition of one input or output +// If the optional quant_param is present, then this is a quantized input, +// otherwise this is a regular input +struct NodeUnitIODef { + // The quantization parameter, scale is manadatory, and zero_point is optional + struct QuantParam { + const NodeArg& scale; + const NodeArg* zero_point{nullptr}; + }; + + const NodeArg& node_arg; + const std::optional quant_param; +}; + /** @class NodeUnit Class to represent a single node or a QDQ group of nodes, which will be used as a single unit. @@ -33,27 +47,13 @@ class NodeUnit { QDQGroup, // The NodeUnit contain a QDQ group of nodes, such as "DQ->Sigmoid->Q" }; - // Definition of one input or output - // If the optional quant_param is present, then this is a quantized input, - // otherwise this is a regular input - struct IODef { - // The quantization parmeter, scale is manadatory, and zero_point is optional - struct QuantParam { - const NodeArg& scale; - const NodeArg* zero_point{nullptr}; - }; - - const NodeArg& node_arg; - const std::optional quant_param; - }; - public: explicit NodeUnit(const Node& node); Type UnitType() const noexcept { return type_; } - const std::vector& Inputs() const noexcept { return input_defs_; } - const std::vector& Outputs() const noexcept { return output_defs_; } + const std::vector& Inputs() const noexcept { return inputs_; } + const std::vector& Outputs() const noexcept { return outputs_; } const std::string& Domain() const noexcept; const std::string& OpType() const noexcept; @@ -63,16 +63,15 @@ class NodeUnit { const Path& ModelPath() const noexcept; ProviderType GetExecutionProviderType() const noexcept; - const Node& GetNode() const noexcept { return node_; } - - const std::vector GetAllNodes() const noexcept { return nodes_; } + const Node& GetNode() const noexcept { return target_node_; } + const std::vector GetOutputNodes() const noexcept { return output_nodes_; } private: - std::vector input_defs_; - std::vector output_defs_; + std::vector inputs_; + std::vector outputs_; - const std::vector nodes_; // all nodes in this NodeUnit - const Node& node_; // target Node + const std::vector output_nodes_; // all the nodes producing outputs for this NodeUnit + const Node& target_node_; Type type_; void InitForNode(); // Initializing for single Node diff --git a/onnxruntime/core/providers/shared/utils/utils.cc b/onnxruntime/core/providers/shared/utils/utils.cc index 6f38a8e368ea4..d0e33062d7f00 100644 --- a/onnxruntime/core/providers/shared/utils/utils.cc +++ b/onnxruntime/core/providers/shared/utils/utils.cc @@ -8,6 +8,7 @@ #include #include #include +#include "core/providers/shared/node_unit/node_unit.h" namespace onnxruntime { @@ -81,6 +82,9 @@ bool GetClipMinMax(const InitializedTensorSet& initializers, const Node& node, NodeAttrHelper::NodeAttrHelper(const onnxruntime::Node& node) : node_attributes_(node.GetAttributes()) {} +NodeAttrHelper::NodeAttrHelper(const NodeUnit& node_unit) + : node_attributes_(node_unit.GetNode().GetAttributes()) {} + float NodeAttrHelper::Get(const std::string& key, float def_val) const { if (!HasAttr(key)) return def_val; diff --git a/onnxruntime/core/providers/shared/utils/utils.h b/onnxruntime/core/providers/shared/utils/utils.h index 925df731fcee3..26898aa95e893 100644 --- a/onnxruntime/core/providers/shared/utils/utils.h +++ b/onnxruntime/core/providers/shared/utils/utils.h @@ -17,6 +17,7 @@ class Logger; class Node; class NodeArg; +class NodeUnit; // Get the min/max of a Clip operator. // If min/max are not known initializer tensors, will return false @@ -34,7 +35,10 @@ bool GetType(const NodeArg& node_arg, int32_t& type, const logging::Logger& logg */ class NodeAttrHelper { public: - NodeAttrHelper(const onnxruntime::Node& node); + explicit NodeAttrHelper(const Node& node); + + // Get the attributes from the target node of the node_unit + explicit NodeAttrHelper(const NodeUnit& node_unit); float Get(const std::string& key, float def_val) const; @@ -52,7 +56,7 @@ class NodeAttrHelper { bool HasAttr(const std::string& key) const; private: - const onnxruntime::NodeAttributes& node_attributes_; + const NodeAttributes& node_attributes_; }; } // namespace onnxruntime