Skip to content

Commit

Permalink
update quantized io check functions
Browse files Browse the repository at this point in the history
  • Loading branch information
guoyu-wang committed Feb 4, 2022
1 parent 87f4d1d commit a6f0a0d
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 65 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -138,16 +138,17 @@ bool HasValidBinaryOpQuantizedInputTypes(const NodeUnit& node_unit) {
if (!GetType(inputs[1].node_arg, b_input_type))
return false;

// QlinearConv supports u8u8 or u8s8
// QLinearMatMul/Add only support u8u8
bool is_quant_conv = IsQuantizedConv(quant_op_type);
// QlinearConv/Mul supports u8u8 or u8s8
// QLinearAdd only support u8u8
bool is_quant_conv_or_matmul = IsQuantizedConv(quant_op_type) || (quant_op_type == QuantizedOpType::QLinearMatMul);

bool has_valid_qlinear_conv_weight =
(b_input_type == ONNX_NAMESPACE::TensorProto_DataType_UINT8 ||
b_input_type == ONNX_NAMESPACE::TensorProto_DataType_INT8);

if (a_input_type != ONNX_NAMESPACE::TensorProto_DataType_UINT8 ||
(!is_quant_conv && a_input_type != b_input_type) ||
(is_quant_conv && !has_valid_qlinear_conv_weight)) {
(!is_quant_conv_or_matmul && a_input_type != b_input_type) ||
(is_quant_conv_or_matmul && !has_valid_qlinear_conv_weight)) {
LOGS_DEFAULT(VERBOSE) << "[" << node_unit.OpType()
<< "] A Input type: [" << a_input_type
<< "] B Input type: [" << b_input_type
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,9 @@ class BaseOpSupportChecker : public IOpSupportChecker {
return ANEURALNETWORKS_FEATURE_LEVEL_1;
}

virtual bool HasSupportedInputOutputsImpl(const NodeUnit& node_unit) const;
virtual bool HasSupportedInputOutputsImpl(
const InitializedTensorSet& initializers, const NodeUnit& node_unit,
const OpSupportCheckParams& params) const;

virtual int GetMinSupportedOpSet(const NodeUnit& /* node_unit */) const { return 1; }
virtual int GetMaxSupportedOpSet(const NodeUnit& /* node_unit */) const { return 15; }
Expand All @@ -112,7 +114,8 @@ class BaseOpSupportChecker : public IOpSupportChecker {

private:
bool HasSupportedOpSet(const NodeUnit& node_unit) const;
bool HasSupportedInputOutputs(const NodeUnit& node_unit) const;
bool HasSupportedInputOutputs(const InitializedTensorSet& initializers, const NodeUnit& node_unit,
const OpSupportCheckParams& params) const;
};

/* static */ void BaseOpSupportChecker::CreateSharedOpSupportChecker(
Expand All @@ -138,7 +141,7 @@ bool BaseOpSupportChecker::IsOpSupported(const InitializedTensorSet& initializer
if (!IsNodeUnitTypeSupported(node_unit))
return false;

if (!HasSupportedInputOutputs(node_unit))
if (!HasSupportedInputOutputs(initializers, node_unit, params))
return false;

// We do not support external initializers for now
Expand All @@ -151,7 +154,8 @@ bool BaseOpSupportChecker::IsOpSupported(const InitializedTensorSet& initializer
return IsOpSupportedImpl(initializers, node_unit, params);
}

bool BaseOpSupportChecker::HasSupportedInputOutputs(const NodeUnit& node_unit) const {
bool BaseOpSupportChecker::HasSupportedInputOutputs(const InitializedTensorSet& initializers, const NodeUnit& node_unit,
const OpSupportCheckParams& params) const {
// We do not support unknown(null) input shape
auto has_supported_shape = [](const NodeArg& node_arg, const std::string& name, const std::string op_type) {
const auto* shape_proto = node_arg.Shape();
Expand Down Expand Up @@ -185,10 +189,12 @@ bool BaseOpSupportChecker::HasSupportedInputOutputs(const NodeUnit& node_unit) c
return false;
}
}
return HasSupportedInputOutputsImpl(node_unit);
return HasSupportedInputOutputsImpl(initializers, node_unit, params);
}

bool BaseOpSupportChecker::HasSupportedInputOutputsImpl(const NodeUnit& node_unit) const {
bool BaseOpSupportChecker::HasSupportedInputOutputsImpl(
const InitializedTensorSet& initializers, const NodeUnit& node_unit,
const OpSupportCheckParams& params) const {
// We only check the type of input 0 by default
// specific op builder can override this
const auto& input = node_unit.Inputs()[0].node_arg;
Expand Down Expand Up @@ -245,8 +251,12 @@ class BinaryOpSupportChecker : public BaseOpSupportChecker {
const OpSupportCheckParams& params) const override;
bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const NodeUnit& node_unit,
const OpSupportCheckParams& params) const override;
bool HasSupportedInputOutputsImpl(const NodeUnit& node_unit) const override;
bool HasSupportedInputOutputsImpl(
const InitializedTensorSet& initializers, const NodeUnit& node_unit,
const OpSupportCheckParams& params) const override;
int GetMinSupportedOpSet(const NodeUnit& node_unit) const override;

static bool IsQuantizedOp(const NodeUnit& node_unit);
};

/* static */ void BinaryOpSupportChecker::CreateSharedOpSupportChecker(
Expand All @@ -263,6 +273,10 @@ class BinaryOpSupportChecker : public BaseOpSupportChecker {
});
}

/* static */ bool BinaryOpSupportChecker::IsQuantizedOp(const NodeUnit& node_unit) {
return GetQuantizedOpType(node_unit) == QuantizedOpType::QLinearAdd;
}

int32_t BinaryOpSupportChecker::GetMinSupportedNNAPIFeatureLevel(
const NodeUnit& node_unit, const OpSupportCheckParams& /* params */) const {
const auto& op(node_unit.OpType());
Expand All @@ -287,16 +301,24 @@ int BinaryOpSupportChecker::GetMinSupportedOpSet(const NodeUnit& node_unit) cons
return 1;
}

bool BinaryOpSupportChecker::HasSupportedInputOutputsImpl(const NodeUnit& node_unit) const {
bool is_qlinear_add = node_unit.OpType() == "QLinearAdd";
bool BinaryOpSupportChecker::HasSupportedInputOutputsImpl(
const InitializedTensorSet& initializers, const NodeUnit& node_unit,
const OpSupportCheckParams& params) const {
bool is_quantized_op = IsQuantizedOp(node_unit);
bool is_pow = node_unit.OpType() == "Pow";
if (!is_qlinear_add && !is_pow)
return BaseOpSupportChecker::HasSupportedInputOutputsImpl(node_unit);
if (!is_quantized_op && !is_pow)
return BaseOpSupportChecker::HasSupportedInputOutputsImpl(initializers, node_unit, params);

if (is_qlinear_add) {
// QLinearAdd
if (is_quantized_op) {
// QLinearAdd/QDQAdd/QDQMul
if (!HasValidBinaryOpQuantizedInputTypes(node_unit))
return false;

if (!IsQuantizedIOSupported(initializers, node_unit, {0, 1}, params, true /* is_input */))
return false;

if (!IsQuantizedIOSupported(initializers, node_unit, {0}, params, false /* is_input */))
return false;
}

// Pow we only support both input as fp32 now
Expand Down Expand Up @@ -324,7 +346,6 @@ bool BinaryOpSupportChecker::IsOpSupportedImpl(const InitializedTensorSet& initi
const OpSupportCheckParams& params) const {
const auto& op_type(node_unit.OpType());
const auto& inputs = node_unit.Inputs();
bool op_is_qlinear = op_type == "QLinearAdd";
Shape input1_shape, input2_shape;
if (!GetShape(inputs[0].node_arg, input1_shape) ||
!GetShape(inputs[1].node_arg, input2_shape))
Expand All @@ -339,32 +360,6 @@ bool BinaryOpSupportChecker::IsOpSupportedImpl(const InitializedTensorSet& initi
return false;
}

if (op_is_qlinear) {
// For QLinearAdd, we only support uint8 output now
int32_t output_type;
if (!GetType(node_unit.Outputs()[0].node_arg, output_type))
return false;

if (output_type != ONNX_NAMESPACE::TensorProto_DataType_UINT8) {
LOGS_DEFAULT(VERBOSE) << "[" << op_type
<< "] output type: [" << output_type
<< "] is not supported for now";
return false;
}

// Check input scales and ZPs
if (!HasValidQuantizationScales(initializers, node_unit, {0, 1}, params, true /* is_input */))
return false;
if (!HasValidQuantizationZeroPoints(initializers, node_unit, {0, 1}, true /* is_input */))
return false;

// Check output scale and ZP
if (!HasValidQuantizationScales(initializers, node_unit, {0}, params, false /* is_input */))
return false;
if (!HasValidQuantizationZeroPoints(initializers, node_unit, {0}, false /* is_input */))
return false;
}

return true;
}

Expand All @@ -382,7 +377,9 @@ class TransposeOpSupportChecker : public BaseOpSupportChecker {
return ANEURALNETWORKS_FEATURE_LEVEL_2;
}

bool HasSupportedInputOutputsImpl(const NodeUnit& node_unit) const override;
bool HasSupportedInputOutputsImpl(
const InitializedTensorSet& initializers, const NodeUnit& node_unit,
const OpSupportCheckParams& params) const override;
};

bool TransposeOpSupportChecker::IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const NodeUnit& node_unit,
Expand All @@ -401,7 +398,9 @@ bool TransposeOpSupportChecker::IsOpSupportedImpl(const InitializedTensorSet& /*
return true;
}

bool TransposeOpSupportChecker::HasSupportedInputOutputsImpl(const NodeUnit& node_unit) const {
bool TransposeOpSupportChecker::HasSupportedInputOutputsImpl(
const InitializedTensorSet& /* initializers */, const NodeUnit& node_unit,
const OpSupportCheckParams& /* params */) const {
int32_t input_type;
if (!GetType(node_unit.Inputs()[0].node_arg, input_type))
return false;
Expand Down Expand Up @@ -561,7 +560,9 @@ class PoolOpSupportChecker : public BaseOpSupportChecker {
return params.use_nchw ? ANEURALNETWORKS_FEATURE_LEVEL_3 : ANEURALNETWORKS_FEATURE_LEVEL_2;
}

bool HasSupportedInputOutputsImpl(const NodeUnit& node_unit) const override;
bool HasSupportedInputOutputsImpl(
const InitializedTensorSet& /* initializers */, const NodeUnit& node_unit,
const OpSupportCheckParams& /* params */) const override;
};

/* static */ void PoolOpSupportChecker::CreateSharedOpSupportChecker(
Expand Down Expand Up @@ -691,11 +692,13 @@ bool PoolOpSupportChecker::IsOpSupportedImpl(const InitializedTensorSet& initial
return true;
}

bool PoolOpSupportChecker::HasSupportedInputOutputsImpl(const NodeUnit& node_unit) const {
bool PoolOpSupportChecker::HasSupportedInputOutputsImpl(
const InitializedTensorSet& initializers, const NodeUnit& node_unit,
const OpSupportCheckParams& params) const {
bool is_max_pool = node_unit.OpType() == "MaxPool";
bool is_qlinear_average_pool = node_unit.OpType() == "QLinearAveragePool";
if (!is_max_pool && !is_qlinear_average_pool)
return BaseOpSupportChecker::HasSupportedInputOutputsImpl(node_unit);
return BaseOpSupportChecker::HasSupportedInputOutputsImpl(initializers, node_unit, params);

if (is_qlinear_average_pool) {
return HasValidUnaryOpQuantizedInputs(node_unit);
Expand Down Expand Up @@ -736,7 +739,9 @@ class ConvOpSupportChecker : public BaseOpSupportChecker {
return params.use_nchw ? ANEURALNETWORKS_FEATURE_LEVEL_3 : ANEURALNETWORKS_FEATURE_LEVEL_2;
}

bool HasSupportedInputOutputsImpl(const NodeUnit& node_unit) const override;
bool HasSupportedInputOutputsImpl(
const InitializedTensorSet& /* initializers */, const NodeUnit& node_unit,
const OpSupportCheckParams& /* params */) const override;
bool IsNodeUnitTypeSupported(const NodeUnit& /* node_unit */) const override { return true; }
static bool IsQuantizedOp(const NodeUnit& node_unit);
};
Expand All @@ -755,9 +760,11 @@ class ConvOpSupportChecker : public BaseOpSupportChecker {
return IsQuantizedConv(GetQuantizedOpType(node_unit));
}

bool ConvOpSupportChecker::HasSupportedInputOutputsImpl(const NodeUnit& node_unit) const {
bool ConvOpSupportChecker::HasSupportedInputOutputsImpl(
const InitializedTensorSet& initializers, const NodeUnit& node_unit,
const OpSupportCheckParams& params) const {
if (!IsQuantizedOp(node_unit))
return BaseOpSupportChecker::HasSupportedInputOutputsImpl(node_unit);
return BaseOpSupportChecker::HasSupportedInputOutputsImpl(initializers, node_unit, params);

// QLinearConv only supports input of uint8 for now
if (!HasValidBinaryOpQuantizedInputTypes(node_unit))
Expand Down Expand Up @@ -925,13 +932,17 @@ class GemmOpSupportChecker : public BaseOpSupportChecker {
private:
bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const NodeUnit& node_unit,
const OpSupportCheckParams& params) const override;
bool HasSupportedInputOutputsImpl(const NodeUnit& node_unit) const override;
bool HasSupportedInputOutputsImpl(
const InitializedTensorSet& /* initializers */, const NodeUnit& node_unit,
const OpSupportCheckParams& /* params */) const override;
int GetMinSupportedOpSet(const NodeUnit& node_unit) const override;
};

bool GemmOpSupportChecker::HasSupportedInputOutputsImpl(const NodeUnit& node_unit) const {
bool GemmOpSupportChecker::HasSupportedInputOutputsImpl(
const InitializedTensorSet& initializers, const NodeUnit& node_unit,
const OpSupportCheckParams& params) const {
if (node_unit.OpType() != "QLinearMatMul")
return BaseOpSupportChecker::HasSupportedInputOutputsImpl(node_unit);
return BaseOpSupportChecker::HasSupportedInputOutputsImpl(initializers, node_unit, params);

// QLinearMatMul
if (!HasValidBinaryOpQuantizedInputTypes(node_unit))
Expand Down Expand Up @@ -1121,7 +1132,9 @@ class UnaryOpSupportChecker : public BaseOpSupportChecker {
int32_t GetMinSupportedNNAPIFeatureLevel(const NodeUnit& /* node_unit */,
const OpSupportCheckParams& params) const override;

bool HasSupportedInputOutputsImpl(const NodeUnit& node_unit) const override;
bool HasSupportedInputOutputsImpl(
const InitializedTensorSet& /* initializers */, const NodeUnit& node_unit,
const OpSupportCheckParams& /* params */) const override;

int GetMinSupportedOpSet(const NodeUnit& node_unit) const override;

Expand Down Expand Up @@ -1170,10 +1183,12 @@ int32_t UnaryOpSupportChecker::GetMinSupportedNNAPIFeatureLevel(const NodeUnit&
return ANEURALNETWORKS_FEATURE_LEVEL_1;
}

bool UnaryOpSupportChecker::HasSupportedInputOutputsImpl(const NodeUnit& node_unit) const {
bool UnaryOpSupportChecker::HasSupportedInputOutputsImpl(
const InitializedTensorSet& initializers, const NodeUnit& node_unit,
const OpSupportCheckParams& params) const {
// We only need to override input check for QLinearSigmoid
if (node_unit.OpType() != "QLinearSigmoid")
return BaseOpSupportChecker::HasSupportedInputOutputsImpl(node_unit);
return BaseOpSupportChecker::HasSupportedInputOutputsImpl(initializers, node_unit, params);

return HasValidUnaryOpQuantizedInputs(node_unit);
}
Expand Down Expand Up @@ -1243,7 +1258,9 @@ class ConcatOpSupportChecker : public BaseOpSupportChecker {
bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const NodeUnit& node_unit,
const OpSupportCheckParams& params) const override;

bool HasSupportedInputOutputsImpl(const NodeUnit& node_unit) const override;
bool HasSupportedInputOutputsImpl(
const InitializedTensorSet& /* initializers */, const NodeUnit& node_unit,
const OpSupportCheckParams& /* params */) const override;
};

bool ConcatOpSupportChecker::IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const NodeUnit& node_unit,
Expand All @@ -1262,7 +1279,9 @@ bool ConcatOpSupportChecker::IsOpSupportedImpl(const InitializedTensorSet& /* in
return true;
}

bool ConcatOpSupportChecker::HasSupportedInputOutputsImpl(const NodeUnit& node_unit) const {
bool ConcatOpSupportChecker::HasSupportedInputOutputsImpl(
const InitializedTensorSet& /* initializers */, const NodeUnit& node_unit,
const OpSupportCheckParams& /* params */) const {
int32_t input_type;
if (!GetType(node_unit.Inputs()[0].node_arg, input_type))
return false;
Expand Down Expand Up @@ -1370,7 +1389,9 @@ class DequantizeLinearOpSupportChecker : public BaseOpSupportChecker {
const OpSupportCheckParams& /* params */) const override {
return ANEURALNETWORKS_FEATURE_LEVEL_1;
}
bool HasSupportedInputOutputsImpl(const NodeUnit& node_unit) const override;
bool HasSupportedInputOutputsImpl(
const InitializedTensorSet& /* initializers */, const NodeUnit& node_unit,
const OpSupportCheckParams& /* params */) const override;
};

bool DequantizeLinearOpSupportChecker::IsOpSupportedImpl(const InitializedTensorSet& initializers, const NodeUnit& node_unit,
Expand All @@ -1385,7 +1406,9 @@ bool DequantizeLinearOpSupportChecker::IsOpSupportedImpl(const InitializedTensor
return true;
}

bool DequantizeLinearOpSupportChecker::HasSupportedInputOutputsImpl(const NodeUnit& node_unit) const {
bool DequantizeLinearOpSupportChecker::HasSupportedInputOutputsImpl(
const InitializedTensorSet& /* initializers */, const NodeUnit& node_unit,
const OpSupportCheckParams& /* params */) const {
int32_t input_type;
if (!GetType(node_unit.Inputs()[0].node_arg, input_type))
return false;
Expand Down Expand Up @@ -1474,7 +1497,9 @@ class ResizeOpSupportChecker : public BaseOpSupportChecker {
// We only support Resize opset 11+ here
int GetMinSupportedOpSet(const NodeUnit& /* node_unit */) const override { return 11; }

bool HasSupportedInputOutputsImpl(const NodeUnit& node_unit) const override;
bool HasSupportedInputOutputsImpl(
const InitializedTensorSet& /* initializers */, const NodeUnit& node_unit,
const OpSupportCheckParams& /* params */) const override;
bool IsNodeUnitTypeSupported(const NodeUnit& /* node_unit */) const override { return true; }
static bool IsQuantizedOp(const NodeUnit& node_unit) ORT_MUST_USE_RESULT; // TODO, see if we want to move this to BaseOpBuilder
};
Expand Down Expand Up @@ -1647,7 +1672,9 @@ int32_t ResizeOpSupportChecker::GetMinSupportedNNAPIFeatureLevel(const NodeUnit&
return ANEURALNETWORKS_FEATURE_LEVEL_2;
}

bool ResizeOpSupportChecker::HasSupportedInputOutputsImpl(const NodeUnit& node_unit) const {
bool ResizeOpSupportChecker::HasSupportedInputOutputsImpl(
const InitializedTensorSet& /* initializers */, const NodeUnit& node_unit,
const OpSupportCheckParams& /* params */) const {
int32_t input_type;
if (!GetType(node_unit.Inputs()[0].node_arg, input_type))
return false;
Expand Down

0 comments on commit a6f0a0d

Please sign in to comment.