diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index ad4bab1d09745..35ac6e0e5b154 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -405,7 +405,7 @@ Do not modify directly.* |Pad|*in* data:**T**
*in* pads:**tensor(int64)**
*in* value:**T**
*out* output:**T**|1+|**T** = tensor(float)| |QAttention|*in* input:**T1**
*in* weight:**T2**
*in* bias:**T3**
*in* input_scale:**T3**
*in* weight_scale:**T3**
*in* mask_index:**T4**
*in* input_zero_point:**T1**
*in* weight_zero_point:**T2**
*in* past:**T3**
*out* output:**T3**
*out* present:**T3**|1+|**T1** = tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(float)
**T4** = tensor(int32)| |QEmbedLayerNormalization|*in* input_ids:**T1**
*in* segment_ids:**T1**
*in* word_embedding_quant:**T2**
*in* position_embedding_quant:**T2**
*in* segment_embedding:**T2**
*in* gamma_quant:**T2**
*in* beta_quant:**T2**
*in* mask:**T1**
*in* word_embedding_scale:**T**
*in* position_embedding_scale:**T**
*in* segment_embedding_scale:**T**
*in* gamma_scale:**T**
*in* beta_scale:**T**
*in* word_embedding_zero_point:**T2**
*in* position_embedding_zero_point:**T2**
*in* segment_embedding_zero_point:**T2**
*in* gamma_zero_point:**T2**
*in* beta_zero_point:**T2**
*out* layernorm_out:**T**
*out* mask_index_out:**T1**|1+|**T** = tensor(float)| -|QGemm|*in* A:**TA**
*in* a_scale:**T**
*in* a_zero_point:**TA**
*in* B:**TB**
*in* b_scale:**T**
*in* b_zero_point:**TB**
*in* C:**TC**
*in* y_scale:**T**
*in* y_zero_point:**TYZ**
*out* Y:**TY**|1+|**T** = tensor(float)
**TA** = tensor(int8), tensor(uint8)
**TB** = tensor(int8), tensor(uint8)
**TC** = tensor(int32)
**TY** = tensor(float), tensor(int8), tensor(uint8)
**TYZ** = tensor(int8), tensor(uint8)| +|QGemm|*in* A:**TA**
*in* a_scale:**T**
*in* a_zero_point:**TA**
*in* B:**TB**
*in* b_scale:**T**
*in* b_zero_point:**TB**
*in* C:**TC**
*in* y_scale:**T**
*in* y_zero_point:**TYZ**
*out* Y:**TY**|1+|**T** = tensor(float)
**TA** = tensor(uint8)
**TB** = tensor(int8), tensor(uint8)
**TC** = tensor(int32)
**TY** = tensor(float), tensor(uint8)
**TYZ** = tensor(uint8)| |QLinearAdd|*in* A:**T**
*in* A_scale:**tensor(float)**
*in* A_zero_point:**T**
*in* B:**T**
*in* B_scale:**tensor(float)**
*in* B_zero_point:**T**
*in* C_scale:**tensor(float)**
*in* C_zero_point:**T**
*out* C:**T**|1+|**T** = tensor(int8), tensor(uint8)| |QLinearConv|*in* x:**T1**
*in* x_scale:**tensor(float)**
*in* x_zero_point:**T1**
*in* w:**T2**
*in* w_scale:**tensor(float)**
*in* w_zero_point:**T2**
*in* y_scale:**tensor(float)**
*in* y_zero_point:**T3**
*in* B:**T4**
*out* y:**T3**|1+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(int8), tensor(uint8)
**T4** = tensor(int32)| |QLinearLeakyRelu|*in* X:**T**
*in* X_scale:**tensor(float)**
*in* X_zero_point:**T**
*in* Y_scale:**tensor(float)**
*in* Y_zero_point:**T**
*out* Y:**T**|1+|**T** = tensor(int8), tensor(uint8)| diff --git a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc index 08125ce51bca2..e9d22000b07fb 100644 --- a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc @@ -78,7 +78,6 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int8_t, NhwcMaxPool); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, NhwcMaxPool); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, QEmbedLayerNormalization); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int8_t, QGemm); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, QGemm); // ******** End: Quantization ******************* // @@ -170,7 +169,6 @@ Status RegisterQuantizationKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, BuildKernelCreateInfo, }; diff --git a/onnxruntime/contrib_ops/cpu/quantization/quant_gemm.cc b/onnxruntime/contrib_ops/cpu/quantization/quant_gemm.cc index 8d83ab7f681db..9bc7b8e76907e 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/quant_gemm.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/quant_gemm.cc @@ -44,10 +44,8 @@ class QGemm : protected GemmBase, public MatMulIntegerBase { AllocatorPtr allocator; ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); - bool a_is_signed = a->IsDataType(); - const uint8_t* a_data = static_cast(a->DataRaw()); - BufferUniquePtr a_trans_buffer; + const uint8_t* a_data = a->template Data(); if (trans_A_ == CblasTrans) { a_data = quantization::TransPoseInputData(a_data, a_trans_buffer, allocator, K, M); } @@ -84,12 +82,12 @@ class QGemm : protected GemmBase, public MatMulIntegerBase { GemmBroadcastBias(M, N, 1.f, c->template Data(), &(c->Shape()), gemm_output_data); } - MLAS_GEMM_QUANT_SHAPE_PARAMS gemm_shape{M, N, K, a_is_signed, b_is_signed, c != nullptr}; + MLAS_GEMM_QUANT_SHAPE_PARAMS gemm_shape{M, N, K, false /*AIsSigned*/, b_is_signed, c != nullptr}; MLAS_GEMM_QUANT_DATA_PARAMS gemm_param; gemm_param.A = a_data; gemm_param.lda = gemm_shape.K; - gemm_param.ZeroPointA = *(static_cast(a_zp->DataRaw())); + gemm_param.ZeroPointA = *(a_zp->template Data()); gemm_param.B = b_data; gemm_param.ldb = gemm_shape.N; @@ -222,20 +220,5 @@ ONNX_OPERATOR_TYPED_KERNEL_EX( .TypeConstraint("TY", {DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}), QGemm); -ONNX_OPERATOR_TYPED_KERNEL_EX( - QGemm, - kMSDomain, - 1, - int8_t, - kCpuExecutionProvider, - KernelDefBuilder() - .TypeConstraint("T", DataTypeImpl::GetTensorType()) - .TypeConstraint("TA", DataTypeImpl::GetTensorType()) - .TypeConstraint("TB", DataTypeImpl::GetTensorType()) - .TypeConstraint("TC", DataTypeImpl::GetTensorType()) - .TypeConstraint("TYZ", DataTypeImpl::GetTensorType()) - .TypeConstraint("TY", {DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}), - QGemm); - } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc index f89f201fd7aab..63741cfb9773f 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc @@ -223,67 +223,5 @@ Status MatMulReplaceWithQLinear::Run(Graph& graph, const NodesToOptimize& select } } -static std::vector GetGemmMoveInfo(bool does_q_node_exist) { - NTO::NodeLocation dq_A{NTO::NodeType::kInput, 0}; - NTO::NodeLocation dq_B{NTO::NodeType::kInput, 1}; - NTO::NodeLocation dq_bias{NTO::NodeType::kInput, 2}; - NTO::NodeLocation target{NTO::NodeType::kTarget, 0}; - NTO::NodeLocation q{NTO::NodeType::kOutput, 0}; - - std::vector moves{ - MoveAll(dq_A, ArgType::kInput), // append all inputs from DQ of A - MoveAll(dq_B, ArgType::kInput), // append all inputs from DQ of B - MoveAndAppend(dq_bias, ArgType::kInput, 0, ArgType::kInput, true, true)}; // (optional) append bias - - if (does_q_node_exist) { - moves.push_back(MoveAndAppend(q, ArgType::kInput, 1, ArgType::kInput)); // append scale (input 1) from Q - moves.push_back(MoveAndAppend(q, ArgType::kInput, 2, ArgType::kInput)); // append zp (input 2) from Q - moves.push_back(MoveAll(q, ArgType::kOutput)); // and use the outputs from Q - } else { - moves.push_back(MoveAll(target, ArgType::kOutput)); - } - - return moves; -} - -GemmReplaceWithQuant::GemmReplaceWithQuant() - : qgemm_with_float_as_output_replacer_(kMSDomain, GetGemmMoveInfo(false), "QGemm"), - qgemm_with_8bits_as_output_replacer_(kMSDomain, GetGemmMoveInfo(true), "QGemm") { -} - -Status GemmReplaceWithQuant::Run(Graph& graph, const NodesToOptimize& selected_nodes) const { - RemoveAttrBeta(selected_nodes); - bool is_output_float = selected_nodes.num_outputs == 0; - if (is_output_float) { - return qgemm_with_float_as_output_replacer_.Run(graph, selected_nodes); - } - - return qgemm_with_8bits_as_output_replacer_.Run(graph, selected_nodes); -} - -#if !defined(ORT_MINIMAL_BUILD) -Status GemmReplaceWithQuant::RunForSave(Graph& graph, - const NodesToOptimize& selected_nodes, - const SatRuntimeOptimizationSaveContext& save_context, - SavedState& saved_state, - bool& graph_modified) const { - RemoveAttrBeta(selected_nodes); - bool is_output_float = selected_nodes.num_outputs == 0; - if (is_output_float) { - return qgemm_with_float_as_output_replacer_.RunForSave(graph, - selected_nodes, - save_context, - saved_state, - graph_modified); - } - - return qgemm_with_8bits_as_output_replacer_.RunForSave(graph, - selected_nodes, - save_context, - saved_state, - graph_modified); -} -#endif // !defined(ORT_MINIMAL_BUILD) - } // namespace QDQ } // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h index a81a4ec3ea0c0..2fa72a19c03b1 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h @@ -67,25 +67,5 @@ struct MatMulReplaceWithQLinear : public Action { BinaryReplaceWithQLinear qlinear_matmul_replacer_; }; -struct GemmReplaceWithQuant : public Action { - GemmReplaceWithQuant(); - - Status Run(Graph&, const NodesToOptimize& selected_nodes) const override; - -#if !defined(ORT_MINIMAL_BUILD) - Status RunForSave(Graph& /*graph*/, const NodesToOptimize& /*selected_nodes*/, - const SatRuntimeOptimizationSaveContext& /*save_context*/, - SavedState& /*saved_state*/, bool& /*graph_modified*/) const override; -#endif // !defined(ORT_MINIMAL_BUILD) - - static inline void RemoveAttrBeta(const NodesToOptimize& selected_nodes) { - selected_nodes.Target().ClearAttribute("beta"); - } - - private: - QDQReplaceWithNew qgemm_with_float_as_output_replacer_; - QDQReplaceWithNew qgemm_with_8bits_as_output_replacer_; -}; - } // namespace QDQ } // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc index ab1fcdcec3e73..50bc405378fb0 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc @@ -167,26 +167,6 @@ void MatMulQDQRules(SelectorActionRegistry& qdq_selector_action_registry, bool i #endif } -void GemmQDQRules(SelectorActionRegistry& qdq_selector_action_registry) { - // 3 to 5 nodes. 0=DQ A, 1=DQ B, 2=DQ C(optional), 3=Gemm, 4=Q Y(optional) - // Replace with QGemm - // Delete all original nodes. - const std::string action_name{"Gemm"}; - - std::unique_ptr action = std::make_unique(); - -#if !defined(ORT_MINIMAL_BUILD) - std::unique_ptr selector = std::make_unique(); - qdq_selector_action_registry.RegisterSelectorAndAction(action_name, - {{"Gemm", {}}}, - std::move(selector), - std::move(action)); - -#else - qdq_selector_action_registry.RegisterAction(action_name, std::move(action)); -#endif -} - SelectorActionRegistry CreateSelectorActionRegistry(bool is_int8_allowed) { SelectorActionRegistry qdq_selector_action_registry; @@ -197,7 +177,6 @@ SelectorActionRegistry CreateSelectorActionRegistry(bool is_int8_allowed) { VariadicOpQDQRules(qdq_selector_action_registry); ConvQDQRules(qdq_selector_action_registry, is_int8_allowed); MatMulQDQRules(qdq_selector_action_registry, is_int8_allowed); - GemmQDQRules(qdq_selector_action_registry); return qdq_selector_action_registry; } diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc index 3722504695176..db0c1dfe58646 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc @@ -40,23 +40,16 @@ static std::vector FindQDQNodes(const GraphViewer& graph_viewer, co bool NodeGroupSelector::CheckQDQNodes(const GraphViewer& graph_viewer, const Node& node, const std::vector& dq_nodes, const std::vector& q_nodes, - int num_dq_inputs, - bool is_empty_q_nodes_allowed) const { + int num_dq_inputs) const { if (num_dq_inputs == -1) { num_dq_inputs = NumActualValues(node, true); } - // The input is a Graph Viewer, so cannot use graph_utils or optimizer_utils - if (num_dq_inputs != gsl::narrow_cast(dq_nodes.size())) { - return false; - } - - if (q_nodes.empty()) { - return is_empty_q_nodes_allowed; - } - int num_outputs = NumActualValues(node, false); // number of outputs that exist - return (num_outputs == gsl::narrow_cast(q_nodes.size())) && + + // The input is a Graph Viewer, so cannot use graph_utils or optimizer_utils + return num_dq_inputs == gsl::narrow_cast(dq_nodes.size()) && + num_outputs == gsl::narrow_cast(q_nodes.size()) && q_nodes.size() == node.GetOutputEdgesCount() && !graph_viewer.NodeProducesGraphOutput(node); } @@ -255,48 +248,6 @@ bool MatMulNodeGroupSelector::Check(const GraphViewer& graph_viewer, } } -bool GemmNodeGroupSelector::Check(const GraphViewer& graph_viewer, - const Node& node, - const std::vector& dq_nodes, - const std::vector& q_nodes) const { - if (!CheckQDQNodes(graph_viewer, node, dq_nodes, q_nodes, - -1 /*num_dq_inputs*/, true /*is_empty_q_nodes_allowed*/)) { - return false; - } - - // input and output types need to be same - int32_t dt_A = dq_nodes[0]->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); - int32_t dt_B = dq_nodes[1]->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); - - if (dt_A == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8) { - if (dt_A != dt_B) { // if A is signed int, B must be signed int - return false; - } - } - - if (!q_nodes.empty()) { - int32_t dt_Y = q_nodes[0]->OutputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); - if (dt_A != dt_Y) { // activation and output must be same type - return false; - } - } - - if (dq_nodes.size() < 3) { // no bias - return true; - } - - if (node.GetAttributes().at("beta").f() != 1.0) { // beta needs to be 1.0 - return false; - } - - int32_t dt_bias = dq_nodes[2]->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); - return dt_bias == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32; -} - -void GemmSelector::UpdateBuilder(NodesToOptimizeIndicesBuilder& builder) const { - builder.input_nodes.resize(3, NodesToOptimizeIndices::kEmptyNodeIndex); -} - } // namespace QDQ } // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h index ddafc20d0bb85..7bef53e8c0465 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h @@ -34,8 +34,7 @@ class NodeGroupSelector { bool CheckQDQNodes(const GraphViewer& graph_viewer, const Node& node, const std::vector& dq_nodes, const std::vector& q_nodes, - int num_dq_inputs = -1, - bool is_empty_q_nodes_allowed = false) const; + int num_dq_inputs = -1) const; private: // derived classes should implement this check @@ -121,15 +120,6 @@ class MatMulNodeGroupSelector : public NodeGroupSelector { bool matmulintegertofloat_allowed_; }; -// Input: DQ nodes for A, B and optional C -// Output: optional Q node for Y -class GemmNodeGroupSelector : public NodeGroupSelector { - private: - bool Check(const GraphViewer& graph_viewer, const Node& node, - const std::vector& dq_nodes, - const std::vector& q_nodes) const override; -}; - /* * NodeSelector instances for use in the QDQ::SelectorActionTransformer. */ @@ -200,16 +190,6 @@ class MatMulSelector : public BaseSelector { : BaseSelector(std::make_unique(int8_allowed, /*matmulintegertofloat_allowed*/ true)) {} }; -// Input: DQ nodes for A, B and optional C -// Output: optional Q node for Y -class GemmSelector : public BaseSelector { - public: - GemmSelector() - : BaseSelector(std::make_unique()) {} - - void UpdateBuilder(NodesToOptimizeIndicesBuilder&) const override; -}; - } // namespace QDQ } // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/selectors_actions/helpers.cc b/onnxruntime/core/optimizer/selectors_actions/helpers.cc index 41444ba8000b4..c83f30f3fc474 100644 --- a/onnxruntime/core/optimizer/selectors_actions/helpers.cc +++ b/onnxruntime/core/optimizer/selectors_actions/helpers.cc @@ -314,16 +314,6 @@ Status MoveInputOutput(Graph& graph, const NodesToOptimize& selected_nodes, Node if (src != nullptr) { ORT_RETURN_IF_ERROR(MoveInputOutputImpl(graph, move.value_move_info, *src, dest, only_update_dest_definitions)); - } else if (move.value_move_info.optional && - move.value_move_info.fill_optional_with_empty) { - auto& dest_defs = (move.value_move_info.dest_slot.in_out == ArgType::kInput) - ? dest.MutableInputDefs() - : dest.MutableOutputDefs(); - dest_defs.push_back(&graph.GetOrCreateNodeArg("", nullptr)); - - if (move.value_move_info.dest_slot.in_out == ArgType::kInput) { - dest.MutableInputArgsCount().push_back(1); - } } } } diff --git a/onnxruntime/core/optimizer/selectors_actions/helpers.h b/onnxruntime/core/optimizer/selectors_actions/helpers.h index a57ece5339112..c70fee9e9e668 100644 --- a/onnxruntime/core/optimizer/selectors_actions/helpers.h +++ b/onnxruntime/core/optimizer/selectors_actions/helpers.h @@ -160,24 +160,18 @@ struct ValueMoveInfo { } // append single value (may be variadic) from source to destination - ValueMoveInfo(InOutDefSlot src_slot_in, - ArgType dest_slot_type, - bool is_optional = false, - bool fill_optional_with_empty = false) + ValueMoveInfo(InOutDefSlot src_slot_in, ArgType dest_slot_type, bool is_optional = false) : src_slot(src_slot_in), dest_slot{dest_slot_type, -1}, copy_all{false}, append{true}, - optional{is_optional}, - fill_optional_with_empty{fill_optional_with_empty} {} + optional{is_optional} {} InOutDefSlot src_slot; InOutDefSlot dest_slot; - bool copy_all{false}; // ignore src_slot.idx and copy all values - bool append{false}; // ignore dest_slot.idx and append to existing values - bool optional{false}; // optional copy that can be skipped if source node is missing - bool fill_optional_with_empty; // fill optional NodeArg by NodeArg with empty name. - // Only support in 'append single value' mode. + bool copy_all{false}; // ignore src_slot.idx and copy all values + bool append{false}; // ignore dest_slot.idx and append to existing values + bool optional{false}; // optional copy that can be skipped if source node is missing private: ValueMoveInfo() = default; @@ -219,12 +213,10 @@ inline NodeAndMoveInfo MoveToSlot(const NodesToOptimize::NodeLocation& src_node, inline NodeAndMoveInfo MoveAndAppend(const NodesToOptimize::NodeLocation& src_node, ArgType src_direction, int src_slot, ArgType dest_direction, - bool optional = false, - bool fill_optional_with_empty = false) { + bool optional = false) { return NodeAndMoveInfo{src_node, ValueMoveInfo{ InOutDefSlot{src_direction, src_slot}, // move from this slot - dest_direction, optional, - fill_optional_with_empty}}; // append here + dest_direction, optional}}; // append here } // move all inputs/outputs from the source node to the target/replacement node diff --git a/onnxruntime/python/tools/quantization/onnx_quantizer.py b/onnxruntime/python/tools/quantization/onnx_quantizer.py index 6349bd8cd53eb..ed99ee1560752 100644 --- a/onnxruntime/python/tools/quantization/onnx_quantizer.py +++ b/onnxruntime/python/tools/quantization/onnx_quantizer.py @@ -553,7 +553,7 @@ def find_quantized_value(self, input_name): return self.parent.find_quantized_value(input_name) return None - def quantize_bias_static(self, bias_name, input_name, weight_name, beta = 1.0): + def quantize_bias_static(self, bias_name, input_name, weight_name): ''' Quantized the bias. Zero Point == 0 and Scale == Input_Scale * Weight_Scale ''' @@ -584,7 +584,7 @@ def quantize_bias_static(self, bias_name, input_name, weight_name, beta = 1.0): input_scale = self.tensor_proto_to_array(inputscale_initializer) # calcuate scale for bias - bias_scale = input_scale * weight_scale * beta + bias_scale = input_scale * weight_scale # quantize bias quantized_data = (np.asarray(bias_data) / bias_scale).round().astype(np.int32) diff --git a/onnxruntime/python/tools/quantization/operators/gemm.py b/onnxruntime/python/tools/quantization/operators/gemm.py deleted file mode 100644 index f297bfb428a19..0000000000000 --- a/onnxruntime/python/tools/quantization/operators/gemm.py +++ /dev/null @@ -1,117 +0,0 @@ -import onnx -import numpy as np -import logging -from .base_operator import QuantOperatorBase -from .qdq_base_operator import QDQOperatorBase -from ..quant_utils import find_by_name, get_mul_node, QuantizedValue, QuantizedValueType, attribute_to_kwarg, ms_domain -from onnx import onnx_pb as onnx_proto - - -def is_B_transposed(gemm_node): - transB_attribute = [attr for attr in gemm_node.attribute if attr.name == 'transB'] - if len(transB_attribute): - return 0 < onnx.helper.get_attribute_value(transB_attribute[0]) - - return False - -def get_beta(gemm_node): - beta_attribute = [attr for attr in gemm_node.attribute if attr.name == 'beta'] - if len(beta_attribute): - return onnx.helper.get_attribute_value(beta_attribute[0]) - - return 1.0 - -def set_default_beta(gemm_node): - beta_attribute = [attr for attr in gemm_node.attribute if attr.name == 'beta'] - if len(beta_attribute): - beta_attribute[0].f = 1.0 - - return 1.0 - -class QLinearGemm(QuantOperatorBase): - def __init__(self, onnx_quantizer, onnx_node): - super().__init__(onnx_quantizer, onnx_node) - - def quantize(self): - node = self.node - assert (node.op_type == "Gemm") - - data_found, output_scale_name, output_zp_name, _, _ = \ - self.quantizer._get_quantization_params(node.output[0]) - - if self.quantizer.is_input_a_weight(node.input[1]) and self.quantizer.is_per_channel(): - (quantized_input_names, zero_point_names, scale_names, nodes) = \ - self.quantizer.quantize_inputs(node, [0], reduce_range=self.quantizer.reduce_range) - quant_weight_tuple = self.quantizer.quantize_weight_per_channel(node.input[1], onnx_proto.TensorProto.INT8, - 0 if is_B_transposed(node) else 1) - quantized_input_names.append(quant_weight_tuple[0]) - zero_point_names.append(quant_weight_tuple[1]) - scale_names.append(quant_weight_tuple[2]) - else: - (quantized_input_names, zero_point_names, scale_names, nodes) = \ - self.quantizer.quantize_inputs(node, [0, 1], reduce_range=self.quantizer.reduce_range) - - if not data_found or quantized_input_names is None: - return super().quantize() - - quantized_bias_name = "" - if len(node.input) == 3: - if not self.quantizer.is_input_a_weight(node.input[2]): - return super().quantize() - - quantized_bias_name = self.quantizer.quantize_bias_static(node.input[2], node.input[0], node.input[1], get_beta(self.node)) - - qgemm_output = node.output[0] + "_quantized" - qgemm_name = qgemm_name = node.name + "_quant" if node.name != "" else "" - - kwargs = {} - for attribute in node.attribute: - if attribute.name != "beta": - kwargs.update(attribute_to_kwarg(attribute)) - kwargs["domain"] = ms_domain - - # generate input - qgemm_inputs = [] - for i in range(2): - qgemm_inputs.extend([quantized_input_names[i], scale_names[i], zero_point_names[i]]) - - qgemm_inputs.extend([quantized_bias_name, output_scale_name, output_zp_name]) - - qgemm_node = onnx.helper.make_node("QGemm", qgemm_inputs, [qgemm_output], - qgemm_name, **kwargs) - nodes.append(qgemm_node) - - # Create an entry for this quantized value - q_output = QuantizedValue(node.output[0], qgemm_output, output_scale_name, output_zp_name, - QuantizedValueType.Input) - self.quantizer.quantized_value_map[node.output[0]] = q_output - - self.quantizer.new_nodes += nodes - - -class QDQGemm(QDQOperatorBase): - def __init__(self, onnx_quantizer, onnx_node): - super().__init__(onnx_quantizer, onnx_node) - - def quantize(self): - node = self.node - assert (node.op_type == "Gemm") - - self.quantizer.quantize_tensor(node.input[0]) - if not self.disable_qdq_for_node_output: - self.quantizer.quantize_tensor(node.output[0]) - - if self.quantizer.is_per_channel(): - self.quantizer.quantize_tensor_per_channel(node.input[1], 0 if is_B_transposed(node) else 1) - else: - self.quantizer.quantize_tensor(node.input[1]) - - if len(node.input) == 3: - if self.quantizer.is_input_a_weight(node.input[2]): - self.quantizer.quantize_bias_tensor(node.input[2], node.input[0], node.input[1], get_beta(self.node)) - set_default_beta(self.node) - else: - logging.warning( - "Bias of Gemm node '{}' is not constant. Please exclude this node for better performance." - .format(self.node.name)) - diff --git a/onnxruntime/python/tools/quantization/qdq_quantizer.py b/onnxruntime/python/tools/quantization/qdq_quantizer.py index 09a45799f4737..f5797282dda06 100644 --- a/onnxruntime/python/tools/quantization/qdq_quantizer.py +++ b/onnxruntime/python/tools/quantization/qdq_quantizer.py @@ -83,11 +83,11 @@ def quantize_tensor_per_channel(self, tensor_name, axis): tensor_name)) self.quantize_tensor(tensor_name) - def quantize_bias_tensor(self, bias_name, input_name, weight_name, beta = 1.0): + def quantize_bias_tensor(self, bias_name, input_name, weight_name): weight = find_by_name(bias_name, self.model.initializer()) if weight is not None: if weight.data_type == onnx_proto.TensorProto.FLOAT: - self.bias_to_quantize.append((bias_name, input_name, weight_name, beta)) + self.bias_to_quantize.append((bias_name, input_name, weight_name)) else: logging.warning("Expected {} to be a weight".format(bias_name)) @@ -222,11 +222,11 @@ def quantize_tensors(self): self.quantized_value_map[tensor_name] = quantized_value def quantize_bias_tensors(self): - for bias_name, input_name, weight_name, beta in self.bias_to_quantize: + for bias_name, input_name, weight_name in self.bias_to_quantize: if bias_name in self.quantized_value_map.keys(): continue # Quantize the input - self.quantize_bias_static(bias_name, input_name, weight_name, beta) + self.quantize_bias_static(bias_name, input_name, weight_name) self.model.remove_initializer(find_by_name(bias_name, self.model.initializer())) quant_value = self.quantized_value_map[bias_name] inputs = [quant_value.q_name, quant_value.scale_name, quant_value.zp_name] diff --git a/onnxruntime/python/tools/quantization/quantize.py b/onnxruntime/python/tools/quantization/quantize.py index 826d53884facd..a0a0b935226cd 100644 --- a/onnxruntime/python/tools/quantization/quantize.py +++ b/onnxruntime/python/tools/quantization/quantize.py @@ -42,16 +42,15 @@ def optimize_model(model_path: Path): return optimized_model -def load_model(model_path: Path, optimize=True, handle_gemm_with_matmul=True): - - model = optimize_model(Path(model_path)) if optimize else onnx.load(Path(model_path)) - - if handle_gemm_with_matmul: - onnx_model = ONNXModel(model) +def load_model(model_path: Path, optimize=True): + if optimize: + #optimize the original model + onnx_model = ONNXModel(optimize_model(Path(model_path))) + # to support GEMM onnx_model.replace_gemm_with_matmul() return onnx_model.model - return model + return onnx.load(Path(model_path)) def quantize(model, @@ -212,7 +211,7 @@ def quantize_static(model_input, if not op_types_to_quantize or len(op_types_to_quantize) == 0: op_types_to_quantize = list(QLinearOpsRegistry.keys()) - model = load_model(Path(model_input), optimize_model, False) + model = load_model(Path(model_input), optimize_model) calibrator = create_calibrator(model, op_types_to_quantize, calibrate_method=calibrate_method) calibrator.collect_data(calibration_data_reader) diff --git a/onnxruntime/python/tools/quantization/registry.py b/onnxruntime/python/tools/quantization/registry.py index d046cbc6dfa86..e63e1761f7143 100644 --- a/onnxruntime/python/tools/quantization/registry.py +++ b/onnxruntime/python/tools/quantization/registry.py @@ -18,7 +18,6 @@ from .operators.resize import QResize, QDQResize from .operators.pooling import QLinearPool from .operators.concat import QLinearConcat, QDQConcat -from .operators.gemm import QLinearGemm, QDQGemm CommonOpsRegistry = { "Gather": GatherQuant, @@ -37,7 +36,6 @@ QLinearOpsRegistry = { "ArgMax": QArgMax, "Conv": QLinearConv, - "Gemm": QLinearGemm, "MatMul": QLinearMatMul, "Add": QLinearBinaryOp, "Mul": QLinearBinaryOp, @@ -60,7 +58,6 @@ QDQRegistry = { "Conv": QDQConv, - "Gemm": QDQGemm, "Clip": QDQRemovableActivation, "Relu": QDQRemovableActivation, "Reshape": QDQDirect8BitOp, diff --git a/onnxruntime/test/contrib_ops/quant_gemm_test.cc b/onnxruntime/test/contrib_ops/quant_gemm_test.cc index c8dbea216ec2a..897320b30eef7 100644 --- a/onnxruntime/test/contrib_ops/quant_gemm_test.cc +++ b/onnxruntime/test/contrib_ops/quant_gemm_test.cc @@ -15,12 +15,11 @@ #include #include -#include namespace onnxruntime { namespace test { -template +template void RunQuantGemmU8X8Test(const int M, const int N, const int K, @@ -30,42 +29,36 @@ void RunQuantGemmU8X8Test(const int M, bool B_is_initializer, bool per_column = false) { static std::default_random_engine e(123); - static std::uniform_int_distribution random_A(std::numeric_limits::min(), - std::numeric_limits::max()); - - constexpr int overflow_adjust = std::is_signed_v ? 2 : 1; - constexpr int random_B_min = std::numeric_limits::min() / overflow_adjust; - constexpr int random_B_max = std::numeric_limits::min() / overflow_adjust; - static std::uniform_int_distribution random_B(random_B_min, - random_B_max); + static std::uniform_int_distribution n_unsigned(0, 127); + static std::uniform_int_distribution n_xint8(std::numeric_limits::min(), std::numeric_limits::max()); static std::uniform_real_distribution n_apha(1.0f, 2.0f); static std::uniform_real_distribution n_scale(0.003f, 0.004f); Eigen::MatrixXi matrix_a = Eigen::MatrixXi::Random(K, M) - .unaryExpr([](int) { return random_A(e); }); - std::vector matrix_a_data; + .unaryExpr([](int) { return n_unsigned(e); }); + std::vector matrix_a_data; if (is_A_trans) { Eigen::MatrixXi matrix_a_trans = matrix_a.transpose().eval(); - matrix_a_data = ToVector(matrix_a_trans.data(), M * K); + matrix_a_data = ToVector(matrix_a_trans.data(), M * K); } else { - matrix_a_data = ToVector(matrix_a.data(), M * K); + matrix_a_data = ToVector(matrix_a.data(), M * K); } - ActType a_zero_point = GetMiddle(matrix_a_data); + uint8_t a_zero_point = GetMiddle(matrix_a_data); Eigen::MatrixXi matrix_a_offset = matrix_a - a_zero_point * Eigen::MatrixXi::Ones(K, M); float a_scale = n_scale(e); Eigen::MatrixXi matrix_b = Eigen::MatrixXi::Random(N, K) - .unaryExpr([](int) { return random_B(e); }); - std::vector matrix_b_data; + .unaryExpr([](int) { return n_xint8(e); }); + std::vector matrix_b_data; if (is_B_trans) { Eigen::MatrixXi matrix_b_trans = matrix_b.transpose().eval(); - matrix_b_data = ToVector(matrix_b_trans.data(), N * K); + matrix_b_data = ToVector(matrix_b_trans.data(), N * K); } else { - matrix_b_data = ToVector(matrix_b.data(), N * K); + matrix_b_data = ToVector(matrix_b.data(), N * K); } - WeightType b_zero_point = GetMiddle(matrix_b_data); + ScalarB b_zero_point = GetMiddle(matrix_b_data); std::vector b_scale({n_scale(e)}); - std::vector b_zp_per_column({b_zero_point}); + std::vector b_zp_per_column({b_zero_point}); Eigen::MatrixXi b_zp_matrix = b_zero_point * Eigen::MatrixXi::Ones(N, K); Eigen::MatrixXf b_scale_matrix = b_scale[0] * Eigen::MatrixXf::Ones(N, M); if (per_column) { @@ -81,7 +74,7 @@ void RunQuantGemmU8X8Test(const int M, float alpha = n_apha(e); Eigen::MatrixXi matrix_c = Eigen::MatrixXi::Random(N, M) - .unaryExpr([](int) { return random_A(e); }); + .unaryExpr([](int) { return n_xint8(e); }); Eigen::MatrixXi matrix_int32 = (matrix_b - b_zp_matrix) * matrix_a_offset; if (has_C) { @@ -93,12 +86,12 @@ void RunQuantGemmU8X8Test(const int M, test.AddAttribute("transA", is_A_trans ? 1 : 0); test.AddAttribute("transB", is_B_trans ? 1 : 0); test.AddAttribute("alpha", alpha); - test.AddInput("A", is_A_trans ? std::vector({K, M}) : std::vector({M, K}), std::move(matrix_a_data)); + test.AddInput("A", is_A_trans ? std::vector({K, M}) : std::vector({M, K}), std::move(matrix_a_data)); test.AddInput("a_scale", {}, {a_scale}); - test.AddInput("a_zero_point", {}, {a_zero_point}); - test.AddInput("B", is_B_trans ? std::vector({N, K}) : std::vector({K, N}), std::move(matrix_b_data), B_is_initializer); + test.AddInput("a_zero_point", {}, {a_zero_point}); + test.AddInput("B", is_B_trans ? std::vector({N, K}) : std::vector({K, N}), std::move(matrix_b_data), B_is_initializer); test.AddInput("b_scale", {SafeInt(b_scale.size())}, b_scale); - test.AddInput("b_zero_point", {SafeInt(b_zp_per_column.size())}, b_zp_per_column); + test.AddInput("b_zero_point", {SafeInt(b_zp_per_column.size())}, b_zp_per_column); if (has_C) { test.AddInput("C", {M, N}, ToVector(matrix_c.data(), M * N)); @@ -108,14 +101,14 @@ void RunQuantGemmU8X8Test(const int M, if constexpr (std::is_same_v) { test.AddOptionalInputEdge(); - test.AddOptionalInputEdge(); + test.AddOptionalInputEdge(); test.AddOutput("Y", {M, N}, std::vector(matrix_output.data(), matrix_output.data() + M * N)); } else { - std::vector quant_output(M * N); - quantization::Params quant_param = quantization::QuantizeLinear(matrix_output.data(), quant_output.data(), M * N); + std::vector quant_output(M * N); + quantization::Params quant_param = quantization::QuantizeLinear(matrix_output.data(), quant_output.data(), M * N); test.AddInput("y_scale", {}, {quant_param.scale}); - test.AddInput("y_zero_point", {}, {quant_param.zero_point}); - test.AddOutput("Y", {M, N}, quant_output); + test.AddInput("y_zero_point", {}, {quant_param.zero_point}); + test.AddOutput("Y", {M, N}, quant_output); } test.Run(); @@ -129,13 +122,10 @@ void RunQuantGemmTest(const int M, bool has_C, bool B_is_initializer, bool per_column = false) { - RunQuantGemmU8X8Test(M, N, K, is_A_trans, is_B_trans, has_C, B_is_initializer, per_column); - RunQuantGemmU8X8Test(M, N, K, is_A_trans, is_B_trans, has_C, B_is_initializer, per_column); - RunQuantGemmU8X8Test(M, N, K, is_A_trans, is_B_trans, has_C, B_is_initializer, per_column); - RunQuantGemmU8X8Test(M, N, K, is_A_trans, is_B_trans, has_C, B_is_initializer, per_column); - - RunQuantGemmU8X8Test(M, N, K, is_A_trans, is_B_trans, has_C, B_is_initializer, per_column); - RunQuantGemmU8X8Test(M, N, K, is_A_trans, is_B_trans, has_C, B_is_initializer, per_column); + RunQuantGemmU8X8Test(M, N, K, is_A_trans, is_B_trans, has_C, B_is_initializer, per_column); + RunQuantGemmU8X8Test(M, N, K, is_A_trans, is_B_trans, has_C, B_is_initializer, per_column); + RunQuantGemmU8X8Test(M, N, K, is_A_trans, is_B_trans, has_C, B_is_initializer, per_column); + RunQuantGemmU8X8Test(M, N, K, is_A_trans, is_B_trans, has_C, B_is_initializer, per_column); } void RunQuantGemmTestBatch(const int M, const int N, const int K) { diff --git a/onnxruntime/test/optimizer/qdq_transformer_test.cc b/onnxruntime/test/optimizer/qdq_transformer_test.cc index d74160217f46a..a81f95d994318 100644 --- a/onnxruntime/test/optimizer/qdq_transformer_test.cc +++ b/onnxruntime/test/optimizer/qdq_transformer_test.cc @@ -537,173 +537,6 @@ TEST(QDQTransformerTests, MatMul_S8S8U8) { QDQTransformerMatMulTests(true); } -template -void QDQTransformerGemmTests(bool has_output_q, bool has_bias, bool beta_not_one = false) { - auto test_case = [&](const std::vector& input1_shape, const std::vector& input2_shape) { - auto build_test_case = [&](ModelTestBuilder& builder) { - auto* input1_arg = builder.MakeInput(input1_shape, -1.f, 1.f); - auto* input2_arg = builder.MakeInput(input2_shape, -1.f, 1.f); - auto* output_arg = builder.MakeOutput(); - - typedef std::numeric_limits Input1Limits; - typedef std::numeric_limits Input2Limits; - typedef std::numeric_limits OutputTypeLimits; - - std::vector input_args; - - // add QDQ A - auto* q1_output = builder.MakeIntermediate(); - auto* dq1_output = builder.MakeIntermediate(); - builder.AddQuantizeLinearNode(input1_arg, - .039f, - (Input1Limits::max() + Input1Limits::min()) / 2 + 1, - q1_output); - builder.AddDequantizeLinearNode(q1_output, - .039f, - (Input2Limits::max() + Input1Limits::min()) / 2 + 1, - dq1_output); - - input_args.push_back(dq1_output); - - // add QDQ B - auto* q2_output = builder.MakeIntermediate(); - auto* dq2_output = builder.MakeIntermediate(); - builder.AddQuantizeLinearNode(input2_arg, - .04f, - (Input2Limits::max() + Input2Limits::min()) / 2 + 1, - q2_output); - builder.AddDequantizeLinearNode(q2_output, - .04f, - (Input2Limits::max() + Input2Limits::min()) / 2 + 1, - dq2_output); - input_args.push_back(dq2_output); - - if (has_bias) { - auto* dq_bias_output = builder.MakeIntermediate(); - auto* bias = builder.MakeInitializer({input2_shape[1]}, static_cast(0), static_cast(127)); - builder.AddDequantizeLinearNode(bias, 0.00156f, - 0, - dq_bias_output); - input_args.push_back(dq_bias_output); - } - - Node* gemm_node = nullptr; - - if (has_output_q) { - auto* gemm_op_output = builder.MakeIntermediate(); - gemm_node = &builder.AddNode("Gemm", input_args, {gemm_op_output}); - - // add QDQ output - auto* q3_output = builder.MakeIntermediate(); - builder.AddQuantizeLinearNode(gemm_op_output, - .039f, - (OutputTypeLimits::max() + OutputTypeLimits::min()) / 2 + 1, - q3_output); - builder.AddDequantizeLinearNode(q3_output, - .039f, - (OutputTypeLimits::max() + OutputTypeLimits::min()) / 2 + 1, - output_arg); - } else { - gemm_node = &builder.AddNode("Gemm", input_args, {output_arg}); - } - - if (beta_not_one) { - gemm_node->AddAttribute("beta", 2.0f); - } - }; - - auto check_binary_op_graph = [&](InferenceSessionWrapper& session) { - auto op_to_count = CountOpsInGraph(session.GetGraph()); - if ((!has_output_q || std::is_same_v)&& - (!has_bias || (std::is_same_v && !beta_not_one)) && - (std::is_same_v || std::is_same_v)) { - EXPECT_EQ(op_to_count["com.microsoft.QGemm"], 1); - EXPECT_EQ(op_to_count["Gemm"], 0); - EXPECT_EQ(op_to_count["QuantizeLinear"], 2); - EXPECT_EQ(op_to_count["DequantizeLinear"], has_output_q ? 1 : 0); - } else { - int q_count = 2; // Q for A and B - int dq_count = 2; // DQ for A and B - if (has_bias) { - dq_count++; - } - if (has_output_q) { - q_count++; - dq_count++; - } - EXPECT_EQ(op_to_count["com.microsoft.QGemm"], 0); - EXPECT_EQ(op_to_count["Gemm"], 1); - EXPECT_EQ(op_to_count["QuantizeLinear"], q_count); - EXPECT_EQ(op_to_count["DequantizeLinear"], dq_count); - } - }; - - TransformerTester(build_test_case, - check_binary_op_graph, - TransformerLevel::Level1, - TransformerLevel::Level2, - 12 /*opset_version*/, - 0.01 /*per_sample_tolerance*/, - 0.01 /*relative_per_sample_tolerance*/, - std::make_unique()); - }; - - test_case({2, 2}, {2, 4}); - test_case({13, 15}, {15, 15}); -} - -template -void QDQTransformerGemmTests() { - QDQTransformerGemmTests(false, false); - QDQTransformerGemmTests(false, true); - QDQTransformerGemmTests(true, false); - QDQTransformerGemmTests(true, true); - QDQTransformerGemmTests(false, false, true); - QDQTransformerGemmTests(false, true, true); - QDQTransformerGemmTests(true, false, true); - QDQTransformerGemmTests(true, true, true); -} - -TEST(QDQTransformerTests, Gemm_U8U8U8) { - QDQTransformerGemmTests(); - QDQTransformerGemmTests(); -} - -TEST(QDQTransformerTests, Gemm_U8S8S8) { - QDQTransformerGemmTests(); - QDQTransformerGemmTests(); -} - -TEST(QDQTransformerTests, Gemm_U8U8S8) { - QDQTransformerGemmTests(); - QDQTransformerGemmTests(); -} - -TEST(QDQTransformerTests, Gemm_U8S8U8) { - QDQTransformerGemmTests(); - QDQTransformerGemmTests(); -} - -TEST(QDQTransformerTests, Gemm_S8S8S8) { - QDQTransformerGemmTests(); - QDQTransformerGemmTests(); -} - -TEST(QDQTransformerTests, Gemm_S8U8U8) { - QDQTransformerGemmTests(); - QDQTransformerGemmTests(); -} - -TEST(QDQTransformerTests, Gemm_S8U8S8) { - QDQTransformerGemmTests(); - QDQTransformerGemmTests(); -} - -TEST(QDQTransformerTests, Gemm_S8S8U8) { - QDQTransformerGemmTests(); - QDQTransformerGemmTests(); -} - TEST(QDQTransformerTests, Gather) { auto test_case = [&](const std::vector& input1_shape, const std::vector& weights_shape) { auto build_test_case = [&](ModelTestBuilder& builder) { diff --git a/onnxruntime/test/providers/cpu/math/matmul_integer_test.cc b/onnxruntime/test/providers/cpu/math/matmul_integer_test.cc index 489d6ff915692..7e41fe020a91d 100644 --- a/onnxruntime/test/providers/cpu/math/matmul_integer_test.cc +++ b/onnxruntime/test/providers/cpu/math/matmul_integer_test.cc @@ -302,12 +302,12 @@ TEST(MatmulIntegerOpTest, MatMulInteger_PerColumn_ND) { } // [M x N] = [M x K] x [K x N] = [batch_seq x input_dim] x [input_dim x embed_dim] -template +template void RunMatMulIntegerU8X8Test(const int M, const int N, const int K, bool non_zero_zp, bool B_is_initializer, bool per_column_zp = false) { OpTester test("MatMulInteger", 10); static std::default_random_engine e(123); static std::uniform_int_distribution n_unsigned(0, 127); - static std::uniform_int_distribution n_xint8(std::numeric_limits::min(), std::numeric_limits::max()); + static std::uniform_int_distribution n_xint8(std::numeric_limits::min(), std::numeric_limits::max()); Eigen::MatrixXi matrix_a = Eigen::MatrixXi::Random(K, M) .unaryExpr([](int) { return n_unsigned(e); }); @@ -317,9 +317,9 @@ void RunMatMulIntegerU8X8Test(const int M, const int N, const int K, bool non_ze Eigen::MatrixXi matrix_b = Eigen::MatrixXi::Random(N, K) .unaryExpr([](int) { return n_xint8(e); }); - std::vector matrix_b_data = ToVector(matrix_b.data(), N * K); - WeightType b_zero_point = non_zero_zp ? GetMiddle(matrix_b_data) : 0; - std::vector b_zp_per_column(N, b_zero_point); + std::vector matrix_b_data = ToVector(matrix_b.data(), N * K); + ScalarB b_zero_point = non_zero_zp ? GetMiddle(matrix_b_data) : 0; + std::vector b_zp_per_column(N, b_zero_point); Eigen::MatrixXi b_zp_matrix = b_zero_point * Eigen::MatrixXi::Ones(N, K); if (non_zero_zp && per_column_zp) { for (int i = 0; i < N; i++) { @@ -331,13 +331,13 @@ void RunMatMulIntegerU8X8Test(const int M, const int N, const int K, bool non_ze Eigen::MatrixXi matrix_c = ((matrix_b - b_zp_matrix) * matrix_a_offset).eval(); test.AddInput("T1", {M, K}, std::move(matrix_a_data)); - test.AddInput("T2", {K, N}, std::move(matrix_b_data), B_is_initializer); + test.AddInput("T2", {K, N}, std::move(matrix_b_data), B_is_initializer); if (non_zero_zp) { test.AddInput("a_zero_point", {}, {a_zero_point}); if (per_column_zp) { - test.AddInput("b_zero_point", {N}, b_zp_per_column); + test.AddInput("b_zero_point", {N}, b_zp_per_column); } else { - test.AddInput("b_zero_point", {}, {b_zero_point}); + test.AddInput("b_zero_point", {}, {b_zero_point}); } } diff --git a/onnxruntime/test/python/quantization/test_op_gemm.py b/onnxruntime/test/python/quantization/test_op_gemm.py index 11a2ba488b4d5..cf61402fa5d84 100644 --- a/onnxruntime/test/python/quantization/test_op_gemm.py +++ b/onnxruntime/test/python/quantization/test_op_gemm.py @@ -130,7 +130,7 @@ def static_quant_test(self, model_fp32_path, data_reader, activation_type, weigh data_reader.rewind() quantize_static(model_fp32_path, model_int8_path, data_reader, activation_type=activation_type, weight_type=weight_type, extra_options=extra_options) - quant_nodes = {'QGemm': 2, 'QuantizeLinear': 1, 'DequantizeLinear': 1} + quant_nodes = {'QLinearMatMul': 2, 'QLinearAdd': 2, 'QuantizeLinear': 1, 'DequantizeLinear': 1} check_op_type_count(self, model_int8_path, **quant_nodes) qnode_io_qtypes = {'QuantizeLinear': [['i', 2, activation_proto_qtype], ['o', 0, activation_proto_qtype]]} qnode_io_qtypes.update({'DequantizeLinear': [['i', 2, activation_proto_qtype]]}) @@ -147,7 +147,7 @@ def static_quant_test_qdq(self, model_fp32_path, data_reader, activation_type, w data_reader.rewind() quantize_static(model_fp32_path, model_int8_path, data_reader, quant_format=QuantFormat.QDQ, activation_type=activation_type, weight_type=weight_type, extra_options=extra_options) - quant_nodes = {'Gemm': 2, 'QuantizeLinear': 3, 'DequantizeLinear': 7} + quant_nodes = {'MatMul': 2, 'Add': 2, 'QuantizeLinear': 5, 'DequantizeLinear': 9} check_op_type_count(self, model_int8_path, **quant_nodes) qnode_io_qtypes = {'QuantizeLinear': [['i', 2, activation_proto_qtype], ['o', 0, activation_proto_qtype]]} check_qtype_by_node_type(self, model_int8_path, qnode_io_qtypes) diff --git a/onnxruntime/test/testdata/kernel_def_hashes/contrib.cpu.json b/onnxruntime/test/testdata/kernel_def_hashes/contrib.cpu.json index 2a85dab3b5fd0..120e6de7c00f3 100644 --- a/onnxruntime/test/testdata/kernel_def_hashes/contrib.cpu.json +++ b/onnxruntime/test/testdata/kernel_def_hashes/contrib.cpu.json @@ -283,10 +283,6 @@ "QEmbedLayerNormalization com.microsoft CPUExecutionProvider", 9235385557940152248 ], - [ - "QGemm com.microsoft CPUExecutionProvider", - 13009794669709617232 - ], [ "QGemm com.microsoft CPUExecutionProvider", 13737193491843065240