Skip to content

Commit

Permalink
Add qdq mul support
Browse files Browse the repository at this point in the history
  • Loading branch information
guoyu-wang committed Feb 7, 2022
1 parent 9cadda7 commit 1238049
Show file tree
Hide file tree
Showing 8 changed files with 145 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,10 @@ QuantizedOpType GetQuantizedOpType(const NodeUnit& node_unit) {
return QuantizedOpType::QDQResize;
else if (op_type == "AveragePool")
return QuantizedOpType::QDQAveragePool;
else if (op_type == "Add")
return QuantizedOpType::QDQAdd;
else if (op_type == "Mul")
return QuantizedOpType::QDQMul;
} else {
// throw?
// Do we want to throw here? seems got neglected last time
Expand Down Expand Up @@ -114,6 +118,8 @@ bool IsQuantizedPool(QuantizedOpType quant_op_type) {
bool IsQuantizedBinaryOp(QuantizedOpType quant_op_type) {
return quant_op_type == QuantizedOpType::QLinearMatMul ||
quant_op_type == QuantizedOpType::QLinearAdd ||
quant_op_type == QuantizedOpType::QDQAdd ||
quant_op_type == QuantizedOpType::QDQMul ||
IsQuantizedConv(quant_op_type);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ enum class QuantizedOpType : uint8_t {
QDQConv,
QDQResize,
QDQAveragePool,
QDQAdd,
QDQMul,
// TODO, add other QDQ NodeUnit types
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -656,8 +656,10 @@ class BinaryOpBuilder : public BaseOpBuilder {
};

/* static */ bool BinaryOpBuilder::IsQuantizedOp(const NodeUnit& node_unit) {
// TODO, add support for QDQ NodeUnit
return node_unit.OpType() == "QLinearAdd";
const auto quant_type = GetQuantizedOpType(node_unit);
return quant_type == QuantizedOpType::QLinearAdd ||
quant_type == QuantizedOpType::QDQAdd ||
quant_type == QuantizedOpType::QDQMul;
}

void BinaryOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const NodeUnit& node_unit) const {
Expand Down Expand Up @@ -690,12 +692,12 @@ Status BinaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const

int32_t op_code;
bool add_activation = true;
bool op_is_qlinear = op_type == "QLinearAdd";
if (op_type == "Add" || op_is_qlinear) {
bool is_quant_op = IsQuantizedOp(node_unit);
if (op_type == "Add" || is_quant_op) { // Add/QLinearAdd/QDQAdd
op_code = ANEURALNETWORKS_ADD;
} else if (op_type == "Sub") {
op_code = ANEURALNETWORKS_SUB;
} else if (op_type == "Mul") {
} else if (op_type == "Mul" || is_quant_op) { // Mul/QDQMul
op_code = ANEURALNETWORKS_MUL;
} else if (op_type == "Div") {
op_code = ANEURALNETWORKS_DIV;
Expand All @@ -721,15 +723,15 @@ Status BinaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const
b_zero_point = 0,
y_zero_point = 0;

if (op_is_qlinear) {
if (is_quant_op) {
ORT_RETURN_IF_ERROR(GetBinaryOpQuantizationScaleAndZeroPoint(
model_builder.GetInitializerTensors(), node_unit,
a_scale, b_scale, y_scale,
a_zero_point, b_zero_point, y_zero_point));
}

// Verify if the scale and zero point matchs from onnx input and nnapi input match
if (op_is_qlinear) {
if (is_quant_op) {
ORT_RETURN_IF_ERROR(IsValidInputQuantizedType(model_builder, input1, a_scale, a_zero_point));
ORT_RETURN_IF_ERROR(IsValidInputQuantizedType(model_builder, input2, b_scale, b_zero_point));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,7 @@ class BinaryOpSupportChecker : public BaseOpSupportChecker {
const OpSupportCheckParams& params) const override;
int GetMinSupportedOpSet(const NodeUnit& node_unit) const override;

bool IsNodeUnitTypeSupported(const NodeUnit& node_unit) const override;
static bool IsQuantizedOp(const NodeUnit& node_unit);
};

Expand All @@ -470,8 +471,21 @@ class BinaryOpSupportChecker : public BaseOpSupportChecker {
});
}

bool BinaryOpSupportChecker::IsNodeUnitTypeSupported(const NodeUnit& node_unit) const {
if (node_unit.UnitType() == NodeUnit::Type::QDQGroup) {
const auto quant_type = GetQuantizedOpType(node_unit);
return quant_type == QuantizedOpType::QDQAdd ||
quant_type == QuantizedOpType::QDQMul;
}

return true;
}

/* static */ bool BinaryOpSupportChecker::IsQuantizedOp(const NodeUnit& node_unit) {
return GetQuantizedOpType(node_unit) == QuantizedOpType::QLinearAdd;
const auto quant_type = GetQuantizedOpType(node_unit);
return quant_type == QuantizedOpType::QLinearAdd ||
quant_type == QuantizedOpType::QDQAdd ||
quant_type == QuantizedOpType::QDQMul;
}

int32_t BinaryOpSupportChecker::GetMinSupportedNNAPIFeatureLevel(
Expand Down Expand Up @@ -760,7 +774,7 @@ class PoolOpSupportChecker : public BaseOpSupportChecker {
bool HasSupportedInputOutputsImpl(
const InitializedTensorSet& initializers, const NodeUnit& node_unit,
const OpSupportCheckParams& params) const override;
bool IsNodeUnitTypeSupported(const NodeUnit& /* node_unit */) const override { return true; }
bool IsNodeUnitTypeSupported(const NodeUnit& /* node_unit */) const override;
static bool IsQuantizedOp(const NodeUnit& node_unit);
};

Expand All @@ -777,6 +791,15 @@ class PoolOpSupportChecker : public BaseOpSupportChecker {
});
}

bool PoolOpSupportChecker::IsNodeUnitTypeSupported(const NodeUnit& node_unit) const {
if (node_unit.UnitType() == NodeUnit::Type::QDQGroup) {
const auto quant_type = GetQuantizedOpType(node_unit);
return quant_type == QuantizedOpType::QDQAveragePool;
}

return true;
}

/* static */ bool PoolOpSupportChecker::IsQuantizedOp(const NodeUnit& node_unit) {
return IsQuantizedPool(GetQuantizedOpType(node_unit));
}
Expand Down
51 changes: 40 additions & 11 deletions onnxruntime/test/optimizer/qdq_test_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,27 @@ GetQDQTestCaseFn BuildQDQConvTestCase(const std::vector<int64_t>& input_shape, c
template <typename InputType, typename OutputType>
GetQDQTestCaseFn BuildQDQAveragePoolTestCase(const std::vector<int64_t>& input_shape) {
return [input_shape](ModelTestBuilder& builder) {

#ifdef USE_NNAPI // NNAPI require consistent scales for DQ -> Pool -> Q
float dq_scale = 0.0038f;
float pool_output_scale = 0.0038f;
float q_scale = 0.0038f;
InputType dq_zp = std::numeric_limits<OutputType>::max() / 2;
InputType pool_output_zp = std::numeric_limits<OutputType>::max() / 2;
InputType q_zp = std::numeric_limits<OutputType>::max() / 2;
#else
float dq_scale = 0.0035f;
float pool_output_scale = 0.0038f;
float q_scale = 0.0039f;
InputType dq_zp = 7;
InputType pool_output_zp = std::numeric_limits<OutputType>::max() / 2;
InputType q_zp = std::numeric_limits<OutputType>::max() / 2;
#endif

auto* input_arg = builder.MakeInput<float>(input_shape, -1.f, 1.f);
auto* output_arg = builder.MakeOutput();
// add QDQ + AveragePool
auto* dq_output = AddQDQNodePair<InputType>(builder, input_arg, .0035f, 7);
auto* dq_output = AddQDQNodePair<InputType>(builder, input_arg, dq_scale, dq_zp);
auto* averagepool_output = builder.MakeIntermediate();
Node& pool_node = builder.AddNode("AveragePool", {dq_output}, {averagepool_output});
std::vector<int64_t> pads((input_shape.size() - 2) * 2, 1);
Expand All @@ -95,12 +112,12 @@ GetQDQTestCaseFn BuildQDQAveragePoolTestCase(const std::vector<int64_t>& input_s
// add QDQ output
auto* q_output = builder.MakeIntermediate();
builder.AddQuantizeLinearNode<OutputType>(averagepool_output,
.0038f,
std::numeric_limits<OutputType>::max() / 2,
pool_output_scale,
pool_output_zp,
q_output);
builder.AddDequantizeLinearNode<OutputType>(q_output,
.0039f,
std::numeric_limits<OutputType>::max() / 2,
q_scale,
q_zp,
output_arg);
};
}
Expand All @@ -118,27 +135,39 @@ GetQDQTestCaseFn BuildBinaryOpTestCase(const std::vector<int64_t>& input_shape,
auto* input2_arg = builder.MakeInput<float>(input_shape, -1.f, 1.f);
auto* output_arg = builder.MakeOutput();

#ifdef USE_NNAPI // NNAPI require consistent scales for DQ -> bin op -> Q
float dq_scale = 0.0039f;
float op_input_scale = 0.0039f;
float op_output_scale = 0.0039f;
float q_scale = 0.0039f;
#else
float dq_scale = 0.004f;
float op_input_scale = 0.0039f;
float op_output_scale = 0.0038f;
float q_scale = 0.0039f;
#endif

// add QDQ 1
auto* q1_output = builder.MakeIntermediate();
auto* dq1_output = builder.MakeIntermediate();
builder.AddQuantizeLinearNode<Input1Type>(input1_arg,
.004f,
dq_scale,
std::numeric_limits<Input1Type>::max() / 2,
q1_output);
builder.AddDequantizeLinearNode<Input1Type>(q1_output,
.0039f,
op_input_scale,
std::numeric_limits<Input1Type>::max() / 2,
dq1_output);

// add QDQ 2
auto* q2_output = builder.MakeIntermediate();
auto* dq2_output = builder.MakeIntermediate();
builder.AddQuantizeLinearNode<Input2Type>(input2_arg,
.004f,
dq_scale,
std::numeric_limits<Input2Type>::max() / 2,
q2_output);
builder.AddDequantizeLinearNode<Input2Type>(q2_output,
.0039f,
op_input_scale,
std::numeric_limits<Input2Type>::max() / 2,
dq2_output);

Expand All @@ -149,11 +178,11 @@ GetQDQTestCaseFn BuildBinaryOpTestCase(const std::vector<int64_t>& input_shape,
// add QDQ output
auto* q3_output = builder.MakeIntermediate();
builder.AddQuantizeLinearNode<OutputType>(binary_op_output,
.0038f,
op_output_scale,
std::numeric_limits<OutputType>::max() / 2,
q3_output);
builder.AddDequantizeLinearNode<OutputType>(q3_output,
.0039f,
q_scale,
std::numeric_limits<OutputType>::max() / 2,
output_arg);
};
Expand Down
37 changes: 32 additions & 5 deletions onnxruntime/test/providers/nnapi/nnapi_basic_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,9 @@ TEST(NnapiExecutionProviderTest, TestNoShapeInputModel) {
<< "No node should be taken by the NNAPI EP";
}

static void RunQDQModelTest(const GetQDQTestCaseFn& build_test_case, const char* test_description) {
static void RunQDQModelTest(const GetQDQTestCaseFn& build_test_case,
const char* test_description,
const EPVerificationParams& params = EPVerificationParams()) {
onnxruntime::Model model(test_description, false, DefaultLoggingManager().DefaultLogger());
Graph& graph = model.MainGraph();
ModelTestBuilder helper(graph);
Expand All @@ -286,7 +288,7 @@ static void RunQDQModelTest(const GetQDQTestCaseFn& build_test_case, const char*
#if defined(__ANDROID__)
RunAndVerifyOutputsWithEP(model_data, "NnapiExecutionProviderTest.TestQDQModel",
std::make_unique<NnapiExecutionProvider>(0),
helper.feeds_);
helper.feeds_, params);
#else
// test load only
SessionOptions so;
Expand All @@ -306,7 +308,8 @@ TEST(NnapiExecutionProviderTest, TestQDQConv) {
uint8_t /* OutputType */>(
{1, 1, 5, 5} /*input_shape*/,
{1, 1, 3, 3} /*weights_shape*/),
"nnapi_qdq_test_graph_conv");
"nnapi_qdq_test_graph_conv",
{true /* verify_entire_graph_use_ep */});
}

TEST(NnapiExecutionProviderTest, TestQDQResize) {
Expand All @@ -316,14 +319,38 @@ TEST(NnapiExecutionProviderTest, TestQDQResize) {
{1, 3, 32, 32} /* sizes_data */,
"linear" /* mode */,
"asymmetric" /* coordinate_transformation_mode */),
"nnapi_qdq_test_graph_resize");
"nnapi_qdq_test_graph_resize",
{true /* verify_entire_graph_use_ep */});
}

TEST(NnapiExecutionProviderTest, TestQDQAveragePool) {
// NNAPI Pool use different rounding, which may cause ~1% difference in the result
RunQDQModelTest(BuildQDQAveragePoolTestCase<uint8_t /* InputType */,
uint8_t /* OutputType */>(
{1, 3, 32, 32} /* input_shape */),
"nnapi_qdq_test_graph_averagepool");
"nnapi_qdq_test_graph_averagepool",
{
true /* verify_entire_graph_use_ep */,
1e-2f /* fp32_abs_err */,
});
}

TEST(NnapiExecutionProviderTest, TestQDQBinaryOp) {
RunQDQModelTest(BuildBinaryOpTestCase<uint8_t /* Input1Type */,
uint8_t /* Input2Type */,
uint8_t /* OutputType */>(
{1, 23, 13, 13} /* input_shape */,
"Add" /* op_type */),
"nnapi_qdq_test_graph_add",
{true /* verify_entire_graph_use_ep */});

RunQDQModelTest(BuildBinaryOpTestCase<uint8_t /* Input1Type */,
uint8_t /* Input2Type */,
uint8_t /* OutputType */>(
{1, 23, 13, 13} /* input_shape */,
"Add" /* op_type */),
"nnapi_qdq_test_graph_mul",
{true /* verify_entire_graph_use_ep */});
}

#endif // !(ORT_MINIMAL_BUILD)
Expand Down
19 changes: 16 additions & 3 deletions onnxruntime/test/util/include/test_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,18 @@ class Graph;

namespace test {

// struct to hold some verification params for RunAndVerifyOutputsWithEP
struct EPVerificationParams {
// Verify the entire graph is taken by the EP
// if this is set to false, then will verify that at least one node is assigned to 'execution_provider'
bool verify_entire_graph_use_ep{false};

// Some EP may use different rounding than ORT CPU EP, which may cause a bigger abs error than
// the default of 1e-5f, especially for scenarios such as [Q -> Quantized op -> DQ]
// Set this only if this is necessary
float fp32_abs_err = 1e-5f;
};

// return number of nodes in the Graph and any subgraphs that are assigned to the specified execution provider
int CountAssignedNodes(const Graph& current_graph, const std::string& ep_type);

Expand All @@ -23,13 +35,14 @@ int CountAssignedNodes(const Graph& current_graph, const std::string& ep_type);
void RunAndVerifyOutputsWithEP(const ORTCHAR_T* model_path,
const char* log_id,
std::unique_ptr<IExecutionProvider> execution_provider,
const NameMLValMap& feeds);
const NameMLValMap& feeds,
const EPVerificationParams& params = EPVerificationParams());

// helper function that takes in model_data
// used in nnapi qdq model tests
void RunAndVerifyOutputsWithEP(const std::string& model_data,
const char* log_id,
std::unique_ptr<IExecutionProvider> execution_provider,
const NameMLValMap& feeds);
const NameMLValMap& feeds,
const EPVerificationParams& params = EPVerificationParams());
} // namespace test
} // namespace onnxruntime
Loading

0 comments on commit 1238049

Please sign in to comment.