Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NNAPI QDQ] Add QDQReshape op support #10533

Merged
merged 13 commits into from
Feb 15, 2022
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ QuantizedOpType GetQuantizedOpType(const NodeUnit& node_unit) {
return QuantizedOpType::QDQMul;
else if (op_type == "Transpose")
return QuantizedOpType::QDQTranspose;
else if (op_type == "Reshape")
return QuantizedOpType::QDQReshape;
} else {
// throw?
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ enum class QuantizedOpType : uint8_t {
QDQAdd,
QDQMul,
QDQTranspose,
QDQReshape,
// TODO, add other QDQ NodeUnit types
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -861,18 +861,28 @@ class ReshapeOpBuilder : public BaseOpBuilder {
public:
void AddInitializersToSkip(ModelBuilder& model_builder, const NodeUnit& node_unit) const override;
static Status AddReshapeOperator(ModelBuilder& model_builder, const NodeUnit& node_unit,
const std::string& input, const std::vector<int32_t>& shape);
const std::string& input, const std::vector<int32_t>& shape,
float scale = 0.0f, int32_t zero_point = 0);

private:
Status AddToModelBuilderImpl(ModelBuilder& model_builder, const NodeUnit& node_unit) const override;
static bool CanSkipReshape(const ModelBuilder& model_builder, const NodeUnit& node_unit,
size_t input_rank, size_t output_rank);
static bool IsQuantizedOp(const NodeUnit& node_unit) ORT_MUST_USE_RESULT; // TODO, see if we want to move this to BaseOpBuilder
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see if we want to move this to BaseOpBuilder

doesn't have to be in this PR, but when will the decision be made?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, @gwang-msft do you think we should move it to baseopbuilder now or keep it here at individual opbuilder level as we don't have that many qdq ops supported?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can move it as a virtual function for BaseOpBuilder and by default return false, each individual builder will override this if necessary
Same for BaseOpSupportChecker

};

void ReshapeOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const NodeUnit& node_unit) const {
if (IsQuantizedOp(node_unit)) {
AddQuantizationScaleAndZeroPointToSkip(model_builder, *node_unit.Inputs()[0].quant_param); // x_scale, x_zp
AddQuantizationScaleAndZeroPointToSkip(model_builder, *node_unit.Outputs()[0].quant_param); // y_scale, y_zp
}
model_builder.AddInitializerToSkip(node_unit.Inputs()[1].node_arg.Name());
}

/* static */ bool ReshapeOpBuilder::IsQuantizedOp(const NodeUnit& node_unit) {
return GetQuantizedOpType(node_unit) == QuantizedOpType::QDQReshape;
}

// We can skip the Reshape if all the output edges satisfies both the following conditions
// 1. The output the reshape/flatten is not an output of the graph
// 2. The output of the reshape/flatten is the input 0 of one or more GEMM/Matmul operators,
Expand Down Expand Up @@ -947,7 +957,8 @@ void ReshapeOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const
/* static */ Status ReshapeOpBuilder::AddReshapeOperator(ModelBuilder& model_builder,
const NodeUnit& node_unit,
const std::string& input,
const std::vector<int32_t>& shape) {
const std::vector<int32_t>& shape,
float scale, int32_t zero_point) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will be easier to get the scale and zero_point from the input inside this function, instead of passing them in, use something like this

+  // For reshape, the output type should be the same as the input type except the shape is different
+  auto output_operand_type = operand_types.at(input);
+  output_operand_type.SetDimensions(shaper[output]);
+
   // Since Reshape is not running using hardware in NNAPI for some CPU (e.g. Qualcomm SD for now)
   // We will try to see if we the skip the Reshape to prevent context switching between
   // NNAPI CPU impl and NNAPI hardware accelerator impl
   if (CanSkipReshape(model_builder, node_unit, input_rank, output_rank)) {
     // Since reshape can be skipped, only register the dimension and type, with same index and new name
-    const OperandType output_operand_type(operand_types.at(input).type, shaper[output], scale, zero_point);
     model_builder.RegisterOperand(output, operand_indices.at(input), output_operand_type, false);
   } else {
     // We still need to perform a reshape here
     // Add new shape
     Shape shape_dimen = {static_cast<uint32_t>(shape.size())};
     std::string shape_name = model_builder.GetUniqueName(node_unit.Name() + input + "newshape");
-    OperandType shape_operand_type(Type::TENSOR_INT32, shape_dimen, scale, zero_point);
+    OperandType shape_operand_type(Type::TENSOR_INT32, shape_dimen);
     ORT_RETURN_IF_ERROR(model_builder.AddOperandFromPersistMemoryBuffer(shape_name, shape.data(), shape_operand_type));
     input_indices.push_back(operand_indices.at(shape_name));
-
-    const OperandType output_operand_type(operand_types.at(input).type, shaper[output], scale, zero_point);
     ORT_RETURN_IF_ERROR(model_builder.AddOperation(ANEURALNETWORKS_RESHAPE, input_indices, {output}, {output_operand_type}, {false}));
   }

auto& shaper(model_builder.GetShaper());
const auto& operand_indices(model_builder.GetOperandIndices());
const auto& operand_types(model_builder.GetOperandTypes());
Expand All @@ -961,7 +972,7 @@ void ReshapeOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const
// NNAPI CPU impl and NNAPI hardware accelerator impl
if (CanSkipReshape(model_builder, node_unit, input_rank, output_rank)) {
// Since reshape can be skipped, only register the dimension and type, with same index and new name
const OperandType output_operand_type(operand_types.at(input).type, shaper[output]);
const OperandType output_operand_type(operand_types.at(input).type, shaper[output], scale, zero_point);
model_builder.RegisterOperand(output, operand_indices.at(input), output_operand_type, false);
} else {
// We still need to perform a reshape here
Expand All @@ -971,11 +982,11 @@ void ReshapeOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const
// Add new shape
Shape shape_dimen = {static_cast<uint32_t>(shape.size())};
std::string shape_name = model_builder.GetUniqueName(node_unit.Name() + input + "newshape");
OperandType shape_operand_type(Type::TENSOR_INT32, shape_dimen);
OperandType shape_operand_type(Type::TENSOR_INT32, shape_dimen, scale, zero_point);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the new shape of the reshape op, it does not need scale and zero_point

ORT_RETURN_IF_ERROR(model_builder.AddOperandFromPersistMemoryBuffer(shape_name, shape.data(), shape_operand_type));
input_indices.push_back(operand_indices.at(shape_name));

const OperandType output_operand_type(operand_types.at(input).type, shaper[output]);
const OperandType output_operand_type(operand_types.at(input).type, shaper[output], scale, zero_point);
ORT_RETURN_IF_ERROR(model_builder.AddOperation(ANEURALNETWORKS_RESHAPE, input_indices, {output}, {output_operand_type}, {false}));
}

Expand Down Expand Up @@ -1006,7 +1017,16 @@ Status ReshapeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, cons
shape[i] = dim == 0 ? input_shape[i] : dim;
}

return AddReshapeOperator(model_builder, node_unit, input, shape);
// Check if the quantization scale and ZP are correct
float x_scale = 0.0f;
int32_t x_zero_point = 0;
if (IsQuantizedOp(node_unit)) {
ORT_RETURN_IF_ERROR(GetQuantizationScaleAndZeroPoint(
initializers, node_unit.Inputs()[0], node_unit.ModelPath(), x_scale, x_zero_point));
ORT_RETURN_IF_ERROR(IsValidInputQuantizedType(model_builder, input, x_scale, x_zero_point));
}

return AddReshapeOperator(model_builder, node_unit, input, shape, x_scale, x_zero_point);
}

#pragma endregion op_reshape
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -653,8 +653,17 @@ class ReshapeOpSupportChecker : public BaseOpSupportChecker {

// Reshape opset 4- uses attributes for new shape which we do not support for now
int GetMinSupportedOpSet(const NodeUnit& /* node_unit */) const override { return 5; }
bool HasSupportedInputOutputsImpl(
const InitializedTensorSet& /* initializers */, const NodeUnit& node_unit,
const OpSupportCheckParams& /* params */) const override;
bool IsNodeUnitTypeSupported(const NodeUnit& /* node_unit */) const override { return true; }
static bool IsQuantizedOp(const NodeUnit& node_unit) ORT_MUST_USE_RESULT; // TODO, see if we want to move this to BaseOpBuilder
};

/* static */ bool ReshapeOpSupportChecker::IsQuantizedOp(const NodeUnit& node_unit) {
return GetQuantizedOpType(node_unit) == QuantizedOpType::QDQReshape;
}

bool ReshapeOpSupportChecker::IsOpSupportedImpl(const InitializedTensorSet& initializers, const NodeUnit& node_unit,
const OpSupportCheckParams& /* params */) const {
const auto& inputs = node_unit.Inputs();
Expand Down Expand Up @@ -685,7 +694,7 @@ bool ReshapeOpSupportChecker::IsOpSupportedImpl(const InitializedTensorSet& init
const auto perm_size = SafeInt<uint32_t>(perm_tensor.dims()[0]);

NodeAttrHelper helper(node_unit);
const bool allow_zero = helper.Get("allowzero ", 0) == 1;
const bool allow_zero = helper.Get("allowzero", 0) == 1;
for (uint32_t i = 0; i < perm_size; i++) {
// NNAPI reshape does not support 0 as dimension
if (raw_perm[i] == 0) {
Expand All @@ -704,6 +713,24 @@ bool ReshapeOpSupportChecker::IsOpSupportedImpl(const InitializedTensorSet& init
return true;
}

bool ReshapeOpSupportChecker::HasSupportedInputOutputsImpl(
const InitializedTensorSet& initializers, const NodeUnit& node_unit,
const OpSupportCheckParams& params) const {
if (!IsQuantizedOp(node_unit)) {
return BaseOpSupportChecker::HasSupportedInputOutputsImpl(initializers, node_unit, params);
}

if (!IsQuantizedIOSupported(initializers, node_unit, {0}, params, IOKind::Input)) {
return false;
}

if (!IsQuantizedIOSupported(initializers, node_unit, {0}, params, IOKind::Output)) {
return false;
}

return true;
}

#pragma endregion

#pragma region op_batchnormalization
Expand Down
22 changes: 22 additions & 0 deletions onnxruntime/test/optimizer/qdq_test_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,5 +36,27 @@ GetQDQTestCaseFn BuildQDQResizeTestCase(
};
}

GetQDQTestCaseFn BuildQDQReshapeTestCase(const std::vector<int64_t>& input_shape,
const std::vector<int64_t>& reshape_shape) {
return [input_shape, reshape_shape](ModelTestBuilder& builder) {
auto* input_arg = builder.MakeInput<uint8_t>(input_shape,
std::numeric_limits<uint8_t>::min(),
std::numeric_limits<uint8_t>::max());
auto* output_arg = builder.MakeOutput();

// add DQ
auto* dq_output = builder.MakeIntermediate();
builder.AddDequantizeLinearNode<uint8_t>(input_arg, .003f, 1, dq_output);

// add Reshape
auto* reshape_output = builder.MakeIntermediate();
auto* shape = builder.Make1DInitializer<int64_t>(reshape_shape);
builder.AddNode("Reshape", {dq_output, shape}, {reshape_output});

// add Q
builder.AddQuantizeLinearNode<uint8_t>(reshape_output, .003f, 1, output_arg);
};
}

} // namespace test
} // namespace onnxruntime
3 changes: 3 additions & 0 deletions onnxruntime/test/optimizer/qdq_test_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -212,5 +212,8 @@ GetQDQTestCaseFn BuildQDQTransposeTestCase(
builder.AddQuantizeLinearNode<OutputType>(transpose_output, .003f, q_zp, output_arg);
};
}

GetQDQTestCaseFn BuildQDQReshapeTestCase(const std::vector<int64_t>& input_shape,
const std::vector<int64_t>& reshape_shape);
} // namespace test
} // namespace onnxruntime
9 changes: 9 additions & 0 deletions onnxruntime/test/providers/nnapi/nnapi_basic_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,15 @@ TEST(NnapiExecutionProviderTest, TestQDQTranspose) {
});
}

TEST(NnapiExecutionProviderTest, TestQDQReshape) {
RunQDQModelTest(BuildQDQReshapeTestCase({1, 3, 64, 64} /* input_shape */,
{1, 64, 64, 3} /* reshape_shape */),
"nnapi_qdq_test_graph_reshape",
{
true /* verify_entire_graph_use_ep */
});
}

#endif // !(ORT_MINIMAL_BUILD)

TEST(NnapiExecutionProviderTest, NNAPIFlagsTest) {
Expand Down