From d50890b0f0ac9788b6bb942cd652c6575757a836 Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Thu, 25 Jul 2024 03:53:49 -0700 Subject: [PATCH 01/20] Try fuse conv + relu in qnn --- .../qnn/builder/qnn_conv_activation_fusion.cc | 320 ++++++++++++++++++ .../qnn/builder/qnn_conv_activation_fusion.h | 24 ++ .../core/providers/qnn/builder/qnn_fusions.cc | 135 +++++--- .../providers/qnn/builder/qnn_model_wrapper.h | 6 + .../core/providers/qnn/builder/qnn_utils.h | 9 + 5 files changed, 441 insertions(+), 53 deletions(-) create mode 100644 onnxruntime/core/providers/qnn/builder/qnn_conv_activation_fusion.cc create mode 100644 onnxruntime/core/providers/qnn/builder/qnn_conv_activation_fusion.h diff --git a/onnxruntime/core/providers/qnn/builder/qnn_conv_activation_fusion.cc b/onnxruntime/core/providers/qnn/builder/qnn_conv_activation_fusion.cc new file mode 100644 index 0000000000000..d09c6876bc279 --- /dev/null +++ b/onnxruntime/core/providers/qnn/builder/qnn_conv_activation_fusion.cc @@ -0,0 +1,320 @@ +#include "core/providers/qnn/builder/qnn_conv_activation_fusion.h" + +#include +#include +#include +#include +#include +#include "core/graph/graph_utils.h" +#include "core/optimizer/qdq_transformer/qdq_util.h" +#include "core/framework/node_unit.h" +#include "core/providers/qnn/builder/qnn_utils.h" +#include "core/providers/qnn/builder/op_builder_factory.h" + +namespace onnxruntime { +namespace qnn { + +static const NodeUnit* GetOnlyChildOfType(const GraphViewer& graph_viewer, + const NodeUnit& parent_node_unit, + gsl::span child_op_types, + const std::unordered_map& node_unit_map, + const std::unordered_set& handled_node_units) { + const Node& parent_node = parent_node_unit.GetNode(); + + // Parent must have a single child (1 output edge) and must not produce a graph output. + if (parent_node.GetOutputEdgesCount() != 1 || graph_viewer.NodeProducesGraphOutput(parent_node)) { + return nullptr; + } + + // Child must be of a valid type. + const Node& child_node = parent_node.OutputEdgesBegin()->GetNode(); + const std::string& child_type = child_node.OpType(); + bool is_valid_child_type = false; + + for (const auto& valid_op_type : child_op_types) { + if (valid_op_type == child_type) { + is_valid_child_type = true; + break; + } + } + + if (!is_valid_child_type) { + return nullptr; + } + + const auto child_node_unit_it = node_unit_map.find(&child_node); + assert(child_node_unit_it != node_unit_map.end()); + const NodeUnit* child_node_unit = child_node_unit_it->second; + + // Check if child node has already been handled. Should not be the case if the calling + // fusion function has been called in topological order, but check to be safe. + if (handled_node_units.count(child_node_unit) != 0) { + return nullptr; + } + + // child must not already be part of a QDQ NodeUnit (i.e., be standalone). + if (child_node_unit->UnitType() != NodeUnit::Type::SingleNode) { + return nullptr; + } + + return child_node_unit; +} + +static bool CanClipBeRemoved(const QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& clip_node_unit, + const NodeUnit& q_node_unit) { + assert(clip_node_unit.OpType() == "Clip" && q_node_unit.OpType() == QDQ::QOpName); + // TODO(adrianlizarraga): Implement. + (void)qnn_model_wrapper; + return true; +} + +static bool CanReluBeRemoved(const QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& relu_node_unit, + const NodeUnit& q_node_unit) { + assert(relu_node_unit.OpType() == "Relu" && q_node_unit.OpType() == QDQ::QOpName); + const auto& q_inputs = q_node_unit.GetNode().InputDefs(); + + // Require an explicit zero-point input for now. + if (q_inputs.size() != 3 || !q_inputs[QDQ::ZERO_POINT_ID]->Exists()) { + return false; + } + + std::vector zero_points; + int32_t zp_data_type = ONNX_NAMESPACE::TensorProto::DataType::TensorProto_DataType_UNDEFINED; + Status status = qnn_model_wrapper.UnpackZeroPoints(q_inputs[QDQ::ZERO_POINT_ID]->Name(), + zero_points, zp_data_type); + + // Should only have one zero-point (per-tensor). + if (!status.IsOK() || zero_points.size() != 1) { + return false; + } + + int32_t onnx_zero_point = -zero_points[0]; // QNN zero-points are negated. + + // Relu is redundant if the zero-point is set to the smallest quantized value. + switch (zp_data_type) { + case ONNX_NAMESPACE::TensorProto::DataType::TensorProto_DataType_INT8: + return onnx_zero_point == static_cast(std::numeric_limits::lowest()); + case ONNX_NAMESPACE::TensorProto::DataType::TensorProto_DataType_UINT8: + return onnx_zero_point == static_cast(std::numeric_limits::lowest()); + case ONNX_NAMESPACE::TensorProto::DataType::TensorProto_DataType_INT16: + return onnx_zero_point == static_cast(std::numeric_limits::lowest()); + case ONNX_NAMESPACE::TensorProto::DataType::TensorProto_DataType_UINT16: + return onnx_zero_point == static_cast(std::numeric_limits::lowest()); + default: + return false; + } +} + +static bool CanActivationBeRemoved(const QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& activation_node_unit, + const NodeUnit& q_node_unit) { + const std::string& activation_type = activation_node_unit.OpType(); + + if (activation_type == "Relu") { + return CanReluBeRemoved(qnn_model_wrapper, activation_node_unit, q_node_unit); + } + + if (activation_type == "Clip") { + return CanClipBeRemoved(qnn_model_wrapper, activation_node_unit, q_node_unit); + } + + return false; +} + +// adjust for an optional input/output that has an entry but does not exist +static int NumActualValues(const Node& node, bool input) { + const auto& defs = input ? node.InputDefs() : node.OutputDefs(); + return gsl::narrow_cast(std::count_if(defs.cbegin(), defs.cend(), + [](const NodeArg* def) { return def && def->Exists(); })); +} + +static std::vector FindQDQNodes(const GraphViewer& graph_viewer, const Node& node, bool find_dq_nodes) { + // First get all the upstream (DQ) or downstream (Q) nodes + std::vector nodes = + find_dq_nodes ? graph_utils::FindParentsByType(node, QDQ::DQOpName) + : graph_utils::FindChildrenByType(node, QDQ::QOpName); + + // Remove all the nodes which are not in the graph_viewer + nodes.erase(std::remove_if(nodes.begin(), nodes.end(), + [&graph_viewer](const Node* _node) { + return _node == nullptr || graph_viewer.GetNode(_node->Index()) == nullptr; + }), + nodes.end()); + + return nodes; +} + +static std::optional GetConvQDQNodeGroup( + const GraphViewer& graph_viewer, + const std::unordered_map& node_unit_map, + const std::unordered_set& handled_node_units, + const Node& conv_node, + const Node& q_node) { + assert((conv_node.OpType() == "Conv" || conv_node.OpType() == "ConvTranspose") && + q_node.OpType() == QDQ::QOpName); + std::vector dq_nodes = FindQDQNodes(graph_viewer, conv_node, /*find_dq_nodes*/ true); + std::vector q_nodes = {&q_node}; + int num_dq_inputs = NumActualValues(conv_node, /*input*/ true); + + // Within a QDQ node group, a target node input is the only consumer of each DQ. + if (num_dq_inputs != static_cast(dq_nodes.size())) { + return std::nullopt; + } + + for (const auto* dq_node : dq_nodes) { + if (graph_viewer.NodeProducesGraphOutput(*dq_node)) { + return std::nullopt; + } + + const bool dq_has_single_output_edge_to_target = + dq_node->GetOutputEdgesCount() == 1 && + dq_node->OutputEdgesBegin()->GetNode().Index() == conv_node.Index(); + if (!dq_has_single_output_edge_to_target) { + return std::nullopt; + } + + const auto dq_node_unit_it = node_unit_map.find(dq_node); + assert(dq_node_unit_it != node_unit_map.end()); + const NodeUnit* dq_node_unit = dq_node_unit_it->second; + + if (handled_node_units.count(dq_node_unit) != 0) { + return std::nullopt; + } + + if (dq_node_unit->UnitType() != NodeUnit::Type::SingleNode) { + return std::nullopt; + } + } + + // input and output types need to be same + int32_t dt_input = dq_nodes[0]->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); + int32_t dt_weight = dq_nodes[1]->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); + int32_t dt_output = q_nodes[0]->OutputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); + if (dt_input != dt_output) { + return std::nullopt; + } + + if (dt_input == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8) { + if (dt_weight != dt_input) { + return std::nullopt; + } + } + + if (dq_nodes.size() == 3) { // has bias + int32_t dt_bias = dq_nodes[2]->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); + if (dt_bias != ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32) { + return std::nullopt; + } + } + + QDQ::NodeGroup node_group; + node_group.dq_nodes.reserve(dq_nodes.size()); + node_group.q_nodes.reserve(q_nodes.size()); + node_group.target_node = conv_node.Index(); + auto get_node_idx = [&](const Node* n) { return n->Index(); }; + std::transform(dq_nodes.begin(), dq_nodes.end(), std::back_inserter(node_group.dq_nodes), get_node_idx); + std::transform(q_nodes.begin(), q_nodes.end(), std::back_inserter(node_group.q_nodes), get_node_idx); + return node_group; +} + +Status TryConvActivationFusion(/*out*/ std::vector& fused_nodes, + QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& conv_node_unit, + const std::unordered_map& node_unit_map, + const std::unordered_set& handled_node_units, + const logging::Logger& logger, + bool do_op_validation) { + // Expect that this function is called with a standalone Conv or ConvTranspose. + assert((conv_node_unit.OpType() == "Conv" || conv_node_unit.OpType() == "ConvTranspose") && + conv_node_unit.UnitType() == NodeUnit::Type::SingleNode); + + const GraphViewer& graph_viewer = qnn_model_wrapper.GetGraphViewer(); + + // Conv must have a single Relu or Clip child. + const std::array activation_op_types = {"Relu", "Clip"}; + const NodeUnit* activation_node_unit = GetOnlyChildOfType(graph_viewer, conv_node_unit, activation_op_types, + node_unit_map, handled_node_units); + if (activation_node_unit == nullptr) { + return Status::OK(); + } + + // Relu/Clip must have a single Q child. + const std::array q_op_types = {QDQ::QOpName}; + const NodeUnit* q_node_unit = GetOnlyChildOfType(graph_viewer, *activation_node_unit, q_op_types, + node_unit_map, handled_node_units); + + if (q_node_unit == nullptr) { + return Status::OK(); + } + + // Check if Clip/Relu can be removed because the Q node provides an equivalent effect. + if (!CanActivationBeRemoved(qnn_model_wrapper, *activation_node_unit, *q_node_unit)) { + return Status::OK(); + } + + // Create a QDQ node group with DQ* -> Conv -> Q + const Node& conv_node = conv_node_unit.GetNode(); + const Node& activation_node = activation_node_unit->GetNode(); + const Node& q_node = q_node_unit->GetNode(); + std::optional qdq_node_group = GetConvQDQNodeGroup(graph_viewer, + node_unit_map, + handled_node_units, + conv_node, + q_node); + + if (!qdq_node_group.has_value()) { + return Status::OK(); + } + + NodeUnit qdq_node_unit(graph_viewer, *qdq_node_group); + + // Create a temporary QnnModelWrapper for validation only. We need to be sure that this fusion will work before + // modifying the actual QnnModelWrapper. This allows us to revert to the traditional OpBuilder workflow if this + // fusion doesn't work out. + QnnModelWrapper tmp_model_wrapper(graph_viewer, + logger, + qnn_model_wrapper.GetQnnInterface(), + qnn_model_wrapper.GetQnnBackendHandle(), + qnn_model_wrapper.GetInputIndexMap(), + qnn_model_wrapper.GetOutputIndexMap(), + qnn_model_wrapper.GetInitializerLookup(), + qnn_model_wrapper.GetQnnBackendType()); + + const auto* conv_op_builder = qnn::GetOpBuilder(qdq_node_unit.OpType()); + if (conv_op_builder == nullptr) { + return Status::OK(); + } + + QNN_RETURN_OK_IF_ERROR(conv_op_builder->IsOpSupported(tmp_model_wrapper, qdq_node_unit, logger), logger); + + // ====== The following statements modify qnn_model_wrapper. ======== + // Validation passed, so we're now committed to doing a fusion. + // If we encounter an error, we return it directly to caller. + LOGS(logger, VERBOSE) << " Adding Conv + Activation via fusion. conv_node name: [" << conv_node.Name() + << "] activation_node optype: [" << activation_node.OpType() + << "] activation_node name: [" << activation_node.Name() + << "]"; + + if (do_op_validation) { + ORT_RETURN_IF_ERROR(conv_op_builder->IsOpSupported(qnn_model_wrapper, qdq_node_unit, logger)); + } else { + ORT_RETURN_IF_ERROR(conv_op_builder->AddToModelBuilder(qnn_model_wrapper, qdq_node_unit, logger)); + } + + // Success. Add all nodes to fused_nodes so that caller can mark them as handled. + for (const Node* dq_node : qdq_node_unit.GetDQNodes()) { + const auto dq_node_unit_it = node_unit_map.find(dq_node); + ORT_RETURN_IF(dq_node_unit_it == node_unit_map.end(), "DQ node does not have a NodeUnit"); + fused_nodes.push_back(dq_node_unit_it->second); + } + + fused_nodes.push_back(&conv_node_unit); + fused_nodes.push_back(activation_node_unit); + fused_nodes.push_back(q_node_unit); + + return Status::OK(); +} +} // namespace qnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/qnn_conv_activation_fusion.h b/onnxruntime/core/providers/qnn/builder/qnn_conv_activation_fusion.h new file mode 100644 index 0000000000000..76c08a269f90b --- /dev/null +++ b/onnxruntime/core/providers/qnn/builder/qnn_conv_activation_fusion.h @@ -0,0 +1,24 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include + +#include "core/framework/node_unit.h" +#include "core/providers/qnn/builder/qnn_model_wrapper.h" + +namespace onnxruntime { +namespace qnn { + +Status TryConvActivationFusion(/*out*/ std::vector& fused_nodes, + QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& conv_node_unit, + const std::unordered_map& node_unit_map, + const std::unordered_set& handled_node_units, + const logging::Logger& logger, + bool do_op_validation); +} // namespace qnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/qnn_fusions.cc b/onnxruntime/core/providers/qnn/builder/qnn_fusions.cc index b04075f11203c..e30b34cfa7ca5 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_fusions.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_fusions.cc @@ -15,15 +15,7 @@ #include "core/providers/qnn/builder/qnn_utils.h" #include "core/providers/qnn/builder/qnn_model_wrapper.h" #include "core/providers/qnn/builder/op_builder_factory.h" - -#define QNN_RETURN_OK_IF_ERROR(expr, logger) \ - do { \ - auto _status = (expr); \ - if ((!_status.IsOK())) { \ - LOGS((logger), VERBOSE) << _status.ErrorMessage(); \ - return Status::OK(); \ - } \ - } while (0) +#include "core/providers/qnn/builder/qnn_conv_activation_fusion.h" namespace onnxruntime { namespace qnn { @@ -34,53 +26,26 @@ namespace qnn { * * \param fused_nodes Output list of node units that were fused. Remains empty if fusion is not applied. * \param qnn_model_wrapper The QNN model that is being built. - * \param start_node_unit The node unit that could potentially start the DQ -> Q sequence. - * \param node_unit_map Maps a node to its node unit. - * \param handled_node_units Set of node units that have already been processed. Fusion will not fuse nodes - * in this set. + * \param dq_node_unit The DQ node unit. + * \param q_node_unit The Q node unit. * \param logger The logger. * \param do_op_validation True if should call QNN operator validation APIs. * \return An onnxruntime::Status */ static Status TryHandleConvertSequence(std::vector& fused_nodes, QnnModelWrapper& qnn_model_wrapper, - const NodeUnit& start_node_unit, - const std::unordered_map& node_unit_map, - const std::unordered_set& handled_node_units, + const NodeUnit& dq_node_unit, + const NodeUnit& q_node_unit, const logging::Logger& logger, bool do_op_validation) { const GraphViewer& graph_viewer = qnn_model_wrapper.GetGraphViewer(); // Looking for a standalone DQ to start the sequence. - if (start_node_unit.OpType() != QDQ::DQOpName || start_node_unit.UnitType() != NodeUnit::Type::SingleNode) { - return Status::OK(); - } - - const Node& dq_node = start_node_unit.GetNode(); - - // DQ must have a single Q child. DQ must not produce a graph output. - auto children = graph_utils::FindChildrenByType(dq_node, QDQ::QOpName); - if (children.size() != 1 || dq_node.GetOutputEdgesCount() != 1 || graph_viewer.NodeProducesGraphOutput(dq_node)) { - return Status::OK(); - } - - const Node& q_node = *children[0]; - const auto q_node_unit_it = node_unit_map.find(&q_node); - - ORT_RETURN_IF(q_node_unit_it == node_unit_map.end(), "Node does not have a corresponding NodeUnit"); - - const NodeUnit* q_node_unit = q_node_unit_it->second; - - // Check if Q node has already been handled. Should not be the case if this - // fusion function has been called in topological order, but check to be safe. - if (handled_node_units.count(q_node_unit) != 0) { - return Status::OK(); - } + assert(dq_node_unit.OpType() == QDQ::DQOpName && dq_node_unit.UnitType() == NodeUnit::Type::SingleNode); + assert(q_node_unit.OpType() == QDQ::QOpName && q_node_unit.UnitType() == NodeUnit::Type::SingleNode); - // Q child must not already be part of a QDQ NodeUnit (i.e., be standalone). - if (q_node_unit->UnitType() != NodeUnit::Type::SingleNode) { - return Status::OK(); - } + const Node& dq_node = dq_node_unit.GetNode(); + const Node& q_node = q_node_unit.GetNode(); auto get_const_initializer = [&graph_viewer](const std::string& initializer_name) { return graph_viewer.GetConstantInitializer(initializer_name, true); @@ -91,9 +56,9 @@ static Status TryHandleConvertSequence(std::vector& fused_nodes return Status::OK(); } - const auto& node_name = utils::GetNodeName(start_node_unit); - const NodeUnitIODef& input_def = start_node_unit.Inputs()[0]; - const NodeUnitIODef& output_def = q_node_unit->Outputs()[0]; + const auto& node_name = utils::GetNodeName(dq_node_unit); + const NodeUnitIODef& input_def = dq_node_unit.Inputs()[0]; + const NodeUnitIODef& output_def = q_node_unit.Outputs()[0]; QnnTensorWrapper input_tensor; QnnTensorWrapper output_tensor; @@ -115,14 +80,14 @@ static Status TryHandleConvertSequence(std::vector& fused_nodes // If we encounter an error, we return it directly to caller. LOGS(logger, VERBOSE) << " Adding QNN Convert via fusion. dq_node name: [" << dq_node.Name() << "] dq_node optype: [" << dq_node.OpType() - << "] q_node name: [" << q_node_unit->Name() - << "] q_node optype: [" << q_node_unit->OpType() + << "] q_node name: [" << q_node_unit.Name() + << "] q_node optype: [" << q_node_unit.OpType() << "]"; // Add a QNN Convert to the model. Get the input from the DQ node, and the output from the Q node. ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(input_tensor)), "Failed to add input"); ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(output_tensor)), "Failed to add output"); - ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(utils::GetNodeName(*q_node_unit), + ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(utils::GetNodeName(q_node_unit), QNN_OP_PACKAGE_NAME_QTI_AISW, QNN_OP_CONVERT, {input_def.node_arg.Name()}, @@ -131,8 +96,72 @@ static Status TryHandleConvertSequence(std::vector& fused_nodes do_op_validation), "Failed to add fused Convert node."); - fused_nodes.push_back(&start_node_unit); - fused_nodes.push_back(q_node_unit); + fused_nodes.push_back(&dq_node_unit); + fused_nodes.push_back(&q_node_unit); + + return Status::OK(); +} + +/** + * Tries to fuse sequences that start with a DQ node. + * + * \param fused_nodes Output list of node units that were fused. Remains empty if fusion is not applied. + * \param qnn_model_wrapper The QNN model that is being built. + * \param dq_node_unit The DQ node unit that could potentially start a sequence. + * \param node_unit_map Maps a node to its node unit. + * \param handled_node_units Set of node units that have already been processed. Fusion will not fuse nodes + * in this set. + * \param logger The logger. + * \param do_op_validation True if should call QNN operator validation APIs. + * \return An onnxruntime::Status + */ +static Status TryHandleDequantize(std::vector& fused_nodes, + QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& dq_node_unit, + const std::unordered_map& node_unit_map, + const std::unordered_set& handled_node_units, + const logging::Logger& logger, + bool do_op_validation) { + // Expect that this function is called with a standalone DQ. + assert(dq_node_unit.OpType() == QDQ::DQOpName && dq_node_unit.UnitType() == NodeUnit::Type::SingleNode); + + const GraphViewer& graph_viewer = qnn_model_wrapper.GetGraphViewer(); + const Node& dq_node = dq_node_unit.GetNode(); + + // DQ must have a single child (1 output edge) and must not produce a graph output. + if (dq_node.GetOutputEdgesCount() != 1 || graph_viewer.NodeProducesGraphOutput(dq_node)) { + return Status::OK(); + } + + const Node& child_node = dq_node.OutputEdgesBegin()->GetNode(); + const auto child_node_unit_it = node_unit_map.find(&child_node); + ORT_RETURN_IF(child_node_unit_it == node_unit_map.end(), "Node does not have a corresponding NodeUnit"); + const NodeUnit* child_node_unit = child_node_unit_it->second; + + // Check if child node has already been handled. Should not be the case if this + // fusion function has been called in topological order, but check to be safe. + if (handled_node_units.count(child_node_unit) != 0) { + return Status::OK(); + } + + // child must not already be part of a QDQ NodeUnit (i.e., be standalone). + if (child_node_unit->UnitType() != NodeUnit::Type::SingleNode) { + return Status::OK(); + } + + const std::string& child_type = child_node.OpType(); + + // Try (DQ -> Q) into QNN's Convert op. + if (child_type == QDQ::QOpName) { + return TryHandleConvertSequence(fused_nodes, qnn_model_wrapper, dq_node_unit, *child_node_unit, + logger, do_op_validation); + } + + // Try (DQ -> Conv/ConvTranspose -> Relu/Clip -> Q) into QNN Conv/ConvTranspose. + if (child_type == "Conv" || child_type == "ConvTranspose") { + return TryConvActivationFusion(fused_nodes, qnn_model_wrapper, *child_node_unit, node_unit_map, + handled_node_units, logger, do_op_validation); + } return Status::OK(); } @@ -269,7 +298,7 @@ Status TryFusions(/*out*/ std::vector& fused_nodes, bool validate) { // Maps a starting operator type to the fusion function. static std::unordered_map fusions = { - {"DequantizeLinear", TryHandleConvertSequence}, + {"DequantizeLinear", TryHandleDequantize}, {"HardSigmoid", TryHandleHardSigmoidSequence}, }; diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.h b/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.h index 9ab122b7f8e28..fdf6616393ff8 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.h @@ -52,6 +52,12 @@ class QnnModelWrapper { ~QnnModelWrapper() = default; + const QNN_INTERFACE_VER_TYPE& GetQnnInterface() const { return qnn_interface_; } + const Qnn_BackendHandle_t& GetQnnBackendHandle() const { return backend_handle_; } + const std::unordered_map& GetInputIndexMap() const { return input_index_map_; } + const std::unordered_map& GetOutputIndexMap() const { return output_index_map_; } + const std::unordered_set& GetInitializerLookup() const { return initializer_lookup_; } + bool CreateQnnGraph(const Qnn_ContextHandle_t& context, const std::string& graph_name, const QnnGraph_Config_t** graph_configs = nullptr); diff --git a/onnxruntime/core/providers/qnn/builder/qnn_utils.h b/onnxruntime/core/providers/qnn/builder/qnn_utils.h index 2392040d284b7..4305bf56f522e 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_utils.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_utils.h @@ -13,6 +13,15 @@ #include "core/framework/node_unit.h" #include "core/util/qmath.h" +#define QNN_RETURN_OK_IF_ERROR(expr, logger) \ + do { \ + auto _status = (expr); \ + if ((!_status.IsOK())) { \ + LOGS((logger), VERBOSE) << _status.ErrorMessage(); \ + return Status::OK(); \ + } \ + } while (0) + namespace onnxruntime { namespace qnn { class QnnOpConfigWrapper; From 9c164ce2255128c9bcda3788c409fbb76bc87d0e Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Fri, 26 Jul 2024 14:48:57 -0700 Subject: [PATCH 02/20] Check in progress --- onnxruntime/core/framework/node_unit.cc | 37 ++ onnxruntime/core/framework/node_unit.h | 2 + .../qnn/builder/qnn_conv_activation_fusion.cc | 30 +- .../qnn/builder/qnn_conv_activation_fusion.h | 8 + .../core/providers/qnn/builder/qnn_fusions.cc | 335 ++++++++++++++++++ .../core/providers/qnn/builder/qnn_fusions.h | 29 ++ .../test/providers/qnn/qnn_basic_test.cc | 66 +++- 7 files changed, 501 insertions(+), 6 deletions(-) diff --git a/onnxruntime/core/framework/node_unit.cc b/onnxruntime/core/framework/node_unit.cc index e2c06fbdfa621..54964b0275fc8 100644 --- a/onnxruntime/core/framework/node_unit.cc +++ b/onnxruntime/core/framework/node_unit.cc @@ -272,6 +272,43 @@ NodeUnit::NodeUnit(const GraphViewer& graph_viewer, const QDQ::NodeGroup& node_g } } +NodeUnit::NodeUnit(const GraphViewer& graph_viewer, const QDQ::NodeGroup& node_group, + const Node& output_activation_node) + : dq_nodes_{GetQDQIONodes(graph_viewer, node_group, true /* is_input */)}, + target_node_(*graph_viewer.GetNode(node_group.target_node)), + q_nodes_{GetQDQIONodes(graph_viewer, node_group, false /* is_input */)}, + type_(Type::QDQGroup), + inputs_{GetQDQIODefs(target_node_, node_group, true /* is_input */)}, + outputs_{GetQDQIODefs(output_activation_node, node_group, false /* is_input */)} { + input_edge_count_ = std::accumulate(dq_nodes_.cbegin(), dq_nodes_.cend(), size_t(0), + [](size_t acc, const Node* node) { return acc + node->GetInputEdgesCount(); }); + + // add edges for inputs that are not from DQ nodes. there is one edge to each DQ node. + // other inputs could come from initializers or graph inputs (no edges) or other nodes (edge). + input_edge_count_ += target_node_.GetInputEdgesCount() - dq_nodes_.size(); + + // create output edges. each target node output either goes to Q node/s or non-Q node/s. + // ValidateNodeGroupQDQNodes ensures this. + auto cur_edge = output_activation_node.OutputEdgesBegin(); + auto end_edge = output_activation_node.OutputEdgesEnd(); + for (; cur_edge != end_edge; ++cur_edge) { + const Node& node = cur_edge->GetNode(); + + // if node is in q_nodes we hide the Q node. + if (std::find(q_nodes_.cbegin(), q_nodes_.cend(), &node) != q_nodes_.cend()) { + auto src_idx = cur_edge->GetSrcArgIndex(); + auto q_cur_edge = node.OutputEdgesBegin(); + auto q_end_edge = node.OutputEdgesEnd(); + for (; q_cur_edge != q_end_edge; ++q_cur_edge) { + output_edges_.insert(Node::EdgeEnd{q_cur_edge->GetNode(), src_idx, q_cur_edge->GetDstArgIndex()}); + } + } else { + // non-Q node, or Q node that isn't in the QDQ node group (unexpected but may be possible). add as-is. + output_edges_.insert(*cur_edge); + } + } +} + const std::string& NodeUnit::Domain() const noexcept { return target_node_.Domain(); } const std::string& NodeUnit::OpType() const noexcept { return target_node_.OpType(); } const std::string& NodeUnit::Name() const noexcept { return target_node_.Name(); } diff --git a/onnxruntime/core/framework/node_unit.h b/onnxruntime/core/framework/node_unit.h index e84e62479162f..494d7bd849b4b 100644 --- a/onnxruntime/core/framework/node_unit.h +++ b/onnxruntime/core/framework/node_unit.h @@ -68,6 +68,8 @@ class NodeUnit { public: explicit NodeUnit(const Node& node); explicit NodeUnit(const GraphViewer& graph_viewer, const QDQ::NodeGroup& node_group); + explicit NodeUnit(const GraphViewer& graph_viewer, const QDQ::NodeGroup& node_group, + const Node& output_activation_node); Type UnitType() const noexcept { return type_; } diff --git a/onnxruntime/core/providers/qnn/builder/qnn_conv_activation_fusion.cc b/onnxruntime/core/providers/qnn/builder/qnn_conv_activation_fusion.cc index d09c6876bc279..fd9779cb5512a 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_conv_activation_fusion.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_conv_activation_fusion.cc @@ -219,6 +219,34 @@ static std::optional GetConvQDQNodeGroup( return node_group; } +Status QnnConvActivationFusionAdd(QnnModelWrapper& qnn_model_wrapper, + gsl::span dq_node_units, + const NodeUnit* conv_node_unit, + const NodeUnit* activation_node_unit, + const NodeUnit* q_node_unit, + const logging::Logger& logger, + bool validate) { + QDQ::NodeGroup custom_node_group; + custom_node_group.dq_nodes.reserve(dq_node_units.size()); + custom_node_group.q_nodes = std::vector{q_node_unit->Index()}; + custom_node_group.target_node = conv_node_unit->Index(); + auto get_node_idx = [](const NodeUnit* n) { return n->Index(); }; + std::transform(dq_node_units.begin(), dq_node_units.end(), std::back_inserter(custom_node_group.dq_nodes), + get_node_idx); + + NodeUnit custom_node_unit(qnn_model_wrapper.GetGraphViewer(), custom_node_group, activation_node_unit->GetNode()); + const auto* conv_op_builder = qnn::GetOpBuilder(custom_node_unit.OpType()); + if (conv_op_builder == nullptr) { + return Status::OK(); + } + + if (validate) { + return conv_op_builder->IsOpSupported(qnn_model_wrapper, custom_node_unit, logger); + } + + return conv_op_builder->AddToModelBuilder(qnn_model_wrapper, custom_node_unit, logger, validate); +} + Status TryConvActivationFusion(/*out*/ std::vector& fused_nodes, QnnModelWrapper& qnn_model_wrapper, const NodeUnit& conv_node_unit, @@ -268,7 +296,7 @@ Status TryConvActivationFusion(/*out*/ std::vector& fused_nodes return Status::OK(); } - NodeUnit qdq_node_unit(graph_viewer, *qdq_node_group); + NodeUnit qdq_node_unit(graph_viewer, *qdq_node_group, activation_node); // Create a temporary QnnModelWrapper for validation only. We need to be sure that this fusion will work before // modifying the actual QnnModelWrapper. This allows us to revert to the traditional OpBuilder workflow if this diff --git a/onnxruntime/core/providers/qnn/builder/qnn_conv_activation_fusion.h b/onnxruntime/core/providers/qnn/builder/qnn_conv_activation_fusion.h index 76c08a269f90b..6ad37650122d6 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_conv_activation_fusion.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_conv_activation_fusion.h @@ -13,6 +13,14 @@ namespace onnxruntime { namespace qnn { +Status QnnConvActivationFusionAdd(QnnModelWrapper& qnn_model_wrapper, + gsl::span dq_node_units, + const NodeUnit* conv_node_unit, + const NodeUnit* activation_node_unit, + const NodeUnit* q_node_unit, + const logging::Logger& logger, + bool validate = false); + Status TryConvActivationFusion(/*out*/ std::vector& fused_nodes, QnnModelWrapper& qnn_model_wrapper, const NodeUnit& conv_node_unit, diff --git a/onnxruntime/core/providers/qnn/builder/qnn_fusions.cc b/onnxruntime/core/providers/qnn/builder/qnn_fusions.cc index e30b34cfa7ca5..18d5216f65e60 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_fusions.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_fusions.cc @@ -4,6 +4,7 @@ #include "core/providers/qnn/builder/qnn_fusions.h" #include +#include #include #include #include @@ -20,6 +21,234 @@ namespace onnxruntime { namespace qnn { +static Status QnnDQQFusionAdd(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& dq_node_unit, + const NodeUnit& q_node_unit, + const logging::Logger& logger, + bool validate = false) { + assert(dq_node_unit.OpType() == QDQ::DQOpName && q_node_unit.OpType() == QDQ::QOpName); + const auto& node_name = utils::GetNodeName(dq_node_unit); + const NodeUnitIODef& input_def = dq_node_unit.Inputs()[0]; + const NodeUnitIODef& output_def = q_node_unit.Outputs()[0]; + + QnnTensorWrapper input_tensor; + QnnTensorWrapper output_tensor; + + ORT_RETURN_IF_ERROR(qnn_model_wrapper.MakeTensorWrapper(input_def, input_tensor), logger); + ORT_RETURN_IF_ERROR(qnn_model_wrapper.MakeTensorWrapper(output_def, output_tensor), logger); + + if (validate) { + ORT_RETURN_IF_ERROR(qnn_model_wrapper.ValidateQnnNode(node_name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + QNN_OP_CONVERT, + {input_tensor.GetQnnTensor()}, + {output_tensor.GetQnnTensor()}, + {}), + logger); + } else { + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(input_tensor)), "Failed to add input"); + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(output_tensor)), "Failed to add output"); + ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(utils::GetNodeName(q_node_unit), + QNN_OP_PACKAGE_NAME_QTI_AISW, + QNN_OP_CONVERT, + {input_def.node_arg.Name()}, + {output_def.node_arg.Name()}, + {}, + validate), + "Failed to add fused Convert node."); + } +} + +static Status QnnHardSigmoidMulFusionAdd(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& hardsigmoid_node_unit, + const NodeUnit& mul_node_unit, + const logging::Logger& logger, + bool validate = false) { + assert(hardsigmoid_node_unit.OpType() == "HardSigmoid" && mul_node_unit.OpType() == "Mul"); + const auto& node_name = utils::GetNodeName(hardsigmoid_node_unit); + const NodeUnitIODef& input_def = hardsigmoid_node_unit.Inputs()[0]; + const NodeUnitIODef& output_def = mul_node_unit.Outputs()[0]; + + QnnTensorWrapper input_tensor; + QnnTensorWrapper output_tensor; + + ORT_RETURN_IF_ERROR(qnn_model_wrapper.MakeTensorWrapper(input_def, input_tensor), logger); + ORT_RETURN_IF_ERROR(qnn_model_wrapper.MakeTensorWrapper(output_def, output_tensor), logger); + + if (validate) { + ORT_RETURN_IF_ERROR(qnn_model_wrapper.ValidateQnnNode(node_name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + QNN_OP_HARD_SWISH, + {input_tensor.GetQnnTensor()}, + {output_tensor.GetQnnTensor()}, + {}), + logger); + } else { + LOGS(logger, VERBOSE) << " Adding QNN HardSwish via fusion. HardSigmoid name: [" << hardsigmoid_node_unit.Name() + << "] Mul name: [" << mul_node_unit.Name() << "]"; + + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(input_tensor)), "Failed to add input"); + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(output_tensor)), "Failed to add output"); + ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(node_name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + QNN_OP_HARD_SWISH, + {input_def.node_arg.Name()}, + {output_def.node_arg.Name()}, + {}, + validate), + "Failed to add fused HardSwish node."); + } +} + +std::string_view QnnNodeGroup::TypeToString(QnnNodeGroup::Type type) { + static std::array(QnnNodeGroup::Type::COUNT)> type_names = { + "Undefined", + "NodeUnit", + "ConvActivationFusion", + "DQQFusion", + "HardSigmoidMulFusion", + }; + + return type_names[static_cast(type)]; +} + +Status QnnNodeGroup::IsSupported(QnnModelWrapper& qmw, const logging::Logger& logger) const { + switch (type_) { + case Type::NodeUnit: { + ORT_RETURN_IF_NOT(node_units_.size() == 1 && node_units_[0] != nullptr, ""); + const auto* op_builder = qnn::GetOpBuilder(node_units_[0]->OpType()); + ORT_RETURN_IF_NOT(op_builder != nullptr, ""); + return op_builder->IsOpSupported(qmw, *node_units_[0], logger); + } + case Type::ConvActivationFusion: { + const size_t num_node_units = node_units_.size(); + ORT_RETURN_IF_NOT((num_node_units == 5 || num_node_units == 6), ""); + + const bool has_bias_dq = num_node_units == 6; + std::vector dq_node_units = {node_units_[0], node_units_[1]}; + const NodeUnit* conv_node_unit = node_units_[num_node_units - 3]; + const NodeUnit* activation_node_unit = node_units_[num_node_units - 2]; + const NodeUnit* q_node_unit = node_units_[num_node_units - 1]; + + if (has_bias_dq) { + dq_node_units.push_back(node_units_[2]); + } + return QnnConvActivationFusionAdd(qmw, + dq_node_units, + conv_node_unit, + activation_node_unit, + q_node_unit, + logger, + /*validate*/ true); + } + case Type::DQQFusion: { + ORT_RETURN_IF_NOT(node_units_.size() == 2, ""); + const NodeUnit* dq_node_unit = node_units_[0]; + const NodeUnit* q_node_unit = node_units_[1]; + ORT_RETURN_IF_NOT(dq_node_unit != nullptr && q_node_unit != nullptr, ""); + return QnnDQQFusionAdd(qmw, *dq_node_unit, *q_node_unit, logger, /*validate*/ true); + } + case Type::HardSigmoidMulFusion: { + ORT_RETURN_IF_NOT(node_units_.size() == 2, ""); + const NodeUnit* hardsigmoid_node_unit = node_units_[0]; + const NodeUnit* mul_node_unit = node_units_[1]; + ORT_RETURN_IF_NOT(hardsigmoid_node_unit != nullptr && mul_node_unit != nullptr, ""); + return QnnHardSigmoidMulFusionAdd(qmw, *hardsigmoid_node_unit, *mul_node_unit, logger, /*validate*/ true); + } + default: + std::string error_msg = MakeString("Unhandled QnnNodeGroup::Type ", TypeToString(type_), + " in QnnNodeGroup::IsSupported()"); + LOGS(logger, ERROR) << error_msg; + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, error_msg); + } +} + +Status QnnNodeGroup::AddToModelBuilder(QnnModelWrapper& qmw, const logging::Logger& logger) const { + switch (type_) { + case Type::NodeUnit: { + ORT_RETURN_IF_NOT(node_units_.size() == 1 && node_units_[0] != nullptr, ""); + const auto* op_builder = qnn::GetOpBuilder(node_units_[0]->OpType()); + ORT_RETURN_IF_NOT(op_builder != nullptr, ""); + return op_builder->AddToModelBuilder(qmw, *node_units_[0], logger, /*do_op_validation*/ false); + } + case Type::ConvActivationFusion: { + const size_t num_node_units = node_units_.size(); + ORT_RETURN_IF_NOT((num_node_units == 5 || num_node_units == 6), ""); + + const bool has_bias_dq = num_node_units == 6; + std::vector dq_node_units = {node_units_[0], node_units_[1]}; + const NodeUnit* conv_node_unit = node_units_[num_node_units - 3]; + const NodeUnit* activation_node_unit = node_units_[num_node_units - 2]; + const NodeUnit* q_node_unit = node_units_[num_node_units - 1]; + + if (has_bias_dq) { + dq_node_units.push_back(node_units_[2]); + } + return QnnConvActivationFusionAdd(qmw, + dq_node_units, + conv_node_unit, + activation_node_unit, + q_node_unit, + logger, + /*validate*/ false); + } + case Type::DQQFusion: { + ORT_RETURN_IF_NOT(node_units_.size() == 2, ""); + const NodeUnit* dq_node_unit = node_units_[0]; + const NodeUnit* q_node_unit = node_units_[1]; + ORT_RETURN_IF_NOT(dq_node_unit != nullptr && q_node_unit != nullptr, ""); + return QnnDQQFusionAdd(qmw, *dq_node_unit, *q_node_unit, logger, /*validate*/ false); + } + case Type::HardSigmoidMulFusion: { + ORT_RETURN_IF_NOT(node_units_.size() == 2, ""); + const NodeUnit* hardsigmoid_node_unit = node_units_[0]; + const NodeUnit* mul_node_unit = node_units_[1]; + ORT_RETURN_IF_NOT(hardsigmoid_node_unit != nullptr && mul_node_unit != nullptr, ""); + return QnnHardSigmoidMulFusionAdd(qmw, *hardsigmoid_node_unit, *mul_node_unit, logger, /*validate*/ false); + } + default: + std::string error_msg = MakeString("Unhandled QnnNodeGroup::Type ", TypeToString(type_), + " in QnnNodeGroup::AddToModelBuilder()"); + LOGS(logger, ERROR) << error_msg; + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, error_msg); + } +} + +const NodeUnit* QnnNodeGroup::GetTargetNodeUnit(const logging::Logger& logger) const { + switch (type_) { + case Type::NodeUnit: { + if (node_units_.size() != 1) { + return nullptr; + } + return node_units_[0]; + } + case Type::ConvActivationFusion: { + const size_t num_node_units = node_units_.size(); + if (!(num_node_units == 5 || num_node_units == 6)) { + return nullptr; + } + return node_units_[num_node_units - 3]; + } + case Type::DQQFusion: { + if (node_units_.size() != 2) { + return nullptr; + } + return node_units_[0]; + } + case Type::HardSigmoidMulFusion: { + if (node_units_.size() != 2) { + return nullptr; + } + return node_units_[0]; + } + default: + std::string error_msg = MakeString("Unhandled QnnNodeGroup::Type ", TypeToString(type_), + " in QnnNodeGroup::AddToModelBuilder()"); + LOGS(logger, ERROR) << error_msg; + return nullptr; + } +} + /** * Tries to merge a DQ -> Q sequence into a QNN Convert operator. The DQ -> Q must be converting from * one quantization type (e.g., uint8_t) to another (e.g., uint16_t). @@ -319,5 +548,111 @@ Status TryFusions(/*out*/ std::vector& fused_nodes, return Status::OK(); } +static Status TryQnnFusions(/*out*/ std::optional& fused_node_group, + QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& starting_node_unit, + const std::unordered_map& node_to_node_unit, + const std::unordered_map& node_unit_to_qnn_node_group, + const logging::Logger& logger) { + return Status::OK(); +} + +Status GetQnnNodeGroups(/*out*/ std::vector& qnn_node_groups, + QnnModelWrapper& qnn_model_wrapper, + const std::unordered_map& node_to_node_unit, + const logging::Logger& logger) { + const GraphViewer& graph_viewer = qnn_model_wrapper.GetGraphViewer(); + const std::vector sorted_node_indices = graph_viewer.GetNodesInTopologicalOrder(); + const size_t approx_num_node_units = static_cast(graph_viewer.NumberOfNodes() / 2); + + std::vector sorted_qnn_node_group_indices; + sorted_qnn_node_group_indices.reserve(approx_num_node_units); + + std::vector tmp_qnn_node_groups; + tmp_qnn_node_groups.reserve(approx_num_node_units); + + { + std::unordered_map node_unit_to_qnn_node_group; + std::vector> sorted_node_units; + sorted_node_units.reserve(approx_num_node_units); + + // Create QnnNodeGroups for fusions first. + for (NodeIndex node_index : sorted_node_indices) { + gsl::not_null node = graph_viewer.GetNode(node_index); + + // Get the NodeUnit associated with the node. + const auto node_unit_it = node_to_node_unit.find(node); + ORT_RETURN_IF_NOT(node_unit_it != node_to_node_unit.end(), "Could not find NodeUnit for Node ", node->Name()); + gsl::not_null node_unit = node_unit_it->second; + + // Skip this node if it is not the NodeUnit's target node to ensure NodeUnits are visited in topological order. + if (node != &node_unit->GetNode()) { + continue; + } + + sorted_node_units.push_back(node_unit); + + if (node_unit_to_qnn_node_group.count(node_unit) != 0) { + continue; // Already handled this node unit + } + + std::optional fused_node_group; + ORT_RETURN_IF_ERROR(TryQnnFusions(fused_node_group, qnn_model_wrapper, *node_unit, + node_to_node_unit, node_unit_to_qnn_node_group, logger)); + + if (fused_node_group.has_value()) { + const QnnNodeGroup::IndexType index = tmp_qnn_node_groups.size(); + fused_node_group->index_ = index; + + for (const NodeUnit* fused_node_unit : fused_node_group->GetNodeUnits()) { + assert(fused_node_unit != nullptr); + node_unit_to_qnn_node_group.insert({fused_node_unit, index}); + } + + tmp_qnn_node_groups.push_back(std::move(*fused_node_group)); + } + } + + // Create QnnNodeGroups for the leftover NodeUnits. + for (gsl::not_null node_unit : sorted_node_units) { + const auto it = node_unit_to_qnn_node_group.find(node_unit); + if (it != node_unit_to_qnn_node_group.end()) { + // Already handled this NodeUnit. + const QnnNodeGroup& qnn_node_group = tmp_qnn_node_groups[it->second]; + if (node_unit == qnn_node_group.GetTargetNodeUnit(logger)) { + sorted_qnn_node_group_indices.push_back(qnn_node_group.index_); + } + continue; + } + + const QnnNodeGroup::IndexType index = tmp_qnn_node_groups.size(); + QnnNodeGroup fused_node_group = {}; + fused_node_group.type_ = QnnNodeGroup::Type::NodeUnit; + fused_node_group.index_ = index; + fused_node_group.node_units_.resize(1); + fused_node_group.node_units_[0] = node_unit; + tmp_qnn_node_groups.push_back(std::move(fused_node_group)); + + node_unit_to_qnn_node_group.insert({node_unit, index}); + sorted_qnn_node_group_indices.push_back(index); + } + + assert(tmp_qnn_node_groups.size() == sorted_qnn_node_group_indices.size()); + } + + // Copy QnnNodeGroups to output in sorted (topological) order. + qnn_node_groups.resize(0); + qnn_node_groups.reserve(tmp_qnn_node_groups.size()); + for (auto index : sorted_qnn_node_group_indices) { + assert(index < tmp_qnn_node_groups.size()); + QnnNodeGroup qnn_node_group = std::move(tmp_qnn_node_groups[index]); + qnn_node_group.index_ = qnn_node_groups.size(); + qnn_node_groups.push_back(std::move(qnn_node_group)); + } + + assert(qnn_node_groups.size() == sorted_qnn_node_group_indices.size()); + + return Status::OK(); +} } // namespace qnn } // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/qnn_fusions.h b/onnxruntime/core/providers/qnn/builder/qnn_fusions.h index 39e2e71c01d8c..5bb652df1fa44 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_fusions.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_fusions.h @@ -3,6 +3,7 @@ #pragma once +#include #include #include #include @@ -13,6 +14,34 @@ namespace onnxruntime { namespace qnn { +struct QnnNodeGroup { + using IndexType = size_t; + enum class Type : uint8_t { + Undefined = 0, + NodeUnit, + ConvActivationFusion, + DQQFusion, + HardSigmoidMulFusion, + COUNT, + }; + + static std::string_view TypeToString(QnnNodeGroup::Type type); + + Status IsSupported(QnnModelWrapper& qmw, const logging::Logger& logger) const; + Status AddToModelBuilder(QnnModelWrapper& qmw, const logging::Logger& logger) const; + const std::vector& GetNodeUnits() const { return node_units_; } + const NodeUnit* GetTargetNodeUnit(const logging::Logger& logger) const; + + QnnNodeGroup::Type type_ = QnnNodeGroup::Type::Undefined; + IndexType index_ = 0; + std::vector node_units_; +}; + +Status GetQnnNodeGroups(/*out*/ std::vector& qnn_node_groups, + QnnModelWrapper& qnn_model_wrapper, + const std::unordered_map& node_unit_map, + const logging::Logger& logger); + /** * Tries to fuse a node sequence starting from the given starting node. Should be called in a topologically ordered * walk of node units. diff --git a/onnxruntime/test/providers/qnn/qnn_basic_test.cc b/onnxruntime/test/providers/qnn/qnn_basic_test.cc index 9d19c36dc94b2..0ea638567e83a 100644 --- a/onnxruntime/test/providers/qnn/qnn_basic_test.cc +++ b/onnxruntime/test/providers/qnn/qnn_basic_test.cc @@ -835,14 +835,14 @@ TEST_F(QnnHTPBackendTests, HTPGraphFinalizationOptimizationModes) { // Test that models run with various SoC model values TEST_F(QnnHTPBackendTests, HTPSocModels) { - constexpr std::array soc_models = {"", // No explicit SoC model specified - "0", // "Unknown" + constexpr std::array soc_models = { "", // No explicit SoC model specified + "0", // "Unknown" #if defined(_M_ARM64) - "37"}; // SC8280X + "37" }; // SC8280X #elif defined(__linux__) - "30"}; // SM8350 + "30" }; // SM8350 #else - ""}; + "" }; #endif for (auto soc_model : soc_models) { @@ -948,6 +948,62 @@ TEST_F(QnnHTPBackendTests, Float32ModelWithFP16PrecisionTest) { 0.008f); } +TEST_F(QnnHTPBackendTests, TestOD) { + Ort::SessionOptions so; + +#if 1 + const ORTCHAR_T* ort_model_path = ORT_MODEL_FOLDER "od_current_tf2onnx.onnx"; + // so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); +#else + const ORTCHAR_T* ort_model_path = ORT_MODEL_FOLDER "unet.preprocessed.quant.onnx_ctx.onnx"; +#endif + auto& logging_manager = DefaultLoggingManager(); + logging_manager.SetDefaultLoggerSeverity(logging::Severity::kVERBOSE); + + // Ensure all type/shape inference warnings result in errors! + so.AddConfigEntry(kOrtSessionOptionsDisableCPUEPFallback, "0"); // Disable fallback to the CPU EP. + so.AddConfigEntry(kDebugLayoutTransformation, "1"); + so.SetGraphOptimizationLevel(ORT_ENABLE_ALL); + so.SetLogSeverityLevel(ORT_LOGGING_LEVEL_VERBOSE); + onnxruntime::ProviderOptions options; + +#if defined(_WIN32) + options["backend_path"] = "QnnHtp.dll"; +#else + options["backend_path"] = "libQnnHtp.so"; +#endif + + so.AppendExecutionProvider("QNN", options); + + Ort::Session session(*ort_env, ort_model_path, so); + + std::vector input_data(300 * 300 * 3, 0.5f); + + auto memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); + std::vector ort_inputs; + std::vector ort_input_names; + + // Add input "serving_default_input_3:0" + std::array input_1_shape{1, 300, 300, 3}; + ort_inputs.emplace_back(Ort::Value::CreateTensor( + memory_info, input_data.data(), input_data.size(), input_1_shape.data(), input_1_shape.size())); + ort_input_names.push_back("serving_default_input_3:0"); + + // Run session and get outputs + std::array output_names{"StatefulPartitionedCall:1", "StatefulPartitionedCall:0"}; + std::vector ort_outputs = session.Run(Ort::RunOptions{nullptr}, ort_input_names.data(), ort_inputs.data(), + ort_inputs.size(), output_names.data(), output_names.size()); + + // Check output shape. + Ort::Value& ort_output = ort_outputs[0]; + auto typeshape = ort_output.GetTensorTypeAndShapeInfo(); + const float* results = ort_output.GetTensorData(); + + for (size_t i = 0; i < typeshape.GetElementCount() && i < 20; i++) { + std::cout << i << ": " << results[i] << std::endl; + } +} + #endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) #endif // !defined(ORT_MINIMAL_BUILD) From 4cab5f953c8cfcb5aa92b1c33d8e8fb31d84f5b8 Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Fri, 26 Jul 2024 18:56:00 -0700 Subject: [PATCH 03/20] Get it working on test model --- .../qnn/builder/qnn_conv_activation_fusion.cc | 146 +++---- .../qnn/builder/qnn_conv_activation_fusion.h | 11 +- .../core/providers/qnn/builder/qnn_fusions.cc | 360 ++++++++---------- .../core/providers/qnn/builder/qnn_fusions.h | 25 +- .../core/providers/qnn/builder/qnn_model.cc | 49 +-- .../providers/qnn/qnn_execution_provider.cc | 114 +++--- .../providers/qnn/qnn_execution_provider.h | 3 - .../test/providers/qnn/qnn_basic_test.cc | 20 +- 8 files changed, 301 insertions(+), 427 deletions(-) diff --git a/onnxruntime/core/providers/qnn/builder/qnn_conv_activation_fusion.cc b/onnxruntime/core/providers/qnn/builder/qnn_conv_activation_fusion.cc index fd9779cb5512a..26e7a5ff4cbf8 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_conv_activation_fusion.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_conv_activation_fusion.cc @@ -18,7 +18,7 @@ static const NodeUnit* GetOnlyChildOfType(const GraphViewer& graph_viewer, const NodeUnit& parent_node_unit, gsl::span child_op_types, const std::unordered_map& node_unit_map, - const std::unordered_set& handled_node_units) { + const std::unordered_map& node_unit_to_qnn_node_group) { const Node& parent_node = parent_node_unit.GetNode(); // Parent must have a single child (1 output edge) and must not produce a graph output. @@ -48,7 +48,7 @@ static const NodeUnit* GetOnlyChildOfType(const GraphViewer& graph_viewer, // Check if child node has already been handled. Should not be the case if the calling // fusion function has been called in topological order, but check to be safe. - if (handled_node_units.count(child_node_unit) != 0) { + if (node_unit_to_qnn_node_group.count(child_node_unit) != 0) { return nullptr; } @@ -146,10 +146,11 @@ static std::vector FindQDQNodes(const GraphViewer& graph_viewer, co return nodes; } -static std::optional GetConvQDQNodeGroup( +static Status GetConvDQNodeUnits( + /*out*/ std::vector& dq_node_units, const GraphViewer& graph_viewer, - const std::unordered_map& node_unit_map, - const std::unordered_set& handled_node_units, + const std::unordered_map& node_to_node_unit, + const std::unordered_map& node_unit_to_qnn_node_group, const Node& conv_node, const Node& q_node) { assert((conv_node.OpType() == "Conv" || conv_node.OpType() == "ConvTranspose") && @@ -159,64 +160,50 @@ static std::optional GetConvQDQNodeGroup( int num_dq_inputs = NumActualValues(conv_node, /*input*/ true); // Within a QDQ node group, a target node input is the only consumer of each DQ. - if (num_dq_inputs != static_cast(dq_nodes.size())) { - return std::nullopt; - } + ORT_RETURN_IF_NOT(num_dq_inputs == static_cast(dq_nodes.size()), + "Conv should be the only consumer of each DQ"); for (const auto* dq_node : dq_nodes) { - if (graph_viewer.NodeProducesGraphOutput(*dq_node)) { - return std::nullopt; - } + ORT_RETURN_IF(graph_viewer.NodeProducesGraphOutput(*dq_node), + "QDQ ", conv_node.OpType(), "'s input DQ node must not produce a graph output"); const bool dq_has_single_output_edge_to_target = dq_node->GetOutputEdgesCount() == 1 && dq_node->OutputEdgesBegin()->GetNode().Index() == conv_node.Index(); - if (!dq_has_single_output_edge_to_target) { - return std::nullopt; - } - - const auto dq_node_unit_it = node_unit_map.find(dq_node); - assert(dq_node_unit_it != node_unit_map.end()); - const NodeUnit* dq_node_unit = dq_node_unit_it->second; - - if (handled_node_units.count(dq_node_unit) != 0) { - return std::nullopt; - } - - if (dq_node_unit->UnitType() != NodeUnit::Type::SingleNode) { - return std::nullopt; - } + ORT_RETURN_IF_NOT(dq_has_single_output_edge_to_target, "DQ should have a single output to Conv"); } // input and output types need to be same int32_t dt_input = dq_nodes[0]->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); int32_t dt_weight = dq_nodes[1]->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); int32_t dt_output = q_nodes[0]->OutputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); - if (dt_input != dt_output) { - return std::nullopt; - } + ORT_RETURN_IF(dt_input != dt_output, "Conv input[0] and output quantization types must match"); if (dt_input == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8) { - if (dt_weight != dt_input) { - return std::nullopt; - } + ORT_RETURN_IF(dt_weight != dt_input, + conv_node.OpType(), "'s input[0] and input[1] quantization types must match if input[0] is int8"); } if (dq_nodes.size() == 3) { // has bias int32_t dt_bias = dq_nodes[2]->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); - if (dt_bias != ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32) { - return std::nullopt; - } + ORT_RETURN_IF(dt_bias != ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32, + "QDQ ", conv_node.OpType(), " must have int32 quantized bias"); + } + + dq_node_units.reserve(dq_nodes.size()); + for (const auto* dq_node : dq_nodes) { + const auto it = node_to_node_unit.find(dq_node); + assert(it != node_to_node_unit.end()); + const NodeUnit* dq_node_unit = it->second; + + ORT_RETURN_IF_NOT(node_unit_to_qnn_node_group.count(dq_node_unit) == 0, + "DQ NodeUnit ", dq_node_unit->Name(), " has already been added to another QnnNodeGroup"); + ORT_RETURN_IF_NOT(dq_node_unit->UnitType() == NodeUnit::Type::SingleNode, + "Expect DQ to be a NodeUnit of type SingleNode"); + dq_node_units.push_back(dq_node_unit); } - QDQ::NodeGroup node_group; - node_group.dq_nodes.reserve(dq_nodes.size()); - node_group.q_nodes.reserve(q_nodes.size()); - node_group.target_node = conv_node.Index(); - auto get_node_idx = [&](const Node* n) { return n->Index(); }; - std::transform(dq_nodes.begin(), dq_nodes.end(), std::back_inserter(node_group.dq_nodes), get_node_idx); - std::transform(q_nodes.begin(), q_nodes.end(), std::back_inserter(node_group.q_nodes), get_node_idx); - return node_group; + return Status::OK(); } Status QnnConvActivationFusionAdd(QnnModelWrapper& qnn_model_wrapper, @@ -247,13 +234,12 @@ Status QnnConvActivationFusionAdd(QnnModelWrapper& qnn_model_wrapper, return conv_op_builder->AddToModelBuilder(qnn_model_wrapper, custom_node_unit, logger, validate); } -Status TryConvActivationFusion(/*out*/ std::vector& fused_nodes, +Status TryConvActivationFusion(/*out*/ std::optional& qnn_node_group, QnnModelWrapper& qnn_model_wrapper, const NodeUnit& conv_node_unit, - const std::unordered_map& node_unit_map, - const std::unordered_set& handled_node_units, - const logging::Logger& logger, - bool do_op_validation) { + const std::unordered_map& node_to_node_unit, + const std::unordered_map& node_unit_to_qnn_node_group, + const logging::Logger& logger) { // Expect that this function is called with a standalone Conv or ConvTranspose. assert((conv_node_unit.OpType() == "Conv" || conv_node_unit.OpType() == "ConvTranspose") && conv_node_unit.UnitType() == NodeUnit::Type::SingleNode); @@ -263,7 +249,7 @@ Status TryConvActivationFusion(/*out*/ std::vector& fused_nodes // Conv must have a single Relu or Clip child. const std::array activation_op_types = {"Relu", "Clip"}; const NodeUnit* activation_node_unit = GetOnlyChildOfType(graph_viewer, conv_node_unit, activation_op_types, - node_unit_map, handled_node_units); + node_to_node_unit, node_unit_to_qnn_node_group); if (activation_node_unit == nullptr) { return Status::OK(); } @@ -271,7 +257,7 @@ Status TryConvActivationFusion(/*out*/ std::vector& fused_nodes // Relu/Clip must have a single Q child. const std::array q_op_types = {QDQ::QOpName}; const NodeUnit* q_node_unit = GetOnlyChildOfType(graph_viewer, *activation_node_unit, q_op_types, - node_unit_map, handled_node_units); + node_to_node_unit, node_unit_to_qnn_node_group); if (q_node_unit == nullptr) { return Status::OK(); @@ -286,17 +272,16 @@ Status TryConvActivationFusion(/*out*/ std::vector& fused_nodes const Node& conv_node = conv_node_unit.GetNode(); const Node& activation_node = activation_node_unit->GetNode(); const Node& q_node = q_node_unit->GetNode(); - std::optional qdq_node_group = GetConvQDQNodeGroup(graph_viewer, - node_unit_map, - handled_node_units, - conv_node, - q_node); + std::vector dq_node_units; + QNN_RETURN_OK_IF_ERROR(GetConvDQNodeUnits(dq_node_units, + graph_viewer, + node_to_node_unit, + node_unit_to_qnn_node_group, + conv_node, + q_node), + logger); - if (!qdq_node_group.has_value()) { - return Status::OK(); - } - - NodeUnit qdq_node_unit(graph_viewer, *qdq_node_group, activation_node); + assert(dq_node_units.size() == 3 || dq_node_units.size() == 2); // Create a temporary QnnModelWrapper for validation only. We need to be sure that this fusion will work before // modifying the actual QnnModelWrapper. This allows us to revert to the traditional OpBuilder workflow if this @@ -310,37 +295,28 @@ Status TryConvActivationFusion(/*out*/ std::vector& fused_nodes qnn_model_wrapper.GetInitializerLookup(), qnn_model_wrapper.GetQnnBackendType()); - const auto* conv_op_builder = qnn::GetOpBuilder(qdq_node_unit.OpType()); - if (conv_op_builder == nullptr) { - return Status::OK(); - } - - QNN_RETURN_OK_IF_ERROR(conv_op_builder->IsOpSupported(tmp_model_wrapper, qdq_node_unit, logger), logger); + QNN_RETURN_OK_IF_ERROR(QnnConvActivationFusionAdd(tmp_model_wrapper, + dq_node_units, + &conv_node_unit, + activation_node_unit, + q_node_unit, + logger, + /*validate*/ true), + logger); - // ====== The following statements modify qnn_model_wrapper. ======== - // Validation passed, so we're now committed to doing a fusion. + // Validation passed, so create a QnnNodeGroup. // If we encounter an error, we return it directly to caller. - LOGS(logger, VERBOSE) << " Adding Conv + Activation via fusion. conv_node name: [" << conv_node.Name() + LOGS(logger, VERBOSE) << "Will use Conv + Activation via fusion. conv_node name: [" << conv_node.Name() << "] activation_node optype: [" << activation_node.OpType() << "] activation_node name: [" << activation_node.Name() << "]"; - if (do_op_validation) { - ORT_RETURN_IF_ERROR(conv_op_builder->IsOpSupported(qnn_model_wrapper, qdq_node_unit, logger)); - } else { - ORT_RETURN_IF_ERROR(conv_op_builder->AddToModelBuilder(qnn_model_wrapper, qdq_node_unit, logger)); - } - - // Success. Add all nodes to fused_nodes so that caller can mark them as handled. - for (const Node* dq_node : qdq_node_unit.GetDQNodes()) { - const auto dq_node_unit_it = node_unit_map.find(dq_node); - ORT_RETURN_IF(dq_node_unit_it == node_unit_map.end(), "DQ node does not have a NodeUnit"); - fused_nodes.push_back(dq_node_unit_it->second); - } - - fused_nodes.push_back(&conv_node_unit); - fused_nodes.push_back(activation_node_unit); - fused_nodes.push_back(q_node_unit); + qnn_node_group = QnnNodeGroup{}; + qnn_node_group->type_ = QnnNodeGroup::Type::ConvActivationFusion; + qnn_node_group->node_units_ = std::move(dq_node_units); + qnn_node_group->node_units_.push_back(&conv_node_unit); + qnn_node_group->node_units_.push_back(activation_node_unit); + qnn_node_group->node_units_.push_back(q_node_unit); return Status::OK(); } diff --git a/onnxruntime/core/providers/qnn/builder/qnn_conv_activation_fusion.h b/onnxruntime/core/providers/qnn/builder/qnn_conv_activation_fusion.h index 6ad37650122d6..9cca16536ad95 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_conv_activation_fusion.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_conv_activation_fusion.h @@ -3,12 +3,14 @@ #pragma once +#include #include #include #include #include "core/framework/node_unit.h" #include "core/providers/qnn/builder/qnn_model_wrapper.h" +#include "core/providers/qnn/builder/qnn_fusions.h" namespace onnxruntime { namespace qnn { @@ -21,12 +23,11 @@ Status QnnConvActivationFusionAdd(QnnModelWrapper& qnn_model_wrapper, const logging::Logger& logger, bool validate = false); -Status TryConvActivationFusion(/*out*/ std::vector& fused_nodes, +Status TryConvActivationFusion(/*out*/ std::optional& qnn_node_group, QnnModelWrapper& qnn_model_wrapper, const NodeUnit& conv_node_unit, - const std::unordered_map& node_unit_map, - const std::unordered_set& handled_node_units, - const logging::Logger& logger, - bool do_op_validation); + const std::unordered_map& node_to_node_unit, + const std::unordered_map& node_unit_to_qnn_node_group, + const logging::Logger& logger); } // namespace qnn } // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/qnn_fusions.cc b/onnxruntime/core/providers/qnn/builder/qnn_fusions.cc index 18d5216f65e60..8b7bf899cd622 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_fusions.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_fusions.cc @@ -26,6 +26,7 @@ static Status QnnDQQFusionAdd(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& q_node_unit, const logging::Logger& logger, bool validate = false) { + ORT_UNUSED_PARAMETER(logger); assert(dq_node_unit.OpType() == QDQ::DQOpName && q_node_unit.OpType() == QDQ::QOpName); const auto& node_name = utils::GetNodeName(dq_node_unit); const NodeUnitIODef& input_def = dq_node_unit.Inputs()[0]; @@ -34,8 +35,8 @@ static Status QnnDQQFusionAdd(QnnModelWrapper& qnn_model_wrapper, QnnTensorWrapper input_tensor; QnnTensorWrapper output_tensor; - ORT_RETURN_IF_ERROR(qnn_model_wrapper.MakeTensorWrapper(input_def, input_tensor), logger); - ORT_RETURN_IF_ERROR(qnn_model_wrapper.MakeTensorWrapper(output_def, output_tensor), logger); + ORT_RETURN_IF_ERROR(qnn_model_wrapper.MakeTensorWrapper(input_def, input_tensor)); + ORT_RETURN_IF_ERROR(qnn_model_wrapper.MakeTensorWrapper(output_def, output_tensor)); if (validate) { ORT_RETURN_IF_ERROR(qnn_model_wrapper.ValidateQnnNode(node_name, @@ -43,8 +44,7 @@ static Status QnnDQQFusionAdd(QnnModelWrapper& qnn_model_wrapper, QNN_OP_CONVERT, {input_tensor.GetQnnTensor()}, {output_tensor.GetQnnTensor()}, - {}), - logger); + {})); } else { ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(input_tensor)), "Failed to add input"); ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(output_tensor)), "Failed to add output"); @@ -57,6 +57,8 @@ static Status QnnDQQFusionAdd(QnnModelWrapper& qnn_model_wrapper, validate), "Failed to add fused Convert node."); } + + return Status::OK(); } static Status QnnHardSigmoidMulFusionAdd(QnnModelWrapper& qnn_model_wrapper, @@ -64,6 +66,7 @@ static Status QnnHardSigmoidMulFusionAdd(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& mul_node_unit, const logging::Logger& logger, bool validate = false) { + ORT_UNUSED_PARAMETER(logger); assert(hardsigmoid_node_unit.OpType() == "HardSigmoid" && mul_node_unit.OpType() == "Mul"); const auto& node_name = utils::GetNodeName(hardsigmoid_node_unit); const NodeUnitIODef& input_def = hardsigmoid_node_unit.Inputs()[0]; @@ -72,8 +75,8 @@ static Status QnnHardSigmoidMulFusionAdd(QnnModelWrapper& qnn_model_wrapper, QnnTensorWrapper input_tensor; QnnTensorWrapper output_tensor; - ORT_RETURN_IF_ERROR(qnn_model_wrapper.MakeTensorWrapper(input_def, input_tensor), logger); - ORT_RETURN_IF_ERROR(qnn_model_wrapper.MakeTensorWrapper(output_def, output_tensor), logger); + ORT_RETURN_IF_ERROR(qnn_model_wrapper.MakeTensorWrapper(input_def, input_tensor)); + ORT_RETURN_IF_ERROR(qnn_model_wrapper.MakeTensorWrapper(output_def, output_tensor)); if (validate) { ORT_RETURN_IF_ERROR(qnn_model_wrapper.ValidateQnnNode(node_name, @@ -81,8 +84,7 @@ static Status QnnHardSigmoidMulFusionAdd(QnnModelWrapper& qnn_model_wrapper, QNN_OP_HARD_SWISH, {input_tensor.GetQnnTensor()}, {output_tensor.GetQnnTensor()}, - {}), - logger); + {})); } else { LOGS(logger, VERBOSE) << " Adding QNN HardSwish via fusion. HardSigmoid name: [" << hardsigmoid_node_unit.Name() << "] Mul name: [" << mul_node_unit.Name() << "]"; @@ -98,6 +100,8 @@ static Status QnnHardSigmoidMulFusionAdd(QnnModelWrapper& qnn_model_wrapper, validate), "Failed to add fused HardSwish node."); } + + return Status::OK(); } std::string_view QnnNodeGroup::TypeToString(QnnNodeGroup::Type type) { @@ -116,9 +120,24 @@ Status QnnNodeGroup::IsSupported(QnnModelWrapper& qmw, const logging::Logger& lo switch (type_) { case Type::NodeUnit: { ORT_RETURN_IF_NOT(node_units_.size() == 1 && node_units_[0] != nullptr, ""); - const auto* op_builder = qnn::GetOpBuilder(node_units_[0]->OpType()); - ORT_RETURN_IF_NOT(op_builder != nullptr, ""); - return op_builder->IsOpSupported(qmw, *node_units_[0], logger); + const NodeUnit& node_unit = *node_units_[0]; + const std::string& op_type = node_unit.OpType(); + const auto* op_builder = qnn::GetOpBuilder(op_type); + + if (op_builder == nullptr) { + std::string err_msg = MakeString("Operators of type `", op_type, + "` are not supported by QNN EP.", op_type, " node `", + node_unit.Name(), "` will not be assigned to QNN EP."); + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, err_msg); + } + + Status status = op_builder->IsOpSupported(qmw, *node_units_[0], logger); + if (!status.IsOK()) { + LOGS(logger, WARNING) << op_type << " node `" << node_unit.Name() + << "` is not supported: " << status.ErrorMessage(); + } + + return status; } case Type::ConvActivationFusion: { const size_t num_node_units = node_units_.size(); @@ -133,27 +152,55 @@ Status QnnNodeGroup::IsSupported(QnnModelWrapper& qmw, const logging::Logger& lo if (has_bias_dq) { dq_node_units.push_back(node_units_[2]); } - return QnnConvActivationFusionAdd(qmw, - dq_node_units, - conv_node_unit, - activation_node_unit, - q_node_unit, - logger, - /*validate*/ true); + Status status = QnnConvActivationFusionAdd(qmw, + dq_node_units, + conv_node_unit, + activation_node_unit, + q_node_unit, + logger, + /*validate*/ true); + + if (!status.IsOK()) { + LOGS(logger, ERROR) << conv_node_unit->OpType() << "/" << activation_node_unit->OpType() + << " fusion is not supported, but should be according to initial validation." + << " Node names: " << conv_node_unit->Name() << ", " << activation_node_unit->Name() + << " Error: " << status.ErrorMessage(); + } + + return status; } case Type::DQQFusion: { - ORT_RETURN_IF_NOT(node_units_.size() == 2, ""); + ORT_RETURN_IF_NOT(node_units_.size() == 2, "Expected 2 NodeUnits for DQ -> Q fusion"); const NodeUnit* dq_node_unit = node_units_[0]; const NodeUnit* q_node_unit = node_units_[1]; ORT_RETURN_IF_NOT(dq_node_unit != nullptr && q_node_unit != nullptr, ""); - return QnnDQQFusionAdd(qmw, *dq_node_unit, *q_node_unit, logger, /*validate*/ true); + Status status = QnnDQQFusionAdd(qmw, *dq_node_unit, *q_node_unit, logger, /*validate*/ true); + + if (!status.IsOK()) { + LOGS(logger, ERROR) << "(DQ -> Q) into QNN Convert fusion is not supported, " + << "but should be according to initial validation. " + << "Node names: " << dq_node_unit->Name() << ", " << q_node_unit->Name() + << " Error: " << status.ErrorMessage(); + } + + return status; } case Type::HardSigmoidMulFusion: { - ORT_RETURN_IF_NOT(node_units_.size() == 2, ""); + ORT_RETURN_IF_NOT(node_units_.size() == 2, "Expected 2 NodeUnits for HardSimoid -> Mul fusion"); const NodeUnit* hardsigmoid_node_unit = node_units_[0]; const NodeUnit* mul_node_unit = node_units_[1]; ORT_RETURN_IF_NOT(hardsigmoid_node_unit != nullptr && mul_node_unit != nullptr, ""); - return QnnHardSigmoidMulFusionAdd(qmw, *hardsigmoid_node_unit, *mul_node_unit, logger, /*validate*/ true); + Status status = QnnHardSigmoidMulFusionAdd(qmw, *hardsigmoid_node_unit, *mul_node_unit, logger, + /*validate*/ true); + + if (!status.IsOK()) { + LOGS(logger, ERROR) << "(HardSigmoid -> Mul) into QNN HardSwish fusion is not supported, " + << "but should be according to initial validation. " + << "Node names: " << hardsigmoid_node_unit->Name() << ", " << mul_node_unit->Name() + << " Error: " << status.ErrorMessage(); + } + + return status; } default: std::string error_msg = MakeString("Unhandled QnnNodeGroup::Type ", TypeToString(type_), @@ -168,7 +215,7 @@ Status QnnNodeGroup::AddToModelBuilder(QnnModelWrapper& qmw, const logging::Logg case Type::NodeUnit: { ORT_RETURN_IF_NOT(node_units_.size() == 1 && node_units_[0] != nullptr, ""); const auto* op_builder = qnn::GetOpBuilder(node_units_[0]->OpType()); - ORT_RETURN_IF_NOT(op_builder != nullptr, ""); + ORT_RETURN_IF_NOT(op_builder != nullptr, "[QNN EP]: Missing OpBuilder for OpType ", node_units_[0]->OpType()); return op_builder->AddToModelBuilder(qmw, *node_units_[0], logger, /*do_op_validation*/ false); } case Type::ConvActivationFusion: { @@ -193,14 +240,14 @@ Status QnnNodeGroup::AddToModelBuilder(QnnModelWrapper& qmw, const logging::Logg /*validate*/ false); } case Type::DQQFusion: { - ORT_RETURN_IF_NOT(node_units_.size() == 2, ""); + ORT_RETURN_IF_NOT(node_units_.size() == 2, "Expected 2 NodeUnits for DQ -> Q fusion"); const NodeUnit* dq_node_unit = node_units_[0]; const NodeUnit* q_node_unit = node_units_[1]; ORT_RETURN_IF_NOT(dq_node_unit != nullptr && q_node_unit != nullptr, ""); return QnnDQQFusionAdd(qmw, *dq_node_unit, *q_node_unit, logger, /*validate*/ false); } case Type::HardSigmoidMulFusion: { - ORT_RETURN_IF_NOT(node_units_.size() == 2, ""); + ORT_RETURN_IF_NOT(node_units_.size() == 2, "Expected 2 NodeUnits for HardSimoid -> Mul fusion"); const NodeUnit* hardsigmoid_node_unit = node_units_[0]; const NodeUnit* mul_node_unit = node_units_[1]; ORT_RETURN_IF_NOT(hardsigmoid_node_unit != nullptr && mul_node_unit != nullptr, ""); @@ -261,96 +308,12 @@ const NodeUnit* QnnNodeGroup::GetTargetNodeUnit(const logging::Logger& logger) c * \param do_op_validation True if should call QNN operator validation APIs. * \return An onnxruntime::Status */ -static Status TryHandleConvertSequence(std::vector& fused_nodes, - QnnModelWrapper& qnn_model_wrapper, - const NodeUnit& dq_node_unit, - const NodeUnit& q_node_unit, - const logging::Logger& logger, - bool do_op_validation) { - const GraphViewer& graph_viewer = qnn_model_wrapper.GetGraphViewer(); - - // Looking for a standalone DQ to start the sequence. - assert(dq_node_unit.OpType() == QDQ::DQOpName && dq_node_unit.UnitType() == NodeUnit::Type::SingleNode); - assert(q_node_unit.OpType() == QDQ::QOpName && q_node_unit.UnitType() == NodeUnit::Type::SingleNode); - - const Node& dq_node = dq_node_unit.GetNode(); - const Node& q_node = q_node_unit.GetNode(); - - auto get_const_initializer = [&graph_viewer](const std::string& initializer_name) { - return graph_viewer.GetConstantInitializer(initializer_name, true); - }; - - // DQ and Q must have equal scale type and different zp type. - if (!QDQ::IsDQQConversion(dq_node, q_node, get_const_initializer, graph_viewer.ModelPath())) { - return Status::OK(); - } - - const auto& node_name = utils::GetNodeName(dq_node_unit); - const NodeUnitIODef& input_def = dq_node_unit.Inputs()[0]; - const NodeUnitIODef& output_def = q_node_unit.Outputs()[0]; - - QnnTensorWrapper input_tensor; - QnnTensorWrapper output_tensor; - - // Run QNN validation on the final fused node before committing to doing a fusion. - // Importantly, this validation process does not modify the qnn_model_wrapper. - // If validation fails here, we return Status::OK() to allow QNN EP to use the normal OpBuilder workflow. - QNN_RETURN_OK_IF_ERROR(qnn_model_wrapper.MakeTensorWrapper(input_def, input_tensor), logger); - QNN_RETURN_OK_IF_ERROR(qnn_model_wrapper.MakeTensorWrapper(output_def, output_tensor), logger); - QNN_RETURN_OK_IF_ERROR(qnn_model_wrapper.ValidateQnnNode(node_name, - QNN_OP_PACKAGE_NAME_QTI_AISW, - QNN_OP_CONVERT, - {input_tensor.GetQnnTensor()}, - {output_tensor.GetQnnTensor()}, - {}), - logger); - - // Validation passed, so we're now committed to doing a fusion. The following statements modify qnn_model_wrapper. - // If we encounter an error, we return it directly to caller. - LOGS(logger, VERBOSE) << " Adding QNN Convert via fusion. dq_node name: [" << dq_node.Name() - << "] dq_node optype: [" << dq_node.OpType() - << "] q_node name: [" << q_node_unit.Name() - << "] q_node optype: [" << q_node_unit.OpType() - << "]"; - - // Add a QNN Convert to the model. Get the input from the DQ node, and the output from the Q node. - ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(input_tensor)), "Failed to add input"); - ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(output_tensor)), "Failed to add output"); - ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(utils::GetNodeName(q_node_unit), - QNN_OP_PACKAGE_NAME_QTI_AISW, - QNN_OP_CONVERT, - {input_def.node_arg.Name()}, - {output_def.node_arg.Name()}, - {}, - do_op_validation), - "Failed to add fused Convert node."); - - fused_nodes.push_back(&dq_node_unit); - fused_nodes.push_back(&q_node_unit); - - return Status::OK(); -} - -/** - * Tries to fuse sequences that start with a DQ node. - * - * \param fused_nodes Output list of node units that were fused. Remains empty if fusion is not applied. - * \param qnn_model_wrapper The QNN model that is being built. - * \param dq_node_unit The DQ node unit that could potentially start a sequence. - * \param node_unit_map Maps a node to its node unit. - * \param handled_node_units Set of node units that have already been processed. Fusion will not fuse nodes - * in this set. - * \param logger The logger. - * \param do_op_validation True if should call QNN operator validation APIs. - * \return An onnxruntime::Status - */ -static Status TryHandleDequantize(std::vector& fused_nodes, - QnnModelWrapper& qnn_model_wrapper, - const NodeUnit& dq_node_unit, - const std::unordered_map& node_unit_map, - const std::unordered_set& handled_node_units, - const logging::Logger& logger, - bool do_op_validation) { +static Status TryDQQFusion(std::optional& qnn_node_group, + QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& dq_node_unit, + const std::unordered_map& node_to_node_unit, + const std::unordered_map& node_unit_to_qnn_node_group, + const logging::Logger& logger) { // Expect that this function is called with a standalone DQ. assert(dq_node_unit.OpType() == QDQ::DQOpName && dq_node_unit.UnitType() == NodeUnit::Type::SingleNode); @@ -362,35 +325,49 @@ static Status TryHandleDequantize(std::vector& fused_nodes, return Status::OK(); } - const Node& child_node = dq_node.OutputEdgesBegin()->GetNode(); - const auto child_node_unit_it = node_unit_map.find(&child_node); - ORT_RETURN_IF(child_node_unit_it == node_unit_map.end(), "Node does not have a corresponding NodeUnit"); - const NodeUnit* child_node_unit = child_node_unit_it->second; + const Node& q_node = dq_node.OutputEdgesBegin()->GetNode(); + if (q_node.OpType() != QDQ::QOpName) { + return Status::OK(); + } + + const auto q_node_unit_it = node_to_node_unit.find(&q_node); + ORT_RETURN_IF(q_node_unit_it == node_to_node_unit.end(), "Node does not have a corresponding NodeUnit"); + const NodeUnit* q_node_unit = q_node_unit_it->second; + + // child must not already be part of a QDQ NodeUnit (i.e., be standalone). + if (q_node_unit->UnitType() != NodeUnit::Type::SingleNode) { + return Status::OK(); + } // Check if child node has already been handled. Should not be the case if this // fusion function has been called in topological order, but check to be safe. - if (handled_node_units.count(child_node_unit) != 0) { + if (node_unit_to_qnn_node_group.count(q_node_unit) != 0) { return Status::OK(); } - // child must not already be part of a QDQ NodeUnit (i.e., be standalone). - if (child_node_unit->UnitType() != NodeUnit::Type::SingleNode) { + auto get_const_initializer = [&graph_viewer](const std::string& initializer_name) { + return graph_viewer.GetConstantInitializer(initializer_name, true); + }; + + // DQ and Q must have equal scale type and different zp type. + if (!QDQ::IsDQQConversion(dq_node, q_node, get_const_initializer, graph_viewer.ModelPath())) { return Status::OK(); } - const std::string& child_type = child_node.OpType(); + QNN_RETURN_OK_IF_ERROR(QnnDQQFusionAdd(qnn_model_wrapper, dq_node_unit, *q_node_unit, logger, /*validate*/ true), + logger); - // Try (DQ -> Q) into QNN's Convert op. - if (child_type == QDQ::QOpName) { - return TryHandleConvertSequence(fused_nodes, qnn_model_wrapper, dq_node_unit, *child_node_unit, - logger, do_op_validation); - } + // Validation passed, so create a QnnNodeGroup. + LOGS(logger, VERBOSE) << " Will use QNN Convert via fusion. dq_node name: [" << dq_node.Name() + << "] dq_node optype: [" << dq_node.OpType() + << "] q_node name: [" << q_node_unit->Name() + << "] q_node optype: [" << q_node_unit->OpType() + << "]"; - // Try (DQ -> Conv/ConvTranspose -> Relu/Clip -> Q) into QNN Conv/ConvTranspose. - if (child_type == "Conv" || child_type == "ConvTranspose") { - return TryConvActivationFusion(fused_nodes, qnn_model_wrapper, *child_node_unit, node_unit_map, - handled_node_units, logger, do_op_validation); - } + qnn_node_group = QnnNodeGroup{}; + qnn_node_group->type_ = QnnNodeGroup::Type::DQQFusion; + qnn_node_group->node_units_.push_back(&dq_node_unit); + qnn_node_group->node_units_.push_back(q_node_unit); return Status::OK(); } @@ -409,19 +386,19 @@ static Status TryHandleDequantize(std::vector& fused_nodes, * \param do_op_validation True if should call QNN operator validation APIs. * \return A Status indicating a potential failure. */ -static Status TryHandleHardSigmoidSequence(std::vector& fused_nodes, - QnnModelWrapper& qnn_model_wrapper, - const NodeUnit& start_node_unit, - const std::unordered_map& node_unit_map, - const std::unordered_set& handled_node_units, - const logging::Logger& logger, - bool do_op_validation) { +static Status TryHardSigmoidMulFusion(std::optional& qnn_node_group, + QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& hardsigmoid_node_unit, + const std::unordered_map& node_to_node_unit, + const std::unordered_map& node_unit_to_qnn_node_group, + const logging::Logger& logger) { // Looking for a standalone HardSigmoid to start the sequence. - if (start_node_unit.OpType() != "HardSigmoid" || start_node_unit.UnitType() != NodeUnit::Type::SingleNode) { + if (hardsigmoid_node_unit.OpType() != "HardSigmoid" || + hardsigmoid_node_unit.UnitType() != NodeUnit::Type::SingleNode) { return Status::OK(); } - NodeAttrHelper hs_attr_helper(start_node_unit); + NodeAttrHelper hs_attr_helper(hardsigmoid_node_unit); float alpha = hs_attr_helper.Get("alpha", 0.2f); float beta = hs_attr_helper.Get("beta", 0.5f); constexpr float req_alpha = 1.0f / 6.0f; @@ -435,22 +412,25 @@ static Status TryHandleHardSigmoidSequence(std::vector& fused_n } const GraphViewer& graph_viewer = qnn_model_wrapper.GetGraphViewer(); - const Node& hs_node = start_node_unit.GetNode(); + const Node& hs_node = hardsigmoid_node_unit.GetNode(); - // HardSigmoid must have a single Mul child. HardSigmoid must not produce a graph output. - auto children = graph_utils::FindChildrenByType(hs_node, "Mul"); - if (children.size() != 1 || hs_node.GetOutputEdgesCount() != 1 || graph_viewer.NodeProducesGraphOutput(hs_node)) { + // HardSigmoid must have a single child (1 output edge) and must not produce a graph output. + if (hs_node.GetOutputEdgesCount() != 1 || graph_viewer.NodeProducesGraphOutput(hs_node)) { return Status::OK(); } - const Node& mul_node = *children[0]; - const auto mul_node_unit_it = node_unit_map.find(&mul_node); - ORT_RETURN_IF(mul_node_unit_it == node_unit_map.end(), "Node does not have a corresponding NodeUnit"); + const Node& mul_node = hs_node.OutputEdgesBegin()->GetNode(); + if (mul_node.OpType() != "Mul") { + return Status::OK(); + } + + const auto mul_node_unit_it = node_to_node_unit.find(&mul_node); + ORT_RETURN_IF(mul_node_unit_it == node_to_node_unit.end(), "Mul Node does not have a corresponding NodeUnit"); const NodeUnit* mul_node_unit = mul_node_unit_it->second; // Check if Mul node has already been handled. Should not be the case if this // fusion function has been called in topological order, but check to be safe. - if (handled_node_units.count(mul_node_unit) != 0) { + if (node_unit_to_qnn_node_group.count(mul_node_unit) != 0) { return Status::OK(); } @@ -460,7 +440,7 @@ static Status TryHandleHardSigmoidSequence(std::vector& fused_n } // Input to HardSigmoid must also be the other input to the Mul. - auto& hs_input_name = start_node_unit.Inputs()[0].node_arg.Name(); + auto& hs_input_name = hardsigmoid_node_unit.Inputs()[0].node_arg.Name(); const bool same_root_input = mul_node.InputDefs()[0]->Name() == hs_input_name || mul_node.InputDefs()[1]->Name() == hs_input_name; @@ -468,56 +448,30 @@ static Status TryHandleHardSigmoidSequence(std::vector& fused_n return Status::OK(); } - const auto& node_name = utils::GetNodeName(start_node_unit); - const NodeUnitIODef& input_def = start_node_unit.Inputs()[0]; - const NodeUnitIODef& output_def = mul_node_unit->Outputs()[0]; - - QnnTensorWrapper input_tensor; - QnnTensorWrapper output_tensor; - - // Run QNN validation on the final fused node before committing to doing a fusion. - // Importantly, this validation process does not modify the qnn_model_wrapper. - // If validation fails here, we return Status::OK() to allow QNN EP to use the normal OpBuilder workflow. - QNN_RETURN_OK_IF_ERROR(qnn_model_wrapper.MakeTensorWrapper(input_def, input_tensor), logger); - QNN_RETURN_OK_IF_ERROR(qnn_model_wrapper.MakeTensorWrapper(output_def, output_tensor), logger); - QNN_RETURN_OK_IF_ERROR(qnn_model_wrapper.ValidateQnnNode(node_name, - QNN_OP_PACKAGE_NAME_QTI_AISW, - QNN_OP_HARD_SWISH, - {input_tensor.GetQnnTensor()}, - {output_tensor.GetQnnTensor()}, - {}), + QNN_RETURN_OK_IF_ERROR(QnnHardSigmoidMulFusionAdd(qnn_model_wrapper, hardsigmoid_node_unit, *mul_node_unit, + logger, /*validate*/ true), logger); - // Validation passed, so we're now committed to doing a fusion. The following statements modify qnn_model_wrapper. - // If we encounter an error, we return it directly to caller. - LOGS(logger, VERBOSE) << " Adding QNN HardSwish via fusion. HardSigmoid name: [" << start_node_unit.Name() + // Validation passed, so create a QnnNodeGroup. Any errors are now passed back to the caller. + LOGS(logger, VERBOSE) << "Will use QNN HardSwish via fusion. HardSigmoid name: [" << hardsigmoid_node_unit.Name() << "] Mul name: [" << mul_node_unit->Name() << "]"; - ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(input_tensor)), "Failed to add input"); - ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(output_tensor)), "Failed to add output"); - ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(node_name, - QNN_OP_PACKAGE_NAME_QTI_AISW, - QNN_OP_HARD_SWISH, - {input_def.node_arg.Name()}, - {output_def.node_arg.Name()}, - {}, - do_op_validation), - "Failed to add fused HardSwish node."); - - fused_nodes.push_back(&start_node_unit); - fused_nodes.push_back(mul_node_unit); + qnn_node_group = QnnNodeGroup{}; + qnn_node_group->type_ = QnnNodeGroup::Type::HardSigmoidMulFusion; + qnn_node_group->node_units_.push_back(&hardsigmoid_node_unit); + qnn_node_group->node_units_.push_back(mul_node_unit); return Status::OK(); } -using FusionFunc = Status (*)(std::vector&, +using FusionFunc = Status (*)(std::optional&, QnnModelWrapper&, const NodeUnit&, const std::unordered_map&, - const std::unordered_set&, - const logging::Logger&, - bool); + const std::unordered_map&, + const logging::Logger&); +#if 0 Status TryFusions(/*out*/ std::vector& fused_nodes, QnnModelWrapper& qnn_model_wrapper, const NodeUnit& starting_node, @@ -547,6 +501,7 @@ Status TryFusions(/*out*/ std::vector& fused_nodes, return Status::OK(); } +#endif static Status TryQnnFusions(/*out*/ std::optional& fused_node_group, QnnModelWrapper& qnn_model_wrapper, @@ -554,27 +509,46 @@ static Status TryQnnFusions(/*out*/ std::optional& fused_node_grou const std::unordered_map& node_to_node_unit, const std::unordered_map& node_unit_to_qnn_node_group, const logging::Logger& logger) { + // Maps a starting operator type to the fusion function. + static std::unordered_map fusions = { + {"DequantizeLinear", TryDQQFusion}, + {"HardSigmoid", TryHardSigmoidMulFusion}, + {"Conv", TryConvActivationFusion}, + {"ConvTranspose", TryConvActivationFusion}, + }; + + // For now, all fusions involve standalone node units (i.e., no wrapping DQ/Q nodes). + if (starting_node_unit.UnitType() != NodeUnit::Type::SingleNode) { + return Status::OK(); + } + + auto iter = fusions.find(starting_node_unit.OpType()); + if (iter != fusions.end()) { + FusionFunc fusion_func = iter->second; + ORT_RETURN_IF_ERROR(fusion_func(fused_node_group, qnn_model_wrapper, starting_node_unit, node_to_node_unit, + node_unit_to_qnn_node_group, logger)); + } return Status::OK(); } Status GetQnnNodeGroups(/*out*/ std::vector& qnn_node_groups, QnnModelWrapper& qnn_model_wrapper, const std::unordered_map& node_to_node_unit, + const size_t num_node_units, const logging::Logger& logger) { const GraphViewer& graph_viewer = qnn_model_wrapper.GetGraphViewer(); const std::vector sorted_node_indices = graph_viewer.GetNodesInTopologicalOrder(); - const size_t approx_num_node_units = static_cast(graph_viewer.NumberOfNodes() / 2); std::vector sorted_qnn_node_group_indices; - sorted_qnn_node_group_indices.reserve(approx_num_node_units); + sorted_qnn_node_group_indices.reserve(num_node_units); std::vector tmp_qnn_node_groups; - tmp_qnn_node_groups.reserve(approx_num_node_units); + tmp_qnn_node_groups.reserve(num_node_units); { std::unordered_map node_unit_to_qnn_node_group; std::vector> sorted_node_units; - sorted_node_units.reserve(approx_num_node_units); + sorted_node_units.reserve(num_node_units); // Create QnnNodeGroups for fusions first. for (NodeIndex node_index : sorted_node_indices) { diff --git a/onnxruntime/core/providers/qnn/builder/qnn_fusions.h b/onnxruntime/core/providers/qnn/builder/qnn_fusions.h index 5bb652df1fa44..779e04ed91b41 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_fusions.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_fusions.h @@ -39,29 +39,8 @@ struct QnnNodeGroup { Status GetQnnNodeGroups(/*out*/ std::vector& qnn_node_groups, QnnModelWrapper& qnn_model_wrapper, - const std::unordered_map& node_unit_map, + const std::unordered_map& node_to_node_unit, + size_t num_node_units, const logging::Logger& logger); - -/** - * Tries to fuse a node sequence starting from the given starting node. Should be called in a topologically ordered - * walk of node units. - * - * \param fused_nodes Output list of node units that were fused. Remains empty if fusion was not applied. - * \param qnn_model_wrapper The QNN model that is being built. - * \param starting_node The node unit that could potentially start the sequence. - * \param node_unit_map Maps a node to its node unit. - * \param handled_node_units Set of node units that have already been processed. Fusion will not fuse nodes - * in this set. - * \param logger The logger. - * \param do_op_validation True if should call QNN operator validation APIs. - * \return A Status indicating a potential failure. - */ -Status TryFusions(/*out*/ std::vector& fused_nodes, - QnnModelWrapper& qnn_model_wrapper, - const NodeUnit& starting_node, - const std::unordered_map& node_unit_map, - const std::unordered_set& handled_node_units, - const logging::Logger& logger, - bool do_op_validation); } // namespace qnn } // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model.cc b/onnxruntime/core/providers/qnn/builder/qnn_model.cc index 503943dfb636b..4a74527566b7d 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_model.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_model.cc @@ -117,49 +117,20 @@ Status QnnModel::ComposeGraph(const GraphViewer& graph_viewer, return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to initialize qnn_model_wrapper."); } - std::unordered_set handled_node_units; + std::vector qnn_node_groups; + qnn_node_groups.reserve(node_unit_holder.size()); - // Op builer - const auto& node_indices = graph_viewer.GetNodesInTopologicalOrder(); - for (size_t i = 0; i < node_indices.size(); i++) { - const auto* node(graph_viewer.GetNode(node_indices[i])); + ORT_RETURN_IF_ERROR(qnn::GetQnnNodeGroups(qnn_node_groups, qnn_model_wrapper, node_unit_map, + node_unit_holder.size(), logger_)); - // Check whether it's part of NodeUnit - const NodeUnit& node_unit = GetNodeUnit(node, node_unit_map); - // Q, DQ nodes in the node unit only carry the quantization parameters - // Add the QNN node when it is the target node (It's a normal node or a single Q/DQ node) - const std::string& op_type = node_unit.OpType(); + for (const qnn::QnnNodeGroup& qnn_node_group : qnn_node_groups) { + Status status = qnn_node_group.AddToModelBuilder(qnn_model_wrapper, logger_); - if (node != &node_unit.GetNode()) { - continue; + if (!status.IsOK()) { + LOGS(logger_, ERROR) << "[QNN EP] Failed to add supported node to QNN graph during EP's compile call: " + << status.ErrorMessage() << std::endl; + return status; } - - if (handled_node_units.count(&node_unit) != 0) { - continue; // Already handled. - } - - // Try to see if this node unit can be fused. - std::vector fused_nodes; - ORT_RETURN_IF_ERROR(TryFusions(fused_nodes, qnn_model_wrapper, node_unit, node_unit_map, - handled_node_units, logger_, false /*do_op_validation*/)); - - if (!fused_nodes.empty()) { - for (auto fused_node_unit : fused_nodes) { - handled_node_units.insert(fused_node_unit); - } - continue; - } - - LOGS(logger_, VERBOSE) << " node name: [" << node->Name() - << "] node optype: [" << op_type - << "] as part of the NodeUnit type: [" << node_unit.OpType() - << "] name: [" << node_unit.Name() - << "]"; - if (const auto* op_builder = GetOpBuilder(op_type)) { - ORT_RETURN_IF_ERROR(op_builder->AddToModelBuilder(qnn_model_wrapper, node_unit, logger_)); - } - - handled_node_units.insert(&node_unit); } ORT_RETURN_IF_NOT(qnn_model_wrapper.ComposeQnnGraph(), "Failed to compose Qnn graph."); diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc index 0ddaa97694217..15306731b1f3e 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc @@ -405,27 +405,6 @@ QNNExecutionProvider::~QNNExecutionProvider() { #endif } -bool QNNExecutionProvider::IsNodeSupported(qnn::QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit, - const logging::Logger& logger) const { - const std::string& op_type = node_unit.OpType(); - bool supported = false; - const auto* op_builder = qnn::GetOpBuilder(op_type); - if (op_builder == nullptr) { - LOGS(logger, WARNING) << "Operators of type `" << node_unit.OpType() << "` are not supported by QNN EP." - << node_unit.OpType() << " node `" << node_unit.Name() - << "` will not be assigned to QNN EP."; - } else { - auto status = op_builder->IsOpSupported(qnn_model_wrapper, - node_unit, logger); - if (Status::OK() != status) { - LOGS(logger, WARNING) << node_unit.OpType() << " node `" << node_unit.Name() - << "` is not supported: " << status.ErrorMessage(); - } - supported = (Status::OK() == status); - } - return supported; -} - std::unordered_set QNNExecutionProvider::GetSupportedNodes(const GraphViewer& graph_viewer, const std::unordered_map& node_unit_map, @@ -462,68 +441,65 @@ QNNExecutionProvider::GetSupportedNodes(const GraphViewer& graph_viewer, initializer_input_lookup, qnn_backend_manager_->GetQnnBackendType()); - std::unordered_set handled_node_units; - handled_node_units.reserve(node_unit_size); + std::vector qnn_node_groups; + qnn_node_groups.reserve(node_unit_size); - auto add_supported_nodes = [](std::unordered_set& supported_nodes, const NodeUnit* node_unit) { - for (const auto* node_in_group : node_unit->GetAllNodesInGroup()) { - supported_nodes.insert(node_in_group); + if (Status status = qnn::GetQnnNodeGroups(qnn_node_groups, qnn_model_wrapper, + node_unit_map, node_unit_size, logger); + !status.IsOK()) { + LOGS(logger, ERROR) << status.ErrorMessage(); + return {}; + } + + auto add_supported_nodes = [](std::unordered_set& supported_nodes, + const qnn::QnnNodeGroup& qnn_node_group) { + for (const NodeUnit* node_unit : qnn_node_group.GetNodeUnits()) { + for (const Node* node : node_unit->GetAllNodesInGroup()) { + supported_nodes.insert(node); + } } }; - const auto& node_indices = graph_viewer.GetNodesInTopologicalOrder(); - for (size_t i = 0; i < node_indices.size(); i++) { - gsl::not_null node(graph_viewer.GetNode(node_indices[i])); - - // Get the node_unit associated with the node. Note that the node may not be the node_unit's target node. - const NodeUnit* node_unit = node_unit_map.at(node); - - // Visiting 'nodes' in topological order does not guarantee that 'node_units' are - // also visited in topological order. Skip this node if it is not the node_unit's target node - // to ensure 'node_units' are visited in topological order. - if (node != &node_unit->GetNode()) { - continue; + auto log_node_support = [](const logging::Logger& logger, + logging::Severity log_severity, + logging::DataType log_data_type, + const onnxruntime::CodeLocation& call_site, + const qnn::QnnNodeGroup& qnn_node_group, + bool supported) { + if (!logger.OutputIsEnabled(log_severity, log_data_type)) { + return; } - if (handled_node_units.count(node_unit) != 0) { - continue; // Already handled this node unit + std::ostringstream oss; + oss << "[QNN EP] " << (supported ? "Supports " : "Does NOT support ") << "the following nodes as part of a " + << qnn::QnnNodeGroup::TypeToString(qnn_node_group.type_) << " group:" << std::endl; + for (const NodeUnit* node_unit : qnn_node_group.GetNodeUnits()) { + for (const Node* node : node_unit->GetAllNodesInGroup()) { + oss << "\tOperator type: " << node->OpType() + << " Node name: " << node->Name() + << " Node index: " << node->Index() << std::endl; + } } - // Try to see if this node unit can be fused. - std::vector fused_nodes; - Status fusion_status = TryFusions(fused_nodes, qnn_model_wrapper, *node_unit, node_unit_map, - handled_node_units, logger, true /*do_op_validation*/); + logging::Capture(logger, log_severity, logging::Category::onnxruntime, + log_data_type, call_site) + .Stream() + << oss.str(); + }; - if (!fusion_status.IsOK()) { - LOGS(logger, WARNING) << "Failed to apply fusion: " << fusion_status.ErrorMessage(); - handled_node_units.insert(node_unit); - continue; - } + for (const qnn::QnnNodeGroup& qnn_node_group : qnn_node_groups) { + Status status = qnn_node_group.IsSupported(qnn_model_wrapper, logger); + const bool supported = status.IsOK(); - if (!fused_nodes.empty()) { - for (auto fused_node_unit : fused_nodes) { - handled_node_units.insert(fused_node_unit); - add_supported_nodes(supported_nodes, fused_node_unit); - } - continue; + constexpr auto log_severity = logging::Severity::kVERBOSE; + constexpr auto log_data_type = logging::DataType::SYSTEM; + if (logger.OutputIsEnabled(log_severity, log_data_type)) { + log_node_support(logger, log_severity, log_data_type, ORT_WHERE, qnn_node_group, supported); } - // Couldn't fuse the node unit. See if it is supported by itself. - const bool supported = IsNodeSupported(qnn_model_wrapper, *node_unit, logger); - LOGS(logger, VERBOSE) << "Node supported: [" << supported - << "] index: [" << node->Index() - << "] name: [" << node->Name() - << "] Operator type: [" << node->OpType() - << "] as part of the NodeUnit type: [" << node_unit->OpType() - << "] index: [" << node_unit->Index() - << "] name: [" << node_unit->Name() - << "]"; - if (supported) { - add_supported_nodes(supported_nodes, node_unit); + add_supported_nodes(supported_nodes, qnn_node_group); } - - handled_node_units.insert(node_unit); } return supported_nodes; diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.h b/onnxruntime/core/providers/qnn/qnn_execution_provider.h index e7419dabb14d1..b9e3608856b65 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.h +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.h @@ -53,9 +53,6 @@ class QNNExecutionProvider : public IExecutionProvider { Status OnRunEnd(bool sync_stream, const onnxruntime::RunOptions& run_options) override; private: - bool IsNodeSupported(qnn::QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit, - const logging::Logger& logger) const; - std::unordered_set GetSupportedNodes(const GraphViewer& graph_viewer, const std::unordered_map& node_unit_map, const size_t node_unit_size, diff --git a/onnxruntime/test/providers/qnn/qnn_basic_test.cc b/onnxruntime/test/providers/qnn/qnn_basic_test.cc index 0ea638567e83a..37eeac5101feb 100644 --- a/onnxruntime/test/providers/qnn/qnn_basic_test.cc +++ b/onnxruntime/test/providers/qnn/qnn_basic_test.cc @@ -484,7 +484,7 @@ static GetTestModelFn F32BuildAdd3Tensors(const TestInputDef& input0_def, } // Tests running a single session in multiple threads on the CPU backend. -TEST_F(QnnCPUBackendTests, MultithreadSessionRun) { +TEST_F(QnnCPUBackendTests, DISABLED_MultithreadSessionRun) { std::unique_ptr model; std::vector input_data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; std::vector shape = {1, 3, 2}; @@ -564,7 +564,7 @@ static GetTestModelFn QDQBuildAdd3Tensors(const TestInputDef& input0_def, } // Tests running a single session in multiple threads on the HTP backend. -TEST_F(QnnHTPBackendTests, MultithreadSessionRun) { +TEST_F(QnnHTPBackendTests, DISABLED_MultithreadSessionRun) { std::unique_ptr model; std::vector input_data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; std::vector shape = {1, 3, 2}; @@ -616,7 +616,7 @@ TEST_F(QnnHTPBackendTests, MultithreadSessionRun) { } // Tests running a single session in multiple threads on the HTP backend with run option to set power config -TEST_F(QnnHTPBackendTests, MultithreadHtpPowerCfgSessionRunOption) { +TEST_F(QnnHTPBackendTests, DISABLED_MultithreadHtpPowerCfgSessionRunOption) { std::unique_ptr model; std::vector input_data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; std::vector shape = {1, 3, 2}; @@ -678,7 +678,7 @@ TEST_F(QnnHTPBackendTests, MultithreadHtpPowerCfgSessionRunOption) { } // Tests running a single session in multiple threads on the HTP backend with EP option to set default power config -TEST_F(QnnHTPBackendTests, MultithreadDefaultHtpPowerCfgFromEpOption) { +TEST_F(QnnHTPBackendTests, DISABLED_MultithreadDefaultHtpPowerCfgFromEpOption) { std::unique_ptr model; std::vector input_data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; std::vector shape = {1, 3, 2}; @@ -732,7 +732,7 @@ TEST_F(QnnHTPBackendTests, MultithreadDefaultHtpPowerCfgFromEpOption) { // Tests running a single session in multiple threads on the HTP backend with // EP option to set default power config + run option to set power config for each run -TEST_F(QnnHTPBackendTests, MultithreadHtpPowerCfgDefaultAndRunOption) { +TEST_F(QnnHTPBackendTests, DISABLED_MultithreadHtpPowerCfgDefaultAndRunOption) { std::unique_ptr model; std::vector input_data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; std::vector shape = {1, 3, 2}; @@ -953,18 +953,18 @@ TEST_F(QnnHTPBackendTests, TestOD) { #if 1 const ORTCHAR_T* ort_model_path = ORT_MODEL_FOLDER "od_current_tf2onnx.onnx"; - // so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); + //so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); #else const ORTCHAR_T* ort_model_path = ORT_MODEL_FOLDER "unet.preprocessed.quant.onnx_ctx.onnx"; #endif - auto& logging_manager = DefaultLoggingManager(); - logging_manager.SetDefaultLoggerSeverity(logging::Severity::kVERBOSE); + //auto& logging_manager = DefaultLoggingManager(); + //logging_manager.SetDefaultLoggerSeverity(logging::Severity::kVERBOSE); // Ensure all type/shape inference warnings result in errors! so.AddConfigEntry(kOrtSessionOptionsDisableCPUEPFallback, "0"); // Disable fallback to the CPU EP. so.AddConfigEntry(kDebugLayoutTransformation, "1"); - so.SetGraphOptimizationLevel(ORT_ENABLE_ALL); - so.SetLogSeverityLevel(ORT_LOGGING_LEVEL_VERBOSE); + so.SetGraphOptimizationLevel(ORT_ENABLE_BASIC); + //so.SetLogSeverityLevel(ORT_LOGGING_LEVEL_VERBOSE); onnxruntime::ProviderOptions options; #if defined(_WIN32) From e2c9c00959cf53b1ef808633927687180dfe4fe7 Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Fri, 26 Jul 2024 23:11:29 -0700 Subject: [PATCH 04/20] Finish clip removal calcs --- .../qnn/builder/qnn_conv_activation_fusion.cc | 182 ++++++++++++++++-- .../core/providers/qnn/builder/qnn_fusions.cc | 32 --- .../providers/qnn/qnn_execution_provider.cc | 15 +- 3 files changed, 168 insertions(+), 61 deletions(-) diff --git a/onnxruntime/core/providers/qnn/builder/qnn_conv_activation_fusion.cc b/onnxruntime/core/providers/qnn/builder/qnn_conv_activation_fusion.cc index 26e7a5ff4cbf8..0aba738627e57 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_conv_activation_fusion.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_conv_activation_fusion.cc @@ -8,6 +8,7 @@ #include "core/graph/graph_utils.h" #include "core/optimizer/qdq_transformer/qdq_util.h" #include "core/framework/node_unit.h" +#include "core/providers/shared/utils/utils.h" #include "core/providers/qnn/builder/qnn_utils.h" #include "core/providers/qnn/builder/op_builder_factory.h" @@ -60,19 +61,12 @@ static const NodeUnit* GetOnlyChildOfType(const GraphViewer& graph_viewer, return child_node_unit; } -static bool CanClipBeRemoved(const QnnModelWrapper& qnn_model_wrapper, - const NodeUnit& clip_node_unit, - const NodeUnit& q_node_unit) { - assert(clip_node_unit.OpType() == "Clip" && q_node_unit.OpType() == QDQ::QOpName); - // TODO(adrianlizarraga): Implement. - (void)qnn_model_wrapper; - return true; -} - -static bool CanReluBeRemoved(const QnnModelWrapper& qnn_model_wrapper, - const NodeUnit& relu_node_unit, - const NodeUnit& q_node_unit) { - assert(relu_node_unit.OpType() == "Relu" && q_node_unit.OpType() == QDQ::QOpName); +static bool GetQScalarScaleZeroPoint(const QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& q_node_unit, + /*out*/ float& scale, + /*out*/ int32_t& zero_point, + /*out*/ int32_t& zp_data_type) { + assert(q_node_unit.OpType() == QDQ::QOpName); const auto& q_inputs = q_node_unit.GetNode().InputDefs(); // Require an explicit zero-point input for now. @@ -81,7 +75,6 @@ static bool CanReluBeRemoved(const QnnModelWrapper& qnn_model_wrapper, } std::vector zero_points; - int32_t zp_data_type = ONNX_NAMESPACE::TensorProto::DataType::TensorProto_DataType_UNDEFINED; Status status = qnn_model_wrapper.UnpackZeroPoints(q_inputs[QDQ::ZERO_POINT_ID]->Name(), zero_points, zp_data_type); @@ -89,19 +82,170 @@ static bool CanReluBeRemoved(const QnnModelWrapper& qnn_model_wrapper, if (!status.IsOK() || zero_points.size() != 1) { return false; } + zero_point = -zero_points[0]; // QNN zero-points are negated. + + std::vector scales; + status = qnn_model_wrapper.UnpackScales(q_inputs[QDQ::SCALE_ID]->Name(), scales); + + // Should only have one scale (per-tensor). + if (!status.IsOK() || scales.size() != 1) { + return false; + } + + scale = scales[0]; + return true; +} + +static bool GetQRminRmax(const QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& q_node_unit, + /*out*/ float& rmin, + /*out*/ float& rmax) { + int32_t zp_data_type = ONNX_NAMESPACE::TensorProto::DataType::TensorProto_DataType_UNDEFINED; + int32_t zero_point = 0; + float scale = 0.0f; + + if (!GetQScalarScaleZeroPoint(qnn_model_wrapper, q_node_unit, scale, zero_point, zp_data_type)) { + return false; + } + + switch (zp_data_type) { + case ONNX_NAMESPACE::TensorProto_DataType_INT8: { + rmin = scale * (std::numeric_limits::lowest() - zero_point); + rmax = scale * (std::numeric_limits::max() - zero_point); + break; + } + case ONNX_NAMESPACE::TensorProto_DataType_UINT8: { + rmin = scale * (std::numeric_limits::lowest() - zero_point); + rmax = scale * (std::numeric_limits::max() - zero_point); + break; + } + case ONNX_NAMESPACE::TensorProto_DataType_INT16: { + rmin = scale * (std::numeric_limits::lowest() - zero_point); + rmax = scale * (std::numeric_limits::max() - zero_point); + break; + } + case ONNX_NAMESPACE::TensorProto_DataType_UINT16: { + rmin = scale * (std::numeric_limits::lowest() - zero_point); + rmax = scale * (std::numeric_limits::max() - zero_point); + break; + } + default: + return false; + } + + return true; +} + +static bool GetClipMinMax(const QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& clip_node_unit, + /*out*/ float& clip_min, + /*out*/ float& clip_max) { + clip_min = std::numeric_limits::lowest(); + clip_max = std::numeric_limits::max(); + + // Clip's min and max are attributes before opset 11. + if (clip_node_unit.GetNode().SinceVersion() < 11) { + NodeAttrHelper attr_helper(clip_node_unit); + std::optional min_opt = attr_helper.GetFloat("min"); + std::optional max_opt = attr_helper.GetFloat("max"); + + if (min_opt.has_value()) { + clip_min = min_opt.value(); + } + + if (max_opt.has_value()) { + clip_max = max_opt.value(); + } + + return true; + } + + // After opset 11, min and max are inputs. + const auto& inputs = clip_node_unit.Inputs(); + const size_t num_inputs = inputs.size(); + auto get_min_or_max = [&qnn_model_wrapper](const NodeUnitIODef& input, /*out*/ float& result) -> bool { + TensorInfo input_info = {}; + std::vector raw_bytes; + if (Status status = qnn_model_wrapper.GetTensorInfo(input, input_info); !status.IsOK()) { + return false; + } + if (!input_info.is_initializer) { + return false; + } + if (Status status = qnn_model_wrapper.UnpackInitializerData(*input_info.initializer_tensor, raw_bytes); + !status.IsOK()) { + return false; + } + if (input_info.qnn_data_type != QNN_DATATYPE_FLOAT_32) { + return false; + } + result = static_cast(*reinterpret_cast(raw_bytes.data())); + return true; + }; + + if (num_inputs > 1 && inputs[1].node_arg.Exists()) { + if (!get_min_or_max(inputs[1], clip_min)) { + return false; + } + } + + if (num_inputs > 2 && inputs[2].node_arg.Exists()) { + if (!get_min_or_max(inputs[2], clip_max)) { + return false; + } + } + + return true; +} + +static bool CanClipBeRemoved(const QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& clip_node_unit, + const NodeUnit& q_node_unit) { + assert(clip_node_unit.OpType() == "Clip" && q_node_unit.OpType() == QDQ::QOpName); + float rmin = 0.0f; + float rmax = 0.0f; + + if (!GetQRminRmax(qnn_model_wrapper, q_node_unit, rmin, rmax)) { + return false; + } + + float clip_min = std::numeric_limits::lowest(); + float clip_max = std::numeric_limits::max(); + + if (!GetClipMinMax(qnn_model_wrapper, clip_node_unit, clip_min, clip_max)) { + return false; + } + + constexpr float epsilon = std::numeric_limits::epsilon(); + if ((epsilon < clip_min - rmin) || (epsilon < rmax - clip_max)) { + return false; + } - int32_t onnx_zero_point = -zero_points[0]; // QNN zero-points are negated. + return true; +} + +static bool CanReluBeRemoved(const QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& relu_node_unit, + const NodeUnit& q_node_unit) { + assert(relu_node_unit.OpType() == "Relu" && q_node_unit.OpType() == QDQ::QOpName); + int32_t zp_data_type = ONNX_NAMESPACE::TensorProto::DataType::TensorProto_DataType_UNDEFINED; + int32_t zero_point = 0; + float scale = 0.0f; + + if (!GetQScalarScaleZeroPoint(qnn_model_wrapper, q_node_unit, scale, zero_point, zp_data_type)) { + return false; + } // Relu is redundant if the zero-point is set to the smallest quantized value. switch (zp_data_type) { case ONNX_NAMESPACE::TensorProto::DataType::TensorProto_DataType_INT8: - return onnx_zero_point == static_cast(std::numeric_limits::lowest()); + return zero_point == static_cast(std::numeric_limits::lowest()); case ONNX_NAMESPACE::TensorProto::DataType::TensorProto_DataType_UINT8: - return onnx_zero_point == static_cast(std::numeric_limits::lowest()); + return zero_point == static_cast(std::numeric_limits::lowest()); case ONNX_NAMESPACE::TensorProto::DataType::TensorProto_DataType_INT16: - return onnx_zero_point == static_cast(std::numeric_limits::lowest()); + return zero_point == static_cast(std::numeric_limits::lowest()); case ONNX_NAMESPACE::TensorProto::DataType::TensorProto_DataType_UINT16: - return onnx_zero_point == static_cast(std::numeric_limits::lowest()); + return zero_point == static_cast(std::numeric_limits::lowest()); default: return false; } diff --git a/onnxruntime/core/providers/qnn/builder/qnn_fusions.cc b/onnxruntime/core/providers/qnn/builder/qnn_fusions.cc index 8b7bf899cd622..358292cb3cc17 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_fusions.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_fusions.cc @@ -471,38 +471,6 @@ using FusionFunc = Status (*)(std::optional&, const std::unordered_map&, const logging::Logger&); -#if 0 -Status TryFusions(/*out*/ std::vector& fused_nodes, - QnnModelWrapper& qnn_model_wrapper, - const NodeUnit& starting_node, - const std::unordered_map& node_unit_map, - const std::unordered_set& handled_node_units, - const logging::Logger& logger, - bool validate) { - // Maps a starting operator type to the fusion function. - static std::unordered_map fusions = { - {"DequantizeLinear", TryHandleDequantize}, - {"HardSigmoid", TryHandleHardSigmoidSequence}, - }; - - // For now, all fusions involve standalone node units (i.e., no wrapping DQ/Q nodes). - if (starting_node.UnitType() != NodeUnit::Type::SingleNode) { - return Status::OK(); - } - - auto iter = fusions.find(starting_node.OpType()); - if (iter != fusions.end()) { - fused_nodes.clear(); - - FusionFunc fusion_func = iter->second; - ORT_RETURN_IF_ERROR(fusion_func(fused_nodes, qnn_model_wrapper, starting_node, node_unit_map, - handled_node_units, logger, validate)); - } - - return Status::OK(); -} -#endif - static Status TryQnnFusions(/*out*/ std::optional& fused_node_group, QnnModelWrapper& qnn_model_wrapper, const NodeUnit& starting_node_unit, diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc index 15306731b1f3e..eabd1a5043ba1 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc @@ -451,15 +451,6 @@ QNNExecutionProvider::GetSupportedNodes(const GraphViewer& graph_viewer, return {}; } - auto add_supported_nodes = [](std::unordered_set& supported_nodes, - const qnn::QnnNodeGroup& qnn_node_group) { - for (const NodeUnit* node_unit : qnn_node_group.GetNodeUnits()) { - for (const Node* node : node_unit->GetAllNodesInGroup()) { - supported_nodes.insert(node); - } - } - }; - auto log_node_support = [](const logging::Logger& logger, logging::Severity log_severity, logging::DataType log_data_type, @@ -498,7 +489,11 @@ QNNExecutionProvider::GetSupportedNodes(const GraphViewer& graph_viewer, } if (supported) { - add_supported_nodes(supported_nodes, qnn_node_group); + for (const NodeUnit* node_unit : qnn_node_group.GetNodeUnits()) { + for (const Node* node : node_unit->GetAllNodesInGroup()) { + supported_nodes.insert(node); + } + } } } From 6f042c86e1b367f04ec8fe698c37f8266086d006 Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Sat, 27 Jul 2024 00:46:22 -0700 Subject: [PATCH 05/20] Dont always return Status --- .../qnn/builder/qnn_conv_activation_fusion.cc | 153 ++++++++++-------- .../qnn/builder/qnn_conv_activation_fusion.h | 12 +- .../core/providers/qnn/builder/qnn_fusions.cc | 127 ++++++++------- 3 files changed, 166 insertions(+), 126 deletions(-) diff --git a/onnxruntime/core/providers/qnn/builder/qnn_conv_activation_fusion.cc b/onnxruntime/core/providers/qnn/builder/qnn_conv_activation_fusion.cc index 0aba738627e57..b62c5f21f82ba 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_conv_activation_fusion.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_conv_activation_fusion.cc @@ -29,6 +29,9 @@ static const NodeUnit* GetOnlyChildOfType(const GraphViewer& graph_viewer, // Child must be of a valid type. const Node& child_node = parent_node.OutputEdgesBegin()->GetNode(); + if (graph_viewer.GetNode(child_node.Index()) == nullptr) { + return nullptr; // Node is not in this GraphViewer + } const std::string& child_type = child_node.OpType(); bool is_valid_child_type = false; @@ -44,7 +47,9 @@ static const NodeUnit* GetOnlyChildOfType(const GraphViewer& graph_viewer, } const auto child_node_unit_it = node_unit_map.find(&child_node); - assert(child_node_unit_it != node_unit_map.end()); + if (child_node_unit_it == node_unit_map.end()) { + return nullptr; + } const NodeUnit* child_node_unit = child_node_unit_it->second; // Check if child node has already been handled. Should not be the case if the calling @@ -290,64 +295,84 @@ static std::vector FindQDQNodes(const GraphViewer& graph_viewer, co return nodes; } -static Status GetConvDQNodeUnits( - /*out*/ std::vector& dq_node_units, +static std::vector GetConvDQs( const GraphViewer& graph_viewer, const std::unordered_map& node_to_node_unit, const std::unordered_map& node_unit_to_qnn_node_group, - const Node& conv_node, - const Node& q_node) { - assert((conv_node.OpType() == "Conv" || conv_node.OpType() == "ConvTranspose") && - q_node.OpType() == QDQ::QOpName); + const Node& conv_node) { + assert(conv_node.OpType() == "Conv" || conv_node.OpType() == "ConvTranspose"); std::vector dq_nodes = FindQDQNodes(graph_viewer, conv_node, /*find_dq_nodes*/ true); - std::vector q_nodes = {&q_node}; int num_dq_inputs = NumActualValues(conv_node, /*input*/ true); // Within a QDQ node group, a target node input is the only consumer of each DQ. - ORT_RETURN_IF_NOT(num_dq_inputs == static_cast(dq_nodes.size()), - "Conv should be the only consumer of each DQ"); + if (num_dq_inputs != static_cast(dq_nodes.size())) { + return {}; + } + std::vector dq_node_units; for (const auto* dq_node : dq_nodes) { - ORT_RETURN_IF(graph_viewer.NodeProducesGraphOutput(*dq_node), - "QDQ ", conv_node.OpType(), "'s input DQ node must not produce a graph output"); + if (graph_viewer.NodeProducesGraphOutput(*dq_node)) { + return {}; + } const bool dq_has_single_output_edge_to_target = dq_node->GetOutputEdgesCount() == 1 && dq_node->OutputEdgesBegin()->GetNode().Index() == conv_node.Index(); - ORT_RETURN_IF_NOT(dq_has_single_output_edge_to_target, "DQ should have a single output to Conv"); + if (!dq_has_single_output_edge_to_target) { + return {}; + } + + const auto it = node_to_node_unit.find(dq_node); + if (it == node_to_node_unit.end()) { + return {}; + } + + const NodeUnit* dq_node_unit = it->second; + + if (!dq_node_unit || node_unit_to_qnn_node_group.count(dq_node_unit) != 0) { + return {}; + } + + if (dq_node_unit->UnitType() != NodeUnit::Type::SingleNode) { + return {}; + } + + dq_node_units.push_back(dq_node_unit); } - // input and output types need to be same - int32_t dt_input = dq_nodes[0]->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); - int32_t dt_weight = dq_nodes[1]->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); - int32_t dt_output = q_nodes[0]->OutputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); - ORT_RETURN_IF(dt_input != dt_output, "Conv input[0] and output quantization types must match"); + return dq_node_units; +} - if (dt_input == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8) { - ORT_RETURN_IF(dt_weight != dt_input, - conv_node.OpType(), "'s input[0] and input[1] quantization types must match if input[0] is int8"); +static bool IsValidQDQConv(gsl::span dq_node_units, + gsl::not_null q_node_unit) { + assert(q_node_unit->OpType() == QDQ::QOpName); + const size_t num_dqs = dq_node_units.size(); + if (num_dqs != 2 && num_dqs != 3) { + return false; } - if (dq_nodes.size() == 3) { // has bias - int32_t dt_bias = dq_nodes[2]->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); - ORT_RETURN_IF(dt_bias != ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32, - "QDQ ", conv_node.OpType(), " must have int32 quantized bias"); + // input and output types need to be same + int32_t dt_input = dq_node_units[0]->GetNode().InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); + int32_t dt_weight = dq_node_units[1]->GetNode().InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); + int32_t dt_output = q_node_unit->GetNode().OutputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); + if (dt_input != dt_output) { + return false; } - dq_node_units.reserve(dq_nodes.size()); - for (const auto* dq_node : dq_nodes) { - const auto it = node_to_node_unit.find(dq_node); - assert(it != node_to_node_unit.end()); - const NodeUnit* dq_node_unit = it->second; + if (dt_input == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8) { + if (dt_weight != dt_input) { + return false; + } + } - ORT_RETURN_IF_NOT(node_unit_to_qnn_node_group.count(dq_node_unit) == 0, - "DQ NodeUnit ", dq_node_unit->Name(), " has already been added to another QnnNodeGroup"); - ORT_RETURN_IF_NOT(dq_node_unit->UnitType() == NodeUnit::Type::SingleNode, - "Expect DQ to be a NodeUnit of type SingleNode"); - dq_node_units.push_back(dq_node_unit); + if (num_dqs == 3) { // has bias + int32_t dt_bias = dq_node_units[2]->GetNode().InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); + if (dt_bias != ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32) { + return false; + } } - return Status::OK(); + return true; } Status QnnConvActivationFusionAdd(QnnModelWrapper& qnn_model_wrapper, @@ -378,12 +403,11 @@ Status QnnConvActivationFusionAdd(QnnModelWrapper& qnn_model_wrapper, return conv_op_builder->AddToModelBuilder(qnn_model_wrapper, custom_node_unit, logger, validate); } -Status TryConvActivationFusion(/*out*/ std::optional& qnn_node_group, - QnnModelWrapper& qnn_model_wrapper, - const NodeUnit& conv_node_unit, - const std::unordered_map& node_to_node_unit, - const std::unordered_map& node_unit_to_qnn_node_group, - const logging::Logger& logger) { +std::optional TryConvActivationFusion(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& conv_node_unit, + const std::unordered_map& node_to_node_unit, + const std::unordered_map& node_unit_to_qnn_node_group, + const logging::Logger& logger) { // Expect that this function is called with a standalone Conv or ConvTranspose. assert((conv_node_unit.OpType() == "Conv" || conv_node_unit.OpType() == "ConvTranspose") && conv_node_unit.UnitType() == NodeUnit::Type::SingleNode); @@ -395,7 +419,7 @@ Status TryConvActivationFusion(/*out*/ std::optional& qnn_node_gro const NodeUnit* activation_node_unit = GetOnlyChildOfType(graph_viewer, conv_node_unit, activation_op_types, node_to_node_unit, node_unit_to_qnn_node_group); if (activation_node_unit == nullptr) { - return Status::OK(); + return std::nullopt; } // Relu/Clip must have a single Q child. @@ -404,28 +428,25 @@ Status TryConvActivationFusion(/*out*/ std::optional& qnn_node_gro node_to_node_unit, node_unit_to_qnn_node_group); if (q_node_unit == nullptr) { - return Status::OK(); + return std::nullopt; } // Check if Clip/Relu can be removed because the Q node provides an equivalent effect. if (!CanActivationBeRemoved(qnn_model_wrapper, *activation_node_unit, *q_node_unit)) { - return Status::OK(); + return std::nullopt; } // Create a QDQ node group with DQ* -> Conv -> Q const Node& conv_node = conv_node_unit.GetNode(); const Node& activation_node = activation_node_unit->GetNode(); - const Node& q_node = q_node_unit->GetNode(); - std::vector dq_node_units; - QNN_RETURN_OK_IF_ERROR(GetConvDQNodeUnits(dq_node_units, - graph_viewer, - node_to_node_unit, - node_unit_to_qnn_node_group, - conv_node, - q_node), - logger); + std::vector dq_node_units = GetConvDQs(graph_viewer, + node_to_node_unit, + node_unit_to_qnn_node_group, + conv_node); - assert(dq_node_units.size() == 3 || dq_node_units.size() == 2); + if (!IsValidQDQConv(dq_node_units, q_node_unit)) { + return std::nullopt; + } // Create a temporary QnnModelWrapper for validation only. We need to be sure that this fusion will work before // modifying the actual QnnModelWrapper. This allows us to revert to the traditional OpBuilder workflow if this @@ -439,14 +460,16 @@ Status TryConvActivationFusion(/*out*/ std::optional& qnn_node_gro qnn_model_wrapper.GetInitializerLookup(), qnn_model_wrapper.GetQnnBackendType()); - QNN_RETURN_OK_IF_ERROR(QnnConvActivationFusionAdd(tmp_model_wrapper, - dq_node_units, - &conv_node_unit, - activation_node_unit, - q_node_unit, - logger, - /*validate*/ true), - logger); + if (Status status = QnnConvActivationFusionAdd(tmp_model_wrapper, + dq_node_units, + &conv_node_unit, + activation_node_unit, + q_node_unit, + logger, + /*validate*/ true); + !status.IsOK()) { + return std::nullopt; + } // Validation passed, so create a QnnNodeGroup. // If we encounter an error, we return it directly to caller. @@ -455,14 +478,14 @@ Status TryConvActivationFusion(/*out*/ std::optional& qnn_node_gro << "] activation_node name: [" << activation_node.Name() << "]"; - qnn_node_group = QnnNodeGroup{}; + std::optional qnn_node_group = QnnNodeGroup{}; qnn_node_group->type_ = QnnNodeGroup::Type::ConvActivationFusion; qnn_node_group->node_units_ = std::move(dq_node_units); qnn_node_group->node_units_.push_back(&conv_node_unit); qnn_node_group->node_units_.push_back(activation_node_unit); qnn_node_group->node_units_.push_back(q_node_unit); - return Status::OK(); + return qnn_node_group; } } // namespace qnn } // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/qnn_conv_activation_fusion.h b/onnxruntime/core/providers/qnn/builder/qnn_conv_activation_fusion.h index 9cca16536ad95..3f5bdfc3078dc 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_conv_activation_fusion.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_conv_activation_fusion.h @@ -23,11 +23,11 @@ Status QnnConvActivationFusionAdd(QnnModelWrapper& qnn_model_wrapper, const logging::Logger& logger, bool validate = false); -Status TryConvActivationFusion(/*out*/ std::optional& qnn_node_group, - QnnModelWrapper& qnn_model_wrapper, - const NodeUnit& conv_node_unit, - const std::unordered_map& node_to_node_unit, - const std::unordered_map& node_unit_to_qnn_node_group, - const logging::Logger& logger); +std::optional TryConvActivationFusion( + QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& conv_node_unit, + const std::unordered_map& node_to_node_unit, + const std::unordered_map& node_unit_to_qnn_node_group, + const logging::Logger& logger); } // namespace qnn } // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/qnn_fusions.cc b/onnxruntime/core/providers/qnn/builder/qnn_fusions.cc index 358292cb3cc17..04b727f534424 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_fusions.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_fusions.cc @@ -308,12 +308,12 @@ const NodeUnit* QnnNodeGroup::GetTargetNodeUnit(const logging::Logger& logger) c * \param do_op_validation True if should call QNN operator validation APIs. * \return An onnxruntime::Status */ -static Status TryDQQFusion(std::optional& qnn_node_group, - QnnModelWrapper& qnn_model_wrapper, - const NodeUnit& dq_node_unit, - const std::unordered_map& node_to_node_unit, - const std::unordered_map& node_unit_to_qnn_node_group, - const logging::Logger& logger) { +static std::optional TryDQQFusion( + QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& dq_node_unit, + const std::unordered_map& node_to_node_unit, + const std::unordered_map& node_unit_to_qnn_node_group, + const logging::Logger& logger) { // Expect that this function is called with a standalone DQ. assert(dq_node_unit.OpType() == QDQ::DQOpName && dq_node_unit.UnitType() == NodeUnit::Type::SingleNode); @@ -322,27 +322,33 @@ static Status TryDQQFusion(std::optional& qnn_node_group, // DQ must have a single child (1 output edge) and must not produce a graph output. if (dq_node.GetOutputEdgesCount() != 1 || graph_viewer.NodeProducesGraphOutput(dq_node)) { - return Status::OK(); + return std::nullopt; } const Node& q_node = dq_node.OutputEdgesBegin()->GetNode(); if (q_node.OpType() != QDQ::QOpName) { - return Status::OK(); + return std::nullopt; + } + + if (graph_viewer.GetNode(q_node.Index()) == nullptr) { + return std::nullopt; // Node is not in this GraphViewer } const auto q_node_unit_it = node_to_node_unit.find(&q_node); - ORT_RETURN_IF(q_node_unit_it == node_to_node_unit.end(), "Node does not have a corresponding NodeUnit"); + if (q_node_unit_it == node_to_node_unit.end()) { + return std::nullopt; + } const NodeUnit* q_node_unit = q_node_unit_it->second; // child must not already be part of a QDQ NodeUnit (i.e., be standalone). if (q_node_unit->UnitType() != NodeUnit::Type::SingleNode) { - return Status::OK(); + return std::nullopt; } // Check if child node has already been handled. Should not be the case if this // fusion function has been called in topological order, but check to be safe. if (node_unit_to_qnn_node_group.count(q_node_unit) != 0) { - return Status::OK(); + return std::nullopt; } auto get_const_initializer = [&graph_viewer](const std::string& initializer_name) { @@ -351,11 +357,14 @@ static Status TryDQQFusion(std::optional& qnn_node_group, // DQ and Q must have equal scale type and different zp type. if (!QDQ::IsDQQConversion(dq_node, q_node, get_const_initializer, graph_viewer.ModelPath())) { - return Status::OK(); + return std::nullopt; } - QNN_RETURN_OK_IF_ERROR(QnnDQQFusionAdd(qnn_model_wrapper, dq_node_unit, *q_node_unit, logger, /*validate*/ true), - logger); + if (Status status = QnnDQQFusionAdd(qnn_model_wrapper, dq_node_unit, *q_node_unit, + logger, /*validate*/ true); + !status.IsOK()) { + return std::nullopt; + } // Validation passed, so create a QnnNodeGroup. LOGS(logger, VERBOSE) << " Will use QNN Convert via fusion. dq_node name: [" << dq_node.Name() @@ -364,12 +373,12 @@ static Status TryDQQFusion(std::optional& qnn_node_group, << "] q_node optype: [" << q_node_unit->OpType() << "]"; - qnn_node_group = QnnNodeGroup{}; + std::optional qnn_node_group = QnnNodeGroup{}; qnn_node_group->type_ = QnnNodeGroup::Type::DQQFusion; qnn_node_group->node_units_.push_back(&dq_node_unit); qnn_node_group->node_units_.push_back(q_node_unit); - return Status::OK(); + return qnn_node_group; } /** @@ -386,16 +395,16 @@ static Status TryDQQFusion(std::optional& qnn_node_group, * \param do_op_validation True if should call QNN operator validation APIs. * \return A Status indicating a potential failure. */ -static Status TryHardSigmoidMulFusion(std::optional& qnn_node_group, - QnnModelWrapper& qnn_model_wrapper, - const NodeUnit& hardsigmoid_node_unit, - const std::unordered_map& node_to_node_unit, - const std::unordered_map& node_unit_to_qnn_node_group, - const logging::Logger& logger) { +static std::optional TryHardSigmoidMulFusion( + QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& hardsigmoid_node_unit, + const std::unordered_map& node_to_node_unit, + const std::unordered_map& node_unit_to_qnn_node_group, + const logging::Logger& logger) { // Looking for a standalone HardSigmoid to start the sequence. if (hardsigmoid_node_unit.OpType() != "HardSigmoid" || hardsigmoid_node_unit.UnitType() != NodeUnit::Type::SingleNode) { - return Status::OK(); + return std::nullopt; } NodeAttrHelper hs_attr_helper(hardsigmoid_node_unit); @@ -408,7 +417,7 @@ static Status TryHardSigmoidMulFusion(std::optional& qnn_node_grou // Check for explicit values of alpha and beta. if (std::abs(alpha - req_alpha) > alpha_eps || std::abs(beta - req_beta) > beta_eps) { - return Status::OK(); + return std::nullopt; } const GraphViewer& graph_viewer = qnn_model_wrapper.GetGraphViewer(); @@ -416,27 +425,33 @@ static Status TryHardSigmoidMulFusion(std::optional& qnn_node_grou // HardSigmoid must have a single child (1 output edge) and must not produce a graph output. if (hs_node.GetOutputEdgesCount() != 1 || graph_viewer.NodeProducesGraphOutput(hs_node)) { - return Status::OK(); + return std::nullopt; } const Node& mul_node = hs_node.OutputEdgesBegin()->GetNode(); if (mul_node.OpType() != "Mul") { - return Status::OK(); + return std::nullopt; + } + + if (graph_viewer.GetNode(mul_node.Index()) == nullptr) { + return std::nullopt; // Node is not in this GraphViewer } const auto mul_node_unit_it = node_to_node_unit.find(&mul_node); - ORT_RETURN_IF(mul_node_unit_it == node_to_node_unit.end(), "Mul Node does not have a corresponding NodeUnit"); + if (mul_node_unit_it == node_to_node_unit.end()) { + return std::nullopt; + } const NodeUnit* mul_node_unit = mul_node_unit_it->second; // Check if Mul node has already been handled. Should not be the case if this // fusion function has been called in topological order, but check to be safe. if (node_unit_to_qnn_node_group.count(mul_node_unit) != 0) { - return Status::OK(); + return std::nullopt; } // Mul child must not already be part of a QDQ NodeUnit (i.e., be standalone). if (mul_node_unit->UnitType() != NodeUnit::Type::SingleNode) { - return Status::OK(); + return std::nullopt; } // Input to HardSigmoid must also be the other input to the Mul. @@ -445,38 +460,40 @@ static Status TryHardSigmoidMulFusion(std::optional& qnn_node_grou mul_node.InputDefs()[1]->Name() == hs_input_name; if (!same_root_input) { - return Status::OK(); + return std::nullopt; } - QNN_RETURN_OK_IF_ERROR(QnnHardSigmoidMulFusionAdd(qnn_model_wrapper, hardsigmoid_node_unit, *mul_node_unit, - logger, /*validate*/ true), - logger); + if (Status status = QnnHardSigmoidMulFusionAdd(qnn_model_wrapper, hardsigmoid_node_unit, *mul_node_unit, + logger, /*validate*/ true); + !status.IsOK()) { + return std::nullopt; + } // Validation passed, so create a QnnNodeGroup. Any errors are now passed back to the caller. LOGS(logger, VERBOSE) << "Will use QNN HardSwish via fusion. HardSigmoid name: [" << hardsigmoid_node_unit.Name() << "] Mul name: [" << mul_node_unit->Name() << "]"; - qnn_node_group = QnnNodeGroup{}; + std::optional qnn_node_group = QnnNodeGroup{}; qnn_node_group->type_ = QnnNodeGroup::Type::HardSigmoidMulFusion; qnn_node_group->node_units_.push_back(&hardsigmoid_node_unit); qnn_node_group->node_units_.push_back(mul_node_unit); - return Status::OK(); + return qnn_node_group; } -using FusionFunc = Status (*)(std::optional&, - QnnModelWrapper&, - const NodeUnit&, - const std::unordered_map&, - const std::unordered_map&, - const logging::Logger&); - -static Status TryQnnFusions(/*out*/ std::optional& fused_node_group, - QnnModelWrapper& qnn_model_wrapper, - const NodeUnit& starting_node_unit, - const std::unordered_map& node_to_node_unit, - const std::unordered_map& node_unit_to_qnn_node_group, - const logging::Logger& logger) { +using FusionFunc = std::optional (*)( + QnnModelWrapper&, + const NodeUnit&, + const std::unordered_map&, + const std::unordered_map&, + const logging::Logger&); + +static std::optional TryQnnFusions( + QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& starting_node_unit, + const std::unordered_map& node_to_node_unit, + const std::unordered_map& node_unit_to_qnn_node_group, + const logging::Logger& logger) { // Maps a starting operator type to the fusion function. static std::unordered_map fusions = { {"DequantizeLinear", TryDQQFusion}, @@ -487,16 +504,16 @@ static Status TryQnnFusions(/*out*/ std::optional& fused_node_grou // For now, all fusions involve standalone node units (i.e., no wrapping DQ/Q nodes). if (starting_node_unit.UnitType() != NodeUnit::Type::SingleNode) { - return Status::OK(); + return std::nullopt; } auto iter = fusions.find(starting_node_unit.OpType()); if (iter != fusions.end()) { FusionFunc fusion_func = iter->second; - ORT_RETURN_IF_ERROR(fusion_func(fused_node_group, qnn_model_wrapper, starting_node_unit, node_to_node_unit, - node_unit_to_qnn_node_group, logger)); + return fusion_func(qnn_model_wrapper, starting_node_unit, node_to_node_unit, + node_unit_to_qnn_node_group, logger); } - return Status::OK(); + return std::nullopt; } Status GetQnnNodeGroups(/*out*/ std::vector& qnn_node_groups, @@ -538,9 +555,9 @@ Status GetQnnNodeGroups(/*out*/ std::vector& qnn_node_groups, continue; // Already handled this node unit } - std::optional fused_node_group; - ORT_RETURN_IF_ERROR(TryQnnFusions(fused_node_group, qnn_model_wrapper, *node_unit, - node_to_node_unit, node_unit_to_qnn_node_group, logger)); + std::optional fused_node_group = TryQnnFusions(qnn_model_wrapper, *node_unit, + node_to_node_unit, node_unit_to_qnn_node_group, + logger); if (fused_node_group.has_value()) { const QnnNodeGroup::IndexType index = tmp_qnn_node_groups.size(); From dd8dc3d4457babce0309f7d850cb58ac0a5ccdbc Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Sat, 27 Jul 2024 02:07:53 -0700 Subject: [PATCH 06/20] Remove macro --- onnxruntime/core/providers/qnn/builder/qnn_utils.h | 9 --------- 1 file changed, 9 deletions(-) diff --git a/onnxruntime/core/providers/qnn/builder/qnn_utils.h b/onnxruntime/core/providers/qnn/builder/qnn_utils.h index 4305bf56f522e..2392040d284b7 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_utils.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_utils.h @@ -13,15 +13,6 @@ #include "core/framework/node_unit.h" #include "core/util/qmath.h" -#define QNN_RETURN_OK_IF_ERROR(expr, logger) \ - do { \ - auto _status = (expr); \ - if ((!_status.IsOK())) { \ - LOGS((logger), VERBOSE) << _status.ErrorMessage(); \ - return Status::OK(); \ - } \ - } while (0) - namespace onnxruntime { namespace qnn { class QnnOpConfigWrapper; From fa79a2f0e1931b005617dcc4d19812a1e66b9eb0 Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Sat, 27 Jul 2024 03:00:19 -0700 Subject: [PATCH 07/20] Use different NodeUnit constructor --- onnxruntime/core/framework/node_unit.cc | 47 ++++---------- onnxruntime/core/framework/node_unit.h | 6 +- .../qnn/builder/qnn_conv_activation_fusion.cc | 62 +++++++++++++++---- .../qnn/builder/qnn_conv_activation_fusion.h | 1 - .../core/providers/qnn/builder/qnn_fusions.cc | 3 - .../test/providers/qnn/qnn_basic_test.cc | 2 +- 6 files changed, 68 insertions(+), 53 deletions(-) diff --git a/onnxruntime/core/framework/node_unit.cc b/onnxruntime/core/framework/node_unit.cc index 54964b0275fc8..a491edb9699b3 100644 --- a/onnxruntime/core/framework/node_unit.cc +++ b/onnxruntime/core/framework/node_unit.cc @@ -272,41 +272,18 @@ NodeUnit::NodeUnit(const GraphViewer& graph_viewer, const QDQ::NodeGroup& node_g } } -NodeUnit::NodeUnit(const GraphViewer& graph_viewer, const QDQ::NodeGroup& node_group, - const Node& output_activation_node) - : dq_nodes_{GetQDQIONodes(graph_viewer, node_group, true /* is_input */)}, - target_node_(*graph_viewer.GetNode(node_group.target_node)), - q_nodes_{GetQDQIONodes(graph_viewer, node_group, false /* is_input */)}, - type_(Type::QDQGroup), - inputs_{GetQDQIODefs(target_node_, node_group, true /* is_input */)}, - outputs_{GetQDQIODefs(output_activation_node, node_group, false /* is_input */)} { - input_edge_count_ = std::accumulate(dq_nodes_.cbegin(), dq_nodes_.cend(), size_t(0), - [](size_t acc, const Node* node) { return acc + node->GetInputEdgesCount(); }); - - // add edges for inputs that are not from DQ nodes. there is one edge to each DQ node. - // other inputs could come from initializers or graph inputs (no edges) or other nodes (edge). - input_edge_count_ += target_node_.GetInputEdgesCount() - dq_nodes_.size(); - - // create output edges. each target node output either goes to Q node/s or non-Q node/s. - // ValidateNodeGroupQDQNodes ensures this. - auto cur_edge = output_activation_node.OutputEdgesBegin(); - auto end_edge = output_activation_node.OutputEdgesEnd(); - for (; cur_edge != end_edge; ++cur_edge) { - const Node& node = cur_edge->GetNode(); - - // if node is in q_nodes we hide the Q node. - if (std::find(q_nodes_.cbegin(), q_nodes_.cend(), &node) != q_nodes_.cend()) { - auto src_idx = cur_edge->GetSrcArgIndex(); - auto q_cur_edge = node.OutputEdgesBegin(); - auto q_end_edge = node.OutputEdgesEnd(); - for (; q_cur_edge != q_end_edge; ++q_cur_edge) { - output_edges_.insert(Node::EdgeEnd{q_cur_edge->GetNode(), src_idx, q_cur_edge->GetDstArgIndex()}); - } - } else { - // non-Q node, or Q node that isn't in the QDQ node group (unexpected but may be possible). add as-is. - output_edges_.insert(*cur_edge); - } - } +NodeUnit::NodeUnit(std::vector dq_nodes, const Node& target_node, + std::vector q_nodes, Type type, + std::vector inputs, std::vector outputs, + size_t input_edge_count, Node::EdgeSet output_edges) + : dq_nodes_(std::move(dq_nodes)), + target_node_(target_node), + q_nodes_(std::move(q_nodes)), + type_(type), + inputs_(std::move(inputs)), + outputs_(std::move(outputs)), + input_edge_count_(input_edge_count), + output_edges_(std::move(output_edges)) { } const std::string& NodeUnit::Domain() const noexcept { return target_node_.Domain(); } diff --git a/onnxruntime/core/framework/node_unit.h b/onnxruntime/core/framework/node_unit.h index 494d7bd849b4b..060653170fde0 100644 --- a/onnxruntime/core/framework/node_unit.h +++ b/onnxruntime/core/framework/node_unit.h @@ -68,8 +68,10 @@ class NodeUnit { public: explicit NodeUnit(const Node& node); explicit NodeUnit(const GraphViewer& graph_viewer, const QDQ::NodeGroup& node_group); - explicit NodeUnit(const GraphViewer& graph_viewer, const QDQ::NodeGroup& node_group, - const Node& output_activation_node); + NodeUnit(std::vector dq_nodes, const Node& target_node, + std::vector q_nodes, Type type, + std::vector inputs, std::vector outputs, + size_t input_edge_count, Node::EdgeSet output_edges); Type UnitType() const noexcept { return type_; } diff --git a/onnxruntime/core/providers/qnn/builder/qnn_conv_activation_fusion.cc b/onnxruntime/core/providers/qnn/builder/qnn_conv_activation_fusion.cc index b62c5f21f82ba..e5891baf4ac50 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_conv_activation_fusion.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_conv_activation_fusion.cc @@ -378,19 +378,60 @@ static bool IsValidQDQConv(gsl::span dq_node_units, Status QnnConvActivationFusionAdd(QnnModelWrapper& qnn_model_wrapper, gsl::span dq_node_units, const NodeUnit* conv_node_unit, - const NodeUnit* activation_node_unit, const NodeUnit* q_node_unit, const logging::Logger& logger, bool validate) { - QDQ::NodeGroup custom_node_group; - custom_node_group.dq_nodes.reserve(dq_node_units.size()); - custom_node_group.q_nodes = std::vector{q_node_unit->Index()}; - custom_node_group.target_node = conv_node_unit->Index(); - auto get_node_idx = [](const NodeUnit* n) { return n->Index(); }; - std::transform(dq_node_units.begin(), dq_node_units.end(), std::back_inserter(custom_node_group.dq_nodes), - get_node_idx); - - NodeUnit custom_node_unit(qnn_model_wrapper.GetGraphViewer(), custom_node_group, activation_node_unit->GetNode()); + std::vector dq_nodes; + dq_nodes.reserve(dq_node_units.size()); + for (const NodeUnit* dq_node_unit : dq_node_units) { + dq_nodes.push_back(&dq_node_unit->GetNode()); + } + std::vector q_nodes = {&q_node_unit->GetNode()}; + const Node& target_node = conv_node_unit->GetNode(); + + // Populate NodeUnit inputs + std::vector inputs; + inputs.reserve(dq_node_units.size()); + for (const Node* dq_node : dq_nodes) { + const auto dq_inputs = dq_node->InputDefs(); + const auto& dq_attrs = dq_node->GetAttributes(); + + std::optional axis; + if (auto entry = dq_attrs.find("axis"); entry != dq_attrs.end()) { + axis = entry->second.i(); + } + + // quantization scale and zp are always the input[1, 2] + NodeUnitIODef::QuantParam quant_param{*dq_inputs[1], dq_inputs.size() == 3 ? dq_inputs[2] : nullptr, axis}; + inputs.push_back(NodeUnitIODef{*dq_inputs[0], quant_param}); + } + + // Populate NodeUnit outputs and output edges + std::vector outputs; + Node::EdgeSet output_edges; + for (const Node* q_node : q_nodes) { + const auto q_inputs = q_node->InputDefs(); + const auto& q_attrs = q_node->GetAttributes(); + const auto q_outputs = q_node->OutputDefs(); + + std::optional axis; + if (auto entry = q_attrs.find("axis"); entry != q_attrs.end()) { + axis = entry->second.i(); + } + + // quantization scale and zp are always the input[1, 2] + NodeUnitIODef::QuantParam quant_param{*q_inputs[1], q_inputs.size() == 3 ? q_inputs[2] : nullptr, axis}; + outputs.push_back(NodeUnitIODef{*q_outputs[0], quant_param}); + + auto q_cur_edge = q_node->OutputEdgesBegin(); + auto q_end_edge = q_node->OutputEdgesEnd(); + for (; q_cur_edge != q_end_edge; ++q_cur_edge) { + output_edges.insert(Node::EdgeEnd{q_cur_edge->GetNode(), 0, q_cur_edge->GetDstArgIndex()}); + } + } + + NodeUnit custom_node_unit(dq_nodes, target_node, q_nodes, NodeUnit::Type::QDQGroup, + inputs, outputs, dq_nodes.size(), output_edges); const auto* conv_op_builder = qnn::GetOpBuilder(custom_node_unit.OpType()); if (conv_op_builder == nullptr) { return Status::OK(); @@ -463,7 +504,6 @@ std::optional TryConvActivationFusion(QnnModelWrapper& qnn_model_w if (Status status = QnnConvActivationFusionAdd(tmp_model_wrapper, dq_node_units, &conv_node_unit, - activation_node_unit, q_node_unit, logger, /*validate*/ true); diff --git a/onnxruntime/core/providers/qnn/builder/qnn_conv_activation_fusion.h b/onnxruntime/core/providers/qnn/builder/qnn_conv_activation_fusion.h index 3f5bdfc3078dc..b019dbb9205d6 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_conv_activation_fusion.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_conv_activation_fusion.h @@ -18,7 +18,6 @@ namespace qnn { Status QnnConvActivationFusionAdd(QnnModelWrapper& qnn_model_wrapper, gsl::span dq_node_units, const NodeUnit* conv_node_unit, - const NodeUnit* activation_node_unit, const NodeUnit* q_node_unit, const logging::Logger& logger, bool validate = false); diff --git a/onnxruntime/core/providers/qnn/builder/qnn_fusions.cc b/onnxruntime/core/providers/qnn/builder/qnn_fusions.cc index 04b727f534424..033ac3ce4a555 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_fusions.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_fusions.cc @@ -155,7 +155,6 @@ Status QnnNodeGroup::IsSupported(QnnModelWrapper& qmw, const logging::Logger& lo Status status = QnnConvActivationFusionAdd(qmw, dq_node_units, conv_node_unit, - activation_node_unit, q_node_unit, logger, /*validate*/ true); @@ -225,7 +224,6 @@ Status QnnNodeGroup::AddToModelBuilder(QnnModelWrapper& qmw, const logging::Logg const bool has_bias_dq = num_node_units == 6; std::vector dq_node_units = {node_units_[0], node_units_[1]}; const NodeUnit* conv_node_unit = node_units_[num_node_units - 3]; - const NodeUnit* activation_node_unit = node_units_[num_node_units - 2]; const NodeUnit* q_node_unit = node_units_[num_node_units - 1]; if (has_bias_dq) { @@ -234,7 +232,6 @@ Status QnnNodeGroup::AddToModelBuilder(QnnModelWrapper& qmw, const logging::Logg return QnnConvActivationFusionAdd(qmw, dq_node_units, conv_node_unit, - activation_node_unit, q_node_unit, logger, /*validate*/ false); diff --git a/onnxruntime/test/providers/qnn/qnn_basic_test.cc b/onnxruntime/test/providers/qnn/qnn_basic_test.cc index 37eeac5101feb..5146a1cc14865 100644 --- a/onnxruntime/test/providers/qnn/qnn_basic_test.cc +++ b/onnxruntime/test/providers/qnn/qnn_basic_test.cc @@ -953,7 +953,7 @@ TEST_F(QnnHTPBackendTests, TestOD) { #if 1 const ORTCHAR_T* ort_model_path = ORT_MODEL_FOLDER "od_current_tf2onnx.onnx"; - //so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); + so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); #else const ORTCHAR_T* ort_model_path = ORT_MODEL_FOLDER "unet.preprocessed.quant.onnx_ctx.onnx"; #endif From f21dd7748c12c7fc82a77a395049b1d121ce46c6 Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Sat, 27 Jul 2024 12:30:33 -0700 Subject: [PATCH 08/20] Update NodeUnit constructor to use gsl::span --- onnxruntime/core/framework/node_unit.cc | 14 ++++++------- onnxruntime/core/framework/node_unit.h | 6 +++--- .../qnn/builder/qnn_conv_activation_fusion.cc | 20 ++++++++++++------- .../test/providers/qnn/qnn_basic_test.cc | 2 +- 4 files changed, 24 insertions(+), 18 deletions(-) diff --git a/onnxruntime/core/framework/node_unit.cc b/onnxruntime/core/framework/node_unit.cc index a491edb9699b3..84d6ccb4d7acb 100644 --- a/onnxruntime/core/framework/node_unit.cc +++ b/onnxruntime/core/framework/node_unit.cc @@ -272,16 +272,16 @@ NodeUnit::NodeUnit(const GraphViewer& graph_viewer, const QDQ::NodeGroup& node_g } } -NodeUnit::NodeUnit(std::vector dq_nodes, const Node& target_node, - std::vector q_nodes, Type type, - std::vector inputs, std::vector outputs, +NodeUnit::NodeUnit(gsl::span dq_nodes, const Node& target_node, + gsl::span q_nodes, Type type, + gsl::span inputs, gsl::span outputs, size_t input_edge_count, Node::EdgeSet output_edges) - : dq_nodes_(std::move(dq_nodes)), + : dq_nodes_(dq_nodes.begin(), dq_nodes.end()), target_node_(target_node), - q_nodes_(std::move(q_nodes)), + q_nodes_(q_nodes.begin(), q_nodes.end()), type_(type), - inputs_(std::move(inputs)), - outputs_(std::move(outputs)), + inputs_(inputs.begin(), inputs.end()), + outputs_(outputs.begin(), outputs.end()), input_edge_count_(input_edge_count), output_edges_(std::move(output_edges)) { } diff --git a/onnxruntime/core/framework/node_unit.h b/onnxruntime/core/framework/node_unit.h index 060653170fde0..c2297c13a41e6 100644 --- a/onnxruntime/core/framework/node_unit.h +++ b/onnxruntime/core/framework/node_unit.h @@ -68,9 +68,9 @@ class NodeUnit { public: explicit NodeUnit(const Node& node); explicit NodeUnit(const GraphViewer& graph_viewer, const QDQ::NodeGroup& node_group); - NodeUnit(std::vector dq_nodes, const Node& target_node, - std::vector q_nodes, Type type, - std::vector inputs, std::vector outputs, + NodeUnit(gsl::span dq_nodes, const Node& target_node, + gsl::span q_nodes, Type type, + gsl::span inputs, gsl::span outputs, size_t input_edge_count, Node::EdgeSet output_edges); Type UnitType() const noexcept { return type_; } diff --git a/onnxruntime/core/providers/qnn/builder/qnn_conv_activation_fusion.cc b/onnxruntime/core/providers/qnn/builder/qnn_conv_activation_fusion.cc index e5891baf4ac50..f04ecb01f4f82 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_conv_activation_fusion.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_conv_activation_fusion.cc @@ -381,17 +381,22 @@ Status QnnConvActivationFusionAdd(QnnModelWrapper& qnn_model_wrapper, const NodeUnit* q_node_unit, const logging::Logger& logger, bool validate) { - std::vector dq_nodes; - dq_nodes.reserve(dq_node_units.size()); - for (const NodeUnit* dq_node_unit : dq_node_units) { - dq_nodes.push_back(&dq_node_unit->GetNode()); + const size_t num_dqs = dq_node_units.size(); + constexpr size_t max_num_dqs = 3; + ORT_RETURN_IF_NOT(num_dqs == 2 || num_dqs == max_num_dqs, "QDQ Conv should have 2 or 3 DQs"); + + std::array dq_nodes_buf = {}; + for (size_t i = 0; i < num_dqs; i++) { + dq_nodes_buf[i] = &dq_node_units[i]->GetNode(); } - std::vector q_nodes = {&q_node_unit->GetNode()}; + gsl::span dq_nodes(dq_nodes_buf.data(), num_dqs); + + std::array q_nodes = {&q_node_unit->GetNode()}; const Node& target_node = conv_node_unit->GetNode(); // Populate NodeUnit inputs std::vector inputs; - inputs.reserve(dq_node_units.size()); + inputs.reserve(num_dqs); for (const Node* dq_node : dq_nodes) { const auto dq_inputs = dq_node->InputDefs(); const auto& dq_attrs = dq_node->GetAttributes(); @@ -423,6 +428,7 @@ Status QnnConvActivationFusionAdd(QnnModelWrapper& qnn_model_wrapper, NodeUnitIODef::QuantParam quant_param{*q_inputs[1], q_inputs.size() == 3 ? q_inputs[2] : nullptr, axis}; outputs.push_back(NodeUnitIODef{*q_outputs[0], quant_param}); + // Gather output edges out of the Q node. auto q_cur_edge = q_node->OutputEdgesBegin(); auto q_end_edge = q_node->OutputEdgesEnd(); for (; q_cur_edge != q_end_edge; ++q_cur_edge) { @@ -431,7 +437,7 @@ Status QnnConvActivationFusionAdd(QnnModelWrapper& qnn_model_wrapper, } NodeUnit custom_node_unit(dq_nodes, target_node, q_nodes, NodeUnit::Type::QDQGroup, - inputs, outputs, dq_nodes.size(), output_edges); + inputs, outputs, num_dqs, output_edges); const auto* conv_op_builder = qnn::GetOpBuilder(custom_node_unit.OpType()); if (conv_op_builder == nullptr) { return Status::OK(); diff --git a/onnxruntime/test/providers/qnn/qnn_basic_test.cc b/onnxruntime/test/providers/qnn/qnn_basic_test.cc index 5146a1cc14865..37eeac5101feb 100644 --- a/onnxruntime/test/providers/qnn/qnn_basic_test.cc +++ b/onnxruntime/test/providers/qnn/qnn_basic_test.cc @@ -953,7 +953,7 @@ TEST_F(QnnHTPBackendTests, TestOD) { #if 1 const ORTCHAR_T* ort_model_path = ORT_MODEL_FOLDER "od_current_tf2onnx.onnx"; - so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); + //so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); #else const ORTCHAR_T* ort_model_path = ORT_MODEL_FOLDER "unet.preprocessed.quant.onnx_ctx.onnx"; #endif From a96d342476b3a16437f50999d4728938cdd6010e Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Sat, 27 Jul 2024 15:20:02 -0700 Subject: [PATCH 09/20] Remove initial conv qnn validation when grouping conv+activation --- .../qnn/builder/qnn_conv_activation_fusion.cc | 32 ++----------------- 1 file changed, 3 insertions(+), 29 deletions(-) diff --git a/onnxruntime/core/providers/qnn/builder/qnn_conv_activation_fusion.cc b/onnxruntime/core/providers/qnn/builder/qnn_conv_activation_fusion.cc index f04ecb01f4f82..2cfdeeaac157e 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_conv_activation_fusion.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_conv_activation_fusion.cc @@ -229,10 +229,8 @@ static bool CanClipBeRemoved(const QnnModelWrapper& qnn_model_wrapper, return true; } -static bool CanReluBeRemoved(const QnnModelWrapper& qnn_model_wrapper, - const NodeUnit& relu_node_unit, - const NodeUnit& q_node_unit) { - assert(relu_node_unit.OpType() == "Relu" && q_node_unit.OpType() == QDQ::QOpName); +static bool CanQRelaceRelu(const QnnModelWrapper& qnn_model_wrapper, const NodeUnit& q_node_unit) { + assert(q_node_unit.OpType() == QDQ::QOpName); int32_t zp_data_type = ONNX_NAMESPACE::TensorProto::DataType::TensorProto_DataType_UNDEFINED; int32_t zero_point = 0; float scale = 0.0f; @@ -262,7 +260,7 @@ static bool CanActivationBeRemoved(const QnnModelWrapper& qnn_model_wrapper, const std::string& activation_type = activation_node_unit.OpType(); if (activation_type == "Relu") { - return CanReluBeRemoved(qnn_model_wrapper, activation_node_unit, q_node_unit); + return CanQRelaceRelu(qnn_model_wrapper, q_node_unit); } if (activation_type == "Clip") { @@ -495,30 +493,6 @@ std::optional TryConvActivationFusion(QnnModelWrapper& qnn_model_w return std::nullopt; } - // Create a temporary QnnModelWrapper for validation only. We need to be sure that this fusion will work before - // modifying the actual QnnModelWrapper. This allows us to revert to the traditional OpBuilder workflow if this - // fusion doesn't work out. - QnnModelWrapper tmp_model_wrapper(graph_viewer, - logger, - qnn_model_wrapper.GetQnnInterface(), - qnn_model_wrapper.GetQnnBackendHandle(), - qnn_model_wrapper.GetInputIndexMap(), - qnn_model_wrapper.GetOutputIndexMap(), - qnn_model_wrapper.GetInitializerLookup(), - qnn_model_wrapper.GetQnnBackendType()); - - if (Status status = QnnConvActivationFusionAdd(tmp_model_wrapper, - dq_node_units, - &conv_node_unit, - q_node_unit, - logger, - /*validate*/ true); - !status.IsOK()) { - return std::nullopt; - } - - // Validation passed, so create a QnnNodeGroup. - // If we encounter an error, we return it directly to caller. LOGS(logger, VERBOSE) << "Will use Conv + Activation via fusion. conv_node name: [" << conv_node.Name() << "] activation_node optype: [" << activation_node.OpType() << "] activation_node name: [" << activation_node.Name() From dfee4738f38e1f2ca08d033e2ba95f86553d767d Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Sat, 27 Jul 2024 17:23:20 -0700 Subject: [PATCH 10/20] Refactor and use jump tables to compare cpu usage on session creation --- .../core/providers/qnn/builder/qnn_fusions.cc | 614 ------------------ .../core/providers/qnn/builder/qnn_model.cc | 2 +- .../{qnn_fusions.h => qnn_node_group.h} | 0 .../conv_activation_fusion.cc} | 82 ++- .../conv_activation_fusion.h} | 17 +- .../qnn/builder/qnn_node_group/dq_q_fusion.cc | 174 +++++ .../qnn/builder/qnn_node_group/dq_q_fusion.h | 45 ++ .../qnn_node_group/hardsigmoid_mul_fusion.cc | 191 ++++++ .../qnn_node_group/hardsigmoid_mul_fusion.h | 47 ++ .../builder/qnn_node_group/qnn_node_group.cc | 264 ++++++++ .../qnn/builder/qnn_node_group/utils.cc | 0 .../qnn/builder/qnn_node_group/utils.h | 0 .../providers/qnn/qnn_execution_provider.cc | 1 - .../providers/qnn/qnn_execution_provider.h | 1 + 14 files changed, 807 insertions(+), 631 deletions(-) delete mode 100644 onnxruntime/core/providers/qnn/builder/qnn_fusions.cc rename onnxruntime/core/providers/qnn/builder/{qnn_fusions.h => qnn_node_group.h} (100%) rename onnxruntime/core/providers/qnn/builder/{qnn_conv_activation_fusion.cc => qnn_node_group/conv_activation_fusion.cc} (84%) rename onnxruntime/core/providers/qnn/builder/{qnn_conv_activation_fusion.h => qnn_node_group/conv_activation_fusion.h} (58%) create mode 100644 onnxruntime/core/providers/qnn/builder/qnn_node_group/dq_q_fusion.cc create mode 100644 onnxruntime/core/providers/qnn/builder/qnn_node_group/dq_q_fusion.h create mode 100644 onnxruntime/core/providers/qnn/builder/qnn_node_group/hardsigmoid_mul_fusion.cc create mode 100644 onnxruntime/core/providers/qnn/builder/qnn_node_group/hardsigmoid_mul_fusion.h create mode 100644 onnxruntime/core/providers/qnn/builder/qnn_node_group/qnn_node_group.cc create mode 100644 onnxruntime/core/providers/qnn/builder/qnn_node_group/utils.cc create mode 100644 onnxruntime/core/providers/qnn/builder/qnn_node_group/utils.h diff --git a/onnxruntime/core/providers/qnn/builder/qnn_fusions.cc b/onnxruntime/core/providers/qnn/builder/qnn_fusions.cc deleted file mode 100644 index 033ac3ce4a555..0000000000000 --- a/onnxruntime/core/providers/qnn/builder/qnn_fusions.cc +++ /dev/null @@ -1,614 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/providers/qnn/builder/qnn_fusions.h" - -#include -#include -#include -#include -#include -#include -#include -#include "core/graph/graph_utils.h" -#include "core/optimizer/qdq_transformer/qdq_util.h" -#include "core/framework/node_unit.h" -#include "core/providers/qnn/builder/qnn_utils.h" -#include "core/providers/qnn/builder/qnn_model_wrapper.h" -#include "core/providers/qnn/builder/op_builder_factory.h" -#include "core/providers/qnn/builder/qnn_conv_activation_fusion.h" - -namespace onnxruntime { -namespace qnn { - -static Status QnnDQQFusionAdd(QnnModelWrapper& qnn_model_wrapper, - const NodeUnit& dq_node_unit, - const NodeUnit& q_node_unit, - const logging::Logger& logger, - bool validate = false) { - ORT_UNUSED_PARAMETER(logger); - assert(dq_node_unit.OpType() == QDQ::DQOpName && q_node_unit.OpType() == QDQ::QOpName); - const auto& node_name = utils::GetNodeName(dq_node_unit); - const NodeUnitIODef& input_def = dq_node_unit.Inputs()[0]; - const NodeUnitIODef& output_def = q_node_unit.Outputs()[0]; - - QnnTensorWrapper input_tensor; - QnnTensorWrapper output_tensor; - - ORT_RETURN_IF_ERROR(qnn_model_wrapper.MakeTensorWrapper(input_def, input_tensor)); - ORT_RETURN_IF_ERROR(qnn_model_wrapper.MakeTensorWrapper(output_def, output_tensor)); - - if (validate) { - ORT_RETURN_IF_ERROR(qnn_model_wrapper.ValidateQnnNode(node_name, - QNN_OP_PACKAGE_NAME_QTI_AISW, - QNN_OP_CONVERT, - {input_tensor.GetQnnTensor()}, - {output_tensor.GetQnnTensor()}, - {})); - } else { - ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(input_tensor)), "Failed to add input"); - ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(output_tensor)), "Failed to add output"); - ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(utils::GetNodeName(q_node_unit), - QNN_OP_PACKAGE_NAME_QTI_AISW, - QNN_OP_CONVERT, - {input_def.node_arg.Name()}, - {output_def.node_arg.Name()}, - {}, - validate), - "Failed to add fused Convert node."); - } - - return Status::OK(); -} - -static Status QnnHardSigmoidMulFusionAdd(QnnModelWrapper& qnn_model_wrapper, - const NodeUnit& hardsigmoid_node_unit, - const NodeUnit& mul_node_unit, - const logging::Logger& logger, - bool validate = false) { - ORT_UNUSED_PARAMETER(logger); - assert(hardsigmoid_node_unit.OpType() == "HardSigmoid" && mul_node_unit.OpType() == "Mul"); - const auto& node_name = utils::GetNodeName(hardsigmoid_node_unit); - const NodeUnitIODef& input_def = hardsigmoid_node_unit.Inputs()[0]; - const NodeUnitIODef& output_def = mul_node_unit.Outputs()[0]; - - QnnTensorWrapper input_tensor; - QnnTensorWrapper output_tensor; - - ORT_RETURN_IF_ERROR(qnn_model_wrapper.MakeTensorWrapper(input_def, input_tensor)); - ORT_RETURN_IF_ERROR(qnn_model_wrapper.MakeTensorWrapper(output_def, output_tensor)); - - if (validate) { - ORT_RETURN_IF_ERROR(qnn_model_wrapper.ValidateQnnNode(node_name, - QNN_OP_PACKAGE_NAME_QTI_AISW, - QNN_OP_HARD_SWISH, - {input_tensor.GetQnnTensor()}, - {output_tensor.GetQnnTensor()}, - {})); - } else { - LOGS(logger, VERBOSE) << " Adding QNN HardSwish via fusion. HardSigmoid name: [" << hardsigmoid_node_unit.Name() - << "] Mul name: [" << mul_node_unit.Name() << "]"; - - ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(input_tensor)), "Failed to add input"); - ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(output_tensor)), "Failed to add output"); - ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(node_name, - QNN_OP_PACKAGE_NAME_QTI_AISW, - QNN_OP_HARD_SWISH, - {input_def.node_arg.Name()}, - {output_def.node_arg.Name()}, - {}, - validate), - "Failed to add fused HardSwish node."); - } - - return Status::OK(); -} - -std::string_view QnnNodeGroup::TypeToString(QnnNodeGroup::Type type) { - static std::array(QnnNodeGroup::Type::COUNT)> type_names = { - "Undefined", - "NodeUnit", - "ConvActivationFusion", - "DQQFusion", - "HardSigmoidMulFusion", - }; - - return type_names[static_cast(type)]; -} - -Status QnnNodeGroup::IsSupported(QnnModelWrapper& qmw, const logging::Logger& logger) const { - switch (type_) { - case Type::NodeUnit: { - ORT_RETURN_IF_NOT(node_units_.size() == 1 && node_units_[0] != nullptr, ""); - const NodeUnit& node_unit = *node_units_[0]; - const std::string& op_type = node_unit.OpType(); - const auto* op_builder = qnn::GetOpBuilder(op_type); - - if (op_builder == nullptr) { - std::string err_msg = MakeString("Operators of type `", op_type, - "` are not supported by QNN EP.", op_type, " node `", - node_unit.Name(), "` will not be assigned to QNN EP."); - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, err_msg); - } - - Status status = op_builder->IsOpSupported(qmw, *node_units_[0], logger); - if (!status.IsOK()) { - LOGS(logger, WARNING) << op_type << " node `" << node_unit.Name() - << "` is not supported: " << status.ErrorMessage(); - } - - return status; - } - case Type::ConvActivationFusion: { - const size_t num_node_units = node_units_.size(); - ORT_RETURN_IF_NOT((num_node_units == 5 || num_node_units == 6), ""); - - const bool has_bias_dq = num_node_units == 6; - std::vector dq_node_units = {node_units_[0], node_units_[1]}; - const NodeUnit* conv_node_unit = node_units_[num_node_units - 3]; - const NodeUnit* activation_node_unit = node_units_[num_node_units - 2]; - const NodeUnit* q_node_unit = node_units_[num_node_units - 1]; - - if (has_bias_dq) { - dq_node_units.push_back(node_units_[2]); - } - Status status = QnnConvActivationFusionAdd(qmw, - dq_node_units, - conv_node_unit, - q_node_unit, - logger, - /*validate*/ true); - - if (!status.IsOK()) { - LOGS(logger, ERROR) << conv_node_unit->OpType() << "/" << activation_node_unit->OpType() - << " fusion is not supported, but should be according to initial validation." - << " Node names: " << conv_node_unit->Name() << ", " << activation_node_unit->Name() - << " Error: " << status.ErrorMessage(); - } - - return status; - } - case Type::DQQFusion: { - ORT_RETURN_IF_NOT(node_units_.size() == 2, "Expected 2 NodeUnits for DQ -> Q fusion"); - const NodeUnit* dq_node_unit = node_units_[0]; - const NodeUnit* q_node_unit = node_units_[1]; - ORT_RETURN_IF_NOT(dq_node_unit != nullptr && q_node_unit != nullptr, ""); - Status status = QnnDQQFusionAdd(qmw, *dq_node_unit, *q_node_unit, logger, /*validate*/ true); - - if (!status.IsOK()) { - LOGS(logger, ERROR) << "(DQ -> Q) into QNN Convert fusion is not supported, " - << "but should be according to initial validation. " - << "Node names: " << dq_node_unit->Name() << ", " << q_node_unit->Name() - << " Error: " << status.ErrorMessage(); - } - - return status; - } - case Type::HardSigmoidMulFusion: { - ORT_RETURN_IF_NOT(node_units_.size() == 2, "Expected 2 NodeUnits for HardSimoid -> Mul fusion"); - const NodeUnit* hardsigmoid_node_unit = node_units_[0]; - const NodeUnit* mul_node_unit = node_units_[1]; - ORT_RETURN_IF_NOT(hardsigmoid_node_unit != nullptr && mul_node_unit != nullptr, ""); - Status status = QnnHardSigmoidMulFusionAdd(qmw, *hardsigmoid_node_unit, *mul_node_unit, logger, - /*validate*/ true); - - if (!status.IsOK()) { - LOGS(logger, ERROR) << "(HardSigmoid -> Mul) into QNN HardSwish fusion is not supported, " - << "but should be according to initial validation. " - << "Node names: " << hardsigmoid_node_unit->Name() << ", " << mul_node_unit->Name() - << " Error: " << status.ErrorMessage(); - } - - return status; - } - default: - std::string error_msg = MakeString("Unhandled QnnNodeGroup::Type ", TypeToString(type_), - " in QnnNodeGroup::IsSupported()"); - LOGS(logger, ERROR) << error_msg; - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, error_msg); - } -} - -Status QnnNodeGroup::AddToModelBuilder(QnnModelWrapper& qmw, const logging::Logger& logger) const { - switch (type_) { - case Type::NodeUnit: { - ORT_RETURN_IF_NOT(node_units_.size() == 1 && node_units_[0] != nullptr, ""); - const auto* op_builder = qnn::GetOpBuilder(node_units_[0]->OpType()); - ORT_RETURN_IF_NOT(op_builder != nullptr, "[QNN EP]: Missing OpBuilder for OpType ", node_units_[0]->OpType()); - return op_builder->AddToModelBuilder(qmw, *node_units_[0], logger, /*do_op_validation*/ false); - } - case Type::ConvActivationFusion: { - const size_t num_node_units = node_units_.size(); - ORT_RETURN_IF_NOT((num_node_units == 5 || num_node_units == 6), ""); - - const bool has_bias_dq = num_node_units == 6; - std::vector dq_node_units = {node_units_[0], node_units_[1]}; - const NodeUnit* conv_node_unit = node_units_[num_node_units - 3]; - const NodeUnit* q_node_unit = node_units_[num_node_units - 1]; - - if (has_bias_dq) { - dq_node_units.push_back(node_units_[2]); - } - return QnnConvActivationFusionAdd(qmw, - dq_node_units, - conv_node_unit, - q_node_unit, - logger, - /*validate*/ false); - } - case Type::DQQFusion: { - ORT_RETURN_IF_NOT(node_units_.size() == 2, "Expected 2 NodeUnits for DQ -> Q fusion"); - const NodeUnit* dq_node_unit = node_units_[0]; - const NodeUnit* q_node_unit = node_units_[1]; - ORT_RETURN_IF_NOT(dq_node_unit != nullptr && q_node_unit != nullptr, ""); - return QnnDQQFusionAdd(qmw, *dq_node_unit, *q_node_unit, logger, /*validate*/ false); - } - case Type::HardSigmoidMulFusion: { - ORT_RETURN_IF_NOT(node_units_.size() == 2, "Expected 2 NodeUnits for HardSimoid -> Mul fusion"); - const NodeUnit* hardsigmoid_node_unit = node_units_[0]; - const NodeUnit* mul_node_unit = node_units_[1]; - ORT_RETURN_IF_NOT(hardsigmoid_node_unit != nullptr && mul_node_unit != nullptr, ""); - return QnnHardSigmoidMulFusionAdd(qmw, *hardsigmoid_node_unit, *mul_node_unit, logger, /*validate*/ false); - } - default: - std::string error_msg = MakeString("Unhandled QnnNodeGroup::Type ", TypeToString(type_), - " in QnnNodeGroup::AddToModelBuilder()"); - LOGS(logger, ERROR) << error_msg; - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, error_msg); - } -} - -const NodeUnit* QnnNodeGroup::GetTargetNodeUnit(const logging::Logger& logger) const { - switch (type_) { - case Type::NodeUnit: { - if (node_units_.size() != 1) { - return nullptr; - } - return node_units_[0]; - } - case Type::ConvActivationFusion: { - const size_t num_node_units = node_units_.size(); - if (!(num_node_units == 5 || num_node_units == 6)) { - return nullptr; - } - return node_units_[num_node_units - 3]; - } - case Type::DQQFusion: { - if (node_units_.size() != 2) { - return nullptr; - } - return node_units_[0]; - } - case Type::HardSigmoidMulFusion: { - if (node_units_.size() != 2) { - return nullptr; - } - return node_units_[0]; - } - default: - std::string error_msg = MakeString("Unhandled QnnNodeGroup::Type ", TypeToString(type_), - " in QnnNodeGroup::AddToModelBuilder()"); - LOGS(logger, ERROR) << error_msg; - return nullptr; - } -} - -/** - * Tries to merge a DQ -> Q sequence into a QNN Convert operator. The DQ -> Q must be converting from - * one quantization type (e.g., uint8_t) to another (e.g., uint16_t). - * - * \param fused_nodes Output list of node units that were fused. Remains empty if fusion is not applied. - * \param qnn_model_wrapper The QNN model that is being built. - * \param dq_node_unit The DQ node unit. - * \param q_node_unit The Q node unit. - * \param logger The logger. - * \param do_op_validation True if should call QNN operator validation APIs. - * \return An onnxruntime::Status - */ -static std::optional TryDQQFusion( - QnnModelWrapper& qnn_model_wrapper, - const NodeUnit& dq_node_unit, - const std::unordered_map& node_to_node_unit, - const std::unordered_map& node_unit_to_qnn_node_group, - const logging::Logger& logger) { - // Expect that this function is called with a standalone DQ. - assert(dq_node_unit.OpType() == QDQ::DQOpName && dq_node_unit.UnitType() == NodeUnit::Type::SingleNode); - - const GraphViewer& graph_viewer = qnn_model_wrapper.GetGraphViewer(); - const Node& dq_node = dq_node_unit.GetNode(); - - // DQ must have a single child (1 output edge) and must not produce a graph output. - if (dq_node.GetOutputEdgesCount() != 1 || graph_viewer.NodeProducesGraphOutput(dq_node)) { - return std::nullopt; - } - - const Node& q_node = dq_node.OutputEdgesBegin()->GetNode(); - if (q_node.OpType() != QDQ::QOpName) { - return std::nullopt; - } - - if (graph_viewer.GetNode(q_node.Index()) == nullptr) { - return std::nullopt; // Node is not in this GraphViewer - } - - const auto q_node_unit_it = node_to_node_unit.find(&q_node); - if (q_node_unit_it == node_to_node_unit.end()) { - return std::nullopt; - } - const NodeUnit* q_node_unit = q_node_unit_it->second; - - // child must not already be part of a QDQ NodeUnit (i.e., be standalone). - if (q_node_unit->UnitType() != NodeUnit::Type::SingleNode) { - return std::nullopt; - } - - // Check if child node has already been handled. Should not be the case if this - // fusion function has been called in topological order, but check to be safe. - if (node_unit_to_qnn_node_group.count(q_node_unit) != 0) { - return std::nullopt; - } - - auto get_const_initializer = [&graph_viewer](const std::string& initializer_name) { - return graph_viewer.GetConstantInitializer(initializer_name, true); - }; - - // DQ and Q must have equal scale type and different zp type. - if (!QDQ::IsDQQConversion(dq_node, q_node, get_const_initializer, graph_viewer.ModelPath())) { - return std::nullopt; - } - - if (Status status = QnnDQQFusionAdd(qnn_model_wrapper, dq_node_unit, *q_node_unit, - logger, /*validate*/ true); - !status.IsOK()) { - return std::nullopt; - } - - // Validation passed, so create a QnnNodeGroup. - LOGS(logger, VERBOSE) << " Will use QNN Convert via fusion. dq_node name: [" << dq_node.Name() - << "] dq_node optype: [" << dq_node.OpType() - << "] q_node name: [" << q_node_unit->Name() - << "] q_node optype: [" << q_node_unit->OpType() - << "]"; - - std::optional qnn_node_group = QnnNodeGroup{}; - qnn_node_group->type_ = QnnNodeGroup::Type::DQQFusion; - qnn_node_group->node_units_.push_back(&dq_node_unit); - qnn_node_group->node_units_.push_back(q_node_unit); - - return qnn_node_group; -} - -/** - * Tries to fuse the sequence `x * HardSigmoid(x)` into a single HardSwish(x) operator. - * Should be called in a topologically ordered iteration of node units. - * - * \param fused_nodes Output list of node units that were fused. Remains empty if fusion was not applied. - * \param qnn_model_wrapper The QNN model that is being built. - * \param starting_node The node unit that could potentially start the sequence. - * \param node_unit_map Maps a node to its node unit. - * \param handled_node_units Set of node units that have already been processed. Fusion will not fuse nodes - * in this set. - * \param logger The logger. - * \param do_op_validation True if should call QNN operator validation APIs. - * \return A Status indicating a potential failure. - */ -static std::optional TryHardSigmoidMulFusion( - QnnModelWrapper& qnn_model_wrapper, - const NodeUnit& hardsigmoid_node_unit, - const std::unordered_map& node_to_node_unit, - const std::unordered_map& node_unit_to_qnn_node_group, - const logging::Logger& logger) { - // Looking for a standalone HardSigmoid to start the sequence. - if (hardsigmoid_node_unit.OpType() != "HardSigmoid" || - hardsigmoid_node_unit.UnitType() != NodeUnit::Type::SingleNode) { - return std::nullopt; - } - - NodeAttrHelper hs_attr_helper(hardsigmoid_node_unit); - float alpha = hs_attr_helper.Get("alpha", 0.2f); - float beta = hs_attr_helper.Get("beta", 0.5f); - constexpr float req_alpha = 1.0f / 6.0f; - constexpr float req_beta = 0.5f; - constexpr float alpha_eps = std::numeric_limits::epsilon() * req_alpha; - constexpr float beta_eps = std::numeric_limits::epsilon() * req_beta; - - // Check for explicit values of alpha and beta. - if (std::abs(alpha - req_alpha) > alpha_eps || std::abs(beta - req_beta) > beta_eps) { - return std::nullopt; - } - - const GraphViewer& graph_viewer = qnn_model_wrapper.GetGraphViewer(); - const Node& hs_node = hardsigmoid_node_unit.GetNode(); - - // HardSigmoid must have a single child (1 output edge) and must not produce a graph output. - if (hs_node.GetOutputEdgesCount() != 1 || graph_viewer.NodeProducesGraphOutput(hs_node)) { - return std::nullopt; - } - - const Node& mul_node = hs_node.OutputEdgesBegin()->GetNode(); - if (mul_node.OpType() != "Mul") { - return std::nullopt; - } - - if (graph_viewer.GetNode(mul_node.Index()) == nullptr) { - return std::nullopt; // Node is not in this GraphViewer - } - - const auto mul_node_unit_it = node_to_node_unit.find(&mul_node); - if (mul_node_unit_it == node_to_node_unit.end()) { - return std::nullopt; - } - const NodeUnit* mul_node_unit = mul_node_unit_it->second; - - // Check if Mul node has already been handled. Should not be the case if this - // fusion function has been called in topological order, but check to be safe. - if (node_unit_to_qnn_node_group.count(mul_node_unit) != 0) { - return std::nullopt; - } - - // Mul child must not already be part of a QDQ NodeUnit (i.e., be standalone). - if (mul_node_unit->UnitType() != NodeUnit::Type::SingleNode) { - return std::nullopt; - } - - // Input to HardSigmoid must also be the other input to the Mul. - auto& hs_input_name = hardsigmoid_node_unit.Inputs()[0].node_arg.Name(); - const bool same_root_input = mul_node.InputDefs()[0]->Name() == hs_input_name || - mul_node.InputDefs()[1]->Name() == hs_input_name; - - if (!same_root_input) { - return std::nullopt; - } - - if (Status status = QnnHardSigmoidMulFusionAdd(qnn_model_wrapper, hardsigmoid_node_unit, *mul_node_unit, - logger, /*validate*/ true); - !status.IsOK()) { - return std::nullopt; - } - - // Validation passed, so create a QnnNodeGroup. Any errors are now passed back to the caller. - LOGS(logger, VERBOSE) << "Will use QNN HardSwish via fusion. HardSigmoid name: [" << hardsigmoid_node_unit.Name() - << "] Mul name: [" << mul_node_unit->Name() << "]"; - - std::optional qnn_node_group = QnnNodeGroup{}; - qnn_node_group->type_ = QnnNodeGroup::Type::HardSigmoidMulFusion; - qnn_node_group->node_units_.push_back(&hardsigmoid_node_unit); - qnn_node_group->node_units_.push_back(mul_node_unit); - - return qnn_node_group; -} - -using FusionFunc = std::optional (*)( - QnnModelWrapper&, - const NodeUnit&, - const std::unordered_map&, - const std::unordered_map&, - const logging::Logger&); - -static std::optional TryQnnFusions( - QnnModelWrapper& qnn_model_wrapper, - const NodeUnit& starting_node_unit, - const std::unordered_map& node_to_node_unit, - const std::unordered_map& node_unit_to_qnn_node_group, - const logging::Logger& logger) { - // Maps a starting operator type to the fusion function. - static std::unordered_map fusions = { - {"DequantizeLinear", TryDQQFusion}, - {"HardSigmoid", TryHardSigmoidMulFusion}, - {"Conv", TryConvActivationFusion}, - {"ConvTranspose", TryConvActivationFusion}, - }; - - // For now, all fusions involve standalone node units (i.e., no wrapping DQ/Q nodes). - if (starting_node_unit.UnitType() != NodeUnit::Type::SingleNode) { - return std::nullopt; - } - - auto iter = fusions.find(starting_node_unit.OpType()); - if (iter != fusions.end()) { - FusionFunc fusion_func = iter->second; - return fusion_func(qnn_model_wrapper, starting_node_unit, node_to_node_unit, - node_unit_to_qnn_node_group, logger); - } - return std::nullopt; -} - -Status GetQnnNodeGroups(/*out*/ std::vector& qnn_node_groups, - QnnModelWrapper& qnn_model_wrapper, - const std::unordered_map& node_to_node_unit, - const size_t num_node_units, - const logging::Logger& logger) { - const GraphViewer& graph_viewer = qnn_model_wrapper.GetGraphViewer(); - const std::vector sorted_node_indices = graph_viewer.GetNodesInTopologicalOrder(); - - std::vector sorted_qnn_node_group_indices; - sorted_qnn_node_group_indices.reserve(num_node_units); - - std::vector tmp_qnn_node_groups; - tmp_qnn_node_groups.reserve(num_node_units); - - { - std::unordered_map node_unit_to_qnn_node_group; - std::vector> sorted_node_units; - sorted_node_units.reserve(num_node_units); - - // Create QnnNodeGroups for fusions first. - for (NodeIndex node_index : sorted_node_indices) { - gsl::not_null node = graph_viewer.GetNode(node_index); - - // Get the NodeUnit associated with the node. - const auto node_unit_it = node_to_node_unit.find(node); - ORT_RETURN_IF_NOT(node_unit_it != node_to_node_unit.end(), "Could not find NodeUnit for Node ", node->Name()); - gsl::not_null node_unit = node_unit_it->second; - - // Skip this node if it is not the NodeUnit's target node to ensure NodeUnits are visited in topological order. - if (node != &node_unit->GetNode()) { - continue; - } - - sorted_node_units.push_back(node_unit); - - if (node_unit_to_qnn_node_group.count(node_unit) != 0) { - continue; // Already handled this node unit - } - - std::optional fused_node_group = TryQnnFusions(qnn_model_wrapper, *node_unit, - node_to_node_unit, node_unit_to_qnn_node_group, - logger); - - if (fused_node_group.has_value()) { - const QnnNodeGroup::IndexType index = tmp_qnn_node_groups.size(); - fused_node_group->index_ = index; - - for (const NodeUnit* fused_node_unit : fused_node_group->GetNodeUnits()) { - assert(fused_node_unit != nullptr); - node_unit_to_qnn_node_group.insert({fused_node_unit, index}); - } - - tmp_qnn_node_groups.push_back(std::move(*fused_node_group)); - } - } - - // Create QnnNodeGroups for the leftover NodeUnits. - for (gsl::not_null node_unit : sorted_node_units) { - const auto it = node_unit_to_qnn_node_group.find(node_unit); - if (it != node_unit_to_qnn_node_group.end()) { - // Already handled this NodeUnit. - const QnnNodeGroup& qnn_node_group = tmp_qnn_node_groups[it->second]; - if (node_unit == qnn_node_group.GetTargetNodeUnit(logger)) { - sorted_qnn_node_group_indices.push_back(qnn_node_group.index_); - } - continue; - } - - const QnnNodeGroup::IndexType index = tmp_qnn_node_groups.size(); - QnnNodeGroup fused_node_group = {}; - fused_node_group.type_ = QnnNodeGroup::Type::NodeUnit; - fused_node_group.index_ = index; - fused_node_group.node_units_.resize(1); - fused_node_group.node_units_[0] = node_unit; - tmp_qnn_node_groups.push_back(std::move(fused_node_group)); - - node_unit_to_qnn_node_group.insert({node_unit, index}); - sorted_qnn_node_group_indices.push_back(index); - } - - assert(tmp_qnn_node_groups.size() == sorted_qnn_node_group_indices.size()); - } - - // Copy QnnNodeGroups to output in sorted (topological) order. - qnn_node_groups.resize(0); - qnn_node_groups.reserve(tmp_qnn_node_groups.size()); - for (auto index : sorted_qnn_node_group_indices) { - assert(index < tmp_qnn_node_groups.size()); - QnnNodeGroup qnn_node_group = std::move(tmp_qnn_node_groups[index]); - qnn_node_group.index_ = qnn_node_groups.size(); - qnn_node_groups.push_back(std::move(qnn_node_group)); - } - - assert(qnn_node_groups.size() == sorted_qnn_node_group_indices.size()); - - return Status::OK(); -} -} // namespace qnn -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model.cc b/onnxruntime/core/providers/qnn/builder/qnn_model.cc index 4a74527566b7d..47d4a13b071ab 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_model.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_model.cc @@ -7,7 +7,7 @@ #include "QnnOpDef.h" #include "core/providers/qnn/builder/op_builder_factory.h" -#include "core/providers/qnn/builder/qnn_fusions.h" +#include "core/providers/qnn/builder/qnn_node_group.h" #include "core/providers/shared/utils/utils.h" #include "core/framework/utils.h" #include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h" diff --git a/onnxruntime/core/providers/qnn/builder/qnn_fusions.h b/onnxruntime/core/providers/qnn/builder/qnn_node_group.h similarity index 100% rename from onnxruntime/core/providers/qnn/builder/qnn_fusions.h rename to onnxruntime/core/providers/qnn/builder/qnn_node_group.h diff --git a/onnxruntime/core/providers/qnn/builder/qnn_conv_activation_fusion.cc b/onnxruntime/core/providers/qnn/builder/qnn_node_group/conv_activation_fusion.cc similarity index 84% rename from onnxruntime/core/providers/qnn/builder/qnn_conv_activation_fusion.cc rename to onnxruntime/core/providers/qnn/builder/qnn_node_group/conv_activation_fusion.cc index 2cfdeeaac157e..5e947446021e6 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_conv_activation_fusion.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/conv_activation_fusion.cc @@ -1,4 +1,4 @@ -#include "core/providers/qnn/builder/qnn_conv_activation_fusion.h" +#include "core/providers/qnn/builder/qnn_node_group/conv_activation_fusion.h" #include #include @@ -373,12 +373,12 @@ static bool IsValidQDQConv(gsl::span dq_node_units, return true; } -Status QnnConvActivationFusionAdd(QnnModelWrapper& qnn_model_wrapper, - gsl::span dq_node_units, - const NodeUnit* conv_node_unit, - const NodeUnit* q_node_unit, - const logging::Logger& logger, - bool validate) { +static Status QnnConvActivationFusionAdd(QnnModelWrapper& qnn_model_wrapper, + gsl::span dq_node_units, + const NodeUnit* conv_node_unit, + const NodeUnit* q_node_unit, + const logging::Logger& logger, + bool validate) { const size_t num_dqs = dq_node_units.size(); constexpr size_t max_num_dqs = 3; ORT_RETURN_IF_NOT(num_dqs == 2 || num_dqs == max_num_dqs, "QDQ Conv should have 2 or 3 DQs"); @@ -507,5 +507,73 @@ std::optional TryConvActivationFusion(QnnModelWrapper& qnn_model_w return qnn_node_group; } + +namespace conv_act_fusion { + +Status IsSupported(QnnModelWrapper& qmw, const QnnNodeGroup& qnn_node_group, const logging::Logger& logger) { + const size_t num_node_units = qnn_node_group.node_units_.size(); + ORT_RETURN_IF_NOT((num_node_units == 5 || num_node_units == 6), ""); + + const bool has_bias_dq = num_node_units == 6; + std::vector dq_node_units = {qnn_node_group.node_units_[0], qnn_node_group.node_units_[1]}; + const NodeUnit* conv_node_unit = qnn_node_group.node_units_[num_node_units - 3]; + const NodeUnit* activation_node_unit = qnn_node_group.node_units_[num_node_units - 2]; + const NodeUnit* q_node_unit = qnn_node_group.node_units_[num_node_units - 1]; + + if (has_bias_dq) { + dq_node_units.push_back(qnn_node_group.node_units_[2]); + } + Status status = QnnConvActivationFusionAdd(qmw, + dq_node_units, + conv_node_unit, + q_node_unit, + logger, + /*validate*/ true); + + if (!status.IsOK()) { + LOGS(logger, ERROR) << conv_node_unit->OpType() << "/" << activation_node_unit->OpType() + << " fusion is not supported, but should be according to initial validation." + << " Node names: " << conv_node_unit->Name() << ", " << activation_node_unit->Name() + << " Error: " << status.ErrorMessage(); + } + + return status; +} + +Status AddToModelBuilder(QnnModelWrapper& qmw, const QnnNodeGroup& qnn_node_group, const logging::Logger& logger) { + const size_t num_node_units = qnn_node_group.node_units_.size(); + ORT_RETURN_IF_NOT((num_node_units == 5 || num_node_units == 6), ""); + + const bool has_bias_dq = num_node_units == 6; + std::vector dq_node_units = {qnn_node_group.node_units_[0], qnn_node_group.node_units_[1]}; + const NodeUnit* conv_node_unit = qnn_node_group.node_units_[num_node_units - 3]; + const NodeUnit* q_node_unit = qnn_node_group.node_units_[num_node_units - 1]; + + if (has_bias_dq) { + dq_node_units.push_back(qnn_node_group.node_units_[2]); + } + return QnnConvActivationFusionAdd(qmw, + dq_node_units, + conv_node_unit, + q_node_unit, + logger, + /*validate*/ false); +} + +#if 0 +const std::vector& GetNodeUnits(const QnnNodeGroup& qnn_node_group) { + return qnn_node_group.node_units_; +} +#endif + +const NodeUnit* GetTargetNodeUnit(const QnnNodeGroup& qnn_node_group, const logging::Logger& logger) { + ORT_UNUSED_PARAMETER(logger); + const size_t num_node_units = qnn_node_group.node_units_.size(); + if (!(num_node_units == 5 || num_node_units == 6)) { + return nullptr; + } + return qnn_node_group.node_units_[num_node_units - 3]; +} +} // namespace conv_act_fusion } // namespace qnn } // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/qnn_conv_activation_fusion.h b/onnxruntime/core/providers/qnn/builder/qnn_node_group/conv_activation_fusion.h similarity index 58% rename from onnxruntime/core/providers/qnn/builder/qnn_conv_activation_fusion.h rename to onnxruntime/core/providers/qnn/builder/qnn_node_group/conv_activation_fusion.h index b019dbb9205d6..50b02595f5f72 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_conv_activation_fusion.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/conv_activation_fusion.h @@ -10,23 +10,24 @@ #include "core/framework/node_unit.h" #include "core/providers/qnn/builder/qnn_model_wrapper.h" -#include "core/providers/qnn/builder/qnn_fusions.h" +#include "core/providers/qnn/builder/qnn_node_group.h" namespace onnxruntime { namespace qnn { -Status QnnConvActivationFusionAdd(QnnModelWrapper& qnn_model_wrapper, - gsl::span dq_node_units, - const NodeUnit* conv_node_unit, - const NodeUnit* q_node_unit, - const logging::Logger& logger, - bool validate = false); - std::optional TryConvActivationFusion( QnnModelWrapper& qnn_model_wrapper, const NodeUnit& conv_node_unit, const std::unordered_map& node_to_node_unit, const std::unordered_map& node_unit_to_qnn_node_group, const logging::Logger& logger); + +namespace conv_act_fusion { + +Status IsSupported(QnnModelWrapper& qmw, const QnnNodeGroup& qnn_node_group, const logging::Logger& logger); +Status AddToModelBuilder(QnnModelWrapper& qmw, const QnnNodeGroup& qnn_node_group, const logging::Logger& logger); +// const std::vector& GetNodeUnits(const QnnNodeGroup& qnn_node_group); +const NodeUnit* GetTargetNodeUnit(const QnnNodeGroup& qnn_node_group, const logging::Logger& logger); +} // namespace conv_act_fusion } // namespace qnn } // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/dq_q_fusion.cc b/onnxruntime/core/providers/qnn/builder/qnn_node_group/dq_q_fusion.cc new file mode 100644 index 0000000000000..25bce1fe39a7c --- /dev/null +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/dq_q_fusion.cc @@ -0,0 +1,174 @@ +#include "core/providers/qnn/builder/qnn_node_group/dq_q_fusion.h" + +#include +#include +#include +#include +#include +#include "core/graph/graph_utils.h" +#include "core/optimizer/qdq_transformer/qdq_util.h" +#include "core/framework/node_unit.h" +#include "core/providers/shared/utils/utils.h" +#include "core/providers/qnn/builder/qnn_utils.h" +#include "core/providers/qnn/builder/op_builder_factory.h" + +namespace onnxruntime { +namespace qnn { + +static Status QnnDQQFusionAdd(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& dq_node_unit, + const NodeUnit& q_node_unit, + const logging::Logger& logger, + bool validate = false) { + ORT_UNUSED_PARAMETER(logger); + assert(dq_node_unit.OpType() == QDQ::DQOpName && q_node_unit.OpType() == QDQ::QOpName); + const auto& node_name = utils::GetNodeName(dq_node_unit); + const NodeUnitIODef& input_def = dq_node_unit.Inputs()[0]; + const NodeUnitIODef& output_def = q_node_unit.Outputs()[0]; + + QnnTensorWrapper input_tensor; + QnnTensorWrapper output_tensor; + + ORT_RETURN_IF_ERROR(qnn_model_wrapper.MakeTensorWrapper(input_def, input_tensor)); + ORT_RETURN_IF_ERROR(qnn_model_wrapper.MakeTensorWrapper(output_def, output_tensor)); + + if (validate) { + ORT_RETURN_IF_ERROR(qnn_model_wrapper.ValidateQnnNode(node_name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + QNN_OP_CONVERT, + {input_tensor.GetQnnTensor()}, + {output_tensor.GetQnnTensor()}, + {})); + } else { + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(input_tensor)), "Failed to add input"); + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(output_tensor)), "Failed to add output"); + ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(utils::GetNodeName(q_node_unit), + QNN_OP_PACKAGE_NAME_QTI_AISW, + QNN_OP_CONVERT, + {input_def.node_arg.Name()}, + {output_def.node_arg.Name()}, + {}, + validate), + "Failed to add fused Convert node."); + } + + return Status::OK(); +} + +std::optional TryDQQFusion( + QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& dq_node_unit, + const std::unordered_map& node_to_node_unit, + const std::unordered_map& node_unit_to_qnn_node_group, + const logging::Logger& logger) { + // Expect that this function is called with a standalone DQ. + assert(dq_node_unit.OpType() == QDQ::DQOpName && dq_node_unit.UnitType() == NodeUnit::Type::SingleNode); + + const GraphViewer& graph_viewer = qnn_model_wrapper.GetGraphViewer(); + const Node& dq_node = dq_node_unit.GetNode(); + + // DQ must have a single child (1 output edge) and must not produce a graph output. + if (dq_node.GetOutputEdgesCount() != 1 || graph_viewer.NodeProducesGraphOutput(dq_node)) { + return std::nullopt; + } + + const Node& q_node = dq_node.OutputEdgesBegin()->GetNode(); + if (q_node.OpType() != QDQ::QOpName) { + return std::nullopt; + } + + if (graph_viewer.GetNode(q_node.Index()) == nullptr) { + return std::nullopt; // Node is not in this GraphViewer + } + + const auto q_node_unit_it = node_to_node_unit.find(&q_node); + if (q_node_unit_it == node_to_node_unit.end()) { + return std::nullopt; + } + const NodeUnit* q_node_unit = q_node_unit_it->second; + + // child must not already be part of a QDQ NodeUnit (i.e., be standalone). + if (q_node_unit->UnitType() != NodeUnit::Type::SingleNode) { + return std::nullopt; + } + + // Check if child node has already been handled. Should not be the case if this + // fusion function has been called in topological order, but check to be safe. + if (node_unit_to_qnn_node_group.count(q_node_unit) != 0) { + return std::nullopt; + } + + auto get_const_initializer = [&graph_viewer](const std::string& initializer_name) { + return graph_viewer.GetConstantInitializer(initializer_name, true); + }; + + // DQ and Q must have equal scale type and different zp type. + if (!QDQ::IsDQQConversion(dq_node, q_node, get_const_initializer, graph_viewer.ModelPath())) { + return std::nullopt; + } + + if (Status status = QnnDQQFusionAdd(qnn_model_wrapper, dq_node_unit, *q_node_unit, + logger, /*validate*/ true); + !status.IsOK()) { + return std::nullopt; + } + + // Validation passed, so create a QnnNodeGroup. + LOGS(logger, VERBOSE) << " Will use QNN Convert via fusion. dq_node name: [" << dq_node.Name() + << "] dq_node optype: [" << dq_node.OpType() + << "] q_node name: [" << q_node_unit->Name() + << "] q_node optype: [" << q_node_unit->OpType() + << "]"; + + std::optional qnn_node_group = QnnNodeGroup{}; + qnn_node_group->type_ = QnnNodeGroup::Type::DQQFusion; + qnn_node_group->node_units_.push_back(&dq_node_unit); + qnn_node_group->node_units_.push_back(q_node_unit); + + return qnn_node_group; +} + +namespace dq_q_fusion { + +Status IsSupported(QnnModelWrapper& qmw, const QnnNodeGroup& qnn_node_group, const logging::Logger& logger) { + ORT_RETURN_IF_NOT(qnn_node_group.node_units_.size() == 2, "Expected 2 NodeUnits for DQ -> Q fusion"); + const NodeUnit* dq_node_unit = qnn_node_group.node_units_[0]; + const NodeUnit* q_node_unit = qnn_node_group.node_units_[1]; + ORT_RETURN_IF_NOT(dq_node_unit != nullptr && q_node_unit != nullptr, ""); + Status status = QnnDQQFusionAdd(qmw, *dq_node_unit, *q_node_unit, logger, /*validate*/ true); + + if (!status.IsOK()) { + LOGS(logger, ERROR) << "(DQ -> Q) into QNN Convert fusion is not supported, " + << "but should be according to initial validation. " + << "Node names: " << dq_node_unit->Name() << ", " << q_node_unit->Name() + << " Error: " << status.ErrorMessage(); + } + + return status; +} + +Status AddToModelBuilder(QnnModelWrapper& qmw, const QnnNodeGroup& qnn_node_group, const logging::Logger& logger) { + ORT_RETURN_IF_NOT(qnn_node_group.node_units_.size() == 2, "Expected 2 NodeUnits for DQ -> Q fusion"); + const NodeUnit* dq_node_unit = qnn_node_group.node_units_[0]; + const NodeUnit* q_node_unit = qnn_node_group.node_units_[1]; + ORT_RETURN_IF_NOT(dq_node_unit != nullptr && q_node_unit != nullptr, ""); + return QnnDQQFusionAdd(qmw, *dq_node_unit, *q_node_unit, logger, /*validate*/ false); +} + +#if 0 +const std::vector& GetNodeUnits(const QnnNodeGroup& qnn_node_group) { + return qnn_node_group.node_units_; +} +#endif + +const NodeUnit* GetTargetNodeUnit(const QnnNodeGroup& qnn_node_group, const logging::Logger& logger) { + ORT_UNUSED_PARAMETER(logger); + if (qnn_node_group.node_units_.size() != 2) { + return nullptr; + } + return qnn_node_group.node_units_[0]; +} + +} // namespace dq_q_fusion +} // namespace qnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/dq_q_fusion.h b/onnxruntime/core/providers/qnn/builder/qnn_node_group/dq_q_fusion.h new file mode 100644 index 0000000000000..5a0529e9c4fda --- /dev/null +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/dq_q_fusion.h @@ -0,0 +1,45 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include +#include + +#include "core/framework/node_unit.h" +#include "core/providers/qnn/builder/qnn_model_wrapper.h" +#include "core/providers/qnn/builder/qnn_node_group.h" + +namespace onnxruntime { +namespace qnn { + +/** + * Tries to merge a DQ -> Q sequence into a QNN Convert operator. The DQ -> Q must be converting from + * one quantization type (e.g., uint8_t) to another (e.g., uint16_t). + * + * \param fused_nodes Output list of node units that were fused. Remains empty if fusion is not applied. + * \param qnn_model_wrapper The QNN model that is being built. + * \param dq_node_unit The DQ node unit. + * \param q_node_unit The Q node unit. + * \param logger The logger. + * \param do_op_validation True if should call QNN operator validation APIs. + * \return An onnxruntime::Status + */ +std::optional TryDQQFusion( + QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& dq_node_unit, + const std::unordered_map& node_to_node_unit, + const std::unordered_map& node_unit_to_qnn_node_group, + const logging::Logger& logger); + +namespace dq_q_fusion { + +Status IsSupported(QnnModelWrapper& qmw, const QnnNodeGroup& qnn_node_group, const logging::Logger& logger); +Status AddToModelBuilder(QnnModelWrapper& qmw, const QnnNodeGroup& qnn_node_group, const logging::Logger& logger); +// const std::vector& GetNodeUnits(const QnnNodeGroup& qnn_node_group); +const NodeUnit* GetTargetNodeUnit(const QnnNodeGroup& qnn_node_group, const logging::Logger& logger); +} // namespace dq_q_fusion +} // namespace qnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/hardsigmoid_mul_fusion.cc b/onnxruntime/core/providers/qnn/builder/qnn_node_group/hardsigmoid_mul_fusion.cc new file mode 100644 index 0000000000000..beee1d2ea5249 --- /dev/null +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/hardsigmoid_mul_fusion.cc @@ -0,0 +1,191 @@ +#include "core/providers/qnn/builder/qnn_node_group/hardsigmoid_mul_fusion.h" + +#include +#include +#include +#include +#include +#include "core/graph/graph_utils.h" +#include "core/optimizer/qdq_transformer/qdq_util.h" +#include "core/framework/node_unit.h" +#include "core/providers/shared/utils/utils.h" +#include "core/providers/qnn/builder/qnn_utils.h" +#include "core/providers/qnn/builder/op_builder_factory.h" + +namespace onnxruntime { +namespace qnn { + +static Status QnnHardSigmoidMulFusionAdd(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& hardsigmoid_node_unit, + const NodeUnit& mul_node_unit, + const logging::Logger& logger, + bool validate = false) { + ORT_UNUSED_PARAMETER(logger); + assert(hardsigmoid_node_unit.OpType() == "HardSigmoid" && mul_node_unit.OpType() == "Mul"); + const auto& node_name = utils::GetNodeName(hardsigmoid_node_unit); + const NodeUnitIODef& input_def = hardsigmoid_node_unit.Inputs()[0]; + const NodeUnitIODef& output_def = mul_node_unit.Outputs()[0]; + + QnnTensorWrapper input_tensor; + QnnTensorWrapper output_tensor; + + ORT_RETURN_IF_ERROR(qnn_model_wrapper.MakeTensorWrapper(input_def, input_tensor)); + ORT_RETURN_IF_ERROR(qnn_model_wrapper.MakeTensorWrapper(output_def, output_tensor)); + + if (validate) { + ORT_RETURN_IF_ERROR(qnn_model_wrapper.ValidateQnnNode(node_name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + QNN_OP_HARD_SWISH, + {input_tensor.GetQnnTensor()}, + {output_tensor.GetQnnTensor()}, + {})); + } else { + LOGS(logger, VERBOSE) << " Adding QNN HardSwish via fusion. HardSigmoid name: [" << hardsigmoid_node_unit.Name() + << "] Mul name: [" << mul_node_unit.Name() << "]"; + + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(input_tensor)), "Failed to add input"); + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(output_tensor)), "Failed to add output"); + ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(node_name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + QNN_OP_HARD_SWISH, + {input_def.node_arg.Name()}, + {output_def.node_arg.Name()}, + {}, + validate), + "Failed to add fused HardSwish node."); + } + + return Status::OK(); +} + +std::optional TryHardSigmoidMulFusion( + QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& hardsigmoid_node_unit, + const std::unordered_map& node_to_node_unit, + const std::unordered_map& node_unit_to_qnn_node_group, + const logging::Logger& logger) { + // Looking for a standalone HardSigmoid to start the sequence. + if (hardsigmoid_node_unit.OpType() != "HardSigmoid" || + hardsigmoid_node_unit.UnitType() != NodeUnit::Type::SingleNode) { + return std::nullopt; + } + + NodeAttrHelper hs_attr_helper(hardsigmoid_node_unit); + float alpha = hs_attr_helper.Get("alpha", 0.2f); + float beta = hs_attr_helper.Get("beta", 0.5f); + constexpr float req_alpha = 1.0f / 6.0f; + constexpr float req_beta = 0.5f; + constexpr float alpha_eps = std::numeric_limits::epsilon() * req_alpha; + constexpr float beta_eps = std::numeric_limits::epsilon() * req_beta; + + // Check for explicit values of alpha and beta. + if (std::abs(alpha - req_alpha) > alpha_eps || std::abs(beta - req_beta) > beta_eps) { + return std::nullopt; + } + + const GraphViewer& graph_viewer = qnn_model_wrapper.GetGraphViewer(); + const Node& hs_node = hardsigmoid_node_unit.GetNode(); + + // HardSigmoid must have a single child (1 output edge) and must not produce a graph output. + if (hs_node.GetOutputEdgesCount() != 1 || graph_viewer.NodeProducesGraphOutput(hs_node)) { + return std::nullopt; + } + + const Node& mul_node = hs_node.OutputEdgesBegin()->GetNode(); + if (mul_node.OpType() != "Mul") { + return std::nullopt; + } + + if (graph_viewer.GetNode(mul_node.Index()) == nullptr) { + return std::nullopt; // Node is not in this GraphViewer + } + + const auto mul_node_unit_it = node_to_node_unit.find(&mul_node); + if (mul_node_unit_it == node_to_node_unit.end()) { + return std::nullopt; + } + const NodeUnit* mul_node_unit = mul_node_unit_it->second; + + // Check if Mul node has already been handled. Should not be the case if this + // fusion function has been called in topological order, but check to be safe. + if (node_unit_to_qnn_node_group.count(mul_node_unit) != 0) { + return std::nullopt; + } + + // Mul child must not already be part of a QDQ NodeUnit (i.e., be standalone). + if (mul_node_unit->UnitType() != NodeUnit::Type::SingleNode) { + return std::nullopt; + } + + // Input to HardSigmoid must also be the other input to the Mul. + auto& hs_input_name = hardsigmoid_node_unit.Inputs()[0].node_arg.Name(); + const bool same_root_input = mul_node.InputDefs()[0]->Name() == hs_input_name || + mul_node.InputDefs()[1]->Name() == hs_input_name; + + if (!same_root_input) { + return std::nullopt; + } + + if (Status status = QnnHardSigmoidMulFusionAdd(qnn_model_wrapper, hardsigmoid_node_unit, *mul_node_unit, + logger, /*validate*/ true); + !status.IsOK()) { + return std::nullopt; + } + + // Validation passed, so create a QnnNodeGroup. Any errors are now passed back to the caller. + LOGS(logger, VERBOSE) << "Will use QNN HardSwish via fusion. HardSigmoid name: [" << hardsigmoid_node_unit.Name() + << "] Mul name: [" << mul_node_unit->Name() << "]"; + + std::optional qnn_node_group = QnnNodeGroup{}; + qnn_node_group->type_ = QnnNodeGroup::Type::HardSigmoidMulFusion; + qnn_node_group->node_units_.push_back(&hardsigmoid_node_unit); + qnn_node_group->node_units_.push_back(mul_node_unit); + + return qnn_node_group; +} + +namespace hs_mul_fusion { + +Status IsSupported(QnnModelWrapper& qmw, const QnnNodeGroup& qnn_node_group, const logging::Logger& logger) { + ORT_RETURN_IF_NOT(qnn_node_group.node_units_.size() == 2, "Expected 2 NodeUnits for HardSimoid -> Mul fusion"); + const NodeUnit* hardsigmoid_node_unit = qnn_node_group.node_units_[0]; + const NodeUnit* mul_node_unit = qnn_node_group.node_units_[1]; + ORT_RETURN_IF_NOT(hardsigmoid_node_unit != nullptr && mul_node_unit != nullptr, ""); + Status status = QnnHardSigmoidMulFusionAdd(qmw, *hardsigmoid_node_unit, *mul_node_unit, logger, + /*validate*/ true); + + if (!status.IsOK()) { + LOGS(logger, ERROR) << "(HardSigmoid -> Mul) into QNN HardSwish fusion is not supported, " + << "but should be according to initial validation. " + << "Node names: " << hardsigmoid_node_unit->Name() << ", " << mul_node_unit->Name() + << " Error: " << status.ErrorMessage(); + } + + return status; +} + +Status AddToModelBuilder(QnnModelWrapper& qmw, const QnnNodeGroup& qnn_node_group, const logging::Logger& logger) { + ORT_RETURN_IF_NOT(qnn_node_group.node_units_.size() == 2, "Expected 2 NodeUnits for HardSimoid -> Mul fusion"); + const NodeUnit* hardsigmoid_node_unit = qnn_node_group.node_units_[0]; + const NodeUnit* mul_node_unit = qnn_node_group.node_units_[1]; + ORT_RETURN_IF_NOT(hardsigmoid_node_unit != nullptr && mul_node_unit != nullptr, ""); + return QnnHardSigmoidMulFusionAdd(qmw, *hardsigmoid_node_unit, *mul_node_unit, logger, /*validate*/ false); +} + +#if 0 +const std::vector& GetNodeUnits(const QnnNodeGroup& qnn_node_group) { + return qnn_node_group.node_units_; +} +#endif + +const NodeUnit* GetTargetNodeUnit(const QnnNodeGroup& qnn_node_group, const logging::Logger& logger) { + ORT_UNUSED_PARAMETER(logger); + if (qnn_node_group.node_units_.size() != 2) { + return nullptr; + } + return qnn_node_group.node_units_[0]; +} + +} // namespace hs_mul_fusion +} // namespace qnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/hardsigmoid_mul_fusion.h b/onnxruntime/core/providers/qnn/builder/qnn_node_group/hardsigmoid_mul_fusion.h new file mode 100644 index 0000000000000..01bb0a39b4990 --- /dev/null +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/hardsigmoid_mul_fusion.h @@ -0,0 +1,47 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include +#include + +#include "core/framework/node_unit.h" +#include "core/providers/qnn/builder/qnn_model_wrapper.h" +#include "core/providers/qnn/builder/qnn_node_group.h" + +namespace onnxruntime { +namespace qnn { + +/** + * Tries to fuse the sequence `x * HardSigmoid(x)` into a single HardSwish(x) operator. + * Should be called in a topologically ordered iteration of node units. + * + * \param fused_nodes Output list of node units that were fused. Remains empty if fusion was not applied. + * \param qnn_model_wrapper The QNN model that is being built. + * \param starting_node The node unit that could potentially start the sequence. + * \param node_unit_map Maps a node to its node unit. + * \param handled_node_units Set of node units that have already been processed. Fusion will not fuse nodes + * in this set. + * \param logger The logger. + * \param do_op_validation True if should call QNN operator validation APIs. + * \return A Status indicating a potential failure. + */ +std::optional TryHardSigmoidMulFusion( + QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& hardsigmoid_node_unit, + const std::unordered_map& node_to_node_unit, + const std::unordered_map& node_unit_to_qnn_node_group, + const logging::Logger& logger); + +namespace hs_mul_fusion { + +Status IsSupported(QnnModelWrapper& qmw, const QnnNodeGroup& qnn_node_group, const logging::Logger& logger); +Status AddToModelBuilder(QnnModelWrapper& qmw, const QnnNodeGroup& qnn_node_group, const logging::Logger& logger); +// const std::vector& GetNodeUnits(const QnnNodeGroup& qnn_node_group); +const NodeUnit* GetTargetNodeUnit(const QnnNodeGroup& qnn_node_group, const logging::Logger& logger); +} // namespace hs_mul_fusion +} // namespace qnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/qnn_node_group.cc b/onnxruntime/core/providers/qnn/builder/qnn_node_group/qnn_node_group.cc new file mode 100644 index 0000000000000..38c5027fcb8ac --- /dev/null +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/qnn_node_group.cc @@ -0,0 +1,264 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/qnn/builder/qnn_node_group.h" + +#include +#include +#include +#include +#include +#include +#include +#include "core/graph/graph_utils.h" +#include "core/optimizer/qdq_transformer/qdq_util.h" +#include "core/framework/node_unit.h" +#include "core/providers/qnn/builder/qnn_utils.h" +#include "core/providers/qnn/builder/qnn_model_wrapper.h" +#include "core/providers/qnn/builder/op_builder_factory.h" +#include "core/providers/qnn/builder/qnn_node_group/conv_activation_fusion.h" +#include "core/providers/qnn/builder/qnn_node_group/dq_q_fusion.h" +#include "core/providers/qnn/builder/qnn_node_group/hardsigmoid_mul_fusion.h" + +namespace onnxruntime { +namespace qnn { + +std::string_view QnnNodeGroup::TypeToString(QnnNodeGroup::Type type) { + static std::array(QnnNodeGroup::Type::COUNT)> type_names = { + "Undefined", + "NodeUnit", + "ConvActivationFusion", + "DQQFusion", + "HardSigmoidMulFusion", + }; + + return type_names[static_cast(type)]; +} + +Status QnnNodeGroup::IsSupported(QnnModelWrapper& qmw, const logging::Logger& logger) const { + using Func = Status (*)( + QnnModelWrapper&, + const QnnNodeGroup&, + const logging::Logger&); + + static std::array(QnnNodeGroup::Type::COUNT)> funcs = { + [](QnnModelWrapper&, const QnnNodeGroup& qnn_node_group, const logging::Logger& logger) -> Status { + std::string error_msg = MakeString("Unhandled QnnNodeGroup::Type ", TypeToString(qnn_node_group.type_), + " in QnnNodeGroup::IsSupported()"); + LOGS(logger, ERROR) << error_msg; + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, error_msg); + }, + [](QnnModelWrapper& qmw, const QnnNodeGroup& qnn_node_group, const logging::Logger& logger) -> Status { + ORT_RETURN_IF_NOT(qnn_node_group.node_units_.size() == 1 && qnn_node_group.node_units_[0] != nullptr, ""); + const NodeUnit& node_unit = *qnn_node_group.node_units_[0]; + const std::string& op_type = node_unit.OpType(); + const auto* op_builder = qnn::GetOpBuilder(op_type); + + if (op_builder == nullptr) { + std::string err_msg = MakeString("Operators of type `", op_type, + "` are not supported by QNN EP.", op_type, " node `", + node_unit.Name(), "` will not be assigned to QNN EP."); + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, err_msg); + } + + Status status = op_builder->IsOpSupported(qmw, *qnn_node_group.node_units_[0], logger); + if (!status.IsOK()) { + LOGS(logger, WARNING) << op_type << " node `" << node_unit.Name() + << "` is not supported: " << status.ErrorMessage(); + } + + return status; + }, + conv_act_fusion::IsSupported, + dq_q_fusion::IsSupported, + hs_mul_fusion::IsSupported, + }; + + return funcs[static_cast(type_)](qmw, *this, logger); +} + +Status QnnNodeGroup::AddToModelBuilder(QnnModelWrapper& qmw, const logging::Logger& logger) const { + using Func = Status (*)( + QnnModelWrapper&, + const QnnNodeGroup&, + const logging::Logger&); + + static std::array(QnnNodeGroup::Type::COUNT)> funcs = { + [](QnnModelWrapper&, const QnnNodeGroup& qnn_node_group, const logging::Logger& logger) -> Status { + std::string error_msg = MakeString("Unhandled QnnNodeGroup::Type ", TypeToString(qnn_node_group.type_), + " in QnnNodeGroup::AddToModelBuilder()"); + LOGS(logger, ERROR) << error_msg; + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, error_msg); + }, + [](QnnModelWrapper& qmw, const QnnNodeGroup& qnn_node_group, const logging::Logger& logger) -> Status { + ORT_RETURN_IF_NOT(qnn_node_group.node_units_.size() == 1 && qnn_node_group.node_units_[0] != nullptr, ""); + const auto* op_builder = qnn::GetOpBuilder(qnn_node_group.node_units_[0]->OpType()); + ORT_RETURN_IF_NOT(op_builder != nullptr, "[QNN EP]: Missing OpBuilder for OpType ", qnn_node_group.node_units_[0]->OpType()); + return op_builder->AddToModelBuilder(qmw, *qnn_node_group.node_units_[0], logger, /*do_op_validation*/ false); + }, + conv_act_fusion::AddToModelBuilder, + dq_q_fusion::AddToModelBuilder, + hs_mul_fusion::AddToModelBuilder, + }; + + return funcs[static_cast(type_)](qmw, *this, logger); +} + +const NodeUnit* QnnNodeGroup::GetTargetNodeUnit(const logging::Logger& logger) const { + using Func = const NodeUnit* (*)(const QnnNodeGroup&, const logging::Logger&); + + static std::array(QnnNodeGroup::Type::COUNT)> funcs = { + [](const QnnNodeGroup& qnn_node_group, const logging::Logger& logger) -> const NodeUnit* { + std::string error_msg = MakeString("Unhandled QnnNodeGroup::Type ", TypeToString(qnn_node_group.type_), + " in QnnNodeGroup::AddToModelBuilder()"); + LOGS(logger, ERROR) << error_msg; + return nullptr; + }, + [](const QnnNodeGroup& qnn_node_group, const logging::Logger&) -> const NodeUnit* { + if (qnn_node_group.node_units_.size() != 1) { + return nullptr; + } + return qnn_node_group.node_units_[0]; + }, + conv_act_fusion::GetTargetNodeUnit, + dq_q_fusion::GetTargetNodeUnit, + hs_mul_fusion::GetTargetNodeUnit, + }; + + return funcs[static_cast(type_)](*this, logger); +} + +using FusionFunc = std::optional (*)( + QnnModelWrapper&, + const NodeUnit&, + const std::unordered_map&, + const std::unordered_map&, + const logging::Logger&); + +static std::optional TryQnnFusions( + QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& starting_node_unit, + const std::unordered_map& node_to_node_unit, + const std::unordered_map& node_unit_to_qnn_node_group, + const logging::Logger& logger) { + // Maps a starting operator type to the fusion function. + static std::unordered_map fusions = { + {"DequantizeLinear", TryDQQFusion}, + {"HardSigmoid", TryHardSigmoidMulFusion}, + {"Conv", TryConvActivationFusion}, + {"ConvTranspose", TryConvActivationFusion}, + }; + + // For now, all fusions involve standalone node units (i.e., no wrapping DQ/Q nodes). + if (starting_node_unit.UnitType() != NodeUnit::Type::SingleNode) { + return std::nullopt; + } + + auto iter = fusions.find(starting_node_unit.OpType()); + if (iter != fusions.end()) { + FusionFunc fusion_func = iter->second; + return fusion_func(qnn_model_wrapper, starting_node_unit, node_to_node_unit, + node_unit_to_qnn_node_group, logger); + } + return std::nullopt; +} + +Status GetQnnNodeGroups(/*out*/ std::vector& qnn_node_groups, + QnnModelWrapper& qnn_model_wrapper, + const std::unordered_map& node_to_node_unit, + const size_t num_node_units, + const logging::Logger& logger) { + const GraphViewer& graph_viewer = qnn_model_wrapper.GetGraphViewer(); + const std::vector sorted_node_indices = graph_viewer.GetNodesInTopologicalOrder(); + + std::vector sorted_qnn_node_group_indices; + sorted_qnn_node_group_indices.reserve(num_node_units); + + std::vector tmp_qnn_node_groups; + tmp_qnn_node_groups.reserve(num_node_units); + + { + std::unordered_map node_unit_to_qnn_node_group; + std::vector> sorted_node_units; + sorted_node_units.reserve(num_node_units); + + // Create QnnNodeGroups for fusions first. + for (NodeIndex node_index : sorted_node_indices) { + gsl::not_null node = graph_viewer.GetNode(node_index); + + // Get the NodeUnit associated with the node. + const auto node_unit_it = node_to_node_unit.find(node); + ORT_RETURN_IF_NOT(node_unit_it != node_to_node_unit.end(), "Could not find NodeUnit for Node ", node->Name()); + gsl::not_null node_unit = node_unit_it->second; + + // Skip this node if it is not the NodeUnit's target node to ensure NodeUnits are visited in topological order. + if (node != &node_unit->GetNode()) { + continue; + } + + sorted_node_units.push_back(node_unit); + + if (node_unit_to_qnn_node_group.count(node_unit) != 0) { + continue; // Already handled this node unit + } + + std::optional fused_node_group = TryQnnFusions(qnn_model_wrapper, *node_unit, + node_to_node_unit, node_unit_to_qnn_node_group, + logger); + + if (fused_node_group.has_value()) { + const QnnNodeGroup::IndexType index = tmp_qnn_node_groups.size(); + fused_node_group->index_ = index; + + for (const NodeUnit* fused_node_unit : fused_node_group->GetNodeUnits()) { + assert(fused_node_unit != nullptr); + node_unit_to_qnn_node_group.insert({fused_node_unit, index}); + } + + tmp_qnn_node_groups.push_back(std::move(*fused_node_group)); + } + } + + // Create QnnNodeGroups for the leftover NodeUnits. + for (gsl::not_null node_unit : sorted_node_units) { + const auto it = node_unit_to_qnn_node_group.find(node_unit); + if (it != node_unit_to_qnn_node_group.end()) { + // Already handled this NodeUnit. + const QnnNodeGroup& qnn_node_group = tmp_qnn_node_groups[it->second]; + if (node_unit == qnn_node_group.GetTargetNodeUnit(logger)) { + sorted_qnn_node_group_indices.push_back(qnn_node_group.index_); + } + continue; + } + + const QnnNodeGroup::IndexType index = tmp_qnn_node_groups.size(); + QnnNodeGroup fused_node_group = {}; + fused_node_group.type_ = QnnNodeGroup::Type::NodeUnit; + fused_node_group.index_ = index; + fused_node_group.node_units_.resize(1); + fused_node_group.node_units_[0] = node_unit; + tmp_qnn_node_groups.push_back(std::move(fused_node_group)); + + node_unit_to_qnn_node_group.insert({node_unit, index}); + sorted_qnn_node_group_indices.push_back(index); + } + + assert(tmp_qnn_node_groups.size() == sorted_qnn_node_group_indices.size()); + } + + // Copy QnnNodeGroups to output in sorted (topological) order. + qnn_node_groups.resize(0); + qnn_node_groups.reserve(tmp_qnn_node_groups.size()); + for (auto index : sorted_qnn_node_group_indices) { + assert(index < tmp_qnn_node_groups.size()); + QnnNodeGroup qnn_node_group = std::move(tmp_qnn_node_groups[index]); + qnn_node_group.index_ = qnn_node_groups.size(); + qnn_node_groups.push_back(std::move(qnn_node_group)); + } + + assert(qnn_node_groups.size() == sorted_qnn_node_group_indices.size()); + + return Status::OK(); +} +} // namespace qnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/utils.cc b/onnxruntime/core/providers/qnn/builder/qnn_node_group/utils.cc new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/utils.h b/onnxruntime/core/providers/qnn/builder/qnn_node_group/utils.h new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc index eabd1a5043ba1..dbc8394da34a6 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc @@ -16,7 +16,6 @@ #include "core/platform/env.h" #include "core/providers/common.h" #include "core/providers/partitioning_utils.h" -#include "core/providers/qnn/builder/qnn_fusions.h" #include "core/providers/partitioning_utils.h" #include "core/providers/qnn/builder/qnn_model_wrapper.h" #include "core/providers/qnn/builder/op_builder_factory.h" diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.h b/onnxruntime/core/providers/qnn/qnn_execution_provider.h index b9e3608856b65..0e677e156a491 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.h +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.h @@ -11,6 +11,7 @@ #include "core/providers/qnn/builder/qnn_backend_manager.h" #include "core/providers/qnn/builder/qnn_model.h" #include "core/providers/qnn/builder/qnn_configs_helper.h" +#include "core/providers/qnn/builder/qnn_node_group.h" #include "HTP/QnnHtpGraph.h" #include #include From 3dbd2d6385ff1a6a3e2a42bb464f61e5f14623db Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Sat, 27 Jul 2024 17:48:47 -0700 Subject: [PATCH 11/20] Reuse utility func --- .../qnn_node_group/conv_activation_fusion.cc | 52 +-------------- .../qnn/builder/qnn_node_group/dq_q_fusion.cc | 39 +++-------- .../qnn_node_group/hardsigmoid_mul_fusion.cc | 36 ++-------- .../qnn/builder/qnn_node_group/utils.cc | 66 +++++++++++++++++++ .../qnn/builder/qnn_node_group/utils.h | 23 +++++++ 5 files changed, 107 insertions(+), 109 deletions(-) diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/conv_activation_fusion.cc b/onnxruntime/core/providers/qnn/builder/qnn_node_group/conv_activation_fusion.cc index 5e947446021e6..ac31a588a70f5 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_node_group/conv_activation_fusion.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/conv_activation_fusion.cc @@ -11,61 +11,11 @@ #include "core/providers/shared/utils/utils.h" #include "core/providers/qnn/builder/qnn_utils.h" #include "core/providers/qnn/builder/op_builder_factory.h" +#include "core/providers/qnn/builder/qnn_node_group/utils.h" namespace onnxruntime { namespace qnn { -static const NodeUnit* GetOnlyChildOfType(const GraphViewer& graph_viewer, - const NodeUnit& parent_node_unit, - gsl::span child_op_types, - const std::unordered_map& node_unit_map, - const std::unordered_map& node_unit_to_qnn_node_group) { - const Node& parent_node = parent_node_unit.GetNode(); - - // Parent must have a single child (1 output edge) and must not produce a graph output. - if (parent_node.GetOutputEdgesCount() != 1 || graph_viewer.NodeProducesGraphOutput(parent_node)) { - return nullptr; - } - - // Child must be of a valid type. - const Node& child_node = parent_node.OutputEdgesBegin()->GetNode(); - if (graph_viewer.GetNode(child_node.Index()) == nullptr) { - return nullptr; // Node is not in this GraphViewer - } - const std::string& child_type = child_node.OpType(); - bool is_valid_child_type = false; - - for (const auto& valid_op_type : child_op_types) { - if (valid_op_type == child_type) { - is_valid_child_type = true; - break; - } - } - - if (!is_valid_child_type) { - return nullptr; - } - - const auto child_node_unit_it = node_unit_map.find(&child_node); - if (child_node_unit_it == node_unit_map.end()) { - return nullptr; - } - const NodeUnit* child_node_unit = child_node_unit_it->second; - - // Check if child node has already been handled. Should not be the case if the calling - // fusion function has been called in topological order, but check to be safe. - if (node_unit_to_qnn_node_group.count(child_node_unit) != 0) { - return nullptr; - } - - // child must not already be part of a QDQ NodeUnit (i.e., be standalone). - if (child_node_unit->UnitType() != NodeUnit::Type::SingleNode) { - return nullptr; - } - - return child_node_unit; -} - static bool GetQScalarScaleZeroPoint(const QnnModelWrapper& qnn_model_wrapper, const NodeUnit& q_node_unit, /*out*/ float& scale, diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/dq_q_fusion.cc b/onnxruntime/core/providers/qnn/builder/qnn_node_group/dq_q_fusion.cc index 25bce1fe39a7c..3f513af8b666c 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_node_group/dq_q_fusion.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/dq_q_fusion.cc @@ -11,6 +11,7 @@ #include "core/providers/shared/utils/utils.h" #include "core/providers/qnn/builder/qnn_utils.h" #include "core/providers/qnn/builder/op_builder_factory.h" +#include "core/providers/qnn/builder/qnn_node_group/utils.h" namespace onnxruntime { namespace qnn { @@ -62,39 +63,19 @@ std::optional TryDQQFusion( const std::unordered_map& node_unit_to_qnn_node_group, const logging::Logger& logger) { // Expect that this function is called with a standalone DQ. - assert(dq_node_unit.OpType() == QDQ::DQOpName && dq_node_unit.UnitType() == NodeUnit::Type::SingleNode); - - const GraphViewer& graph_viewer = qnn_model_wrapper.GetGraphViewer(); - const Node& dq_node = dq_node_unit.GetNode(); - - // DQ must have a single child (1 output edge) and must not produce a graph output. - if (dq_node.GetOutputEdgesCount() != 1 || graph_viewer.NodeProducesGraphOutput(dq_node)) { + if (dq_node_unit.OpType() != "DequantizeLinear" || dq_node_unit.UnitType() != NodeUnit::Type::SingleNode) { return std::nullopt; } - const Node& q_node = dq_node.OutputEdgesBegin()->GetNode(); - if (q_node.OpType() != QDQ::QOpName) { - return std::nullopt; - } - - if (graph_viewer.GetNode(q_node.Index()) == nullptr) { - return std::nullopt; // Node is not in this GraphViewer - } - - const auto q_node_unit_it = node_to_node_unit.find(&q_node); - if (q_node_unit_it == node_to_node_unit.end()) { - return std::nullopt; - } - const NodeUnit* q_node_unit = q_node_unit_it->second; + const GraphViewer& graph_viewer = qnn_model_wrapper.GetGraphViewer(); + const Node& dq_node = dq_node_unit.GetNode(); - // child must not already be part of a QDQ NodeUnit (i.e., be standalone). - if (q_node_unit->UnitType() != NodeUnit::Type::SingleNode) { - return std::nullopt; - } + // DQ must have a single Q child (1 output edge) and must not produce a graph output. + const std::array child_types = {"QuantizeLinear"}; + const NodeUnit* q_node_unit = GetOnlyChildOfType(graph_viewer, dq_node_unit, child_types, + node_to_node_unit, node_unit_to_qnn_node_group); - // Check if child node has already been handled. Should not be the case if this - // fusion function has been called in topological order, but check to be safe. - if (node_unit_to_qnn_node_group.count(q_node_unit) != 0) { + if (q_node_unit == nullptr) { return std::nullopt; } @@ -103,7 +84,7 @@ std::optional TryDQQFusion( }; // DQ and Q must have equal scale type and different zp type. - if (!QDQ::IsDQQConversion(dq_node, q_node, get_const_initializer, graph_viewer.ModelPath())) { + if (!QDQ::IsDQQConversion(dq_node, q_node_unit->GetNode(), get_const_initializer, graph_viewer.ModelPath())) { return std::nullopt; } diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/hardsigmoid_mul_fusion.cc b/onnxruntime/core/providers/qnn/builder/qnn_node_group/hardsigmoid_mul_fusion.cc index beee1d2ea5249..6411fb1b68e67 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_node_group/hardsigmoid_mul_fusion.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/hardsigmoid_mul_fusion.cc @@ -11,6 +11,7 @@ #include "core/providers/shared/utils/utils.h" #include "core/providers/qnn/builder/qnn_utils.h" #include "core/providers/qnn/builder/op_builder_factory.h" +#include "core/providers/qnn/builder/qnn_node_group/utils.h" namespace onnxruntime { namespace qnn { @@ -83,41 +84,18 @@ std::optional TryHardSigmoidMulFusion( return std::nullopt; } + // HardSigmoid must have a single Mul child (1 output edge) and must not produce a graph output. const GraphViewer& graph_viewer = qnn_model_wrapper.GetGraphViewer(); - const Node& hs_node = hardsigmoid_node_unit.GetNode(); + const std::array child_types = {"Mul"}; + const NodeUnit* mul_node_unit = GetOnlyChildOfType(graph_viewer, hardsigmoid_node_unit, child_types, + node_to_node_unit, node_unit_to_qnn_node_group); - // HardSigmoid must have a single child (1 output edge) and must not produce a graph output. - if (hs_node.GetOutputEdgesCount() != 1 || graph_viewer.NodeProducesGraphOutput(hs_node)) { - return std::nullopt; - } - - const Node& mul_node = hs_node.OutputEdgesBegin()->GetNode(); - if (mul_node.OpType() != "Mul") { - return std::nullopt; - } - - if (graph_viewer.GetNode(mul_node.Index()) == nullptr) { - return std::nullopt; // Node is not in this GraphViewer - } - - const auto mul_node_unit_it = node_to_node_unit.find(&mul_node); - if (mul_node_unit_it == node_to_node_unit.end()) { - return std::nullopt; - } - const NodeUnit* mul_node_unit = mul_node_unit_it->second; - - // Check if Mul node has already been handled. Should not be the case if this - // fusion function has been called in topological order, but check to be safe. - if (node_unit_to_qnn_node_group.count(mul_node_unit) != 0) { - return std::nullopt; - } - - // Mul child must not already be part of a QDQ NodeUnit (i.e., be standalone). - if (mul_node_unit->UnitType() != NodeUnit::Type::SingleNode) { + if (mul_node_unit == nullptr) { return std::nullopt; } // Input to HardSigmoid must also be the other input to the Mul. + const Node& mul_node = mul_node_unit->GetNode(); auto& hs_input_name = hardsigmoid_node_unit.Inputs()[0].node_arg.Name(); const bool same_root_input = mul_node.InputDefs()[0]->Name() == hs_input_name || mul_node.InputDefs()[1]->Name() == hs_input_name; diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/utils.cc b/onnxruntime/core/providers/qnn/builder/qnn_node_group/utils.cc index e69de29bb2d1d..923ea8314c907 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_node_group/utils.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/utils.cc @@ -0,0 +1,66 @@ +#include "core/providers/qnn/builder/qnn_node_group/utils.h" + +#include +#include +#include + +#include "core/graph/graph_viewer.h" +#include "core/framework/node_unit.h" +#include "core/providers/qnn/builder/qnn_node_group.h" + +namespace onnxruntime { +namespace qnn { + +const NodeUnit* GetOnlyChildOfType(const GraphViewer& graph_viewer, + const NodeUnit& parent_node_unit, + gsl::span child_op_types, + const std::unordered_map& node_unit_map, + const std::unordered_map& node_unit_to_qnn_node_group) { + const Node& parent_node = parent_node_unit.GetNode(); + + // Parent must have a single child (1 output edge) and must not produce a graph output. + if (parent_node.GetOutputEdgesCount() != 1 || graph_viewer.NodeProducesGraphOutput(parent_node)) { + return nullptr; + } + + // Child must be of a valid type. + const Node& child_node = parent_node.OutputEdgesBegin()->GetNode(); + if (graph_viewer.GetNode(child_node.Index()) == nullptr) { + return nullptr; // Node is not in this GraphViewer + } + const std::string& child_type = child_node.OpType(); + bool is_valid_child_type = false; + + for (const auto& valid_op_type : child_op_types) { + if (valid_op_type == child_type) { + is_valid_child_type = true; + break; + } + } + + if (!is_valid_child_type) { + return nullptr; + } + + const auto child_node_unit_it = node_unit_map.find(&child_node); + if (child_node_unit_it == node_unit_map.end()) { + return nullptr; + } + const NodeUnit* child_node_unit = child_node_unit_it->second; + + // Check if child node has already been handled. Should not be the case if the calling + // fusion function has been called in topological order, but check to be safe. + if (node_unit_to_qnn_node_group.count(child_node_unit) != 0) { + return nullptr; + } + + // child must not already be part of a QDQ NodeUnit (i.e., be standalone). + if (child_node_unit->UnitType() != NodeUnit::Type::SingleNode) { + return nullptr; + } + + return child_node_unit; +} + +} // namespace qnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/utils.h b/onnxruntime/core/providers/qnn/builder/qnn_node_group/utils.h index e69de29bb2d1d..4f67fabab9133 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_node_group/utils.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/utils.h @@ -0,0 +1,23 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include + +#include "core/graph/graph_viewer.h" +#include "core/framework/node_unit.h" +#include "core/providers/qnn/builder/qnn_node_group.h" + +namespace onnxruntime { +namespace qnn { +const NodeUnit* GetOnlyChildOfType(const GraphViewer& graph_viewer, + const NodeUnit& parent_node_unit, + gsl::span child_op_types, + const std::unordered_map& node_unit_map, + const std::unordered_map& node_unit_to_qnn_node_group); + +} // namespace qnn +} // namespace onnxruntime From 44dc6962e7a3671f9beef9d9c5aceb2a05d2fe85 Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Sat, 27 Jul 2024 22:15:45 -0700 Subject: [PATCH 12/20] Use virtual base class instead of enum type --- .../core/providers/qnn/builder/qnn_model.cc | 6 +- .../providers/qnn/builder/qnn_node_group.h | 34 ++-- .../qnn_node_group/conv_activation_fusion.cc | 139 ++++++------- .../qnn_node_group/conv_activation_fusion.h | 34 +++- .../qnn/builder/qnn_node_group/dq_q_fusion.cc | 68 ++----- .../qnn/builder/qnn_node_group/dq_q_fusion.h | 28 ++- .../qnn_node_group/hardsigmoid_mul_fusion.cc | 67 ++----- .../qnn_node_group/hardsigmoid_mul_fusion.h | 27 ++- .../builder/qnn_node_group/qnn_node_group.cc | 182 +++++------------- .../qnn/builder/qnn_node_group/utils.cc | 2 +- .../qnn/builder/qnn_node_group/utils.h | 2 +- .../providers/qnn/qnn_execution_provider.cc | 15 +- .../providers/qnn/qnn_execution_provider.h | 1 - 13 files changed, 241 insertions(+), 364 deletions(-) diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model.cc b/onnxruntime/core/providers/qnn/builder/qnn_model.cc index 47d4a13b071ab..83f9184d33611 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_model.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_model.cc @@ -117,14 +117,14 @@ Status QnnModel::ComposeGraph(const GraphViewer& graph_viewer, return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to initialize qnn_model_wrapper."); } - std::vector qnn_node_groups; + std::vector> qnn_node_groups; qnn_node_groups.reserve(node_unit_holder.size()); ORT_RETURN_IF_ERROR(qnn::GetQnnNodeGroups(qnn_node_groups, qnn_model_wrapper, node_unit_map, node_unit_holder.size(), logger_)); - for (const qnn::QnnNodeGroup& qnn_node_group : qnn_node_groups) { - Status status = qnn_node_group.AddToModelBuilder(qnn_model_wrapper, logger_); + for (const std::unique_ptr& qnn_node_group : qnn_node_groups) { + Status status = qnn_node_group->AddToModelBuilder(qnn_model_wrapper, logger_); if (!status.IsOK()) { LOGS(logger_, ERROR) << "[QNN EP] Failed to add supported node to QNN graph during EP's compile call: " diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group.h b/onnxruntime/core/providers/qnn/builder/qnn_node_group.h index 779e04ed91b41..fb6aa221aac3e 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_node_group.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group.h @@ -3,9 +3,8 @@ #pragma once -#include +#include #include -#include #include #include "core/framework/node_unit.h" @@ -14,30 +13,19 @@ namespace onnxruntime { namespace qnn { -struct QnnNodeGroup { - using IndexType = size_t; - enum class Type : uint8_t { - Undefined = 0, - NodeUnit, - ConvActivationFusion, - DQQFusion, - HardSigmoidMulFusion, - COUNT, - }; +class IQnnNodeGroup { + public: + virtual ~IQnnNodeGroup() = default; + virtual Status IsSupported(QnnModelWrapper& qmw, const logging::Logger& logger) const = 0; + virtual Status AddToModelBuilder(QnnModelWrapper& qmw, const logging::Logger& logger) const = 0; + virtual std::vector GetNodeUnits() const = 0; + virtual const NodeUnit* GetTargetNodeUnit() const = 0; + virtual std::string_view Type() const = 0; - static std::string_view TypeToString(QnnNodeGroup::Type type); - - Status IsSupported(QnnModelWrapper& qmw, const logging::Logger& logger) const; - Status AddToModelBuilder(QnnModelWrapper& qmw, const logging::Logger& logger) const; - const std::vector& GetNodeUnits() const { return node_units_; } - const NodeUnit* GetTargetNodeUnit(const logging::Logger& logger) const; - - QnnNodeGroup::Type type_ = QnnNodeGroup::Type::Undefined; - IndexType index_ = 0; - std::vector node_units_; + size_t index_ = 0; }; -Status GetQnnNodeGroups(/*out*/ std::vector& qnn_node_groups, +Status GetQnnNodeGroups(/*out*/ std::vector>& qnn_node_groups, QnnModelWrapper& qnn_model_wrapper, const std::unordered_map& node_to_node_unit, size_t num_node_units, diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/conv_activation_fusion.cc b/onnxruntime/core/providers/qnn/builder/qnn_node_group/conv_activation_fusion.cc index ac31a588a70f5..4da4a748f801d 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_node_group/conv_activation_fusion.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/conv_activation_fusion.cc @@ -246,7 +246,7 @@ static std::vector FindQDQNodes(const GraphViewer& graph_viewer, co static std::vector GetConvDQs( const GraphViewer& graph_viewer, const std::unordered_map& node_to_node_unit, - const std::unordered_map& node_unit_to_qnn_node_group, + const std::unordered_map& node_unit_to_qnn_node_group, const Node& conv_node) { assert(conv_node.OpType() == "Conv" || conv_node.OpType() == "ConvTranspose"); std::vector dq_nodes = FindQDQNodes(graph_viewer, conv_node, /*find_dq_nodes*/ true); @@ -324,7 +324,7 @@ static bool IsValidQDQConv(gsl::span dq_node_units, } static Status QnnConvActivationFusionAdd(QnnModelWrapper& qnn_model_wrapper, - gsl::span dq_node_units, + gsl::span dq_node_units, const NodeUnit* conv_node_unit, const NodeUnit* q_node_unit, const logging::Logger& logger, @@ -398,14 +398,19 @@ static Status QnnConvActivationFusionAdd(QnnModelWrapper& qnn_model_wrapper, return conv_op_builder->AddToModelBuilder(qnn_model_wrapper, custom_node_unit, logger, validate); } -std::optional TryConvActivationFusion(QnnModelWrapper& qnn_model_wrapper, - const NodeUnit& conv_node_unit, - const std::unordered_map& node_to_node_unit, - const std::unordered_map& node_unit_to_qnn_node_group, - const logging::Logger& logger) { +std::unique_ptr TryConvActivationFusion(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& conv_node_unit, + const std::unordered_map& node_to_node_unit, + const std::unordered_map& node_unit_to_qnn_node_group, + const logging::Logger& logger) { + ORT_UNUSED_PARAMETER(logger); // Expect that this function is called with a standalone Conv or ConvTranspose. - assert((conv_node_unit.OpType() == "Conv" || conv_node_unit.OpType() == "ConvTranspose") && - conv_node_unit.UnitType() == NodeUnit::Type::SingleNode); + const auto& conv_type = conv_node_unit.OpType(); + + if ((conv_type != "Conv" && conv_type != "ConvTranspose") || + (conv_node_unit.UnitType() != NodeUnit::Type::SingleNode)) { + return nullptr; + } const GraphViewer& graph_viewer = qnn_model_wrapper.GetGraphViewer(); @@ -414,7 +419,7 @@ std::optional TryConvActivationFusion(QnnModelWrapper& qnn_model_w const NodeUnit* activation_node_unit = GetOnlyChildOfType(graph_viewer, conv_node_unit, activation_op_types, node_to_node_unit, node_unit_to_qnn_node_group); if (activation_node_unit == nullptr) { - return std::nullopt; + return nullptr; } // Relu/Clip must have a single Q child. @@ -423,107 +428,85 @@ std::optional TryConvActivationFusion(QnnModelWrapper& qnn_model_w node_to_node_unit, node_unit_to_qnn_node_group); if (q_node_unit == nullptr) { - return std::nullopt; + return nullptr; } // Check if Clip/Relu can be removed because the Q node provides an equivalent effect. if (!CanActivationBeRemoved(qnn_model_wrapper, *activation_node_unit, *q_node_unit)) { - return std::nullopt; + return nullptr; } // Create a QDQ node group with DQ* -> Conv -> Q const Node& conv_node = conv_node_unit.GetNode(); - const Node& activation_node = activation_node_unit->GetNode(); std::vector dq_node_units = GetConvDQs(graph_viewer, node_to_node_unit, node_unit_to_qnn_node_group, conv_node); if (!IsValidQDQConv(dq_node_units, q_node_unit)) { - return std::nullopt; + return nullptr; } - LOGS(logger, VERBOSE) << "Will use Conv + Activation via fusion. conv_node name: [" << conv_node.Name() - << "] activation_node optype: [" << activation_node.OpType() - << "] activation_node name: [" << activation_node.Name() - << "]"; - - std::optional qnn_node_group = QnnNodeGroup{}; - qnn_node_group->type_ = QnnNodeGroup::Type::ConvActivationFusion; - qnn_node_group->node_units_ = std::move(dq_node_units); - qnn_node_group->node_units_.push_back(&conv_node_unit); - qnn_node_group->node_units_.push_back(activation_node_unit); - qnn_node_group->node_units_.push_back(q_node_unit); - - return qnn_node_group; + return std::make_unique(dq_node_units, conv_node_unit, + *activation_node_unit, *q_node_unit); } namespace conv_act_fusion { +QnnNodeGroup::QnnNodeGroup(gsl::span dq_node_units, + const NodeUnit& conv_node_unit, + const NodeUnit& activation_node_unit, + const NodeUnit& q_node_unit) + : dq_node_units_{}, + conv_node_unit_(conv_node_unit), + activation_node_unit_(activation_node_unit), + q_node_unit_(q_node_unit) { + assert(dq_node_units.size() <= dq_node_units_.size()); + std::copy(dq_node_units.begin(), dq_node_units.end(), dq_node_units_.data()); +} -Status IsSupported(QnnModelWrapper& qmw, const QnnNodeGroup& qnn_node_group, const logging::Logger& logger) { - const size_t num_node_units = qnn_node_group.node_units_.size(); - ORT_RETURN_IF_NOT((num_node_units == 5 || num_node_units == 6), ""); - - const bool has_bias_dq = num_node_units == 6; - std::vector dq_node_units = {qnn_node_group.node_units_[0], qnn_node_group.node_units_[1]}; - const NodeUnit* conv_node_unit = qnn_node_group.node_units_[num_node_units - 3]; - const NodeUnit* activation_node_unit = qnn_node_group.node_units_[num_node_units - 2]; - const NodeUnit* q_node_unit = qnn_node_group.node_units_[num_node_units - 1]; - - if (has_bias_dq) { - dq_node_units.push_back(qnn_node_group.node_units_[2]); - } - Status status = QnnConvActivationFusionAdd(qmw, - dq_node_units, - conv_node_unit, - q_node_unit, - logger, - /*validate*/ true); +Status QnnNodeGroup::IsSupported(QnnModelWrapper& qmw, const logging::Logger& logger) const { + const size_t num_dqs = dq_node_units_.back() != nullptr ? 3 : 2; + gsl::span dq_node_units(dq_node_units_.data(), num_dqs); - if (!status.IsOK()) { - LOGS(logger, ERROR) << conv_node_unit->OpType() << "/" << activation_node_unit->OpType() - << " fusion is not supported, but should be according to initial validation." - << " Node names: " << conv_node_unit->Name() << ", " << activation_node_unit->Name() - << " Error: " << status.ErrorMessage(); - } - - return status; + return QnnConvActivationFusionAdd(qmw, + dq_node_units, + &conv_node_unit_, + &q_node_unit_, + logger, + /*validate*/ true); } -Status AddToModelBuilder(QnnModelWrapper& qmw, const QnnNodeGroup& qnn_node_group, const logging::Logger& logger) { - const size_t num_node_units = qnn_node_group.node_units_.size(); - ORT_RETURN_IF_NOT((num_node_units == 5 || num_node_units == 6), ""); +Status QnnNodeGroup::AddToModelBuilder(QnnModelWrapper& qmw, const logging::Logger& logger) const { + const size_t num_dqs = dq_node_units_.back() != nullptr ? 3 : 2; + gsl::span dq_node_units(dq_node_units_.data(), num_dqs); - const bool has_bias_dq = num_node_units == 6; - std::vector dq_node_units = {qnn_node_group.node_units_[0], qnn_node_group.node_units_[1]}; - const NodeUnit* conv_node_unit = qnn_node_group.node_units_[num_node_units - 3]; - const NodeUnit* q_node_unit = qnn_node_group.node_units_[num_node_units - 1]; - - if (has_bias_dq) { - dq_node_units.push_back(qnn_node_group.node_units_[2]); - } return QnnConvActivationFusionAdd(qmw, dq_node_units, - conv_node_unit, - q_node_unit, + &conv_node_unit_, + &q_node_unit_, logger, /*validate*/ false); } -#if 0 -const std::vector& GetNodeUnits(const QnnNodeGroup& qnn_node_group) { - return qnn_node_group.node_units_; -} -#endif +std::vector QnnNodeGroup::GetNodeUnits() const { + const size_t num_dqs = dq_node_units_.back() != nullptr ? 3 : 2; -const NodeUnit* GetTargetNodeUnit(const QnnNodeGroup& qnn_node_group, const logging::Logger& logger) { - ORT_UNUSED_PARAMETER(logger); - const size_t num_node_units = qnn_node_group.node_units_.size(); - if (!(num_node_units == 5 || num_node_units == 6)) { - return nullptr; + std::vector node_units; + node_units.reserve(6); + for (size_t i = 0; i < num_dqs; i++) { + node_units.push_back(dq_node_units_[i]); } - return qnn_node_group.node_units_[num_node_units - 3]; + node_units.push_back(&conv_node_unit_); + node_units.push_back(&activation_node_unit_); + node_units.push_back(&q_node_unit_); + + return node_units; } + +const NodeUnit* QnnNodeGroup::GetTargetNodeUnit() const { + return &conv_node_unit_; +} + } // namespace conv_act_fusion } // namespace qnn } // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/conv_activation_fusion.h b/onnxruntime/core/providers/qnn/builder/qnn_node_group/conv_activation_fusion.h index 50b02595f5f72..a195c86d2393a 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_node_group/conv_activation_fusion.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/conv_activation_fusion.h @@ -3,9 +3,10 @@ #pragma once -#include +#include +#include +#include #include -#include #include #include "core/framework/node_unit.h" @@ -15,19 +16,36 @@ namespace onnxruntime { namespace qnn { -std::optional TryConvActivationFusion( +std::unique_ptr TryConvActivationFusion( QnnModelWrapper& qnn_model_wrapper, const NodeUnit& conv_node_unit, const std::unordered_map& node_to_node_unit, - const std::unordered_map& node_unit_to_qnn_node_group, + const std::unordered_map& node_unit_to_qnn_node_group, const logging::Logger& logger); namespace conv_act_fusion { -Status IsSupported(QnnModelWrapper& qmw, const QnnNodeGroup& qnn_node_group, const logging::Logger& logger); -Status AddToModelBuilder(QnnModelWrapper& qmw, const QnnNodeGroup& qnn_node_group, const logging::Logger& logger); -// const std::vector& GetNodeUnits(const QnnNodeGroup& qnn_node_group); -const NodeUnit* GetTargetNodeUnit(const QnnNodeGroup& qnn_node_group, const logging::Logger& logger); +class QnnNodeGroup : public IQnnNodeGroup { + public: + QnnNodeGroup(gsl::span dq_node_units, + const NodeUnit& conv_node_unit, + const NodeUnit& activation_node_unit, + const NodeUnit& q_node_unit); + ORT_DISALLOW_COPY_AND_ASSIGNMENT(QnnNodeGroup); + + Status IsSupported(QnnModelWrapper& qmw, const logging::Logger& logger) const override; + Status AddToModelBuilder(QnnModelWrapper& qmw, const logging::Logger& logger) const override; + std::vector GetNodeUnits() const override; + const NodeUnit* GetTargetNodeUnit() const override; + std::string_view Type() const override { return "ConvActivationFusion"; } + + private: + std::array dq_node_units_; // Last DQ is nullptr if bias is missing. + const NodeUnit& conv_node_unit_; + const NodeUnit& activation_node_unit_; + const NodeUnit& q_node_unit_; +}; + } // namespace conv_act_fusion } // namespace qnn } // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/dq_q_fusion.cc b/onnxruntime/core/providers/qnn/builder/qnn_node_group/dq_q_fusion.cc index 3f513af8b666c..ac782b5b1420a 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_node_group/dq_q_fusion.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/dq_q_fusion.cc @@ -56,15 +56,15 @@ static Status QnnDQQFusionAdd(QnnModelWrapper& qnn_model_wrapper, return Status::OK(); } -std::optional TryDQQFusion( +std::unique_ptr TryDQQFusion( QnnModelWrapper& qnn_model_wrapper, const NodeUnit& dq_node_unit, const std::unordered_map& node_to_node_unit, - const std::unordered_map& node_unit_to_qnn_node_group, + const std::unordered_map& node_unit_to_qnn_node_group, const logging::Logger& logger) { // Expect that this function is called with a standalone DQ. if (dq_node_unit.OpType() != "DequantizeLinear" || dq_node_unit.UnitType() != NodeUnit::Type::SingleNode) { - return std::nullopt; + return nullptr; } const GraphViewer& graph_viewer = qnn_model_wrapper.GetGraphViewer(); @@ -76,7 +76,7 @@ std::optional TryDQQFusion( node_to_node_unit, node_unit_to_qnn_node_group); if (q_node_unit == nullptr) { - return std::nullopt; + return nullptr; } auto get_const_initializer = [&graph_viewer](const std::string& initializer_name) { @@ -85,69 +85,39 @@ std::optional TryDQQFusion( // DQ and Q must have equal scale type and different zp type. if (!QDQ::IsDQQConversion(dq_node, q_node_unit->GetNode(), get_const_initializer, graph_viewer.ModelPath())) { - return std::nullopt; + return nullptr; } if (Status status = QnnDQQFusionAdd(qnn_model_wrapper, dq_node_unit, *q_node_unit, logger, /*validate*/ true); !status.IsOK()) { - return std::nullopt; + return nullptr; } - // Validation passed, so create a QnnNodeGroup. - LOGS(logger, VERBOSE) << " Will use QNN Convert via fusion. dq_node name: [" << dq_node.Name() - << "] dq_node optype: [" << dq_node.OpType() - << "] q_node name: [" << q_node_unit->Name() - << "] q_node optype: [" << q_node_unit->OpType() - << "]"; - - std::optional qnn_node_group = QnnNodeGroup{}; - qnn_node_group->type_ = QnnNodeGroup::Type::DQQFusion; - qnn_node_group->node_units_.push_back(&dq_node_unit); - qnn_node_group->node_units_.push_back(q_node_unit); - + std::unique_ptr qnn_node_group = std::make_unique(dq_node_unit, + *q_node_unit); return qnn_node_group; } namespace dq_q_fusion { +QnnNodeGroup::QnnNodeGroup(const NodeUnit& dq_node_unit, const NodeUnit& q_node_unit) + : dq_node_unit_(dq_node_unit), q_node_unit_(q_node_unit) { +} -Status IsSupported(QnnModelWrapper& qmw, const QnnNodeGroup& qnn_node_group, const logging::Logger& logger) { - ORT_RETURN_IF_NOT(qnn_node_group.node_units_.size() == 2, "Expected 2 NodeUnits for DQ -> Q fusion"); - const NodeUnit* dq_node_unit = qnn_node_group.node_units_[0]; - const NodeUnit* q_node_unit = qnn_node_group.node_units_[1]; - ORT_RETURN_IF_NOT(dq_node_unit != nullptr && q_node_unit != nullptr, ""); - Status status = QnnDQQFusionAdd(qmw, *dq_node_unit, *q_node_unit, logger, /*validate*/ true); - - if (!status.IsOK()) { - LOGS(logger, ERROR) << "(DQ -> Q) into QNN Convert fusion is not supported, " - << "but should be according to initial validation. " - << "Node names: " << dq_node_unit->Name() << ", " << q_node_unit->Name() - << " Error: " << status.ErrorMessage(); - } - - return status; +Status QnnNodeGroup::IsSupported(QnnModelWrapper& qmw, const logging::Logger& logger) const { + return QnnDQQFusionAdd(qmw, dq_node_unit_, q_node_unit_, logger, /*validate*/ true); } -Status AddToModelBuilder(QnnModelWrapper& qmw, const QnnNodeGroup& qnn_node_group, const logging::Logger& logger) { - ORT_RETURN_IF_NOT(qnn_node_group.node_units_.size() == 2, "Expected 2 NodeUnits for DQ -> Q fusion"); - const NodeUnit* dq_node_unit = qnn_node_group.node_units_[0]; - const NodeUnit* q_node_unit = qnn_node_group.node_units_[1]; - ORT_RETURN_IF_NOT(dq_node_unit != nullptr && q_node_unit != nullptr, ""); - return QnnDQQFusionAdd(qmw, *dq_node_unit, *q_node_unit, logger, /*validate*/ false); +Status QnnNodeGroup::AddToModelBuilder(QnnModelWrapper& qmw, const logging::Logger& logger) const { + return QnnDQQFusionAdd(qmw, dq_node_unit_, q_node_unit_, logger, /*validate*/ false); } -#if 0 -const std::vector& GetNodeUnits(const QnnNodeGroup& qnn_node_group) { - return qnn_node_group.node_units_; +std::vector QnnNodeGroup::GetNodeUnits() const { + return std::vector{&dq_node_unit_, &q_node_unit_}; } -#endif -const NodeUnit* GetTargetNodeUnit(const QnnNodeGroup& qnn_node_group, const logging::Logger& logger) { - ORT_UNUSED_PARAMETER(logger); - if (qnn_node_group.node_units_.size() != 2) { - return nullptr; - } - return qnn_node_group.node_units_[0]; +const NodeUnit* QnnNodeGroup::GetTargetNodeUnit() const { + return &dq_node_unit_; } } // namespace dq_q_fusion diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/dq_q_fusion.h b/onnxruntime/core/providers/qnn/builder/qnn_node_group/dq_q_fusion.h index 5a0529e9c4fda..2e5b612c41a81 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_node_group/dq_q_fusion.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/dq_q_fusion.h @@ -3,11 +3,11 @@ #pragma once -#include +#include #include -#include #include +#include "core/common/common.h" #include "core/framework/node_unit.h" #include "core/providers/qnn/builder/qnn_model_wrapper.h" #include "core/providers/qnn/builder/qnn_node_group.h" @@ -27,19 +27,31 @@ namespace qnn { * \param do_op_validation True if should call QNN operator validation APIs. * \return An onnxruntime::Status */ -std::optional TryDQQFusion( +std::unique_ptr TryDQQFusion( QnnModelWrapper& qnn_model_wrapper, const NodeUnit& dq_node_unit, const std::unordered_map& node_to_node_unit, - const std::unordered_map& node_unit_to_qnn_node_group, + const std::unordered_map& node_unit_to_qnn_node_group, const logging::Logger& logger); namespace dq_q_fusion { -Status IsSupported(QnnModelWrapper& qmw, const QnnNodeGroup& qnn_node_group, const logging::Logger& logger); -Status AddToModelBuilder(QnnModelWrapper& qmw, const QnnNodeGroup& qnn_node_group, const logging::Logger& logger); -// const std::vector& GetNodeUnits(const QnnNodeGroup& qnn_node_group); -const NodeUnit* GetTargetNodeUnit(const QnnNodeGroup& qnn_node_group, const logging::Logger& logger); +class QnnNodeGroup : public IQnnNodeGroup { + public: + QnnNodeGroup(const NodeUnit& dq_node_unit, const NodeUnit& q_node_unit); + ORT_DISALLOW_COPY_AND_ASSIGNMENT(QnnNodeGroup); + + Status IsSupported(QnnModelWrapper& qmw, const logging::Logger& logger) const override; + Status AddToModelBuilder(QnnModelWrapper& qmw, const logging::Logger& logger) const override; + std::vector GetNodeUnits() const override; + const NodeUnit* GetTargetNodeUnit() const override; + std::string_view Type() const override { return "DQQFusion"; } + + private: + const NodeUnit& dq_node_unit_; + const NodeUnit& q_node_unit_; +}; + } // namespace dq_q_fusion } // namespace qnn } // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/hardsigmoid_mul_fusion.cc b/onnxruntime/core/providers/qnn/builder/qnn_node_group/hardsigmoid_mul_fusion.cc index 6411fb1b68e67..817e2190e7825 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_node_group/hardsigmoid_mul_fusion.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/hardsigmoid_mul_fusion.cc @@ -59,16 +59,16 @@ static Status QnnHardSigmoidMulFusionAdd(QnnModelWrapper& qnn_model_wrapper, return Status::OK(); } -std::optional TryHardSigmoidMulFusion( +std::unique_ptr TryHardSigmoidMulFusion( QnnModelWrapper& qnn_model_wrapper, const NodeUnit& hardsigmoid_node_unit, const std::unordered_map& node_to_node_unit, - const std::unordered_map& node_unit_to_qnn_node_group, + const std::unordered_map& node_unit_to_qnn_node_group, const logging::Logger& logger) { // Looking for a standalone HardSigmoid to start the sequence. if (hardsigmoid_node_unit.OpType() != "HardSigmoid" || hardsigmoid_node_unit.UnitType() != NodeUnit::Type::SingleNode) { - return std::nullopt; + return nullptr; } NodeAttrHelper hs_attr_helper(hardsigmoid_node_unit); @@ -81,7 +81,7 @@ std::optional TryHardSigmoidMulFusion( // Check for explicit values of alpha and beta. if (std::abs(alpha - req_alpha) > alpha_eps || std::abs(beta - req_beta) > beta_eps) { - return std::nullopt; + return nullptr; } // HardSigmoid must have a single Mul child (1 output edge) and must not produce a graph output. @@ -91,7 +91,7 @@ std::optional TryHardSigmoidMulFusion( node_to_node_unit, node_unit_to_qnn_node_group); if (mul_node_unit == nullptr) { - return std::nullopt; + return nullptr; } // Input to HardSigmoid must also be the other input to the Mul. @@ -101,67 +101,38 @@ std::optional TryHardSigmoidMulFusion( mul_node.InputDefs()[1]->Name() == hs_input_name; if (!same_root_input) { - return std::nullopt; + return nullptr; } if (Status status = QnnHardSigmoidMulFusionAdd(qnn_model_wrapper, hardsigmoid_node_unit, *mul_node_unit, logger, /*validate*/ true); !status.IsOK()) { - return std::nullopt; + return nullptr; } - // Validation passed, so create a QnnNodeGroup. Any errors are now passed back to the caller. - LOGS(logger, VERBOSE) << "Will use QNN HardSwish via fusion. HardSigmoid name: [" << hardsigmoid_node_unit.Name() - << "] Mul name: [" << mul_node_unit->Name() << "]"; - - std::optional qnn_node_group = QnnNodeGroup{}; - qnn_node_group->type_ = QnnNodeGroup::Type::HardSigmoidMulFusion; - qnn_node_group->node_units_.push_back(&hardsigmoid_node_unit); - qnn_node_group->node_units_.push_back(mul_node_unit); - - return qnn_node_group; + return std::make_unique(hardsigmoid_node_unit, *mul_node_unit); } namespace hs_mul_fusion { -Status IsSupported(QnnModelWrapper& qmw, const QnnNodeGroup& qnn_node_group, const logging::Logger& logger) { - ORT_RETURN_IF_NOT(qnn_node_group.node_units_.size() == 2, "Expected 2 NodeUnits for HardSimoid -> Mul fusion"); - const NodeUnit* hardsigmoid_node_unit = qnn_node_group.node_units_[0]; - const NodeUnit* mul_node_unit = qnn_node_group.node_units_[1]; - ORT_RETURN_IF_NOT(hardsigmoid_node_unit != nullptr && mul_node_unit != nullptr, ""); - Status status = QnnHardSigmoidMulFusionAdd(qmw, *hardsigmoid_node_unit, *mul_node_unit, logger, - /*validate*/ true); - - if (!status.IsOK()) { - LOGS(logger, ERROR) << "(HardSigmoid -> Mul) into QNN HardSwish fusion is not supported, " - << "but should be according to initial validation. " - << "Node names: " << hardsigmoid_node_unit->Name() << ", " << mul_node_unit->Name() - << " Error: " << status.ErrorMessage(); - } +QnnNodeGroup::QnnNodeGroup(const NodeUnit& hardsigmoid_node_unit, const NodeUnit& mul_node_unit) + : hardsigmoid_node_unit_(hardsigmoid_node_unit), mul_node_unit_(mul_node_unit) { +} - return status; +Status QnnNodeGroup::IsSupported(QnnModelWrapper& qmw, const logging::Logger& logger) const { + return QnnHardSigmoidMulFusionAdd(qmw, hardsigmoid_node_unit_, mul_node_unit_, logger, /*validate*/ true); } -Status AddToModelBuilder(QnnModelWrapper& qmw, const QnnNodeGroup& qnn_node_group, const logging::Logger& logger) { - ORT_RETURN_IF_NOT(qnn_node_group.node_units_.size() == 2, "Expected 2 NodeUnits for HardSimoid -> Mul fusion"); - const NodeUnit* hardsigmoid_node_unit = qnn_node_group.node_units_[0]; - const NodeUnit* mul_node_unit = qnn_node_group.node_units_[1]; - ORT_RETURN_IF_NOT(hardsigmoid_node_unit != nullptr && mul_node_unit != nullptr, ""); - return QnnHardSigmoidMulFusionAdd(qmw, *hardsigmoid_node_unit, *mul_node_unit, logger, /*validate*/ false); +Status QnnNodeGroup::AddToModelBuilder(QnnModelWrapper& qmw, const logging::Logger& logger) const { + return QnnHardSigmoidMulFusionAdd(qmw, hardsigmoid_node_unit_, mul_node_unit_, logger, /*validate*/ false); } -#if 0 -const std::vector& GetNodeUnits(const QnnNodeGroup& qnn_node_group) { - return qnn_node_group.node_units_; +std::vector QnnNodeGroup::GetNodeUnits() const { + return std::vector{&hardsigmoid_node_unit_, &mul_node_unit_}; } -#endif -const NodeUnit* GetTargetNodeUnit(const QnnNodeGroup& qnn_node_group, const logging::Logger& logger) { - ORT_UNUSED_PARAMETER(logger); - if (qnn_node_group.node_units_.size() != 2) { - return nullptr; - } - return qnn_node_group.node_units_[0]; +const NodeUnit* QnnNodeGroup::GetTargetNodeUnit() const { + return &hardsigmoid_node_unit_; } } // namespace hs_mul_fusion diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/hardsigmoid_mul_fusion.h b/onnxruntime/core/providers/qnn/builder/qnn_node_group/hardsigmoid_mul_fusion.h index 01bb0a39b4990..1cfb6119e3acc 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_node_group/hardsigmoid_mul_fusion.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/hardsigmoid_mul_fusion.h @@ -3,9 +3,8 @@ #pragma once -#include +#include #include -#include #include #include "core/framework/node_unit.h" @@ -29,19 +28,31 @@ namespace qnn { * \param do_op_validation True if should call QNN operator validation APIs. * \return A Status indicating a potential failure. */ -std::optional TryHardSigmoidMulFusion( +std::unique_ptr TryHardSigmoidMulFusion( QnnModelWrapper& qnn_model_wrapper, const NodeUnit& hardsigmoid_node_unit, const std::unordered_map& node_to_node_unit, - const std::unordered_map& node_unit_to_qnn_node_group, + const std::unordered_map& node_unit_to_qnn_node_group, const logging::Logger& logger); namespace hs_mul_fusion { -Status IsSupported(QnnModelWrapper& qmw, const QnnNodeGroup& qnn_node_group, const logging::Logger& logger); -Status AddToModelBuilder(QnnModelWrapper& qmw, const QnnNodeGroup& qnn_node_group, const logging::Logger& logger); -// const std::vector& GetNodeUnits(const QnnNodeGroup& qnn_node_group); -const NodeUnit* GetTargetNodeUnit(const QnnNodeGroup& qnn_node_group, const logging::Logger& logger); +class QnnNodeGroup : public IQnnNodeGroup { + public: + QnnNodeGroup(const NodeUnit& hardsigmoid_node_unit, const NodeUnit& mul_node_unit); + ORT_DISALLOW_COPY_AND_ASSIGNMENT(QnnNodeGroup); + + Status IsSupported(QnnModelWrapper& qmw, const logging::Logger& logger) const override; + Status AddToModelBuilder(QnnModelWrapper& qmw, const logging::Logger& logger) const override; + std::vector GetNodeUnits() const override; + const NodeUnit* GetTargetNodeUnit() const override; + std::string_view Type() const override { return "HardSigmoidMulFusion"; } + + private: + const NodeUnit& hardsigmoid_node_unit_; + const NodeUnit& mul_node_unit_; +}; + } // namespace hs_mul_fusion } // namespace qnn } // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/qnn_node_group.cc b/onnxruntime/core/providers/qnn/builder/qnn_node_group/qnn_node_group.cc index 38c5027fcb8ac..8486d20dd6065 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_node_group/qnn_node_group.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/qnn_node_group.cc @@ -4,10 +4,9 @@ #include "core/providers/qnn/builder/qnn_node_group.h" #include -#include +#include #include #include -#include #include #include #include "core/graph/graph_utils.h" @@ -23,123 +22,51 @@ namespace onnxruntime { namespace qnn { -std::string_view QnnNodeGroup::TypeToString(QnnNodeGroup::Type type) { - static std::array(QnnNodeGroup::Type::COUNT)> type_names = { - "Undefined", - "NodeUnit", - "ConvActivationFusion", - "DQQFusion", - "HardSigmoidMulFusion", - }; +class QnnNodeUnitWrapper : public IQnnNodeGroup { + public: + QnnNodeUnitWrapper(const NodeUnit& node_unit) : node_unit_(node_unit) {} + ORT_DISALLOW_COPY_AND_ASSIGNMENT(QnnNodeUnitWrapper); - return type_names[static_cast(type)]; -} + Status IsSupported(QnnModelWrapper& qmw, const logging::Logger& logger) const override { + const std::string& op_type = node_unit_.OpType(); + const auto* op_builder = qnn::GetOpBuilder(op_type); + ORT_RETURN_IF_NOT(op_builder != nullptr, "Operators of type `", op_type, + "` are not supported by QNN EP.", op_type, " node `", + node_unit_.Name(), "` will not be assigned to QNN EP."); -Status QnnNodeGroup::IsSupported(QnnModelWrapper& qmw, const logging::Logger& logger) const { - using Func = Status (*)( - QnnModelWrapper&, - const QnnNodeGroup&, - const logging::Logger&); - - static std::array(QnnNodeGroup::Type::COUNT)> funcs = { - [](QnnModelWrapper&, const QnnNodeGroup& qnn_node_group, const logging::Logger& logger) -> Status { - std::string error_msg = MakeString("Unhandled QnnNodeGroup::Type ", TypeToString(qnn_node_group.type_), - " in QnnNodeGroup::IsSupported()"); - LOGS(logger, ERROR) << error_msg; - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, error_msg); - }, - [](QnnModelWrapper& qmw, const QnnNodeGroup& qnn_node_group, const logging::Logger& logger) -> Status { - ORT_RETURN_IF_NOT(qnn_node_group.node_units_.size() == 1 && qnn_node_group.node_units_[0] != nullptr, ""); - const NodeUnit& node_unit = *qnn_node_group.node_units_[0]; - const std::string& op_type = node_unit.OpType(); - const auto* op_builder = qnn::GetOpBuilder(op_type); - - if (op_builder == nullptr) { - std::string err_msg = MakeString("Operators of type `", op_type, - "` are not supported by QNN EP.", op_type, " node `", - node_unit.Name(), "` will not be assigned to QNN EP."); - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, err_msg); - } - - Status status = op_builder->IsOpSupported(qmw, *qnn_node_group.node_units_[0], logger); - if (!status.IsOK()) { - LOGS(logger, WARNING) << op_type << " node `" << node_unit.Name() - << "` is not supported: " << status.ErrorMessage(); - } - - return status; - }, - conv_act_fusion::IsSupported, - dq_q_fusion::IsSupported, - hs_mul_fusion::IsSupported, - }; - - return funcs[static_cast(type_)](qmw, *this, logger); -} + return op_builder->IsOpSupported(qmw, node_unit_, logger); + } -Status QnnNodeGroup::AddToModelBuilder(QnnModelWrapper& qmw, const logging::Logger& logger) const { - using Func = Status (*)( - QnnModelWrapper&, - const QnnNodeGroup&, - const logging::Logger&); - - static std::array(QnnNodeGroup::Type::COUNT)> funcs = { - [](QnnModelWrapper&, const QnnNodeGroup& qnn_node_group, const logging::Logger& logger) -> Status { - std::string error_msg = MakeString("Unhandled QnnNodeGroup::Type ", TypeToString(qnn_node_group.type_), - " in QnnNodeGroup::AddToModelBuilder()"); - LOGS(logger, ERROR) << error_msg; - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, error_msg); - }, - [](QnnModelWrapper& qmw, const QnnNodeGroup& qnn_node_group, const logging::Logger& logger) -> Status { - ORT_RETURN_IF_NOT(qnn_node_group.node_units_.size() == 1 && qnn_node_group.node_units_[0] != nullptr, ""); - const auto* op_builder = qnn::GetOpBuilder(qnn_node_group.node_units_[0]->OpType()); - ORT_RETURN_IF_NOT(op_builder != nullptr, "[QNN EP]: Missing OpBuilder for OpType ", qnn_node_group.node_units_[0]->OpType()); - return op_builder->AddToModelBuilder(qmw, *qnn_node_group.node_units_[0], logger, /*do_op_validation*/ false); - }, - conv_act_fusion::AddToModelBuilder, - dq_q_fusion::AddToModelBuilder, - hs_mul_fusion::AddToModelBuilder, - }; + Status AddToModelBuilder(QnnModelWrapper& qmw, const logging::Logger& logger) const override { + const std::string& op_type = node_unit_.OpType(); + const auto* op_builder = qnn::GetOpBuilder(op_type); + ORT_RETURN_IF_NOT(op_builder != nullptr, "[QNN EP]: Missing OpBuilder for OpType ", op_type); + return op_builder->AddToModelBuilder(qmw, node_unit_, logger, /*do_op_validation*/ false); + } - return funcs[static_cast(type_)](qmw, *this, logger); -} + std::vector GetNodeUnits() const override { + return std::vector{&node_unit_}; + } -const NodeUnit* QnnNodeGroup::GetTargetNodeUnit(const logging::Logger& logger) const { - using Func = const NodeUnit* (*)(const QnnNodeGroup&, const logging::Logger&); - - static std::array(QnnNodeGroup::Type::COUNT)> funcs = { - [](const QnnNodeGroup& qnn_node_group, const logging::Logger& logger) -> const NodeUnit* { - std::string error_msg = MakeString("Unhandled QnnNodeGroup::Type ", TypeToString(qnn_node_group.type_), - " in QnnNodeGroup::AddToModelBuilder()"); - LOGS(logger, ERROR) << error_msg; - return nullptr; - }, - [](const QnnNodeGroup& qnn_node_group, const logging::Logger&) -> const NodeUnit* { - if (qnn_node_group.node_units_.size() != 1) { - return nullptr; - } - return qnn_node_group.node_units_[0]; - }, - conv_act_fusion::GetTargetNodeUnit, - dq_q_fusion::GetTargetNodeUnit, - hs_mul_fusion::GetTargetNodeUnit, - }; + const NodeUnit* GetTargetNodeUnit() const override { return &node_unit_; } + std::string_view Type() const override { return "NodeUnitWrapper"; } - return funcs[static_cast(type_)](*this, logger); -} + private: + const NodeUnit& node_unit_; +}; -using FusionFunc = std::optional (*)( +using FusionFunc = std::unique_ptr (*)( QnnModelWrapper&, const NodeUnit&, const std::unordered_map&, - const std::unordered_map&, + const std::unordered_map&, const logging::Logger&); -static std::optional TryQnnFusions( +static std::unique_ptr TryQnnFusions( QnnModelWrapper& qnn_model_wrapper, const NodeUnit& starting_node_unit, const std::unordered_map& node_to_node_unit, - const std::unordered_map& node_unit_to_qnn_node_group, + const std::unordered_map& node_unit_to_qnn_node_group, const logging::Logger& logger) { // Maps a starting operator type to the fusion function. static std::unordered_map fusions = { @@ -151,7 +78,7 @@ static std::optional TryQnnFusions( // For now, all fusions involve standalone node units (i.e., no wrapping DQ/Q nodes). if (starting_node_unit.UnitType() != NodeUnit::Type::SingleNode) { - return std::nullopt; + return nullptr; } auto iter = fusions.find(starting_node_unit.OpType()); @@ -160,10 +87,10 @@ static std::optional TryQnnFusions( return fusion_func(qnn_model_wrapper, starting_node_unit, node_to_node_unit, node_unit_to_qnn_node_group, logger); } - return std::nullopt; + return nullptr; } -Status GetQnnNodeGroups(/*out*/ std::vector& qnn_node_groups, +Status GetQnnNodeGroups(/*out*/ std::vector>& qnn_node_groups, QnnModelWrapper& qnn_model_wrapper, const std::unordered_map& node_to_node_unit, const size_t num_node_units, @@ -171,14 +98,14 @@ Status GetQnnNodeGroups(/*out*/ std::vector& qnn_node_groups, const GraphViewer& graph_viewer = qnn_model_wrapper.GetGraphViewer(); const std::vector sorted_node_indices = graph_viewer.GetNodesInTopologicalOrder(); - std::vector sorted_qnn_node_group_indices; + std::vector sorted_qnn_node_group_indices; sorted_qnn_node_group_indices.reserve(num_node_units); - std::vector tmp_qnn_node_groups; + std::vector> tmp_qnn_node_groups; tmp_qnn_node_groups.reserve(num_node_units); { - std::unordered_map node_unit_to_qnn_node_group; + std::unordered_map node_unit_to_qnn_node_group; std::vector> sorted_node_units; sorted_node_units.reserve(num_node_units); @@ -202,20 +129,20 @@ Status GetQnnNodeGroups(/*out*/ std::vector& qnn_node_groups, continue; // Already handled this node unit } - std::optional fused_node_group = TryQnnFusions(qnn_model_wrapper, *node_unit, - node_to_node_unit, node_unit_to_qnn_node_group, - logger); + std::unique_ptr fused_node_group = TryQnnFusions(qnn_model_wrapper, *node_unit, + node_to_node_unit, node_unit_to_qnn_node_group, + logger); - if (fused_node_group.has_value()) { - const QnnNodeGroup::IndexType index = tmp_qnn_node_groups.size(); + if (fused_node_group) { + const size_t index = tmp_qnn_node_groups.size(); fused_node_group->index_ = index; for (const NodeUnit* fused_node_unit : fused_node_group->GetNodeUnits()) { assert(fused_node_unit != nullptr); - node_unit_to_qnn_node_group.insert({fused_node_unit, index}); + node_unit_to_qnn_node_group.insert({fused_node_unit, fused_node_group.get()}); } - tmp_qnn_node_groups.push_back(std::move(*fused_node_group)); + tmp_qnn_node_groups.push_back(std::move(fused_node_group)); } } @@ -224,22 +151,19 @@ Status GetQnnNodeGroups(/*out*/ std::vector& qnn_node_groups, const auto it = node_unit_to_qnn_node_group.find(node_unit); if (it != node_unit_to_qnn_node_group.end()) { // Already handled this NodeUnit. - const QnnNodeGroup& qnn_node_group = tmp_qnn_node_groups[it->second]; - if (node_unit == qnn_node_group.GetTargetNodeUnit(logger)) { - sorted_qnn_node_group_indices.push_back(qnn_node_group.index_); + gsl::not_null qnn_node_group = it->second; + if (node_unit == qnn_node_group->GetTargetNodeUnit()) { + sorted_qnn_node_group_indices.push_back(qnn_node_group->index_); } continue; } - const QnnNodeGroup::IndexType index = tmp_qnn_node_groups.size(); - QnnNodeGroup fused_node_group = {}; - fused_node_group.type_ = QnnNodeGroup::Type::NodeUnit; - fused_node_group.index_ = index; - fused_node_group.node_units_.resize(1); - fused_node_group.node_units_[0] = node_unit; + const size_t index = tmp_qnn_node_groups.size(); + auto fused_node_group = std::make_unique(*node_unit); + fused_node_group->index_ = index; tmp_qnn_node_groups.push_back(std::move(fused_node_group)); - node_unit_to_qnn_node_group.insert({node_unit, index}); + node_unit_to_qnn_node_group.insert({node_unit, fused_node_group.get()}); sorted_qnn_node_group_indices.push_back(index); } @@ -251,8 +175,8 @@ Status GetQnnNodeGroups(/*out*/ std::vector& qnn_node_groups, qnn_node_groups.reserve(tmp_qnn_node_groups.size()); for (auto index : sorted_qnn_node_group_indices) { assert(index < tmp_qnn_node_groups.size()); - QnnNodeGroup qnn_node_group = std::move(tmp_qnn_node_groups[index]); - qnn_node_group.index_ = qnn_node_groups.size(); + std::unique_ptr qnn_node_group = std::move(tmp_qnn_node_groups[index]); + qnn_node_group->index_ = qnn_node_groups.size(); qnn_node_groups.push_back(std::move(qnn_node_group)); } diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/utils.cc b/onnxruntime/core/providers/qnn/builder/qnn_node_group/utils.cc index 923ea8314c907..1bcdb26be3400 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_node_group/utils.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/utils.cc @@ -15,7 +15,7 @@ const NodeUnit* GetOnlyChildOfType(const GraphViewer& graph_viewer, const NodeUnit& parent_node_unit, gsl::span child_op_types, const std::unordered_map& node_unit_map, - const std::unordered_map& node_unit_to_qnn_node_group) { + const std::unordered_map& node_unit_to_qnn_node_group) { const Node& parent_node = parent_node_unit.GetNode(); // Parent must have a single child (1 output edge) and must not produce a graph output. diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/utils.h b/onnxruntime/core/providers/qnn/builder/qnn_node_group/utils.h index 4f67fabab9133..308d08d42d87c 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_node_group/utils.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/utils.h @@ -17,7 +17,7 @@ const NodeUnit* GetOnlyChildOfType(const GraphViewer& graph_viewer, const NodeUnit& parent_node_unit, gsl::span child_op_types, const std::unordered_map& node_unit_map, - const std::unordered_map& node_unit_to_qnn_node_group); + const std::unordered_map& node_unit_to_qnn_node_group); } // namespace qnn } // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc index dbc8394da34a6..58bfacc5cd73d 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc @@ -19,6 +19,7 @@ #include "core/providers/partitioning_utils.h" #include "core/providers/qnn/builder/qnn_model_wrapper.h" #include "core/providers/qnn/builder/op_builder_factory.h" +#include "core/providers/qnn/builder/qnn_node_group.h" #include "core/providers/qnn/builder/qnn_def.h" #include "core/providers/qnn/builder/onnx_ctx_model_helper.h" #include "core/framework/run_options.h" @@ -440,7 +441,7 @@ QNNExecutionProvider::GetSupportedNodes(const GraphViewer& graph_viewer, initializer_input_lookup, qnn_backend_manager_->GetQnnBackendType()); - std::vector qnn_node_groups; + std::vector> qnn_node_groups; qnn_node_groups.reserve(node_unit_size); if (Status status = qnn::GetQnnNodeGroups(qnn_node_groups, qnn_model_wrapper, @@ -454,7 +455,7 @@ QNNExecutionProvider::GetSupportedNodes(const GraphViewer& graph_viewer, logging::Severity log_severity, logging::DataType log_data_type, const onnxruntime::CodeLocation& call_site, - const qnn::QnnNodeGroup& qnn_node_group, + const qnn::IQnnNodeGroup& qnn_node_group, bool supported) { if (!logger.OutputIsEnabled(log_severity, log_data_type)) { return; @@ -462,7 +463,7 @@ QNNExecutionProvider::GetSupportedNodes(const GraphViewer& graph_viewer, std::ostringstream oss; oss << "[QNN EP] " << (supported ? "Supports " : "Does NOT support ") << "the following nodes as part of a " - << qnn::QnnNodeGroup::TypeToString(qnn_node_group.type_) << " group:" << std::endl; + << qnn_node_group.Type() << " group:" << std::endl; for (const NodeUnit* node_unit : qnn_node_group.GetNodeUnits()) { for (const Node* node : node_unit->GetAllNodesInGroup()) { oss << "\tOperator type: " << node->OpType() @@ -477,18 +478,18 @@ QNNExecutionProvider::GetSupportedNodes(const GraphViewer& graph_viewer, << oss.str(); }; - for (const qnn::QnnNodeGroup& qnn_node_group : qnn_node_groups) { - Status status = qnn_node_group.IsSupported(qnn_model_wrapper, logger); + for (const std::unique_ptr& qnn_node_group : qnn_node_groups) { + Status status = qnn_node_group->IsSupported(qnn_model_wrapper, logger); const bool supported = status.IsOK(); constexpr auto log_severity = logging::Severity::kVERBOSE; constexpr auto log_data_type = logging::DataType::SYSTEM; if (logger.OutputIsEnabled(log_severity, log_data_type)) { - log_node_support(logger, log_severity, log_data_type, ORT_WHERE, qnn_node_group, supported); + log_node_support(logger, log_severity, log_data_type, ORT_WHERE, *qnn_node_group, supported); } if (supported) { - for (const NodeUnit* node_unit : qnn_node_group.GetNodeUnits()) { + for (const NodeUnit* node_unit : qnn_node_group->GetNodeUnits()) { for (const Node* node : node_unit->GetAllNodesInGroup()) { supported_nodes.insert(node); } diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.h b/onnxruntime/core/providers/qnn/qnn_execution_provider.h index 0e677e156a491..b9e3608856b65 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.h +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.h @@ -11,7 +11,6 @@ #include "core/providers/qnn/builder/qnn_backend_manager.h" #include "core/providers/qnn/builder/qnn_model.h" #include "core/providers/qnn/builder/qnn_configs_helper.h" -#include "core/providers/qnn/builder/qnn_node_group.h" #include "HTP/QnnHtpGraph.h" #include #include From 576c2f87294ad8e7e9ea7e9ad18e0694bf51d7a6 Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Sun, 28 Jul 2024 10:12:28 -0700 Subject: [PATCH 13/20] Clean up --- onnxruntime/core/framework/node_unit.cc | 6 +- onnxruntime/core/framework/node_unit.h | 6 +- .../providers/qnn/builder/qnn_model_wrapper.h | 6 -- .../providers/qnn/builder/qnn_node_group.h | 5 +- .../qnn_node_group/conv_activation_fusion.cc | 65 ++++++++++--------- .../qnn_node_group/conv_activation_fusion.h | 11 ++-- .../qnn/builder/qnn_node_group/dq_q_fusion.cc | 12 ++-- .../qnn/builder/qnn_node_group/dq_q_fusion.h | 5 +- .../qnn_node_group/hardsigmoid_mul_fusion.cc | 12 ++-- .../qnn_node_group/hardsigmoid_mul_fusion.h | 5 +- .../builder/qnn_node_group/qnn_node_group.cc | 38 +++++------ 11 files changed, 82 insertions(+), 89 deletions(-) diff --git a/onnxruntime/core/framework/node_unit.cc b/onnxruntime/core/framework/node_unit.cc index 84d6ccb4d7acb..d2930a770c0a0 100644 --- a/onnxruntime/core/framework/node_unit.cc +++ b/onnxruntime/core/framework/node_unit.cc @@ -272,9 +272,9 @@ NodeUnit::NodeUnit(const GraphViewer& graph_viewer, const QDQ::NodeGroup& node_g } } -NodeUnit::NodeUnit(gsl::span dq_nodes, const Node& target_node, - gsl::span q_nodes, Type type, - gsl::span inputs, gsl::span outputs, +NodeUnit::NodeUnit(gsl::span dq_nodes, const Node& target_node, + gsl::span q_nodes, Type type, + gsl::span inputs, gsl::span outputs, size_t input_edge_count, Node::EdgeSet output_edges) : dq_nodes_(dq_nodes.begin(), dq_nodes.end()), target_node_(target_node), diff --git a/onnxruntime/core/framework/node_unit.h b/onnxruntime/core/framework/node_unit.h index c2297c13a41e6..8bc2f79c4a372 100644 --- a/onnxruntime/core/framework/node_unit.h +++ b/onnxruntime/core/framework/node_unit.h @@ -68,9 +68,9 @@ class NodeUnit { public: explicit NodeUnit(const Node& node); explicit NodeUnit(const GraphViewer& graph_viewer, const QDQ::NodeGroup& node_group); - NodeUnit(gsl::span dq_nodes, const Node& target_node, - gsl::span q_nodes, Type type, - gsl::span inputs, gsl::span outputs, + NodeUnit(gsl::span dq_nodes, const Node& target_node, + gsl::span q_nodes, Type type, + gsl::span inputs, gsl::span outputs, size_t input_edge_count, Node::EdgeSet output_edges); Type UnitType() const noexcept { return type_; } diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.h b/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.h index fdf6616393ff8..9ab122b7f8e28 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.h @@ -52,12 +52,6 @@ class QnnModelWrapper { ~QnnModelWrapper() = default; - const QNN_INTERFACE_VER_TYPE& GetQnnInterface() const { return qnn_interface_; } - const Qnn_BackendHandle_t& GetQnnBackendHandle() const { return backend_handle_; } - const std::unordered_map& GetInputIndexMap() const { return input_index_map_; } - const std::unordered_map& GetOutputIndexMap() const { return output_index_map_; } - const std::unordered_set& GetInitializerLookup() const { return initializer_lookup_; } - bool CreateQnnGraph(const Qnn_ContextHandle_t& context, const std::string& graph_name, const QnnGraph_Config_t** graph_configs = nullptr); diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group.h b/onnxruntime/core/providers/qnn/builder/qnn_node_group.h index fb6aa221aac3e..bd2e58c2d3973 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_node_group.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group.h @@ -3,6 +3,7 @@ #pragma once +#include #include #include #include @@ -18,11 +19,9 @@ class IQnnNodeGroup { virtual ~IQnnNodeGroup() = default; virtual Status IsSupported(QnnModelWrapper& qmw, const logging::Logger& logger) const = 0; virtual Status AddToModelBuilder(QnnModelWrapper& qmw, const logging::Logger& logger) const = 0; - virtual std::vector GetNodeUnits() const = 0; + virtual gsl::span GetNodeUnits() const = 0; virtual const NodeUnit* GetTargetNodeUnit() const = 0; virtual std::string_view Type() const = 0; - - size_t index_ = 0; }; Status GetQnnNodeGroups(/*out*/ std::vector>& qnn_node_groups, diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/conv_activation_fusion.cc b/onnxruntime/core/providers/qnn/builder/qnn_node_group/conv_activation_fusion.cc index 4da4a748f801d..065a2810a2920 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_node_group/conv_activation_fusion.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/conv_activation_fusion.cc @@ -332,6 +332,7 @@ static Status QnnConvActivationFusionAdd(QnnModelWrapper& qnn_model_wrapper, const size_t num_dqs = dq_node_units.size(); constexpr size_t max_num_dqs = 3; ORT_RETURN_IF_NOT(num_dqs == 2 || num_dqs == max_num_dqs, "QDQ Conv should have 2 or 3 DQs"); + ORT_RETURN_IF_NOT(conv_node_unit->OpType() == "Conv" && q_node_unit->OpType() == "QuantizeLinear"); std::array dq_nodes_buf = {}; for (size_t i = 0; i < num_dqs; i++) { @@ -447,64 +448,66 @@ std::unique_ptr TryConvActivationFusion(QnnModelWrapper& qnn_mode return nullptr; } - return std::make_unique(dq_node_units, conv_node_unit, - *activation_node_unit, *q_node_unit); + return std::make_unique(*dq_node_units[0], + *dq_node_units[1], + dq_node_units.size() == 3 ? dq_node_units[2] : nullptr, + conv_node_unit, + *activation_node_unit, + *q_node_unit); } namespace conv_act_fusion { -QnnNodeGroup::QnnNodeGroup(gsl::span dq_node_units, +QnnNodeGroup::QnnNodeGroup(const NodeUnit& dq_node_unit_0, + const NodeUnit& dq_node_unit_1, + const NodeUnit* dq_node_unit_2, const NodeUnit& conv_node_unit, const NodeUnit& activation_node_unit, const NodeUnit& q_node_unit) - : dq_node_units_{}, - conv_node_unit_(conv_node_unit), - activation_node_unit_(activation_node_unit), - q_node_unit_(q_node_unit) { - assert(dq_node_units.size() <= dq_node_units_.size()); - std::copy(dq_node_units.begin(), dq_node_units.end(), dq_node_units_.data()); + : node_units_{} { + size_t i = 0; + node_units_[i++] = &dq_node_unit_0; + node_units_[i++] = &dq_node_unit_1; + if (dq_node_unit_2 != nullptr) { + node_units_[i++] = dq_node_unit_2; + } + node_units_[i++] = &conv_node_unit; + node_units_[i++] = &activation_node_unit; + node_units_[i++] = &q_node_unit; + assert((!dq_node_unit_2 && i == 5) || (dq_node_unit_2 && i == 6)); } Status QnnNodeGroup::IsSupported(QnnModelWrapper& qmw, const logging::Logger& logger) const { - const size_t num_dqs = dq_node_units_.back() != nullptr ? 3 : 2; - gsl::span dq_node_units(dq_node_units_.data(), num_dqs); + const size_t num_dqs = node_units_.back() != nullptr ? 3 : 2; + gsl::span dq_node_units(node_units_.data(), num_dqs); return QnnConvActivationFusionAdd(qmw, dq_node_units, - &conv_node_unit_, - &q_node_unit_, + node_units_[num_dqs], // Conv + node_units_[num_dqs + 2], // Q logger, /*validate*/ true); } Status QnnNodeGroup::AddToModelBuilder(QnnModelWrapper& qmw, const logging::Logger& logger) const { - const size_t num_dqs = dq_node_units_.back() != nullptr ? 3 : 2; - gsl::span dq_node_units(dq_node_units_.data(), num_dqs); + const size_t num_dqs = node_units_.back() != nullptr ? 3 : 2; + gsl::span dq_node_units(node_units_.data(), num_dqs); return QnnConvActivationFusionAdd(qmw, dq_node_units, - &conv_node_unit_, - &q_node_unit_, + node_units_[num_dqs], // Conv + node_units_[num_dqs + 2], // Q logger, /*validate*/ false); } -std::vector QnnNodeGroup::GetNodeUnits() const { - const size_t num_dqs = dq_node_units_.back() != nullptr ? 3 : 2; - - std::vector node_units; - node_units.reserve(6); - for (size_t i = 0; i < num_dqs; i++) { - node_units.push_back(dq_node_units_[i]); - } - node_units.push_back(&conv_node_unit_); - node_units.push_back(&activation_node_unit_); - node_units.push_back(&q_node_unit_); - - return node_units; +gsl::span QnnNodeGroup::GetNodeUnits() const { + const size_t num_node_units = node_units_.back() != nullptr ? 6 : 5; + return gsl::make_span(node_units_.data(), num_node_units); } const NodeUnit* QnnNodeGroup::GetTargetNodeUnit() const { - return &conv_node_unit_; + const size_t conv_index = node_units_.back() != nullptr ? 3 : 2; + return node_units_[conv_index]; } } // namespace conv_act_fusion diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/conv_activation_fusion.h b/onnxruntime/core/providers/qnn/builder/qnn_node_group/conv_activation_fusion.h index a195c86d2393a..43a3aa63fe9ea 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_node_group/conv_activation_fusion.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/conv_activation_fusion.h @@ -27,7 +27,9 @@ namespace conv_act_fusion { class QnnNodeGroup : public IQnnNodeGroup { public: - QnnNodeGroup(gsl::span dq_node_units, + QnnNodeGroup(const NodeUnit& dq_node_unit_0, + const NodeUnit& dq_node_unit_1, + const NodeUnit* dq_node_unit_2, const NodeUnit& conv_node_unit, const NodeUnit& activation_node_unit, const NodeUnit& q_node_unit); @@ -35,15 +37,12 @@ class QnnNodeGroup : public IQnnNodeGroup { Status IsSupported(QnnModelWrapper& qmw, const logging::Logger& logger) const override; Status AddToModelBuilder(QnnModelWrapper& qmw, const logging::Logger& logger) const override; - std::vector GetNodeUnits() const override; + gsl::span GetNodeUnits() const override; const NodeUnit* GetTargetNodeUnit() const override; std::string_view Type() const override { return "ConvActivationFusion"; } private: - std::array dq_node_units_; // Last DQ is nullptr if bias is missing. - const NodeUnit& conv_node_unit_; - const NodeUnit& activation_node_unit_; - const NodeUnit& q_node_unit_; + std::array node_units_; // Last elem is nullptr if bias DQ is missing. }; } // namespace conv_act_fusion diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/dq_q_fusion.cc b/onnxruntime/core/providers/qnn/builder/qnn_node_group/dq_q_fusion.cc index ac782b5b1420a..e31219c8b3b76 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_node_group/dq_q_fusion.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/dq_q_fusion.cc @@ -101,23 +101,23 @@ std::unique_ptr TryDQQFusion( namespace dq_q_fusion { QnnNodeGroup::QnnNodeGroup(const NodeUnit& dq_node_unit, const NodeUnit& q_node_unit) - : dq_node_unit_(dq_node_unit), q_node_unit_(q_node_unit) { + : node_units_{&dq_node_unit, &q_node_unit} { } Status QnnNodeGroup::IsSupported(QnnModelWrapper& qmw, const logging::Logger& logger) const { - return QnnDQQFusionAdd(qmw, dq_node_unit_, q_node_unit_, logger, /*validate*/ true); + return QnnDQQFusionAdd(qmw, *node_units_[0], *node_units_[1], logger, /*validate*/ true); } Status QnnNodeGroup::AddToModelBuilder(QnnModelWrapper& qmw, const logging::Logger& logger) const { - return QnnDQQFusionAdd(qmw, dq_node_unit_, q_node_unit_, logger, /*validate*/ false); + return QnnDQQFusionAdd(qmw, *node_units_[0], *node_units_[1], logger, /*validate*/ false); } -std::vector QnnNodeGroup::GetNodeUnits() const { - return std::vector{&dq_node_unit_, &q_node_unit_}; +gsl::span QnnNodeGroup::GetNodeUnits() const { + return node_units_; } const NodeUnit* QnnNodeGroup::GetTargetNodeUnit() const { - return &dq_node_unit_; + return node_units_[0]; } } // namespace dq_q_fusion diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/dq_q_fusion.h b/onnxruntime/core/providers/qnn/builder/qnn_node_group/dq_q_fusion.h index 2e5b612c41a81..c5d779c8234ff 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_node_group/dq_q_fusion.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/dq_q_fusion.h @@ -43,13 +43,12 @@ class QnnNodeGroup : public IQnnNodeGroup { Status IsSupported(QnnModelWrapper& qmw, const logging::Logger& logger) const override; Status AddToModelBuilder(QnnModelWrapper& qmw, const logging::Logger& logger) const override; - std::vector GetNodeUnits() const override; + gsl::span GetNodeUnits() const override; const NodeUnit* GetTargetNodeUnit() const override; std::string_view Type() const override { return "DQQFusion"; } private: - const NodeUnit& dq_node_unit_; - const NodeUnit& q_node_unit_; + std::array node_units_; }; } // namespace dq_q_fusion diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/hardsigmoid_mul_fusion.cc b/onnxruntime/core/providers/qnn/builder/qnn_node_group/hardsigmoid_mul_fusion.cc index 817e2190e7825..e77d613d607c6 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_node_group/hardsigmoid_mul_fusion.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/hardsigmoid_mul_fusion.cc @@ -116,23 +116,23 @@ std::unique_ptr TryHardSigmoidMulFusion( namespace hs_mul_fusion { QnnNodeGroup::QnnNodeGroup(const NodeUnit& hardsigmoid_node_unit, const NodeUnit& mul_node_unit) - : hardsigmoid_node_unit_(hardsigmoid_node_unit), mul_node_unit_(mul_node_unit) { + : node_units_{&hardsigmoid_node_unit, &mul_node_unit} { } Status QnnNodeGroup::IsSupported(QnnModelWrapper& qmw, const logging::Logger& logger) const { - return QnnHardSigmoidMulFusionAdd(qmw, hardsigmoid_node_unit_, mul_node_unit_, logger, /*validate*/ true); + return QnnHardSigmoidMulFusionAdd(qmw, *node_units_[0], *node_units_[1], logger, /*validate*/ true); } Status QnnNodeGroup::AddToModelBuilder(QnnModelWrapper& qmw, const logging::Logger& logger) const { - return QnnHardSigmoidMulFusionAdd(qmw, hardsigmoid_node_unit_, mul_node_unit_, logger, /*validate*/ false); + return QnnHardSigmoidMulFusionAdd(qmw, *node_units_[0], *node_units_[1], logger, /*validate*/ false); } -std::vector QnnNodeGroup::GetNodeUnits() const { - return std::vector{&hardsigmoid_node_unit_, &mul_node_unit_}; +gsl::span QnnNodeGroup::GetNodeUnits() const { + return node_units_; } const NodeUnit* QnnNodeGroup::GetTargetNodeUnit() const { - return &hardsigmoid_node_unit_; + return node_units_[0]; } } // namespace hs_mul_fusion diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/hardsigmoid_mul_fusion.h b/onnxruntime/core/providers/qnn/builder/qnn_node_group/hardsigmoid_mul_fusion.h index 1cfb6119e3acc..3b04dccf1f6a5 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_node_group/hardsigmoid_mul_fusion.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/hardsigmoid_mul_fusion.h @@ -44,13 +44,12 @@ class QnnNodeGroup : public IQnnNodeGroup { Status IsSupported(QnnModelWrapper& qmw, const logging::Logger& logger) const override; Status AddToModelBuilder(QnnModelWrapper& qmw, const logging::Logger& logger) const override; - std::vector GetNodeUnits() const override; + gsl::span GetNodeUnits() const override; const NodeUnit* GetTargetNodeUnit() const override; std::string_view Type() const override { return "HardSigmoidMulFusion"; } private: - const NodeUnit& hardsigmoid_node_unit_; - const NodeUnit& mul_node_unit_; + std::array node_units_; }; } // namespace hs_mul_fusion diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/qnn_node_group.cc b/onnxruntime/core/providers/qnn/builder/qnn_node_group/qnn_node_group.cc index 8486d20dd6065..7a5abd6c9c9e2 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_node_group/qnn_node_group.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/qnn_node_group.cc @@ -3,6 +3,7 @@ #include "core/providers/qnn/builder/qnn_node_group.h" +#include #include #include #include @@ -24,35 +25,35 @@ namespace qnn { class QnnNodeUnitWrapper : public IQnnNodeGroup { public: - QnnNodeUnitWrapper(const NodeUnit& node_unit) : node_unit_(node_unit) {} + QnnNodeUnitWrapper(const NodeUnit& node_unit) : node_unit_(&node_unit) {} ORT_DISALLOW_COPY_AND_ASSIGNMENT(QnnNodeUnitWrapper); Status IsSupported(QnnModelWrapper& qmw, const logging::Logger& logger) const override { - const std::string& op_type = node_unit_.OpType(); + const std::string& op_type = node_unit_->OpType(); const auto* op_builder = qnn::GetOpBuilder(op_type); ORT_RETURN_IF_NOT(op_builder != nullptr, "Operators of type `", op_type, "` are not supported by QNN EP.", op_type, " node `", - node_unit_.Name(), "` will not be assigned to QNN EP."); + node_unit_->Name(), "` will not be assigned to QNN EP."); - return op_builder->IsOpSupported(qmw, node_unit_, logger); + return op_builder->IsOpSupported(qmw, *node_unit_, logger); } Status AddToModelBuilder(QnnModelWrapper& qmw, const logging::Logger& logger) const override { - const std::string& op_type = node_unit_.OpType(); + const std::string& op_type = node_unit_->OpType(); const auto* op_builder = qnn::GetOpBuilder(op_type); ORT_RETURN_IF_NOT(op_builder != nullptr, "[QNN EP]: Missing OpBuilder for OpType ", op_type); - return op_builder->AddToModelBuilder(qmw, node_unit_, logger, /*do_op_validation*/ false); + return op_builder->AddToModelBuilder(qmw, *node_unit_, logger, /*do_op_validation*/ false); } - std::vector GetNodeUnits() const override { - return std::vector{&node_unit_}; + gsl::span GetNodeUnits() const override { + return gsl::span{&node_unit_, 1ULL}; } - const NodeUnit* GetTargetNodeUnit() const override { return &node_unit_; } + const NodeUnit* GetTargetNodeUnit() const override { return node_unit_; } std::string_view Type() const override { return "NodeUnitWrapper"; } private: - const NodeUnit& node_unit_; + const NodeUnit* node_unit_; }; using FusionFunc = std::unique_ptr (*)( @@ -106,6 +107,7 @@ Status GetQnnNodeGroups(/*out*/ std::vector>& qnn { std::unordered_map node_unit_to_qnn_node_group; + std::unordered_map fused_qnn_node_group_indices; std::vector> sorted_node_units; sorted_node_units.reserve(num_node_units); @@ -135,7 +137,7 @@ Status GetQnnNodeGroups(/*out*/ std::vector>& qnn if (fused_node_group) { const size_t index = tmp_qnn_node_groups.size(); - fused_node_group->index_ = index; + fused_qnn_node_group_indices[fused_node_group.get()] = index; for (const NodeUnit* fused_node_unit : fused_node_group->GetNodeUnits()) { assert(fused_node_unit != nullptr); @@ -151,19 +153,18 @@ Status GetQnnNodeGroups(/*out*/ std::vector>& qnn const auto it = node_unit_to_qnn_node_group.find(node_unit); if (it != node_unit_to_qnn_node_group.end()) { // Already handled this NodeUnit. - gsl::not_null qnn_node_group = it->second; - if (node_unit == qnn_node_group->GetTargetNodeUnit()) { - sorted_qnn_node_group_indices.push_back(qnn_node_group->index_); + gsl::not_null fused_qnn_node_group = it->second; + if (node_unit == fused_qnn_node_group->GetTargetNodeUnit()) { + sorted_qnn_node_group_indices.push_back(fused_qnn_node_group_indices[fused_qnn_node_group]); } continue; } const size_t index = tmp_qnn_node_groups.size(); - auto fused_node_group = std::make_unique(*node_unit); - fused_node_group->index_ = index; - tmp_qnn_node_groups.push_back(std::move(fused_node_group)); + auto qnn_node_group = std::make_unique(*node_unit); - node_unit_to_qnn_node_group.insert({node_unit, fused_node_group.get()}); + node_unit_to_qnn_node_group.insert({node_unit, qnn_node_group.get()}); + tmp_qnn_node_groups.push_back(std::move(qnn_node_group)); sorted_qnn_node_group_indices.push_back(index); } @@ -176,7 +177,6 @@ Status GetQnnNodeGroups(/*out*/ std::vector>& qnn for (auto index : sorted_qnn_node_group_indices) { assert(index < tmp_qnn_node_groups.size()); std::unique_ptr qnn_node_group = std::move(tmp_qnn_node_groups[index]); - qnn_node_group->index_ = qnn_node_groups.size(); qnn_node_groups.push_back(std::move(qnn_node_group)); } From 0e7eeceede95c729935a39bdda5d06a068f2db5c Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Sun, 28 Jul 2024 15:42:06 -0700 Subject: [PATCH 14/20] Remove use of optimizer qdq utils from fusion code; Rename fusion classes --- .../qnn/builder/qnn_model_wrapper.cc | 2 + .../providers/qnn/builder/qnn_node_group.h | 4 +- .../qnn_node_group/conv_activation_fusion.cc | 283 +++++++----------- .../qnn_node_group/conv_activation_fusion.h | 34 +-- .../qnn/builder/qnn_node_group/dq_q_fusion.cc | 175 +++++++---- .../qnn/builder/qnn_node_group/dq_q_fusion.h | 50 ++-- .../qnn_node_group/hardsigmoid_mul_fusion.cc | 117 ++++---- .../qnn_node_group/hardsigmoid_mul_fusion.h | 55 ++-- .../builder/qnn_node_group/qnn_node_group.cc | 11 +- .../qnn/builder/qnn_node_group/utils.h | 6 + .../providers/qnn/qnn_execution_provider.cc | 60 ++-- .../test/providers/qnn/qnn_basic_test.cc | 2 +- 12 files changed, 404 insertions(+), 395 deletions(-) diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc b/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc index c8537307ef3ba..fb1011bcf8055 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc @@ -239,6 +239,8 @@ bool QnnModelWrapper::CreateQnnNode(const std::string& qnn_node_name, std::string error_msg; bool rt = op_config_wrapper.QnnGraphOpValidation(qnn_interface_, backend_handle_, error_msg); if (!rt) { + // TODO(adrianlizarraga): Return a Status with the error message so that aggregated logs show a more + // specific validation error (instead of "failed to add node"). LOGS(logger_, WARNING) << error_msg; } return rt; diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group.h b/onnxruntime/core/providers/qnn/builder/qnn_node_group.h index bd2e58c2d3973..a3c1b1bcdd407 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_node_group.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group.h @@ -8,12 +8,14 @@ #include #include +#include "core/common/logging/logging.h" #include "core/framework/node_unit.h" -#include "core/providers/qnn/builder/qnn_model_wrapper.h" namespace onnxruntime { namespace qnn { +class QnnModelWrapper; + class IQnnNodeGroup { public: virtual ~IQnnNodeGroup() = default; diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/conv_activation_fusion.cc b/onnxruntime/core/providers/qnn/builder/qnn_node_group/conv_activation_fusion.cc index 065a2810a2920..f5ddee6b1f78e 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_node_group/conv_activation_fusion.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/conv_activation_fusion.cc @@ -6,12 +6,12 @@ #include #include #include "core/graph/graph_utils.h" -#include "core/optimizer/qdq_transformer/qdq_util.h" #include "core/framework/node_unit.h" #include "core/providers/shared/utils/utils.h" #include "core/providers/qnn/builder/qnn_utils.h" #include "core/providers/qnn/builder/op_builder_factory.h" #include "core/providers/qnn/builder/qnn_node_group/utils.h" +#include "core/providers/qnn/builder/qnn_model_wrapper.h" namespace onnxruntime { namespace qnn { @@ -21,16 +21,16 @@ static bool GetQScalarScaleZeroPoint(const QnnModelWrapper& qnn_model_wrapper, /*out*/ float& scale, /*out*/ int32_t& zero_point, /*out*/ int32_t& zp_data_type) { - assert(q_node_unit.OpType() == QDQ::QOpName); + assert(q_node_unit.OpType() == QUANTIZE_LINEAR); const auto& q_inputs = q_node_unit.GetNode().InputDefs(); // Require an explicit zero-point input for now. - if (q_inputs.size() != 3 || !q_inputs[QDQ::ZERO_POINT_ID]->Exists()) { + if (q_inputs.size() != 3 || !q_inputs[QDQ_ZERO_POINT_INPUT_IDX]->Exists()) { return false; } std::vector zero_points; - Status status = qnn_model_wrapper.UnpackZeroPoints(q_inputs[QDQ::ZERO_POINT_ID]->Name(), + Status status = qnn_model_wrapper.UnpackZeroPoints(q_inputs[QDQ_ZERO_POINT_INPUT_IDX]->Name(), zero_points, zp_data_type); // Should only have one zero-point (per-tensor). @@ -40,7 +40,7 @@ static bool GetQScalarScaleZeroPoint(const QnnModelWrapper& qnn_model_wrapper, zero_point = -zero_points[0]; // QNN zero-points are negated. std::vector scales; - status = qnn_model_wrapper.UnpackScales(q_inputs[QDQ::SCALE_ID]->Name(), scales); + status = qnn_model_wrapper.UnpackScales(q_inputs[QDQ_SCALE_INPUT_IDX]->Name(), scales); // Should only have one scale (per-tensor). if (!status.IsOK() || scales.size() != 1) { @@ -91,72 +91,11 @@ static bool GetQRminRmax(const QnnModelWrapper& qnn_model_wrapper, return true; } -static bool GetClipMinMax(const QnnModelWrapper& qnn_model_wrapper, - const NodeUnit& clip_node_unit, - /*out*/ float& clip_min, - /*out*/ float& clip_max) { - clip_min = std::numeric_limits::lowest(); - clip_max = std::numeric_limits::max(); - - // Clip's min and max are attributes before opset 11. - if (clip_node_unit.GetNode().SinceVersion() < 11) { - NodeAttrHelper attr_helper(clip_node_unit); - std::optional min_opt = attr_helper.GetFloat("min"); - std::optional max_opt = attr_helper.GetFloat("max"); - - if (min_opt.has_value()) { - clip_min = min_opt.value(); - } - - if (max_opt.has_value()) { - clip_max = max_opt.value(); - } - - return true; - } - - // After opset 11, min and max are inputs. - const auto& inputs = clip_node_unit.Inputs(); - const size_t num_inputs = inputs.size(); - auto get_min_or_max = [&qnn_model_wrapper](const NodeUnitIODef& input, /*out*/ float& result) -> bool { - TensorInfo input_info = {}; - std::vector raw_bytes; - if (Status status = qnn_model_wrapper.GetTensorInfo(input, input_info); !status.IsOK()) { - return false; - } - if (!input_info.is_initializer) { - return false; - } - if (Status status = qnn_model_wrapper.UnpackInitializerData(*input_info.initializer_tensor, raw_bytes); - !status.IsOK()) { - return false; - } - if (input_info.qnn_data_type != QNN_DATATYPE_FLOAT_32) { - return false; - } - result = static_cast(*reinterpret_cast(raw_bytes.data())); - return true; - }; - - if (num_inputs > 1 && inputs[1].node_arg.Exists()) { - if (!get_min_or_max(inputs[1], clip_min)) { - return false; - } - } - - if (num_inputs > 2 && inputs[2].node_arg.Exists()) { - if (!get_min_or_max(inputs[2], clip_max)) { - return false; - } - } - - return true; -} - static bool CanClipBeRemoved(const QnnModelWrapper& qnn_model_wrapper, const NodeUnit& clip_node_unit, - const NodeUnit& q_node_unit) { - assert(clip_node_unit.OpType() == "Clip" && q_node_unit.OpType() == QDQ::QOpName); + const NodeUnit& q_node_unit, + const logging::Logger& logger) { + assert(clip_node_unit.OpType() == "Clip" && q_node_unit.OpType() == QUANTIZE_LINEAR); float rmin = 0.0f; float rmax = 0.0f; @@ -167,7 +106,8 @@ static bool CanClipBeRemoved(const QnnModelWrapper& qnn_model_wrapper, float clip_min = std::numeric_limits::lowest(); float clip_max = std::numeric_limits::max(); - if (!GetClipMinMax(qnn_model_wrapper, clip_node_unit, clip_min, clip_max)) { + if (!onnxruntime::GetClipMinMax(qnn_model_wrapper.GetGraphViewer(), clip_node_unit.GetNode(), + clip_min, clip_max, logger)) { return false; } @@ -180,7 +120,7 @@ static bool CanClipBeRemoved(const QnnModelWrapper& qnn_model_wrapper, } static bool CanQRelaceRelu(const QnnModelWrapper& qnn_model_wrapper, const NodeUnit& q_node_unit) { - assert(q_node_unit.OpType() == QDQ::QOpName); + assert(q_node_unit.OpType() == QUANTIZE_LINEAR); int32_t zp_data_type = ONNX_NAMESPACE::TensorProto::DataType::TensorProto_DataType_UNDEFINED; int32_t zero_point = 0; float scale = 0.0f; @@ -206,7 +146,8 @@ static bool CanQRelaceRelu(const QnnModelWrapper& qnn_model_wrapper, const NodeU static bool CanActivationBeRemoved(const QnnModelWrapper& qnn_model_wrapper, const NodeUnit& activation_node_unit, - const NodeUnit& q_node_unit) { + const NodeUnit& q_node_unit, + const logging::Logger& logger) { const std::string& activation_type = activation_node_unit.OpType(); if (activation_type == "Relu") { @@ -214,91 +155,94 @@ static bool CanActivationBeRemoved(const QnnModelWrapper& qnn_model_wrapper, } if (activation_type == "Clip") { - return CanClipBeRemoved(qnn_model_wrapper, activation_node_unit, q_node_unit); + return CanClipBeRemoved(qnn_model_wrapper, activation_node_unit, q_node_unit, logger); } return false; } -// adjust for an optional input/output that has an entry but does not exist -static int NumActualValues(const Node& node, bool input) { - const auto& defs = input ? node.InputDefs() : node.OutputDefs(); - return gsl::narrow_cast(std::count_if(defs.cbegin(), defs.cend(), - [](const NodeArg* def) { return def && def->Exists(); })); -} - -static std::vector FindQDQNodes(const GraphViewer& graph_viewer, const Node& node, bool find_dq_nodes) { - // First get all the upstream (DQ) or downstream (Q) nodes - std::vector nodes = - find_dq_nodes ? graph_utils::FindParentsByType(node, QDQ::DQOpName) - : graph_utils::FindChildrenByType(node, QDQ::QOpName); +static std::vector FindParentDQNodes(const GraphViewer& graph_viewer, const Node& node) { + // Get all parent DQ nodes sorted by destination argument index. + std::vector parents(node.InputDefs().size(), nullptr); + for (auto it = node.InputEdgesBegin(); it != node.InputEdgesEnd(); it++) { + if (it->GetNode().OpType().compare(DEQUANTIZE_LINEAR) == 0) { + parents[it->GetDstArgIndex()] = &(it->GetNode()); + } + } // Remove all the nodes which are not in the graph_viewer - nodes.erase(std::remove_if(nodes.begin(), nodes.end(), - [&graph_viewer](const Node* _node) { - return _node == nullptr || graph_viewer.GetNode(_node->Index()) == nullptr; - }), - nodes.end()); + parents.erase(std::remove_if(parents.begin(), parents.end(), + [&graph_viewer](const Node* _node) { + return _node == nullptr || graph_viewer.GetNode(_node->Index()) == nullptr; + }), + parents.end()); - return nodes; + return parents; } -static std::vector GetConvDQs( +static bool GetConvDQs( const GraphViewer& graph_viewer, const std::unordered_map& node_to_node_unit, const std::unordered_map& node_unit_to_qnn_node_group, - const Node& conv_node) { - assert(conv_node.OpType() == "Conv" || conv_node.OpType() == "ConvTranspose"); - std::vector dq_nodes = FindQDQNodes(graph_viewer, conv_node, /*find_dq_nodes*/ true); - int num_dq_inputs = NumActualValues(conv_node, /*input*/ true); + const Node& conv_node, + /*out*/ std::array& dq_node_units) { + if (conv_node.OpType() != "Conv" && conv_node.OpType() != "ConvTranspose") { + return false; + } + + // Count number of inputs to Conv node. + const auto& conv_inputs = conv_node.InputDefs(); + const size_t num_conv_inputs = std::count_if(conv_inputs.cbegin(), conv_inputs.cend(), + [](const NodeArg* input) { return input && input->Exists(); }); + + // Get the Conv's parent DQ nodes. + std::vector dq_nodes = FindParentDQNodes(graph_viewer, conv_node); + const size_t num_dqs = dq_nodes.size(); // Within a QDQ node group, a target node input is the only consumer of each DQ. - if (num_dq_inputs != static_cast(dq_nodes.size())) { - return {}; + if ((num_conv_inputs != num_dqs) || (num_dqs > dq_node_units.size())) { + return false; } - std::vector dq_node_units; - for (const auto* dq_node : dq_nodes) { - if (graph_viewer.NodeProducesGraphOutput(*dq_node)) { - return {}; + dq_node_units.fill(nullptr); + for (size_t i = 0; i < num_dqs; i++) { + const Node* dq_node = dq_nodes[i]; + + // DQ must not produce a graph output. + if (!dq_node || graph_viewer.NodeProducesGraphOutput(*dq_node)) { + return false; } + // Conv should be the only consumer of a parent DQ. const bool dq_has_single_output_edge_to_target = dq_node->GetOutputEdgesCount() == 1 && dq_node->OutputEdgesBegin()->GetNode().Index() == conv_node.Index(); if (!dq_has_single_output_edge_to_target) { - return {}; + return false; } + // DQ node must be part of a "standalone" NodeUnit. const auto it = node_to_node_unit.find(dq_node); if (it == node_to_node_unit.end()) { - return {}; + return false; } - const NodeUnit* dq_node_unit = it->second; - if (!dq_node_unit || node_unit_to_qnn_node_group.count(dq_node_unit) != 0) { - return {}; + return false; } - if (dq_node_unit->UnitType() != NodeUnit::Type::SingleNode) { - return {}; + return false; } - dq_node_units.push_back(dq_node_unit); + dq_node_units[i] = dq_node_unit; } - return dq_node_units; + return true; } -static bool IsValidQDQConv(gsl::span dq_node_units, - gsl::not_null q_node_unit) { - assert(q_node_unit->OpType() == QDQ::QOpName); - const size_t num_dqs = dq_node_units.size(); - if (num_dqs != 2 && num_dqs != 3) { - return false; - } - +static bool CheckQDQConvDataTypes(std::array& dq_node_units, + gsl::not_null q_node_unit) { + assert(q_node_unit->OpType() == QUANTIZE_LINEAR); // input and output types need to be same int32_t dt_input = dq_node_units[0]->GetNode().InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); int32_t dt_weight = dq_node_units[1]->GetNode().InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); @@ -313,7 +257,7 @@ static bool IsValidQDQConv(gsl::span dq_node_units, } } - if (num_dqs == 3) { // has bias + if (dq_node_units[2] != nullptr) { // has bias int32_t dt_bias = dq_node_units[2]->GetNode().InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); if (dt_bias != ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32) { return false; @@ -323,12 +267,16 @@ static bool IsValidQDQConv(gsl::span dq_node_units, return true; } -static Status QnnConvActivationFusionAdd(QnnModelWrapper& qnn_model_wrapper, - gsl::span dq_node_units, - const NodeUnit* conv_node_unit, - const NodeUnit* q_node_unit, - const logging::Logger& logger, - bool validate) { +#define ValidateOnQnn(qnn_model_wrapper, dq_node_units, conv_node_unit, q_node_unit, logger) \ + CreateOrValidateOnQnn((qnn_model_wrapper), (dq_node_units), (conv_node_unit), (q_node_unit), (logger), true) +#define CreateOnQnn(qnn_model_wrapper, dq_node_units, conv_node_unit, q_node_unit, logger) \ + CreateOrValidateOnQnn((qnn_model_wrapper), (dq_node_units), (conv_node_unit), (q_node_unit), (logger), false) +static Status CreateOrValidateOnQnn(QnnModelWrapper& qnn_model_wrapper, + gsl::span dq_node_units, + const NodeUnit* conv_node_unit, + const NodeUnit* q_node_unit, + const logging::Logger& logger, + bool validate) { const size_t num_dqs = dq_node_units.size(); constexpr size_t max_num_dqs = 3; ORT_RETURN_IF_NOT(num_dqs == 2 || num_dqs == max_num_dqs, "QDQ Conv should have 2 or 3 DQs"); @@ -399,12 +347,12 @@ static Status QnnConvActivationFusionAdd(QnnModelWrapper& qnn_model_wrapper, return conv_op_builder->AddToModelBuilder(qnn_model_wrapper, custom_node_unit, logger, validate); } -std::unique_ptr TryConvActivationFusion(QnnModelWrapper& qnn_model_wrapper, - const NodeUnit& conv_node_unit, - const std::unordered_map& node_to_node_unit, - const std::unordered_map& node_unit_to_qnn_node_group, - const logging::Logger& logger) { - ORT_UNUSED_PARAMETER(logger); +std::unique_ptr ConvActivationFusion::TryFusion( + QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& conv_node_unit, + const std::unordered_map& node_to_node_unit, + const std::unordered_map& node_unit_to_qnn_node_group, + const logging::Logger& logger) { // Expect that this function is called with a standalone Conv or ConvTranspose. const auto& conv_type = conv_node_unit.OpType(); @@ -424,7 +372,7 @@ std::unique_ptr TryConvActivationFusion(QnnModelWrapper& qnn_mode } // Relu/Clip must have a single Q child. - const std::array q_op_types = {QDQ::QOpName}; + const std::array q_op_types = {QUANTIZE_LINEAR}; const NodeUnit* q_node_unit = GetOnlyChildOfType(graph_viewer, *activation_node_unit, q_op_types, node_to_node_unit, node_unit_to_qnn_node_group); @@ -433,36 +381,38 @@ std::unique_ptr TryConvActivationFusion(QnnModelWrapper& qnn_mode } // Check if Clip/Relu can be removed because the Q node provides an equivalent effect. - if (!CanActivationBeRemoved(qnn_model_wrapper, *activation_node_unit, *q_node_unit)) { + if (!CanActivationBeRemoved(qnn_model_wrapper, *activation_node_unit, *q_node_unit, logger)) { return nullptr; } // Create a QDQ node group with DQ* -> Conv -> Q const Node& conv_node = conv_node_unit.GetNode(); - std::vector dq_node_units = GetConvDQs(graph_viewer, - node_to_node_unit, - node_unit_to_qnn_node_group, - conv_node); + std::array dq_node_units = {}; + if (!GetConvDQs(graph_viewer, + node_to_node_unit, + node_unit_to_qnn_node_group, + conv_node, dq_node_units)) { + return nullptr; + } - if (!IsValidQDQConv(dq_node_units, q_node_unit)) { + if (!CheckQDQConvDataTypes(dq_node_units, q_node_unit)) { return nullptr; } - return std::make_unique(*dq_node_units[0], - *dq_node_units[1], - dq_node_units.size() == 3 ? dq_node_units[2] : nullptr, - conv_node_unit, - *activation_node_unit, - *q_node_unit); + return std::make_unique(*dq_node_units[0], + *dq_node_units[1], + dq_node_units[2], + conv_node_unit, + *activation_node_unit, + *q_node_unit); } -namespace conv_act_fusion { -QnnNodeGroup::QnnNodeGroup(const NodeUnit& dq_node_unit_0, - const NodeUnit& dq_node_unit_1, - const NodeUnit* dq_node_unit_2, - const NodeUnit& conv_node_unit, - const NodeUnit& activation_node_unit, - const NodeUnit& q_node_unit) +ConvActivationFusion::ConvActivationFusion(const NodeUnit& dq_node_unit_0, + const NodeUnit& dq_node_unit_1, + const NodeUnit* dq_node_unit_2, + const NodeUnit& conv_node_unit, + const NodeUnit& activation_node_unit, + const NodeUnit& q_node_unit) : node_units_{} { size_t i = 0; node_units_[i++] = &dq_node_unit_0; @@ -476,40 +426,35 @@ QnnNodeGroup::QnnNodeGroup(const NodeUnit& dq_node_unit_0, assert((!dq_node_unit_2 && i == 5) || (dq_node_unit_2 && i == 6)); } -Status QnnNodeGroup::IsSupported(QnnModelWrapper& qmw, const logging::Logger& logger) const { +Status ConvActivationFusion::IsSupported(QnnModelWrapper& qmw, const logging::Logger& logger) const { const size_t num_dqs = node_units_.back() != nullptr ? 3 : 2; gsl::span dq_node_units(node_units_.data(), num_dqs); - return QnnConvActivationFusionAdd(qmw, - dq_node_units, - node_units_[num_dqs], // Conv - node_units_[num_dqs + 2], // Q - logger, - /*validate*/ true); + return ValidateOnQnn(qmw, dq_node_units, + node_units_[num_dqs], // Conv + node_units_[num_dqs + 2], // Q + logger); } -Status QnnNodeGroup::AddToModelBuilder(QnnModelWrapper& qmw, const logging::Logger& logger) const { +Status ConvActivationFusion::AddToModelBuilder(QnnModelWrapper& qmw, const logging::Logger& logger) const { const size_t num_dqs = node_units_.back() != nullptr ? 3 : 2; gsl::span dq_node_units(node_units_.data(), num_dqs); - return QnnConvActivationFusionAdd(qmw, - dq_node_units, - node_units_[num_dqs], // Conv - node_units_[num_dqs + 2], // Q - logger, - /*validate*/ false); + return CreateOnQnn(qmw, dq_node_units, + node_units_[num_dqs], // Conv + node_units_[num_dqs + 2], // Q + logger); } -gsl::span QnnNodeGroup::GetNodeUnits() const { +gsl::span ConvActivationFusion::GetNodeUnits() const { const size_t num_node_units = node_units_.back() != nullptr ? 6 : 5; return gsl::make_span(node_units_.data(), num_node_units); } -const NodeUnit* QnnNodeGroup::GetTargetNodeUnit() const { +const NodeUnit* ConvActivationFusion::GetTargetNodeUnit() const { const size_t conv_index = node_units_.back() != nullptr ? 3 : 2; return node_units_[conv_index]; } -} // namespace conv_act_fusion } // namespace qnn } // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/conv_activation_fusion.h b/onnxruntime/core/providers/qnn/builder/qnn_node_group/conv_activation_fusion.h index 43a3aa63fe9ea..f0e140addbd7a 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_node_group/conv_activation_fusion.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/conv_activation_fusion.h @@ -10,30 +10,22 @@ #include #include "core/framework/node_unit.h" -#include "core/providers/qnn/builder/qnn_model_wrapper.h" #include "core/providers/qnn/builder/qnn_node_group.h" namespace onnxruntime { namespace qnn { -std::unique_ptr TryConvActivationFusion( - QnnModelWrapper& qnn_model_wrapper, - const NodeUnit& conv_node_unit, - const std::unordered_map& node_to_node_unit, - const std::unordered_map& node_unit_to_qnn_node_group, - const logging::Logger& logger); +class QnnModelWrapper; -namespace conv_act_fusion { - -class QnnNodeGroup : public IQnnNodeGroup { +class ConvActivationFusion : public IQnnNodeGroup { public: - QnnNodeGroup(const NodeUnit& dq_node_unit_0, - const NodeUnit& dq_node_unit_1, - const NodeUnit* dq_node_unit_2, - const NodeUnit& conv_node_unit, - const NodeUnit& activation_node_unit, - const NodeUnit& q_node_unit); - ORT_DISALLOW_COPY_AND_ASSIGNMENT(QnnNodeGroup); + ConvActivationFusion(const NodeUnit& dq_node_unit_0, + const NodeUnit& dq_node_unit_1, + const NodeUnit* dq_node_unit_2, + const NodeUnit& conv_node_unit, + const NodeUnit& activation_node_unit, + const NodeUnit& q_node_unit); + ORT_DISALLOW_COPY_AND_ASSIGNMENT(ConvActivationFusion); Status IsSupported(QnnModelWrapper& qmw, const logging::Logger& logger) const override; Status AddToModelBuilder(QnnModelWrapper& qmw, const logging::Logger& logger) const override; @@ -41,10 +33,16 @@ class QnnNodeGroup : public IQnnNodeGroup { const NodeUnit* GetTargetNodeUnit() const override; std::string_view Type() const override { return "ConvActivationFusion"; } + static std::unique_ptr TryFusion( + QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& conv_node_unit, + const std::unordered_map& node_to_node_unit, + const std::unordered_map& node_unit_to_qnn_node_group, + const logging::Logger& logger); + private: std::array node_units_; // Last elem is nullptr if bias DQ is missing. }; -} // namespace conv_act_fusion } // namespace qnn } // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/dq_q_fusion.cc b/onnxruntime/core/providers/qnn/builder/qnn_node_group/dq_q_fusion.cc index e31219c8b3b76..4b2c96aa9e4d7 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_node_group/dq_q_fusion.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/dq_q_fusion.cc @@ -6,23 +6,89 @@ #include #include #include "core/graph/graph_utils.h" -#include "core/optimizer/qdq_transformer/qdq_util.h" #include "core/framework/node_unit.h" #include "core/providers/shared/utils/utils.h" #include "core/providers/qnn/builder/qnn_utils.h" #include "core/providers/qnn/builder/op_builder_factory.h" #include "core/providers/qnn/builder/qnn_node_group/utils.h" +#include "core/providers/qnn/builder/qnn_model_wrapper.h" namespace onnxruntime { namespace qnn { -static Status QnnDQQFusionAdd(QnnModelWrapper& qnn_model_wrapper, - const NodeUnit& dq_node_unit, - const NodeUnit& q_node_unit, - const logging::Logger& logger, - bool validate = false) { +// Forward declarations. +#define ValidateOnQnn(qnn_model_wrapper, dq_node_unit, q_node_unit) \ + CreateOrValidateOnQnn((qnn_model_wrapper), (dq_node_unit), (q_node_unit), true) +#define CreateOnQnn(qnn_model_wrapper, dq_node_unit, q_node_unit) \ + CreateOrValidateOnQnn((qnn_model_wrapper), (dq_node_unit), (q_node_unit), false) +static Status CreateOrValidateOnQnn(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& dq_node_unit, + const NodeUnit& q_node_unit, bool validate); +static bool IsDQQConversion(const GraphViewer& graph_viewer, const Node& dq_node, const Node& q_node); + +std::unique_ptr DQQFusion::TryFusion( + QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& dq_node_unit, + const std::unordered_map& node_to_node_unit, + const std::unordered_map& node_unit_to_qnn_node_group, + const logging::Logger& logger) { ORT_UNUSED_PARAMETER(logger); - assert(dq_node_unit.OpType() == QDQ::DQOpName && q_node_unit.OpType() == QDQ::QOpName); + // Expect that this function is called with a standalone DQ. + if (dq_node_unit.OpType() != DEQUANTIZE_LINEAR || dq_node_unit.UnitType() != NodeUnit::Type::SingleNode) { + return nullptr; + } + + const GraphViewer& graph_viewer = qnn_model_wrapper.GetGraphViewer(); + const Node& dq_node = dq_node_unit.GetNode(); + + // DQ must have a single Q child (1 output edge) and must not produce a graph output. + const std::array child_types = {QUANTIZE_LINEAR}; + const NodeUnit* q_node_unit = GetOnlyChildOfType(graph_viewer, dq_node_unit, child_types, + node_to_node_unit, node_unit_to_qnn_node_group); + + if (q_node_unit == nullptr) { + return nullptr; + } + + // DQ and Q must have equal scale type and different zp type. + if (!IsDQQConversion(graph_viewer, dq_node, q_node_unit->GetNode())) { + return nullptr; + } + + if (Status status = ValidateOnQnn(qnn_model_wrapper, dq_node_unit, *q_node_unit); + !status.IsOK()) { + return nullptr; + } + + return std::make_unique(dq_node_unit, *q_node_unit); +} + +DQQFusion::DQQFusion(const NodeUnit& dq_node_unit, const NodeUnit& q_node_unit) + : node_units_{&dq_node_unit, &q_node_unit} { +} + +Status DQQFusion::IsSupported(QnnModelWrapper& qmw, const logging::Logger& logger) const { + ORT_UNUSED_PARAMETER(logger); + return ValidateOnQnn(qmw, *node_units_[0], *node_units_[1]); +} + +Status DQQFusion::AddToModelBuilder(QnnModelWrapper& qmw, const logging::Logger& logger) const { + ORT_UNUSED_PARAMETER(logger); + return CreateOnQnn(qmw, *node_units_[0], *node_units_[1]); +} + +gsl::span DQQFusion::GetNodeUnits() const { + return node_units_; +} + +const NodeUnit* DQQFusion::GetTargetNodeUnit() const { + return node_units_[0]; +} + +static Status CreateOrValidateOnQnn(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& dq_node_unit, + const NodeUnit& q_node_unit, + bool validate) { + assert(dq_node_unit.OpType() == DEQUANTIZE_LINEAR && q_node_unit.OpType() == QUANTIZE_LINEAR); const auto& node_name = utils::GetNodeName(dq_node_unit); const NodeUnitIODef& input_def = dq_node_unit.Inputs()[0]; const NodeUnitIODef& output_def = q_node_unit.Outputs()[0]; @@ -56,70 +122,57 @@ static Status QnnDQQFusionAdd(QnnModelWrapper& qnn_model_wrapper, return Status::OK(); } -std::unique_ptr TryDQQFusion( - QnnModelWrapper& qnn_model_wrapper, - const NodeUnit& dq_node_unit, - const std::unordered_map& node_to_node_unit, - const std::unordered_map& node_unit_to_qnn_node_group, - const logging::Logger& logger) { - // Expect that this function is called with a standalone DQ. - if (dq_node_unit.OpType() != "DequantizeLinear" || dq_node_unit.UnitType() != NodeUnit::Type::SingleNode) { - return nullptr; - } - - const GraphViewer& graph_viewer = qnn_model_wrapper.GetGraphViewer(); - const Node& dq_node = dq_node_unit.GetNode(); - - // DQ must have a single Q child (1 output edge) and must not produce a graph output. - const std::array child_types = {"QuantizeLinear"}; - const NodeUnit* q_node_unit = GetOnlyChildOfType(graph_viewer, dq_node_unit, child_types, - node_to_node_unit, node_unit_to_qnn_node_group); +static bool IsDQQConversion(const GraphViewer& graph_viewer, const Node& dq_node, const Node& q_node) { + ConstPointerContainer> dq_input_defs = dq_node.InputDefs(); + ConstPointerContainer> q_input_defs = q_node.InputDefs(); - if (q_node_unit == nullptr) { - return nullptr; - } + auto is_scalar_shape = [](const NodeArg& input_arg) -> bool { + auto shape = input_arg.Shape(); + if (shape == nullptr) { + return false; + } - auto get_const_initializer = [&graph_viewer](const std::string& initializer_name) { - return graph_viewer.GetConstantInitializer(initializer_name, true); + auto dim_size = shape->dim_size(); + return dim_size == 0 || (dim_size == 1 && shape->dim(0).has_dim_value() && shape->dim(0).dim_value() == 1); }; - // DQ and Q must have equal scale type and different zp type. - if (!QDQ::IsDQQConversion(dq_node, q_node_unit->GetNode(), get_const_initializer, graph_viewer.ModelPath())) { - return nullptr; + // Q/DQ contains optional input is not supported + // non-scalar Q/DQ scale and zero point needs are not supported + if (dq_input_defs.size() != QDQ_MAX_NUM_INPUTS || + q_input_defs.size() != QDQ_MAX_NUM_INPUTS || + !is_scalar_shape(*q_input_defs[QDQ_SCALE_INPUT_IDX]) || + !is_scalar_shape(*q_input_defs[QDQ_ZERO_POINT_INPUT_IDX]) || + !is_scalar_shape(*dq_input_defs[QDQ_SCALE_INPUT_IDX]) || + !is_scalar_shape(*dq_input_defs[QDQ_ZERO_POINT_INPUT_IDX])) { + return false; } - if (Status status = QnnDQQFusionAdd(qnn_model_wrapper, dq_node_unit, *q_node_unit, - logger, /*validate*/ true); - !status.IsOK()) { - return nullptr; + // if Q/DQ scale and zero point are not constant, return false + const ONNX_NAMESPACE::TensorProto* dq_scale_tensor_proto = + graph_viewer.GetConstantInitializer(dq_input_defs[QDQ_SCALE_INPUT_IDX]->Name()); + const ONNX_NAMESPACE::TensorProto* q_scale_tensor_proto = + graph_viewer.GetConstantInitializer(q_input_defs[QDQ_SCALE_INPUT_IDX]->Name()); + const ONNX_NAMESPACE::TensorProto* dq_zp_tensor_proto = + graph_viewer.GetConstantInitializer(dq_input_defs[QDQ_ZERO_POINT_INPUT_IDX]->Name()); + const ONNX_NAMESPACE::TensorProto* q_zp_tensor_proto = + graph_viewer.GetConstantInitializer(q_input_defs[QDQ_ZERO_POINT_INPUT_IDX]->Name()); + if (nullptr == q_zp_tensor_proto || + nullptr == dq_zp_tensor_proto || + nullptr == q_scale_tensor_proto || + nullptr == dq_scale_tensor_proto) { + return false; } - std::unique_ptr qnn_node_group = std::make_unique(dq_node_unit, - *q_node_unit); - return qnn_node_group; -} - -namespace dq_q_fusion { -QnnNodeGroup::QnnNodeGroup(const NodeUnit& dq_node_unit, const NodeUnit& q_node_unit) - : node_units_{&dq_node_unit, &q_node_unit} { -} - -Status QnnNodeGroup::IsSupported(QnnModelWrapper& qmw, const logging::Logger& logger) const { - return QnnDQQFusionAdd(qmw, *node_units_[0], *node_units_[1], logger, /*validate*/ true); -} - -Status QnnNodeGroup::AddToModelBuilder(QnnModelWrapper& qmw, const logging::Logger& logger) const { - return QnnDQQFusionAdd(qmw, *node_units_[0], *node_units_[1], logger, /*validate*/ false); -} - -gsl::span QnnNodeGroup::GetNodeUnits() const { - return node_units_; -} + // All TensorProtos must have a data type + if (!q_zp_tensor_proto->has_data_type() || !dq_zp_tensor_proto->has_data_type() || + !q_scale_tensor_proto->has_data_type() || !dq_scale_tensor_proto->has_data_type()) { + return false; + } -const NodeUnit* QnnNodeGroup::GetTargetNodeUnit() const { - return node_units_[0]; + // check Q/DQ have same scale type and different zero point type + return (dq_zp_tensor_proto->data_type() != q_zp_tensor_proto->data_type()) && + (dq_scale_tensor_proto->data_type() == q_scale_tensor_proto->data_type()); } -} // namespace dq_q_fusion } // namespace qnn } // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/dq_q_fusion.h b/onnxruntime/core/providers/qnn/builder/qnn_node_group/dq_q_fusion.h index c5d779c8234ff..dbfc852e8e7fa 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_node_group/dq_q_fusion.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/dq_q_fusion.h @@ -9,37 +9,17 @@ #include "core/common/common.h" #include "core/framework/node_unit.h" -#include "core/providers/qnn/builder/qnn_model_wrapper.h" #include "core/providers/qnn/builder/qnn_node_group.h" namespace onnxruntime { namespace qnn { -/** - * Tries to merge a DQ -> Q sequence into a QNN Convert operator. The DQ -> Q must be converting from - * one quantization type (e.g., uint8_t) to another (e.g., uint16_t). - * - * \param fused_nodes Output list of node units that were fused. Remains empty if fusion is not applied. - * \param qnn_model_wrapper The QNN model that is being built. - * \param dq_node_unit The DQ node unit. - * \param q_node_unit The Q node unit. - * \param logger The logger. - * \param do_op_validation True if should call QNN operator validation APIs. - * \return An onnxruntime::Status - */ -std::unique_ptr TryDQQFusion( - QnnModelWrapper& qnn_model_wrapper, - const NodeUnit& dq_node_unit, - const std::unordered_map& node_to_node_unit, - const std::unordered_map& node_unit_to_qnn_node_group, - const logging::Logger& logger); - -namespace dq_q_fusion { - -class QnnNodeGroup : public IQnnNodeGroup { +class QnnModelWrapper; + +class DQQFusion : public IQnnNodeGroup { public: - QnnNodeGroup(const NodeUnit& dq_node_unit, const NodeUnit& q_node_unit); - ORT_DISALLOW_COPY_AND_ASSIGNMENT(QnnNodeGroup); + DQQFusion(const NodeUnit& dq_node_unit, const NodeUnit& q_node_unit); + ORT_DISALLOW_COPY_AND_ASSIGNMENT(DQQFusion); Status IsSupported(QnnModelWrapper& qmw, const logging::Logger& logger) const override; Status AddToModelBuilder(QnnModelWrapper& qmw, const logging::Logger& logger) const override; @@ -47,10 +27,28 @@ class QnnNodeGroup : public IQnnNodeGroup { const NodeUnit* GetTargetNodeUnit() const override; std::string_view Type() const override { return "DQQFusion"; } + /** + * Tries to merge a DQ -> Q sequence into a QNN Convert operator. The DQ -> Q must be converting from + * one quantization type (e.g., uint8_t) to another (e.g., uint16_t). + * + * \param fused_nodes Output list of node units that were fused. Remains empty if fusion is not applied. + * \param qnn_model_wrapper The QNN model that is being built. + * \param dq_node_unit The DQ node unit. + * \param q_node_unit The Q node unit. + * \param logger The logger. + * \param do_op_validation True if should call QNN operator validation APIs. + * \return An onnxruntime::Status + */ + static std::unique_ptr TryFusion( + QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& dq_node_unit, + const std::unordered_map& node_to_node_unit, + const std::unordered_map& node_unit_to_qnn_node_group, + const logging::Logger& logger); + private: std::array node_units_; }; -} // namespace dq_q_fusion } // namespace qnn } // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/hardsigmoid_mul_fusion.cc b/onnxruntime/core/providers/qnn/builder/qnn_node_group/hardsigmoid_mul_fusion.cc index e77d613d607c6..90f2b1ef29f9a 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_node_group/hardsigmoid_mul_fusion.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/hardsigmoid_mul_fusion.cc @@ -6,65 +6,32 @@ #include #include #include "core/graph/graph_utils.h" -#include "core/optimizer/qdq_transformer/qdq_util.h" #include "core/framework/node_unit.h" #include "core/providers/shared/utils/utils.h" #include "core/providers/qnn/builder/qnn_utils.h" #include "core/providers/qnn/builder/op_builder_factory.h" +#include "core/providers/qnn/builder/qnn_model_wrapper.h" #include "core/providers/qnn/builder/qnn_node_group/utils.h" namespace onnxruntime { namespace qnn { -static Status QnnHardSigmoidMulFusionAdd(QnnModelWrapper& qnn_model_wrapper, - const NodeUnit& hardsigmoid_node_unit, - const NodeUnit& mul_node_unit, - const logging::Logger& logger, - bool validate = false) { - ORT_UNUSED_PARAMETER(logger); - assert(hardsigmoid_node_unit.OpType() == "HardSigmoid" && mul_node_unit.OpType() == "Mul"); - const auto& node_name = utils::GetNodeName(hardsigmoid_node_unit); - const NodeUnitIODef& input_def = hardsigmoid_node_unit.Inputs()[0]; - const NodeUnitIODef& output_def = mul_node_unit.Outputs()[0]; +// Forward declarations. +#define ValidateOnQnn(qnn_model_wrapper, hardsigmoid_node_unit, mul_node_unit) \ + CreateOrValidateOnQnn((qnn_model_wrapper), (hardsigmoid_node_unit), (mul_node_unit), true) +#define CreateOnQnn(qnn_model_wrapper, hardsigmoid_node_unit, mul_node_unit) \ + CreateOrValidateOnQnn((qnn_model_wrapper), (hardsigmoid_node_unit), (mul_node_unit), false) +static Status CreateOrValidateOnQnn(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& hardsigmoid_node_unit, + const NodeUnit& mul_node_unit, bool validate); - QnnTensorWrapper input_tensor; - QnnTensorWrapper output_tensor; - - ORT_RETURN_IF_ERROR(qnn_model_wrapper.MakeTensorWrapper(input_def, input_tensor)); - ORT_RETURN_IF_ERROR(qnn_model_wrapper.MakeTensorWrapper(output_def, output_tensor)); - - if (validate) { - ORT_RETURN_IF_ERROR(qnn_model_wrapper.ValidateQnnNode(node_name, - QNN_OP_PACKAGE_NAME_QTI_AISW, - QNN_OP_HARD_SWISH, - {input_tensor.GetQnnTensor()}, - {output_tensor.GetQnnTensor()}, - {})); - } else { - LOGS(logger, VERBOSE) << " Adding QNN HardSwish via fusion. HardSigmoid name: [" << hardsigmoid_node_unit.Name() - << "] Mul name: [" << mul_node_unit.Name() << "]"; - - ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(input_tensor)), "Failed to add input"); - ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(output_tensor)), "Failed to add output"); - ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(node_name, - QNN_OP_PACKAGE_NAME_QTI_AISW, - QNN_OP_HARD_SWISH, - {input_def.node_arg.Name()}, - {output_def.node_arg.Name()}, - {}, - validate), - "Failed to add fused HardSwish node."); - } - - return Status::OK(); -} - -std::unique_ptr TryHardSigmoidMulFusion( +std::unique_ptr HardSigmoidMulFusion::TryFusion( QnnModelWrapper& qnn_model_wrapper, const NodeUnit& hardsigmoid_node_unit, const std::unordered_map& node_to_node_unit, const std::unordered_map& node_unit_to_qnn_node_group, const logging::Logger& logger) { + ORT_UNUSED_PARAMETER(logger); + // Looking for a standalone HardSigmoid to start the sequence. if (hardsigmoid_node_unit.OpType() != "HardSigmoid" || hardsigmoid_node_unit.UnitType() != NodeUnit::Type::SingleNode) { @@ -104,37 +71,73 @@ std::unique_ptr TryHardSigmoidMulFusion( return nullptr; } - if (Status status = QnnHardSigmoidMulFusionAdd(qnn_model_wrapper, hardsigmoid_node_unit, *mul_node_unit, - logger, /*validate*/ true); + if (Status status = ValidateOnQnn(qnn_model_wrapper, hardsigmoid_node_unit, *mul_node_unit); !status.IsOK()) { return nullptr; } - return std::make_unique(hardsigmoid_node_unit, *mul_node_unit); + return std::make_unique(hardsigmoid_node_unit, *mul_node_unit); } -namespace hs_mul_fusion { - -QnnNodeGroup::QnnNodeGroup(const NodeUnit& hardsigmoid_node_unit, const NodeUnit& mul_node_unit) +HardSigmoidMulFusion::HardSigmoidMulFusion(const NodeUnit& hardsigmoid_node_unit, const NodeUnit& mul_node_unit) : node_units_{&hardsigmoid_node_unit, &mul_node_unit} { } -Status QnnNodeGroup::IsSupported(QnnModelWrapper& qmw, const logging::Logger& logger) const { - return QnnHardSigmoidMulFusionAdd(qmw, *node_units_[0], *node_units_[1], logger, /*validate*/ true); +Status HardSigmoidMulFusion::IsSupported(QnnModelWrapper& qmw, const logging::Logger& logger) const { + ORT_UNUSED_PARAMETER(logger); + return ValidateOnQnn(qmw, *node_units_[0], *node_units_[1]); } -Status QnnNodeGroup::AddToModelBuilder(QnnModelWrapper& qmw, const logging::Logger& logger) const { - return QnnHardSigmoidMulFusionAdd(qmw, *node_units_[0], *node_units_[1], logger, /*validate*/ false); +Status HardSigmoidMulFusion::AddToModelBuilder(QnnModelWrapper& qmw, const logging::Logger& logger) const { + ORT_UNUSED_PARAMETER(logger); + return CreateOnQnn(qmw, *node_units_[0], *node_units_[1]); } -gsl::span QnnNodeGroup::GetNodeUnits() const { +gsl::span HardSigmoidMulFusion::GetNodeUnits() const { return node_units_; } -const NodeUnit* QnnNodeGroup::GetTargetNodeUnit() const { +const NodeUnit* HardSigmoidMulFusion::GetTargetNodeUnit() const { return node_units_[0]; } -} // namespace hs_mul_fusion +static Status CreateOrValidateOnQnn(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& hardsigmoid_node_unit, + const NodeUnit& mul_node_unit, + bool validate) { + assert(hardsigmoid_node_unit.OpType() == "HardSigmoid" && mul_node_unit.OpType() == "Mul"); + const auto& node_name = utils::GetNodeName(hardsigmoid_node_unit); + const NodeUnitIODef& input_def = hardsigmoid_node_unit.Inputs()[0]; + const NodeUnitIODef& output_def = mul_node_unit.Outputs()[0]; + + QnnTensorWrapper input_tensor; + QnnTensorWrapper output_tensor; + + ORT_RETURN_IF_ERROR(qnn_model_wrapper.MakeTensorWrapper(input_def, input_tensor)); + ORT_RETURN_IF_ERROR(qnn_model_wrapper.MakeTensorWrapper(output_def, output_tensor)); + + if (validate) { + ORT_RETURN_IF_ERROR(qnn_model_wrapper.ValidateQnnNode(node_name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + QNN_OP_HARD_SWISH, + {input_tensor.GetQnnTensor()}, + {output_tensor.GetQnnTensor()}, + {})); + } else { + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(input_tensor)), "Failed to add input"); + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(output_tensor)), "Failed to add output"); + ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(node_name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + QNN_OP_HARD_SWISH, + {input_def.node_arg.Name()}, + {output_def.node_arg.Name()}, + {}, + validate), + "Failed to add fused HardSwish node."); + } + + return Status::OK(); +} + } // namespace qnn } // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/hardsigmoid_mul_fusion.h b/onnxruntime/core/providers/qnn/builder/qnn_node_group/hardsigmoid_mul_fusion.h index 3b04dccf1f6a5..505ec17d0eb29 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_node_group/hardsigmoid_mul_fusion.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/hardsigmoid_mul_fusion.h @@ -7,40 +7,19 @@ #include #include +#include "core/common/common.h" #include "core/framework/node_unit.h" -#include "core/providers/qnn/builder/qnn_model_wrapper.h" #include "core/providers/qnn/builder/qnn_node_group.h" namespace onnxruntime { namespace qnn { -/** - * Tries to fuse the sequence `x * HardSigmoid(x)` into a single HardSwish(x) operator. - * Should be called in a topologically ordered iteration of node units. - * - * \param fused_nodes Output list of node units that were fused. Remains empty if fusion was not applied. - * \param qnn_model_wrapper The QNN model that is being built. - * \param starting_node The node unit that could potentially start the sequence. - * \param node_unit_map Maps a node to its node unit. - * \param handled_node_units Set of node units that have already been processed. Fusion will not fuse nodes - * in this set. - * \param logger The logger. - * \param do_op_validation True if should call QNN operator validation APIs. - * \return A Status indicating a potential failure. - */ -std::unique_ptr TryHardSigmoidMulFusion( - QnnModelWrapper& qnn_model_wrapper, - const NodeUnit& hardsigmoid_node_unit, - const std::unordered_map& node_to_node_unit, - const std::unordered_map& node_unit_to_qnn_node_group, - const logging::Logger& logger); - -namespace hs_mul_fusion { - -class QnnNodeGroup : public IQnnNodeGroup { +class QnnModelWrapper; + +class HardSigmoidMulFusion : public IQnnNodeGroup { public: - QnnNodeGroup(const NodeUnit& hardsigmoid_node_unit, const NodeUnit& mul_node_unit); - ORT_DISALLOW_COPY_AND_ASSIGNMENT(QnnNodeGroup); + HardSigmoidMulFusion(const NodeUnit& hardsigmoid_node_unit, const NodeUnit& mul_node_unit); + ORT_DISALLOW_COPY_AND_ASSIGNMENT(HardSigmoidMulFusion); Status IsSupported(QnnModelWrapper& qmw, const logging::Logger& logger) const override; Status AddToModelBuilder(QnnModelWrapper& qmw, const logging::Logger& logger) const override; @@ -48,10 +27,30 @@ class QnnNodeGroup : public IQnnNodeGroup { const NodeUnit* GetTargetNodeUnit() const override; std::string_view Type() const override { return "HardSigmoidMulFusion"; } + /** + * Tries to fuse the sequence `x * HardSigmoid(x)` into a single HardSwish(x) operator. + * Should be called in a topologically ordered iteration of node units. + * + * \param fused_nodes Output list of node units that were fused. Remains empty if fusion was not applied. + * \param qnn_model_wrapper The QNN model that is being built. + * \param starting_node The node unit that could potentially start the sequence. + * \param node_unit_map Maps a node to its node unit. + * \param handled_node_units Set of node units that have already been processed. Fusion will not fuse nodes + * in this set. + * \param logger The logger. + * \param do_op_validation True if should call QNN operator validation APIs. + * \return A Status indicating a potential failure. + */ + static std::unique_ptr TryFusion( + QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& hardsigmoid_node_unit, + const std::unordered_map& node_to_node_unit, + const std::unordered_map& node_unit_to_qnn_node_group, + const logging::Logger& logger); + private: std::array node_units_; }; -} // namespace hs_mul_fusion } // namespace qnn } // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/qnn_node_group.cc b/onnxruntime/core/providers/qnn/builder/qnn_node_group/qnn_node_group.cc index 7a5abd6c9c9e2..950aa1392fc89 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_node_group/qnn_node_group.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/qnn_node_group.cc @@ -11,7 +11,6 @@ #include #include #include "core/graph/graph_utils.h" -#include "core/optimizer/qdq_transformer/qdq_util.h" #include "core/framework/node_unit.h" #include "core/providers/qnn/builder/qnn_utils.h" #include "core/providers/qnn/builder/qnn_model_wrapper.h" @@ -50,7 +49,7 @@ class QnnNodeUnitWrapper : public IQnnNodeGroup { } const NodeUnit* GetTargetNodeUnit() const override { return node_unit_; } - std::string_view Type() const override { return "NodeUnitWrapper"; } + std::string_view Type() const override { return "NodeUnit"; } private: const NodeUnit* node_unit_; @@ -71,10 +70,10 @@ static std::unique_ptr TryQnnFusions( const logging::Logger& logger) { // Maps a starting operator type to the fusion function. static std::unordered_map fusions = { - {"DequantizeLinear", TryDQQFusion}, - {"HardSigmoid", TryHardSigmoidMulFusion}, - {"Conv", TryConvActivationFusion}, - {"ConvTranspose", TryConvActivationFusion}, + {"DequantizeLinear", DQQFusion::TryFusion}, + {"HardSigmoid", HardSigmoidMulFusion::TryFusion}, + {"Conv", ConvActivationFusion::TryFusion}, + {"ConvTranspose", ConvActivationFusion::TryFusion}, }; // For now, all fusions involve standalone node units (i.e., no wrapping DQ/Q nodes). diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/utils.h b/onnxruntime/core/providers/qnn/builder/qnn_node_group/utils.h index 308d08d42d87c..915eff3fbd418 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_node_group/utils.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/utils.h @@ -13,6 +13,12 @@ namespace onnxruntime { namespace qnn { +constexpr const char* QUANTIZE_LINEAR = "QuantizeLinear"; +constexpr const char* DEQUANTIZE_LINEAR = "DequantizeLinear"; +constexpr size_t QDQ_MAX_NUM_INPUTS = 3; +constexpr size_t QDQ_SCALE_INPUT_IDX = 1; +constexpr size_t QDQ_ZERO_POINT_INPUT_IDX = 2; + const NodeUnit* GetOnlyChildOfType(const GraphViewer& graph_viewer, const NodeUnit& parent_node_unit, gsl::span child_op_types, diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc index 58bfacc5cd73d..bb676b94d4927 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc @@ -405,6 +405,37 @@ QNNExecutionProvider::~QNNExecutionProvider() { #endif } +// Logs information about the supported/unsupported nodes. +static void LogNodeSupport(const logging::Logger& logger, + logging::Severity log_severity, + logging::DataType log_data_type, + const onnxruntime::CodeLocation& call_site, + const qnn::IQnnNodeGroup& qnn_node_group, + Status support_status) { + if (!logger.OutputIsEnabled(log_severity, log_data_type)) { + return; + } + + std::ostringstream oss; + oss << (support_status.IsOK() ? "Validation PASSED " : "Validation FAILED ") << "for nodes (" + << qnn_node_group.Type() << "):" << std::endl; + for (const NodeUnit* node_unit : qnn_node_group.GetNodeUnits()) { + for (const Node* node : node_unit->GetAllNodesInGroup()) { + oss << "\tOperator type: " << node->OpType() + << " Node name: " << node->Name() + << " Node index: " << node->Index() << std::endl; + } + } + if (!support_status.IsOK()) { + oss << "\tREASON : " << support_status.ErrorMessage() << std::endl; + } + + logging::Capture(logger, log_severity, logging::Category::onnxruntime, + log_data_type, call_site) + .Stream() + << oss.str(); +} + std::unordered_set QNNExecutionProvider::GetSupportedNodes(const GraphViewer& graph_viewer, const std::unordered_map& node_unit_map, @@ -451,33 +482,6 @@ QNNExecutionProvider::GetSupportedNodes(const GraphViewer& graph_viewer, return {}; } - auto log_node_support = [](const logging::Logger& logger, - logging::Severity log_severity, - logging::DataType log_data_type, - const onnxruntime::CodeLocation& call_site, - const qnn::IQnnNodeGroup& qnn_node_group, - bool supported) { - if (!logger.OutputIsEnabled(log_severity, log_data_type)) { - return; - } - - std::ostringstream oss; - oss << "[QNN EP] " << (supported ? "Supports " : "Does NOT support ") << "the following nodes as part of a " - << qnn_node_group.Type() << " group:" << std::endl; - for (const NodeUnit* node_unit : qnn_node_group.GetNodeUnits()) { - for (const Node* node : node_unit->GetAllNodesInGroup()) { - oss << "\tOperator type: " << node->OpType() - << " Node name: " << node->Name() - << " Node index: " << node->Index() << std::endl; - } - } - - logging::Capture(logger, log_severity, logging::Category::onnxruntime, - log_data_type, call_site) - .Stream() - << oss.str(); - }; - for (const std::unique_ptr& qnn_node_group : qnn_node_groups) { Status status = qnn_node_group->IsSupported(qnn_model_wrapper, logger); const bool supported = status.IsOK(); @@ -485,7 +489,7 @@ QNNExecutionProvider::GetSupportedNodes(const GraphViewer& graph_viewer, constexpr auto log_severity = logging::Severity::kVERBOSE; constexpr auto log_data_type = logging::DataType::SYSTEM; if (logger.OutputIsEnabled(log_severity, log_data_type)) { - log_node_support(logger, log_severity, log_data_type, ORT_WHERE, *qnn_node_group, supported); + LogNodeSupport(logger, log_severity, log_data_type, ORT_WHERE, *qnn_node_group, status); } if (supported) { diff --git a/onnxruntime/test/providers/qnn/qnn_basic_test.cc b/onnxruntime/test/providers/qnn/qnn_basic_test.cc index 37eeac5101feb..a7732fd641e38 100644 --- a/onnxruntime/test/providers/qnn/qnn_basic_test.cc +++ b/onnxruntime/test/providers/qnn/qnn_basic_test.cc @@ -964,7 +964,7 @@ TEST_F(QnnHTPBackendTests, TestOD) { so.AddConfigEntry(kOrtSessionOptionsDisableCPUEPFallback, "0"); // Disable fallback to the CPU EP. so.AddConfigEntry(kDebugLayoutTransformation, "1"); so.SetGraphOptimizationLevel(ORT_ENABLE_BASIC); - //so.SetLogSeverityLevel(ORT_LOGGING_LEVEL_VERBOSE); + //so.SetLogSeverityLevel(ORT_LOGGING_LEVEL_INFO); onnxruntime::ProviderOptions options; #if defined(_WIN32) From 94faeaf7e83e85fdd15811bc928bf08841df30aa Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Sun, 28 Jul 2024 17:20:54 -0700 Subject: [PATCH 15/20] Add more comments --- .../providers/qnn/builder/qnn_node_group.h | 37 +++- .../qnn_node_group/conv_activation_fusion.h | 17 +- .../qnn/builder/qnn_node_group/dq_q_fusion.h | 27 +-- .../qnn_node_group/hardsigmoid_mul_fusion.h | 5 + .../builder/qnn_node_group/qnn_node_group.cc | 164 +++++++++++------- 5 files changed, 170 insertions(+), 80 deletions(-) diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group.h b/onnxruntime/core/providers/qnn/builder/qnn_node_group.h index a3c1b1bcdd407..f9ef01411310f 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_node_group.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group.h @@ -16,16 +16,49 @@ namespace qnn { class QnnModelWrapper; +/// +/// Represents a group of NodeUnits that QNN EP translates into a core QNN operator. Can represent a single NodeUnit +/// or a fusion of multiple NodeUnits (e.g., DQ* -> Conv -> Relu -> Q). +/// class IQnnNodeGroup { public: virtual ~IQnnNodeGroup() = default; - virtual Status IsSupported(QnnModelWrapper& qmw, const logging::Logger& logger) const = 0; - virtual Status AddToModelBuilder(QnnModelWrapper& qmw, const logging::Logger& logger) const = 0; + + // Returns an OK status if this IQnnNodeGroup is supported by QNN. + virtual Status IsSupported(QnnModelWrapper& qnn_model_wrapper, const logging::Logger& logger) const = 0; + + // Adds this IQnnNodeGroup to the QNN model wrapper. + virtual Status AddToModelBuilder(QnnModelWrapper& qnn_model_wrapper, const logging::Logger& logger) const = 0; + + // Returns a list of NodeUnits contained by this IQnnNodeGroup. virtual gsl::span GetNodeUnits() const = 0; + + /// + /// Returns the "target" NodeUnit of the group. This is important for topological ordering of IQnnNodeGroups. + /// The target should be the first NodeUnit where all input paths (of the IQnnNodeGroup) converge. + /// For example, "Conv" should be the target NodeUnit for the following IQnnNodeGroup with 6 NodeUnits. + /// input0 -> DQ -> Conv -> Relu -> Q + /// ^ + /// | + /// input1 -> DQ ----+ + /// + /// Target NodeUnit in IQnnNodeGroup virtual const NodeUnit* GetTargetNodeUnit() const = 0; + + // Returns a string representation of the IQnnNodeGroup's type. virtual std::string_view Type() const = 0; }; +/// +/// Traverses the ONNX graph to create IQnnNodeGroup objects, each containing one or more NodeUnits. +/// The returned IQnnNodeGroup objects are sorted in topological order. +/// +/// Output vector into which the resulting IQnnNodeGroup objects are stored. +/// Contains reference to the ONNX GraphViewer and used for validaton on QNN +/// Maps a Node* to a NodeUnit* +/// The number of NodeUnits in the ONNX graph. +/// Logger +/// Status with potential error Status GetQnnNodeGroups(/*out*/ std::vector>& qnn_node_groups, QnnModelWrapper& qnn_model_wrapper, const std::unordered_map& node_to_node_unit, diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/conv_activation_fusion.h b/onnxruntime/core/providers/qnn/builder/qnn_node_group/conv_activation_fusion.h index f0e140addbd7a..71fc71434c5e5 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_node_group/conv_activation_fusion.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/conv_activation_fusion.h @@ -17,6 +17,11 @@ namespace qnn { class QnnModelWrapper; +/// +/// Represents a fusion of a DQ* -> Conv -> Relu/Clip -> Q sequence where the Relu (or Clip) is redundant +/// due to the quantization effects of the Q. This sequence is translated to a quantized QNN Conv. +/// All contained NodeUnits are of type SingleNode since they are not a part of an existing QDQ node unit. +/// class ConvActivationFusion : public IQnnNodeGroup { public: ConvActivationFusion(const NodeUnit& dq_node_unit_0, @@ -33,6 +38,16 @@ class ConvActivationFusion : public IQnnNodeGroup { const NodeUnit* GetTargetNodeUnit() const override; std::string_view Type() const override { return "ConvActivationFusion"; } + /// + /// Traverses graph to check if the given NodeUnit is part of a valid DQ* -> Conv -> Relu/Clip -> Q sequence. + /// If so, returns a IQnnNodeGroup that contains the constituent NodeUnits. + /// + /// Used for validation and to traverse/query the graph + /// Conv node unit (type SingleNode) that be part of the sequence. + /// Maps a Node to a NodeUnit. + /// Maps a NodeUnit to a IQnnNodeGroup. + /// + /// A valid IQnnNodeGroup on success or an empty std::unique_ptr otherwise static std::unique_ptr TryFusion( QnnModelWrapper& qnn_model_wrapper, const NodeUnit& conv_node_unit, @@ -41,7 +56,7 @@ class ConvActivationFusion : public IQnnNodeGroup { const logging::Logger& logger); private: - std::array node_units_; // Last elem is nullptr if bias DQ is missing. + std::array node_units_; // Last elem is nullptr if the optional bias DQ is missing. }; } // namespace qnn diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/dq_q_fusion.h b/onnxruntime/core/providers/qnn/builder/qnn_node_group/dq_q_fusion.h index dbfc852e8e7fa..90fe44c3af059 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_node_group/dq_q_fusion.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/dq_q_fusion.h @@ -16,6 +16,11 @@ namespace qnn { class QnnModelWrapper; +/// +/// Represents a fusion of a DQ -> Q sequence that converts from one quantization type (e.g., uint8_t) to +/// another (e.g., uint16_t). This is translated into a QNN Convert operator, which is much faster than individual +/// ops. The DQ and Q are standalone NodeUnits that are not part of a QDQ node unit. +/// class DQQFusion : public IQnnNodeGroup { public: DQQFusion(const NodeUnit& dq_node_unit, const NodeUnit& q_node_unit); @@ -27,18 +32,16 @@ class DQQFusion : public IQnnNodeGroup { const NodeUnit* GetTargetNodeUnit() const override; std::string_view Type() const override { return "DQQFusion"; } - /** - * Tries to merge a DQ -> Q sequence into a QNN Convert operator. The DQ -> Q must be converting from - * one quantization type (e.g., uint8_t) to another (e.g., uint16_t). - * - * \param fused_nodes Output list of node units that were fused. Remains empty if fusion is not applied. - * \param qnn_model_wrapper The QNN model that is being built. - * \param dq_node_unit The DQ node unit. - * \param q_node_unit The Q node unit. - * \param logger The logger. - * \param do_op_validation True if should call QNN operator validation APIs. - * \return An onnxruntime::Status - */ + /// + /// Traverses graph to check if the given starting NodeUnit is part of a valid DQ -> Q sequence. + /// If so, returns a IQnnNodeGroup that contains the DQ and Q NodeUnits. + /// + /// Used for validation and traverse/query the graph + /// DQ node unit that could start the DQ -> Q sequence + /// Maps a Node to a NodeUnit. + /// Maps a NodeUnit to a IQnnNodeGroup. + /// + /// A valid IQnnNodeGroup on success or an empty std::unique_ptr otherwise static std::unique_ptr TryFusion( QnnModelWrapper& qnn_model_wrapper, const NodeUnit& dq_node_unit, diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/hardsigmoid_mul_fusion.h b/onnxruntime/core/providers/qnn/builder/qnn_node_group/hardsigmoid_mul_fusion.h index 505ec17d0eb29..e4a87983fc9ef 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_node_group/hardsigmoid_mul_fusion.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/hardsigmoid_mul_fusion.h @@ -16,6 +16,11 @@ namespace qnn { class QnnModelWrapper; +/// +/// Represents a fusion of a HardSigmoid -> Mul sequence that computes `x * HardSigmoid(x)`. +/// This is translated into a QNN HardSwish operator. +/// The contained NodeUnits are of type SingleNode since they are not a part of a QDQ node unit. +/// class HardSigmoidMulFusion : public IQnnNodeGroup { public: HardSigmoidMulFusion(const NodeUnit& hardsigmoid_node_unit, const NodeUnit& mul_node_unit); diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/qnn_node_group.cc b/onnxruntime/core/providers/qnn/builder/qnn_node_group/qnn_node_group.cc index 950aa1392fc89..e4b9e77a04cd7 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_node_group/qnn_node_group.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/qnn_node_group.cc @@ -22,6 +22,10 @@ namespace onnxruntime { namespace qnn { +/// +/// A IQnnNodeGroup class that wraps a single NodeUnit. Most NodeUnits in the ONNX graph will +/// be wrapped by this class. +/// class QnnNodeUnitWrapper : public IQnnNodeGroup { public: QnnNodeUnitWrapper(const NodeUnit& node_unit) : node_unit_(&node_unit) {} @@ -55,6 +59,9 @@ class QnnNodeUnitWrapper : public IQnnNodeGroup { const NodeUnit* node_unit_; }; +/// +/// The type of a function that tries to fuse NodeUnits into a IQnnNodeGroup. +/// using FusionFunc = std::unique_ptr (*)( QnnModelWrapper&, const NodeUnit&, @@ -62,6 +69,17 @@ using FusionFunc = std::unique_ptr (*)( const std::unordered_map&, const logging::Logger&); +/// +/// Given a starting NodeUnit, this function tries all possible fusions that start with that NodeUnit. +/// If successful, returns a IQnnNodeGroup object that represents the fusion of various NodeUnits. +/// Currently only handles standalone NodeUnits that are not in a QDQ unit but that can change in the future. +/// +/// QnnModelWrapper that contains the ONNX GraphViewer. Used for validation. +/// NodeUnit that potentially starts a fusion. +/// Maps a Node* to a NodeUnit* +/// Maps a NodeUnit* to a IQnnNodeGroup* +/// +/// IQnnNodeGroup representing the fusion or an empty std::unique_ptr static std::unique_ptr TryQnnFusions( QnnModelWrapper& qnn_model_wrapper, const NodeUnit& starting_node_unit, @@ -90,92 +108,108 @@ static std::unique_ptr TryQnnFusions( return nullptr; } -Status GetQnnNodeGroups(/*out*/ std::vector>& qnn_node_groups, - QnnModelWrapper& qnn_model_wrapper, - const std::unordered_map& node_to_node_unit, - const size_t num_node_units, - const logging::Logger& logger) { +// Traverses the ONNX Graph and groups NodeUnits into IQnnNodeGroup objects. Some IQnnNodeGroup objects +// represent a fusion of various NodeUnits. This function generates a vector of indices that +// represent the topological order of the qnn_node_groups. +static Status GetQnnNodeGroupsImpl(/*out*/ std::vector>& qnn_node_groups, + /*out*/ std::vector& sorted_qnn_node_group_indices, + QnnModelWrapper& qnn_model_wrapper, + const std::unordered_map& node_to_node_unit, + const size_t num_node_units, + const logging::Logger& logger) { const GraphViewer& graph_viewer = qnn_model_wrapper.GetGraphViewer(); const std::vector sorted_node_indices = graph_viewer.GetNodesInTopologicalOrder(); - std::vector sorted_qnn_node_group_indices; sorted_qnn_node_group_indices.reserve(num_node_units); + qnn_node_groups.reserve(num_node_units); + + std::unordered_map node_unit_to_qnn_node_group; + std::unordered_map fused_qnn_node_group_indices; + std::vector> sorted_node_units; + sorted_node_units.reserve(num_node_units); + + // Process just the fusions of NodeUnits first to ensure a correct topological order of all IQnnNodeGroups. + // This is the same approach taken by ORT utilities for grouping Nodes into NodeUnits. + for (NodeIndex node_index : sorted_node_indices) { + gsl::not_null node = graph_viewer.GetNode(node_index); + + // Get the NodeUnit associated with the node. + const auto node_unit_it = node_to_node_unit.find(node); + ORT_RETURN_IF_NOT(node_unit_it != node_to_node_unit.end(), "Could not find NodeUnit for Node ", node->Name()); + gsl::not_null node_unit = node_unit_it->second; + + // Skip this node if it is not the NodeUnit's target node to ensure NodeUnits are visited in topological order. + if (node != &node_unit->GetNode()) { + continue; + } - std::vector> tmp_qnn_node_groups; - tmp_qnn_node_groups.reserve(num_node_units); - - { - std::unordered_map node_unit_to_qnn_node_group; - std::unordered_map fused_qnn_node_group_indices; - std::vector> sorted_node_units; - sorted_node_units.reserve(num_node_units); - - // Create QnnNodeGroups for fusions first. - for (NodeIndex node_index : sorted_node_indices) { - gsl::not_null node = graph_viewer.GetNode(node_index); + sorted_node_units.push_back(node_unit); - // Get the NodeUnit associated with the node. - const auto node_unit_it = node_to_node_unit.find(node); - ORT_RETURN_IF_NOT(node_unit_it != node_to_node_unit.end(), "Could not find NodeUnit for Node ", node->Name()); - gsl::not_null node_unit = node_unit_it->second; + if (node_unit_to_qnn_node_group.count(node_unit) != 0) { + continue; // Already handled this node unit + } - // Skip this node if it is not the NodeUnit's target node to ensure NodeUnits are visited in topological order. - if (node != &node_unit->GetNode()) { - continue; - } + std::unique_ptr fused_node_group = TryQnnFusions(qnn_model_wrapper, *node_unit, + node_to_node_unit, node_unit_to_qnn_node_group, + logger); - sorted_node_units.push_back(node_unit); + if (fused_node_group) { + const size_t index = qnn_node_groups.size(); + fused_qnn_node_group_indices[fused_node_group.get()] = index; - if (node_unit_to_qnn_node_group.count(node_unit) != 0) { - continue; // Already handled this node unit + for (const NodeUnit* fused_node_unit : fused_node_group->GetNodeUnits()) { + assert(fused_node_unit != nullptr); + node_unit_to_qnn_node_group.insert({fused_node_unit, fused_node_group.get()}); } - std::unique_ptr fused_node_group = TryQnnFusions(qnn_model_wrapper, *node_unit, - node_to_node_unit, node_unit_to_qnn_node_group, - logger); - - if (fused_node_group) { - const size_t index = tmp_qnn_node_groups.size(); - fused_qnn_node_group_indices[fused_node_group.get()] = index; - - for (const NodeUnit* fused_node_unit : fused_node_group->GetNodeUnits()) { - assert(fused_node_unit != nullptr); - node_unit_to_qnn_node_group.insert({fused_node_unit, fused_node_group.get()}); - } + qnn_node_groups.push_back(std::move(fused_node_group)); + } + } - tmp_qnn_node_groups.push_back(std::move(fused_node_group)); + // Create IQnnNodeGroups for the leftover NodeUnits that were not fused. + for (gsl::not_null node_unit : sorted_node_units) { + const auto it = node_unit_to_qnn_node_group.find(node_unit); + + if (it != node_unit_to_qnn_node_group.end()) { + // Already added this NodeUnit to a IQnnNodeGroup, so we'll skip it. + // However, if this NodeUnit is the "target" for the IQnnNodeGroup, then add its index to + // the sorted list of indices. + gsl::not_null fused_qnn_node_group = it->second; + if (node_unit == fused_qnn_node_group->GetTargetNodeUnit()) { + sorted_qnn_node_group_indices.push_back(fused_qnn_node_group_indices[fused_qnn_node_group]); } + continue; } - // Create QnnNodeGroups for the leftover NodeUnits. - for (gsl::not_null node_unit : sorted_node_units) { - const auto it = node_unit_to_qnn_node_group.find(node_unit); - if (it != node_unit_to_qnn_node_group.end()) { - // Already handled this NodeUnit. - gsl::not_null fused_qnn_node_group = it->second; - if (node_unit == fused_qnn_node_group->GetTargetNodeUnit()) { - sorted_qnn_node_group_indices.push_back(fused_qnn_node_group_indices[fused_qnn_node_group]); - } - continue; - } + const size_t index = qnn_node_groups.size(); + auto qnn_node_group = std::make_unique(*node_unit); - const size_t index = tmp_qnn_node_groups.size(); - auto qnn_node_group = std::make_unique(*node_unit); + node_unit_to_qnn_node_group.insert({node_unit, qnn_node_group.get()}); + qnn_node_groups.push_back(std::move(qnn_node_group)); + sorted_qnn_node_group_indices.push_back(index); + } - node_unit_to_qnn_node_group.insert({node_unit, qnn_node_group.get()}); - tmp_qnn_node_groups.push_back(std::move(qnn_node_group)); - sorted_qnn_node_group_indices.push_back(index); - } + assert(qnn_node_groups.size() == sorted_qnn_node_group_indices.size()); - assert(tmp_qnn_node_groups.size() == sorted_qnn_node_group_indices.size()); - } + return Status::OK(); +} + +Status GetQnnNodeGroups(/*out*/ std::vector>& qnn_node_groups, + QnnModelWrapper& qnn_model_wrapper, + const std::unordered_map& node_to_node_unit, + const size_t num_node_units, + const logging::Logger& logger) { + std::vector sorted_qnn_node_group_indices; + std::vector> qnn_node_groups_holder; + ORT_RETURN_IF_ERROR(GetQnnNodeGroupsImpl(qnn_node_groups_holder, sorted_qnn_node_group_indices, qnn_model_wrapper, + node_to_node_unit, num_node_units, logger)); - // Copy QnnNodeGroups to output in sorted (topological) order. + // Move IQnnNodeGroups to the output std::vector in sorted (topological) order. qnn_node_groups.resize(0); - qnn_node_groups.reserve(tmp_qnn_node_groups.size()); + qnn_node_groups.reserve(qnn_node_groups_holder.size()); for (auto index : sorted_qnn_node_group_indices) { - assert(index < tmp_qnn_node_groups.size()); - std::unique_ptr qnn_node_group = std::move(tmp_qnn_node_groups[index]); + assert(index < qnn_node_groups_holder.size()); + std::unique_ptr qnn_node_group = std::move(qnn_node_groups_holder[index]); qnn_node_groups.push_back(std::move(qnn_node_group)); } From a946b5e5b3fa0c3d44d4e29470b65b7299c5d945 Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Sun, 28 Jul 2024 18:54:24 -0700 Subject: [PATCH 16/20] Add unit tests --- .../qnn_node_group/hardsigmoid_mul_fusion.h | 24 +- onnxruntime/test/providers/qnn/conv_test.cc | 261 +++++++++++++++--- .../test/providers/qnn/qnn_basic_test.cc | 66 +---- 3 files changed, 239 insertions(+), 112 deletions(-) diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/hardsigmoid_mul_fusion.h b/onnxruntime/core/providers/qnn/builder/qnn_node_group/hardsigmoid_mul_fusion.h index e4a87983fc9ef..3b67f13492a46 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_node_group/hardsigmoid_mul_fusion.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/hardsigmoid_mul_fusion.h @@ -32,20 +32,16 @@ class HardSigmoidMulFusion : public IQnnNodeGroup { const NodeUnit* GetTargetNodeUnit() const override; std::string_view Type() const override { return "HardSigmoidMulFusion"; } - /** - * Tries to fuse the sequence `x * HardSigmoid(x)` into a single HardSwish(x) operator. - * Should be called in a topologically ordered iteration of node units. - * - * \param fused_nodes Output list of node units that were fused. Remains empty if fusion was not applied. - * \param qnn_model_wrapper The QNN model that is being built. - * \param starting_node The node unit that could potentially start the sequence. - * \param node_unit_map Maps a node to its node unit. - * \param handled_node_units Set of node units that have already been processed. Fusion will not fuse nodes - * in this set. - * \param logger The logger. - * \param do_op_validation True if should call QNN operator validation APIs. - * \return A Status indicating a potential failure. - */ + /// + /// Traverses graph to check if the given starting NodeUnit is part of a valid HardSigmoid -> Mul sequence. + /// If so, returns a IQnnNodeGroup that contains the HardSigmoid and Mul NodeUnits. + /// + /// Used for validation and traverse/query the graph + /// HardSigmoid node unit that could start the sequence + /// Maps a Node to a NodeUnit. + /// Maps a NodeUnit to a IQnnNodeGroup. + /// + /// A valid IQnnNodeGroup on success or an empty std::unique_ptr otherwise static std::unique_ptr TryFusion( QnnModelWrapper& qnn_model_wrapper, const NodeUnit& hardsigmoid_node_unit, diff --git a/onnxruntime/test/providers/qnn/conv_test.cc b/onnxruntime/test/providers/qnn/conv_test.cc index 99636976b9c05..4f7d3dea66db6 100644 --- a/onnxruntime/test/providers/qnn/conv_test.cc +++ b/onnxruntime/test/providers/qnn/conv_test.cc @@ -15,6 +15,13 @@ namespace onnxruntime { namespace test { +// Information for activation node placed between the Conv and Q. +struct OutputActivationInfo { + std::string op_type; // Relu or Clip + std::vector attrs; + std::vector const_inputs; +}; + // Creates a graph with a single float32 Conv operator. Used for testing CPU backend. static GetTestModelFn BuildF32ConvTestCase(const std::string& conv_op_type, const TestInputDef& input_def, const TestInputDef& weights_def, @@ -23,9 +30,10 @@ static GetTestModelFn BuildF32ConvTestCase(const std::string& conv_op_type, cons const std::vector& pads, const std::vector& dilations, std::optional group, - const std::string& auto_pad = "NOTSET") { + const std::string& auto_pad = "NOTSET", + std::optional output_activation = std::nullopt) { return [conv_op_type, input_def, weights_def, bias_def, strides, pads, - dilations, group, auto_pad](ModelTestBuilder& builder) { + dilations, group, auto_pad, output_activation](ModelTestBuilder& builder) { std::vector conv_inputs = { MakeTestInput(builder, input_def), MakeTestInput(builder, weights_def)}; @@ -34,9 +42,9 @@ static GetTestModelFn BuildF32ConvTestCase(const std::string& conv_op_type, cons conv_inputs.push_back(MakeTestInput(builder, bias_def)); } - auto* output = builder.MakeOutput(); + auto* conv_output = output_activation.has_value() ? builder.MakeIntermediate() : builder.MakeOutput(); - Node& conv_node = builder.AddNode(conv_op_type, conv_inputs, {output}); + Node& conv_node = builder.AddNode(conv_op_type, conv_inputs, {conv_output}); conv_node.AddAttribute("auto_pad", auto_pad); if (group.has_value()) { @@ -54,6 +62,18 @@ static GetTestModelFn BuildF32ConvTestCase(const std::string& conv_op_type, cons if (!dilations.empty()) { conv_node.AddAttribute("dilations", dilations); } + + if (output_activation.has_value()) { + NodeArg* output = builder.MakeOutput(); + std::vector activation_inputs = {conv_output}; + for (auto val : output_activation->const_inputs) { + activation_inputs.push_back(builder.MakeScalarInitializer(val)); + } + Node& activation_node = builder.AddNode(output_activation->op_type, activation_inputs, {output}); + for (const auto& attr : output_activation->attrs) { + activation_node.AddAttributeProto(attr); + } + } }; } @@ -88,19 +108,22 @@ static void RunCPUConvOpTest(const std::string& conv_op_type, const TestInputDef // Creates a graph with a single Q/DQ Conv operator. Used for testing HTP backend. template -static GetTestQDQModelFn BuildQDQConvTestCase(const std::string& conv_op_type, - const TestInputDef& input_def, - const TestInputDef& weights_def, - const TestInputDef& bias_def, - const std::vector& strides, - const std::vector& pads, - const std::vector& dilations, - std::optional group, - const std::string& auto_pad = "NOTSET", - bool use_contrib_qdq = false) { +static GetTestQDQModelFn BuildQDQConvTestCase( + const std::string& conv_op_type, + const TestInputDef& input_def, + const TestInputDef& weights_def, + const TestInputDef& bias_def, + const std::vector& strides, + const std::vector& pads, + const std::vector& dilations, + std::optional group, + const std::string& auto_pad = "NOTSET", + bool use_contrib_qdq = false, + std::optional output_activation = std::nullopt) { return [conv_op_type, input_def, weights_def, bias_def, strides, pads, - dilations, group, auto_pad, use_contrib_qdq](ModelTestBuilder& builder, - std::vector>& output_qparams) { + dilations, group, auto_pad, + use_contrib_qdq, output_activation](ModelTestBuilder& builder, + std::vector>& output_qparams) { std::vector conv_inputs; // input -> Q/DQ -> @@ -144,27 +167,42 @@ static GetTestQDQModelFn BuildQDQConvTestCase(const std::string conv_node.AddAttribute("dilations", dilations); } - AddQDQNodePairWithOutputAsGraphOutput(builder, conv_output, output_qparams[0].scale, + NodeArg* q_input = conv_output; + if (output_activation.has_value()) { + q_input = builder.MakeIntermediate(); + std::vector activation_inputs = {conv_output}; + for (auto val : output_activation->const_inputs) { + activation_inputs.push_back(builder.MakeScalarInitializer(val)); + } + Node& activation_node = builder.AddNode(output_activation->op_type, activation_inputs, {q_input}); + for (const auto& attr : output_activation->attrs) { + activation_node.AddAttributeProto(attr); + } + } + + AddQDQNodePairWithOutputAsGraphOutput(builder, q_input, output_qparams[0].scale, output_qparams[0].zero_point, use_contrib_qdq); }; } template -static GetTestQDQModelFn BuildQDQPerChannelConvTestCase(const std::string& conv_op_type, - const TestInputDef& input_def, - const TestInputDef& weights_def, - const TestInputDef& bias_def, - int64_t weight_quant_axis, - const std::vector& strides, - const std::vector& pads, - const std::vector& dilations, - std::optional group, - const std::string& auto_pad = "NOTSET", - bool use_contrib_qdq = false) { +static GetTestQDQModelFn BuildQDQPerChannelConvTestCase( + const std::string& conv_op_type, + const TestInputDef& input_def, + const TestInputDef& weights_def, + const TestInputDef& bias_def, + int64_t weight_quant_axis, + const std::vector& strides, + const std::vector& pads, + const std::vector& dilations, + std::optional group, + const std::string& auto_pad = "NOTSET", + bool use_contrib_qdq = false, + std::optional output_activation = std::nullopt) { return [conv_op_type, input_def, weights_def, bias_def, strides, pads, dilations, group, auto_pad, use_contrib_qdq, - weight_quant_axis](ModelTestBuilder& builder, - std::vector>& output_qparams) { + weight_quant_axis, output_activation](ModelTestBuilder& builder, + std::vector>& output_qparams) { std::vector conv_inputs; // input -> Q/DQ -> @@ -248,7 +286,20 @@ static GetTestQDQModelFn BuildQDQPerChannelConvTestCase(const s conv_node.AddAttribute("dilations", dilations); } - AddQDQNodePairWithOutputAsGraphOutput(builder, conv_output, output_qparams[0].scale, + NodeArg* q_input = conv_output; + if (output_activation.has_value()) { + q_input = builder.MakeIntermediate(); + std::vector activation_inputs = {conv_output}; + for (auto val : output_activation->const_inputs) { + activation_inputs.push_back(builder.MakeScalarInitializer(val)); + } + Node& activation_node = builder.AddNode(output_activation->op_type, activation_inputs, {q_input}); + for (const auto& attr : output_activation->attrs) { + activation_node.AddAttributeProto(attr); + } + } + + AddQDQNodePairWithOutputAsGraphOutput(builder, q_input, output_qparams[0].scale, output_qparams[0].zero_point, use_contrib_qdq); }; } @@ -267,7 +318,8 @@ static void RunHTPConvOpTest(const std::string& conv_op_type, const TestInputDef ExpectedEPNodeAssignment expected_ep_assignment, bool use_contrib_qdq = false, int opset = 13, - QDQTolerance tolerance = QDQTolerance()) { + QDQTolerance tolerance = QDQTolerance(), + std::optional output_activation = std::nullopt) { ProviderOptions provider_options; #if defined(_WIN32) @@ -277,10 +329,11 @@ static void RunHTPConvOpTest(const std::string& conv_op_type, const TestInputDef #endif TestQDQModelAccuracy(BuildF32ConvTestCase(conv_op_type, input_def, weights_def, bias_def, strides, pads, dilations, - group, auto_pad), + group, auto_pad, output_activation), BuildQDQConvTestCase(conv_op_type, input_def, weights_def, bias_def, strides, pads, dilations, - group, auto_pad, use_contrib_qdq), + group, auto_pad, use_contrib_qdq, + output_activation), provider_options, opset, expected_ep_assignment, @@ -302,7 +355,8 @@ static void RunHTPConvOpPerChannelTest(const std::string& conv_op_type, const Te ExpectedEPNodeAssignment expected_ep_assignment, bool use_contrib_qdq = false, int opset = 13, - QDQTolerance tolerance = QDQTolerance()) { + QDQTolerance tolerance = QDQTolerance(), + std::optional output_activation = std::nullopt) { ProviderOptions provider_options; #if defined(_WIN32) @@ -312,11 +366,11 @@ static void RunHTPConvOpPerChannelTest(const std::string& conv_op_type, const Te #endif auto f32_fn = BuildF32ConvTestCase(conv_op_type, input_def, weights_def, bias_def, strides, pads, dilations, - group, auto_pad); + group, auto_pad, output_activation); auto qdq_fn = BuildQDQPerChannelConvTestCase(conv_op_type, input_def, weights_def, bias_def, weight_quant_axis, strides, pads, dilations, group, auto_pad, - use_contrib_qdq); + use_contrib_qdq, output_activation); TestQDQModelAccuracy(f32_fn, qdq_fn, provider_options, opset, expected_ep_assignment, tolerance); } @@ -764,6 +818,139 @@ TEST_F(QnnHTPBackendTests, ConvU16S4S32_PerChannel) { 21); // opset } +// Test fusion of DQs -> Conv -> Relu/Clip -> Q. +// User per-tensor quantization. +TEST_F(QnnHTPBackendTests, ConvU8U8S32_ReluClipFusion) { + std::vector input_shape = {1, 2, 4, 4}; + std::vector weight_shape = {3, 2, 2, 2}; + std::vector bias_shape = {3}; + + TestInputDef input_def(input_shape, false, + GetFloatDataInRange(0.0f, 1.0f, TensorShape(input_shape).Size())); + TestInputDef weight_def(weight_shape, true, + GetFloatDataInRange(-1.0f, 5.0f, TensorShape(weight_shape).Size())); + TestInputDef bias_def(bias_shape, true, + GetFloatDataInRange(-1.0f, 1.0f, TensorShape(bias_shape).Size())); + + // DQs -> Conv (w/ bias) -> Relu -> Q + OutputActivationInfo relu_info = {"Relu", {}}; + RunHTPConvOpTest("Conv", + input_def, + weight_def, + bias_def, + {1, 1}, // Strides + {0, 0, 0, 0}, // Pads + {1, 1}, // Dilations + 1, // default group + "NOTSET", + ExpectedEPNodeAssignment::All, + false, // use_qdq_contrib_ops + 21, // opset + QDQTolerance(), + relu_info); + + // DQs -> Conv (NO bias) -> Relu -> Q + RunHTPConvOpTest("Conv", + input_def, + weight_def, + TestInputDef(), + {1, 1}, // Strides + {0, 0, 0, 0}, // Pads + {1, 1}, // Dilations + 1, // default group + "NOTSET", + ExpectedEPNodeAssignment::All, + false, // use_qdq_contrib_ops + 21, // opset + QDQTolerance(), + relu_info); + + // DQs -> Conv (w/ bias) -> Clip -> Q + OutputActivationInfo clip_info = {"Clip", {}, {0.0f, 6.0f}}; + RunHTPConvOpTest("Conv", + input_def, + weight_def, + bias_def, + {1, 1}, // Strides + {0, 0, 0, 0}, // Pads + {1, 1}, // Dilations + 1, // default group + "NOTSET", + ExpectedEPNodeAssignment::All, + false, // use_qdq_contrib_ops + 21, // opset + QDQTolerance(), + clip_info); + + // DQs -> Conv (NO bias) -> Clip -> Q + OutputActivationInfo clip_info_2 = {"Clip", {}, {-6.0f, 6.0f}}; + RunHTPConvOpTest("Conv", + input_def, + weight_def, + TestInputDef(), + {1, 1}, // Strides + {0, 0, 0, 0}, // Pads + {1, 1}, // Dilations + 1, // default group + "NOTSET", + ExpectedEPNodeAssignment::All, + false, // use_qdq_contrib_ops + 21, // opset + QDQTolerance(), + clip_info_2); +} + +// Test fusion of DQs -> Conv -> Relu/Clip -> Q. +// User per-channel quantization. +TEST_F(QnnHTPBackendTests, ConvS8S8S32_PerChannel_ReluClipFusion) { + std::vector input_shape = {1, 2, 4, 4}; + std::vector weight_shape = {3, 2, 2, 2}; + std::vector bias_shape = {3}; + + TestInputDef input_def(input_shape, false, + GetFloatDataInRange(0.0f, 1.0f, TensorShape(input_shape).Size())); + TestInputDef weight_def(weight_shape, true, + GetFloatDataInRange(-1.0f, 5.0f, TensorShape(weight_shape).Size())); + TestInputDef bias_def(bias_shape, true, + GetFloatDataInRange(-1.0f, 1.0f, TensorShape(bias_shape).Size())); + + // DQs -> Conv (w/ bias) -> Relu -> Q + OutputActivationInfo relu_info = {"Relu", {}}; + RunHTPConvOpPerChannelTest("Conv", + input_def, + weight_def, + bias_def, + 0, // weight quant axis + {1, 1}, // Strides + {0, 0, 0, 0}, // Pads + {1, 1}, // Dilations + 1, // default group + "NOTSET", + ExpectedEPNodeAssignment::All, + false, // use_qdq_contrib_ops + 21, // opset + QDQTolerance(), + relu_info); + + // DQs -> Conv (w/ bias) -> Clip -> Q + OutputActivationInfo clip_info = {"Clip", {}, {0.0f, 6.0f}}; + RunHTPConvOpPerChannelTest("Conv", + input_def, + weight_def, + bias_def, + 0, // weight quant axis + {1, 1}, // Strides + {0, 0, 0, 0}, // Pads + {1, 1}, // Dilations + 1, // default group + "NOTSET", + ExpectedEPNodeAssignment::All, + false, // use_qdq_contrib_ops + 21, // opset + QDQTolerance(), + clip_info); +} + // Test per-channel QDQ Conv with INT4 weights and a negative weight quantization axis that still points to dimension 0. TEST_F(QnnHTPBackendTests, ConvU16S4S32_PerChannel_NegativeWeightQuantAxis) { std::vector input_shape = {1, 2, 4, 4}; diff --git a/onnxruntime/test/providers/qnn/qnn_basic_test.cc b/onnxruntime/test/providers/qnn/qnn_basic_test.cc index a7732fd641e38..9489d354755e4 100644 --- a/onnxruntime/test/providers/qnn/qnn_basic_test.cc +++ b/onnxruntime/test/providers/qnn/qnn_basic_test.cc @@ -484,7 +484,7 @@ static GetTestModelFn F32BuildAdd3Tensors(const TestInputDef& input0_def, } // Tests running a single session in multiple threads on the CPU backend. -TEST_F(QnnCPUBackendTests, DISABLED_MultithreadSessionRun) { +TEST_F(QnnCPUBackendTests, MultithreadSessionRun) { std::unique_ptr model; std::vector input_data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; std::vector shape = {1, 3, 2}; @@ -564,7 +564,7 @@ static GetTestModelFn QDQBuildAdd3Tensors(const TestInputDef& input0_def, } // Tests running a single session in multiple threads on the HTP backend. -TEST_F(QnnHTPBackendTests, DISABLED_MultithreadSessionRun) { +TEST_F(QnnHTPBackendTests, MultithreadSessionRun) { std::unique_ptr model; std::vector input_data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; std::vector shape = {1, 3, 2}; @@ -616,7 +616,7 @@ TEST_F(QnnHTPBackendTests, DISABLED_MultithreadSessionRun) { } // Tests running a single session in multiple threads on the HTP backend with run option to set power config -TEST_F(QnnHTPBackendTests, DISABLED_MultithreadHtpPowerCfgSessionRunOption) { +TEST_F(QnnHTPBackendTests, MultithreadHtpPowerCfgSessionRunOption) { std::unique_ptr model; std::vector input_data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; std::vector shape = {1, 3, 2}; @@ -678,7 +678,7 @@ TEST_F(QnnHTPBackendTests, DISABLED_MultithreadHtpPowerCfgSessionRunOption) { } // Tests running a single session in multiple threads on the HTP backend with EP option to set default power config -TEST_F(QnnHTPBackendTests, DISABLED_MultithreadDefaultHtpPowerCfgFromEpOption) { +TEST_F(QnnHTPBackendTests, MultithreadDefaultHtpPowerCfgFromEpOption) { std::unique_ptr model; std::vector input_data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; std::vector shape = {1, 3, 2}; @@ -732,7 +732,7 @@ TEST_F(QnnHTPBackendTests, DISABLED_MultithreadDefaultHtpPowerCfgFromEpOption) { // Tests running a single session in multiple threads on the HTP backend with // EP option to set default power config + run option to set power config for each run -TEST_F(QnnHTPBackendTests, DISABLED_MultithreadHtpPowerCfgDefaultAndRunOption) { +TEST_F(QnnHTPBackendTests, MultithreadHtpPowerCfgDefaultAndRunOption) { std::unique_ptr model; std::vector input_data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; std::vector shape = {1, 3, 2}; @@ -948,62 +948,6 @@ TEST_F(QnnHTPBackendTests, Float32ModelWithFP16PrecisionTest) { 0.008f); } -TEST_F(QnnHTPBackendTests, TestOD) { - Ort::SessionOptions so; - -#if 1 - const ORTCHAR_T* ort_model_path = ORT_MODEL_FOLDER "od_current_tf2onnx.onnx"; - //so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); -#else - const ORTCHAR_T* ort_model_path = ORT_MODEL_FOLDER "unet.preprocessed.quant.onnx_ctx.onnx"; -#endif - //auto& logging_manager = DefaultLoggingManager(); - //logging_manager.SetDefaultLoggerSeverity(logging::Severity::kVERBOSE); - - // Ensure all type/shape inference warnings result in errors! - so.AddConfigEntry(kOrtSessionOptionsDisableCPUEPFallback, "0"); // Disable fallback to the CPU EP. - so.AddConfigEntry(kDebugLayoutTransformation, "1"); - so.SetGraphOptimizationLevel(ORT_ENABLE_BASIC); - //so.SetLogSeverityLevel(ORT_LOGGING_LEVEL_INFO); - onnxruntime::ProviderOptions options; - -#if defined(_WIN32) - options["backend_path"] = "QnnHtp.dll"; -#else - options["backend_path"] = "libQnnHtp.so"; -#endif - - so.AppendExecutionProvider("QNN", options); - - Ort::Session session(*ort_env, ort_model_path, so); - - std::vector input_data(300 * 300 * 3, 0.5f); - - auto memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); - std::vector ort_inputs; - std::vector ort_input_names; - - // Add input "serving_default_input_3:0" - std::array input_1_shape{1, 300, 300, 3}; - ort_inputs.emplace_back(Ort::Value::CreateTensor( - memory_info, input_data.data(), input_data.size(), input_1_shape.data(), input_1_shape.size())); - ort_input_names.push_back("serving_default_input_3:0"); - - // Run session and get outputs - std::array output_names{"StatefulPartitionedCall:1", "StatefulPartitionedCall:0"}; - std::vector ort_outputs = session.Run(Ort::RunOptions{nullptr}, ort_input_names.data(), ort_inputs.data(), - ort_inputs.size(), output_names.data(), output_names.size()); - - // Check output shape. - Ort::Value& ort_output = ort_outputs[0]; - auto typeshape = ort_output.GetTensorTypeAndShapeInfo(); - const float* results = ort_output.GetTensorData(); - - for (size_t i = 0; i < typeshape.GetElementCount() && i < 20; i++) { - std::cout << i << ": " << results[i] << std::endl; - } -} - #endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) #endif // !defined(ORT_MINIMAL_BUILD) From 5850900a19c8c3b7519713f28a65eb4e109f4228 Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Sun, 28 Jul 2024 19:02:49 -0700 Subject: [PATCH 17/20] Run lintrunner --- onnxruntime/test/providers/qnn/qnn_basic_test.cc | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/onnxruntime/test/providers/qnn/qnn_basic_test.cc b/onnxruntime/test/providers/qnn/qnn_basic_test.cc index 9489d354755e4..9d19c36dc94b2 100644 --- a/onnxruntime/test/providers/qnn/qnn_basic_test.cc +++ b/onnxruntime/test/providers/qnn/qnn_basic_test.cc @@ -835,14 +835,14 @@ TEST_F(QnnHTPBackendTests, HTPGraphFinalizationOptimizationModes) { // Test that models run with various SoC model values TEST_F(QnnHTPBackendTests, HTPSocModels) { - constexpr std::array soc_models = { "", // No explicit SoC model specified - "0", // "Unknown" + constexpr std::array soc_models = {"", // No explicit SoC model specified + "0", // "Unknown" #if defined(_M_ARM64) - "37" }; // SC8280X + "37"}; // SC8280X #elif defined(__linux__) - "30" }; // SM8350 + "30"}; // SM8350 #else - "" }; + ""}; #endif for (auto soc_model : soc_models) { From 7b430f7d3c86f4eb00e831414809f1523f1f737d Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Sun, 28 Jul 2024 20:58:23 -0700 Subject: [PATCH 18/20] Forgot error message --- .../qnn/builder/qnn_node_group/conv_activation_fusion.cc | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/conv_activation_fusion.cc b/onnxruntime/core/providers/qnn/builder/qnn_node_group/conv_activation_fusion.cc index f5ddee6b1f78e..1deb36f7ce5fb 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_node_group/conv_activation_fusion.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/conv_activation_fusion.cc @@ -280,7 +280,9 @@ static Status CreateOrValidateOnQnn(QnnModelWrapper& qnn_model_wrapper, const size_t num_dqs = dq_node_units.size(); constexpr size_t max_num_dqs = 3; ORT_RETURN_IF_NOT(num_dqs == 2 || num_dqs == max_num_dqs, "QDQ Conv should have 2 or 3 DQs"); - ORT_RETURN_IF_NOT(conv_node_unit->OpType() == "Conv" && q_node_unit->OpType() == "QuantizeLinear"); + ORT_RETURN_IF_NOT(conv_node_unit->OpType() == "Conv" && q_node_unit->OpType() == QUANTIZE_LINEAR, + "Expected Conv/ConvTranspose and QuantizeLinear but got ", conv_node_unit->OpType(), " and ", + q_node_unit->OpType()); std::array dq_nodes_buf = {}; for (size_t i = 0; i < num_dqs; i++) { From dab7496a29bef99e7f2e051a489bddb2ce28d196 Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Sun, 28 Jul 2024 21:54:28 -0700 Subject: [PATCH 19/20] Fix linux compiler error due to struct initialization --- onnxruntime/core/framework/node_unit.cc | 1 + .../qnn_node_group/conv_activation_fusion.cc | 6 ++++- .../qnn_node_group/conv_activation_fusion.h | 2 +- .../qnn/builder/qnn_node_group/dq_q_fusion.cc | 3 ++- .../qnn_node_group/hardsigmoid_mul_fusion.cc | 3 ++- .../builder/qnn_node_group/qnn_node_group.cc | 2 +- .../qnn/builder/qnn_node_group/utils.cc | 4 +-- onnxruntime/test/providers/qnn/conv_test.cc | 25 ++++++------------- 8 files changed, 22 insertions(+), 24 deletions(-) diff --git a/onnxruntime/core/framework/node_unit.cc b/onnxruntime/core/framework/node_unit.cc index d2930a770c0a0..516d7425d4989 100644 --- a/onnxruntime/core/framework/node_unit.cc +++ b/onnxruntime/core/framework/node_unit.cc @@ -4,6 +4,7 @@ #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) #include "node_unit.h" +#include #include "core/graph/graph_viewer.h" namespace onnxruntime { diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/conv_activation_fusion.cc b/onnxruntime/core/providers/qnn/builder/qnn_node_group/conv_activation_fusion.cc index 1deb36f7ce5fb..bf479cf86cfc0 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_node_group/conv_activation_fusion.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/conv_activation_fusion.cc @@ -1,10 +1,11 @@ #include "core/providers/qnn/builder/qnn_node_group/conv_activation_fusion.h" +#include #include #include -#include #include #include +#include #include "core/graph/graph_utils.h" #include "core/framework/node_unit.h" #include "core/providers/shared/utils/utils.h" @@ -111,6 +112,9 @@ static bool CanClipBeRemoved(const QnnModelWrapper& qnn_model_wrapper, return false; } + // The clip range must entirely overlap the quantization range (quantization can be smaller). + // Clip range: [------------------] + // Quant range: [-------------] constexpr float epsilon = std::numeric_limits::epsilon(); if ((epsilon < clip_min - rmin) || (epsilon < rmax - clip_max)) { return false; diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/conv_activation_fusion.h b/onnxruntime/core/providers/qnn/builder/qnn_node_group/conv_activation_fusion.h index 71fc71434c5e5..b604b25e943e6 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_node_group/conv_activation_fusion.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/conv_activation_fusion.h @@ -3,8 +3,8 @@ #pragma once -#include #include +#include #include #include #include diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/dq_q_fusion.cc b/onnxruntime/core/providers/qnn/builder/qnn_node_group/dq_q_fusion.cc index 4b2c96aa9e4d7..ce87ac4a3d21c 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_node_group/dq_q_fusion.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/dq_q_fusion.cc @@ -1,10 +1,11 @@ #include "core/providers/qnn/builder/qnn_node_group/dq_q_fusion.h" +#include #include #include -#include #include #include +#include #include "core/graph/graph_utils.h" #include "core/framework/node_unit.h" #include "core/providers/shared/utils/utils.h" diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/hardsigmoid_mul_fusion.cc b/onnxruntime/core/providers/qnn/builder/qnn_node_group/hardsigmoid_mul_fusion.cc index 90f2b1ef29f9a..76b1726646486 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_node_group/hardsigmoid_mul_fusion.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/hardsigmoid_mul_fusion.cc @@ -1,10 +1,11 @@ #include "core/providers/qnn/builder/qnn_node_group/hardsigmoid_mul_fusion.h" +#include #include #include -#include #include #include +#include #include "core/graph/graph_utils.h" #include "core/framework/node_unit.h" #include "core/providers/shared/utils/utils.h" diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/qnn_node_group.cc b/onnxruntime/core/providers/qnn/builder/qnn_node_group/qnn_node_group.cc index e4b9e77a04cd7..9fb9e815321c0 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_node_group/qnn_node_group.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/qnn_node_group.cc @@ -28,7 +28,7 @@ namespace qnn { /// class QnnNodeUnitWrapper : public IQnnNodeGroup { public: - QnnNodeUnitWrapper(const NodeUnit& node_unit) : node_unit_(&node_unit) {} + explicit QnnNodeUnitWrapper(const NodeUnit& node_unit) : node_unit_(&node_unit) {} ORT_DISALLOW_COPY_AND_ASSIGNMENT(QnnNodeUnitWrapper); Status IsSupported(QnnModelWrapper& qmw, const logging::Logger& logger) const override { diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/utils.cc b/onnxruntime/core/providers/qnn/builder/qnn_node_group/utils.cc index 1bcdb26be3400..5548d7d37c378 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_node_group/utils.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/utils.cc @@ -15,7 +15,7 @@ const NodeUnit* GetOnlyChildOfType(const GraphViewer& graph_viewer, const NodeUnit& parent_node_unit, gsl::span child_op_types, const std::unordered_map& node_unit_map, - const std::unordered_map& node_unit_to_qnn_node_group) { + const std::unordered_map& qnn_node_group_map) { const Node& parent_node = parent_node_unit.GetNode(); // Parent must have a single child (1 output edge) and must not produce a graph output. @@ -50,7 +50,7 @@ const NodeUnit* GetOnlyChildOfType(const GraphViewer& graph_viewer, // Check if child node has already been handled. Should not be the case if the calling // fusion function has been called in topological order, but check to be safe. - if (node_unit_to_qnn_node_group.count(child_node_unit) != 0) { + if (qnn_node_group_map.count(child_node_unit) != 0) { return nullptr; } diff --git a/onnxruntime/test/providers/qnn/conv_test.cc b/onnxruntime/test/providers/qnn/conv_test.cc index 4f7d3dea66db6..2bfcc81f975d3 100644 --- a/onnxruntime/test/providers/qnn/conv_test.cc +++ b/onnxruntime/test/providers/qnn/conv_test.cc @@ -18,7 +18,6 @@ namespace test { // Information for activation node placed between the Conv and Q. struct OutputActivationInfo { std::string op_type; // Relu or Clip - std::vector attrs; std::vector const_inputs; }; @@ -69,10 +68,7 @@ static GetTestModelFn BuildF32ConvTestCase(const std::string& conv_op_type, cons for (auto val : output_activation->const_inputs) { activation_inputs.push_back(builder.MakeScalarInitializer(val)); } - Node& activation_node = builder.AddNode(output_activation->op_type, activation_inputs, {output}); - for (const auto& attr : output_activation->attrs) { - activation_node.AddAttributeProto(attr); - } + builder.AddNode(output_activation->op_type, activation_inputs, {output}); } }; } @@ -174,10 +170,7 @@ static GetTestQDQModelFn BuildQDQConvTestCase( for (auto val : output_activation->const_inputs) { activation_inputs.push_back(builder.MakeScalarInitializer(val)); } - Node& activation_node = builder.AddNode(output_activation->op_type, activation_inputs, {q_input}); - for (const auto& attr : output_activation->attrs) { - activation_node.AddAttributeProto(attr); - } + builder.AddNode(output_activation->op_type, activation_inputs, {q_input}); } AddQDQNodePairWithOutputAsGraphOutput(builder, q_input, output_qparams[0].scale, @@ -293,10 +286,7 @@ static GetTestQDQModelFn BuildQDQPerChannelConvTestCase( for (auto val : output_activation->const_inputs) { activation_inputs.push_back(builder.MakeScalarInitializer(val)); } - Node& activation_node = builder.AddNode(output_activation->op_type, activation_inputs, {q_input}); - for (const auto& attr : output_activation->attrs) { - activation_node.AddAttributeProto(attr); - } + builder.AddNode(output_activation->op_type, activation_inputs, {q_input}); } AddQDQNodePairWithOutputAsGraphOutput(builder, q_input, output_qparams[0].scale, @@ -866,7 +856,8 @@ TEST_F(QnnHTPBackendTests, ConvU8U8S32_ReluClipFusion) { relu_info); // DQs -> Conv (w/ bias) -> Clip -> Q - OutputActivationInfo clip_info = {"Clip", {}, {0.0f, 6.0f}}; + // Opset 6 Clip uses attributes for min/max + OutputActivationInfo clip_info = {"Clip", {0.0f, 2.0f}}; RunHTPConvOpTest("Conv", input_def, weight_def, @@ -878,12 +869,12 @@ TEST_F(QnnHTPBackendTests, ConvU8U8S32_ReluClipFusion) { "NOTSET", ExpectedEPNodeAssignment::All, false, // use_qdq_contrib_ops - 21, // opset + 19, // opset QDQTolerance(), clip_info); // DQs -> Conv (NO bias) -> Clip -> Q - OutputActivationInfo clip_info_2 = {"Clip", {}, {-6.0f, 6.0f}}; + OutputActivationInfo clip_info_2 = {"Clip", {-6.0f, 6.0f}}; RunHTPConvOpTest("Conv", input_def, weight_def, @@ -933,7 +924,7 @@ TEST_F(QnnHTPBackendTests, ConvS8S8S32_PerChannel_ReluClipFusion) { relu_info); // DQs -> Conv (w/ bias) -> Clip -> Q - OutputActivationInfo clip_info = {"Clip", {}, {0.0f, 6.0f}}; + OutputActivationInfo clip_info = {"Clip", {0.0f, 6.0f}}; RunHTPConvOpPerChannelTest("Conv", input_def, weight_def, From 2fd58631e071e18975d47337cc57a743b0b67c4a Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Tue, 30 Jul 2024 12:59:03 -0700 Subject: [PATCH 20/20] Rename variable according to review comment; Add function comments. --- onnxruntime/core/framework/node_unit.cc | 4 ++-- onnxruntime/core/framework/node_unit.h | 2 +- .../qnn_node_group/conv_activation_fusion.cc | 14 ++++++++++++++ .../providers/qnn/builder/qnn_node_group/utils.h | 11 +++++++++++ 4 files changed, 28 insertions(+), 3 deletions(-) diff --git a/onnxruntime/core/framework/node_unit.cc b/onnxruntime/core/framework/node_unit.cc index 516d7425d4989..850cb167a3ece 100644 --- a/onnxruntime/core/framework/node_unit.cc +++ b/onnxruntime/core/framework/node_unit.cc @@ -274,13 +274,13 @@ NodeUnit::NodeUnit(const GraphViewer& graph_viewer, const QDQ::NodeGroup& node_g } NodeUnit::NodeUnit(gsl::span dq_nodes, const Node& target_node, - gsl::span q_nodes, Type type, + gsl::span q_nodes, Type unit_type, gsl::span inputs, gsl::span outputs, size_t input_edge_count, Node::EdgeSet output_edges) : dq_nodes_(dq_nodes.begin(), dq_nodes.end()), target_node_(target_node), q_nodes_(q_nodes.begin(), q_nodes.end()), - type_(type), + type_(unit_type), inputs_(inputs.begin(), inputs.end()), outputs_(outputs.begin(), outputs.end()), input_edge_count_(input_edge_count), diff --git a/onnxruntime/core/framework/node_unit.h b/onnxruntime/core/framework/node_unit.h index 8bc2f79c4a372..50bd423d2f547 100644 --- a/onnxruntime/core/framework/node_unit.h +++ b/onnxruntime/core/framework/node_unit.h @@ -69,7 +69,7 @@ class NodeUnit { explicit NodeUnit(const Node& node); explicit NodeUnit(const GraphViewer& graph_viewer, const QDQ::NodeGroup& node_group); NodeUnit(gsl::span dq_nodes, const Node& target_node, - gsl::span q_nodes, Type type, + gsl::span q_nodes, Type unit_type, gsl::span inputs, gsl::span outputs, size_t input_edge_count, Node::EdgeSet output_edges); diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/conv_activation_fusion.cc b/onnxruntime/core/providers/qnn/builder/qnn_node_group/conv_activation_fusion.cc index bf479cf86cfc0..813bba8a5952b 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_node_group/conv_activation_fusion.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/conv_activation_fusion.cc @@ -17,6 +17,7 @@ namespace onnxruntime { namespace qnn { +// Gets the scale, zero-point, and zero-point type for a QuantizeLinear node that uses per-tensor quantization. static bool GetQScalarScaleZeroPoint(const QnnModelWrapper& qnn_model_wrapper, const NodeUnit& q_node_unit, /*out*/ float& scale, @@ -52,6 +53,7 @@ static bool GetQScalarScaleZeroPoint(const QnnModelWrapper& qnn_model_wrapper, return true; } +// Computes the floating point range (rmin, rmax) from a QuantizeLinear node's scale/zero-point. static bool GetQRminRmax(const QnnModelWrapper& qnn_model_wrapper, const NodeUnit& q_node_unit, /*out*/ float& rmin, @@ -92,6 +94,7 @@ static bool GetQRminRmax(const QnnModelWrapper& qnn_model_wrapper, return true; } +// Returns true if the Clip in the sequence (Clip -> Q) can be removed because it is made redundant by the Q. static bool CanClipBeRemoved(const QnnModelWrapper& qnn_model_wrapper, const NodeUnit& clip_node_unit, const NodeUnit& q_node_unit, @@ -123,6 +126,7 @@ static bool CanClipBeRemoved(const QnnModelWrapper& qnn_model_wrapper, return true; } +// Returns true if the Relu in the sequence (Relu -> Q) can be removed because it is made redundant by the Q. static bool CanQRelaceRelu(const QnnModelWrapper& qnn_model_wrapper, const NodeUnit& q_node_unit) { assert(q_node_unit.OpType() == QUANTIZE_LINEAR); int32_t zp_data_type = ONNX_NAMESPACE::TensorProto::DataType::TensorProto_DataType_UNDEFINED; @@ -148,6 +152,7 @@ static bool CanQRelaceRelu(const QnnModelWrapper& qnn_model_wrapper, const NodeU } } +// Returns true if the Clip/Relu in the sequence (Clip/Relu -> Q) can be removed because it is made redundant by the Q. static bool CanActivationBeRemoved(const QnnModelWrapper& qnn_model_wrapper, const NodeUnit& activation_node_unit, const NodeUnit& q_node_unit, @@ -165,6 +170,7 @@ static bool CanActivationBeRemoved(const QnnModelWrapper& qnn_model_wrapper, return false; } +// Returns the parent DQ nodes for a given node. static std::vector FindParentDQNodes(const GraphViewer& graph_viewer, const Node& node) { // Get all parent DQ nodes sorted by destination argument index. std::vector parents(node.InputDefs().size(), nullptr); @@ -184,6 +190,8 @@ static std::vector FindParentDQNodes(const GraphViewer& graph_viewe return parents; } +// Gets the parent DQ nodes for the given Conv node. This fuction checks that the DQs are not a part of +// any other NodeUnit and that every Conv input comes from a parent DQ. static bool GetConvDQs( const GraphViewer& graph_viewer, const std::unordered_map& node_to_node_unit, @@ -244,6 +252,7 @@ static bool GetConvDQs( return true; } +// Checks that the input and output data types are valid for a QDQ Conv. static bool CheckQDQConvDataTypes(std::array& dq_node_units, gsl::not_null q_node_unit) { assert(q_node_unit->OpType() == QUANTIZE_LINEAR); @@ -271,6 +280,9 @@ static bool CheckQDQConvDataTypes(std::array& dq_node_units, return true; } +// Utility function to either validate or create a quantized QNN Conv node. The function creates a temporary +// custom NodeUnit that excludes the Clip/Relu because it is redundant. This custom NodeUnit is passed to our +// existing Conv OpBuilder for creation or validation via QNN APIs. #define ValidateOnQnn(qnn_model_wrapper, dq_node_units, conv_node_unit, q_node_unit, logger) \ CreateOrValidateOnQnn((qnn_model_wrapper), (dq_node_units), (conv_node_unit), (q_node_unit), (logger), true) #define CreateOnQnn(qnn_model_wrapper, dq_node_units, conv_node_unit, q_node_unit, logger) \ @@ -353,6 +365,8 @@ static Status CreateOrValidateOnQnn(QnnModelWrapper& qnn_model_wrapper, return conv_op_builder->AddToModelBuilder(qnn_model_wrapper, custom_node_unit, logger, validate); } +// Traverses graph to check if the given NodeUnit is part of a valid DQ* -> Conv -> Relu/Clip -> Q sequence. +// If so, returns a IQnnNodeGroup that contains the constituent NodeUnits. std::unique_ptr ConvActivationFusion::TryFusion( QnnModelWrapper& qnn_model_wrapper, const NodeUnit& conv_node_unit, diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/utils.h b/onnxruntime/core/providers/qnn/builder/qnn_node_group/utils.h index 915eff3fbd418..0d11d21906ccb 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_node_group/utils.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/utils.h @@ -19,6 +19,17 @@ constexpr size_t QDQ_MAX_NUM_INPUTS = 3; constexpr size_t QDQ_SCALE_INPUT_IDX = 1; constexpr size_t QDQ_ZERO_POINT_INPUT_IDX = 2; +/// +/// Utility function to get a child NodeUnit. The returned NodeUnit must be the parent's only child, must be +/// of the expected type, and must not be a part of another IQnnNodeGroup. +/// +/// GraphViewer containing all Nodes +/// Parent NodeUnit +/// Valid child types +/// Maps a Node to its NodeUnit +/// Maps a NodeUnit to its IQnnNodeGroup. +/// Used to check that the child has not already been added to another IQnnNodeGroup. +/// const NodeUnit* GetOnlyChildOfType(const GraphViewer& graph_viewer, const NodeUnit& parent_node_unit, gsl::span child_op_types,