-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fixed point based requantization on arm64 #11540
Conversation
121f8ff
to
2927755
Compare
May want to add the rounding concerns in the description? In reply to: 1137660710 |
// S is the scale with type float | ||
// Z is the zero point with type same as TOutput. | ||
// min is the minimum value of type TOutput. | ||
// max is the maximum value of type TOutput. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great comments, very informative!
- This is a performance improvement, yet the option is by default "off", due to rounding errors? Would you consider specifying the reason for this?
- The same trick might work in other CPUs too. Since it's default off, maybe removing ARM from name? #Resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, it is turned off by default for rounding. It rounds half to up, however onnx specs requires to round half to even. Will add the info.
There is no plan to support it on x86. keeping arm here can avoid the confusion that users think x86 supports similar option.
@@ -500,8 +508,7 @@ MlasConvSym( | |||
} | |||
|
|||
MLAS_CONV_SYM_POST_PROCESS_PARAMS PostProcessParams = {}; | |||
|
|||
MlasConvSymSetOutputZeroPoint(PostProcessParams, Params.OutputZeroPoint, Params.InputIsSigned); | |||
MlasConvSymSetOutputZeroPoint(PostProcessParams, OutputZeroPoint, Params.InputIsSigned); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this done repeatedly on the same set of parameters? should we consider moving this out in the future? Need to change MLAS interface to do that. Since MLAS is not public yet maybe ok?
void | ||
MLASCALL | ||
MlasConvSymDepthwiseKernelSize25ArmU8S8( | ||
MlasConvSymDepthwiseKernelSize25ArmS8S8Impl( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks like a huge change. It used to be the U8S8 code is defined before S8S8, now it is reversed. it seems that this is the cause of most of the changes. Why do you need to flip the position of these two? #Resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it is unintentional. let me reverse them back.
Multiplier(Multiplier), PreShift(PreShift), PostShift(PostShift), | ||
Size(Size), ZeroPoint(ZeroPoint){} | ||
|
||
MLAS_ROUND_KIND RequantRoundKind; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this field might be redundant, it can be deduced from the value of Scale or Multiplier (0 vs none 0) #Resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is more descriptive. Would like to keep it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can be replaced by a const method
Since the change involves change to the kernels, especially adding branches, Consider Adding perf tests with one of the models with big conv or gemm ops? |
This reverts commit 1f2c926. Because it makes our packaging pipeline crash Error message: [ RUN ] QLinearConvTest.Conv3D_S8S8_Depthwise Test #1: onnxruntime_test_all ...................Subprocess killed***Exception: 838.24 sec We haven't successfully reproduced the bug on a real ARM64 hardware. Currently we only saw it showed up with qemu. More investigations are on-going.
This PR adds fixed point based requantization for ARM64 devices.
Requantization is computed with formula:
v = round(clamp(S * (I - Z), min, max))
where v is the target value with type TOutput, which is either int8_t or uint8_t
I is the input value with type int32_t
S is the scale with type float
Z is the zero point with type same as TOutput.
min is the minimum value of type TOutput.
max is the maximum value of type TOutput.
For considerations of power consumption and some ARM devices don't even have FPUs, it is import to to be able to run
quantization with integer instructions only.FixedPoint Requantization is introduced to support this feature.Its general
idea is to convert scale S to fixed point. Ruy and XNNPack's method are referred for the implementation.
// NOTE that fixed point requantization rounds half to up, whereas ONNX spec rounds half to even, so for identical
// model and input the inference results may not be exactly same with option kOrtSessionOptionsConfigFixedPointRequantOnARM64 on and off. The impact should be
// small in practice (NNApi EP uses same rounding).