Skip to content

Commit

Permalink
Revert "add qdq support for QGemm (microsoft#10414)"
Browse files Browse the repository at this point in the history
This reverts commit 1aa0789.
  • Loading branch information
maxiwell committed Feb 9, 2022
1 parent 0d09dd5 commit d5b7845
Show file tree
Hide file tree
Showing 20 changed files with 68 additions and 579 deletions.
2 changes: 1 addition & 1 deletion docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,7 @@ Do not modify directly.*
|Pad|*in* data:**T**<br> *in* pads:**tensor(int64)**<br> *in* value:**T**<br> *out* output:**T**|1+|**T** = tensor(float)|
|QAttention|*in* input:**T1**<br> *in* weight:**T2**<br> *in* bias:**T3**<br> *in* input_scale:**T3**<br> *in* weight_scale:**T3**<br> *in* mask_index:**T4**<br> *in* input_zero_point:**T1**<br> *in* weight_zero_point:**T2**<br> *in* past:**T3**<br> *out* output:**T3**<br> *out* present:**T3**|1+|**T1** = tensor(uint8)<br/> **T2** = tensor(int8), tensor(uint8)<br/> **T3** = tensor(float)<br/> **T4** = tensor(int32)|
|QEmbedLayerNormalization|*in* input_ids:**T1**<br> *in* segment_ids:**T1**<br> *in* word_embedding_quant:**T2**<br> *in* position_embedding_quant:**T2**<br> *in* segment_embedding:**T2**<br> *in* gamma_quant:**T2**<br> *in* beta_quant:**T2**<br> *in* mask:**T1**<br> *in* word_embedding_scale:**T**<br> *in* position_embedding_scale:**T**<br> *in* segment_embedding_scale:**T**<br> *in* gamma_scale:**T**<br> *in* beta_scale:**T**<br> *in* word_embedding_zero_point:**T2**<br> *in* position_embedding_zero_point:**T2**<br> *in* segment_embedding_zero_point:**T2**<br> *in* gamma_zero_point:**T2**<br> *in* beta_zero_point:**T2**<br> *out* layernorm_out:**T**<br> *out* mask_index_out:**T1**|1+|**T** = tensor(float)|
|QGemm|*in* A:**TA**<br> *in* a_scale:**T**<br> *in* a_zero_point:**TA**<br> *in* B:**TB**<br> *in* b_scale:**T**<br> *in* b_zero_point:**TB**<br> *in* C:**TC**<br> *in* y_scale:**T**<br> *in* y_zero_point:**TYZ**<br> *out* Y:**TY**|1+|**T** = tensor(float)<br/> **TA** = tensor(int8), tensor(uint8)<br/> **TB** = tensor(int8), tensor(uint8)<br/> **TC** = tensor(int32)<br/> **TY** = tensor(float), tensor(int8), tensor(uint8)<br/> **TYZ** = tensor(int8), tensor(uint8)|
|QGemm|*in* A:**TA**<br> *in* a_scale:**T**<br> *in* a_zero_point:**TA**<br> *in* B:**TB**<br> *in* b_scale:**T**<br> *in* b_zero_point:**TB**<br> *in* C:**TC**<br> *in* y_scale:**T**<br> *in* y_zero_point:**TYZ**<br> *out* Y:**TY**|1+|**T** = tensor(float)<br/> **TA** = tensor(uint8)<br/> **TB** = tensor(int8), tensor(uint8)<br/> **TC** = tensor(int32)<br/> **TY** = tensor(float), tensor(uint8)<br/> **TYZ** = tensor(uint8)|
|QLinearAdd|*in* A:**T**<br> *in* A_scale:**tensor(float)**<br> *in* A_zero_point:**T**<br> *in* B:**T**<br> *in* B_scale:**tensor(float)**<br> *in* B_zero_point:**T**<br> *in* C_scale:**tensor(float)**<br> *in* C_zero_point:**T**<br> *out* C:**T**|1+|**T** = tensor(int8), tensor(uint8)|
|QLinearConv|*in* x:**T1**<br> *in* x_scale:**tensor(float)**<br> *in* x_zero_point:**T1**<br> *in* w:**T2**<br> *in* w_scale:**tensor(float)**<br> *in* w_zero_point:**T2**<br> *in* y_scale:**tensor(float)**<br> *in* y_zero_point:**T3**<br> *in* B:**T4**<br> *out* y:**T3**|1+|**T1** = tensor(int8), tensor(uint8)<br/> **T2** = tensor(int8), tensor(uint8)<br/> **T3** = tensor(int8), tensor(uint8)<br/> **T4** = tensor(int32)|
|QLinearLeakyRelu|*in* X:**T**<br> *in* X_scale:**tensor(float)**<br> *in* X_zero_point:**T**<br> *in* Y_scale:**tensor(float)**<br> *in* Y_zero_point:**T**<br> *out* Y:**T**|1+|**T** = tensor(int8), tensor(uint8)|
Expand Down
2 changes: 0 additions & 2 deletions onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,6 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int8_t, NhwcMaxPool);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, NhwcMaxPool);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, QEmbedLayerNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int8_t, QGemm);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, QGemm);
// ******** End: Quantization ******************* //

Expand Down Expand Up @@ -170,7 +169,6 @@ Status RegisterQuantizationKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int8_t, NhwcMaxPool)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, NhwcMaxPool)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, QEmbedLayerNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int8_t, QGemm)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, QGemm)>,
};

Expand Down
23 changes: 3 additions & 20 deletions onnxruntime/contrib_ops/cpu/quantization/quant_gemm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,8 @@ class QGemm : protected GemmBase, public MatMulIntegerBase {
AllocatorPtr allocator;
ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator));

bool a_is_signed = a->IsDataType<int8_t>();
const uint8_t* a_data = static_cast<const uint8_t*>(a->DataRaw());

BufferUniquePtr a_trans_buffer;
const uint8_t* a_data = a->template Data<uint8_t>();
if (trans_A_ == CblasTrans) {
a_data = quantization::TransPoseInputData(a_data, a_trans_buffer, allocator, K, M);
}
Expand Down Expand Up @@ -84,12 +82,12 @@ class QGemm : protected GemmBase, public MatMulIntegerBase {
GemmBroadcastBias(M, N, 1.f, c->template Data<int32_t>(), &(c->Shape()), gemm_output_data);
}

MLAS_GEMM_QUANT_SHAPE_PARAMS gemm_shape{M, N, K, a_is_signed, b_is_signed, c != nullptr};
MLAS_GEMM_QUANT_SHAPE_PARAMS gemm_shape{M, N, K, false /*AIsSigned*/, b_is_signed, c != nullptr};
MLAS_GEMM_QUANT_DATA_PARAMS gemm_param;

gemm_param.A = a_data;
gemm_param.lda = gemm_shape.K;
gemm_param.ZeroPointA = *(static_cast<const uint8_t*>(a_zp->DataRaw()));
gemm_param.ZeroPointA = *(a_zp->template Data<uint8_t>());

gemm_param.B = b_data;
gemm_param.ldb = gemm_shape.N;
Expand Down Expand Up @@ -222,20 +220,5 @@ ONNX_OPERATOR_TYPED_KERNEL_EX(
.TypeConstraint("TY", {DataTypeImpl::GetTensorType<float>(), DataTypeImpl::GetTensorType<uint8_t>()}),
QGemm);

ONNX_OPERATOR_TYPED_KERNEL_EX(
QGemm,
kMSDomain,
1,
int8_t,
kCpuExecutionProvider,
KernelDefBuilder()
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>())
.TypeConstraint("TA", DataTypeImpl::GetTensorType<int8_t>())
.TypeConstraint("TB", DataTypeImpl::GetTensorType<int8_t>())
.TypeConstraint("TC", DataTypeImpl::GetTensorType<int32_t>())
.TypeConstraint("TYZ", DataTypeImpl::GetTensorType<int8_t>())
.TypeConstraint("TY", {DataTypeImpl::GetTensorType<float>(), DataTypeImpl::GetTensorType<int8_t>()}),
QGemm);

} // namespace contrib
} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -223,67 +223,5 @@ Status MatMulReplaceWithQLinear::Run(Graph& graph, const NodesToOptimize& select
}
}

static std::vector<NodeAndMoveInfo> GetGemmMoveInfo(bool does_q_node_exist) {
NTO::NodeLocation dq_A{NTO::NodeType::kInput, 0};
NTO::NodeLocation dq_B{NTO::NodeType::kInput, 1};
NTO::NodeLocation dq_bias{NTO::NodeType::kInput, 2};
NTO::NodeLocation target{NTO::NodeType::kTarget, 0};
NTO::NodeLocation q{NTO::NodeType::kOutput, 0};

std::vector<NodeAndMoveInfo> moves{
MoveAll(dq_A, ArgType::kInput), // append all inputs from DQ of A
MoveAll(dq_B, ArgType::kInput), // append all inputs from DQ of B
MoveAndAppend(dq_bias, ArgType::kInput, 0, ArgType::kInput, true, true)}; // (optional) append bias

if (does_q_node_exist) {
moves.push_back(MoveAndAppend(q, ArgType::kInput, 1, ArgType::kInput)); // append scale (input 1) from Q
moves.push_back(MoveAndAppend(q, ArgType::kInput, 2, ArgType::kInput)); // append zp (input 2) from Q
moves.push_back(MoveAll(q, ArgType::kOutput)); // and use the outputs from Q
} else {
moves.push_back(MoveAll(target, ArgType::kOutput));
}

return moves;
}

GemmReplaceWithQuant::GemmReplaceWithQuant()
: qgemm_with_float_as_output_replacer_(kMSDomain, GetGemmMoveInfo(false), "QGemm"),
qgemm_with_8bits_as_output_replacer_(kMSDomain, GetGemmMoveInfo(true), "QGemm") {
}

Status GemmReplaceWithQuant::Run(Graph& graph, const NodesToOptimize& selected_nodes) const {
RemoveAttrBeta(selected_nodes);
bool is_output_float = selected_nodes.num_outputs == 0;
if (is_output_float) {
return qgemm_with_float_as_output_replacer_.Run(graph, selected_nodes);
}

return qgemm_with_8bits_as_output_replacer_.Run(graph, selected_nodes);
}

#if !defined(ORT_MINIMAL_BUILD)
Status GemmReplaceWithQuant::RunForSave(Graph& graph,
const NodesToOptimize& selected_nodes,
const SatRuntimeOptimizationSaveContext& save_context,
SavedState& saved_state,
bool& graph_modified) const {
RemoveAttrBeta(selected_nodes);
bool is_output_float = selected_nodes.num_outputs == 0;
if (is_output_float) {
return qgemm_with_float_as_output_replacer_.RunForSave(graph,
selected_nodes,
save_context,
saved_state,
graph_modified);
}

return qgemm_with_8bits_as_output_replacer_.RunForSave(graph,
selected_nodes,
save_context,
saved_state,
graph_modified);
}
#endif // !defined(ORT_MINIMAL_BUILD)

} // namespace QDQ
} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -67,25 +67,5 @@ struct MatMulReplaceWithQLinear : public Action {
BinaryReplaceWithQLinear qlinear_matmul_replacer_;
};

struct GemmReplaceWithQuant : public Action {
GemmReplaceWithQuant();

Status Run(Graph&, const NodesToOptimize& selected_nodes) const override;

#if !defined(ORT_MINIMAL_BUILD)
Status RunForSave(Graph& /*graph*/, const NodesToOptimize& /*selected_nodes*/,
const SatRuntimeOptimizationSaveContext& /*save_context*/,
SavedState& /*saved_state*/, bool& /*graph_modified*/) const override;
#endif // !defined(ORT_MINIMAL_BUILD)

static inline void RemoveAttrBeta(const NodesToOptimize& selected_nodes) {
selected_nodes.Target().ClearAttribute("beta");
}

private:
QDQReplaceWithNew qgemm_with_float_as_output_replacer_;
QDQReplaceWithNew qgemm_with_8bits_as_output_replacer_;
};

} // namespace QDQ
} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -167,26 +167,6 @@ void MatMulQDQRules(SelectorActionRegistry& qdq_selector_action_registry, bool i
#endif
}

void GemmQDQRules(SelectorActionRegistry& qdq_selector_action_registry) {
// 3 to 5 nodes. 0=DQ A, 1=DQ B, 2=DQ C(optional), 3=Gemm, 4=Q Y(optional)
// Replace with QGemm
// Delete all original nodes.
const std::string action_name{"Gemm"};

std::unique_ptr<Action> action = std::make_unique<QDQ::GemmReplaceWithQuant>();

#if !defined(ORT_MINIMAL_BUILD)
std::unique_ptr<NodeSelector> selector = std::make_unique<QDQ::GemmSelector>();
qdq_selector_action_registry.RegisterSelectorAndAction(action_name,
{{"Gemm", {}}},
std::move(selector),
std::move(action));

#else
qdq_selector_action_registry.RegisterAction(action_name, std::move(action));
#endif
}

SelectorActionRegistry CreateSelectorActionRegistry(bool is_int8_allowed) {
SelectorActionRegistry qdq_selector_action_registry;

Expand All @@ -197,7 +177,6 @@ SelectorActionRegistry CreateSelectorActionRegistry(bool is_int8_allowed) {
VariadicOpQDQRules(qdq_selector_action_registry);
ConvQDQRules(qdq_selector_action_registry, is_int8_allowed);
MatMulQDQRules(qdq_selector_action_registry, is_int8_allowed);
GemmQDQRules(qdq_selector_action_registry);

return qdq_selector_action_registry;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,23 +40,16 @@ static std::vector<const Node*> FindQDQNodes(const GraphViewer& graph_viewer, co
bool NodeGroupSelector::CheckQDQNodes(const GraphViewer& graph_viewer, const Node& node,
const std::vector<const Node*>& dq_nodes,
const std::vector<const Node*>& q_nodes,
int num_dq_inputs,
bool is_empty_q_nodes_allowed) const {
int num_dq_inputs) const {
if (num_dq_inputs == -1) {
num_dq_inputs = NumActualValues(node, true);
}

// The input is a Graph Viewer, so cannot use graph_utils or optimizer_utils
if (num_dq_inputs != gsl::narrow_cast<int>(dq_nodes.size())) {
return false;
}

if (q_nodes.empty()) {
return is_empty_q_nodes_allowed;
}

int num_outputs = NumActualValues(node, false); // number of outputs that exist
return (num_outputs == gsl::narrow_cast<int>(q_nodes.size())) &&

// The input is a Graph Viewer, so cannot use graph_utils or optimizer_utils
return num_dq_inputs == gsl::narrow_cast<int>(dq_nodes.size()) &&
num_outputs == gsl::narrow_cast<int>(q_nodes.size()) &&
q_nodes.size() == node.GetOutputEdgesCount() &&
!graph_viewer.NodeProducesGraphOutput(node);
}
Expand Down Expand Up @@ -255,48 +248,6 @@ bool MatMulNodeGroupSelector::Check(const GraphViewer& graph_viewer,
}
}

bool GemmNodeGroupSelector::Check(const GraphViewer& graph_viewer,
const Node& node,
const std::vector<const Node*>& dq_nodes,
const std::vector<const Node*>& q_nodes) const {
if (!CheckQDQNodes(graph_viewer, node, dq_nodes, q_nodes,
-1 /*num_dq_inputs*/, true /*is_empty_q_nodes_allowed*/)) {
return false;
}

// input and output types need to be same
int32_t dt_A = dq_nodes[0]->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type();
int32_t dt_B = dq_nodes[1]->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type();

if (dt_A == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8) {
if (dt_A != dt_B) { // if A is signed int, B must be signed int
return false;
}
}

if (!q_nodes.empty()) {
int32_t dt_Y = q_nodes[0]->OutputDefs()[0]->TypeAsProto()->tensor_type().elem_type();
if (dt_A != dt_Y) { // activation and output must be same type
return false;
}
}

if (dq_nodes.size() < 3) { // no bias
return true;
}

if (node.GetAttributes().at("beta").f() != 1.0) { // beta needs to be 1.0
return false;
}

int32_t dt_bias = dq_nodes[2]->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type();
return dt_bias == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32;
}

void GemmSelector::UpdateBuilder(NodesToOptimizeIndicesBuilder& builder) const {
builder.input_nodes.resize(3, NodesToOptimizeIndices::kEmptyNodeIndex);
}

} // namespace QDQ
} // namespace onnxruntime

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,7 @@ class NodeGroupSelector {
bool CheckQDQNodes(const GraphViewer& graph_viewer, const Node& node,
const std::vector<const Node*>& dq_nodes,
const std::vector<const Node*>& q_nodes,
int num_dq_inputs = -1,
bool is_empty_q_nodes_allowed = false) const;
int num_dq_inputs = -1) const;

private:
// derived classes should implement this check
Expand Down Expand Up @@ -121,15 +120,6 @@ class MatMulNodeGroupSelector : public NodeGroupSelector {
bool matmulintegertofloat_allowed_;
};

// Input: DQ nodes for A, B and optional C
// Output: optional Q node for Y
class GemmNodeGroupSelector : public NodeGroupSelector {
private:
bool Check(const GraphViewer& graph_viewer, const Node& node,
const std::vector<const Node*>& dq_nodes,
const std::vector<const Node*>& q_nodes) const override;
};

/*
* NodeSelector instances for use in the QDQ::SelectorActionTransformer.
*/
Expand Down Expand Up @@ -200,16 +190,6 @@ class MatMulSelector : public BaseSelector {
: BaseSelector(std::make_unique<MatMulNodeGroupSelector>(int8_allowed, /*matmulintegertofloat_allowed*/ true)) {}
};

// Input: DQ nodes for A, B and optional C
// Output: optional Q node for Y
class GemmSelector : public BaseSelector {
public:
GemmSelector()
: BaseSelector(std::make_unique<GemmNodeGroupSelector>()) {}

void UpdateBuilder(NodesToOptimizeIndicesBuilder&) const override;
};

} // namespace QDQ
} // namespace onnxruntime

Expand Down
10 changes: 0 additions & 10 deletions onnxruntime/core/optimizer/selectors_actions/helpers.cc
Original file line number Diff line number Diff line change
Expand Up @@ -314,16 +314,6 @@ Status MoveInputOutput(Graph& graph, const NodesToOptimize& selected_nodes, Node
if (src != nullptr) {
ORT_RETURN_IF_ERROR(MoveInputOutputImpl(graph, move.value_move_info, *src, dest,
only_update_dest_definitions));
} else if (move.value_move_info.optional &&
move.value_move_info.fill_optional_with_empty) {
auto& dest_defs = (move.value_move_info.dest_slot.in_out == ArgType::kInput)
? dest.MutableInputDefs()
: dest.MutableOutputDefs();
dest_defs.push_back(&graph.GetOrCreateNodeArg("", nullptr));

if (move.value_move_info.dest_slot.in_out == ArgType::kInput) {
dest.MutableInputArgsCount().push_back(1);
}
}
}
}
Expand Down
22 changes: 7 additions & 15 deletions onnxruntime/core/optimizer/selectors_actions/helpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -160,24 +160,18 @@ struct ValueMoveInfo {
}

// append single value (may be variadic) from source to destination
ValueMoveInfo(InOutDefSlot src_slot_in,
ArgType dest_slot_type,
bool is_optional = false,
bool fill_optional_with_empty = false)
ValueMoveInfo(InOutDefSlot src_slot_in, ArgType dest_slot_type, bool is_optional = false)
: src_slot(src_slot_in),
dest_slot{dest_slot_type, -1},
copy_all{false},
append{true},
optional{is_optional},
fill_optional_with_empty{fill_optional_with_empty} {}
optional{is_optional} {}

InOutDefSlot src_slot;
InOutDefSlot dest_slot;
bool copy_all{false}; // ignore src_slot.idx and copy all values
bool append{false}; // ignore dest_slot.idx and append to existing values
bool optional{false}; // optional copy that can be skipped if source node is missing
bool fill_optional_with_empty; // fill optional NodeArg by NodeArg with empty name.
// Only support in 'append single value' mode.
bool copy_all{false}; // ignore src_slot.idx and copy all values
bool append{false}; // ignore dest_slot.idx and append to existing values
bool optional{false}; // optional copy that can be skipped if source node is missing

private:
ValueMoveInfo() = default;
Expand Down Expand Up @@ -219,12 +213,10 @@ inline NodeAndMoveInfo MoveToSlot(const NodesToOptimize::NodeLocation& src_node,
inline NodeAndMoveInfo MoveAndAppend(const NodesToOptimize::NodeLocation& src_node,
ArgType src_direction, int src_slot,
ArgType dest_direction,
bool optional = false,
bool fill_optional_with_empty = false) {
bool optional = false) {
return NodeAndMoveInfo{src_node, ValueMoveInfo{
InOutDefSlot{src_direction, src_slot}, // move from this slot
dest_direction, optional,
fill_optional_with_empty}}; // append here
dest_direction, optional}}; // append here
}

// move all inputs/outputs from the source node to the target/replacement node
Expand Down
Loading

0 comments on commit d5b7845

Please sign in to comment.