Skip to content

Commit

Permalink
addressed CR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
guoyu-wang committed Jan 29, 2022
1 parent 5f52b46 commit 1b9e628
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 16 deletions.
19 changes: 10 additions & 9 deletions onnxruntime/core/providers/nnapi/nnapi_builtin/builders/helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -96,11 +96,15 @@ ConvType GetConvType(const NodeUnit& node_unit, const InitializedTensorSet& init
return ConvType::Grouped;
}

bool IsQuantizedConv(QuantizedOpType quant_op_type) {
return (quant_op_type == QuantizedOpType::QLinearConv) ||
(quant_op_type == QuantizedOpType::QDQConv);
}

bool IsQuantizedBinaryOp(QuantizedOpType quant_op_type) {
return quant_op_type == QuantizedOpType::QLinearConv ||
quant_op_type == QuantizedOpType::QLinearMatMul ||
return quant_op_type == QuantizedOpType::QLinearMatMul ||
quant_op_type == QuantizedOpType::QLinearAdd ||
quant_op_type == QuantizedOpType::QDQConv;
IsQuantizedConv(quant_op_type);
}

bool HasValidUnaryOpQuantizedInputs(const NodeUnit& node_unit) {
Expand Down Expand Up @@ -134,8 +138,7 @@ bool HasValidBinaryOpQuantizedInputs(const NodeUnit& node_unit) {

// QlinearConv supports u8u8 or u8s8
// QLinearMatMul/Add only support u8u8
bool is_quant_conv = (quant_op_type == QuantizedOpType::QLinearConv) ||
(quant_op_type == QuantizedOpType::QDQConv);
bool is_quant_conv = IsQuantizedConv(quant_op_type);
bool has_valid_qlinear_conv_weight =
(b_input_type == ONNX_NAMESPACE::TensorProto_DataType_UINT8 ||
b_input_type == ONNX_NAMESPACE::TensorProto_DataType_INT8);
Expand All @@ -157,8 +160,7 @@ bool HasValidQuantizationScales(const InitializedTensorSet& initializers, const
const std::vector<size_t>& indices, const OpSupportCheckParams& params, bool is_input) {
const auto& op_type = node_unit.OpType();
auto quant_op_type = GetQuantizedOpType(node_unit);
bool is_quant_conv = (quant_op_type == QuantizedOpType::QLinearConv) ||
(quant_op_type == QuantizedOpType::QDQConv);
bool is_quant_conv = IsQuantizedConv(quant_op_type);
bool is_quant_matmul = (quant_op_type == QuantizedOpType::QLinearMatMul);
const auto& io_defs = is_input ? node_unit.Inputs() : node_unit.Outputs();
for (const auto idx : indices) {
Expand Down Expand Up @@ -232,8 +234,7 @@ bool HasValidQuantizationZeroPoints(const InitializedTensorSet& initializers, co
const std::vector<size_t>& indices, bool is_input) {
const auto& op_type = node_unit.OpType();
auto quant_op_type = GetQuantizedOpType(node_unit);
bool is_quant_conv = (quant_op_type == QuantizedOpType::QLinearConv) ||
(quant_op_type == QuantizedOpType::QDQConv);
bool is_quant_conv = IsQuantizedConv(quant_op_type);
bool is_quant_matmul = (quant_op_type == QuantizedOpType::QLinearMatMul);

const auto& io_defs = is_input ? node_unit.Inputs() : node_unit.Outputs();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,9 @@ QuantizedOpType GetQuantizedOpType(const NodeUnit& node_unit);
// This function assumes the input is a 2d conv node
ConvType GetConvType(const NodeUnit& node_unit, const InitializedTensorSet& initializers);

// If this is a quantized Conv (QLinearConv or QDQConv)
bool IsQuantizedConv(QuantizedOpType quant_op_type);

// This quantized op is an operator or qdq node unit takes 2 inputs and produces 1 output
// Such as QLinearConv, QLinearMatMul, QLinearAdd, QDQConv,...
bool IsQuantizedBinaryOp(QuantizedOpType quant_op_type);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -495,7 +495,23 @@ Status ModelBuilder::AddOperations() {
const auto* node(graph_viewer_.GetNode(node_idx));
const NodeUnit& node_unit = GetNodeUnit(node);

// We only insert the NodeUnit once when we hit the target node
// Since we may have NodeUnit with multiple nodes, insert NodeUnit with the first occurrence of
// its node(s) in topological order may cause the incorrect topological order while inserting
// NodeUNits, for example,
// Q1
// |
// DQ1 DQ2
// \ |
// CONV
// |
// Q2
// In the above graph, we will have 2 NodeUnits, NU1 [Q1] and NU2 [DQ1, DQ2, CONV, Q2]
// The Q1 and DQ2 have the same topological order, if we insert DQ2 (as part of NU2) when we visit DQ2
// first in the topological order, the input from Q1 required by NU2 is not yet inserted, this will
// cause failure finding the inputs for NU2
//
// So we only insert the NodeUnit once when we hit the target node, to ensure the topological order
// of the NodeUnits
if (node != &node_unit.GetNode())
continue;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1260,9 +1260,7 @@ class ConvOpBuilder : public BaseOpBuilder {
};

/* static */ bool ConvOpBuilder::IsQuantizedOp(const NodeUnit& node_unit) {
auto quant_op_type = GetQuantizedOpType(node_unit);
return (quant_op_type == QuantizedOpType::QLinearConv) ||
(quant_op_type == QuantizedOpType::QDQConv);
return IsQuantizedConv(GetQuantizedOpType(node_unit));
}

/* static */ void
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -743,9 +743,7 @@ class ConvOpSupportChecker : public BaseOpSupportChecker {
}

/* static */ bool ConvOpSupportChecker::IsQuantizedOp(const NodeUnit& node_unit) {
auto quant_op_type = GetQuantizedOpType(node_unit);
return (quant_op_type == QuantizedOpType::QLinearConv) ||
(quant_op_type == QuantizedOpType::QDQConv);
return IsQuantizedConv(GetQuantizedOpType(node_unit));
}

bool ConvOpSupportChecker::HasSupportedInputsImpl(const NodeUnit& node_unit) const {
Expand Down

0 comments on commit 1b9e628

Please sign in to comment.