From c5a10a5d92fc3390f41e098e81b9cf966e7e3d90 Mon Sep 17 00:00:00 2001 From: Zhiqiang Wang Date: Sun, 26 Dec 2021 19:15:33 +0800 Subject: [PATCH] Refactor module structure for exporting TensorRT (#254) * Fix docstrings in YOLOGraphSurgeon * Move YOLOTRTModule in yolort.runtime.trt_helper * Make PostProcess irrelevant to LogitsDecoder * Minor fixes --- test/test_runtime.py | 2 +- yolort/models/box_head.py | 91 +++++++++--------- yolort/runtime/trt_helper.py | 131 ++++++++++++++++++++++++-- yolort/runtime/yolo_graphsurgeon.py | 11 ++- yolort/runtime/yolo_tensorrt_model.py | 113 ---------------------- 5 files changed, 175 insertions(+), 173 deletions(-) delete mode 100644 yolort/runtime/yolo_tensorrt_model.py diff --git a/test/test_runtime.py b/test/test_runtime.py index f721d0ab..29407b4e 100644 --- a/test/test_runtime.py +++ b/test/test_runtime.py @@ -4,7 +4,7 @@ import pytest import torch from torch import Tensor -from yolort.runtime.yolo_tensorrt_model import YOLOTRTModule +from yolort.runtime.trt_helper import YOLOTRTModule from yolort.v5 import attempt_download diff --git a/yolort/models/box_head.py b/yolort/models/box_head.py index 7d7595a3..3db871d4 100644 --- a/yolort/models/box_head.py +++ b/yolort/models/box_head.py @@ -318,9 +318,46 @@ def build_targets( return target_cls, target_box, indices, anch +def _concat_pred_logits( + head_outputs: List[Tensor], + grids: List[Tensor], + shifts: List[Tensor], + strides: List[int], +) -> Tensor: + # Concat all pred logits + batch_size, _, _, _, K = head_outputs[0].shape + + # Decode bounding box with the shifts and grids + all_pred_logits = [] + + for head_output, grid, shift, stride in zip(head_outputs, grids, shifts, strides): + head_feature = torch.sigmoid(head_output) + pred_xy, pred_wh = det_utils.decode_single(head_feature[..., :4], grid, shift, stride) + pred_logits = torch.cat((pred_xy, pred_wh, head_feature[..., 4:]), dim=-1) + all_pred_logits.append(pred_logits.view(batch_size, -1, K)) + + all_pred_logits = torch.cat(all_pred_logits, dim=1) + + return all_pred_logits + + +def _decode_pred_logits(pred_logits: Tensor): + """ + Decode the prediction logit from the PostPrecess. + """ + # Compute conf + # box_conf x class_conf, w/ shape: num_anchors x num_classes + scores = pred_logits[:, 5:] * pred_logits[:, 4:5] + boxes = box_convert(pred_logits[:, :4], in_fmt="cxcywh", out_fmt="xyxy") + + return boxes, scores + + class LogitsDecoder(nn.Module): """ - This is a simplified version of PostProcess to remove the ``torchvision::nms`` module. + This is a simplified version of post-processing module, we manually remove + the ``torchvision::ops::nms``, and it will be used later in the procedure of + exporting the ONNX graph for TensorRT. """ def __init__(self, strides: List[int]) -> None: @@ -332,45 +369,6 @@ def __init__(self, strides: List[int]) -> None: super().__init__() self.strides = strides - def _concat_pred_logits( - self, - head_outputs: List[Tensor], - grids: List[Tensor], - shifts: List[Tensor], - ) -> Tensor: - # Concat all pred logits - batch_size, _, _, _, K = head_outputs[0].shape - - # Decode bounding box with the shifts and grids - all_pred_logits = [] - - for i, head_output in enumerate(head_outputs): - head_feature = torch.sigmoid(head_output) - pred_xy, pred_wh = det_utils.decode_single( - head_feature[..., :4], - grids[i], - shifts[i], - self.strides[i], - ) - pred_logits = torch.cat((pred_xy, pred_wh, head_feature[..., 4:]), dim=-1) - all_pred_logits.append(pred_logits.view(batch_size, -1, K)) - - all_pred_logits = torch.cat(all_pred_logits, dim=1) - - return all_pred_logits - - @staticmethod - def _decode_pred_logits(pred_logits: Tensor): - """ - Decode the prediction logit from the PostPrecess. - """ - # Compute conf - # box_conf x class_conf, w/ shape: num_anchors x num_classes - scores = pred_logits[:, 5:] * pred_logits[:, 4:5] - boxes = box_convert(pred_logits[:, :4], in_fmt="cxcywh", out_fmt="xyxy") - - return boxes, scores - def forward( self, head_outputs: List[Tensor], @@ -390,14 +388,14 @@ def forward( """ batch_size = len(head_outputs[0]) - all_pred_logits = self._concat_pred_logits(head_outputs, grids, shifts) + all_pred_logits = _concat_pred_logits(head_outputs, grids, shifts, self.strides) bbox_regression = [] pred_scores = [] for idx in range(batch_size): # image idx, image inference pred_logits = all_pred_logits[idx] - boxes, scores = self._decode_pred_logits(pred_logits) + boxes, scores = _decode_pred_logits(pred_logits) bbox_regression.append(boxes) pred_scores.append(scores) @@ -409,7 +407,7 @@ def forward( return boxes, scores -class PostProcess(LogitsDecoder): +class PostProcess(nn.Module): """ Performs Non-Maximum Suppression (NMS) on inference results """ @@ -428,7 +426,8 @@ def __init__( nms_thresh (float): NMS threshold used for postprocessing the detections. detections_per_img (int): Number of best detections to keep after NMS. """ - super().__init__(strides) + super().__init__() + self.strides = strides self.score_thresh = score_thresh self.nms_thresh = nms_thresh self.detections_per_img = detections_per_img @@ -453,12 +452,12 @@ def forward( """ batch_size = len(head_outputs[0]) - all_pred_logits = self._concat_pred_logits(head_outputs, grids, shifts) + all_pred_logits = _concat_pred_logits(head_outputs, grids, shifts, self.strides) detections: List[Dict[str, Tensor]] = [] for idx in range(batch_size): # image idx, image inference pred_logits = all_pred_logits[idx] - boxes, scores = self._decode_pred_logits(pred_logits) + boxes, scores = _decode_pred_logits(pred_logits) # remove low scoring boxes inds, labels = torch.where(scores > self.score_thresh) boxes, scores = boxes[inds], scores[inds, labels] diff --git a/yolort/runtime/trt_helper.py b/yolort/runtime/trt_helper.py index bc5652e5..16bf9c06 100644 --- a/yolort/runtime/trt_helper.py +++ b/yolort/runtime/trt_helper.py @@ -11,16 +11,127 @@ import logging from pathlib import Path -from typing import Optional +from typing import Optional, Tuple, Union try: import tensorrt as trt except ImportError: trt = None +import torch +from torch import nn, Tensor +from yolort.models import YOLO +from yolort.models.anchor_utils import AnchorGenerator +from yolort.models.backbone_utils import darknet_pan_backbone +from yolort.models.box_head import LogitsDecoder +from yolort.utils import load_from_ultralytics + logging.basicConfig(level=logging.INFO) -logging.getLogger("EngineBuilder").setLevel(logging.INFO) -log = logging.getLogger("EngineBuilder") +logging.getLogger("TRTHelper").setLevel(logging.INFO) +logger = logging.getLogger("TRTHelper") + + +__all__ = ["YOLOTRTModule", "EngineBuilder"] + + +class YOLOTRTModule(nn.Module): + """ + TensorRT deployment friendly wrapper for YOLO. + + Remove the ``torchvision::nms`` in this warpper, due to the fact that some third-party + inference frameworks currently do not support this operator very well. + """ + + def __init__( + self, + checkpoint_path: str, + version: str = "r6.0", + ): + super().__init__() + model_info = load_from_ultralytics(checkpoint_path, version=version) + + backbone_name = f"darknet_{model_info['size']}_{version.replace('.', '_')}" + depth_multiple = model_info["depth_multiple"] + width_multiple = model_info["width_multiple"] + use_p6 = model_info["use_p6"] + backbone = darknet_pan_backbone( + backbone_name, + depth_multiple, + width_multiple, + version=version, + use_p6=use_p6, + ) + num_classes = model_info["num_classes"] + anchor_generator = AnchorGenerator(model_info["strides"], model_info["anchor_grids"]) + post_process = LogitsDecoder(model_info["strides"]) + model = YOLO( + backbone, + num_classes, + anchor_generator=anchor_generator, + post_process=post_process, + ) + + model.load_state_dict(model_info["state_dict"]) + self.model = model + self.num_clases = num_classes + + @torch.no_grad() + def forward(self, inputs: Tensor) -> Tuple[Tensor, Tensor]: + """ + Args: + inputs (Tensor): batched images, of shape [batch_size x 3 x H x W] + """ + # Compute the detections + outputs = self.model(inputs) + + return outputs + + @torch.no_grad() + def to_onnx( + self, + file_path: Union[str, Path], + input_sample: Optional[Tensor] = None, + opset_version: int = 11, + enable_dynamic: bool = True, + **kwargs, + ): + """ + Saves the model in ONNX format. + + Args: + file_path: The path of the file the onnx model should be saved to. + input_sample: An input for tracing. Default: None. + opset_version: Opset version we export the model to the onnx submodule. Default: 11. + enable_dynamic: Whether to specify axes of tensors as dynamic. Default: True. + **kwargs: Will be passed to torch.onnx.export function. + """ + if input_sample is None: + input_sample = torch.rand(1, 3, 320, 320).to(next(self.parameters()).device) + + dynamic_axes = ( + { + "images": {0: "batch", 2: "height", 3: "width"}, + "boxes": {0: "batch", 1: "num_objects"}, + "scores": {0: "batch", 1: "num_objects"}, + } + if enable_dynamic + else None + ) + + input_names = ["images"] + output_names = ["boxes", "scores"] + + torch.onnx.export( + self.model, + input_sample, + file_path, + do_constant_folding=True, + opset_version=opset_version, + input_names=input_names, + output_names=output_names, + dynamic_axes=dynamic_axes, + **kwargs, + ) class EngineBuilder: @@ -66,12 +177,12 @@ def create_network(self, onnx_path: str): inputs = [self.network.get_input(i) for i in range(self.network.num_inputs)] outputs = [self.network.get_output(i) for i in range(self.network.num_outputs)] - log.info("Network Description") + logger.info("Network Description") for input in inputs: self.batch_size = input.shape[0] - log.info(f"Input '{input.name}' with shape {input.shape} and dtype {input.dtype}") + logger.info(f"Input '{input.name}' with shape {input.shape} and dtype {input.dtype}") for output in outputs: - log.info(f"Output '{output.name}' with shape {output.shape} and dtype {output.dtype}") + logger.info(f"Output '{output.name}' with shape {output.shape} and dtype {output.dtype}") assert self.batch_size > 0 self.builder.max_batch_size = self.batch_size @@ -100,18 +211,18 @@ def create_engine( engine_path.parent.mkdir(parents=True, exist_ok=True) - log.info(f"Building {precision} Engine in {engine_path}") + logger.info(f"Building {precision} Engine in {engine_path}") if precision == "fp16": if not self.builder.platform_has_fast_fp16: - log.warning("FP16 is not supported natively on this platform/device") + logger.warning("FP16 is not supported natively on this platform/device") else: self.config.set_flag(trt.BuilderFlag.FP16) elif precision == "fp32": - log.info("Using fp32 mode.") + logger.info("Using fp32 mode.") else: raise NotImplementedError(f"Currently hasn't been implemented: {precision}.") with self.builder.build_engine(self.network, self.config) as engine, open(engine_path, "wb") as f: f.write(engine.serialize()) - log.info(f"Serialize engine success, saved as {engine_path}") + logger.info(f"Serialize engine success, saved as {engine_path}") diff --git a/yolort/runtime/yolo_graphsurgeon.py b/yolort/runtime/yolo_graphsurgeon.py index 0807d905..b0feb2f3 100644 --- a/yolort/runtime/yolo_graphsurgeon.py +++ b/yolort/runtime/yolo_graphsurgeon.py @@ -21,7 +21,7 @@ except ImportError: gs = None -from .yolo_tensorrt_model import YOLOTRTModule +from .trt_helper import YOLOTRTModule logging.basicConfig(level=logging.INFO) logging.getLogger("YOLOGraphSurgeon").setLevel(logging.INFO) @@ -30,8 +30,13 @@ class YOLOGraphSurgeon: """ - Constructor of the YOLOv5 Graph Surgeon object, TensorRT treat ``nms`` as - plugin, especially ``EfficientNMS_TRT`` in our yolort PostProcess module. + Constructor of the YOLOv5 Graph Surgeon object. + + Because TensorRT treat the ``torchvision::ops::nms`` as plugin, we use the a simple post-processing + module named ``LogitsDecoder`` to connect to ``BatchedNMS_TRT`` plugin in TensorRT. + + And the ``BatchedNMS_TRT`` plays the same role of following computation. + https://github.com/zhiqwang/yolov5-rt-stack/blob/02c74a0/yolort/models/box_head.py#L462-L470 Args: checkpoint_path: The path pointing to the PyTorch saved model to load. diff --git a/yolort/runtime/yolo_tensorrt_model.py b/yolort/runtime/yolo_tensorrt_model.py deleted file mode 100644 index e0493b7c..00000000 --- a/yolort/runtime/yolo_tensorrt_model.py +++ /dev/null @@ -1,113 +0,0 @@ -# Copyright (c) 2021, yolort team. All Rights Reserved. -from pathlib import Path -from typing import Optional, Tuple, Union - -import torch -from torch import nn, Tensor -from yolort.models import YOLO -from yolort.models.anchor_utils import AnchorGenerator -from yolort.models.backbone_utils import darknet_pan_backbone -from yolort.models.box_head import LogitsDecoder -from yolort.utils import load_from_ultralytics - -__all__ = ["YOLOTRTModule"] - - -class YOLOTRTModule(nn.Module): - """ - TensorRT deployment friendly wrapper for YOLO. - - Remove the ``torchvision::nms`` in this warpper, due to the fact that some third-party - inference frameworks currently do not support this operator very well. - """ - - def __init__( - self, - checkpoint_path: str, - version: str = "r6.0", - ): - super().__init__() - model_info = load_from_ultralytics(checkpoint_path, version=version) - - backbone_name = f"darknet_{model_info['size']}_{version.replace('.', '_')}" - depth_multiple = model_info["depth_multiple"] - width_multiple = model_info["width_multiple"] - use_p6 = model_info["use_p6"] - backbone = darknet_pan_backbone( - backbone_name, - depth_multiple, - width_multiple, - version=version, - use_p6=use_p6, - ) - num_classes = model_info["num_classes"] - anchor_generator = AnchorGenerator(model_info["strides"], model_info["anchor_grids"]) - post_process = LogitsDecoder(model_info["strides"]) - model = YOLO( - backbone, - num_classes, - anchor_generator=anchor_generator, - post_process=post_process, - ) - - model.load_state_dict(model_info["state_dict"]) - self.model = model - self.num_clases = num_classes - - @torch.no_grad() - def forward(self, inputs: Tensor) -> Tuple[Tensor, Tensor]: - """ - Args: - inputs (Tensor): batched images, of shape [batch_size x 3 x H x W] - """ - # Compute the detections - outputs = self.model(inputs) - - return outputs - - @torch.no_grad() - def to_onnx( - self, - file_path: Union[str, Path], - input_sample: Optional[Tensor] = None, - opset_version: int = 11, - enable_dynamic: bool = True, - **kwargs, - ): - """ - Saves the model in ONNX format. - - Args: - file_path: The path of the file the onnx model should be saved to. - input_sample: An input for tracing. Default: None. - opset_version: Opset version we export the model to the onnx submodule. Default: 11. - enable_dynamic: Whether to specify axes of tensors as dynamic. Default: True. - **kwargs: Will be passed to torch.onnx.export function. - """ - if input_sample is None: - input_sample = torch.rand(1, 3, 320, 320).to(next(self.parameters()).device) - - dynamic_axes = ( - { - "images": {0: "batch", 2: "height", 3: "width"}, - "boxes": {0: "batch", 1: "num_objects"}, - "scores": {0: "batch", 1: "num_objects"}, - } - if enable_dynamic - else None - ) - - input_names = ["images"] - output_names = ["boxes", "scores"] - - torch.onnx.export( - self.model, - input_sample, - file_path, - do_constant_folding=True, - opset_version=opset_version, - input_names=input_names, - output_names=output_names, - dynamic_axes=dynamic_axes, - **kwargs, - )