diff --git a/notebooks/build_engine.py b/notebooks/build_engine.py index bdcf4bd8..a44a22c8 100644 --- a/notebooks/build_engine.py +++ b/notebooks/build_engine.py @@ -18,17 +18,17 @@ import sys from numpy.core.fromnumeric import trace + sys.path.append("./") -import logging import argparse +import logging +import traceback import numpy as np -import tensorrt as trt -import pycuda.driver as cuda import pycuda.autoinit -import traceback - +import pycuda.driver as cuda +import tensorrt as trt from yolort.v5.utils.datasets import IMG_FORMATS, VID_FORMATS, LoadImages logging.basicConfig(level=logging.INFO) @@ -54,8 +54,10 @@ def __init__(self, calib_shape=None, calib_dtype=None) -> None: self.shape = (self.batch_size, 3, *calib_shape) self.num_images = len(self.dataset) self.image_index = 0 - - def get_batch(self, ): + + def get_batch( + self, + ): return iter(self.dataset) @@ -73,7 +75,7 @@ def __init__(self, cache_file): self.image_batcher: ImageBatcher = None self.batch_allocation = None self.batch_generator = None - + def set_image_batcher(self, image_batcher: ImageBatcher): """ Define the image batcher to use, if any. If using only the cache file, an image batcher doesn't need @@ -111,8 +113,12 @@ def get_batch(self, names): image = image[np.newaxis, :, :, :] batch, _, _, _ = image.shape self.image_batcher.image_index += 1 - - log.info("Calibrating image {} / {}".format(self.image_batcher.image_index, self.image_batcher.num_images)) + + log.info( + "Calibrating image {} / {}".format( + self.image_batcher.image_index, self.image_batcher.num_images + ) + ) cuda.memcpy_htod(self.batch_allocation, np.ascontiguousarray(batch)) return [int(self.batch_allocation)] except StopIteration: @@ -120,7 +126,7 @@ def get_batch(self, names): return None except Exception: traceback.print_exc() - + def read_calibration_cache(self): """ Overrides from trt.IInt8EntropyCalibrator2. @@ -171,7 +177,7 @@ def create_network(self, onnx_path): Parse the ONNX graph and create the corresponding TensorRT network definition. :param onnx_path: The path to the ONNX graph to load. """ - network_flags = (1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)) + network_flags = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) self.network = self.builder.create_network(network_flags) self.parser = trt.OnnxParser(self.network, self.trt_logger) @@ -196,8 +202,16 @@ def create_network(self, onnx_path): assert self.batch_size > 0 self.builder.max_batch_size = self.batch_size - def create_engine(self, engine_path, precision, calib_input=None, calib_cache=None, calib_num_images=25000, - calib_batch_size=8, calib_preprocessor=None): + def create_engine( + self, + engine_path, + precision, + calib_input=None, + calib_cache=None, + calib_num_images=25000, + calib_batch_size=8, + calib_preprocessor=None, + ): """ Build the TensorRT engine and serialize it to disk. :param engine_path: The path where to serialize the engine to. @@ -229,9 +243,7 @@ def create_engine(self, engine_path, precision, calib_input=None, calib_cache=No if not os.path.exists(calib_cache): calib_shape = [calib_batch_size] + list(inputs[0].shape[1:]) calib_dtype = trt.nptype(inputs[0].dtype) - self.config.int8_calibrator.set_image_batcher( - ImageBatcher(calib_shape, calib_dtype) - ) + self.config.int8_calibrator.set_image_batcher(ImageBatcher(calib_shape, calib_dtype)) with self.builder.build_engine(self.network, self.config) as engine: with open(engine_path, "wb") as f: @@ -242,26 +254,53 @@ def create_engine(self, engine_path, precision, calib_input=None, calib_cache=No def main(args): builder = EngineBuilder(args.verbose) builder.create_network(args.onnx) - builder.create_engine(args.engine, args.precision, args.calib_input, args.calib_cache, args.calib_num_images, - args.calib_batch_size, args.calib_preprocessor) + builder.create_engine( + args.engine, + args.precision, + args.calib_input, + args.calib_cache, + args.calib_num_images, + args.calib_batch_size, + args.calib_preprocessor, + ) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("-o", "--onnx", help="The input ONNX model file to load") parser.add_argument("-e", "--engine", help="The output path for the TRT engine") - parser.add_argument("-p", "--precision", default="fp16", choices=["fp32", "fp16", "int8"], - help="The precision mode to build in, either 'fp32', 'fp16' or 'int8', default: 'fp16'") + parser.add_argument( + "-p", + "--precision", + default="fp16", + choices=["fp32", "fp16", "int8"], + help="The precision mode to build in, either 'fp32', 'fp16' or 'int8', default: 'fp16'", + ) parser.add_argument("-v", "--verbose", action="store_true", help="Enable more verbose log output") parser.add_argument("--calib_input", help="The directory holding images to use for calibration") - parser.add_argument("--calib_cache", default="./calibration.cache", - help="The file path for INT8 calibration cache to use, default: ./calibration.cache") - parser.add_argument("--calib_num_images", default=10, type=int, - help="The maximum number of images to use for calibration, default: 25000") - parser.add_argument("--calib_batch_size", default=1, type=int, - help="The batch size for the calibration process, default: 1") - parser.add_argument("--calib_preprocessor", default="V2", choices=["V1", "V1MS", "V2"], - help="Set the calibration image preprocessor to use, either 'V2', 'V1' or 'V1MS', default: V2") + parser.add_argument( + "--calib_cache", + default="./calibration.cache", + help="The file path for INT8 calibration cache to use, default: ./calibration.cache", + ) + parser.add_argument( + "--calib_num_images", + default=10, + type=int, + help="The maximum number of images to use for calibration, default: 25000", + ) + parser.add_argument( + "--calib_batch_size", + default=1, + type=int, + help="The batch size for the calibration process, default: 1", + ) + parser.add_argument( + "--calib_preprocessor", + default="V2", + choices=["V1", "V1MS", "V2"], + help="Set the calibration image preprocessor to use, either 'V2', 'V1' or 'V1MS', default: V2", + ) args = parser.parse_args() if not all([args.onnx, args.engine]): parser.print_help()