Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Performance] Unable to use calibration table for INT8 TensorRT execution. #17235

Closed
laclouis5 opened this issue Aug 21, 2023 · 11 comments
Closed
Labels
ep:TensorRT issues related to TensorRT execution provider quantization issues related to quantization

Comments

@laclouis5
Copy link

laclouis5 commented Aug 21, 2023

Describe the issue

I'm trying to run EfficientNetB0 (classification) from timm on the TensorRTExecutionEngine using INT8 execution but I got a warning related to the calibration data not being used during inference:

2023-08-21 16:33:00.127300669 [W:onnxruntime:Default, tensorrt_execution_provider.h:75 log] [2023-08-21 14:33:00 WARNING] Calibrator is not being used. Users must provide dynamic range for all tensors that are not Int32 or Bool.

The execution is fine and the inference time on par with TensorRT native INT8 quantization but the accuracy is very low probably because of the calibration data not being used as mentioned by the warning message. I tried doing inference with logging activated but it did not improved the log messages. I don't have further information on this issue unfortunately.

For context, its an EfficientNetB0 from timm trained on a private dataset thus I cannot share the details of the calibration data or the model weights. It is used to classify 6 classes and it achieves around 89% topk-1 accuracy. There configuration works as expected, both in term of accuracy and expected inference speed:

  • Pytorch execution in mixed precision
  • ONNX with CUDAExecutionProvider
  • ONNX with TensorrtExecutionProvider and FP16 enabled
  • Native TensorRT using Torch-TensorRT framework for conversion
  • TensorRT with INT8 quantization even though the accuracy is lower but still acceptable (82%)

Since TensorRT INT8 works I expected ONNX TRT INT8 to also work. Unfortunately, ONNX TRT Quantized yields an accuracy of 17% (not better than random guess) which is probably due to the calibration data not being used as suggested but the warning message.

To reproduce

I closely followed the documentation and the Resnet50 example (here).

The model weights and validation dataset are private but here is a script that should reproduce the warning message:

import torch
import timm
import tensorrt  # For onnxruntime
import onnx
import onnxruntime as ort
from onnxruntime.quantization import (
    CalibrationDataReader,
    create_calibrator,
    write_calibration_table,
)

N = 1_000
B, C, H, W = 8, 3, 200, 200


class ONNXCalibrationDataset(CalibrationDataReader):
    def __init__(self):
        super().__init__()
        self.count = 0
        self.total = N

    def get_next(self) -> dict:
        if self.count < self.total:
            self.count += 1
            return {"input": torch.randn(B, C, H, W).numpy()}
        else:
            return None


def export():
    model = (
        timm.create_model("efficientnet_b0", pretrained=True, exportable=True)
        .eval()
        .to(memory_format=torch.channels_last)
    )

    sample = torch.randn(B, C, H, W)

    torch.onnx.export(
        model,
        args=sample,
        f="model.onnx",
        input_names=["input"],
        output_names=["output"],
        dynamic_axes={"input": {0: "batch_size"}},
    )

    model = onnx.load("model.onnx")
    model = onnx.shape_inference.infer_shapes(model)
    onnx.save(model, "model_dynamic.onnx")

    calibration_dataset = ONNXCalibrationDataset()
    calibrator = create_calibrator(
        "model_dynamic.onnx",
        [],
        "model_augmented.onnx",
    )
    calibrator.set_execution_providers(["CUDAExecutionProvider"])
    calibrator.collect_data(data_reader=calibration_dataset)
    write_calibration_table(calibrator.compute_range())


def run():
    model = ort.InferenceSession(
        "model_dynamic.onnx",
        providers=[
            (
                "TensorrtExecutionProvider",
                {
                    "trt_fp16_enable": True,
                    "trt_int8_enable": True,
                    "trt_int8_calibration_table_name": "calibration.flatbuffers",
                },
            ),
            "CUDAExecutionProvider",
        ],
    )

    data = torch.randn(B, C, H, W)
    _ = model.run(["output"], {"input": data.numpy()})


def main():
    export()
    run()


if __name__ == "__main__":
    main()

Export logs:

2023-08-21 17:24:54.982333189 [W:onnxruntime:Default, tensorrt_execution_provider.h:75 log] [2023-08-21 15:24:54 WARNING] onnx2trt_utils.cpp:374: Your ONNX model has been generated with INT64 weights, while TensorRT does not natively support INT64. Attempting to cast down to INT32.
2023-08-21 17:24:59.346273604 [W:onnxruntime:Default, tensorrt_execution_provider.h:75 log] [2023-08-21 15:24:59 WARNING] onnx2trt_utils.cpp:374: Your ONNX model has been generated with INT64 weights, while TensorRT does not natively support INT64. Attempting to cast down to INT32.
2023-08-21 17:24:59.358624197 [W:onnxruntime:Default, tensorrt_execution_provider.h:75 log] [2023-08-21 15:24:59 WARNING] Calibrator is not being used. Users must provide dynamic range for all tensors that are not Int32 or Bool.
2023-08-21 17:24:59.358977861 [W:onnxruntime:Default, tensorrt_execution_provider.h:75 log] [2023-08-21 15:24:59 WARNING] Missing scale and zero-point for tensor (Unnamed Layer* 244) [Matrix Multiply]_output, expect fall back to non-int8 implementation for any layer consuming or producing given tensor
2023-08-21 17:24:59.358983657 [W:onnxruntime:Default, tensorrt_execution_provider.h:75 log] [2023-08-21 15:24:59 WARNING] Missing scale and zero-point for tensor (Unnamed Layer* 246) [Shuffle]_output, expect fall back to non-int8 implementation for any layer consuming or producing given tensor
2023-08-21 17:26:02.860807018 [W:onnxruntime:Default, tensorrt_execution_provider.h:75 log] [2023-08-21 15:26:02 WARNING] TensorRT encountered issues when converting weights between types and that could affect accuracy.
2023-08-21 17:26:02.860825995 [W:onnxruntime:Default, tensorrt_execution_provider.h:75 log] [2023-08-21 15:26:02 WARNING] If this is not the desired behavior, please modify the weights or retrain with regularization to adjust the magnitude of the weights.
2023-08-21 17:26:02.860831129 [W:onnxruntime:Default, tensorrt_execution_provider.h:75 log] [2023-08-21 15:26:02 WARNING] Check verbose logs for the list of affected weights.
2023-08-21 17:26:02.860835038 [W:onnxruntime:Default, tensorrt_execution_provider.h:75 log] [2023-08-21 15:26:02 WARNING] - 69 weights are affected by this issue: Detected subnormal FP16 values.
2023-08-21 17:26:02.860852767 [W:onnxruntime:Default, tensorrt_execution_provider.h:75 log] [2023-08-21 15:26:02 WARNING] - 9 weights are affected by this issue: Detected values less than smallest positive FP16 subnormal value and converted them to the FP16 minimum subnormalized value.

Most warnings are fine, they also occur in the un-quantized model and do not affect the performance of the model. Only the third one looks concerning.

Urgency

No response

Platform

Linux

OS Version

Ubuntu 20.04

ONNX Runtime Installation

Released Package

ONNX Runtime Version or Commit ID

1.15

ONNX Runtime API

Python

Architecture

X64

Execution Provider

TensorRT

Execution Provider Library Version

No response

Model File

No response

Is this a quantized model?

Yes

@github-actions github-actions bot added ep:TensorRT issues related to TensorRT execution provider quantization issues related to quantization labels Aug 21, 2023
@chilo-ms
Copy link
Contributor

chilo-ms commented Aug 22, 2023

@laclouis5

From your script, when you create the ONNXCalibrationDataset, you should provide the real dataset for the data reader, so later the calibrator can get the dynamic range for each tensor.

Please see the examples of how to provide the dataset.

        data_reader = ImageNetDataReader(ilsvrc2012_dataset_path,
                                     start_index=0,
                                     end_index=calibration_dataset_size,
                                     stride=calibration_dataset_size,
                                     batch_size=batch_size,
                                     model_path=augmented_model_path,
                                     input_name=input_name)

https://github.com/microsoft/onnxruntime-inference-examples/blob/main/quantization/image_classification/trt/resnet50/e2e_tensorrt_resnet_example.py#L338C9-L338C20
https://github.com/microsoft/onnxruntime-inference-examples/blob/main/quantization/object_detection/trt/yolov3/e2e_user_yolov3_example.py#L89

One thing to double check is to see the content of calibration.cache or calibration.json.
In addition to calibration.flatbuffers, there should be other two plain text files written out for user to see.

@laclouis5
Copy link
Author

laclouis5 commented Aug 22, 2023

What do you mean by "you should provide the real dataset for the data reader"? I used a dataset or random tensors in the code I provided since I cannot share the details but in reality the ONNXCalibrationDataset (which I definitely should have called ONNXCalibrationDataReader) draws its data from the validation dataset.

I double checked the json and cache files and they look fine. Here is an excerpt of the json (truncated):

{"/model/blocks/blocks.5/blocks.5.2/se/Mul_output_0": [-0.27840250730514526, 47.94845199584961], "/model/blocks/blocks.1/blocks.1.0/bn1/act/Sigmoid_output_0": [0.0, 1.0], "/model/blocks/blocks.4/blocks.4.0/bn2/act/Sigmoid_output_0": [0.0, 1.0], "/model/blocks/blocks.6/blocks.6.0/se/act1/Sigmoid_output_0": [0.0, 1.0], "/model/blocks/blocks.3/blocks.3.1/se/ReduceMean_output_0": [-0.276947021484375, 35.54975128173828], "/model/blocks/blocks.3/blocks.3.1/bn1/act/Sigmoid_output_0": [0.0, 1.0], "/model/blocks/blocks.4/blocks.4.1/se/act1/Mul_output_0": [-0.2784646153450012, 11.509453773498535], "/model/blocks/blocks.5/blocks.5.0/se/ReduceMean_output_0": [-0.2540878653526306, 10.721695899963379], "/model/blocks/blocks.6/blocks.6.0/conv_pw/Conv_output_0": [-25.566896438598633, 24.689128875732422], "/model/blocks/blocks.1/blocks.1.1/se/Mul_output_0": [-0.27753210067749023, 201.57601928710938], "/model/blocks/blocks.3/blocks.3.2/se/act1/Sigmoid_output_0": [0.0, 0.9999947547912598], "/model/blocks/blocks.2/blocks.2.0/se/act1/Sigmoid_output_0": [0.0, 0.999976396560669]}

If we imagine that ONNXCalibrationDataset.get_next() returns data from the real validation set then what is the issue in my code?

@chilo-ms
Copy link
Contributor

chilo-ms commented Aug 22, 2023

Ah, i didn't notice you mentioned the dataset is private. Okay, the calibration table seems okay.

But why we still got this error message from TRT, i need to investigate more.

One other configuration to check is the calibration method and ORT quantization script uses minmax as default. I'm curious if we choose percentile or entropy method, will accuracy increase?
Also just wondering do you know which calibration method did TRT use? and is it possible to check the calibration table/dynamic range that TRT uses?

@laclouis5
Copy link
Author

TensorRT uses entropy calibration by default while ORT uses minmax.

I’ll try again ORT entropy but I think I already tried in the past and it gave me OOM issues.

TensorRT with minmax calibration yields poor results with my network, thus I kept the default entropy calibration.

I’ll try to investigate the TensorRT calibration table but the error/bug seems to be in the runtime inference rather than in the calibration process.

@chilo-ms
Copy link
Contributor

chilo-ms commented Aug 22, 2023

Re: but the error/bug seems to be in the runtime inference rather than in the calibration process.

Yes, the message is from TRT and i'm investigating now and will reach out to Nvida about this if needed.
Will let you know once I have the answer.

@chilo-ms
Copy link
Contributor

chilo-ms commented Aug 23, 2023

@laclouis5
I checked with Nvidia.
The message will occur anytime a user specifies that they want TensorRT to use INT8 without a corresponding INT8 calibrator set in their builder config and this is what TRT EP is doing for your case.

Please also note that if the user is properly setting the dynamic ranges of the tensors, then the warning can be safely ignored. Even if we provided dynamic ranges for each tensor, the warning will still print since the INT8 calibrator is not set. The default values for dynamic ranges are +-NaN, but the NaNs are ignored and the layers run in higher precision as per the messages:

“Missing scale and zero-point for tensor (Unnamed Layer* 244) [Matrix Multiply]_output, expect fall back to non-int8 implementation for any layer consuming or producing given tensor”.

@chilo-ms
Copy link
Contributor

chilo-ms commented Aug 23, 2023

I think two things you can try to improve accuracy:

  • Dump the calibration table from TRT and provide it for TRT EP with additional provider option trt_int8_use_native_calibration_table set to true.
  • Change onnxruntime quantization tool to use entropy method for calibration. You mentioned you encountered OOM, please reference this example to see whether it helps.

@laclouis5
Copy link
Author

So to sum up, the warning can be ignored and the calibration did take place for my network. It’s just that this calibration method (min-max) yields poor performance for my specific network.

I obtained good performances with TensorRT Entropy calibration so I’ll try this with ONNX and the OOM workaround.

Thanks!

@laclouis5
Copy link
Author

FYI I was finally able to correctly calibrate my model using ONNX calibration tools. The only setting that reached an acceptable performance was the Percentile calibration with a percentile value of 99.9 (on par with TensorRT Entropy calibration).

However, I encountered some bugs. When using Percentile and Entropy calibration, the calibration data (as returned by calibrator.compute_range()) contains some np.float32 values instead of native Python float so that write_calibration_table fails because np.float32 is not serializable. I was able to solve this issue by converting the np.float32 values to float.

Calibration using Percentile and Entropy is also very slow, probably an order of magnitude slower than TensorRT Entropy calibration. MinMax is fast, though.

@chilo-ms
Copy link
Contributor

chilo-ms commented Aug 25, 2023

Thanks for the follow up and providing the feedback.

The Percentile and Entropy calibration method of onnxruntime calibration tools uses numpy histogram library to create the histogram of each tensor and collect the data which might be way slower than MinMax calibration method. As for why TensorRT Entropy calibration is fast, because it's a black box for us so i can only guess probably they did some optimization for the calibration or even run it on GPU, but for ORT, the numpy library is running on CPU.

@pravn
Copy link

pravn commented Nov 8, 2024

Excellent post. I had experienced the exact same issues. The default (minmax?) does not work. But the other two methods work (Entropy and Percentile), but only after we replace np.float32 with float as @laclouis5 kindly informed us. In my case, I have a dict with different datatypes (lidar, camera), so datasetup is a little different, although it broadly follows the same contrails as the Imagenet example.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ep:TensorRT issues related to TensorRT execution provider quantization issues related to quantization
Projects
None yet
Development

No branches or pull requests

3 participants