Skip to content

Commit

Permalink
[Refactor][tools] Add prebuild tools. (open-mmlab#347)
Browse files Browse the repository at this point in the history
* move to lib

* optional import pytorch rewriter

* reduce torch dependancy of tensorrt export

* remove more mmcv support

* fix pytest

* remove mmcv logge

* Add `mmdeploy.utils.logging`

* Improve the common of the `get_logger`

* Fix lint

* onnxruntim add try catch to  import wrapper if pytorch is available

* Using `mmcv.utils.logging` in all files under `mmdeploy/codebase`

* add __init__

* add prebuild tools

* support windows

* for comment

* exit if failed

* add exist

* decouple

* add tags

* remove .mmdeploy_python

* read python version from system

* update windows config

* update linux config

* remote many

* better build name

* rename python tag

* fix pyhon-tag

* update window config

* add env search

* update tag

* fix build without CUDA_TOOLKIT_ROOT_DIR

Co-authored-by: HinGwenWoong <[email protected]>
  • Loading branch information
2 people authored and lvhan028 committed Jun 3, 2022
1 parent d3304cf commit 218f287
Show file tree
Hide file tree
Showing 37 changed files with 723 additions and 153 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ __pycache__/

# C extensions
*.so
onnx2ncnn

# Distribution / packaging
.Python
Expand Down
6 changes: 6 additions & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
@@ -1 +1,7 @@
include requirements/*.txt
include mmdeploy/backend/ncnn/*.so
include mmdeploy/backend/ncnn/*.dll
include mmdeploy/backend/ncnn/*.pyd
include mmdeploy/lib/*.so
include mmdeploy/lib/*.dll
include mmdeploy/lib/*.pyd
2 changes: 2 additions & 0 deletions csrc/backend_ops/ncnn/onnx2ncnn/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ if (PROTOBUF_FOUND)
${CMAKE_CURRENT_BINARY_DIR})
target_link_libraries(onnx2ncnn PRIVATE ${PROTOBUF_LIBRARIES})

set(_NCNN_CONVERTER_DIR ${CMAKE_SOURCE_DIR}/mmdeploy/backend/ncnn)
install(TARGETS onnx2ncnn DESTINATION ${_NCNN_CONVERTER_DIR})
else ()
message(
FATAL_ERROR "Protobuf not found, onnx model convert tool won't be built")
Expand Down
3 changes: 3 additions & 0 deletions csrc/backend_ops/ncnn/ops/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,6 @@ target_include_directories(${PROJECT_NAME}
PUBLIC ${_COMMON_INCLUDE_DIRS})

add_library(mmdeploy::ncnn_ops ALIAS ${PROJECT_NAME})

set(_NCNN_OPS_DIR ${CMAKE_SOURCE_DIR}/mmdeploy/lib)
install(TARGETS ${PROJECT_NAME} DESTINATION ${_NCNN_OPS_DIR})
3 changes: 3 additions & 0 deletions csrc/backend_ops/onnxruntime/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,6 @@ target_link_libraries(${PROJECT_NAME}_obj PUBLIC onnxruntime)
mmdeploy_add_library(${PROJECT_NAME} SHARED EXCLUDE "")
target_link_libraries(${PROJECT_NAME} PUBLIC ${PROJECT_NAME}_obj)
add_library(mmdeploy::onnxruntime::ops ALIAS ${PROJECT_NAME})

set(_ORT_OPS_DIR ${CMAKE_SOURCE_DIR}/mmdeploy/lib)
install(TARGETS ${PROJECT_NAME} DESTINATION ${_ORT_OPS_DIR})
3 changes: 3 additions & 0 deletions csrc/backend_ops/tensorrt/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,6 @@ mmdeploy_export(${PROJECT_NAME}_obj)
mmdeploy_add_module(${PROJECT_NAME} MODULE EXCLUDE "")
target_link_libraries(${PROJECT_NAME} PRIVATE ${PROJECT_NAME}_obj)
add_library(mmdeploy::tensorrt_ops ALIAS ${PROJECT_NAME})

set(_TRT_OPS_DIR ${CMAKE_SOURCE_DIR}/mmdeploy/lib)
install(TARGETS ${PROJECT_NAME} DESTINATION ${_TRT_OPS_DIR})
6 changes: 5 additions & 1 deletion mmdeploy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@
from mmdeploy.utils import get_root_logger
from .version import __version__ # noqa F401

importlib.import_module('mmdeploy.pytorch')
if importlib.util.find_spec('torch'):
importlib.import_module('mmdeploy.pytorch')
else:
logger = get_root_logger()
logger.debug('torch is not installed.')

if importlib.util.find_spec('mmcv'):
importlib.import_module('mmdeploy.mmcv')
Expand Down
26 changes: 0 additions & 26 deletions mmdeploy/backend/__init__.py
Original file line number Diff line number Diff line change
@@ -1,27 +1 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmdeploy.backend.ncnn import is_available as ncnn_available
from mmdeploy.backend.onnxruntime import is_available as ort_available
from mmdeploy.backend.openvino import is_available as openvino_available
from mmdeploy.backend.pplnn import is_available as pplnn_available
from mmdeploy.backend.sdk import is_available as sdk_available
from mmdeploy.backend.tensorrt import is_available as trt_available

__all__ = []
if ncnn_available():
from .ncnn import NCNNWrapper # noqa: F401,F403
__all__.append('NCNNWrapper')
if ort_available():
from .onnxruntime import ORTWrapper # noqa: F401,F403
__all__.append('ORTWrapper')
if trt_available():
from .tensorrt import TRTWrapper # noqa: F401,F403
__all__.append('TRTWrapper')
if pplnn_available():
from .pplnn import PPLNNWrapper # noqa: F401,F403
__all__.append('PPLNNWrapper')
if openvino_available():
from .openvino import OpenVINOWrapper # noqa: F401,F403
__all__.append('OpenVINOWrapper')
if sdk_available():
from .sdk import SDKWrapper # noqa: F401,F403
__all__.append('SDKWrapper')
7 changes: 5 additions & 2 deletions mmdeploy/backend/ncnn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ def is_plugin_available():


if is_available():
from .wrapper import NCNNWrapper
try:
from .wrapper import NCNNWrapper

__all__ = ['NCNNWrapper']
__all__ = ['NCNNWrapper']
except Exception:
pass
7 changes: 2 additions & 5 deletions mmdeploy/backend/ncnn/init_plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@ def get_ops_path() -> str:
str: The library path of ncnn custom ops.
"""
candidates = [
'../../../build/lib/libmmdeploy_ncnn_ops.so',
'../../../build/bin/*/mmdeploy_ncnn_ops.dll'
'../../lib/libmmdeploy_ncnn_ops.so', '../../lib/mmdeploy_ncnn_ops.dll'
]
return get_file_path(os.path.dirname(__file__), candidates)

Expand All @@ -23,7 +22,5 @@ def get_onnx2ncnn_path() -> str:
Returns:
str: A path of onnx2ncnn tool.
"""
candidates = [
'../../../build/bin/onnx2ncnn', '../../../build/bin/*/onnx2ncnn.exe'
]
candidates = ['./onnx2ncnn', './onnx2ncnn.exe']
return get_file_path(os.path.dirname(__file__), candidates)
12 changes: 9 additions & 3 deletions mmdeploy/backend/ncnn/onnx2ncnn.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os
import os.path as osp
from subprocess import call
from typing import List

import mmcv

from .init_plugins import get_onnx2ncnn_path


def mkdir_or_exist(dir_name, mode=0o777):
if dir_name == '':
return
dir_name = osp.expanduser(dir_name)
os.makedirs(dir_name, mode=mode, exist_ok=True)


def get_output_model_file(onnx_path: str, work_dir: str) -> List[str]:
"""Returns the path to the .param, .bin file with export result.
Expand All @@ -19,7 +25,7 @@ def get_output_model_file(onnx_path: str, work_dir: str) -> List[str]:
List[str]: The path to the files where the export result will be
located.
"""
mmcv.mkdir_or_exist(osp.abspath(work_dir))
mkdir_or_exist(osp.abspath(work_dir))
file_name = osp.splitext(osp.split(onnx_path)[1])[0]
save_param = osp.join(work_dir, file_name + '.param')
save_bin = osp.join(work_dir, file_name + '.bin')
Expand Down
8 changes: 6 additions & 2 deletions mmdeploy/backend/onnxruntime/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,9 @@ def is_plugin_available():


if is_available():
from .wrapper import ORTWrapper
__all__ = ['ORTWrapper']
try:
# import wrapper if pytorch is available
from .wrapper import ORTWrapper
__all__ = ['ORTWrapper']
except Exception:
pass
4 changes: 2 additions & 2 deletions mmdeploy/backend/onnxruntime/init_plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def get_ops_path() -> str:
str: The library path to onnxruntime custom ops.
"""
candidates = [
'../../../build/lib/libmmdeploy_onnxruntime_ops.so',
'../../../build/bin/*/mmdeploy_onnxruntime_ops.dll',
'../../lib/libmmdeploy_onnxruntime_ops.so',
'../../lib/mmdeploy_onnxruntime_ops.dll',
]
return get_file_path(os.path.dirname(__file__), candidates)
13 changes: 10 additions & 3 deletions mmdeploy/backend/sdk/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,20 @@

if lib_path:
lib_dir = os.path.dirname(lib_path)
sys.path.insert(0, lib_dir)
sys.path.append(lib_dir)

if importlib.util.find_spec(module_name) is not None:
from .wrapper import SDKWrapper
__all__ = ['SDKWrapper']
_is_available = True


def is_available() -> bool:
return _is_available


if is_available():

try:
from .wrapper import SDKWrapper
__all__ = ['SDKWrapper']
except Exception:
pass
15 changes: 9 additions & 6 deletions mmdeploy/backend/tensorrt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
import importlib
import os.path as osp

import torch

from .init_plugins import get_ops_path, load_tensorrt_plugin


Expand All @@ -15,8 +13,7 @@ def is_available():
bool: True if TensorRT package is installed and cuda is available.
"""

return importlib.util.find_spec('tensorrt') is not None and \
torch.cuda.is_available()
return importlib.util.find_spec('tensorrt') is not None


def is_plugin_available():
Expand All @@ -31,9 +28,15 @@ def is_plugin_available():

if is_available():
from .utils import create_trt_engine, load_trt_engine, save_trt_engine
from .wrapper import TRTWrapper

__all__ = [
'create_trt_engine', 'save_trt_engine', 'load_trt_engine',
'TRTWrapper', 'load_tensorrt_plugin'
'load_tensorrt_plugin'
]

try:
# import wrapper if pytorch is available
from .wrapper import TRTWrapper
__all__ += ['TRTWrapper']
except Exception:
pass
25 changes: 12 additions & 13 deletions mmdeploy/backend/tensorrt/calib_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@

import h5py
import numpy as np
import pycuda.autoinit # noqa:F401
import pycuda.driver as cuda
import tensorrt as trt
import torch

DEFAULT_CALIBRATION_ALGORITHM = trt.CalibrationAlgoType.ENTROPY_CALIBRATION_2

Expand Down Expand Up @@ -67,30 +68,28 @@ def get_batch(self, names: Sequence[str], **kwargs) -> list:
ret = []
for name in names:
input_group = self.calib_data[name]
data_np = input_group[str(self.count)][...]
data_torch = torch.from_numpy(data_np)
data_np = input_group[str(self.count)][...].astype(np.float32)

# tile the tensor so we can keep the same distribute
opt_shape = self.input_shapes[name]['opt_shape']
data_shape = data_torch.shape
data_shape = data_np.shape

reps = [
int(np.ceil(opt_s / data_s))
for opt_s, data_s in zip(opt_shape, data_shape)
]

data_torch = data_torch.tile(reps)
data_np = np.tile(data_np, reps)

for dim, opt_s in enumerate(opt_shape):
if data_torch.shape[dim] != opt_s:
data_torch = data_torch.narrow(dim, 0, opt_s)
slice_list = tuple(slice(0, end) for end in opt_shape)
data_np = data_np[slice_list]

if name not in self.buffers:
self.buffers[name] = data_torch.cuda(self.device_id)
else:
self.buffers[name].copy_(data_torch.cuda(self.device_id))
data_np_cuda_ptr = cuda.mem_alloc(data_np.nbytes)
cuda.memcpy_htod(data_np_cuda_ptr,
np.ascontiguousarray(data_np))
self.buffers[name] = data_np_cuda_ptr

ret.append(int(self.buffers[name].data_ptr()))
ret.append(self.buffers[name])
self.count += 1
return ret
else:
Expand Down
4 changes: 2 additions & 2 deletions mmdeploy/backend/tensorrt/init_plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ def get_ops_path() -> str:
str: A path of the TensorRT plugin library.
"""
candidates = [
'../../../build/lib/libmmdeploy_tensorrt_ops.so',
'../../../build/bin/*/mmdeploy_tensorrt_ops.dll'
'../../lib/libmmdeploy_tensorrt_ops.so',
'../../lib/mmdeploy_tensorrt_ops.dll'
]
return get_file_path(os.path.dirname(__file__), candidates)

Expand Down
57 changes: 12 additions & 45 deletions mmdeploy/backend/tensorrt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,9 @@

import onnx
import tensorrt as trt
import torch
from packaging import version

from mmdeploy.utils import get_root_logger
from .calib_utils import HDF5Calibrator
from .init_plugins import load_tensorrt_plugin


Expand Down Expand Up @@ -54,8 +52,17 @@ def create_trt_engine(onnx_model: Union[str, onnx.ModelProto],
>>> device_id=0)
>>> })
"""

import os
old_cuda_device = os.environ.get('CUDA_DEVICE', None)
os.environ['CUDA_DEVICE'] = str(device_id)
import pycuda.autoinit # noqa:F401
if old_cuda_device is not None:
os.environ['CUDA_DEVICE'] = old_cuda_device
else:
os.environ.pop('CUDA_DEVICE')

load_tensorrt_plugin()
device = torch.device('cuda:{}'.format(device_id))
# create builder and network
logger = trt.Logger(log_level)
builder = trt.Builder(logger)
Expand Down Expand Up @@ -96,6 +103,7 @@ def create_trt_engine(onnx_model: Union[str, onnx.ModelProto],
config.set_flag(trt.BuilderFlag.FP16)

if int8_mode:
from .calib_utils import HDF5Calibrator
config.set_flag(trt.BuilderFlag.INT8)
assert int8_param is not None
config.int8_calibrator = HDF5Calibrator(
Expand All @@ -110,8 +118,7 @@ def create_trt_engine(onnx_model: Union[str, onnx.ModelProto],
builder.int8_calibrator = config.int8_calibrator

# create engine
with torch.cuda.device(device):
engine = builder.build_engine(network, config)
engine = builder.build_engine(network, config)

assert engine is not None, 'Failed to create TensorRT engine'
return engine
Expand Down Expand Up @@ -145,46 +152,6 @@ def load_trt_engine(path: str) -> trt.ICudaEngine:
return engine


def torch_dtype_from_trt(dtype: trt.DataType) -> torch.dtype:
"""Convert pytorch dtype to TensorRT dtype.
Args:
dtype (str.DataType): The data type in tensorrt.
Returns:
torch.dtype: The corresponding data type in torch.
"""

if dtype == trt.bool:
return torch.bool
elif dtype == trt.int8:
return torch.int8
elif dtype == trt.int32:
return torch.int32
elif dtype == trt.float16:
return torch.float16
elif dtype == trt.float32:
return torch.float32
else:
raise TypeError(f'{dtype} is not supported by torch')


def torch_device_from_trt(device: trt.TensorLocation):
"""Convert pytorch device to TensorRT device.
Args:
device (trt.TensorLocation): The device in tensorrt.
Returns:
torch.device: The corresponding device in torch.
"""
if device == trt.TensorLocation.DEVICE:
return torch.device('cuda')
elif device == trt.TensorLocation.HOST:
return torch.device('cpu')
else:
return TypeError(f'{device} is not supported by torch')


def get_trt_log_level() -> trt.Logger.Severity:
"""Get tensorrt log level from root logger.
Expand Down
Loading

0 comments on commit 218f287

Please sign in to comment.