Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor TensorRT engine export #312

Merged
merged 6 commits into from
Feb 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
337 changes: 116 additions & 221 deletions notebooks/onnx-graphsurgeon-inference-tensorrt.ipynb

Large diffs are not rendered by default.

109 changes: 105 additions & 4 deletions test/test_relaying.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,112 @@
# Copyright (c) 2021, yolort team. All Rights Reserved.
# Copyright (c) 2021, yolort team. All rights reserved.

from pathlib import Path

import pytest
import torch
from torch import Tensor
from torch.jit._trace import TopLevelTracedModule
from yolort.models import yolov5s
from yolort.relaying import get_trace_module
from yolort.relaying import get_trace_module, YOLOInference
from yolort.relaying.trt_graphsurgeon import YOLOTRTGraphSurgeon
from yolort.v5 import attempt_download


def test_get_trace_module():
@pytest.mark.parametrize("h", [320, 416, 640])
@pytest.mark.parametrize("w", [320, 416, 640])
def test_get_trace_module(h, w):
model_func = yolov5s(pretrained=True)
script_module = get_trace_module(model_func, input_shape=(416, 320))
script_module = get_trace_module(model_func, input_shape=(h, w))
assert isinstance(script_module, TopLevelTracedModule)
assert script_module.code is not None


@pytest.mark.parametrize(
"arch, version, upstream_version, hash_prefix",
[
("yolov5s", "r4.0", "v4.0", "9ca9a642"),
("yolov5n", "r6.0", "v6.0", "649e089f"),
("yolov5s", "r6.0", "v6.0", "c3b140f3"),
("yolov5n6", "r6.0", "v6.0", "beecbbae"),
],
)
def test_yolo_trt_module(arch, version, upstream_version, hash_prefix):

base_url = "https://github.com/ultralytics/yolov5/releases/download/"
model_url = f"{base_url}/{upstream_version}/{arch}.pt"
checkpoint_path = attempt_download(model_url, hash_prefix=hash_prefix)

model = YOLOInference(checkpoint_path, version=version)
model.eval()
samples = torch.rand(1, 3, 320, 320)
outs = model(samples)

assert isinstance(outs, tuple)
assert len(outs) == 2
assert isinstance(outs[0], Tensor)
assert isinstance(outs[1], Tensor)


@pytest.mark.parametrize(
"arch, version, upstream_version, hash_prefix",
[
("yolov5s", "r4.0", "v4.0", "9ca9a642"),
("yolov5n", "r6.0", "v6.0", "649e089f"),
("yolov5s", "r6.0", "v6.0", "c3b140f3"),
("yolov5n6", "r6.0", "v6.0", "beecbbae"),
],
)
def test_yolo_trt_module_to_onnx(arch, version, upstream_version, hash_prefix):
base_url = "https://github.com/ultralytics/yolov5/releases/download/"
model_url = f"{base_url}/{upstream_version}/{arch}.pt"
checkpoint_path = attempt_download(model_url, hash_prefix=hash_prefix)

model = YOLOInference(checkpoint_path, version=version)
model.eval()
onnx_file_path = f"yolo_trt_module_to_onnx_{arch}_{hash_prefix}.onnx"
assert not Path(onnx_file_path).exists()
model.to_onnx(onnx_file_path)
assert Path(onnx_file_path).exists()


@pytest.mark.parametrize(
"arch, version, upstream_version, hash_prefix",
[
("yolov5s", "r4.0", "v4.0", "9ca9a642"),
("yolov5n", "r6.0", "v6.0", "649e089f"),
("yolov5s", "r6.0", "v6.0", "c3b140f3"),
("yolov5n6", "r6.0", "v6.0", "beecbbae"),
],
)
def test_yolo_graphsurgeon_wo_nms(arch, version, upstream_version, hash_prefix):
base_url = "https://github.com/ultralytics/yolov5/releases/download/"
model_url = f"{base_url}/{upstream_version}/{arch}.pt"
checkpoint_path = attempt_download(model_url, hash_prefix=hash_prefix)

yolo_gs = YOLOTRTGraphSurgeon(checkpoint_path, version=version, enable_dynamic=False)
onnx_file_path = f"yolo_graphsurgeon_wo_nms_{arch}_{hash_prefix}.onnx"
assert not Path(onnx_file_path).exists()
yolo_gs.save(onnx_file_path)
assert Path(onnx_file_path).exists()


@pytest.mark.parametrize(
"arch, version, upstream_version, hash_prefix",
[
("yolov5s", "r4.0", "v4.0", "9ca9a642"),
("yolov5n", "r6.0", "v6.0", "649e089f"),
("yolov5s", "r6.0", "v6.0", "c3b140f3"),
("yolov5n6", "r6.0", "v6.0", "beecbbae"),
],
)
def test_yolo_graphsurgeon_register_nms(arch, version, upstream_version, hash_prefix):
base_url = "https://github.com/ultralytics/yolov5/releases/download/"
model_url = f"{base_url}/{upstream_version}/{arch}.pt"
checkpoint_path = attempt_download(model_url, hash_prefix=hash_prefix)

yolo_gs = YOLOTRTGraphSurgeon(checkpoint_path, version=version, enable_dynamic=False)
yolo_gs.register_nms()
onnx_file_path = f"yolo_graphsurgeon_register_nms{arch}_{hash_prefix}.onnx"
assert not Path(onnx_file_path).exists()
yolo_gs.save(onnx_file_path)
assert Path(onnx_file_path).exists()
100 changes: 0 additions & 100 deletions test/test_runtime.py
Original file line number Diff line number Diff line change
@@ -1,100 +0,0 @@
# Copyright (c) 2021, yolort team. All Rights Reserved.
from pathlib import Path

import pytest
import torch
from torch import Tensor
from yolort.runtime import YOLOGraphSurgeon
from yolort.runtime.trt_helper import YOLOTRTModule
from yolort.v5 import attempt_download


@pytest.mark.parametrize(
"arch, version, upstream_version, hash_prefix",
[
("yolov5s", "r4.0", "v4.0", "9ca9a642"),
("yolov5n", "r6.0", "v6.0", "649e089f"),
("yolov5s", "r6.0", "v6.0", "c3b140f3"),
("yolov5n6", "r6.0", "v6.0", "beecbbae"),
],
)
def test_yolo_trt_module(arch, version, upstream_version, hash_prefix):

base_url = "https://github.com/ultralytics/yolov5/releases/download/"
model_url = f"{base_url}/{upstream_version}/{arch}.pt"
checkpoint_path = attempt_download(model_url, hash_prefix=hash_prefix)

model = YOLOTRTModule(checkpoint_path, version=version)
model.eval()
samples = torch.rand(1, 3, 320, 320)
outs = model(samples)

assert isinstance(outs, tuple)
assert len(outs) == 2
assert isinstance(outs[0], Tensor)
assert isinstance(outs[1], Tensor)


@pytest.mark.parametrize(
"arch, version, upstream_version, hash_prefix",
[
("yolov5s", "r4.0", "v4.0", "9ca9a642"),
("yolov5n", "r6.0", "v6.0", "649e089f"),
("yolov5s", "r6.0", "v6.0", "c3b140f3"),
("yolov5n6", "r6.0", "v6.0", "beecbbae"),
],
)
def test_yolo_trt_module_to_onnx(arch, version, upstream_version, hash_prefix):
base_url = "https://github.com/ultralytics/yolov5/releases/download/"
model_url = f"{base_url}/{upstream_version}/{arch}.pt"
checkpoint_path = attempt_download(model_url, hash_prefix=hash_prefix)

model = YOLOTRTModule(checkpoint_path, version=version)
model.eval()
onnx_file_path = f"yolo_trt_module_to_onnx_{arch}_{hash_prefix}.onnx"
assert not Path(onnx_file_path).exists()
model.to_onnx(onnx_file_path)
assert Path(onnx_file_path).exists()


@pytest.mark.parametrize(
"arch, version, upstream_version, hash_prefix",
[
("yolov5s", "r4.0", "v4.0", "9ca9a642"),
("yolov5n", "r6.0", "v6.0", "649e089f"),
("yolov5s", "r6.0", "v6.0", "c3b140f3"),
("yolov5n6", "r6.0", "v6.0", "beecbbae"),
],
)
def test_yolo_graphsurgeon_wo_nms(arch, version, upstream_version, hash_prefix):
base_url = "https://github.com/ultralytics/yolov5/releases/download/"
model_url = f"{base_url}/{upstream_version}/{arch}.pt"
checkpoint_path = attempt_download(model_url, hash_prefix=hash_prefix)

yolo_gs = YOLOGraphSurgeon(checkpoint_path, version=version, enable_dynamic=False)
onnx_file_path = f"yolo_graphsurgeon_wo_nms_{arch}_{hash_prefix}.onnx"
assert not Path(onnx_file_path).exists()
yolo_gs.save(onnx_file_path)
assert Path(onnx_file_path).exists()


@pytest.mark.parametrize(
"arch, version, upstream_version, hash_prefix",
[
("yolov5s", "r4.0", "v4.0", "9ca9a642"),
("yolov5n", "r6.0", "v6.0", "649e089f"),
("yolov5s", "r6.0", "v6.0", "c3b140f3"),
("yolov5n6", "r6.0", "v6.0", "beecbbae"),
],
)
def test_yolo_graphsurgeon_register_nms(arch, version, upstream_version, hash_prefix):
base_url = "https://github.com/ultralytics/yolov5/releases/download/"
model_url = f"{base_url}/{upstream_version}/{arch}.pt"
checkpoint_path = attempt_download(model_url, hash_prefix=hash_prefix)

yolo_gs = YOLOGraphSurgeon(checkpoint_path, version=version, enable_dynamic=False)
yolo_gs.register_nms()
onnx_file_path = f"yolo_graphsurgeon_register_nms{arch}_{hash_prefix}.onnx"
assert not Path(onnx_file_path).exists()
yolo_gs.save(onnx_file_path)
assert Path(onnx_file_path).exists()
6 changes: 4 additions & 2 deletions yolort/relaying/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# Copyright (c) 2021, yolort team. All Rights Reserved.
# Copyright (c) 2021, yolort team. All rights reserved.

from .trace_wrapper import get_trace_module
from .yolo_inference import YOLOInference

__all__ = ["get_trace_module"]
__all__ = ["get_trace_module", "YOLOInference"]
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) 2021, yolort team. All rights reserved.

from typing import List, Tuple

import torch
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,4 @@
# Copyright (c) 2021, yolort team. All rights reserved.
#
# This source code is licensed under the GPL-3.0 license found in the
# LICENSE file in the root directory of this source tree.
#
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# This source code is licensed under the Apache-2.0 license found in the
# LICENSE file in the root directory of TensorRT source tree.
#

import logging
from pathlib import Path
Expand All @@ -24,16 +15,18 @@
except ImportError:
gs = None

from .trt_helper import YOLOTRTModule
from .yolo_inference import YOLOInference

logging.basicConfig(level=logging.INFO)
logging.getLogger("YOLOGraphSurgeon").setLevel(logging.INFO)
logger = logging.getLogger("YOLOGraphSurgeon")
logging.getLogger("YOLOTRTGraphSurgeon").setLevel(logging.INFO)
logger = logging.getLogger("YOLOTRTGraphSurgeon")

__all__ = ["YOLOTRTGraphSurgeon"]

class YOLOGraphSurgeon:

class YOLOTRTGraphSurgeon:
"""
Constructor of the YOLOv5 Graph Surgeon object.
YOLOv5 Graph Surgeon for TensorRT inference.

Because TensorRT treat the ``torchvision::ops::nms`` as plugin, we use the a simple post-processing
module named ``LogitsDecoder`` to connect to ``EfficientNMS_TRT`` plugin in TensorRT.
Expand Down Expand Up @@ -66,8 +59,8 @@ def __init__(
checkpoint_path = Path(checkpoint_path)
assert checkpoint_path.exists()

# Use YOLOTRTModule to convert saved model to an initial ONNX graph.
model = YOLOTRTModule(checkpoint_path, version=version)
# Use YOLOInference to convert saved model to an initial ONNX graph.
model = YOLOInference(checkpoint_path, version=version)
model = model.eval()
model = model.to(device=device)
logger.info(f"Loaded saved model from {checkpoint_path}")
Expand Down
116 changes: 116 additions & 0 deletions yolort/relaying/yolo_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
# Copyright (c) 2021, yolort team. All rights reserved.

from pathlib import PosixPath
from typing import Optional, Tuple, Union

import torch
from torch import nn, Tensor
from yolort.models import YOLO
from yolort.models.anchor_utils import AnchorGenerator
from yolort.models.backbone_utils import darknet_pan_backbone
from yolort.utils import load_from_ultralytics

from .logits_decoder import LogitsDecoder

__all__ = ["YOLOInference"]


class YOLOInference(nn.Module):
"""
TensorRT deployment friendly wrapper for YOLO.

Remove the ``torchvision::nms`` in this warpper, due to the fact that some third-party
inference frameworks currently do not support this operator very well.

Args:
checkpoint_path (string): Path of the trained YOLOv5 checkpoint.
version (string): Upstream YOLOv5 version. Default: 'r6.0'
"""

def __init__(self, checkpoint_path: str, version: str = "r6.0"):
super().__init__()
model_info = load_from_ultralytics(checkpoint_path, version=version)

backbone_name = f"darknet_{model_info['size']}_{version.replace('.', '_')}"
depth_multiple = model_info["depth_multiple"]
width_multiple = model_info["width_multiple"]
use_p6 = model_info["use_p6"]
backbone = darknet_pan_backbone(
backbone_name,
depth_multiple,
width_multiple,
version=version,
use_p6=use_p6,
)
num_classes = model_info["num_classes"]
anchor_generator = AnchorGenerator(model_info["strides"], model_info["anchor_grids"])
post_process = LogitsDecoder(model_info["strides"])
model = YOLO(
backbone,
num_classes,
anchor_generator=anchor_generator,
post_process=post_process,
)

model.load_state_dict(model_info["state_dict"])
self.model = model
self.num_classes = num_classes

@torch.no_grad()
def forward(self, inputs: Tensor) -> Tuple[Tensor, Tensor]:
"""
Args:
inputs (Tensor): batched images, of shape [batch_size x 3 x H x W]
"""
# Compute the detections
outputs = self.model(inputs)

return outputs

@torch.no_grad()
def to_onnx(
self,
file_path: Union[str, PosixPath],
input_sample: Optional[Tensor] = None,
opset_version: int = 11,
enable_dynamic: bool = True,
**kwargs,
):
"""
Saves the model in ONNX format.

Args:
file_path (Union[string, PosixPath]): The path of the file the onnx model should
be saved to.
input_sample (Tensor, Optional): An input for tracing. Default: None.
opset_version (int): Opset version we export the model to the onnx submodule. Default: 11.
enable_dynamic (bool): Whether to specify axes of tensors as dynamic. Default: True.
**kwargs: Will be passed to torch.onnx.export function.
"""
if input_sample is None:
input_sample = torch.rand(1, 3, 640, 640).to(next(self.parameters()).device)

dynamic_axes = (
{
"images": {0: "batch", 2: "height", 3: "width"},
"boxes": {0: "batch", 1: "num_objects"},
"scores": {0: "batch", 1: "num_objects"},
}
if enable_dynamic
else None
)

input_names = ["images"]
output_names = ["boxes", "scores"]

torch.onnx.export(
self.model,
input_sample,
file_path,
do_constant_folding=True,
opset_version=opset_version,
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
**kwargs,
)
Loading