Skip to content

Commit

Permalink
Fix register_nms in YOLOGraphSurgeon (#252)
Browse files Browse the repository at this point in the history
* Fix exporting TensorRT engine and inferencing

* Fix register_nms in YOLOv5GraphSurgeon

* Fixing docstrings

* Fix preprocessing and postprocessing in PredictorTRT

* Fix pylint

* Fix detect and plugins register

* Fix importer in test_yolo_trt_module
  • Loading branch information
zhiqwang authored Dec 24, 2021
1 parent 5105b37 commit 1d86967
Show file tree
Hide file tree
Showing 8 changed files with 84 additions and 94 deletions.
2 changes: 1 addition & 1 deletion test/test_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pytest
import torch
from torch import Tensor
from yolort.runtime import YOLOTRTModule
from yolort.runtime.yolo_tensorrt_model import YOLOTRTModule
from yolort.v5 import attempt_download


Expand Down
34 changes: 13 additions & 21 deletions tools/detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,44 +97,36 @@ def run(
# make dir
(save_dir / "labels" if save_txt else save_dir).mkdir(parents=True, exist_ok=True)

# Load model
# Load the TensorRT engine
device = select_device(device)
model = PredictorTRT(
engine = PredictorTRT(
weights,
device=device,
score_thresh=conf_thres,
iou_thresh=iou_thres,
detections_per_img=max_det,
)
stride, names = model.stride, model.names
stride, names = engine.stride, engine.names
img_size = check_img_size(img_size, stride=stride) # check image size

# Dataloader
dataset = LoadImages(source, img_size=img_size, stride=stride, auto=False)

# Run inference
model.warmup(img_size=(1, 3, *img_size), half=half)
dt, seen = [0.0, 0.0, 0.0], 0
for path, im, im0s, _, s in dataset:
engine.warmup(img_size=(1, 3, *img_size), half=half)
dt, seen = [0.0, 0.0], 0
for path, image, im0s, _, s in dataset:
t1 = time_sync()
im = torch.from_numpy(im).to(device)
im = im.half() if half else im.float() # uint8 to fp16/32
im /= 255 # 0 - 255 to 0.0 - 1.0
if len(im.shape) == 3:
im = im[None] # expand for batch dim
image = engine.preprocessing(image)
t2 = time_sync()
dt[0] += t2 - t1

# Inference
visualize = increment_path(save_dir / Path(path).stem, mkdir=True) if visualize else False
pred_logits = model(im)
detections = engine.run_on_image(image)
t3 = time_sync()
dt[1] += t3 - t2

# NMS
detections = model.postprocessing(pred_logits)
dt[2] += time_sync() - t3

# Process predictions
for i, det in enumerate(detections): # per image
seen += 1
Expand All @@ -145,15 +137,15 @@ def run(
p, im0, frame = path, im0s.copy(), getattr(dataset, "frame", 0)

p = Path(p) # to Path
save_path = str(save_dir / p.name) # im.jpg
save_path = str(save_dir / p.name) # image.jpg
txt_path = str(save_dir / "labels" / p.stem) + ("" if dataset.mode == "image" else f"_{frame}")
s += "%gx%g " % im.shape[2:] # print string
s += "%gx%g " % image.shape[2:] # print string
gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] # normalization gain whwh
imc = im0.copy() if save_crop else im0 # for save_crop
annotator = Annotator(im0, line_width=line_thickness, example=str(names))
if len(det):
# Rescale boxes from img_size to im0 size
boxes = scale_coords(im.shape[2:], det["boxes"], im0.shape).round()
boxes = scale_coords(image.shape[2:], det["boxes"], im0.shape).round()
scores = det["scores"]
labels = det["labels"]

Expand Down Expand Up @@ -205,8 +197,8 @@ def run(
# Print results
speeds_info = tuple(x / seen * 1e3 for x in dt) # speeds per image
logger.info(
f"Speed: {speeds_info[0]:.1f}ms pre-process, {speeds_info[1]:.1f}ms inference, "
f"{speeds_info[2]:.1f}ms NMS per image at shape {(1, 3, *img_size)}",
f"Speed: {speeds_info[0]:.1f}ms pre-process, {speeds_info[1]:.1f}ms inference & "
f"NMS per image at shape {(1, 3, *img_size)}",
)
if save_txt or save_img:
saved_info = (
Expand Down
7 changes: 6 additions & 1 deletion yolort/models/box_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,12 @@ def forward(
bbox_regression.append(boxes)
pred_scores.append(scores)

return torch.stack(bbox_regression), torch.stack(pred_scores)
# The default boxes tensor has shape [batch_size, number_boxes, 4].
# This will insert a "1" dimension in the second axis, to become
# [batch_size, number_boxes, 1, 4], the shape that plugin/BatchedNMS expects.
boxes = torch.stack(bbox_regression).unsqueeze_(2)
scores = torch.stack(pred_scores)
return boxes, scores


class PostProcess(LogitsDecoder):
Expand Down
4 changes: 2 additions & 2 deletions yolort/runtime/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) 2021, yolort team. All Rights Reserved.
from .y_onnxruntime import PredictorORT
from .y_tensorrt import PredictorTRT
from .yolo_tensorrt_model import YOLOTRTModule
from .yolo_graphsurgeon import YOLOGraphSurgeon

__all__ = ["PredictorORT", "YOLOTRTModule", "PredictorTRT"]
__all__ = ["PredictorORT", "PredictorTRT", "YOLOGraphSurgeon"]
2 changes: 1 addition & 1 deletion yolort/runtime/trt_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def create_network(self, onnx_path: str):
def create_engine(
self,
engine_path: str,
precision: str,
precision: str = "fp32",
calib_input: Optional[str] = None,
calib_cache: Optional[str] = None,
calib_num_images: int = 5000,
Expand Down
60 changes: 25 additions & 35 deletions yolort/runtime/y_tensorrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import numpy as np
import torch
from torch import Tensor
from torchvision.ops import box_convert, boxes as box_ops

try:
import tensorrt as trt
Expand Down Expand Up @@ -40,7 +39,7 @@ class PredictorTRT:
>>> detector = PredictorTRT(engine_path, device)
>>>
>>> img_path = 'bus.jpg'
>>> scores, class_ids, boxes = detector.run_on_image(img_path)
>>> detections = detector.run_on_image(img_path)
"""

def __init__(
Expand All @@ -62,10 +61,12 @@ def __init__(

self.engine = self._build_engine()
self._set_context()
self.half = False

def _build_engine(self):
logger.info(f"Loading {self.engine_path} for TensorRT inference...")
trt_logger = trt.Logger(trt.Logger.INFO)
trt.init_libnvinfer_plugins(trt_logger, namespace="")
with open(self.engine_path, "rb") as f, trt.Runtime(trt_logger) as runtime:
engine = runtime.deserialize_cuda_engine(f.read())

Expand All @@ -83,6 +84,14 @@ def _set_context(self):
self.binding_addrs = OrderedDict((n, d.ptr) for n, d in self.bindings.items())
self.context = self.engine.create_execution_context()

def preprocessing(self, image):
image = torch.from_numpy(image).to(self.device)
image = image.half() if self.half else image.float() # uint8 to fp16/32
image /= 255 # 0 - 255 to 0.0 - 1.0
if len(image.shape) == 3:
image = image[None] # expand for batch dim
return image

def __call__(self, image: Tensor):
"""
Args:
Expand All @@ -95,50 +104,31 @@ def __call__(self, image: Tensor):
assert image.shape == self.bindings["images"].shape, (image.shape, self.bindings["images"].shape)
self.binding_addrs["images"] = int(image.data_ptr())
self.context.execute_v2(list(self.binding_addrs.values()))
pred_logits = self.bindings["output"].data
return pred_logits
num_dets = self.bindings["num_detections"].data
boxes = self.bindings["detection_boxes"].data
scores = self.bindings["detection_scores"].data
labels = self.bindings["detection_classes"].data
return boxes, scores, labels, num_dets

def run_on_image(self, image):
def run_on_image(self, image: Tensor):
"""
Run the TensorRT engine for one image only.
Args:
image_path (str): The image path to be predicted.
image (Tensor): an image of shape (C, N, H, W).
"""
pred_logits = self(image)
detections = self.postprocessing(pred_logits)
boxes, scores, labels, num_dets = self(image)

detections = self.postprocessing(boxes, scores, labels, num_dets)
return detections

@staticmethod
def _decode_pred_logits(pred_logits: Tensor):
"""
Decode the prediction logit from the PostPrecess.
"""
# Compute conf
# box_conf x class_conf, w/ shape: num_anchors x num_classes
scores = pred_logits[:, 5:] * pred_logits[:, 4:5]
boxes = box_convert(pred_logits[:, :4], in_fmt="cxcywh", out_fmt="xyxy")

return boxes, scores

def postprocessing(self, pred_logits: Tensor):
batch_size = pred_logits.shape[0]
def postprocessing(all_boxes, all_scores, all_labels, all_num_dets):
detections: List[Dict[str, Tensor]] = []

for idx in range(batch_size): # image idx, image inference
# Decode the predict logits
boxes, scores = self._decode_pred_logits(pred_logits[idx])

# remove low scoring boxes
inds, labels = torch.where(scores > self.score_thresh)
boxes, scores = boxes[inds], scores[inds, labels]

# non-maximum suppression, independently done per level
keep = box_ops.batched_nms(boxes, scores, labels, self.iou_thresh)
# Keep only topk scoring head_outputs
keep = keep[: self.detections_per_img]
boxes, scores, labels = boxes[keep], scores[keep], labels[keep]

for boxes, scores, labels, num_dets in zip(all_boxes, all_scores, all_labels, all_num_dets):
keep = num_dets.item()
boxes, scores, labels = boxes[:keep], scores[:keep], labels[:keep]
detections.append({"scores": scores, "labels": labels, "boxes": boxes})

return detections
Expand Down
64 changes: 33 additions & 31 deletions yolort/runtime/yolo_graphsurgeon.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@
from .yolo_tensorrt_model import YOLOTRTModule

logging.basicConfig(level=logging.INFO)
logging.getLogger("YOLOv5GraphSurgeon").setLevel(logging.INFO)
log = logging.getLogger("YOLOv5GraphSurgeon")
logging.getLogger("YOLOGraphSurgeon").setLevel(logging.INFO)
logger = logging.getLogger("YOLOGraphSurgeon")


class YOLOv5GraphSurgeon:
class YOLOGraphSurgeon:
"""
Constructor of the YOLOv5 Graph Surgeon object, TensorRT treat ``nms`` as
plugin, especially ``EfficientNMS_TRT`` in our yolort PostProcess module.
Expand All @@ -44,27 +44,26 @@ class YOLOv5GraphSurgeon:
def __init__(
self,
checkpoint_path: str,
score_thresh: float = 0.25,
version: str = "r6.0",
enable_dynamic: bool = True,
):
checkpoint_path = Path(checkpoint_path)
assert checkpoint_path.exists()

# Use YOLOTRTModule to convert saved model to an initial ONNX graph.
model = YOLOTRTModule(checkpoint_path, score_thresh=score_thresh, version=version)
model = YOLOTRTModule(checkpoint_path, version=version)
model = model.eval()

log.info(f"Loaded saved model from {checkpoint_path}")
logger.info(f"Loaded saved model from {checkpoint_path}")
onnx_model_path = checkpoint_path.with_suffix(".onnx")
model.to_onnx(onnx_model_path, enable_dynamic=enable_dynamic)
self.graph = gs.import_onnx(onnx.load(onnx_model_path))
assert self.graph
log.info("PyTorch2ONNX graph created successfully")
logger.info("PyTorch2ONNX graph created successfully")

# Fold constants via ONNX-GS that PyTorch2ONNX may have missed
self.graph.fold_constants()

self.num_classes = model.num_clases
self.batch_size = 1

def infer(self):
Expand All @@ -85,11 +84,11 @@ def infer(self):
model = shape_inference.infer_shapes(model)
self.graph = gs.import_onnx(model)
except Exception as e:
log.info(f"Shape inference could not be performed at this time:\n{e}")
logger.info(f"Shape inference could not be performed at this time:\n{e}")
try:
self.graph.fold_constants(fold_shapes=True)
except TypeError as e:
log.error(
logger.error(
"This version of ONNX GraphSurgeon does not support folding shapes, "
f"please upgrade your onnx_graphsurgeon module. Error:\n{e}"
)
Expand All @@ -111,43 +110,46 @@ def save(self, output_path):
self.graph.cleanup().toposort()
model = gs.export_onnx(self.graph)
onnx.save(model, output_path)
log.info(f"Saved ONNX model to {output_path}")
logger.info(f"Saved ONNX model to {output_path}")

def register_nms(
self,
score_thresh: float = 0.25,
nms_thresh: float = 0.45,
detections_per_img: int = 100,
normalized: bool = True,
):
"""
Register the ``EfficientNMS_TRT`` plugin node.
Register the ``BatchedNMS_TRT`` plugin node.
NMS expects these shapes for its input tensors:
- box_net: [batch_size, number_boxes, 4]
- class_net: [batch_size, number_boxes, number_labels]
As the original tensors from YOLOv5 will be used, the NMS code type is set to 0 (Corners),
because this is the internal box coding format used by the network.
- box_net: [batch_size, number_boxes, 1, 4]
- class_net: [batch_size, number_boxes, number_labels]
Args:
threshold: Override the score threshold attribute. If set to None,
use the value in the graph.
detections: Override the max detections attribute. If set to None,
use the value in the graph.
score_thresh (float): The scalar threshold for score (low scoring boxes are removed).
nms_thresh (float): The scalar threshold for IOU (new boxes that have high IOU
overlap with previously selected boxes are removed).
detections_per_img (int): Number of best detections to keep after NMS.
normalized (bool): Set to false if the box coordinates are not normalized,
meaning they are not in the range [0,1]. Defaults: True.
"""

self.infer()
# Find the concat node at the end of the network
nms_inputs = self.graph.outputs
op = "EfficientNMS_TRT"
op = "BatchedNMS_TRT"
attrs = {
"plugin_version": "1",
"background_class": -1, # no background class
"max_output_boxes": detections_per_img,
"score_threshold": max(0.01, score_thresh),
"iou_threshold": nms_thresh,
"score_activation": True,
"box_coding": 0,
"shareLocation": True,
"backgroundLabelId": -1, # no background class
"numClasses": self.num_classes,
"topK": 1024,
"keepTopK": detections_per_img,
"scoreThreshold": score_thresh,
"iouThreshold": nms_thresh,
"isNormalized": normalized,
"clipBoxes": False,
}

# NMS Outputs
Expand All @@ -167,8 +169,8 @@ def register_nms(
shape=[self.batch_size, detections_per_img],
)
output_labels = gs.Variable(
name="detection_labels",
dtype=np.int32,
name="detection_classes",
dtype=np.float32,
shape=[self.batch_size, detections_per_img],
)

Expand All @@ -183,7 +185,7 @@ def register_nms(
outputs=nms_outputs,
attrs=attrs,
)
log.info(f"Created NMS plugin '{op}' with attributes: {attrs}")
logger.info(f"Created NMS plugin '{op}' with attributes: {attrs}")

self.graph.outputs = nms_outputs

Expand Down
Loading

0 comments on commit 1d86967

Please sign in to comment.