-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Relay][Quantization] KL-divergence-based per-layer calibration #3538
Conversation
0377167
to
9d71db8
Compare
9d71db8
to
3d1d4cf
Compare
This one is ready. Please review and share your thoughts on calibration api design. |
3d1d4cf
to
99ffbc2
Compare
99ffbc2
to
0e55518
Compare
76b8b76
to
577387d
Compare
src/relay/pass/quantize.cc
Outdated
// ============= | ||
// calibration | ||
|
||
class StatsCollector : private ExprMutator { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If just collect stats, ExprVisitor
should be enough
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ExprMutator
is needed actually. This mutator transform annotated expr to original expr by removing each simulate_quantize
.
For example Relay program:
%1 = ..
%2 = simulate_quantize(%1)
%3 = op(%2)
%4 = simulate_quantize(%3)
We need to profile %1 and %3. But %3 takes %2 as input, we need to replace input of %3 with %1 (because in Annotate pass simulate_quantize in %2 is not in passthrough mode, we need to either remove it or rewrite it in passthrough mode)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ZihengJiang I was thinking that the other pr #3543 actually breaks this pass (because the result of this pass contains annotations and casts)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@vinx13 Why not collect stats before annotate?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ZihengJiang Annotations tell us which nodes should be profiled. If we want to collect stats before annotate, we need to repeat the code similar to annotate to decide which node should be quantized.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, let's keep the current way. #3543 will not breaks this pass since annotation.cast_hint
and annotation.stop_fusion
will not change the running result. They are just annotation and you can view them as identity
. One thing is, instead of detecting and jumping simulated_quantize
inside of IRMutator
, let's adding an option like simulated_quantize(kind=kIdentity)
for eliminating the impact of simulated_quantize
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ZihengJiang updated
@vinx13 Could you please address other comments? |
5f0406e
to
16b27d4
Compare
16b27d4
to
7ba8f30
Compare
@ZihengJiang @vinx13 please followup on this and let us merge soon |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I leave some comments, please ping me if I understand in a wrong way.
cc34c02
to
0dd38e7
Compare
10ef14a
to
493d14b
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM basically.
…he#3538) * [Relay][Quantization] Support floating-point scale * [Relay][Quantization] KL-divergence calibration on dataset * Fix unhandled LeftShift case in QuantizeRealize * Fix lint * drop QBias * fix lint * address comments * address comments * Update comments * address comments * lint * kQIdentity = 0
…he#3538) * [Relay][Quantization] Support floating-point scale * [Relay][Quantization] KL-divergence calibration on dataset * Fix unhandled LeftShift case in QuantizeRealize * Fix lint * drop QBias * fix lint * address comments * address comments * Update comments * address comments * lint * kQIdentity = 0
CollectStats
pass that collects input of eachsimulated_quantize
in annotated graph into a tuple outputmax_scale
as an alternative forpower2_scale
in weight quantizationEvaluation code
https://gist.github.com/vinx13/6f1eb1f9e2c0a8786149ee881bfcd6aa
What's left:
I addedQAnnotateKind.BIAS
. I'm not sure whether it is necessary. Currently there are a few tricks in handling bias (nbit_bias, valid_range, ...). It would be good to find a better solution and avoid these tricks.calibrate
function. We need to decide how users can specify different ways of quantization (max, power2, KLD, ...)Evaluation result on ImageNet:
max_scale for weights, KL divergence for activations:
resnet18_v1, 0.70642 / 0.89702
resnet50_v1, 0.73682 / 0.91664
resnet101_v1, 0.74484 / 0.9208
resnet18_v2, 0.70794 / 0.89832
resnet50_v2, 0.7691 / 0.93268
resnet101_v2, 0.78204 / 0.94124
power2 for weights, KL divergence restricted to power2 value for activations (use --eval-power2 option in my evaluation script):
resnet18_v1, 0.70332 / 0.89526
resnet50_v1, 0.73426 / 0.9146
resnet101_v1, 0.72434 / 0.91058
resnet18_v2, 0.70314 / 0.89618
resnet50_v2, 0.76486 / 0.93108
resnet101_v2, 0.78066 / 0.94002
These experiments are done under
opt_level=2
. Whenopt_level=3
,FoldScaleAxis
might generate some outliers in bias vector and cause significant accuracy drops. We should use different scales than taking the maximum for bias in this case.cc @tqchen @ZihengJiang @eqy @ajtulloch @antinucleon @FrozenGene