-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[NNAPI QDQ] Add QDQReshape op support #10533
Changes from 6 commits
d06928e
724e305
47ad71a
c245cb5
7107029
7b9ae8a
dd772fb
df51aa9
cfe4fc5
a778b55
ea24eba
3fad98b
2a45a2c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -861,18 +861,28 @@ class ReshapeOpBuilder : public BaseOpBuilder { | |
public: | ||
void AddInitializersToSkip(ModelBuilder& model_builder, const NodeUnit& node_unit) const override; | ||
static Status AddReshapeOperator(ModelBuilder& model_builder, const NodeUnit& node_unit, | ||
const std::string& input, const std::vector<int32_t>& shape); | ||
const std::string& input, const std::vector<int32_t>& shape, | ||
float scale = 0.0f, int32_t zero_point = 0); | ||
|
||
private: | ||
Status AddToModelBuilderImpl(ModelBuilder& model_builder, const NodeUnit& node_unit) const override; | ||
static bool CanSkipReshape(const ModelBuilder& model_builder, const NodeUnit& node_unit, | ||
size_t input_rank, size_t output_rank); | ||
static bool IsQuantizedOp(const NodeUnit& node_unit) ORT_MUST_USE_RESULT; // TODO, see if we want to move this to BaseOpBuilder | ||
}; | ||
|
||
void ReshapeOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const NodeUnit& node_unit) const { | ||
if (IsQuantizedOp(node_unit)) { | ||
AddQuantizationScaleAndZeroPointToSkip(model_builder, *node_unit.Inputs()[0].quant_param); // x_scale, x_zp | ||
AddQuantizationScaleAndZeroPointToSkip(model_builder, *node_unit.Outputs()[0].quant_param); // y_scale, y_zp | ||
} | ||
model_builder.AddInitializerToSkip(node_unit.Inputs()[1].node_arg.Name()); | ||
} | ||
|
||
/* static */ bool ReshapeOpBuilder::IsQuantizedOp(const NodeUnit& node_unit) { | ||
return GetQuantizedOpType(node_unit) == QuantizedOpType::QDQReshape; | ||
} | ||
|
||
// We can skip the Reshape if all the output edges satisfies both the following conditions | ||
// 1. The output the reshape/flatten is not an output of the graph | ||
// 2. The output of the reshape/flatten is the input 0 of one or more GEMM/Matmul operators, | ||
|
@@ -947,7 +957,8 @@ void ReshapeOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const | |
/* static */ Status ReshapeOpBuilder::AddReshapeOperator(ModelBuilder& model_builder, | ||
const NodeUnit& node_unit, | ||
const std::string& input, | ||
const std::vector<int32_t>& shape) { | ||
const std::vector<int32_t>& shape, | ||
float scale, int32_t zero_point) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It will be easier to get the scale and zero_point from the input inside this function, instead of passing them in, use something like this + // For reshape, the output type should be the same as the input type except the shape is different
+ auto output_operand_type = operand_types.at(input);
+ output_operand_type.SetDimensions(shaper[output]);
+
// Since Reshape is not running using hardware in NNAPI for some CPU (e.g. Qualcomm SD for now)
// We will try to see if we the skip the Reshape to prevent context switching between
// NNAPI CPU impl and NNAPI hardware accelerator impl
if (CanSkipReshape(model_builder, node_unit, input_rank, output_rank)) {
// Since reshape can be skipped, only register the dimension and type, with same index and new name
- const OperandType output_operand_type(operand_types.at(input).type, shaper[output], scale, zero_point);
model_builder.RegisterOperand(output, operand_indices.at(input), output_operand_type, false);
} else {
// We still need to perform a reshape here
// Add new shape
Shape shape_dimen = {static_cast<uint32_t>(shape.size())};
std::string shape_name = model_builder.GetUniqueName(node_unit.Name() + input + "newshape");
- OperandType shape_operand_type(Type::TENSOR_INT32, shape_dimen, scale, zero_point);
+ OperandType shape_operand_type(Type::TENSOR_INT32, shape_dimen);
ORT_RETURN_IF_ERROR(model_builder.AddOperandFromPersistMemoryBuffer(shape_name, shape.data(), shape_operand_type));
input_indices.push_back(operand_indices.at(shape_name));
-
- const OperandType output_operand_type(operand_types.at(input).type, shaper[output], scale, zero_point);
ORT_RETURN_IF_ERROR(model_builder.AddOperation(ANEURALNETWORKS_RESHAPE, input_indices, {output}, {output_operand_type}, {false}));
}
|
||
auto& shaper(model_builder.GetShaper()); | ||
const auto& operand_indices(model_builder.GetOperandIndices()); | ||
const auto& operand_types(model_builder.GetOperandTypes()); | ||
|
@@ -961,7 +972,7 @@ void ReshapeOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const | |
// NNAPI CPU impl and NNAPI hardware accelerator impl | ||
if (CanSkipReshape(model_builder, node_unit, input_rank, output_rank)) { | ||
// Since reshape can be skipped, only register the dimension and type, with same index and new name | ||
const OperandType output_operand_type(operand_types.at(input).type, shaper[output]); | ||
const OperandType output_operand_type(operand_types.at(input).type, shaper[output], scale, zero_point); | ||
model_builder.RegisterOperand(output, operand_indices.at(input), output_operand_type, false); | ||
} else { | ||
// We still need to perform a reshape here | ||
|
@@ -971,11 +982,11 @@ void ReshapeOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const | |
// Add new shape | ||
Shape shape_dimen = {static_cast<uint32_t>(shape.size())}; | ||
std::string shape_name = model_builder.GetUniqueName(node_unit.Name() + input + "newshape"); | ||
OperandType shape_operand_type(Type::TENSOR_INT32, shape_dimen); | ||
OperandType shape_operand_type(Type::TENSOR_INT32, shape_dimen, scale, zero_point); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is the new shape of the reshape op, it does not need scale and zero_point |
||
ORT_RETURN_IF_ERROR(model_builder.AddOperandFromPersistMemoryBuffer(shape_name, shape.data(), shape_operand_type)); | ||
input_indices.push_back(operand_indices.at(shape_name)); | ||
|
||
const OperandType output_operand_type(operand_types.at(input).type, shaper[output]); | ||
const OperandType output_operand_type(operand_types.at(input).type, shaper[output], scale, zero_point); | ||
ORT_RETURN_IF_ERROR(model_builder.AddOperation(ANEURALNETWORKS_RESHAPE, input_indices, {output}, {output_operand_type}, {false})); | ||
} | ||
|
||
|
@@ -1006,7 +1017,16 @@ Status ReshapeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, cons | |
shape[i] = dim == 0 ? input_shape[i] : dim; | ||
} | ||
|
||
return AddReshapeOperator(model_builder, node_unit, input, shape); | ||
// Check if the quantization scale and ZP are correct | ||
float x_scale = 0.0f; | ||
int32_t x_zero_point = 0; | ||
if (IsQuantizedOp(node_unit)) { | ||
ORT_RETURN_IF_ERROR(GetQuantizationScaleAndZeroPoint( | ||
initializers, node_unit.Inputs()[0], node_unit.ModelPath(), x_scale, x_zero_point)); | ||
ORT_RETURN_IF_ERROR(IsValidInputQuantizedType(model_builder, input, x_scale, x_zero_point)); | ||
} | ||
|
||
return AddReshapeOperator(model_builder, node_unit, input, shape, x_scale, x_zero_point); | ||
} | ||
|
||
#pragma endregion op_reshape | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
doesn't have to be in this PR, but when will the decision be made?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, @gwang-msft do you think we should move it to baseopbuilder now or keep it here at individual opbuilder level as we don't have that many qdq ops supported?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can move it as a virtual function for BaseOpBuilder and by default return false, each individual builder will override this if necessary
Same for BaseOpSupportChecker