forked from open-mmlab/mmdetection3d
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
13 changed files
with
394 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,2 @@ | ||
[settings] | ||
known_third_party = mmcv,mmdet,numpy,setuptools,torch | ||
known_third_party = mmcv,mmdet,numpy,setuptools,tensorrt,torch |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,7 @@ | ||
import tensorrt as trt | ||
|
||
backend = 'tensorrt' | ||
tensorrt_param = dict( | ||
log_level=trt.Logger.WARNING, | ||
fp16_mode=False, | ||
save_file='onnx2tensorrt.engine') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,5 @@ | ||
pytorch2onnx = dict( | ||
export_params=True, keep_initializers_as_inputs=False, opset_version=11) | ||
export_params=True, | ||
keep_initializers_as_inputs=False, | ||
opset_version=11, | ||
save_file='torch2onnx.onnx') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,5 @@ | ||
_base_ = ['./mmcls_base.py', '../_base_/backends/tensorrt.py'] | ||
tensorrt_param = dict( | ||
opt_shape_dict=dict( | ||
input=[[1, 3, 224, 224], [4, 3, 224, 224], [32, 3, 224, 224]]), | ||
max_workspace_size=1 << 30) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,5 @@ | ||
_base_ = ['./base.py', '../_base_/backends/tensorrt.py'] | ||
tensorrt_param = dict( | ||
opt_shape_dict=dict( | ||
input=[[1, 3, 320, 320], [1, 3, 800, 1344], [1, 3, 1344, 1344]]), | ||
max_workspace_size=1 << 30) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
# flake8: noqa | ||
from .init_plugins import load_tensorrt_plugin | ||
from .onnx2tensorrt import onnx2tensorrt | ||
from .tensorrt_utils import (TRTWrapper, load_trt_engine, onnx2trt, | ||
save_trt_engine) | ||
|
||
# load tensorrt plugin lib | ||
load_tensorrt_plugin() | ||
|
||
__all__ = [ | ||
'onnx2trt', 'save_trt_engine', 'load_trt_engine', 'TRTWraper', | ||
'TRTWrapper', 'is_tensorrt_plugin_loaded', 'preprocess_onnx', | ||
'onnx2tensorrt' | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
import ctypes | ||
import glob | ||
import os | ||
import logging | ||
|
||
|
||
def get_tensorrt_op_path(): | ||
"""Get TensorRT plugins library path.""" | ||
wildcard = os.path.abspath( | ||
os.path.join( | ||
os.path.dirname(__file__), | ||
'../../../build/lib/libmmlab_tensorrt_ops.so')) | ||
|
||
paths = glob.glob(wildcard) | ||
lib_path = paths[0] if len(paths) > 0 else '' | ||
return lib_path | ||
|
||
|
||
def load_tensorrt_plugin(): | ||
"""load TensorRT plugins library.""" | ||
lib_path = get_tensorrt_op_path() | ||
if os.path.exists(lib_path): | ||
ctypes.CDLL(lib_path) | ||
return 0 | ||
else: | ||
logging.warning('Can not load tensorrt custom ops.') | ||
return -1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
import os.path as osp | ||
from typing import Optional, Union | ||
|
||
import tensorrt as trt | ||
|
||
import mmcv | ||
import onnx | ||
import torch.multiprocessing as mp | ||
|
||
from .tensorrt_utils import onnx2trt, save_trt_engine | ||
|
||
|
||
def onnx2tensorrt(work_dir: str, | ||
save_file: str, | ||
deploy_cfg: Union[str, mmcv.Config], | ||
onnx_model: Union[str, onnx.ModelProto], | ||
device: str = 'cuda:0', | ||
ret_value: Optional[mp.Value] = None): | ||
ret_value.value = -1 | ||
save_file = 'onnx2tensorrt.engine' | ||
|
||
# load deploy_cfg if needed | ||
if isinstance(deploy_cfg, str): | ||
deploy_cfg = mmcv.Config.fromfile(deploy_cfg) | ||
elif not isinstance(deploy_cfg, mmcv.Config): | ||
raise TypeError('deploy_cfg must be a filename or Config object, ' | ||
f'but got {type(deploy_cfg)}') | ||
|
||
mmcv.mkdir_or_exist(osp.abspath(work_dir)) | ||
|
||
assert 'tensorrt_param' in deploy_cfg | ||
|
||
tensorrt_param = deploy_cfg['tensorrt_param'] | ||
|
||
assert device.startswith('cuda') | ||
device_id = 0 | ||
if len(device) >= 6: | ||
device_id = int(device[5:]) | ||
engine = onnx2trt( | ||
onnx_model, | ||
opt_shape_dict=tensorrt_param['opt_shape_dict'], | ||
log_level=tensorrt_param.get('log_level', trt.Logger.WARNING), | ||
fp16_mode=tensorrt_param.get('fp16_mode', False), | ||
max_workspace_size=tensorrt_param.get('max_workspace_size', 0), | ||
device_id=device_id) | ||
|
||
save_trt_engine(engine, osp.join(work_dir, save_file)) | ||
|
||
ret_value.value = 0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,222 @@ | ||
import onnx | ||
import tensorrt as trt | ||
import torch | ||
|
||
|
||
def onnx2trt(onnx_model, | ||
opt_shape_dict, | ||
log_level=trt.Logger.ERROR, | ||
fp16_mode=False, | ||
max_workspace_size=0, | ||
device_id=0): | ||
"""Convert onnx model to tensorrt engine. | ||
Arguments: | ||
onnx_model (str or onnx.ModelProto): the onnx model to convert from | ||
opt_shape_dict (dict): the min/opt/max shape of each input | ||
log_level (TensorRT log level): the log level of TensorRT | ||
fp16_mode (bool): enable fp16 mode | ||
max_workspace_size (int): set max workspace size of TensorRT engine. | ||
some tactic and layers need large workspace. | ||
device_id (int): choice the device to create engine. | ||
Returns: | ||
tensorrt.ICudaEngine: the TensorRT engine created from onnx_model | ||
Example: | ||
>>> engine = onnx2trt( | ||
>>> "onnx_model.onnx", | ||
>>> {'input': [[1, 3, 160, 160], | ||
>>> [1, 3, 320, 320], | ||
>>> [1, 3, 640, 640]]}, | ||
>>> log_level=trt.Logger.WARNING, | ||
>>> fp16_mode=True, | ||
>>> max_workspace_size=1 << 30, | ||
>>> device_id=0) | ||
>>> }) | ||
""" | ||
device = torch.device('cuda:{}'.format(device_id)) | ||
# create builder and network | ||
logger = trt.Logger(log_level) | ||
builder = trt.Builder(logger) | ||
EXPLICIT_BATCH = 1 << (int)( | ||
trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) | ||
network = builder.create_network(EXPLICIT_BATCH) | ||
|
||
# parse onnx | ||
parser = trt.OnnxParser(network, logger) | ||
|
||
if isinstance(onnx_model, str): | ||
onnx_model = onnx.load(onnx_model) | ||
|
||
if not parser.parse(onnx_model.SerializeToString()): | ||
error_msgs = '' | ||
for error in range(parser.num_errors): | ||
error_msgs += f'{parser.get_error(error)}\n' | ||
raise RuntimeError(f'parse onnx failed:\n{error_msgs}') | ||
|
||
# config builder | ||
builder.max_workspace_size = max_workspace_size | ||
|
||
config = builder.create_builder_config() | ||
config.max_workspace_size = max_workspace_size | ||
profile = builder.create_optimization_profile() | ||
|
||
for input_name, param in opt_shape_dict.items(): | ||
min_shape = tuple(param[0][:]) | ||
opt_shape = tuple(param[1][:]) | ||
max_shape = tuple(param[2][:]) | ||
profile.set_shape(input_name, min_shape, opt_shape, max_shape) | ||
config.add_optimization_profile(profile) | ||
|
||
if fp16_mode: | ||
builder.fp16_mode = fp16_mode | ||
config.set_flag(trt.BuilderFlag.FP16) | ||
|
||
# create engine | ||
with torch.cuda.device(device): | ||
engine = builder.build_engine(network, config) | ||
|
||
return engine | ||
|
||
|
||
def save_trt_engine(engine, path): | ||
"""Serialize TensorRT engine to disk. | ||
Arguments: | ||
engine (tensorrt.ICudaEngine): TensorRT engine to serialize | ||
path (str): disk path to write the engine | ||
""" | ||
with open(path, mode='wb') as f: | ||
f.write(bytearray(engine.serialize())) | ||
|
||
|
||
def load_trt_engine(path): | ||
"""Deserialize TensorRT engine from disk. | ||
Arguments: | ||
path (str): disk path to read the engine | ||
Returns: | ||
tensorrt.ICudaEngine: the TensorRT engine loaded from disk | ||
""" | ||
with trt.Logger() as logger, trt.Runtime(logger) as runtime: | ||
with open(path, mode='rb') as f: | ||
engine_bytes = f.read() | ||
engine = runtime.deserialize_cuda_engine(engine_bytes) | ||
return engine | ||
|
||
|
||
def torch_dtype_from_trt(dtype): | ||
"""Convert pytorch dtype to TensorRT dtype.""" | ||
if dtype == trt.bool: | ||
return torch.bool | ||
elif dtype == trt.int8: | ||
return torch.int8 | ||
elif dtype == trt.int32: | ||
return torch.int32 | ||
elif dtype == trt.float16: | ||
return torch.float16 | ||
elif dtype == trt.float32: | ||
return torch.float32 | ||
else: | ||
raise TypeError('%s is not supported by torch' % dtype) | ||
|
||
|
||
def torch_device_from_trt(device): | ||
"""Convert pytorch device to TensorRT device.""" | ||
if device == trt.TensorLocation.DEVICE: | ||
return torch.device('cuda') | ||
elif device == trt.TensorLocation.HOST: | ||
return torch.device('cpu') | ||
else: | ||
return TypeError('%s is not supported by torch' % device) | ||
|
||
|
||
class TRTWrapper(torch.nn.Module): | ||
"""TensorRT engine Wrapper. | ||
Arguments: | ||
engine (tensorrt.ICudaEngine): TensorRT engine to wrap | ||
input_names (list[str]): names of each inputs | ||
output_names (list[str]): names of each outputs | ||
Note: | ||
If the engine is converted from onnx model. The input_names and | ||
output_names should be the same as onnx model. | ||
""" | ||
|
||
def __init__(self, engine): | ||
super(TRTWrapper, self).__init__() | ||
self.engine = engine | ||
if isinstance(self.engine, str): | ||
self.engine = load_trt_engine(engine) | ||
|
||
if not isinstance(self.engine, trt.ICudaEngine): | ||
raise TypeError('engine should be str or trt.ICudaEngine') | ||
|
||
self._register_state_dict_hook(TRTWrapper._on_state_dict) | ||
self.context = self.engine.create_execution_context() | ||
|
||
self._load_io_names() | ||
|
||
def _load_io_names(self): | ||
# get input and output names from engine | ||
names = [_ for _ in self.engine] | ||
input_names = list(filter(self.engine.binding_is_input, names)) | ||
output_names = list(set(names) - set(input_names)) | ||
self.input_names = input_names | ||
self.output_names = output_names | ||
|
||
def _on_state_dict(self, state_dict, prefix, local_metadata): | ||
state_dict[prefix + 'engine'] = bytearray(self.engine.serialize()) | ||
state_dict[prefix + 'input_names'] = self.input_names | ||
state_dict[prefix + 'output_names'] = self.output_names | ||
|
||
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, | ||
missing_keys, unexpected_keys, error_msgs): | ||
engine_bytes = state_dict[prefix + 'engine'] | ||
|
||
with trt.Logger() as logger, trt.Runtime(logger) as runtime: | ||
self.engine = runtime.deserialize_cuda_engine(engine_bytes) | ||
self.context = self.engine.create_execution_context() | ||
|
||
self.input_names = state_dict[prefix + 'input_names'] | ||
self.output_names = state_dict[prefix + 'output_names'] | ||
|
||
def forward(self, inputs): | ||
""" | ||
Arguments: | ||
inputs (dict): dict of input name-tensors pair | ||
Return: | ||
dict: dict of output name-tensors pair | ||
""" | ||
assert self.input_names is not None | ||
assert self.output_names is not None | ||
bindings = [None] * (len(self.input_names) + len(self.output_names)) | ||
|
||
for input_name, input_tensor in inputs.items(): | ||
idx = self.engine.get_binding_index(input_name) | ||
|
||
if input_tensor.dtype == torch.long: | ||
input_tensor = input_tensor.int() | ||
self.context.set_binding_shape(idx, tuple(input_tensor.shape)) | ||
bindings[idx] = input_tensor.contiguous().data_ptr() | ||
|
||
# create output tensors | ||
outputs = {} | ||
for i, output_name in enumerate(self.output_names): | ||
idx = self.engine.get_binding_index(output_name) | ||
dtype = torch_dtype_from_trt(self.engine.get_binding_dtype(idx)) | ||
shape = tuple(self.context.get_binding_shape(idx)) | ||
|
||
device = torch_device_from_trt(self.engine.get_location(idx)) | ||
output = torch.empty(size=shape, dtype=dtype, device=device) | ||
outputs[output_name] = output | ||
bindings[idx] = output.data_ptr() | ||
|
||
self.context.execute_async_v2(bindings, | ||
torch.cuda.current_stream().cuda_stream) | ||
|
||
return outputs |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.