Skip to content

Commit

Permalink
[Improvement] Better unit test. (#1619)
Browse files Browse the repository at this point in the history
* update test for mmcls and mmdet

* update det3d mmedit mmocr mmpose mmrotate

* update mmseg

* bug fixing

* refactor ops

* rename variable

* remove comment
  • Loading branch information
grimoire authored Feb 8, 2023
1 parent 5de0ecf commit d8e4a78
Show file tree
Hide file tree
Showing 58 changed files with 3,702 additions and 3,868 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -202,13 +202,13 @@ nvinfer1::IPluginV2 *TRTRoIAlignCreator::createPlugin(
if (field_name.compare("mode") == 0) {
int data_size = fc->fields[i].length;
const char *data_start = static_cast<const char *>(fc->fields[i].data);
std::string poolModeStr(data_start, data_size);
if (poolModeStr == "avg") {
std::string pool_mode_str(data_start);
if (pool_mode_str == "avg") {
poolMode = 1;
} else if (poolModeStr == "max") {
} else if (pool_mode_str == "max") {
poolMode = 0;
} else {
std::cout << "Unknown pool mode \"" << poolModeStr << "\"." << std::endl;
std::cout << "Unknown pool mode \"" << pool_mode_str << "\"." << std::endl;
}
ASSERT(poolMode >= 0);
}
Expand Down
1 change: 0 additions & 1 deletion mmdeploy/codebase/mmcls/models/utils/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,6 @@ def shift_window_msa__forward__default(ctx, self, query, hw_shape):
'mmcls.models.utils.ShiftWindowMSA.get_attn_mask',
extra_checkers=LibVersionChecker('mmcls', min_version='0.21.0'))
def shift_window_msa__get_attn_mask__default(ctx,
self,
hw_shape,
window_size,
shift_size,
Expand Down
2 changes: 1 addition & 1 deletion mmdeploy/codebase/mmdet3d/core/bbox/fcos3d_bbox_coder.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

@FUNCTION_REWRITER.register_rewriter(
'mmdet3d.core.bbox.coders.fcos3d_bbox_coder.FCOS3DBBoxCoder.decode_yaw')
def decode_yaw(ctx, self, bbox, centers2d, dir_cls, dir_offset, cam2img):
def decode_yaw(ctx, bbox, centers2d, dir_cls, dir_offset, cam2img):
"""Decode yaw angle and change it from local to global.i. Rewrite this func
to use slice instead of the original operation.
Args:
Expand Down
12 changes: 11 additions & 1 deletion mmdeploy/core/rewriters/function_rewriter.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
import inspect
from typing import (Any, Callable, Dict, List, MutableSequence, Optional,
Tuple, Union)

Expand Down Expand Up @@ -72,7 +73,16 @@ def _set_func(origin_func_path: str,
rewrite_func,
ignore_refs=ignore_refs,
ignore_keys=ignore_keys)
exec(f'{origin_func_path} = rewrite_func')

is_static_method = False
if method_class:
origin_type = inspect.getattr_static(module_or_class, split_path[-1])
is_static_method = isinstance(origin_type, staticmethod)

if is_static_method:
exec(f'{origin_func_path} = staticmethod(rewrite_func)')
else:
exec(f'{origin_func_path} = rewrite_func')


def _del_func(path: str):
Expand Down
23 changes: 23 additions & 0 deletions mmdeploy/core/rewriters/rewriter_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,29 @@ def decorator(object):

return decorator

def remove_record(self, object: Any, filter_cb: Optional[Callable] = None):
"""Remove record.
Args:
object (Any): The object to remove.
filter_cb (Callable): Check if the object need to be remove.
Defaults to None.
"""
key_to_pop = []
for key, records in self._rewrite_records.items():
for rec in records:
if rec['_object'] == object:
if filter_cb is not None:
if filter_cb(rec):
continue
key_to_pop.append((key, rec))

for key, rec in key_to_pop:
records = self._rewrite_records[key]
records.remove(rec)
if len(records) == 0:
self._rewrite_records.pop(key)


class ContextCaller:
"""A callable object used in RewriteContext.
Expand Down
37 changes: 26 additions & 11 deletions mmdeploy/utils/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@
from mmdeploy.utils import (IR, Backend, get_backend, get_dynamic_axes,
get_ir_config, get_onnx_config)

try:
from torch.testing import assert_close as torch_assert_close
except Exception:
from torch.testing import assert_allclose as torch_assert_close


def backend_checker(backend: Backend, require_plugin: bool = False):
"""A decorator which checks if a backend is available.
Expand Down Expand Up @@ -189,12 +194,6 @@ def __init__(self, recover_class):
self._recover_class = recover_class

def __enter__(self):
return self

def __exit__(self, type, value, trace):
self.recover()

def set(self, **kwargs):
"""Replace attributes in backend wrappers with dummy items."""
obj = self._recover_class
self.init = obj.__init__
Expand All @@ -203,10 +202,9 @@ def set(self, **kwargs):
obj.__init__ = SwitchBackendWrapper.BackendWrapper.__init__
obj.forward = SwitchBackendWrapper.BackendWrapper.forward
obj.__call__ = SwitchBackendWrapper.BackendWrapper.__call__
for k, v in kwargs.items():
setattr(obj, k, v)
return self

def recover(self):
def __exit__(self, type, value, trace):
"""Recover to original class."""
assert self.init is not None and \
self.forward is not None,\
Expand All @@ -216,6 +214,11 @@ def recover(self):
obj.forward = self.forward
obj.__call__ = self.call

def set(self, **kwargs):
obj = self._recover_class
for k, v in kwargs.items():
setattr(obj, k, v)


def assert_allclose(expected: List[Union[torch.Tensor, np.ndarray]],
actual: List[Union[torch.Tensor, np.ndarray]],
Expand All @@ -239,8 +242,7 @@ def assert_allclose(expected: List[Union[torch.Tensor, np.ndarray]],
if isinstance(actual[i], (list, np.ndarray)):
actual[i] = torch.tensor(actual[i])
try:
torch.testing.assert_allclose(
actual[i], expected[i], rtol=1e-03, atol=1e-05)
torch_assert_close(actual[i], expected[i], rtol=1e-03, atol=1e-05)
except AssertionError as error:
if tolerate_small_mismatch:
assert '(0.00%)' in str(error), str(error)
Expand Down Expand Up @@ -417,6 +419,19 @@ def get_backend_outputs(ir_file_path: str,
if backend == Backend.TENSORRT:
device = 'cuda'
model_inputs = dict((k, v.cuda()) for k, v in model_inputs.items())
input_shapes = dict(
(k, dict(min_shape=v.shape, max_shape=v.shape, opt_shape=v.shape))
for k, v in model_inputs.items())
model_inputs_cfg = deploy_cfg['backend_config'].get(
'model_inputs', [dict(input_shapes=input_shapes)])
if len(model_inputs_cfg) < 1:
model_inputs_cfg = [dict(input_shapes=input_shapes)]

if 'input_shapes' not in model_inputs_cfg[0]:
model_inputs_cfg[0]['input_shapes'] = input_shapes

deploy_cfg['backend_config']['model_inputs'] = model_inputs_cfg

elif backend == Backend.OPENVINO:
input_info = {
name: value.shape
Expand Down
33 changes: 19 additions & 14 deletions tests/test_apis/test_calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,16 @@
from multiprocessing import Process

import mmcv
import pytest

from mmdeploy.apis import create_calib_input_data

calib_file = tempfile.NamedTemporaryFile(suffix='.h5').name
ann_file = 'tests/data/annotation.json'


def get_end2end_deploy_cfg():
@pytest.fixture
def deploy_cfg():
deploy_cfg = mmcv.Config(
dict(
onnx_config=dict(
Expand Down Expand Up @@ -53,14 +55,15 @@ def get_end2end_deploy_cfg():
return deploy_cfg


def get_partition_deploy_cfg():
deploy_cfg = get_end2end_deploy_cfg()
@pytest.fixture
def partition_deploy_cfg(deploy_cfg):
deploy_cfg._cfg_dict['partition_config'] = dict(
type='two_stage', apply_marks=True)
return deploy_cfg


def get_model_cfg():
@pytest.fixture
def model_cfg():
dataset_type = 'CustomDataset'
data_root = 'tests/data/'
img_norm_cfg = dict(
Expand Down Expand Up @@ -169,10 +172,8 @@ def get_model_cfg():
return model_cfg


def run_test_create_calib_end2end():
def run_test_create_calib_end2end(deploy_cfg, model_cfg):
import h5py
model_cfg = get_model_cfg()
deploy_cfg = get_end2end_deploy_cfg()
create_calib_input_data(
calib_file,
deploy_cfg,
Expand All @@ -194,18 +195,19 @@ def run_test_create_calib_end2end():
# new process.


def test_create_calib_end2end():
p = Process(target=run_test_create_calib_end2end)
def test_create_calib_end2end(deploy_cfg, model_cfg):
p = Process(
target=run_test_create_calib_end2end,
kwargs=dict(deploy_cfg=deploy_cfg, model_cfg=model_cfg))
try:
p.start()
finally:
p.join()


def run_test_create_calib_parittion():
def run_test_create_calib_parittion(partition_deploy_cfg, model_cfg):
import h5py
model_cfg = get_model_cfg()
deploy_cfg = get_partition_deploy_cfg()
deploy_cfg = partition_deploy_cfg
create_calib_input_data(
calib_file,
deploy_cfg,
Expand All @@ -227,8 +229,11 @@ def run_test_create_calib_parittion():
assert calib_data[partition_name][input_names[i]]['0'] is not None


def test_create_calib_parittion():
p = Process(target=run_test_create_calib_parittion)
def test_create_calib_parittion(partition_deploy_cfg, model_cfg):
p = Process(
target=run_test_create_calib_parittion,
kwargs=dict(
partition_deploy_cfg=partition_deploy_cfg, model_cfg=model_cfg))
try:
p.start()
finally:
Expand Down
1 change: 1 addition & 0 deletions tests/test_backend/test_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ def run_wrapper(backend, wrapper, input):
ALL_BACKEND = list(Backend)
ALL_BACKEND.remove(Backend.DEFAULT)
ALL_BACKEND.remove(Backend.PYTORCH)
ALL_BACKEND.remove(Backend.SNPE)
ALL_BACKEND.remove(Backend.SDK)


Expand Down
19 changes: 19 additions & 0 deletions tests/test_codebase/test_mmcls/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Copyright (c) OpenMMLab. All rights reserved.
import pytest

from mmdeploy.codebase import import_codebase
from mmdeploy.utils import Codebase


def pytest_ignore_collect(*args, **kwargs):
import importlib
return importlib.util.find_spec('mmcls') is None


@pytest.fixture(autouse=True, scope='package')
def import_all_modules():
codebase = Codebase.MMCLS
try:
import_codebase(codebase)
except ImportError:
pytest.skip(f'{codebase} is not installed.', allow_module_level=True)
Loading

0 comments on commit d8e4a78

Please sign in to comment.