Skip to content

Commit

Permalink
Refactor module structure for exporting TensorRT (#254)
Browse files Browse the repository at this point in the history
* Fix docstrings in YOLOGraphSurgeon

* Move YOLOTRTModule in yolort.runtime.trt_helper

* Make PostProcess irrelevant to LogitsDecoder

* Minor fixes
  • Loading branch information
zhiqwang authored Dec 26, 2021
1 parent 02c74a0 commit c5a10a5
Show file tree
Hide file tree
Showing 5 changed files with 175 additions and 173 deletions.
2 changes: 1 addition & 1 deletion test/test_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pytest
import torch
from torch import Tensor
from yolort.runtime.yolo_tensorrt_model import YOLOTRTModule
from yolort.runtime.trt_helper import YOLOTRTModule
from yolort.v5 import attempt_download


Expand Down
91 changes: 45 additions & 46 deletions yolort/models/box_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,9 +318,46 @@ def build_targets(
return target_cls, target_box, indices, anch


def _concat_pred_logits(
head_outputs: List[Tensor],
grids: List[Tensor],
shifts: List[Tensor],
strides: List[int],
) -> Tensor:
# Concat all pred logits
batch_size, _, _, _, K = head_outputs[0].shape

# Decode bounding box with the shifts and grids
all_pred_logits = []

for head_output, grid, shift, stride in zip(head_outputs, grids, shifts, strides):
head_feature = torch.sigmoid(head_output)
pred_xy, pred_wh = det_utils.decode_single(head_feature[..., :4], grid, shift, stride)
pred_logits = torch.cat((pred_xy, pred_wh, head_feature[..., 4:]), dim=-1)
all_pred_logits.append(pred_logits.view(batch_size, -1, K))

all_pred_logits = torch.cat(all_pred_logits, dim=1)

return all_pred_logits


def _decode_pred_logits(pred_logits: Tensor):
"""
Decode the prediction logit from the PostPrecess.
"""
# Compute conf
# box_conf x class_conf, w/ shape: num_anchors x num_classes
scores = pred_logits[:, 5:] * pred_logits[:, 4:5]
boxes = box_convert(pred_logits[:, :4], in_fmt="cxcywh", out_fmt="xyxy")

return boxes, scores


class LogitsDecoder(nn.Module):
"""
This is a simplified version of PostProcess to remove the ``torchvision::nms`` module.
This is a simplified version of post-processing module, we manually remove
the ``torchvision::ops::nms``, and it will be used later in the procedure of
exporting the ONNX graph for TensorRT.
"""

def __init__(self, strides: List[int]) -> None:
Expand All @@ -332,45 +369,6 @@ def __init__(self, strides: List[int]) -> None:
super().__init__()
self.strides = strides

def _concat_pred_logits(
self,
head_outputs: List[Tensor],
grids: List[Tensor],
shifts: List[Tensor],
) -> Tensor:
# Concat all pred logits
batch_size, _, _, _, K = head_outputs[0].shape

# Decode bounding box with the shifts and grids
all_pred_logits = []

for i, head_output in enumerate(head_outputs):
head_feature = torch.sigmoid(head_output)
pred_xy, pred_wh = det_utils.decode_single(
head_feature[..., :4],
grids[i],
shifts[i],
self.strides[i],
)
pred_logits = torch.cat((pred_xy, pred_wh, head_feature[..., 4:]), dim=-1)
all_pred_logits.append(pred_logits.view(batch_size, -1, K))

all_pred_logits = torch.cat(all_pred_logits, dim=1)

return all_pred_logits

@staticmethod
def _decode_pred_logits(pred_logits: Tensor):
"""
Decode the prediction logit from the PostPrecess.
"""
# Compute conf
# box_conf x class_conf, w/ shape: num_anchors x num_classes
scores = pred_logits[:, 5:] * pred_logits[:, 4:5]
boxes = box_convert(pred_logits[:, :4], in_fmt="cxcywh", out_fmt="xyxy")

return boxes, scores

def forward(
self,
head_outputs: List[Tensor],
Expand All @@ -390,14 +388,14 @@ def forward(
"""
batch_size = len(head_outputs[0])

all_pred_logits = self._concat_pred_logits(head_outputs, grids, shifts)
all_pred_logits = _concat_pred_logits(head_outputs, grids, shifts, self.strides)

bbox_regression = []
pred_scores = []

for idx in range(batch_size): # image idx, image inference
pred_logits = all_pred_logits[idx]
boxes, scores = self._decode_pred_logits(pred_logits)
boxes, scores = _decode_pred_logits(pred_logits)
bbox_regression.append(boxes)
pred_scores.append(scores)

Expand All @@ -409,7 +407,7 @@ def forward(
return boxes, scores


class PostProcess(LogitsDecoder):
class PostProcess(nn.Module):
"""
Performs Non-Maximum Suppression (NMS) on inference results
"""
Expand All @@ -428,7 +426,8 @@ def __init__(
nms_thresh (float): NMS threshold used for postprocessing the detections.
detections_per_img (int): Number of best detections to keep after NMS.
"""
super().__init__(strides)
super().__init__()
self.strides = strides
self.score_thresh = score_thresh
self.nms_thresh = nms_thresh
self.detections_per_img = detections_per_img
Expand All @@ -453,12 +452,12 @@ def forward(
"""
batch_size = len(head_outputs[0])

all_pred_logits = self._concat_pred_logits(head_outputs, grids, shifts)
all_pred_logits = _concat_pred_logits(head_outputs, grids, shifts, self.strides)
detections: List[Dict[str, Tensor]] = []

for idx in range(batch_size): # image idx, image inference
pred_logits = all_pred_logits[idx]
boxes, scores = self._decode_pred_logits(pred_logits)
boxes, scores = _decode_pred_logits(pred_logits)
# remove low scoring boxes
inds, labels = torch.where(scores > self.score_thresh)
boxes, scores = boxes[inds], scores[inds, labels]
Expand Down
131 changes: 121 additions & 10 deletions yolort/runtime/trt_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,127 @@

import logging
from pathlib import Path
from typing import Optional
from typing import Optional, Tuple, Union

try:
import tensorrt as trt
except ImportError:
trt = None

import torch
from torch import nn, Tensor
from yolort.models import YOLO
from yolort.models.anchor_utils import AnchorGenerator
from yolort.models.backbone_utils import darknet_pan_backbone
from yolort.models.box_head import LogitsDecoder
from yolort.utils import load_from_ultralytics

logging.basicConfig(level=logging.INFO)
logging.getLogger("EngineBuilder").setLevel(logging.INFO)
log = logging.getLogger("EngineBuilder")
logging.getLogger("TRTHelper").setLevel(logging.INFO)
logger = logging.getLogger("TRTHelper")


__all__ = ["YOLOTRTModule", "EngineBuilder"]


class YOLOTRTModule(nn.Module):
"""
TensorRT deployment friendly wrapper for YOLO.
Remove the ``torchvision::nms`` in this warpper, due to the fact that some third-party
inference frameworks currently do not support this operator very well.
"""

def __init__(
self,
checkpoint_path: str,
version: str = "r6.0",
):
super().__init__()
model_info = load_from_ultralytics(checkpoint_path, version=version)

backbone_name = f"darknet_{model_info['size']}_{version.replace('.', '_')}"
depth_multiple = model_info["depth_multiple"]
width_multiple = model_info["width_multiple"]
use_p6 = model_info["use_p6"]
backbone = darknet_pan_backbone(
backbone_name,
depth_multiple,
width_multiple,
version=version,
use_p6=use_p6,
)
num_classes = model_info["num_classes"]
anchor_generator = AnchorGenerator(model_info["strides"], model_info["anchor_grids"])
post_process = LogitsDecoder(model_info["strides"])
model = YOLO(
backbone,
num_classes,
anchor_generator=anchor_generator,
post_process=post_process,
)

model.load_state_dict(model_info["state_dict"])
self.model = model
self.num_clases = num_classes

@torch.no_grad()
def forward(self, inputs: Tensor) -> Tuple[Tensor, Tensor]:
"""
Args:
inputs (Tensor): batched images, of shape [batch_size x 3 x H x W]
"""
# Compute the detections
outputs = self.model(inputs)

return outputs

@torch.no_grad()
def to_onnx(
self,
file_path: Union[str, Path],
input_sample: Optional[Tensor] = None,
opset_version: int = 11,
enable_dynamic: bool = True,
**kwargs,
):
"""
Saves the model in ONNX format.
Args:
file_path: The path of the file the onnx model should be saved to.
input_sample: An input for tracing. Default: None.
opset_version: Opset version we export the model to the onnx submodule. Default: 11.
enable_dynamic: Whether to specify axes of tensors as dynamic. Default: True.
**kwargs: Will be passed to torch.onnx.export function.
"""
if input_sample is None:
input_sample = torch.rand(1, 3, 320, 320).to(next(self.parameters()).device)

dynamic_axes = (
{
"images": {0: "batch", 2: "height", 3: "width"},
"boxes": {0: "batch", 1: "num_objects"},
"scores": {0: "batch", 1: "num_objects"},
}
if enable_dynamic
else None
)

input_names = ["images"]
output_names = ["boxes", "scores"]

torch.onnx.export(
self.model,
input_sample,
file_path,
do_constant_folding=True,
opset_version=opset_version,
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
**kwargs,
)


class EngineBuilder:
Expand Down Expand Up @@ -66,12 +177,12 @@ def create_network(self, onnx_path: str):
inputs = [self.network.get_input(i) for i in range(self.network.num_inputs)]
outputs = [self.network.get_output(i) for i in range(self.network.num_outputs)]

log.info("Network Description")
logger.info("Network Description")
for input in inputs:
self.batch_size = input.shape[0]
log.info(f"Input '{input.name}' with shape {input.shape} and dtype {input.dtype}")
logger.info(f"Input '{input.name}' with shape {input.shape} and dtype {input.dtype}")
for output in outputs:
log.info(f"Output '{output.name}' with shape {output.shape} and dtype {output.dtype}")
logger.info(f"Output '{output.name}' with shape {output.shape} and dtype {output.dtype}")
assert self.batch_size > 0
self.builder.max_batch_size = self.batch_size

Expand Down Expand Up @@ -100,18 +211,18 @@ def create_engine(

engine_path.parent.mkdir(parents=True, exist_ok=True)

log.info(f"Building {precision} Engine in {engine_path}")
logger.info(f"Building {precision} Engine in {engine_path}")

if precision == "fp16":
if not self.builder.platform_has_fast_fp16:
log.warning("FP16 is not supported natively on this platform/device")
logger.warning("FP16 is not supported natively on this platform/device")
else:
self.config.set_flag(trt.BuilderFlag.FP16)
elif precision == "fp32":
log.info("Using fp32 mode.")
logger.info("Using fp32 mode.")
else:
raise NotImplementedError(f"Currently hasn't been implemented: {precision}.")

with self.builder.build_engine(self.network, self.config) as engine, open(engine_path, "wb") as f:
f.write(engine.serialize())
log.info(f"Serialize engine success, saved as {engine_path}")
logger.info(f"Serialize engine success, saved as {engine_path}")
11 changes: 8 additions & 3 deletions yolort/runtime/yolo_graphsurgeon.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
except ImportError:
gs = None

from .yolo_tensorrt_model import YOLOTRTModule
from .trt_helper import YOLOTRTModule

logging.basicConfig(level=logging.INFO)
logging.getLogger("YOLOGraphSurgeon").setLevel(logging.INFO)
Expand All @@ -30,8 +30,13 @@

class YOLOGraphSurgeon:
"""
Constructor of the YOLOv5 Graph Surgeon object, TensorRT treat ``nms`` as
plugin, especially ``EfficientNMS_TRT`` in our yolort PostProcess module.
Constructor of the YOLOv5 Graph Surgeon object.
Because TensorRT treat the ``torchvision::ops::nms`` as plugin, we use the a simple post-processing
module named ``LogitsDecoder`` to connect to ``BatchedNMS_TRT`` plugin in TensorRT.
And the ``BatchedNMS_TRT`` plays the same role of following computation.
https://github.com/zhiqwang/yolov5-rt-stack/blob/02c74a0/yolort/models/box_head.py#L462-L470
Args:
checkpoint_path: The path pointing to the PyTorch saved model to load.
Expand Down
Loading

0 comments on commit c5a10a5

Please sign in to comment.