Skip to content

Commit

Permalink
Add TensorRT Python interface and CLI (#251)
Browse files Browse the repository at this point in the history
* init commit

* Minor fix to the TensorRT engine exporting

* Init y_tensorrt inference script

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Add TensorRT inference CLI

* [pre-commit.ci] auto fixes from pre-commit.com hooks

* Add missing type

* Minor fixes

* Define _build_engine in PredictorTRT

* Define _set_context in PredictorTRT

* Fix examples of PredictorTRT

* Add yolov5 reference

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
zhiqwang and pre-commit-ci[bot] authored Dec 24, 2021
1 parent 75ef670 commit 5105b37
Show file tree
Hide file tree
Showing 9 changed files with 582 additions and 52 deletions.
259 changes: 259 additions & 0 deletions tools/detect.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,259 @@
# Copyright (c) 2021, yolort team. All rights reserved.
#
# This source code is licensed under the GPL-3.0 license found in the
# LICENSE file in the root directory of this source tree.
#

import argparse
from pathlib import Path
from typing import List, Optional, Tuple

import cv2
import torch
from torchvision.ops import box_convert
from yolort.runtime import PredictorTRT
from yolort.utils.image_utils import to_numpy
from yolort.v5.utils.datasets import IMG_FORMATS, VID_FORMATS, LoadImages
from yolort.v5.utils.general import (
colorstr,
set_logging,
increment_path,
check_img_size,
check_file,
scale_coords,
strip_optimizer,
)
from yolort.v5.utils.plots import Annotator, colors, save_one_box
from yolort.v5.utils.torch_utils import select_device, time_sync

logger = set_logging(__name__)


@torch.no_grad()
def run(
weights: str = "yolort.engine",
source: str = "bus.jpg",
img_size: Tuple[int, int] = (640, 640),
conf_thres: float = 0.25,
iou_thres: float = 0.45,
max_det: int = 1000,
device: str = "",
view_img: bool = False,
save_txt: bool = False,
save_conf: bool = False,
save_crop: bool = False,
nosave: bool = False,
classes: Optional[List] = None,
visualize: bool = False,
update: bool = False,
project: str = "./runs/detect",
name: str = "exp",
exist_ok: bool = False,
line_thickness=3,
hide_labels: bool = False,
hide_conf: bool = False,
half: bool = False,
):
"""
The core function for detecting source of image, path or directory.
Adapted form https://github.com/ultralytics/yolov5/blob/db6ec66/detect.py
Args:
weights: Path of the engine
source: file/dir/URL/glob, 0 for webcam
img_size: inference size (height, width)
conf_thres: confidence threshold
iou_thres: NMS IOU threshold
max_det: maximum detections per image
device: cuda device, i.e. 0 or 0,1,2,3 or cpu
view_img: show results
save_txt: save results to *.txt
save_conf: save confidences in --save-txt labels
save_crop: save cropped prediction boxes
nosave: do not save images/videos
classes: filter by class: --class 0, or --class 0 2 3
visualize: visualize features
update: update all models
project: save results to project/name
name: save results to project/name
exist_ok: existing project/name ok, do not increment
line_thickness: bounding box thickness (pixels)
hide_labels: hide labels
hide_conf: hide confidences
half: use FP16 half-precision inference
"""
source = str(source)
save_img = not nosave and not source.endswith(".txt") # save inference images
is_file = Path(source).suffix[1:] in (IMG_FORMATS + VID_FORMATS)
is_url = source.lower().startswith(("rtsp://", "rtmp://", "http://", "https://"))
webcam = source.isnumeric() or source.endswith(".txt") or (is_url and not is_file)
if is_url and is_file:
source = check_file(source) # download

# Directories
# increment run
save_dir = increment_path(Path(project) / name, exist_ok=exist_ok)
# make dir
(save_dir / "labels" if save_txt else save_dir).mkdir(parents=True, exist_ok=True)

# Load model
device = select_device(device)
model = PredictorTRT(
weights,
device=device,
score_thresh=conf_thres,
iou_thresh=iou_thres,
detections_per_img=max_det,
)
stride, names = model.stride, model.names
img_size = check_img_size(img_size, stride=stride) # check image size

# Dataloader
dataset = LoadImages(source, img_size=img_size, stride=stride, auto=False)

# Run inference
model.warmup(img_size=(1, 3, *img_size), half=half)
dt, seen = [0.0, 0.0, 0.0], 0
for path, im, im0s, _, s in dataset:
t1 = time_sync()
im = torch.from_numpy(im).to(device)
im = im.half() if half else im.float() # uint8 to fp16/32
im /= 255 # 0 - 255 to 0.0 - 1.0
if len(im.shape) == 3:
im = im[None] # expand for batch dim
t2 = time_sync()
dt[0] += t2 - t1

# Inference
visualize = increment_path(save_dir / Path(path).stem, mkdir=True) if visualize else False
pred_logits = model(im)
t3 = time_sync()
dt[1] += t3 - t2

# NMS
detections = model.postprocessing(pred_logits)
dt[2] += time_sync() - t3

# Process predictions
for i, det in enumerate(detections): # per image
seen += 1
if webcam: # batch_size >= 1
p, im0, frame = path[i], im0s[i].copy(), dataset.count
s += f"{i}: "
else:
p, im0, frame = path, im0s.copy(), getattr(dataset, "frame", 0)

p = Path(p) # to Path
save_path = str(save_dir / p.name) # im.jpg
txt_path = str(save_dir / "labels" / p.stem) + ("" if dataset.mode == "image" else f"_{frame}")
s += "%gx%g " % im.shape[2:] # print string
gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] # normalization gain whwh
imc = im0.copy() if save_crop else im0 # for save_crop
annotator = Annotator(im0, line_width=line_thickness, example=str(names))
if len(det):
# Rescale boxes from img_size to im0 size
boxes = scale_coords(im.shape[2:], det["boxes"], im0.shape).round()
scores = det["scores"]
labels = det["labels"]

# Print results
for c in labels.unique():
n = (labels == c).sum() # detections per class
s += f"{n} {names[int(c)]}{'s' * (n > 1)}, " # add to string

# Write results
for box, score, class_idx in zip(boxes, scores, labels):
if save_txt: # Write to file
# normalized xywh
xywh = box_convert(torch.tensor(box).view(1, 4), in_fmt="xyxy", out_fmt="cxcywh")
xywh = (xywh / gn).view(-1).tolist()
# label format
line = (class_idx, *xywh, score) if save_conf else (class_idx, *xywh)
with open(txt_path + ".txt", "a") as f:
f.write(("%g " * len(line)).rstrip() % line + "\n")

xyxy = to_numpy(box)
if save_img or save_crop or view_img: # Add bbox to image
cls = int(class_idx) # integer class
label = (
None
if hide_labels
else (names[cls] if hide_conf else f"{names[cls]} {score:.2f}")
)
annotator.box_label(xyxy, label, color=colors(cls, True))
if save_crop:
save_path = save_dir / "crops" / names[cls] / f"{p.stem}.jpg"
save_one_box(xyxy, imc, file=save_path, BGR=True)

# Print time (inference-only)
logger.info(f"{s}Done. ({t3 - t2:.3f}s)")

# Stream results
im0 = annotator.result()
if view_img:
cv2.imshow(str(p), im0)
cv2.waitKey(1) # 1 millisecond

# Save results (image with detections)
if save_img:
if dataset.mode == "image":
cv2.imwrite(save_path, im0)
else:
raise NotImplementedError("Currently this method hasn't implemented yet.")

# Print results
speeds_info = tuple(x / seen * 1e3 for x in dt) # speeds per image
logger.info(
f"Speed: {speeds_info[0]:.1f}ms pre-process, {speeds_info[1]:.1f}ms inference, "
f"{speeds_info[2]:.1f}ms NMS per image at shape {(1, 3, *img_size)}",
)
if save_txt or save_img:
saved_info = (
f"\n{len(list(save_dir.glob('labels/*.txt')))} labels " f"saved to {save_dir / 'labels'}"
if save_txt
else "",
)
logger.info(f"Results saved to {colorstr('bold', save_dir)}{saved_info}")
if update:
# update model (to fix SourceChangeWarning)
strip_optimizer(weights)


def get_parser():
parser = argparse.ArgumentParser("CLI tool for detecting source.", add_help=True)
parser.add_argument("--weights", type=str, default="yolov5s.pt", help="model path(s)")
parser.add_argument("--source", type=str, default="data/images", help="file/dir/URL/glob, 0 for webcam")
parser.add_argument("--img_size", nargs="+", type=int, default=[640], help="inference size h,w")
parser.add_argument("--conf_thres", type=float, default=0.25, help="confidence threshold")
parser.add_argument("--iou_thres", type=float, default=0.45, help="NMS IoU threshold")
parser.add_argument("--max_det", type=int, default=1000, help="maximum detections per image")
parser.add_argument("--device", default="", help="cuda device, i.e. 0 or 0,1,2,3 or cpu")
parser.add_argument("--view_img", action="store_true", help="show results")
parser.add_argument("--save_txt", action="store_true", help="save results to *.txt")
parser.add_argument("--save_conf", action="store_true", help="save confidences in --save-txt labels")
parser.add_argument("--save_crop", action="store_true", help="save cropped prediction boxes")
parser.add_argument("--nosave", action="store_true", help="do not save images/videos")
parser.add_argument("--classes", nargs="+", type=int, help="filter by class: --classes 0 2 3")
parser.add_argument("--visualize", action="store_true", help="visualize features")
parser.add_argument("--update", action="store_true", help="update all models")
parser.add_argument("--project", default="./runs/detect", help="save results to project/name")
parser.add_argument("--name", default="exp", help="save results to project/name")
parser.add_argument("--exist_ok", action="store_true", help="existing project/name ok, do not increment")
parser.add_argument("--line_thickness", default=3, type=int, help="bounding box thickness (pixels)")
parser.add_argument("--hide_labels", default=False, action="store_true", help="hide labels")
parser.add_argument("--hide_conf", default=False, action="store_true", help="hide confidences")
parser.add_argument("--half", action="store_true", help="use FP16 half-precision inference")

return parser


def cli_main():
parser = get_parser()
args = parser.parse_args()
logger.info(f"Command Line Args: {args}")
run(**vars(args))


if __name__ == "__main__":
cli_main()
3 changes: 2 additions & 1 deletion yolort/models/box_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,8 @@ def _concat_pred_logits(

return all_pred_logits

def _decode_pred_logits(self, pred_logits: Tensor):
@staticmethod
def _decode_pred_logits(pred_logits: Tensor):
"""
Decode the prediction logit from the PostPrecess.
"""
Expand Down
3 changes: 2 additions & 1 deletion yolort/runtime/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) 2021, yolort team. All Rights Reserved.
from .y_onnxruntime import PredictorORT
from .y_tensorrt import PredictorTRT
from .yolo_tensorrt_model import YOLOTRTModule

__all__ = ["PredictorORT", "YOLOTRTModule"]
__all__ = ["PredictorORT", "YOLOTRTModule", "PredictorTRT"]
31 changes: 12 additions & 19 deletions yolort/runtime/trt_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,22 +28,22 @@ class EngineBuilder:
Parses an ONNX graph and builds a TensorRT engine from it.
"""

def __init__(self, verbose=False, workspace=8):
def __init__(self, verbose=False, workspace=4):
"""
Args:
verbose: If enabled, a higher verbosity level will be
set on the TensorRT logger.
workspace: Max memory workspace to allow, in Gb.
"""
self.trt_logger = trt.Logger(trt.Logger.INFO)
self.logger = trt.Logger(trt.Logger.INFO)
if verbose:
self.trt_logger.min_severity = trt.Logger.Severity.VERBOSE
self.logger.min_severity = trt.Logger.Severity.VERBOSE

trt.init_libnvinfer_plugins(self.trt_logger, namespace="")
trt.init_libnvinfer_plugins(self.logger, namespace="")

self.builder = trt.Builder(self.trt_logger)
self.builder = trt.Builder(self.logger)
self.config = self.builder.create_builder_config()
self.config.max_workspace_size = workspace * (2 ** 30)
self.config.max_workspace_size = workspace * 1 << 30

self.batch_size = None
self.network = None
Expand All @@ -56,19 +56,12 @@ def create_network(self, onnx_path: str):
Args:
onnx_path: The path to the ONNX graph to load.
"""
network_flags = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)

self.network = self.builder.create_network(network_flags)
self.parser = trt.OnnxParser(self.network, self.trt_logger)

onnx_path = Path(onnx_path)
with open(onnx_path, "rb") as f:
if not self.parser.parse(f.read()):
err_message = f"Failed to load ONNX file: {onnx_path}"
log.error(err_message)
for error in range(self.parser.num_errors):
log.error(self.parser.get_error(error))
raise OSError(err_message)
flag = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
self.network = self.builder.create_network(flag)
self.parser = trt.OnnxParser(self.network, self.logger)
if not self.parser.parse_from_file(onnx_path):
raise RuntimeError(f"Failed to load ONNX file: {onnx_path}")

inputs = [self.network.get_input(i) for i in range(self.network.num_inputs)]
outputs = [self.network.get_output(i) for i in range(self.network.num_outputs)]
Expand Down Expand Up @@ -120,5 +113,5 @@ def create_engine(
raise NotImplementedError(f"Currently hasn't been implemented: {precision}.")

with self.builder.build_engine(self.network, self.config) as engine, open(engine_path, "wb") as f:
log.info(f"Serializing engine to file: {engine_path}")
f.write(engine.serialize())
log.info(f"Serialize engine success, saved as {engine_path}")
Loading

0 comments on commit 5105b37

Please sign in to comment.