-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Performance] Unable to use calibration table for INT8 TensorRT execution. #17235
Comments
From your script, when you create the Please see the examples of how to provide the dataset.
https://github.com/microsoft/onnxruntime-inference-examples/blob/main/quantization/image_classification/trt/resnet50/e2e_tensorrt_resnet_example.py#L338C9-L338C20 One thing to double check is to see the content of |
What do you mean by "you should provide the real dataset for the data reader"? I used a dataset or random tensors in the code I provided since I cannot share the details but in reality the I double checked the json and cache files and they look fine. Here is an excerpt of the json (truncated): {"/model/blocks/blocks.5/blocks.5.2/se/Mul_output_0": [-0.27840250730514526, 47.94845199584961], "/model/blocks/blocks.1/blocks.1.0/bn1/act/Sigmoid_output_0": [0.0, 1.0], "/model/blocks/blocks.4/blocks.4.0/bn2/act/Sigmoid_output_0": [0.0, 1.0], "/model/blocks/blocks.6/blocks.6.0/se/act1/Sigmoid_output_0": [0.0, 1.0], "/model/blocks/blocks.3/blocks.3.1/se/ReduceMean_output_0": [-0.276947021484375, 35.54975128173828], "/model/blocks/blocks.3/blocks.3.1/bn1/act/Sigmoid_output_0": [0.0, 1.0], "/model/blocks/blocks.4/blocks.4.1/se/act1/Mul_output_0": [-0.2784646153450012, 11.509453773498535], "/model/blocks/blocks.5/blocks.5.0/se/ReduceMean_output_0": [-0.2540878653526306, 10.721695899963379], "/model/blocks/blocks.6/blocks.6.0/conv_pw/Conv_output_0": [-25.566896438598633, 24.689128875732422], "/model/blocks/blocks.1/blocks.1.1/se/Mul_output_0": [-0.27753210067749023, 201.57601928710938], "/model/blocks/blocks.3/blocks.3.2/se/act1/Sigmoid_output_0": [0.0, 0.9999947547912598], "/model/blocks/blocks.2/blocks.2.0/se/act1/Sigmoid_output_0": [0.0, 0.999976396560669]} If we imagine that |
Ah, i didn't notice you mentioned the dataset is private. Okay, the calibration table seems okay. But why we still got this error message from TRT, i need to investigate more. One other configuration to check is the calibration method and ORT quantization script uses minmax as default. I'm curious if we choose percentile or entropy method, will accuracy increase? |
TensorRT uses entropy calibration by default while ORT uses minmax. I’ll try again ORT entropy but I think I already tried in the past and it gave me OOM issues. TensorRT with minmax calibration yields poor results with my network, thus I kept the default entropy calibration. I’ll try to investigate the TensorRT calibration table but the error/bug seems to be in the runtime inference rather than in the calibration process. |
Re: but the error/bug seems to be in the runtime inference rather than in the calibration process. Yes, the message is from TRT and i'm investigating now and will reach out to Nvida about this if needed. |
@laclouis5 Please also note that if the user is properly setting the dynamic ranges of the tensors, then the warning can be safely ignored. Even if we provided dynamic ranges for each tensor, the warning will still print since the INT8 calibrator is not set. The default values for dynamic ranges are +-NaN, but the NaNs are ignored and the layers run in higher precision as per the messages: “Missing scale and zero-point for tensor (Unnamed Layer* 244) [Matrix Multiply]_output, expect fall back to non-int8 implementation for any layer consuming or producing given tensor”. |
I think two things you can try to improve accuracy:
|
So to sum up, the warning can be ignored and the calibration did take place for my network. It’s just that this calibration method (min-max) yields poor performance for my specific network. I obtained good performances with TensorRT Entropy calibration so I’ll try this with ONNX and the OOM workaround. Thanks! |
FYI I was finally able to correctly calibrate my model using ONNX calibration tools. The only setting that reached an acceptable performance was the Percentile calibration with a percentile value of 99.9 (on par with TensorRT Entropy calibration). However, I encountered some bugs. When using Percentile and Entropy calibration, the calibration data (as returned by Calibration using Percentile and Entropy is also very slow, probably an order of magnitude slower than TensorRT Entropy calibration. MinMax is fast, though. |
Thanks for the follow up and providing the feedback. The Percentile and Entropy calibration method of onnxruntime calibration tools uses numpy histogram library to create the histogram of each tensor and collect the data which might be way slower than MinMax calibration method. As for why TensorRT Entropy calibration is fast, because it's a black box for us so i can only guess probably they did some optimization for the calibration or even run it on GPU, but for ORT, the numpy library is running on CPU. |
Excellent post. I had experienced the exact same issues. The default (minmax?) does not work. But the other two methods work (Entropy and Percentile), but only after we replace np.float32 with float as @laclouis5 kindly informed us. In my case, I have a dict with different datatypes (lidar, camera), so datasetup is a little different, although it broadly follows the same contrails as the Imagenet example. |
Describe the issue
I'm trying to run EfficientNetB0 (classification) from
timm
on theTensorRTExecutionEngine
using INT8 execution but I got a warning related to the calibration data not being used during inference:The execution is fine and the inference time on par with TensorRT native INT8 quantization but the accuracy is very low probably because of the calibration data not being used as mentioned by the warning message. I tried doing inference with logging activated but it did not improved the log messages. I don't have further information on this issue unfortunately.
For context, its an EfficientNetB0 from
timm
trained on a private dataset thus I cannot share the details of the calibration data or the model weights. It is used to classify 6 classes and it achieves around 89% topk-1 accuracy. There configuration works as expected, both in term of accuracy and expected inference speed:CUDAExecutionProvider
TensorrtExecutionProvider
and FP16 enabledSince TensorRT INT8 works I expected ONNX TRT INT8 to also work. Unfortunately, ONNX TRT Quantized yields an accuracy of 17% (not better than random guess) which is probably due to the calibration data not being used as suggested but the warning message.
To reproduce
I closely followed the documentation and the Resnet50 example (here).
The model weights and validation dataset are private but here is a script that should reproduce the warning message:
Export logs:
Most warnings are fine, they also occur in the un-quantized model and do not affect the performance of the model. Only the third one looks concerning.
Urgency
No response
Platform
Linux
OS Version
Ubuntu 20.04
ONNX Runtime Installation
Released Package
ONNX Runtime Version or Commit ID
1.15
ONNX Runtime API
Python
Architecture
X64
Execution Provider
TensorRT
Execution Provider Library Version
No response
Model File
No response
Is this a quantized model?
Yes
The text was updated successfully, but these errors were encountered: