Skip to content

Commit

Permalink
add tensorrt support (open-mmlab#2)
Browse files Browse the repository at this point in the history
  • Loading branch information
grimoire authored Jun 23, 2021
1 parent 6eb2e89 commit 6c47ee3
Show file tree
Hide file tree
Showing 13 changed files with 394 additions and 13 deletions.
2 changes: 1 addition & 1 deletion .isort.cfg
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
[settings]
known_third_party = mmcv,mmdet,numpy,setuptools,torch
known_third_party = mmcv,mmdet,numpy,setuptools,tensorrt,torch
6 changes: 6 additions & 0 deletions configs/_base_/backends/tensorrt.py
Original file line number Diff line number Diff line change
@@ -1 +1,7 @@
import tensorrt as trt

backend = 'tensorrt'
tensorrt_param = dict(
log_level=trt.Logger.WARNING,
fp16_mode=False,
save_file='onnx2tensorrt.engine')
5 changes: 4 additions & 1 deletion configs/_base_/torch2onnx.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,5 @@
pytorch2onnx = dict(
export_params=True, keep_initializers_as_inputs=False, opset_version=11)
export_params=True,
keep_initializers_as_inputs=False,
opset_version=11,
save_file='torch2onnx.onnx')
4 changes: 4 additions & 0 deletions configs/mmcls/mmcls_tensorrt.py
Original file line number Diff line number Diff line change
@@ -1 +1,5 @@
_base_ = ['./mmcls_base.py', '../_base_/backends/tensorrt.py']
tensorrt_param = dict(
opt_shape_dict=dict(
input=[[1, 3, 224, 224], [4, 3, 224, 224], [32, 3, 224, 224]]),
max_workspace_size=1 << 30)
4 changes: 4 additions & 0 deletions configs/mmdet/tensorrt.py
Original file line number Diff line number Diff line change
@@ -1 +1,5 @@
_base_ = ['./base.py', '../_base_/backends/tensorrt.py']
tensorrt_param = dict(
opt_shape_dict=dict(
input=[[1, 3, 320, 320], [1, 3, 800, 1344], [1, 3, 1344, 1344]]),
max_workspace_size=1 << 30)
9 changes: 5 additions & 4 deletions mmdeploy/apis/pytorch2onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@


def torch2onnx(img: Any,
work_dir: Optional[str],
work_dir: str,
save_file: str,
deploy_cfg: Union[str, mmcv.Config],
model_cfg: Union[str, mmcv.Config],
model_checkpoint: Optional[str] = None,
Expand All @@ -22,18 +23,18 @@ def torch2onnx(img: Any,
# load deploy_cfg if needed
if isinstance(deploy_cfg, str):
deploy_cfg = mmcv.Config.fromfile(deploy_cfg)
elif not isinstance(deploy_cfg, mmcv.Config):
if not isinstance(deploy_cfg, mmcv.Config):
raise TypeError('deploy_cfg must be a filename or Config object, '
f'but got {type(deploy_cfg)}')
# load model_cfg if needed
if isinstance(model_cfg, str):
model_cfg = mmcv.Config.fromfile(model_cfg)
elif not isinstance(model_cfg, mmcv.Config):
if not isinstance(model_cfg, mmcv.Config):
raise TypeError('config must be a filename or Config object, '
f'but got {type(model_cfg)}')

mmcv.mkdir_or_exist(osp.abspath(work_dir))
output_file = osp.join(work_dir, 'torch2onnx.onnx')
output_file = osp.join(work_dir, save_file)

pytorch2onnx_cfg = deploy_cfg['pytorch2onnx']
codebase = deploy_cfg['codebase']
Expand Down
14 changes: 14 additions & 0 deletions mmdeploy/apis/tensorrt/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# flake8: noqa
from .init_plugins import load_tensorrt_plugin
from .onnx2tensorrt import onnx2tensorrt
from .tensorrt_utils import (TRTWrapper, load_trt_engine, onnx2trt,
save_trt_engine)

# load tensorrt plugin lib
load_tensorrt_plugin()

__all__ = [
'onnx2trt', 'save_trt_engine', 'load_trt_engine', 'TRTWraper',
'TRTWrapper', 'is_tensorrt_plugin_loaded', 'preprocess_onnx',
'onnx2tensorrt'
]
27 changes: 27 additions & 0 deletions mmdeploy/apis/tensorrt/init_plugins.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import ctypes
import glob
import os
import logging


def get_tensorrt_op_path():
"""Get TensorRT plugins library path."""
wildcard = os.path.abspath(
os.path.join(
os.path.dirname(__file__),
'../../../build/lib/libmmlab_tensorrt_ops.so'))

paths = glob.glob(wildcard)
lib_path = paths[0] if len(paths) > 0 else ''
return lib_path


def load_tensorrt_plugin():
"""load TensorRT plugins library."""
lib_path = get_tensorrt_op_path()
if os.path.exists(lib_path):
ctypes.CDLL(lib_path)
return 0
else:
logging.warning('Can not load tensorrt custom ops.')
return -1
49 changes: 49 additions & 0 deletions mmdeploy/apis/tensorrt/onnx2tensorrt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import os.path as osp
from typing import Optional, Union

import tensorrt as trt

import mmcv
import onnx
import torch.multiprocessing as mp

from .tensorrt_utils import onnx2trt, save_trt_engine


def onnx2tensorrt(work_dir: str,
save_file: str,
deploy_cfg: Union[str, mmcv.Config],
onnx_model: Union[str, onnx.ModelProto],
device: str = 'cuda:0',
ret_value: Optional[mp.Value] = None):
ret_value.value = -1
save_file = 'onnx2tensorrt.engine'

# load deploy_cfg if needed
if isinstance(deploy_cfg, str):
deploy_cfg = mmcv.Config.fromfile(deploy_cfg)
elif not isinstance(deploy_cfg, mmcv.Config):
raise TypeError('deploy_cfg must be a filename or Config object, '
f'but got {type(deploy_cfg)}')

mmcv.mkdir_or_exist(osp.abspath(work_dir))

assert 'tensorrt_param' in deploy_cfg

tensorrt_param = deploy_cfg['tensorrt_param']

assert device.startswith('cuda')
device_id = 0
if len(device) >= 6:
device_id = int(device[5:])
engine = onnx2trt(
onnx_model,
opt_shape_dict=tensorrt_param['opt_shape_dict'],
log_level=tensorrt_param.get('log_level', trt.Logger.WARNING),
fp16_mode=tensorrt_param.get('fp16_mode', False),
max_workspace_size=tensorrt_param.get('max_workspace_size', 0),
device_id=device_id)

save_trt_engine(engine, osp.join(work_dir, save_file))

ret_value.value = 0
222 changes: 222 additions & 0 deletions mmdeploy/apis/tensorrt/tensorrt_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
import onnx
import tensorrt as trt
import torch


def onnx2trt(onnx_model,
opt_shape_dict,
log_level=trt.Logger.ERROR,
fp16_mode=False,
max_workspace_size=0,
device_id=0):
"""Convert onnx model to tensorrt engine.
Arguments:
onnx_model (str or onnx.ModelProto): the onnx model to convert from
opt_shape_dict (dict): the min/opt/max shape of each input
log_level (TensorRT log level): the log level of TensorRT
fp16_mode (bool): enable fp16 mode
max_workspace_size (int): set max workspace size of TensorRT engine.
some tactic and layers need large workspace.
device_id (int): choice the device to create engine.
Returns:
tensorrt.ICudaEngine: the TensorRT engine created from onnx_model
Example:
>>> engine = onnx2trt(
>>> "onnx_model.onnx",
>>> {'input': [[1, 3, 160, 160],
>>> [1, 3, 320, 320],
>>> [1, 3, 640, 640]]},
>>> log_level=trt.Logger.WARNING,
>>> fp16_mode=True,
>>> max_workspace_size=1 << 30,
>>> device_id=0)
>>> })
"""
device = torch.device('cuda:{}'.format(device_id))
# create builder and network
logger = trt.Logger(log_level)
builder = trt.Builder(logger)
EXPLICIT_BATCH = 1 << (int)(
trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
network = builder.create_network(EXPLICIT_BATCH)

# parse onnx
parser = trt.OnnxParser(network, logger)

if isinstance(onnx_model, str):
onnx_model = onnx.load(onnx_model)

if not parser.parse(onnx_model.SerializeToString()):
error_msgs = ''
for error in range(parser.num_errors):
error_msgs += f'{parser.get_error(error)}\n'
raise RuntimeError(f'parse onnx failed:\n{error_msgs}')

# config builder
builder.max_workspace_size = max_workspace_size

config = builder.create_builder_config()
config.max_workspace_size = max_workspace_size
profile = builder.create_optimization_profile()

for input_name, param in opt_shape_dict.items():
min_shape = tuple(param[0][:])
opt_shape = tuple(param[1][:])
max_shape = tuple(param[2][:])
profile.set_shape(input_name, min_shape, opt_shape, max_shape)
config.add_optimization_profile(profile)

if fp16_mode:
builder.fp16_mode = fp16_mode
config.set_flag(trt.BuilderFlag.FP16)

# create engine
with torch.cuda.device(device):
engine = builder.build_engine(network, config)

return engine


def save_trt_engine(engine, path):
"""Serialize TensorRT engine to disk.
Arguments:
engine (tensorrt.ICudaEngine): TensorRT engine to serialize
path (str): disk path to write the engine
"""
with open(path, mode='wb') as f:
f.write(bytearray(engine.serialize()))


def load_trt_engine(path):
"""Deserialize TensorRT engine from disk.
Arguments:
path (str): disk path to read the engine
Returns:
tensorrt.ICudaEngine: the TensorRT engine loaded from disk
"""
with trt.Logger() as logger, trt.Runtime(logger) as runtime:
with open(path, mode='rb') as f:
engine_bytes = f.read()
engine = runtime.deserialize_cuda_engine(engine_bytes)
return engine


def torch_dtype_from_trt(dtype):
"""Convert pytorch dtype to TensorRT dtype."""
if dtype == trt.bool:
return torch.bool
elif dtype == trt.int8:
return torch.int8
elif dtype == trt.int32:
return torch.int32
elif dtype == trt.float16:
return torch.float16
elif dtype == trt.float32:
return torch.float32
else:
raise TypeError('%s is not supported by torch' % dtype)


def torch_device_from_trt(device):
"""Convert pytorch device to TensorRT device."""
if device == trt.TensorLocation.DEVICE:
return torch.device('cuda')
elif device == trt.TensorLocation.HOST:
return torch.device('cpu')
else:
return TypeError('%s is not supported by torch' % device)


class TRTWrapper(torch.nn.Module):
"""TensorRT engine Wrapper.
Arguments:
engine (tensorrt.ICudaEngine): TensorRT engine to wrap
input_names (list[str]): names of each inputs
output_names (list[str]): names of each outputs
Note:
If the engine is converted from onnx model. The input_names and
output_names should be the same as onnx model.
"""

def __init__(self, engine):
super(TRTWrapper, self).__init__()
self.engine = engine
if isinstance(self.engine, str):
self.engine = load_trt_engine(engine)

if not isinstance(self.engine, trt.ICudaEngine):
raise TypeError('engine should be str or trt.ICudaEngine')

self._register_state_dict_hook(TRTWrapper._on_state_dict)
self.context = self.engine.create_execution_context()

self._load_io_names()

def _load_io_names(self):
# get input and output names from engine
names = [_ for _ in self.engine]
input_names = list(filter(self.engine.binding_is_input, names))
output_names = list(set(names) - set(input_names))
self.input_names = input_names
self.output_names = output_names

def _on_state_dict(self, state_dict, prefix, local_metadata):
state_dict[prefix + 'engine'] = bytearray(self.engine.serialize())
state_dict[prefix + 'input_names'] = self.input_names
state_dict[prefix + 'output_names'] = self.output_names

def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
engine_bytes = state_dict[prefix + 'engine']

with trt.Logger() as logger, trt.Runtime(logger) as runtime:
self.engine = runtime.deserialize_cuda_engine(engine_bytes)
self.context = self.engine.create_execution_context()

self.input_names = state_dict[prefix + 'input_names']
self.output_names = state_dict[prefix + 'output_names']

def forward(self, inputs):
"""
Arguments:
inputs (dict): dict of input name-tensors pair
Return:
dict: dict of output name-tensors pair
"""
assert self.input_names is not None
assert self.output_names is not None
bindings = [None] * (len(self.input_names) + len(self.output_names))

for input_name, input_tensor in inputs.items():
idx = self.engine.get_binding_index(input_name)

if input_tensor.dtype == torch.long:
input_tensor = input_tensor.int()
self.context.set_binding_shape(idx, tuple(input_tensor.shape))
bindings[idx] = input_tensor.contiguous().data_ptr()

# create output tensors
outputs = {}
for i, output_name in enumerate(self.output_names):
idx = self.engine.get_binding_index(output_name)
dtype = torch_dtype_from_trt(self.engine.get_binding_dtype(idx))
shape = tuple(self.context.get_binding_shape(idx))

device = torch_device_from_trt(self.engine.get_location(idx))
output = torch.empty(size=shape, dtype=dtype, device=device)
outputs[output_name] = output
bindings[idx] = output.data_ptr()

self.context.execute_async_v2(bindings,
torch.cuda.current_stream().cuda_stream)

return outputs
10 changes: 10 additions & 0 deletions mmdeploy/mmcv/funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,13 @@ def rewrite_topk_tensorrt(rewriter,
k = int(k)
return rewriter.origin_func(
input, k, dim=dim, largest=largest, sorted=sorted)


@FUNCTION_REWRITERS.register_rewriter(
func_name='torch.Tensor.repeat', backend='tensorrt')
def rewrite_repeat_tensorrt(rewriter, input, *size):
origin_func = rewriter.origin_func
if input.dim() == 1 and len(size) == 1:
return origin_func(input.unsqueeze(0), *([1] + list(size))).squeeze(0)
else:
return origin_func(input, *size)
Loading

0 comments on commit 6c47ee3

Please sign in to comment.