diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index ad4bab1d09745..35ac6e0e5b154 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -405,7 +405,7 @@ Do not modify directly.*
|Pad|*in* data:**T**
*in* pads:**tensor(int64)**
*in* value:**T**
*out* output:**T**|1+|**T** = tensor(float)|
|QAttention|*in* input:**T1**
*in* weight:**T2**
*in* bias:**T3**
*in* input_scale:**T3**
*in* weight_scale:**T3**
*in* mask_index:**T4**
*in* input_zero_point:**T1**
*in* weight_zero_point:**T2**
*in* past:**T3**
*out* output:**T3**
*out* present:**T3**|1+|**T1** = tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(float)
**T4** = tensor(int32)|
|QEmbedLayerNormalization|*in* input_ids:**T1**
*in* segment_ids:**T1**
*in* word_embedding_quant:**T2**
*in* position_embedding_quant:**T2**
*in* segment_embedding:**T2**
*in* gamma_quant:**T2**
*in* beta_quant:**T2**
*in* mask:**T1**
*in* word_embedding_scale:**T**
*in* position_embedding_scale:**T**
*in* segment_embedding_scale:**T**
*in* gamma_scale:**T**
*in* beta_scale:**T**
*in* word_embedding_zero_point:**T2**
*in* position_embedding_zero_point:**T2**
*in* segment_embedding_zero_point:**T2**
*in* gamma_zero_point:**T2**
*in* beta_zero_point:**T2**
*out* layernorm_out:**T**
*out* mask_index_out:**T1**|1+|**T** = tensor(float)|
-|QGemm|*in* A:**TA**
*in* a_scale:**T**
*in* a_zero_point:**TA**
*in* B:**TB**
*in* b_scale:**T**
*in* b_zero_point:**TB**
*in* C:**TC**
*in* y_scale:**T**
*in* y_zero_point:**TYZ**
*out* Y:**TY**|1+|**T** = tensor(float)
**TA** = tensor(int8), tensor(uint8)
**TB** = tensor(int8), tensor(uint8)
**TC** = tensor(int32)
**TY** = tensor(float), tensor(int8), tensor(uint8)
**TYZ** = tensor(int8), tensor(uint8)|
+|QGemm|*in* A:**TA**
*in* a_scale:**T**
*in* a_zero_point:**TA**
*in* B:**TB**
*in* b_scale:**T**
*in* b_zero_point:**TB**
*in* C:**TC**
*in* y_scale:**T**
*in* y_zero_point:**TYZ**
*out* Y:**TY**|1+|**T** = tensor(float)
**TA** = tensor(uint8)
**TB** = tensor(int8), tensor(uint8)
**TC** = tensor(int32)
**TY** = tensor(float), tensor(uint8)
**TYZ** = tensor(uint8)|
|QLinearAdd|*in* A:**T**
*in* A_scale:**tensor(float)**
*in* A_zero_point:**T**
*in* B:**T**
*in* B_scale:**tensor(float)**
*in* B_zero_point:**T**
*in* C_scale:**tensor(float)**
*in* C_zero_point:**T**
*out* C:**T**|1+|**T** = tensor(int8), tensor(uint8)|
|QLinearConv|*in* x:**T1**
*in* x_scale:**tensor(float)**
*in* x_zero_point:**T1**
*in* w:**T2**
*in* w_scale:**tensor(float)**
*in* w_zero_point:**T2**
*in* y_scale:**tensor(float)**
*in* y_zero_point:**T3**
*in* B:**T4**
*out* y:**T3**|1+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(int8), tensor(uint8)
**T4** = tensor(int32)|
|QLinearLeakyRelu|*in* X:**T**
*in* X_scale:**tensor(float)**
*in* X_zero_point:**T**
*in* Y_scale:**tensor(float)**
*in* Y_zero_point:**T**
*out* Y:**T**|1+|**T** = tensor(int8), tensor(uint8)|
diff --git a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc
index 08125ce51bca2..e9d22000b07fb 100644
--- a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc
+++ b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc
@@ -78,7 +78,6 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int8_t, NhwcMaxPool);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, NhwcMaxPool);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, QEmbedLayerNormalization);
-class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int8_t, QGemm);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, QGemm);
// ******** End: Quantization ******************* //
@@ -170,7 +169,6 @@ Status RegisterQuantizationKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
- BuildKernelCreateInfo,
BuildKernelCreateInfo,
};
diff --git a/onnxruntime/contrib_ops/cpu/quantization/quant_gemm.cc b/onnxruntime/contrib_ops/cpu/quantization/quant_gemm.cc
index 8d83ab7f681db..9bc7b8e76907e 100644
--- a/onnxruntime/contrib_ops/cpu/quantization/quant_gemm.cc
+++ b/onnxruntime/contrib_ops/cpu/quantization/quant_gemm.cc
@@ -44,10 +44,8 @@ class QGemm : protected GemmBase, public MatMulIntegerBase {
AllocatorPtr allocator;
ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator));
- bool a_is_signed = a->IsDataType();
- const uint8_t* a_data = static_cast(a->DataRaw());
-
BufferUniquePtr a_trans_buffer;
+ const uint8_t* a_data = a->template Data();
if (trans_A_ == CblasTrans) {
a_data = quantization::TransPoseInputData(a_data, a_trans_buffer, allocator, K, M);
}
@@ -84,12 +82,12 @@ class QGemm : protected GemmBase, public MatMulIntegerBase {
GemmBroadcastBias(M, N, 1.f, c->template Data(), &(c->Shape()), gemm_output_data);
}
- MLAS_GEMM_QUANT_SHAPE_PARAMS gemm_shape{M, N, K, a_is_signed, b_is_signed, c != nullptr};
+ MLAS_GEMM_QUANT_SHAPE_PARAMS gemm_shape{M, N, K, false /*AIsSigned*/, b_is_signed, c != nullptr};
MLAS_GEMM_QUANT_DATA_PARAMS gemm_param;
gemm_param.A = a_data;
gemm_param.lda = gemm_shape.K;
- gemm_param.ZeroPointA = *(static_cast(a_zp->DataRaw()));
+ gemm_param.ZeroPointA = *(a_zp->template Data());
gemm_param.B = b_data;
gemm_param.ldb = gemm_shape.N;
@@ -222,20 +220,5 @@ ONNX_OPERATOR_TYPED_KERNEL_EX(
.TypeConstraint("TY", {DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}),
QGemm);
-ONNX_OPERATOR_TYPED_KERNEL_EX(
- QGemm,
- kMSDomain,
- 1,
- int8_t,
- kCpuExecutionProvider,
- KernelDefBuilder()
- .TypeConstraint("T", DataTypeImpl::GetTensorType())
- .TypeConstraint("TA", DataTypeImpl::GetTensorType())
- .TypeConstraint("TB", DataTypeImpl::GetTensorType())
- .TypeConstraint("TC", DataTypeImpl::GetTensorType())
- .TypeConstraint("TYZ", DataTypeImpl::GetTensorType())
- .TypeConstraint("TY", {DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}),
- QGemm);
-
} // namespace contrib
} // namespace onnxruntime
diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc
index f89f201fd7aab..63741cfb9773f 100644
--- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc
+++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc
@@ -223,67 +223,5 @@ Status MatMulReplaceWithQLinear::Run(Graph& graph, const NodesToOptimize& select
}
}
-static std::vector GetGemmMoveInfo(bool does_q_node_exist) {
- NTO::NodeLocation dq_A{NTO::NodeType::kInput, 0};
- NTO::NodeLocation dq_B{NTO::NodeType::kInput, 1};
- NTO::NodeLocation dq_bias{NTO::NodeType::kInput, 2};
- NTO::NodeLocation target{NTO::NodeType::kTarget, 0};
- NTO::NodeLocation q{NTO::NodeType::kOutput, 0};
-
- std::vector moves{
- MoveAll(dq_A, ArgType::kInput), // append all inputs from DQ of A
- MoveAll(dq_B, ArgType::kInput), // append all inputs from DQ of B
- MoveAndAppend(dq_bias, ArgType::kInput, 0, ArgType::kInput, true, true)}; // (optional) append bias
-
- if (does_q_node_exist) {
- moves.push_back(MoveAndAppend(q, ArgType::kInput, 1, ArgType::kInput)); // append scale (input 1) from Q
- moves.push_back(MoveAndAppend(q, ArgType::kInput, 2, ArgType::kInput)); // append zp (input 2) from Q
- moves.push_back(MoveAll(q, ArgType::kOutput)); // and use the outputs from Q
- } else {
- moves.push_back(MoveAll(target, ArgType::kOutput));
- }
-
- return moves;
-}
-
-GemmReplaceWithQuant::GemmReplaceWithQuant()
- : qgemm_with_float_as_output_replacer_(kMSDomain, GetGemmMoveInfo(false), "QGemm"),
- qgemm_with_8bits_as_output_replacer_(kMSDomain, GetGemmMoveInfo(true), "QGemm") {
-}
-
-Status GemmReplaceWithQuant::Run(Graph& graph, const NodesToOptimize& selected_nodes) const {
- RemoveAttrBeta(selected_nodes);
- bool is_output_float = selected_nodes.num_outputs == 0;
- if (is_output_float) {
- return qgemm_with_float_as_output_replacer_.Run(graph, selected_nodes);
- }
-
- return qgemm_with_8bits_as_output_replacer_.Run(graph, selected_nodes);
-}
-
-#if !defined(ORT_MINIMAL_BUILD)
-Status GemmReplaceWithQuant::RunForSave(Graph& graph,
- const NodesToOptimize& selected_nodes,
- const SatRuntimeOptimizationSaveContext& save_context,
- SavedState& saved_state,
- bool& graph_modified) const {
- RemoveAttrBeta(selected_nodes);
- bool is_output_float = selected_nodes.num_outputs == 0;
- if (is_output_float) {
- return qgemm_with_float_as_output_replacer_.RunForSave(graph,
- selected_nodes,
- save_context,
- saved_state,
- graph_modified);
- }
-
- return qgemm_with_8bits_as_output_replacer_.RunForSave(graph,
- selected_nodes,
- save_context,
- saved_state,
- graph_modified);
-}
-#endif // !defined(ORT_MINIMAL_BUILD)
-
} // namespace QDQ
} // namespace onnxruntime
diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h
index a81a4ec3ea0c0..2fa72a19c03b1 100644
--- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h
+++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h
@@ -67,25 +67,5 @@ struct MatMulReplaceWithQLinear : public Action {
BinaryReplaceWithQLinear qlinear_matmul_replacer_;
};
-struct GemmReplaceWithQuant : public Action {
- GemmReplaceWithQuant();
-
- Status Run(Graph&, const NodesToOptimize& selected_nodes) const override;
-
-#if !defined(ORT_MINIMAL_BUILD)
- Status RunForSave(Graph& /*graph*/, const NodesToOptimize& /*selected_nodes*/,
- const SatRuntimeOptimizationSaveContext& /*save_context*/,
- SavedState& /*saved_state*/, bool& /*graph_modified*/) const override;
-#endif // !defined(ORT_MINIMAL_BUILD)
-
- static inline void RemoveAttrBeta(const NodesToOptimize& selected_nodes) {
- selected_nodes.Target().ClearAttribute("beta");
- }
-
- private:
- QDQReplaceWithNew qgemm_with_float_as_output_replacer_;
- QDQReplaceWithNew qgemm_with_8bits_as_output_replacer_;
-};
-
} // namespace QDQ
} // namespace onnxruntime
diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc
index ab1fcdcec3e73..50bc405378fb0 100644
--- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc
+++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc
@@ -167,26 +167,6 @@ void MatMulQDQRules(SelectorActionRegistry& qdq_selector_action_registry, bool i
#endif
}
-void GemmQDQRules(SelectorActionRegistry& qdq_selector_action_registry) {
- // 3 to 5 nodes. 0=DQ A, 1=DQ B, 2=DQ C(optional), 3=Gemm, 4=Q Y(optional)
- // Replace with QGemm
- // Delete all original nodes.
- const std::string action_name{"Gemm"};
-
- std::unique_ptr action = std::make_unique();
-
-#if !defined(ORT_MINIMAL_BUILD)
- std::unique_ptr selector = std::make_unique();
- qdq_selector_action_registry.RegisterSelectorAndAction(action_name,
- {{"Gemm", {}}},
- std::move(selector),
- std::move(action));
-
-#else
- qdq_selector_action_registry.RegisterAction(action_name, std::move(action));
-#endif
-}
-
SelectorActionRegistry CreateSelectorActionRegistry(bool is_int8_allowed) {
SelectorActionRegistry qdq_selector_action_registry;
@@ -197,7 +177,6 @@ SelectorActionRegistry CreateSelectorActionRegistry(bool is_int8_allowed) {
VariadicOpQDQRules(qdq_selector_action_registry);
ConvQDQRules(qdq_selector_action_registry, is_int8_allowed);
MatMulQDQRules(qdq_selector_action_registry, is_int8_allowed);
- GemmQDQRules(qdq_selector_action_registry);
return qdq_selector_action_registry;
}
diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc
index 3722504695176..db0c1dfe58646 100644
--- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc
+++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc
@@ -40,23 +40,16 @@ static std::vector FindQDQNodes(const GraphViewer& graph_viewer, co
bool NodeGroupSelector::CheckQDQNodes(const GraphViewer& graph_viewer, const Node& node,
const std::vector& dq_nodes,
const std::vector& q_nodes,
- int num_dq_inputs,
- bool is_empty_q_nodes_allowed) const {
+ int num_dq_inputs) const {
if (num_dq_inputs == -1) {
num_dq_inputs = NumActualValues(node, true);
}
- // The input is a Graph Viewer, so cannot use graph_utils or optimizer_utils
- if (num_dq_inputs != gsl::narrow_cast(dq_nodes.size())) {
- return false;
- }
-
- if (q_nodes.empty()) {
- return is_empty_q_nodes_allowed;
- }
-
int num_outputs = NumActualValues(node, false); // number of outputs that exist
- return (num_outputs == gsl::narrow_cast(q_nodes.size())) &&
+
+ // The input is a Graph Viewer, so cannot use graph_utils or optimizer_utils
+ return num_dq_inputs == gsl::narrow_cast(dq_nodes.size()) &&
+ num_outputs == gsl::narrow_cast(q_nodes.size()) &&
q_nodes.size() == node.GetOutputEdgesCount() &&
!graph_viewer.NodeProducesGraphOutput(node);
}
@@ -255,48 +248,6 @@ bool MatMulNodeGroupSelector::Check(const GraphViewer& graph_viewer,
}
}
-bool GemmNodeGroupSelector::Check(const GraphViewer& graph_viewer,
- const Node& node,
- const std::vector& dq_nodes,
- const std::vector& q_nodes) const {
- if (!CheckQDQNodes(graph_viewer, node, dq_nodes, q_nodes,
- -1 /*num_dq_inputs*/, true /*is_empty_q_nodes_allowed*/)) {
- return false;
- }
-
- // input and output types need to be same
- int32_t dt_A = dq_nodes[0]->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type();
- int32_t dt_B = dq_nodes[1]->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type();
-
- if (dt_A == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8) {
- if (dt_A != dt_B) { // if A is signed int, B must be signed int
- return false;
- }
- }
-
- if (!q_nodes.empty()) {
- int32_t dt_Y = q_nodes[0]->OutputDefs()[0]->TypeAsProto()->tensor_type().elem_type();
- if (dt_A != dt_Y) { // activation and output must be same type
- return false;
- }
- }
-
- if (dq_nodes.size() < 3) { // no bias
- return true;
- }
-
- if (node.GetAttributes().at("beta").f() != 1.0) { // beta needs to be 1.0
- return false;
- }
-
- int32_t dt_bias = dq_nodes[2]->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type();
- return dt_bias == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32;
-}
-
-void GemmSelector::UpdateBuilder(NodesToOptimizeIndicesBuilder& builder) const {
- builder.input_nodes.resize(3, NodesToOptimizeIndices::kEmptyNodeIndex);
-}
-
} // namespace QDQ
} // namespace onnxruntime
diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h
index ddafc20d0bb85..7bef53e8c0465 100644
--- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h
+++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h
@@ -34,8 +34,7 @@ class NodeGroupSelector {
bool CheckQDQNodes(const GraphViewer& graph_viewer, const Node& node,
const std::vector& dq_nodes,
const std::vector& q_nodes,
- int num_dq_inputs = -1,
- bool is_empty_q_nodes_allowed = false) const;
+ int num_dq_inputs = -1) const;
private:
// derived classes should implement this check
@@ -121,15 +120,6 @@ class MatMulNodeGroupSelector : public NodeGroupSelector {
bool matmulintegertofloat_allowed_;
};
-// Input: DQ nodes for A, B and optional C
-// Output: optional Q node for Y
-class GemmNodeGroupSelector : public NodeGroupSelector {
- private:
- bool Check(const GraphViewer& graph_viewer, const Node& node,
- const std::vector& dq_nodes,
- const std::vector& q_nodes) const override;
-};
-
/*
* NodeSelector instances for use in the QDQ::SelectorActionTransformer.
*/
@@ -200,16 +190,6 @@ class MatMulSelector : public BaseSelector {
: BaseSelector(std::make_unique(int8_allowed, /*matmulintegertofloat_allowed*/ true)) {}
};
-// Input: DQ nodes for A, B and optional C
-// Output: optional Q node for Y
-class GemmSelector : public BaseSelector {
- public:
- GemmSelector()
- : BaseSelector(std::make_unique()) {}
-
- void UpdateBuilder(NodesToOptimizeIndicesBuilder&) const override;
-};
-
} // namespace QDQ
} // namespace onnxruntime
diff --git a/onnxruntime/core/optimizer/selectors_actions/helpers.cc b/onnxruntime/core/optimizer/selectors_actions/helpers.cc
index 41444ba8000b4..c83f30f3fc474 100644
--- a/onnxruntime/core/optimizer/selectors_actions/helpers.cc
+++ b/onnxruntime/core/optimizer/selectors_actions/helpers.cc
@@ -314,16 +314,6 @@ Status MoveInputOutput(Graph& graph, const NodesToOptimize& selected_nodes, Node
if (src != nullptr) {
ORT_RETURN_IF_ERROR(MoveInputOutputImpl(graph, move.value_move_info, *src, dest,
only_update_dest_definitions));
- } else if (move.value_move_info.optional &&
- move.value_move_info.fill_optional_with_empty) {
- auto& dest_defs = (move.value_move_info.dest_slot.in_out == ArgType::kInput)
- ? dest.MutableInputDefs()
- : dest.MutableOutputDefs();
- dest_defs.push_back(&graph.GetOrCreateNodeArg("", nullptr));
-
- if (move.value_move_info.dest_slot.in_out == ArgType::kInput) {
- dest.MutableInputArgsCount().push_back(1);
- }
}
}
}
diff --git a/onnxruntime/core/optimizer/selectors_actions/helpers.h b/onnxruntime/core/optimizer/selectors_actions/helpers.h
index a57ece5339112..c70fee9e9e668 100644
--- a/onnxruntime/core/optimizer/selectors_actions/helpers.h
+++ b/onnxruntime/core/optimizer/selectors_actions/helpers.h
@@ -160,24 +160,18 @@ struct ValueMoveInfo {
}
// append single value (may be variadic) from source to destination
- ValueMoveInfo(InOutDefSlot src_slot_in,
- ArgType dest_slot_type,
- bool is_optional = false,
- bool fill_optional_with_empty = false)
+ ValueMoveInfo(InOutDefSlot src_slot_in, ArgType dest_slot_type, bool is_optional = false)
: src_slot(src_slot_in),
dest_slot{dest_slot_type, -1},
copy_all{false},
append{true},
- optional{is_optional},
- fill_optional_with_empty{fill_optional_with_empty} {}
+ optional{is_optional} {}
InOutDefSlot src_slot;
InOutDefSlot dest_slot;
- bool copy_all{false}; // ignore src_slot.idx and copy all values
- bool append{false}; // ignore dest_slot.idx and append to existing values
- bool optional{false}; // optional copy that can be skipped if source node is missing
- bool fill_optional_with_empty; // fill optional NodeArg by NodeArg with empty name.
- // Only support in 'append single value' mode.
+ bool copy_all{false}; // ignore src_slot.idx and copy all values
+ bool append{false}; // ignore dest_slot.idx and append to existing values
+ bool optional{false}; // optional copy that can be skipped if source node is missing
private:
ValueMoveInfo() = default;
@@ -219,12 +213,10 @@ inline NodeAndMoveInfo MoveToSlot(const NodesToOptimize::NodeLocation& src_node,
inline NodeAndMoveInfo MoveAndAppend(const NodesToOptimize::NodeLocation& src_node,
ArgType src_direction, int src_slot,
ArgType dest_direction,
- bool optional = false,
- bool fill_optional_with_empty = false) {
+ bool optional = false) {
return NodeAndMoveInfo{src_node, ValueMoveInfo{
InOutDefSlot{src_direction, src_slot}, // move from this slot
- dest_direction, optional,
- fill_optional_with_empty}}; // append here
+ dest_direction, optional}}; // append here
}
// move all inputs/outputs from the source node to the target/replacement node
diff --git a/onnxruntime/python/tools/quantization/onnx_quantizer.py b/onnxruntime/python/tools/quantization/onnx_quantizer.py
index 6349bd8cd53eb..ed99ee1560752 100644
--- a/onnxruntime/python/tools/quantization/onnx_quantizer.py
+++ b/onnxruntime/python/tools/quantization/onnx_quantizer.py
@@ -553,7 +553,7 @@ def find_quantized_value(self, input_name):
return self.parent.find_quantized_value(input_name)
return None
- def quantize_bias_static(self, bias_name, input_name, weight_name, beta = 1.0):
+ def quantize_bias_static(self, bias_name, input_name, weight_name):
'''
Quantized the bias. Zero Point == 0 and Scale == Input_Scale * Weight_Scale
'''
@@ -584,7 +584,7 @@ def quantize_bias_static(self, bias_name, input_name, weight_name, beta = 1.0):
input_scale = self.tensor_proto_to_array(inputscale_initializer)
# calcuate scale for bias
- bias_scale = input_scale * weight_scale * beta
+ bias_scale = input_scale * weight_scale
# quantize bias
quantized_data = (np.asarray(bias_data) / bias_scale).round().astype(np.int32)
diff --git a/onnxruntime/python/tools/quantization/operators/gemm.py b/onnxruntime/python/tools/quantization/operators/gemm.py
deleted file mode 100644
index f297bfb428a19..0000000000000
--- a/onnxruntime/python/tools/quantization/operators/gemm.py
+++ /dev/null
@@ -1,117 +0,0 @@
-import onnx
-import numpy as np
-import logging
-from .base_operator import QuantOperatorBase
-from .qdq_base_operator import QDQOperatorBase
-from ..quant_utils import find_by_name, get_mul_node, QuantizedValue, QuantizedValueType, attribute_to_kwarg, ms_domain
-from onnx import onnx_pb as onnx_proto
-
-
-def is_B_transposed(gemm_node):
- transB_attribute = [attr for attr in gemm_node.attribute if attr.name == 'transB']
- if len(transB_attribute):
- return 0 < onnx.helper.get_attribute_value(transB_attribute[0])
-
- return False
-
-def get_beta(gemm_node):
- beta_attribute = [attr for attr in gemm_node.attribute if attr.name == 'beta']
- if len(beta_attribute):
- return onnx.helper.get_attribute_value(beta_attribute[0])
-
- return 1.0
-
-def set_default_beta(gemm_node):
- beta_attribute = [attr for attr in gemm_node.attribute if attr.name == 'beta']
- if len(beta_attribute):
- beta_attribute[0].f = 1.0
-
- return 1.0
-
-class QLinearGemm(QuantOperatorBase):
- def __init__(self, onnx_quantizer, onnx_node):
- super().__init__(onnx_quantizer, onnx_node)
-
- def quantize(self):
- node = self.node
- assert (node.op_type == "Gemm")
-
- data_found, output_scale_name, output_zp_name, _, _ = \
- self.quantizer._get_quantization_params(node.output[0])
-
- if self.quantizer.is_input_a_weight(node.input[1]) and self.quantizer.is_per_channel():
- (quantized_input_names, zero_point_names, scale_names, nodes) = \
- self.quantizer.quantize_inputs(node, [0], reduce_range=self.quantizer.reduce_range)
- quant_weight_tuple = self.quantizer.quantize_weight_per_channel(node.input[1], onnx_proto.TensorProto.INT8,
- 0 if is_B_transposed(node) else 1)
- quantized_input_names.append(quant_weight_tuple[0])
- zero_point_names.append(quant_weight_tuple[1])
- scale_names.append(quant_weight_tuple[2])
- else:
- (quantized_input_names, zero_point_names, scale_names, nodes) = \
- self.quantizer.quantize_inputs(node, [0, 1], reduce_range=self.quantizer.reduce_range)
-
- if not data_found or quantized_input_names is None:
- return super().quantize()
-
- quantized_bias_name = ""
- if len(node.input) == 3:
- if not self.quantizer.is_input_a_weight(node.input[2]):
- return super().quantize()
-
- quantized_bias_name = self.quantizer.quantize_bias_static(node.input[2], node.input[0], node.input[1], get_beta(self.node))
-
- qgemm_output = node.output[0] + "_quantized"
- qgemm_name = qgemm_name = node.name + "_quant" if node.name != "" else ""
-
- kwargs = {}
- for attribute in node.attribute:
- if attribute.name != "beta":
- kwargs.update(attribute_to_kwarg(attribute))
- kwargs["domain"] = ms_domain
-
- # generate input
- qgemm_inputs = []
- for i in range(2):
- qgemm_inputs.extend([quantized_input_names[i], scale_names[i], zero_point_names[i]])
-
- qgemm_inputs.extend([quantized_bias_name, output_scale_name, output_zp_name])
-
- qgemm_node = onnx.helper.make_node("QGemm", qgemm_inputs, [qgemm_output],
- qgemm_name, **kwargs)
- nodes.append(qgemm_node)
-
- # Create an entry for this quantized value
- q_output = QuantizedValue(node.output[0], qgemm_output, output_scale_name, output_zp_name,
- QuantizedValueType.Input)
- self.quantizer.quantized_value_map[node.output[0]] = q_output
-
- self.quantizer.new_nodes += nodes
-
-
-class QDQGemm(QDQOperatorBase):
- def __init__(self, onnx_quantizer, onnx_node):
- super().__init__(onnx_quantizer, onnx_node)
-
- def quantize(self):
- node = self.node
- assert (node.op_type == "Gemm")
-
- self.quantizer.quantize_tensor(node.input[0])
- if not self.disable_qdq_for_node_output:
- self.quantizer.quantize_tensor(node.output[0])
-
- if self.quantizer.is_per_channel():
- self.quantizer.quantize_tensor_per_channel(node.input[1], 0 if is_B_transposed(node) else 1)
- else:
- self.quantizer.quantize_tensor(node.input[1])
-
- if len(node.input) == 3:
- if self.quantizer.is_input_a_weight(node.input[2]):
- self.quantizer.quantize_bias_tensor(node.input[2], node.input[0], node.input[1], get_beta(self.node))
- set_default_beta(self.node)
- else:
- logging.warning(
- "Bias of Gemm node '{}' is not constant. Please exclude this node for better performance."
- .format(self.node.name))
-
diff --git a/onnxruntime/python/tools/quantization/qdq_quantizer.py b/onnxruntime/python/tools/quantization/qdq_quantizer.py
index 09a45799f4737..f5797282dda06 100644
--- a/onnxruntime/python/tools/quantization/qdq_quantizer.py
+++ b/onnxruntime/python/tools/quantization/qdq_quantizer.py
@@ -83,11 +83,11 @@ def quantize_tensor_per_channel(self, tensor_name, axis):
tensor_name))
self.quantize_tensor(tensor_name)
- def quantize_bias_tensor(self, bias_name, input_name, weight_name, beta = 1.0):
+ def quantize_bias_tensor(self, bias_name, input_name, weight_name):
weight = find_by_name(bias_name, self.model.initializer())
if weight is not None:
if weight.data_type == onnx_proto.TensorProto.FLOAT:
- self.bias_to_quantize.append((bias_name, input_name, weight_name, beta))
+ self.bias_to_quantize.append((bias_name, input_name, weight_name))
else:
logging.warning("Expected {} to be a weight".format(bias_name))
@@ -222,11 +222,11 @@ def quantize_tensors(self):
self.quantized_value_map[tensor_name] = quantized_value
def quantize_bias_tensors(self):
- for bias_name, input_name, weight_name, beta in self.bias_to_quantize:
+ for bias_name, input_name, weight_name in self.bias_to_quantize:
if bias_name in self.quantized_value_map.keys():
continue
# Quantize the input
- self.quantize_bias_static(bias_name, input_name, weight_name, beta)
+ self.quantize_bias_static(bias_name, input_name, weight_name)
self.model.remove_initializer(find_by_name(bias_name, self.model.initializer()))
quant_value = self.quantized_value_map[bias_name]
inputs = [quant_value.q_name, quant_value.scale_name, quant_value.zp_name]
diff --git a/onnxruntime/python/tools/quantization/quantize.py b/onnxruntime/python/tools/quantization/quantize.py
index 826d53884facd..a0a0b935226cd 100644
--- a/onnxruntime/python/tools/quantization/quantize.py
+++ b/onnxruntime/python/tools/quantization/quantize.py
@@ -42,16 +42,15 @@ def optimize_model(model_path: Path):
return optimized_model
-def load_model(model_path: Path, optimize=True, handle_gemm_with_matmul=True):
-
- model = optimize_model(Path(model_path)) if optimize else onnx.load(Path(model_path))
-
- if handle_gemm_with_matmul:
- onnx_model = ONNXModel(model)
+def load_model(model_path: Path, optimize=True):
+ if optimize:
+ #optimize the original model
+ onnx_model = ONNXModel(optimize_model(Path(model_path)))
+ # to support GEMM
onnx_model.replace_gemm_with_matmul()
return onnx_model.model
- return model
+ return onnx.load(Path(model_path))
def quantize(model,
@@ -212,7 +211,7 @@ def quantize_static(model_input,
if not op_types_to_quantize or len(op_types_to_quantize) == 0:
op_types_to_quantize = list(QLinearOpsRegistry.keys())
- model = load_model(Path(model_input), optimize_model, False)
+ model = load_model(Path(model_input), optimize_model)
calibrator = create_calibrator(model, op_types_to_quantize, calibrate_method=calibrate_method)
calibrator.collect_data(calibration_data_reader)
diff --git a/onnxruntime/python/tools/quantization/registry.py b/onnxruntime/python/tools/quantization/registry.py
index d046cbc6dfa86..e63e1761f7143 100644
--- a/onnxruntime/python/tools/quantization/registry.py
+++ b/onnxruntime/python/tools/quantization/registry.py
@@ -18,7 +18,6 @@
from .operators.resize import QResize, QDQResize
from .operators.pooling import QLinearPool
from .operators.concat import QLinearConcat, QDQConcat
-from .operators.gemm import QLinearGemm, QDQGemm
CommonOpsRegistry = {
"Gather": GatherQuant,
@@ -37,7 +36,6 @@
QLinearOpsRegistry = {
"ArgMax": QArgMax,
"Conv": QLinearConv,
- "Gemm": QLinearGemm,
"MatMul": QLinearMatMul,
"Add": QLinearBinaryOp,
"Mul": QLinearBinaryOp,
@@ -60,7 +58,6 @@
QDQRegistry = {
"Conv": QDQConv,
- "Gemm": QDQGemm,
"Clip": QDQRemovableActivation,
"Relu": QDQRemovableActivation,
"Reshape": QDQDirect8BitOp,
diff --git a/onnxruntime/test/contrib_ops/quant_gemm_test.cc b/onnxruntime/test/contrib_ops/quant_gemm_test.cc
index c8dbea216ec2a..897320b30eef7 100644
--- a/onnxruntime/test/contrib_ops/quant_gemm_test.cc
+++ b/onnxruntime/test/contrib_ops/quant_gemm_test.cc
@@ -15,12 +15,11 @@
#include
#include
-#include
namespace onnxruntime {
namespace test {
-template
+template
void RunQuantGemmU8X8Test(const int M,
const int N,
const int K,
@@ -30,42 +29,36 @@ void RunQuantGemmU8X8Test(const int M,
bool B_is_initializer,
bool per_column = false) {
static std::default_random_engine e(123);
- static std::uniform_int_distribution random_A(std::numeric_limits::min(),
- std::numeric_limits::max());
-
- constexpr int overflow_adjust = std::is_signed_v ? 2 : 1;
- constexpr int random_B_min = std::numeric_limits::min() / overflow_adjust;
- constexpr int random_B_max = std::numeric_limits::min() / overflow_adjust;
- static std::uniform_int_distribution random_B(random_B_min,
- random_B_max);
+ static std::uniform_int_distribution n_unsigned(0, 127);
+ static std::uniform_int_distribution n_xint8(std::numeric_limits::min(), std::numeric_limits::max());
static std::uniform_real_distribution n_apha(1.0f, 2.0f);
static std::uniform_real_distribution n_scale(0.003f, 0.004f);
Eigen::MatrixXi matrix_a = Eigen::MatrixXi::Random(K, M)
- .unaryExpr([](int) { return random_A(e); });
- std::vector matrix_a_data;
+ .unaryExpr([](int) { return n_unsigned(e); });
+ std::vector matrix_a_data;
if (is_A_trans) {
Eigen::MatrixXi matrix_a_trans = matrix_a.transpose().eval();
- matrix_a_data = ToVector(matrix_a_trans.data(), M * K);
+ matrix_a_data = ToVector(matrix_a_trans.data(), M * K);
} else {
- matrix_a_data = ToVector(matrix_a.data(), M * K);
+ matrix_a_data = ToVector(matrix_a.data(), M * K);
}
- ActType a_zero_point = GetMiddle(matrix_a_data);
+ uint8_t a_zero_point = GetMiddle(matrix_a_data);
Eigen::MatrixXi matrix_a_offset = matrix_a - a_zero_point * Eigen::MatrixXi::Ones(K, M);
float a_scale = n_scale(e);
Eigen::MatrixXi matrix_b = Eigen::MatrixXi::Random(N, K)
- .unaryExpr([](int) { return random_B(e); });
- std::vector matrix_b_data;
+ .unaryExpr([](int) { return n_xint8(e); });
+ std::vector matrix_b_data;
if (is_B_trans) {
Eigen::MatrixXi matrix_b_trans = matrix_b.transpose().eval();
- matrix_b_data = ToVector(matrix_b_trans.data(), N * K);
+ matrix_b_data = ToVector(matrix_b_trans.data(), N * K);
} else {
- matrix_b_data = ToVector(matrix_b.data(), N * K);
+ matrix_b_data = ToVector(matrix_b.data(), N * K);
}
- WeightType b_zero_point = GetMiddle(matrix_b_data);
+ ScalarB b_zero_point = GetMiddle(matrix_b_data);
std::vector b_scale({n_scale(e)});
- std::vector b_zp_per_column({b_zero_point});
+ std::vector b_zp_per_column({b_zero_point});
Eigen::MatrixXi b_zp_matrix = b_zero_point * Eigen::MatrixXi::Ones(N, K);
Eigen::MatrixXf b_scale_matrix = b_scale[0] * Eigen::MatrixXf::Ones(N, M);
if (per_column) {
@@ -81,7 +74,7 @@ void RunQuantGemmU8X8Test(const int M,
float alpha = n_apha(e);
Eigen::MatrixXi matrix_c = Eigen::MatrixXi::Random(N, M)
- .unaryExpr([](int) { return random_A(e); });
+ .unaryExpr([](int) { return n_xint8(e); });
Eigen::MatrixXi matrix_int32 = (matrix_b - b_zp_matrix) * matrix_a_offset;
if (has_C) {
@@ -93,12 +86,12 @@ void RunQuantGemmU8X8Test(const int M,
test.AddAttribute("transA", is_A_trans ? 1 : 0);
test.AddAttribute("transB", is_B_trans ? 1 : 0);
test.AddAttribute("alpha", alpha);
- test.AddInput("A", is_A_trans ? std::vector({K, M}) : std::vector({M, K}), std::move(matrix_a_data));
+ test.AddInput("A", is_A_trans ? std::vector({K, M}) : std::vector({M, K}), std::move(matrix_a_data));
test.AddInput("a_scale", {}, {a_scale});
- test.AddInput("a_zero_point", {}, {a_zero_point});
- test.AddInput("B", is_B_trans ? std::vector({N, K}) : std::vector({K, N}), std::move(matrix_b_data), B_is_initializer);
+ test.AddInput("a_zero_point", {}, {a_zero_point});
+ test.AddInput("B", is_B_trans ? std::vector({N, K}) : std::vector({K, N}), std::move(matrix_b_data), B_is_initializer);
test.AddInput("b_scale", {SafeInt(b_scale.size())}, b_scale);
- test.AddInput("b_zero_point", {SafeInt(b_zp_per_column.size())}, b_zp_per_column);
+ test.AddInput("b_zero_point", {SafeInt(b_zp_per_column.size())}, b_zp_per_column);
if (has_C) {
test.AddInput("C", {M, N}, ToVector(matrix_c.data(), M * N));
@@ -108,14 +101,14 @@ void RunQuantGemmU8X8Test(const int M,
if constexpr (std::is_same_v) {
test.AddOptionalInputEdge();
- test.AddOptionalInputEdge();
+ test.AddOptionalInputEdge();
test.AddOutput("Y", {M, N}, std::vector(matrix_output.data(), matrix_output.data() + M * N));
} else {
- std::vector quant_output(M * N);
- quantization::Params quant_param = quantization::QuantizeLinear(matrix_output.data(), quant_output.data(), M * N);
+ std::vector quant_output(M * N);
+ quantization::Params quant_param = quantization::QuantizeLinear(matrix_output.data(), quant_output.data(), M * N);
test.AddInput("y_scale", {}, {quant_param.scale});
- test.AddInput("y_zero_point", {}, {quant_param.zero_point});
- test.AddOutput("Y", {M, N}, quant_output);
+ test.AddInput("y_zero_point", {}, {quant_param.zero_point});
+ test.AddOutput("Y", {M, N}, quant_output);
}
test.Run();
@@ -129,13 +122,10 @@ void RunQuantGemmTest(const int M,
bool has_C,
bool B_is_initializer,
bool per_column = false) {
- RunQuantGemmU8X8Test(M, N, K, is_A_trans, is_B_trans, has_C, B_is_initializer, per_column);
- RunQuantGemmU8X8Test(M, N, K, is_A_trans, is_B_trans, has_C, B_is_initializer, per_column);
- RunQuantGemmU8X8Test(M, N, K, is_A_trans, is_B_trans, has_C, B_is_initializer, per_column);
- RunQuantGemmU8X8Test(M, N, K, is_A_trans, is_B_trans, has_C, B_is_initializer, per_column);
-
- RunQuantGemmU8X8Test(M, N, K, is_A_trans, is_B_trans, has_C, B_is_initializer, per_column);
- RunQuantGemmU8X8Test(M, N, K, is_A_trans, is_B_trans, has_C, B_is_initializer, per_column);
+ RunQuantGemmU8X8Test(M, N, K, is_A_trans, is_B_trans, has_C, B_is_initializer, per_column);
+ RunQuantGemmU8X8Test(M, N, K, is_A_trans, is_B_trans, has_C, B_is_initializer, per_column);
+ RunQuantGemmU8X8Test(M, N, K, is_A_trans, is_B_trans, has_C, B_is_initializer, per_column);
+ RunQuantGemmU8X8Test(M, N, K, is_A_trans, is_B_trans, has_C, B_is_initializer, per_column);
}
void RunQuantGemmTestBatch(const int M, const int N, const int K) {
diff --git a/onnxruntime/test/optimizer/qdq_transformer_test.cc b/onnxruntime/test/optimizer/qdq_transformer_test.cc
index d74160217f46a..a81f95d994318 100644
--- a/onnxruntime/test/optimizer/qdq_transformer_test.cc
+++ b/onnxruntime/test/optimizer/qdq_transformer_test.cc
@@ -537,173 +537,6 @@ TEST(QDQTransformerTests, MatMul_S8S8U8) {
QDQTransformerMatMulTests(true);
}
-template
-void QDQTransformerGemmTests(bool has_output_q, bool has_bias, bool beta_not_one = false) {
- auto test_case = [&](const std::vector& input1_shape, const std::vector& input2_shape) {
- auto build_test_case = [&](ModelTestBuilder& builder) {
- auto* input1_arg = builder.MakeInput(input1_shape, -1.f, 1.f);
- auto* input2_arg = builder.MakeInput(input2_shape, -1.f, 1.f);
- auto* output_arg = builder.MakeOutput();
-
- typedef std::numeric_limits Input1Limits;
- typedef std::numeric_limits Input2Limits;
- typedef std::numeric_limits OutputTypeLimits;
-
- std::vector input_args;
-
- // add QDQ A
- auto* q1_output = builder.MakeIntermediate();
- auto* dq1_output = builder.MakeIntermediate();
- builder.AddQuantizeLinearNode(input1_arg,
- .039f,
- (Input1Limits::max() + Input1Limits::min()) / 2 + 1,
- q1_output);
- builder.AddDequantizeLinearNode(q1_output,
- .039f,
- (Input2Limits::max() + Input1Limits::min()) / 2 + 1,
- dq1_output);
-
- input_args.push_back(dq1_output);
-
- // add QDQ B
- auto* q2_output = builder.MakeIntermediate();
- auto* dq2_output = builder.MakeIntermediate();
- builder.AddQuantizeLinearNode(input2_arg,
- .04f,
- (Input2Limits::max() + Input2Limits::min()) / 2 + 1,
- q2_output);
- builder.AddDequantizeLinearNode(q2_output,
- .04f,
- (Input2Limits::max() + Input2Limits::min()) / 2 + 1,
- dq2_output);
- input_args.push_back(dq2_output);
-
- if (has_bias) {
- auto* dq_bias_output = builder.MakeIntermediate();
- auto* bias = builder.MakeInitializer({input2_shape[1]}, static_cast(0), static_cast(127));
- builder.AddDequantizeLinearNode(bias, 0.00156f,
- 0,
- dq_bias_output);
- input_args.push_back(dq_bias_output);
- }
-
- Node* gemm_node = nullptr;
-
- if (has_output_q) {
- auto* gemm_op_output = builder.MakeIntermediate();
- gemm_node = &builder.AddNode("Gemm", input_args, {gemm_op_output});
-
- // add QDQ output
- auto* q3_output = builder.MakeIntermediate();
- builder.AddQuantizeLinearNode(gemm_op_output,
- .039f,
- (OutputTypeLimits::max() + OutputTypeLimits::min()) / 2 + 1,
- q3_output);
- builder.AddDequantizeLinearNode(q3_output,
- .039f,
- (OutputTypeLimits::max() + OutputTypeLimits::min()) / 2 + 1,
- output_arg);
- } else {
- gemm_node = &builder.AddNode("Gemm", input_args, {output_arg});
- }
-
- if (beta_not_one) {
- gemm_node->AddAttribute("beta", 2.0f);
- }
- };
-
- auto check_binary_op_graph = [&](InferenceSessionWrapper& session) {
- auto op_to_count = CountOpsInGraph(session.GetGraph());
- if ((!has_output_q || std::is_same_v)&&
- (!has_bias || (std::is_same_v && !beta_not_one)) &&
- (std::is_same_v || std::is_same_v)) {
- EXPECT_EQ(op_to_count["com.microsoft.QGemm"], 1);
- EXPECT_EQ(op_to_count["Gemm"], 0);
- EXPECT_EQ(op_to_count["QuantizeLinear"], 2);
- EXPECT_EQ(op_to_count["DequantizeLinear"], has_output_q ? 1 : 0);
- } else {
- int q_count = 2; // Q for A and B
- int dq_count = 2; // DQ for A and B
- if (has_bias) {
- dq_count++;
- }
- if (has_output_q) {
- q_count++;
- dq_count++;
- }
- EXPECT_EQ(op_to_count["com.microsoft.QGemm"], 0);
- EXPECT_EQ(op_to_count["Gemm"], 1);
- EXPECT_EQ(op_to_count["QuantizeLinear"], q_count);
- EXPECT_EQ(op_to_count["DequantizeLinear"], dq_count);
- }
- };
-
- TransformerTester(build_test_case,
- check_binary_op_graph,
- TransformerLevel::Level1,
- TransformerLevel::Level2,
- 12 /*opset_version*/,
- 0.01 /*per_sample_tolerance*/,
- 0.01 /*relative_per_sample_tolerance*/,
- std::make_unique());
- };
-
- test_case({2, 2}, {2, 4});
- test_case({13, 15}, {15, 15});
-}
-
-template
-void QDQTransformerGemmTests() {
- QDQTransformerGemmTests(false, false);
- QDQTransformerGemmTests(false, true);
- QDQTransformerGemmTests(true, false);
- QDQTransformerGemmTests(true, true);
- QDQTransformerGemmTests(false, false, true);
- QDQTransformerGemmTests(false, true, true);
- QDQTransformerGemmTests(true, false, true);
- QDQTransformerGemmTests(true, true, true);
-}
-
-TEST(QDQTransformerTests, Gemm_U8U8U8) {
- QDQTransformerGemmTests();
- QDQTransformerGemmTests();
-}
-
-TEST(QDQTransformerTests, Gemm_U8S8S8) {
- QDQTransformerGemmTests();
- QDQTransformerGemmTests();
-}
-
-TEST(QDQTransformerTests, Gemm_U8U8S8) {
- QDQTransformerGemmTests();
- QDQTransformerGemmTests();
-}
-
-TEST(QDQTransformerTests, Gemm_U8S8U8) {
- QDQTransformerGemmTests();
- QDQTransformerGemmTests();
-}
-
-TEST(QDQTransformerTests, Gemm_S8S8S8) {
- QDQTransformerGemmTests();
- QDQTransformerGemmTests();
-}
-
-TEST(QDQTransformerTests, Gemm_S8U8U8) {
- QDQTransformerGemmTests();
- QDQTransformerGemmTests();
-}
-
-TEST(QDQTransformerTests, Gemm_S8U8S8) {
- QDQTransformerGemmTests();
- QDQTransformerGemmTests();
-}
-
-TEST(QDQTransformerTests, Gemm_S8S8U8) {
- QDQTransformerGemmTests();
- QDQTransformerGemmTests();
-}
-
TEST(QDQTransformerTests, Gather) {
auto test_case = [&](const std::vector& input1_shape, const std::vector& weights_shape) {
auto build_test_case = [&](ModelTestBuilder& builder) {
diff --git a/onnxruntime/test/providers/cpu/math/matmul_integer_test.cc b/onnxruntime/test/providers/cpu/math/matmul_integer_test.cc
index 489d6ff915692..7e41fe020a91d 100644
--- a/onnxruntime/test/providers/cpu/math/matmul_integer_test.cc
+++ b/onnxruntime/test/providers/cpu/math/matmul_integer_test.cc
@@ -302,12 +302,12 @@ TEST(MatmulIntegerOpTest, MatMulInteger_PerColumn_ND) {
}
// [M x N] = [M x K] x [K x N] = [batch_seq x input_dim] x [input_dim x embed_dim]
-template
+template
void RunMatMulIntegerU8X8Test(const int M, const int N, const int K, bool non_zero_zp, bool B_is_initializer, bool per_column_zp = false) {
OpTester test("MatMulInteger", 10);
static std::default_random_engine e(123);
static std::uniform_int_distribution n_unsigned(0, 127);
- static std::uniform_int_distribution n_xint8(std::numeric_limits::min(), std::numeric_limits::max());
+ static std::uniform_int_distribution n_xint8(std::numeric_limits::min(), std::numeric_limits::max());
Eigen::MatrixXi matrix_a = Eigen::MatrixXi::Random(K, M)
.unaryExpr([](int) { return n_unsigned(e); });
@@ -317,9 +317,9 @@ void RunMatMulIntegerU8X8Test(const int M, const int N, const int K, bool non_ze
Eigen::MatrixXi matrix_b = Eigen::MatrixXi::Random(N, K)
.unaryExpr([](int) { return n_xint8(e); });
- std::vector matrix_b_data = ToVector(matrix_b.data(), N * K);
- WeightType b_zero_point = non_zero_zp ? GetMiddle(matrix_b_data) : 0;
- std::vector b_zp_per_column(N, b_zero_point);
+ std::vector matrix_b_data = ToVector(matrix_b.data(), N * K);
+ ScalarB b_zero_point = non_zero_zp ? GetMiddle(matrix_b_data) : 0;
+ std::vector b_zp_per_column(N, b_zero_point);
Eigen::MatrixXi b_zp_matrix = b_zero_point * Eigen::MatrixXi::Ones(N, K);
if (non_zero_zp && per_column_zp) {
for (int i = 0; i < N; i++) {
@@ -331,13 +331,13 @@ void RunMatMulIntegerU8X8Test(const int M, const int N, const int K, bool non_ze
Eigen::MatrixXi matrix_c = ((matrix_b - b_zp_matrix) * matrix_a_offset).eval();
test.AddInput("T1", {M, K}, std::move(matrix_a_data));
- test.AddInput("T2", {K, N}, std::move(matrix_b_data), B_is_initializer);
+ test.AddInput("T2", {K, N}, std::move(matrix_b_data), B_is_initializer);
if (non_zero_zp) {
test.AddInput("a_zero_point", {}, {a_zero_point});
if (per_column_zp) {
- test.AddInput("b_zero_point", {N}, b_zp_per_column);
+ test.AddInput("b_zero_point", {N}, b_zp_per_column);
} else {
- test.AddInput("b_zero_point", {}, {b_zero_point});
+ test.AddInput("b_zero_point", {}, {b_zero_point});
}
}
diff --git a/onnxruntime/test/python/quantization/test_op_gemm.py b/onnxruntime/test/python/quantization/test_op_gemm.py
index 11a2ba488b4d5..cf61402fa5d84 100644
--- a/onnxruntime/test/python/quantization/test_op_gemm.py
+++ b/onnxruntime/test/python/quantization/test_op_gemm.py
@@ -130,7 +130,7 @@ def static_quant_test(self, model_fp32_path, data_reader, activation_type, weigh
data_reader.rewind()
quantize_static(model_fp32_path, model_int8_path, data_reader,
activation_type=activation_type, weight_type=weight_type, extra_options=extra_options)
- quant_nodes = {'QGemm': 2, 'QuantizeLinear': 1, 'DequantizeLinear': 1}
+ quant_nodes = {'QLinearMatMul': 2, 'QLinearAdd': 2, 'QuantizeLinear': 1, 'DequantizeLinear': 1}
check_op_type_count(self, model_int8_path, **quant_nodes)
qnode_io_qtypes = {'QuantizeLinear': [['i', 2, activation_proto_qtype], ['o', 0, activation_proto_qtype]]}
qnode_io_qtypes.update({'DequantizeLinear': [['i', 2, activation_proto_qtype]]})
@@ -147,7 +147,7 @@ def static_quant_test_qdq(self, model_fp32_path, data_reader, activation_type, w
data_reader.rewind()
quantize_static(model_fp32_path, model_int8_path, data_reader, quant_format=QuantFormat.QDQ,
activation_type=activation_type, weight_type=weight_type, extra_options=extra_options)
- quant_nodes = {'Gemm': 2, 'QuantizeLinear': 3, 'DequantizeLinear': 7}
+ quant_nodes = {'MatMul': 2, 'Add': 2, 'QuantizeLinear': 5, 'DequantizeLinear': 9}
check_op_type_count(self, model_int8_path, **quant_nodes)
qnode_io_qtypes = {'QuantizeLinear': [['i', 2, activation_proto_qtype], ['o', 0, activation_proto_qtype]]}
check_qtype_by_node_type(self, model_int8_path, qnode_io_qtypes)
diff --git a/onnxruntime/test/testdata/kernel_def_hashes/contrib.cpu.json b/onnxruntime/test/testdata/kernel_def_hashes/contrib.cpu.json
index 2a85dab3b5fd0..120e6de7c00f3 100644
--- a/onnxruntime/test/testdata/kernel_def_hashes/contrib.cpu.json
+++ b/onnxruntime/test/testdata/kernel_def_hashes/contrib.cpu.json
@@ -283,10 +283,6 @@
"QEmbedLayerNormalization com.microsoft CPUExecutionProvider",
9235385557940152248
],
- [
- "QGemm com.microsoft CPUExecutionProvider",
- 13009794669709617232
- ],
[
"QGemm com.microsoft CPUExecutionProvider",
13737193491843065240