Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor module structure for exporting TensorRT #254

Merged
merged 4 commits into from
Dec 26, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion test/test_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pytest
import torch
from torch import Tensor
from yolort.runtime.yolo_tensorrt_model import YOLOTRTModule
from yolort.runtime.trt_helper import YOLOTRTModule
from yolort.v5 import attempt_download


Expand Down
91 changes: 45 additions & 46 deletions yolort/models/box_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,9 +318,46 @@ def build_targets(
return target_cls, target_box, indices, anch


def _concat_pred_logits(
head_outputs: List[Tensor],
grids: List[Tensor],
shifts: List[Tensor],
strides: List[int],
) -> Tensor:
# Concat all pred logits
batch_size, _, _, _, K = head_outputs[0].shape

# Decode bounding box with the shifts and grids
all_pred_logits = []

for head_output, grid, shift, stride in zip(head_outputs, grids, shifts, strides):
head_feature = torch.sigmoid(head_output)
pred_xy, pred_wh = det_utils.decode_single(head_feature[..., :4], grid, shift, stride)
pred_logits = torch.cat((pred_xy, pred_wh, head_feature[..., 4:]), dim=-1)
all_pred_logits.append(pred_logits.view(batch_size, -1, K))

all_pred_logits = torch.cat(all_pred_logits, dim=1)

return all_pred_logits


def _decode_pred_logits(pred_logits: Tensor):
"""
Decode the prediction logit from the PostPrecess.
"""
# Compute conf
# box_conf x class_conf, w/ shape: num_anchors x num_classes
scores = pred_logits[:, 5:] * pred_logits[:, 4:5]
boxes = box_convert(pred_logits[:, :4], in_fmt="cxcywh", out_fmt="xyxy")

return boxes, scores


class LogitsDecoder(nn.Module):
"""
This is a simplified version of PostProcess to remove the ``torchvision::nms`` module.
This is a simplified version of post-processing module, we manually remove
the ``torchvision::ops::nms``, and it will be used later in the procedure of
exporting the ONNX graph for TensorRT.
"""

def __init__(self, strides: List[int]) -> None:
Expand All @@ -332,45 +369,6 @@ def __init__(self, strides: List[int]) -> None:
super().__init__()
self.strides = strides

def _concat_pred_logits(
self,
head_outputs: List[Tensor],
grids: List[Tensor],
shifts: List[Tensor],
) -> Tensor:
# Concat all pred logits
batch_size, _, _, _, K = head_outputs[0].shape

# Decode bounding box with the shifts and grids
all_pred_logits = []

for i, head_output in enumerate(head_outputs):
head_feature = torch.sigmoid(head_output)
pred_xy, pred_wh = det_utils.decode_single(
head_feature[..., :4],
grids[i],
shifts[i],
self.strides[i],
)
pred_logits = torch.cat((pred_xy, pred_wh, head_feature[..., 4:]), dim=-1)
all_pred_logits.append(pred_logits.view(batch_size, -1, K))

all_pred_logits = torch.cat(all_pred_logits, dim=1)

return all_pred_logits

@staticmethod
def _decode_pred_logits(pred_logits: Tensor):
"""
Decode the prediction logit from the PostPrecess.
"""
# Compute conf
# box_conf x class_conf, w/ shape: num_anchors x num_classes
scores = pred_logits[:, 5:] * pred_logits[:, 4:5]
boxes = box_convert(pred_logits[:, :4], in_fmt="cxcywh", out_fmt="xyxy")

return boxes, scores

def forward(
self,
head_outputs: List[Tensor],
Expand All @@ -390,14 +388,14 @@ def forward(
"""
batch_size = len(head_outputs[0])

all_pred_logits = self._concat_pred_logits(head_outputs, grids, shifts)
all_pred_logits = _concat_pred_logits(head_outputs, grids, shifts, self.strides)

bbox_regression = []
pred_scores = []

for idx in range(batch_size): # image idx, image inference
pred_logits = all_pred_logits[idx]
boxes, scores = self._decode_pred_logits(pred_logits)
boxes, scores = _decode_pred_logits(pred_logits)
bbox_regression.append(boxes)
pred_scores.append(scores)

Expand All @@ -409,7 +407,7 @@ def forward(
return boxes, scores


class PostProcess(LogitsDecoder):
class PostProcess(nn.Module):
"""
Performs Non-Maximum Suppression (NMS) on inference results
"""
Expand All @@ -428,7 +426,8 @@ def __init__(
nms_thresh (float): NMS threshold used for postprocessing the detections.
detections_per_img (int): Number of best detections to keep after NMS.
"""
super().__init__(strides)
super().__init__()
self.strides = strides
self.score_thresh = score_thresh
self.nms_thresh = nms_thresh
self.detections_per_img = detections_per_img
Expand All @@ -453,12 +452,12 @@ def forward(
"""
batch_size = len(head_outputs[0])

all_pred_logits = self._concat_pred_logits(head_outputs, grids, shifts)
all_pred_logits = _concat_pred_logits(head_outputs, grids, shifts, self.strides)
detections: List[Dict[str, Tensor]] = []

for idx in range(batch_size): # image idx, image inference
pred_logits = all_pred_logits[idx]
boxes, scores = self._decode_pred_logits(pred_logits)
boxes, scores = _decode_pred_logits(pred_logits)
# remove low scoring boxes
inds, labels = torch.where(scores > self.score_thresh)
boxes, scores = boxes[inds], scores[inds, labels]
Expand Down
131 changes: 121 additions & 10 deletions yolort/runtime/trt_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,127 @@

import logging
from pathlib import Path
from typing import Optional
from typing import Optional, Tuple, Union

try:
import tensorrt as trt
except ImportError:
trt = None

import torch
from torch import nn, Tensor
from yolort.models import YOLO
from yolort.models.anchor_utils import AnchorGenerator
from yolort.models.backbone_utils import darknet_pan_backbone
from yolort.models.box_head import LogitsDecoder
from yolort.utils import load_from_ultralytics

logging.basicConfig(level=logging.INFO)
logging.getLogger("EngineBuilder").setLevel(logging.INFO)
log = logging.getLogger("EngineBuilder")
logging.getLogger("TRTHelper").setLevel(logging.INFO)
logger = logging.getLogger("TRTHelper")


__all__ = ["YOLOTRTModule", "EngineBuilder"]


class YOLOTRTModule(nn.Module):
"""
TensorRT deployment friendly wrapper for YOLO.

Remove the ``torchvision::nms`` in this warpper, due to the fact that some third-party
inference frameworks currently do not support this operator very well.
"""

def __init__(
self,
checkpoint_path: str,
version: str = "r6.0",
):
super().__init__()
model_info = load_from_ultralytics(checkpoint_path, version=version)

backbone_name = f"darknet_{model_info['size']}_{version.replace('.', '_')}"
depth_multiple = model_info["depth_multiple"]
width_multiple = model_info["width_multiple"]
use_p6 = model_info["use_p6"]
backbone = darknet_pan_backbone(
backbone_name,
depth_multiple,
width_multiple,
version=version,
use_p6=use_p6,
)
num_classes = model_info["num_classes"]
anchor_generator = AnchorGenerator(model_info["strides"], model_info["anchor_grids"])
post_process = LogitsDecoder(model_info["strides"])
model = YOLO(
backbone,
num_classes,
anchor_generator=anchor_generator,
post_process=post_process,
)

model.load_state_dict(model_info["state_dict"])
self.model = model
self.num_clases = num_classes

@torch.no_grad()
def forward(self, inputs: Tensor) -> Tuple[Tensor, Tensor]:
"""
Args:
inputs (Tensor): batched images, of shape [batch_size x 3 x H x W]
"""
# Compute the detections
outputs = self.model(inputs)

return outputs

@torch.no_grad()
def to_onnx(
self,
file_path: Union[str, Path],
input_sample: Optional[Tensor] = None,
opset_version: int = 11,
enable_dynamic: bool = True,
**kwargs,
):
"""
Saves the model in ONNX format.

Args:
file_path: The path of the file the onnx model should be saved to.
input_sample: An input for tracing. Default: None.
opset_version: Opset version we export the model to the onnx submodule. Default: 11.
enable_dynamic: Whether to specify axes of tensors as dynamic. Default: True.
**kwargs: Will be passed to torch.onnx.export function.
"""
if input_sample is None:
input_sample = torch.rand(1, 3, 320, 320).to(next(self.parameters()).device)

dynamic_axes = (
{
"images": {0: "batch", 2: "height", 3: "width"},
"boxes": {0: "batch", 1: "num_objects"},
"scores": {0: "batch", 1: "num_objects"},
}
if enable_dynamic
else None
)

input_names = ["images"]
output_names = ["boxes", "scores"]

torch.onnx.export(
self.model,
input_sample,
file_path,
do_constant_folding=True,
opset_version=opset_version,
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
**kwargs,
)


class EngineBuilder:
Expand Down Expand Up @@ -66,12 +177,12 @@ def create_network(self, onnx_path: str):
inputs = [self.network.get_input(i) for i in range(self.network.num_inputs)]
outputs = [self.network.get_output(i) for i in range(self.network.num_outputs)]

log.info("Network Description")
logger.info("Network Description")
for input in inputs:
self.batch_size = input.shape[0]
log.info(f"Input '{input.name}' with shape {input.shape} and dtype {input.dtype}")
logger.info(f"Input '{input.name}' with shape {input.shape} and dtype {input.dtype}")
for output in outputs:
log.info(f"Output '{output.name}' with shape {output.shape} and dtype {output.dtype}")
logger.info(f"Output '{output.name}' with shape {output.shape} and dtype {output.dtype}")
assert self.batch_size > 0
self.builder.max_batch_size = self.batch_size

Expand Down Expand Up @@ -100,18 +211,18 @@ def create_engine(

engine_path.parent.mkdir(parents=True, exist_ok=True)

log.info(f"Building {precision} Engine in {engine_path}")
logger.info(f"Building {precision} Engine in {engine_path}")

if precision == "fp16":
if not self.builder.platform_has_fast_fp16:
log.warning("FP16 is not supported natively on this platform/device")
logger.warning("FP16 is not supported natively on this platform/device")
else:
self.config.set_flag(trt.BuilderFlag.FP16)
elif precision == "fp32":
log.info("Using fp32 mode.")
logger.info("Using fp32 mode.")
else:
raise NotImplementedError(f"Currently hasn't been implemented: {precision}.")

with self.builder.build_engine(self.network, self.config) as engine, open(engine_path, "wb") as f:
f.write(engine.serialize())
log.info(f"Serialize engine success, saved as {engine_path}")
logger.info(f"Serialize engine success, saved as {engine_path}")
11 changes: 8 additions & 3 deletions yolort/runtime/yolo_graphsurgeon.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
except ImportError:
gs = None

from .yolo_tensorrt_model import YOLOTRTModule
from .trt_helper import YOLOTRTModule

logging.basicConfig(level=logging.INFO)
logging.getLogger("YOLOGraphSurgeon").setLevel(logging.INFO)
Expand All @@ -30,8 +30,13 @@

class YOLOGraphSurgeon:
"""
Constructor of the YOLOv5 Graph Surgeon object, TensorRT treat ``nms`` as
plugin, especially ``EfficientNMS_TRT`` in our yolort PostProcess module.
Constructor of the YOLOv5 Graph Surgeon object.

Because TensorRT treat the ``torchvision::ops::nms`` as plugin, we use the a simple post-processing
module named ``LogitsDecoder`` to connect to ``BatchedNMS_TRT`` plugin in TensorRT.

And the ``BatchedNMS_TRT`` plays the same role of following computation.
https://github.com/zhiqwang/yolov5-rt-stack/blob/02c74a0/yolort/models/box_head.py#L462-L470

Args:
checkpoint_path: The path pointing to the PyTorch saved model to load.
Expand Down
Loading