From 11bf3097360d271e8c0d6f26683a8b16477e6c42 Mon Sep 17 00:00:00 2001 From: Jing Fang <126209182+fajin-corp@users.noreply.github.com> Date: Fri, 19 Jul 2024 22:55:15 -0700 Subject: [PATCH] add transform part of the dq matmul tool chain (#21374) ### Description This is a partial change from [fajin/qdqmatmulnbitstoolchain](https://github.com/microsoft/onnxruntime/pull/21180). The original PR is blocked by Web CI failures. MatMulNBits is a heavily optimized matmul operation. Currently a MatMul can be converted to MatMulNBits to speed up the model inference. However, MatMulNBits is an ORT only op. To make the graph compatible with ONNX ops and utilize MatMulNBits at the same time, we introduce Q/DQ support for MatMulNBits. To convert MatMul ops in a model to MatMulNBits: 1. use matmul_4bits_quantizer.py to convert MatMul to DQ + MatMul using QDQ mode. 2. In ORT session, DQ + MatMul is fused to MatMulNBits #### Note MatMulNBits assume B weight is uint4. When no zp is provided, zp defaults to 8, which is different from DQ. DQ defaults zp to 0 when no zp provided. And DQ supports int4. Therefore some conversions are introduced during DQ + MatMul --> MatMulNBits step. #### Perf Using QDQ format will increase the model initialization time and memory consumption. With current implement, model init time increased from ~4s to ~9s, and memory consumption increased from ~2.8GB to ~4.8GB. The memory increase is due to 1. in optimizer, after transpose the B weight, a in-memory tensor proto is created using protobuf's arena. 2. in finalize step, when saving initializer and prepacking, ORT arena is used to create buffers for initializers. The memory allocated by arenas cannot be fully deallocated. If disable ORT arena memory allocation, the memory consumptions of both QDQ format and original format are ~2.2GB. The time increase is mainly due to multiple memory copy, but can be further optimized. ### Motivation and Context Please see description for details. --- .../core/optimizer/graph_transformer_utils.h | 7 +- .../onnxruntime_session_options_config_keys.h | 5 + .../core/optimizer/graph_transformer_utils.cc | 26 +- .../selectors_actions/qdq_actions.cc | 173 ++++++- .../selectors_actions/qdq_actions.h | 29 ++ .../qdq_selector_action_transformer.cc | 39 +- .../qdq_selector_action_transformer.h | 6 +- .../selectors_actions/qdq_selectors.cc | 85 ++++ .../selectors_actions/qdq_selectors.h | 15 + .../optimizer/selectors_actions/actions.cc | 4 +- .../optimizer/selectors_actions/actions.h | 7 +- onnxruntime/core/session/inference_session.cc | 14 +- onnxruntime/test/common/random_generator.h | 17 + .../optimizer/graph_transform_test_builder.h | 16 - .../qdq_matmulnbits_transformer_test.cc | 425 ++++++++++++++++++ onnxruntime/test/optimizer/qdq_test_utils.h | 2 +- 16 files changed, 833 insertions(+), 37 deletions(-) create mode 100644 onnxruntime/test/optimizer/qdq_matmulnbits_transformer_test.cc diff --git a/include/onnxruntime/core/optimizer/graph_transformer_utils.h b/include/onnxruntime/core/optimizer/graph_transformer_utils.h index e609745b5e03f..0bb5c7432f0a7 100644 --- a/include/onnxruntime/core/optimizer/graph_transformer_utils.h +++ b/include/onnxruntime/core/optimizer/graph_transformer_utils.h @@ -10,6 +10,7 @@ #include "core/common/inlined_containers.h" #include "core/framework/session_options.h" #include "core/optimizer/graph_transformer.h" +#include "core/platform/threadpool.h" #if !defined(ORT_MINIMAL_BUILD) #include "core/optimizer/rule_based_graph_transformer.h" @@ -49,7 +50,8 @@ InlinedVector> GenerateTransformers( TransformerLevel level, const SessionOptions& session_options, const IExecutionProvider& execution_provider /*required by constant folding*/, - const InlinedHashSet& rules_and_transformers_to_disable = {}); + const InlinedHashSet& rules_and_transformers_to_disable = {}, + concurrency::ThreadPool* intra_op_thread_pool = nullptr); #endif // !defined(ORT_MINIMAL_BUILD) @@ -78,7 +80,8 @@ InlinedVector> GenerateTransformersForMinimalB const SessionOptions& session_options, const SatApplyContextVariant& apply_context, const IExecutionProvider& cpu_execution_provider, - const InlinedHashSet& rules_and_transformers_to_disable = {}); + const InlinedHashSet& rules_and_transformers_to_disable = {}, + concurrency::ThreadPool* intra_op_thread_pool = nullptr); #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) diff --git a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h index c32e2a77e8453..17ae649e6f174 100644 --- a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h +++ b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h @@ -270,3 +270,8 @@ static const char* const kOrtSessionOptionEpContextEmbedMode = "ep.context_embed // - "0": Gemm FastMath mode is not enabled. [DEFAULT] // - "1": Gemm FastMath mode is enabled. static const char* const kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16 = "mlas.enable_gemm_fastmath_arm64_bfloat16"; + +// When converting DQ + MatMul -> MatMulNBits, the accuracy level of the MatMulNBits is controlled by this option. +// Refer to MatMulNBits op schema for more details. +// If not provided, default is 4. +static const char* const kOrtSessionOptionsQDQMatMulNBitsAccuracyLevel = "session.qdq_matmulnbits_accuracy_level"; diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index e6feb3e7ddbe2..7da65f18ccacb 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -13,6 +13,7 @@ #include "core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h" #include "core/optimizer/selectors_actions/selector_action_transformer_apply_contexts.h" #include "core/session/onnxruntime_session_options_config_keys.h" +#include "core/platform/threadpool.h" #if !defined(ORT_MINIMAL_BUILD) @@ -187,7 +188,8 @@ InlinedVector> GenerateTransformers( TransformerLevel level, const SessionOptions& session_options, const IExecutionProvider& cpu_execution_provider, /*required by constant folding*/ - const InlinedHashSet& rules_and_transformers_to_disable) { + const InlinedHashSet& rules_and_transformers_to_disable, + concurrency::ThreadPool* intra_op_thread_pool) { InlinedVector> transformers; const bool disable_quant_qdq = session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsDisableQuantQDQ, "0") == "1"; @@ -287,6 +289,10 @@ InlinedVector> GenerateTransformers( onnxruntime::kJsExecutionProvider}; const InlinedHashSet cpu_dml_eps = {onnxruntime::kCpuExecutionProvider, onnxruntime::kDmlExecutionProvider}; + const int64_t qdq_matmulnbits_accuracy_level = + ParseStringWithClassicLocale( + session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsQDQMatMulNBitsAccuracyLevel, + "4")); #ifdef MLAS_TARGET_AMD64_IX86 const bool avx2_precision_mode = session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsAvx2PrecisionMode, "0") == "1" && MlasPlatformU8S8Overflow(); @@ -300,7 +306,10 @@ InlinedVector> GenerateTransformers( if (!qdq_is_int8_allowed) { transformers.emplace_back(std::make_unique(avx2_precision_mode, cpu_ep)); } - transformers.emplace_back(std::make_unique(qdq_is_int8_allowed)); + transformers.emplace_back(std::make_unique(qdq_is_int8_allowed, + SatApplyContextVariant{}, + qdq_matmulnbits_accuracy_level, + intra_op_thread_pool)); } transformers.emplace_back(std::make_unique(cpu_ep)); @@ -409,7 +418,8 @@ InlinedVector> GenerateTransformersForMinimalB const SessionOptions& session_options, const SatApplyContextVariant& apply_context, const IExecutionProvider& cpu_execution_provider, - const InlinedHashSet& rules_and_transformers_to_disable) { + const InlinedHashSet& rules_and_transformers_to_disable, + concurrency::ThreadPool* intra_op_thread_pool) { InlinedVector> transformers; const bool saving = std::holds_alternative(apply_context); @@ -423,12 +433,18 @@ InlinedVector> GenerateTransformersForMinimalB const bool qdq_is_int8_allowed = session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsQDQIsInt8Allowed, QDQIsInt8Allowed() ? "1" : "0") == "1"; - + const int64_t qdq_matmulnbits_accuracy_level = + ParseStringWithClassicLocale( + session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsQDQMatMulNBitsAccuracyLevel, + "4")); // runtime optimizations only support CPU EP now const InlinedHashSet cpu_ep = {onnxruntime::kCpuExecutionProvider}; if (!disable_quant_qdq) { - transformers.emplace_back(std::make_unique(qdq_is_int8_allowed, apply_context)); + transformers.emplace_back(std::make_unique(qdq_is_int8_allowed, + apply_context, + qdq_matmulnbits_accuracy_level, + intra_op_thread_pool)); } transformers.emplace_back(std::make_unique(cpu_ep, apply_context)); diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc index 3497ea4c85523..74fecb0427e14 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc @@ -2,10 +2,12 @@ // Licensed under the MIT License. #include "core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h" - #include "core/optimizer/qdq_transformer/qdq_util.h" +#include "core/optimizer/initializer.h" #include "core/graph/node_attr_utils.h" #include "core/framework/tensorprotoutils.h" +#include "core/mlas/inc/mlas_q4.h" + namespace onnxruntime { namespace QDQ { @@ -273,6 +275,175 @@ Status MatMulReplaceWithQLinear::Run(Graph& graph, const NodesToOptimize& select } } +DQMatMulToMatMulNBitsAction::DQMatMulToMatMulNBitsAction(int64_t accuracy_level, + concurrency::ThreadPool* intra_op_thread_pool) + : accuracy_level_{accuracy_level}, + domain_{kMSDomain}, + op_type_{"MatMulNBits"}, + value_moves_{[]() { + NTO::NodeLocation target{NTO::NodeType::kTarget, 0}; + return std::vector{ + MoveAndAppend(target, ArgType::kInput, 0, ArgType::kInput), + MoveAll(target, ArgType::kOutput)}; + }()}, + intra_op_thread_pool_{intra_op_thread_pool} { + ORT_ENFORCE(accuracy_level_ >= 0 && accuracy_level_ <= 4, "MatMulNBits accuracy level must be between 0 and 4"); +} + +NodeAttributes +DQMatMulToMatMulNBitsAction::ExtraAttributes(const RuntimeState& runtime_state) const { + NodeAttributes extra_attributes; + + const auto* dq_node = runtime_state.selected_nodes.Input(0); + auto& attrs = dq_node->GetAttributes(); + const auto* weight_shape = dq_node->InputDefs()[0]->Shape(); + + utils::SetNodeAttribute(utils::MakeAttribute("K", weight_shape->dim(0).dim_value()), extra_attributes); + utils::SetNodeAttribute(utils::MakeAttribute("N", weight_shape->dim(1).dim_value()), extra_attributes); + utils::SetNodeAttribute(utils::MakeAttribute("accuracy_level", accuracy_level_), extra_attributes); + // currently only 4bits is supported. In the future, derive bits from DQ's weight type. + utils::SetNodeAttribute(utils::MakeAttribute("bits", static_cast(4)), extra_attributes); + utils::SetNodeAttribute(utils::MakeAttribute("block_size", attrs.at("block_size").i()), extra_attributes); + + return extra_attributes; +} + +Status DQMatMulToMatMulNBitsAction::ProcessNewNode(Graph& graph, + const NodesToOptimize& selected_nodes, + Node& replacement_node) const { + const auto* dq_node = selected_nodes.Input(0); + const auto* weight_arg = dq_node->InputDefs()[0]; + const auto* scale_arg = dq_node->InputDefs()[1]; + const auto* zp_arg = dq_node->InputDefs().size() > 2 ? dq_node->InputDefs()[2] : nullptr; + const auto& attrs = dq_node->GetAttributes(); + + const ONNX_NAMESPACE::TensorProto* weight_tensor_proto = nullptr; + const ONNX_NAMESPACE::TensorProto* scale_tensor_proto = nullptr; + const ONNX_NAMESPACE::TensorProto* zp_tensor_proto = nullptr; + graph.GetInitializedTensor(weight_arg->Name(), weight_tensor_proto); + graph.GetInitializedTensor(scale_arg->Name(), scale_tensor_proto); + if (zp_arg) { + graph.GetInitializedTensor(zp_arg->Name(), zp_tensor_proto); + } + + auto K = weight_arg->Shape()->dim(0).dim_value(); + auto N = weight_arg->Shape()->dim(1).dim_value(); + auto block_size = attrs.at("block_size").i(); + auto quant_num = (K + block_size - 1) / block_size; + auto blob_bytes = (block_size + 1) / 2; + + // Unfortunately iterating the source data is complicated, the data maybe in + // external file, a raw buffer, or a repeated field depending on the data + // type. UnpackTensor() already contains some of these logic and is closest + // to what we need. But it does not handle external data. + Initializer weight_src(*weight_tensor_proto, graph.ModelPath()); + Initializer scale_src(*scale_tensor_proto, graph.ModelPath()); + std::optional zp_src; + Initializer weight_dst(ONNX_NAMESPACE::TensorProto_DataType_UINT8, + graph.GenerateNodeArgName(weight_arg->Name() + "_T"), + std::vector{N, quant_num, blob_bytes}); + Initializer scale_dst(static_cast(scale_src.data_type()), + graph.GenerateNodeArgName(scale_arg->Name() + "_T"), + std::vector{N * quant_num}); + std::optional zp_dst; + + if (zp_tensor_proto) { + zp_src.emplace(*zp_tensor_proto, graph.ModelPath()); + zp_dst.emplace(ONNX_NAMESPACE::TensorProto_DataType_UINT8, + graph.GenerateNodeArgName(zp_arg->Name() + "_T"), + std::vector{N * ((quant_num + 1) / 2)}); + } else if (weight_src.data_type() == ONNX_NAMESPACE::TensorProto_DataType_UINT4) { + zp_dst.emplace(ONNX_NAMESPACE::TensorProto_DataType_UINT8, + graph.GenerateNodeArgName("fused_DQ_MatMul_zero_point_T"), + std::vector{N * ((quant_num + 1) / 2)}); + } + + if (scale_src.data_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { + if (weight_src.data_type() == ONNX_NAMESPACE::TensorProto_DataType_INT4) { + MlasQDQTransposeBlockwiseQuantized( + weight_src.DataAsByteSpan().data(), + scale_src.data(), + zp_src ? zp_src->DataAsByteSpan().data() : nullptr, + weight_dst.data(), + scale_dst.data(), + zp_dst ? zp_dst->data() : nullptr, + true, + static_cast(K), + static_cast(N), + static_cast(block_size), + intra_op_thread_pool_); + } else { + MlasQDQTransposeBlockwiseQuantized( + weight_src.DataAsByteSpan().data(), + scale_src.data(), + zp_src ? zp_src->DataAsByteSpan().data() : nullptr, + weight_dst.data(), + scale_dst.data(), + zp_dst ? zp_dst->data() : nullptr, + true, + static_cast(K), + static_cast(N), + static_cast(block_size), + intra_op_thread_pool_); + } + } else { + if (weight_src.data_type() == ONNX_NAMESPACE::TensorProto_DataType_INT4) { + MlasQDQTransposeBlockwiseQuantized( + weight_src.DataAsByteSpan().data(), + scale_src.data(), + zp_src ? zp_src->DataAsByteSpan().data() : nullptr, + weight_dst.data(), + scale_dst.data(), + zp_dst ? zp_dst->data() : nullptr, + true, + static_cast(K), + static_cast(N), + static_cast(block_size), + intra_op_thread_pool_); + + } else { + MlasQDQTransposeBlockwiseQuantized( + weight_src.DataAsByteSpan().data(), + scale_src.data(), + zp_src ? zp_src->DataAsByteSpan().data() : nullptr, + weight_dst.data(), + scale_dst.data(), + zp_dst ? zp_dst->data() : nullptr, + true, + static_cast(K), + static_cast(N), + static_cast(block_size), + intra_op_thread_pool_); + } + } + + ONNX_NAMESPACE::TensorProto weight_T_tp; + ONNX_NAMESPACE::TensorProto scale_T_tp; + std::optional zp_T_tp; + + // TODO(fajin): external_data to memory location to avoid arena allocation + // https://github.com/microsoft/onnxruntime/pull/12465 + weight_dst.ToProto(weight_T_tp); + scale_dst.ToProto(scale_T_tp); + if (zp_dst) { + zp_T_tp.emplace(); + zp_dst->ToProto(zp_T_tp.value()); + } + + auto& input_defs = replacement_node.MutableInputDefs(); + input_defs.push_back(&graph_utils::AddInitializer(graph, weight_T_tp)); + replacement_node.MutableInputArgsCount().push_back(1); + input_defs.push_back(&graph_utils::AddInitializer(graph, scale_T_tp)); + replacement_node.MutableInputArgsCount().push_back(1); + + if (zp_T_tp) { + input_defs.push_back(&graph_utils::AddInitializer(graph, zp_T_tp.value())); + replacement_node.MutableInputArgsCount().push_back(1); + } + + return Status::OK(); +} + static std::vector GetGemmMoveInfo(bool does_q_node_exist) { NTO::NodeLocation dq_A{NTO::NodeType::kInput, 0}; NTO::NodeLocation dq_B{NTO::NodeType::kInput, 1}; diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h index 8179a030508a5..47821619db65a 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h @@ -3,7 +3,12 @@ #pragma once +#include +#include +#include + #include "core/optimizer/selectors_actions/actions.h" +#include "core/platform/threadpool.h" namespace onnxruntime { @@ -76,6 +81,30 @@ struct MatMulReplaceWithQLinear : public Action { BinaryReplaceWithQLinear qlinear_matmul_replacer_; }; +// used together with DQMatMulNodeGroupSelector, which does the sanity check +struct DQMatMulToMatMulNBitsAction : public ReplaceWithNew { + DQMatMulToMatMulNBitsAction(int64_t accuracy_level, + concurrency::ThreadPool* intra_op_thread_pool); + + private: + std::string OpType(const RuntimeState&) const override { return op_type_; } + + std::string Domain(const RuntimeState&) const override { return domain_; } + + NodeAttributes ExtraAttributes(const RuntimeState&) const override; + + std::vector ValueMoves(const RuntimeState&) const override { return value_moves_; } + + // transpose initializers, and add to the MatMulNBits inputs + Status ProcessNewNode(Graph&, const NodesToOptimize&, Node&) const override; + + const int64_t accuracy_level_; + const std::string domain_; + const std::string op_type_; + const std::vector value_moves_; + concurrency::ThreadPool* intra_op_thread_pool_; +}; + struct GemmReplaceWithQuant : public Action { GemmReplaceWithQuant(); diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc index 80ead8f8c68d6..17e66a3953b97 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc @@ -228,6 +228,30 @@ void MatMulQDQRules(SelectorActionRegistry& qdq_selector_action_registry, bool i #endif } +void DQMatMulToMatMulNBitsRules(SelectorActionRegistry& qdq_selector_action_registry, + int64_t qdq_matmulnbits_accuracy_level, + concurrency::ThreadPool* intra_op_thread_pool) { + // 2 nodes. DQ -> MatMul. DQ is the second input to MatMul. + // DQ's weight is int4/uint4. DQ's scale is float/float16. + // DQ is block-quantized along axis 0, with block_size >= 16 and as 2's power. + const std::string action_name{"DQMatMulToMatMulNBits"}; + + std::unique_ptr action = + std::make_unique(qdq_matmulnbits_accuracy_level, + intra_op_thread_pool); + +#if !defined(ORT_MINIMAL_BUILD) + std::unique_ptr selector = std::make_unique(); + qdq_selector_action_registry.RegisterSelectorAndAction(action_name, + {{"MatMul", {}}}, + std::move(selector), + std::move(action)); + +#else + qdq_selector_action_registry.RegisterAction(action_name, std::move(action)); +#endif +} + void GemmQDQRules(SelectorActionRegistry& qdq_selector_action_registry) { // 3 to 5 nodes. 0=DQ A, 1=DQ B, 2=DQ C(optional), 3=Gemm, 4=Q Y(optional) // Replace with QGemm @@ -271,7 +295,9 @@ void WhereQDQRules(SelectorActionRegistry& qdq_selector_action_registry) { #endif } -SelectorActionRegistry CreateSelectorActionRegistry(bool is_int8_allowed) { +SelectorActionRegistry CreateSelectorActionRegistry(bool is_int8_allowed, + int64_t qdq_matmulnbits_accuracy_level, + concurrency::ThreadPool* intra_op_thread_pool) { SelectorActionRegistry qdq_selector_action_registry; SplitQDQRules(qdq_selector_action_registry); DropQDQNodesRules(qdq_selector_action_registry); @@ -283,17 +309,22 @@ SelectorActionRegistry CreateSelectorActionRegistry(bool is_int8_allowed) { MatMulQDQRules(qdq_selector_action_registry, is_int8_allowed); GemmQDQRules(qdq_selector_action_registry); WhereQDQRules(qdq_selector_action_registry); + DQMatMulToMatMulNBitsRules(qdq_selector_action_registry, + qdq_matmulnbits_accuracy_level, + intra_op_thread_pool); return qdq_selector_action_registry; } } // namespace -QDQSelectorActionTransformer::QDQSelectorActionTransformer( - bool is_int8_allowed, const SatApplyContextVariant& apply_context) +QDQSelectorActionTransformer::QDQSelectorActionTransformer(bool is_int8_allowed, + const SatApplyContextVariant& apply_context, + int64_t qdq_matmulnbits_accuracy_level, + concurrency::ThreadPool* intra_op_thread_pool) : SelectorActionTransformer{ "QDQSelectorActionTransformer", - CreateSelectorActionRegistry(is_int8_allowed), + CreateSelectorActionRegistry(is_int8_allowed, qdq_matmulnbits_accuracy_level, intra_op_thread_pool), apply_context, // this transformer is only compatible with the CPU and DML EP {kCpuExecutionProvider, kDmlExecutionProvider}} { diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h index 1780923f3f273..ba636f76d1900 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h @@ -5,6 +5,7 @@ #include "core/optimizer/selectors_actions/selector_action_transformer.h" #include "core/mlas/inc/mlas.h" +#include "core/platform/threadpool.h" namespace onnxruntime { @@ -21,7 +22,10 @@ Transformer that fuses QDQ and fp32 ops into quantized ops. */ class QDQSelectorActionTransformer : public SelectorActionTransformer { public: - QDQSelectorActionTransformer(bool is_int8_allowed, const SatApplyContextVariant& apply_context = {}); + QDQSelectorActionTransformer(bool is_int8_allowed, + const SatApplyContextVariant& apply_context = {}, + int64_t qdq_matmulnbits_accuracy_level = 4, + concurrency::ThreadPool* intra_op_thread_pool = nullptr); }; } // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc index 09705f61c82ce..6e93445c7c5c7 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc @@ -414,6 +414,91 @@ bool MatMulNodeGroupSelector::Check(const GraphViewer& graph_viewer, } } +bool DQMatMulNodeGroupSelector::Check(const GraphViewer& graph_viewer, + const Node& node, + const std::vector& dq_nodes, + const std::vector& q_nodes) const { + // Should not have any Q nodes + if (!q_nodes.empty()) { + return false; + } + + const auto& graph = graph_viewer.GetGraph(); + + // MatMul has only 1 DQ input and the DQ must have 1 output edge and not be a graph output + if (dq_nodes.size() != 1 || !optimizer_utils::CheckOutputEdges(graph, *dq_nodes[0], 1)) { + return false; + } + + // DQ must be MatMul's the second input + if (node.InputDefs()[1] != dq_nodes[0]->OutputDefs()[0]) { + return false; + } + + // DQ weight/zero points types are int4/uint4, scales/output types are float or float16 + const auto* weight_arg = dq_nodes[0]->InputDefs()[0]; + const auto* scale_arg = dq_nodes[0]->InputDefs()[1]; + const auto* zero_point_arg = dq_nodes[0]->InputDefs().size() == 3 ? dq_nodes[0]->InputDefs()[2] : nullptr; + int32_t dt_weight = weight_arg->TypeAsProto()->tensor_type().elem_type(); + int32_t dt_scales = scale_arg->TypeAsProto()->tensor_type().elem_type(); + if (dt_scales != ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT && + dt_scales != ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT16) { + return false; + } + + if (!Is4BitIntType(dt_weight)) { + return false; + } + + // DQ is blockwise quantized along axis 0, and block_size must be 2's power and >= 16 + const auto& dq_attrs = dq_nodes[0]->GetAttributes(); + if (const auto a_iter = dq_attrs.find("axis"); + a_iter == dq_attrs.end() || a_iter->second.i() != 0) { + return false; + } + + const auto a_iter = dq_attrs.find("block_size"); + if (a_iter == dq_attrs.end()) { + return false; + } + + auto block_size = a_iter->second.i(); + if (block_size < 16 || ((block_size - 1) & block_size)) { + return false; + } + + // weight, scale and zero points (if exists) must be constants + const auto* weight_tensor_proto = graph.GetConstantInitializer(weight_arg->Name(), true); + const auto* scale_tensor_proto = graph.GetConstantInitializer(scale_arg->Name(), true); + const auto* zp_tensor_proto = zero_point_arg ? graph.GetConstantInitializer(zero_point_arg->Name(), true) : nullptr; + + if (!weight_tensor_proto || !scale_tensor_proto) { + return false; + } + + if (zero_point_arg && !zp_tensor_proto) { + return false; + } + + // weight, scale and zero points (if exists) must have the rank 2 + if (weight_tensor_proto->dims_size() != 2 || + scale_tensor_proto->dims_size() != 2 || + (zp_tensor_proto && zp_tensor_proto->dims_size() != 2)) { + return false; + } + + // check weight, scale and zero points (if exists) shapes + if ((weight_tensor_proto->dims()[0] + block_size - 1) / block_size != scale_tensor_proto->dims()[0] || + weight_tensor_proto->dims()[1] != scale_tensor_proto->dims()[1] || + (zp_tensor_proto && + (zp_tensor_proto->dims()[0] != scale_tensor_proto->dims()[0] || + zp_tensor_proto->dims()[1] != scale_tensor_proto->dims()[1]))) { + return false; + } + + return true; +} + bool GemmNodeGroupSelector::Check(const GraphViewer& graph_viewer, const Node& node, const std::vector& dq_nodes, diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h index 1a2a620acb480..491a15b62cb03 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h @@ -204,6 +204,14 @@ class MatMulNodeGroupSelector : public NodeGroupSelector { bool allow_4bit_; }; +// Convert "1 DQ node for input B -> MatMul" to "MatMulNBits" +class DQMatMulNodeGroupSelector : public NodeGroupSelector { + private: + bool Check(const GraphViewer& graph_viewer, const Node& node, + const std::vector& dq_nodes, + const std::vector& q_nodes) const override; +}; + // Input: DQ nodes for A, B and optional C // Output: optional Q node for Y class GemmNodeGroupSelector : public NodeGroupSelector { @@ -358,6 +366,13 @@ class MatMulSelector : public BaseSelector { allow_16bit, allow_4bit)) {} }; +// Convert "1 DQ node for input B -> MatMul" to "MatMulNBits" +class DQMatMulToMatMulNBitsSelector : public BaseSelector { + public: + explicit DQMatMulToMatMulNBitsSelector(gsl::span compatible_providers = {}) + : BaseSelector(std::make_unique(), compatible_providers) {} +}; + // Input: DQ nodes for A, B and optional C // Output: optional Q node for Y class GemmSelector : public BaseSelector { diff --git a/onnxruntime/core/optimizer/selectors_actions/actions.cc b/onnxruntime/core/optimizer/selectors_actions/actions.cc index c8d5acbf66b78..bb4033afedc49 100644 --- a/onnxruntime/core/optimizer/selectors_actions/actions.cc +++ b/onnxruntime/core/optimizer/selectors_actions/actions.cc @@ -102,12 +102,14 @@ static Status CreateReplacementNode(Graph& graph, Status ReplaceWithNew::Run(Graph& graph, const NodesToOptimize& selected_nodes) const { const RuntimeState runtime_state{graph, selected_nodes}; + Node* replacement{}; ORT_RETURN_IF_ERROR(CreateReplacementNode(graph, selected_nodes, OpType(runtime_state), Domain(runtime_state), ExtraAttributes(runtime_state), ValueMoves(runtime_state), - /* only_update_dest_definitions */ false, nullptr)); + /* only_update_dest_definitions */ false, &replacement)); + ORT_RETURN_IF_ERROR(ProcessNewNode(graph, selected_nodes, *replacement)); return node_remover_.Run(graph, selected_nodes); } diff --git a/onnxruntime/core/optimizer/selectors_actions/actions.h b/onnxruntime/core/optimizer/selectors_actions/actions.h index 9384bfa7027cd..465ae38565b15 100644 --- a/onnxruntime/core/optimizer/selectors_actions/actions.h +++ b/onnxruntime/core/optimizer/selectors_actions/actions.h @@ -158,6 +158,12 @@ struct ReplaceWithNew : public Action { // specifies how the inputs and outputs for the replaced nodes are moved to the new node virtual std::vector ValueMoves(const RuntimeState&) const = 0; + // For the changes that cannot be done by simply moving node args around, use this method to make + // additional changes to the new node and the graph. e.g., DQMatMulToMatMulNBitsAction transposes + // the second weight of MatMul ops and create new node args. + // Note: This method is only used in Run(), but not in RunForSave(). + virtual Status ProcessNewNode(Graph&, const NodesToOptimize&, Node&) const { return Status::OK(); } + RemoveNodes node_remover_; }; @@ -187,5 +193,4 @@ struct ReplaceWithNewFixed : public ReplaceWithNew { const NodeAttributes extra_attrs_; const std::vector value_moves_; }; - } // namespace onnxruntime diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index f0eed91d70440..3fd6e84e0e5ce 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -1609,7 +1609,8 @@ Status PartitionOrtFormatModel(onnxruntime::Graph& graph, #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) Status ApplyOrtFormatModelRuntimeOptimizations( onnxruntime::Graph& graph, const logging::Logger& logger, const SessionOptions& session_options, - const InlinedHashSet& optimizers_to_disable, const IExecutionProvider& cpu_ep) { + const InlinedHashSet& optimizers_to_disable, const IExecutionProvider& cpu_ep, + concurrency::ThreadPool* intra_op_thread_pool) { bool modified = false; for (int level = static_cast(TransformerLevel::Level2); @@ -1617,7 +1618,7 @@ Status ApplyOrtFormatModelRuntimeOptimizations( ++level) { const auto transformers = optimizer_utils::GenerateTransformersForMinimalBuild( static_cast(level), session_options, SatRuntimeOptimizationLoadContext{}, cpu_ep, - optimizers_to_disable); + optimizers_to_disable, intra_op_thread_pool); for (const auto& transformer : transformers) { ORT_RETURN_IF_ERROR(transformer->Apply(graph, modified, logger)); @@ -2005,7 +2006,8 @@ common::Status InferenceSession::Initialize() { #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) const auto& cpu_ep = *execution_providers_.Get(onnxruntime::kCpuExecutionProvider); ORT_RETURN_IF_ERROR_SESSIONID_( - ApplyOrtFormatModelRuntimeOptimizations(graph, *session_logger_, session_options_, optimizers_to_disable_, cpu_ep)); + ApplyOrtFormatModelRuntimeOptimizations(graph, *session_logger_, session_options_, optimizers_to_disable_, + cpu_ep, GetIntraOpThreadPoolToUse())); #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) } @@ -3167,7 +3169,8 @@ common::Status InferenceSession::AddPredefinedTransformers( if (use_full_build_optimizations) { return optimizer_utils::GenerateTransformers(level, session_options_, cpu_ep, - optimizers_to_disable_); + optimizers_to_disable_, + GetIntraOpThreadPoolToUse()); } else { const auto sat_context = minimal_build_optimization_handling == @@ -3176,7 +3179,8 @@ common::Status InferenceSession::AddPredefinedTransformers( record_runtime_optimization_produced_op_schema_fn}} : SatApplyContextVariant{SatDirectApplicationContext{}}; return optimizer_utils::GenerateTransformersForMinimalBuild(level, session_options_, sat_context, cpu_ep, - optimizers_to_disable_); + optimizers_to_disable_, + GetIntraOpThreadPoolToUse()); } }(); diff --git a/onnxruntime/test/common/random_generator.h b/onnxruntime/test/common/random_generator.h index 9ab4a82463d51..9bc50ce88ef16 100644 --- a/onnxruntime/test/common/random_generator.h +++ b/onnxruntime/test/common/random_generator.h @@ -12,6 +12,7 @@ #include "core/common/common.h" #include "core/common/optional.h" #include "core/common/type_utils.h" +#include "core/framework/int4.h" #include "test/util/include/test_random_seed.h" namespace onnxruntime { @@ -108,6 +109,22 @@ class RandomValueGenerator { return val; } + template + typename std::enable_if< + std::is_same_v || std::is_same_v, + std::vector>::type + Uniform(gsl::span dims, TInt4 min, TInt4 max) { + using UnpackedType = typename TInt4::UnpackedType; + std::vector data_int8 = Uniform(dims, min.GetElem(0), max.GetElem(0)); + std::vector data(TInt4::CalcNumInt4Pairs(data_int8.size())); + for (size_t i = 0; i < data_int8.size(); i++) { + size_t r = i >> 1; + size_t c = i & 0x1; + data[r].SetElem(c, data_int8[i]); + } + return data; + } + // Gaussian distribution for float template typename std::enable_if< diff --git a/onnxruntime/test/optimizer/graph_transform_test_builder.h b/onnxruntime/test/optimizer/graph_transform_test_builder.h index 6214094a26c4f..b9af675afe74d 100644 --- a/onnxruntime/test/optimizer/graph_transform_test_builder.h +++ b/onnxruntime/test/optimizer/graph_transform_test_builder.h @@ -117,22 +117,6 @@ class ModelTestBuilder { return MakeInput(shape, data); } - template - typename std::enable_if< - std::is_same_v || std::is_same_v, - NodeArg*>::type - MakeInputInt4(const std::vector& shape, typename TInt4::UnpackedType min, typename TInt4::UnpackedType max) { - using UnpackedType = typename TInt4::UnpackedType; - std::vector data_int8 = rand_gen_.Uniform(shape, min, max); - std::vector data(TInt4::CalcNumInt4Pairs(data_int8.size())); - for (size_t i = 0; i < data_int8.size(); i++) { - size_t r = i >> 1; - size_t c = i & 0x1; - data[r].SetElem(c, data_int8[i]); - } - return MakeInput(shape, data); - } - template NodeArg* MakeInput(const std::optional>& shape, std::optional input_name = std::nullopt) { diff --git a/onnxruntime/test/optimizer/qdq_matmulnbits_transformer_test.cc b/onnxruntime/test/optimizer/qdq_matmulnbits_transformer_test.cc new file mode 100644 index 0000000000000..3d117794104fa --- /dev/null +++ b/onnxruntime/test/optimizer/qdq_matmulnbits_transformer_test.cc @@ -0,0 +1,425 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include + +#include "core/common/span_utils.h" +#include "core/framework/int4.h" +#include "core/graph/node_attr_utils.h" +#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h" +#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h" +#include "core/optimizer/qdq_transformer/selectors_actions/shared/utils.h" +#include "core/session/onnxruntime_session_options_config_keys.h" + +#include "test/compare_ortvalue.h" +#include "test/test_environment.h" +#include "test/framework/test_utils.h" +#include "test/optimizer/qdq_test_utils.h" +#include "test/optimizer/graph_transform_test_builder.h" +#include "test/util/include/asserts.h" +#include "test/util/include/inference_session_wrapper.h" + +#include "gtest/gtest.h" + +#if defined(_MSC_VER) +#pragma warning(disable : 4127) +#endif // #if defined(_MSC_VER) + +struct QDQOpKeys { + const char* quantize_linear; + const char* dequantize_linear; +}; + +constexpr QDQOpKeys GetQDQOpKeys(bool use_contrib_qdq) { + if (use_contrib_qdq) { + return {"com.microsoft.QuantizeLinear", "com.microsoft.DequantizeLinear"}; + } + return {"QuantizeLinear", "DequantizeLinear"}; +} + +namespace onnxruntime { +namespace test { + +#if !defined(DISABLE_CONTRIB_OPS) + +// Input1 Input2 +// | | +// \ DQ +// \ / +// MatMul +// | +// output +template +typename std::enable_if || std::is_same_v, void>::type +RunDQMatMulNotConverted_NonConstDQ(const std::vector& input1_shape, + const std::vector& input2_shape, + const int64_t axis, + const int64_t block_size, + int64_t accuracy_level) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* input1_arg = builder.MakeInput(input1_shape, -100.0f, 100.0f); + auto* input2_arg = builder.MakeInput(input2_shape, T(T::min_val, 0), T(T::max_val, 0)); + auto* output_arg = builder.MakeOutput(); + + // add DQ + auto* dq_output = builder.MakeIntermediate(); + NodeAttributes attrs; + utils::SetNodeAttribute(utils::MakeAttribute("axis", axis), attrs); + utils::SetNodeAttribute(utils::MakeAttribute("block_size", block_size), attrs); + + auto scale_shape = std::vector{input2_shape}; + scale_shape[axis] = (scale_shape[axis] + block_size - 1) / block_size; + auto* scale_arg = builder.MakeInitializer(scale_shape, 8.0f, 12.0f); + if constexpr (use_zp) { + auto* zp_arg = builder.MakeInitializer(scale_shape, T(0, 0), T(2, 0)); + builder.AddNode("DequantizeLinear", {input2_arg, scale_arg, zp_arg}, {dq_output}, "", &attrs); + } else { + builder.AddNode("DequantizeLinear", {input2_arg, scale_arg}, {dq_output}, "", &attrs); + } + + builder.AddNode("MatMul", {input1_arg, dq_output}, {output_arg}); + }; + + auto check_graph = [&](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + const QDQOpKeys qdq_keys = GetQDQOpKeys(false); + EXPECT_EQ(op_to_count["MatMul"], 1); + EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 0); + EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 1); + }; + + std::function add_session_options_fn{}; + if (accuracy_level >= 0) { + add_session_options_fn = [accuracy_level](SessionOptions& sess_opts) { + std::ignore = sess_opts.config_options.AddConfigEntry(kOrtSessionOptionsQDQMatMulNBitsAccuracyLevel, + std::to_string(accuracy_level).c_str()); + }; + } + + TransformerTester(build_test_case, + check_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 21 /*opset_version*/, + 1e-5 /*per_sample_tolerance*/, + 1e-5 /*relative_per_sample_tolerance*/, + nullptr, + add_session_options_fn); +} + +TEST(QDQTransformerTests, DQMatMulNotConvertedToMatMulNBits_NonConstDQ) { + // DQ contrib op schema is not updated to support blocked quantization + RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, 0); + RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, 0); + RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, 0); + RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, 0); + RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, 1); + RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, 1); + RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, 1); + RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, 1); + RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, 4); + RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, 4); + RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, 4); + RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, 4); + RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, -1); + RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, -1); + RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, -1); + RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, -1); +} + +// Input2 +// | +// DQ / +// \ / +// MatMul +// | +// output +template +typename std::enable_if || std::is_same_v, void>::type +RunDQMatMulNotConverted_FirstDQInput(const std::vector& weight_shape, + const std::vector& input2_shape, + const int64_t axis, + const int64_t block_size, + int64_t accuracy_level) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* weight_arg = builder.MakeInitializer(weight_shape, T(T::min_val, 0), T(T::max_val, 0)); + auto* input2_arg = builder.MakeInput(input2_shape, -100.0f, 100.0f); + auto* output_arg = builder.MakeOutput(); + + // add DQ + auto* dq_output = builder.MakeIntermediate(); + NodeAttributes attrs; + utils::SetNodeAttribute(utils::MakeAttribute("axis", axis), attrs); + utils::SetNodeAttribute(utils::MakeAttribute("block_size", block_size), attrs); + + auto scale_shape = std::vector{weight_shape}; + scale_shape[axis] = (scale_shape[axis] + block_size - 1) / block_size; + auto* scale_arg = builder.MakeInitializer(scale_shape, 8.0f, 12.0f); + if constexpr (use_zp) { + auto* zp_arg = builder.MakeInitializer(scale_shape, T(0, 0), T(2, 0)); + builder.AddNode("DequantizeLinear", {weight_arg, scale_arg, zp_arg}, {dq_output}, "", &attrs); + } else { + builder.AddNode("DequantizeLinear", {weight_arg, scale_arg}, {dq_output}, "", &attrs); + } + + builder.AddNode("MatMul", {dq_output, input2_arg}, {output_arg}); + }; + + auto check_graph = [&](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + const QDQOpKeys qdq_keys = GetQDQOpKeys(false); + EXPECT_EQ(op_to_count["MatMul"], 1); + EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 0); + EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 1); + }; + + std::function add_session_options_fn{}; + if (accuracy_level >= 0) { + add_session_options_fn = [accuracy_level](SessionOptions& sess_opts) { + std::ignore = sess_opts.config_options.AddConfigEntry(kOrtSessionOptionsQDQMatMulNBitsAccuracyLevel, + std::to_string(accuracy_level).c_str()); + }; + } + + TransformerTester(build_test_case, + check_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 21 /*opset_version*/, + 1e-5 /*per_sample_tolerance*/, + 1e-5 /*relative_per_sample_tolerance*/, + nullptr, + add_session_options_fn); +} + +TEST(QDQTransformerTests, DQMatMulNotConvertedToMatMulNBits_FirstDQInput) { + // DQ contrib op schema is not updated to support blocked quantization + RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, 0); + RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, 0); + RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, 0); + RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, 0); + RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, 1); + RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, 1); + RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, 1); + RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, 1); + RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, 4); + RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, 4); + RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, 4); + RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, 4); + RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, -1); + RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, -1); + RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, -1); + RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, -1); +} + +// Input1 +// | +// \ DQ +// \ / +// MatMul +// | +// output +template +void RunDQMatMulNotConverted_TypeShapeMismatch(const std::vector& input1_shape, + const std::vector& weight_shape, + const int64_t axis, + const int64_t block_size, + int64_t accuracy_level) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* input_arg = builder.MakeInput(input1_shape, -100.0f, 100.0f); + auto* output_arg = builder.MakeOutput(); + NodeArg* weight_arg = nullptr; + + // add DQ + if constexpr (std::is_same_v || std::is_same_v) { + weight_arg = builder.MakeInitializer(weight_shape, T(T::min_val, 0), T(T::max_val, 0)); + } else { + weight_arg = builder.MakeInitializer(weight_shape, + std::numeric_limits::min(), + std::numeric_limits::max()); + } + + auto* dq_output = builder.MakeIntermediate(); + NodeAttributes attrs; + utils::SetNodeAttribute(utils::MakeAttribute("axis", axis), attrs); + utils::SetNodeAttribute(utils::MakeAttribute("block_size", block_size), attrs); + + auto scale_shape = std::vector{weight_shape}; + scale_shape[axis] = (scale_shape[axis] + block_size - 1) / block_size; + auto* scale_arg = builder.MakeInitializer(scale_shape, 8.0f, 12.0f); + if constexpr (use_zp) { + NodeArg* zp_arg; + if constexpr (std::is_same_v || std::is_same_v) { + zp_arg = builder.MakeInitializer(scale_shape, T(0, 0), T(2, 0)); + } else { + zp_arg = builder.MakeInitializer(scale_shape, 0, 2); + } + + builder.AddNode("DequantizeLinear", {weight_arg, scale_arg, zp_arg}, {dq_output}, "", &attrs); + } else { + builder.AddNode("DequantizeLinear", {weight_arg, scale_arg}, {dq_output}, "", &attrs); + } + + builder.AddNode("MatMul", {input_arg, dq_output}, {output_arg}); + }; + + auto check_graph = [&](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + const QDQOpKeys qdq_keys = GetQDQOpKeys(false); + EXPECT_EQ(op_to_count["MatMul"], 1); + EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 0); + EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 1); + }; + + std::function add_session_options_fn{}; + if (accuracy_level >= 0) { + add_session_options_fn = [accuracy_level](SessionOptions& sess_opts) { + std::ignore = sess_opts.config_options.AddConfigEntry(kOrtSessionOptionsQDQMatMulNBitsAccuracyLevel, + std::to_string(accuracy_level).c_str()); + }; + } + + TransformerTester(build_test_case, + check_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 21 /*opset_version*/, + 1e-5 /*per_sample_tolerance*/, + 1e-5 /*relative_per_sample_tolerance*/, + nullptr, + add_session_options_fn); +} + +TEST(QDQTransformerTests, DQMatMulNotConvertedToMatMulNBits_TypeMismatch) { + // DQ contrib op schema is not updated to support blocked quantization + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 16, 0); + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 16, 0); + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 16, 0); + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 16, 0); + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 16, 0); + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 16, 0); + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 16, 0); + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 16, 0); + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 16, 0); +} + +TEST(QDQTransformerTests, DQMatMulNotConvertedToMatMulNBits_ShapeMismatch) { + // DQ contrib op schema is not updated to support blocked quantization + // block size too small + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 8, 0); + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 8, 0); + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 8, 0); + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 8, 0); + // block size not 2's power + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 17, 0); + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 17, 0); + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 17, 0); + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 17, 0); + // not axis 0 + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 37}, 1, 16, 0); + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 37}, 1, 16, 0); + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 37}, 1, 16, 0); + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 37}, 1, 16, 0); + // not rank 2 + RunDQMatMulNotConverted_TypeShapeMismatch({2, 12, 37}, {2, 37, 12}, 0, 16, 0); + RunDQMatMulNotConverted_TypeShapeMismatch({2, 12, 37}, {2, 37, 12}, 0, 16, 0); + RunDQMatMulNotConverted_TypeShapeMismatch({2, 12, 37}, {2, 37, 12}, 0, 16, 0); + RunDQMatMulNotConverted_TypeShapeMismatch({2, 12, 37}, {2, 37, 12}, 0, 16, 0); +} + +// Input1 +// | DQ +// \ / +// MatMul +// | DQ +// \ / +// MatMul +// | +// output +template +typename std::enable_if || std::is_same_v, void>::type +RunDQMatMulConverted(const std::vector& input1_shape, + const std::vector& weight1_shape, + const std::vector& weight2_shape, + const int64_t axis, + const int64_t block_size, + int64_t accuracy_level) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* input_arg = builder.MakeInput(input1_shape, -100.0f, 100.0f); + auto* output_arg = builder.MakeOutput(); + + // add DQ + NodeAttributes attrs; + utils::SetNodeAttribute(utils::MakeAttribute("axis", axis), attrs); + utils::SetNodeAttribute(utils::MakeAttribute("block_size", block_size), attrs); + auto scale1_shape = std::vector{weight1_shape}; + auto scale2_shape = std::vector{weight2_shape}; + scale1_shape[axis] = (scale1_shape[axis] + block_size - 1) / block_size; + scale2_shape[axis] = (scale2_shape[axis] + block_size - 1) / block_size; + + auto* weight1_arg = builder.MakeInitializer(weight1_shape, T(T::min_val, 0), T(T::max_val, 0)); + auto* weight2_arg = builder.MakeInitializer(weight2_shape, T(T::min_val, 0), T(T::max_val, 0)); + auto* dq1_output = builder.MakeIntermediate(); + auto* dq2_output = builder.MakeIntermediate(); + auto* matmul1_output = builder.MakeIntermediate(); + + auto* scales1_arg = builder.MakeInitializer(scale1_shape, 8.0f, 12.0f); + auto* scales2_arg = builder.MakeInitializer(scale2_shape, 8.0f, 12.0f); + if constexpr (use_zp) { + auto* zp1_arg = builder.MakeInitializer(scale1_shape, T(0, 0), T(2, 0)); + auto* zp2_arg = builder.MakeInitializer(scale2_shape, T(0, 0), T(2, 0)); + builder.AddNode("DequantizeLinear", {weight1_arg, scales1_arg, zp1_arg}, {dq1_output}, "", &attrs); + builder.AddNode("DequantizeLinear", {weight2_arg, scales2_arg, zp2_arg}, {dq2_output}, "", &attrs); + } else { + builder.AddNode("DequantizeLinear", {weight1_arg, scales1_arg}, {dq1_output}, "", &attrs); + builder.AddNode("DequantizeLinear", {weight2_arg, scales2_arg}, {dq2_output}, "", &attrs); + } + + builder.AddNode("MatMul", {input_arg, dq1_output}, {matmul1_output}); + builder.AddNode("MatMul", {matmul1_output, dq2_output}, {output_arg}); + }; + + auto check_graph = [&](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + const QDQOpKeys qdq_keys = GetQDQOpKeys(false); + EXPECT_EQ(op_to_count["MatMul"], 0); + EXPECT_EQ(op_to_count["com.microsoft.MatMulNBits"], 2); + EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 0); + EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 0); + }; + + std::function add_session_options_fn{}; + if (accuracy_level >= 0) { + add_session_options_fn = [accuracy_level](SessionOptions& sess_opts) { + std::ignore = sess_opts.config_options.AddConfigEntry(kOrtSessionOptionsQDQMatMulNBitsAccuracyLevel, + std::to_string(accuracy_level).c_str()); + }; + } + + TransformerTester(build_test_case, + check_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 21 /*opset_version*/, + 1e-5 /*per_sample_tolerance*/, + 1e-5 /*relative_per_sample_tolerance*/, + nullptr, + add_session_options_fn); +} + +TEST(QDQTransformerTests, DQMatMulConvertedToMatMulNBits) { + // DQ contrib op schema is not updated to support blocked quantization + RunDQMatMulConverted({12, 12}, {12, 37}, {37, 12}, 0, 16, 0); + RunDQMatMulConverted({12, 12}, {12, 37}, {37, 12}, 0, 16, 0); + RunDQMatMulConverted({12, 12}, {12, 37}, {37, 12}, 0, 16, 0); + RunDQMatMulConverted({12, 12}, {12, 37}, {37, 12}, 0, 16, 0); + RunDQMatMulConverted({12, 12}, {12, 37}, {37, 12}, 0, 16, 1); + RunDQMatMulConverted({12, 12}, {12, 37}, {37, 12}, 0, 16, 1); + RunDQMatMulConverted({12, 12}, {12, 37}, {37, 12}, 0, 16, 1); + RunDQMatMulConverted({12, 12}, {12, 37}, {37, 12}, 0, 16, 1); +} + +#endif // !defined(DISABLE_CONTRIB_OPS) + +} // namespace test +} // namespace onnxruntime diff --git a/onnxruntime/test/optimizer/qdq_test_utils.h b/onnxruntime/test/optimizer/qdq_test_utils.h index 862408f31f004..52ac2a2541a79 100644 --- a/onnxruntime/test/optimizer/qdq_test_utils.h +++ b/onnxruntime/test/optimizer/qdq_test_utils.h @@ -517,7 +517,7 @@ GetQDQTestCaseFn BuildQDQSplitTestCase(const std::vector& input_shape, NodeArg* input_arg = nullptr; if constexpr (std::is_same_v || std::is_same_v) { - input_arg = builder.MakeInputInt4(input_shape, InputType::min_val, InputType::max_val); + input_arg = builder.MakeInput(input_shape, InputType(InputType::min_val, 0), InputType(InputType::max_val, 0)); dq_zp = InputType(static_cast(InputType::max_val / 2)); q_zp = OutputType(static_cast(OutputType::max_val / 2)); } else {