forked from open-mmlab/mmdetection3d
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Feature] Apis unit test (open-mmlab#7)
* add apis test * split torch2onnx impl, prepare for codebase test * add is_available to backend * lint
- Loading branch information
Showing
9 changed files
with
463 additions
and
39 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,2 @@ | ||
[settings] | ||
known_third_party = mmcv,mmdet,numpy,onnx,setuptools,tensorrt,torch | ||
known_third_party = mmcv,mmdet,numpy,onnx,pytest,setuptools,tensorrt,torch |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,3 @@ | ||
from .pytorch2onnx import torch2onnx | ||
from .pytorch2onnx import torch2onnx, torch2onnx_impl | ||
|
||
__all__ = ['torch2onnx'] | ||
__all__ = ['torch2onnx_impl', 'torch2onnx'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
from .init_plugins import get_ops_path | ||
|
||
__all__ = ['get_ops_path'] | ||
|
||
|
||
def is_available(): | ||
import os.path as osp | ||
tensorrt_op_path = get_ops_path() | ||
if not osp.exists(tensorrt_op_path): | ||
return False | ||
|
||
import importlib | ||
return importlib.util.find_spec('onnxruntime') is not None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
import ctypes | ||
import glob | ||
import logging | ||
import os | ||
|
||
|
||
def get_ops_path(): | ||
"""Get TensorRT plugins library path.""" | ||
wildcard = os.path.abspath( | ||
os.path.join( | ||
os.path.dirname(__file__), | ||
'../../../build/lib/libmmlab_onnxruntime_ops.so')) | ||
|
||
paths = glob.glob(wildcard) | ||
lib_path = paths[0] if len(paths) > 0 else '' | ||
return lib_path | ||
|
||
|
||
def load_tensorrt_plugin(): | ||
"""load TensorRT plugins library.""" | ||
lib_path = get_ops_path() | ||
if os.path.exists(lib_path): | ||
ctypes.CDLL(lib_path) | ||
return 0 | ||
else: | ||
logging.warning('Can not load tensorrt custom ops.') | ||
return -1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,14 +1,27 @@ | ||
# flake8: noqa | ||
from .init_plugins import load_tensorrt_plugin | ||
from .onnx2tensorrt import onnx2tensorrt | ||
from .tensorrt_utils import (TRTWrapper, load_trt_engine, onnx2trt, | ||
save_trt_engine) | ||
|
||
# load tensorrt plugin lib | ||
load_tensorrt_plugin() | ||
|
||
__all__ = [ | ||
'onnx2trt', 'save_trt_engine', 'load_trt_engine', 'TRTWraper', | ||
'TRTWrapper', 'is_tensorrt_plugin_loaded', 'preprocess_onnx', | ||
'onnx2tensorrt' | ||
] | ||
from .init_plugins import get_ops_path, load_tensorrt_plugin | ||
|
||
|
||
def is_available(): | ||
import os.path as osp | ||
tensorrt_op_path = get_ops_path() | ||
if not osp.exists(tensorrt_op_path): | ||
return False | ||
|
||
import importlib | ||
return importlib.util.find_spec('tensorrt') is not None | ||
|
||
|
||
if is_available(): | ||
from .onnx2tensorrt import onnx2tensorrt | ||
from .tensorrt_utils import (TRTWrapper, load_trt_engine, onnx2trt, | ||
save_trt_engine) | ||
|
||
# load tensorrt plugin lib | ||
load_tensorrt_plugin() | ||
|
||
__all__ = [ | ||
'onnx2trt', 'save_trt_engine', 'load_trt_engine', 'TRTWraper', | ||
'TRTWrapper', 'is_tensorrt_plugin_loaded', 'preprocess_onnx', | ||
'onnx2tensorrt' | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
import os | ||
import os.path as osp | ||
import shutil | ||
|
||
import mmcv | ||
import pytest | ||
import torch | ||
import torch.multiprocessing as mp | ||
from torch import nn | ||
|
||
import mmdeploy.apis.tensorrt as trt_apis | ||
|
||
# skip if tensorrt apis can not loaded | ||
if not trt_apis.is_available(): | ||
pytest.skip('TensorRT apis is not prepared.') | ||
trt = pytest.importorskip('tensorrt', reason='Import tensorrt failed.') | ||
if not torch.cuda.is_available(): | ||
pytest.skip('CUDA is not available.') | ||
|
||
# load apis from trt_apis | ||
TRTWrapper = trt_apis.TRTWrapper | ||
onnx2tensorrt = trt_apis.onnx2tensorrt | ||
|
||
ret_value = mp.Value('d', 0, lock=False) | ||
work_dir = './tmp/' | ||
onnx_file = 'tmp.onnx' | ||
save_file = 'tmp.engine' | ||
|
||
|
||
@pytest.fixture(autouse=True) | ||
def clear_workdir_after_test(): | ||
# clear work_dir before test | ||
if osp.exists(work_dir): | ||
shutil.rmtree(work_dir) | ||
os.mkdir(work_dir) | ||
|
||
yield | ||
|
||
# clear work_dir after test | ||
if osp.exists(work_dir): | ||
shutil.rmtree(work_dir) | ||
|
||
|
||
def test_onnx2tensorrt(): | ||
|
||
# dummy model | ||
class TestModel(nn.Module): | ||
|
||
def __init__(self): | ||
super().__init__() | ||
|
||
def forward(self, x): | ||
return x + 1 | ||
|
||
model = TestModel().eval().cuda() | ||
x = torch.rand(1, 3, 64, 64).cuda() | ||
|
||
onnx_path = osp.join(work_dir, onnx_file) | ||
# export to onnx | ||
torch.onnx.export( | ||
model, | ||
x, | ||
onnx_path, | ||
input_names=['input'], | ||
output_names=['output'], | ||
dynamic_axes={'input': { | ||
0: 'batch', | ||
2: 'height', | ||
3: 'width' | ||
}}) | ||
|
||
assert osp.exists(onnx_path) | ||
|
||
# deploy config | ||
deploy_cfg = mmcv.Config( | ||
dict( | ||
backend='tensorrt', | ||
tensorrt_param=dict( | ||
shared_param=dict( | ||
log_level=trt.Logger.WARNING, fp16_mode=False), | ||
model_params=[ | ||
dict( | ||
opt_shape_dict=dict( | ||
input=[[1, 3, 32, 32], [1, 3, 64, 64], | ||
[1, 3, 128, 128]]), | ||
max_workspace_size=1 << 30) | ||
]))) | ||
|
||
# convert to engine | ||
onnx2tensorrt( | ||
work_dir, | ||
save_file, | ||
0, | ||
deploy_cfg=deploy_cfg, | ||
onnx_model=onnx_path, | ||
ret_value=ret_value) | ||
|
||
assert ret_value.value == 0 | ||
assert osp.exists(work_dir) | ||
assert osp.exists(osp.join(work_dir, save_file)) | ||
|
||
# test | ||
trt_model = TRTWrapper(osp.join(work_dir, save_file)) | ||
x = x.cuda() | ||
|
||
with torch.no_grad(): | ||
trt_output = trt_model({'input': x})['output'] | ||
|
||
torch.testing.assert_allclose(trt_output, x + 1) |
Oops, something went wrong.