From be5423d860448a55df43e58557d839e684d451b4 Mon Sep 17 00:00:00 2001 From: SingleZombie Date: Fri, 11 Feb 2022 14:37:47 +0800 Subject: [PATCH 01/10] Finish function tests --- mmdeploy/core/rewriters/function_rewriter.py | 35 ++- mmdeploy/core/rewriters/module_rewriter.py | 35 +-- mmdeploy/core/rewriters/rewriter_manager.py | 25 +-- mmdeploy/core/rewriters/rewriter_utils.py | 221 +++++++++++++++---- mmdeploy/core/rewriters/symbolic_rewriter.py | 20 +- mmdeploy/utils/__init__.py | 6 +- mmdeploy/utils/constants.py | 7 + mmdeploy/utils/env.py | 49 ++++ tests/test_core/test_function_rewriter.py | 22 +- tests/test_core/test_rewriter_registry.py | 59 ----- tests/test_core/test_rewriter_utils.py | 110 +++++++++ tests/test_utils/test_util.py | 23 ++ tools/check_env.py | 81 ++----- 13 files changed, 448 insertions(+), 245 deletions(-) create mode 100644 mmdeploy/utils/env.py delete mode 100644 tests/test_core/test_rewriter_registry.py create mode 100644 tests/test_core/test_rewriter_utils.py diff --git a/mmdeploy/core/rewriters/function_rewriter.py b/mmdeploy/core/rewriters/function_rewriter.py index 674361f634..5b02bd5cf6 100644 --- a/mmdeploy/core/rewriters/function_rewriter.py +++ b/mmdeploy/core/rewriters/function_rewriter.py @@ -1,8 +1,9 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Callable, Dict +from typing import Callable, Dict, List, Optional, Union -from mmdeploy.utils import Backend, get_root_logger -from .rewriter_utils import ContextCaller, RewriterRegistry, import_function +from mmdeploy.utils import IR, Backend, get_root_logger +from .rewriter_utils import (Checker, ContextCaller, RewriterRegistry, + import_function) def _set_func(origin_func_path: str, rewrite_func: Callable): @@ -66,32 +67,30 @@ class FunctionRewriter: def __init__(self): self._registry = RewriterRegistry() - def add_backend(self, backend: str): - """Add a backend by calling the _registry.add_backend.""" - self._registry.add_backend(backend) - - def register_rewriter(self, - func_name: str, - backend: str = Backend.DEFAULT.value, - **kwargs): + def register_rewriter( + self, + func_name: str, + backend: str = Backend.DEFAULT.value, + ir: IR = IR.DEFAULT, + extra_checkers: Optional[Union[Checker, List[Checker]]] = None, + **kwargs): """The interface of function rewriter decorator. Args: func_name (str): The function name/path to rewrite. - backend (str): The inference engine name. + backend (str): The rewriter will be activated on which backend. + Returns: Callable: The process of registering function. """ - return self._registry.register_object(func_name, backend, **kwargs) + return self._registry.register_object(func_name, backend, ir, + extra_checkers, **kwargs) - def enter(self, - cfg: Dict = dict(), - backend: str = Backend.DEFAULT.value, - **kwargs): + def enter(self, cfg: Dict = dict(), env: Dict = dict(), **kwargs): """The implementation of function rewrite.""" # Get current records - functions_records = self._registry.get_records(backend) + functions_records = self._registry.get_records(env) self._origin_functions = list() self._additional_functions = list() diff --git a/mmdeploy/core/rewriters/module_rewriter.py b/mmdeploy/core/rewriters/module_rewriter.py index 43720443c6..5d2166192d 100644 --- a/mmdeploy/core/rewriters/module_rewriter.py +++ b/mmdeploy/core/rewriters/module_rewriter.py @@ -1,11 +1,13 @@ # Copyright (c) OpenMMLab. All rights reserved. import inspect +from typing import Dict, List, Optional, Union import mmcv from torch import nn -from mmdeploy.utils.constants import Backend -from .rewriter_utils import RewriterRegistry, eval_with_import +from mmdeploy.utils.constants import IR, Backend +from .rewriter_utils import (Checker, RewriterRegistry, collect_env, + eval_with_import) class ModuleRewriter: @@ -26,14 +28,13 @@ class ModuleRewriter: def __init__(self): self._registry = RewriterRegistry() - def add_backend(self, backend: str): - """Add a backend by calling the _registry.add_backend.""" - self._registry.add_backend(backend) - - def register_rewrite_module(self, - module_type: str, - backend: str = Backend.DEFAULT.value, - **kwargs): + def register_rewrite_module( + self, + module_type: str, + backend: str = Backend.DEFAULT.value, + ir: IR = IR.DEFAULT, + extra_checkers: Optional[Union[Checker, List[Checker]]] = None, + **kwargs): """The interface of module rewriter decorator. Args: @@ -41,15 +42,17 @@ def register_rewrite_module(self, backend (str): The inference engine name. Returns: - nn.Module: THe rewritten model. + nn.Module: The rewritten model. """ - return self._registry.register_object(module_type, backend, **kwargs) + return self._registry.register_object(module_type, backend, ir, + extra_checkers, **kwargs) def patch_model(self, model: nn.Module, cfg: mmcv.Config, backend: str = Backend.DEFAULT.value, recursive: bool = True, + ir: IR = IR.DEFAULT, **kwargs) -> nn.Module: """Replace the models that was registered. @@ -67,7 +70,9 @@ def patch_model(self, >>> patched_model = patch_model(model, cfg=deploy_cfg, >>> backend=backend) """ - self._collect_record(backend) + # TODO: Make the type of parameter backend to Backend + env = collect_env(Backend.get(backend), ir) + self._collect_record(env) return self._replace_module(model, cfg, recursive, **kwargs) def _replace_one_module(self, module, cfg, **kwargs): @@ -103,9 +108,9 @@ def _replace_module_impl(model, cfg, **kwargs): return _replace_module_impl(model, cfg, **kwargs) - def _collect_record(self, backend: str): + def _collect_record(self, env: Dict): """Collect models in registry.""" self._records = {} - records = self._registry.get_records(backend) + records = self._registry.get_records(env) for name, kwargs in records: self._records[eval_with_import(name)] = kwargs diff --git a/mmdeploy/core/rewriters/rewriter_manager.py b/mmdeploy/core/rewriters/rewriter_manager.py index df7e82703d..f76072e55d 100644 --- a/mmdeploy/core/rewriters/rewriter_manager.py +++ b/mmdeploy/core/rewriters/rewriter_manager.py @@ -4,9 +4,10 @@ import mmcv import torch.nn as nn -from mmdeploy.utils.constants import Backend +from mmdeploy.utils.constants import IR, Backend from .function_rewriter import FunctionRewriter from .module_rewriter import ModuleRewriter +from .rewriter_utils import collect_env from .symbolic_rewriter import SymbolicRewriter @@ -18,20 +19,8 @@ def __init__(self): self.function_rewriter = FunctionRewriter() self.symbolic_rewriter = SymbolicRewriter() - def add_backend(self, backend: str): - """Add backend to all rewriters. - - Args: - backend (str): The backend to support. - """ - self.module_rewriter.add_backend(backend) - self.function_rewriter.add_backend(backend) - self.symbolic_rewriter.add_backend(backend) - REWRITER_MANAGER = RewriterManager() -for backend in Backend: - REWRITER_MANAGER.add_backend(backend.value) MODULE_REWRITER = REWRITER_MANAGER.module_rewriter FUNCTION_REWRITER = REWRITER_MANAGER.function_rewriter @@ -84,20 +73,20 @@ class RewriterContext: def __init__(self, cfg: Dict = dict(), backend: str = Backend.DEFAULT.value, + ir: IR = IR.DEFAULT, rewriter_manager: RewriterManager = REWRITER_MANAGER, **kwargs): self._cfg = cfg - self._backend = backend self._kwargs = kwargs self._rewriter_manager = rewriter_manager + # TODO: Make the type of parameter backend to Backend + self._env = collect_env(Backend.get(backend), ir) def enter(self): """Call the enter() of rewriters.""" - self._rewriter_manager.function_rewriter.enter(self._cfg, - self._backend, + self._rewriter_manager.function_rewriter.enter(self._cfg, self._env, **self._kwargs) - self._rewriter_manager.symbolic_rewriter.enter(self._cfg, - self._backend, + self._rewriter_manager.symbolic_rewriter.enter(self._cfg, self._env, **self._kwargs) def exit(self): diff --git a/mmdeploy/core/rewriters/rewriter_utils.py b/mmdeploy/core/rewriters/rewriter_utils.py index 701078144a..12e29d605e 100644 --- a/mmdeploy/core/rewriters/rewriter_utils.py +++ b/mmdeploy/core/rewriters/rewriter_utils.py @@ -1,8 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. import inspect -from typing import Any, Callable, Dict, List, Optional, Tuple +import warnings +from abc import ABCMeta, abstractmethod +from typing import Any, Callable, Dict, List, Optional, Tuple, Union -from mmdeploy.utils.constants import Backend +import mmdeploy +from mmdeploy.utils.constants import IR, Backend def eval_with_import(path: str) -> Any: @@ -56,6 +59,66 @@ def import_function(path: str) -> Tuple[Callable, Optional[type]]: return obj, None +def collect_env(backend: Backend, ir: IR, **kwargs): + from mmdeploy.utils import get_codebase_version, get_backend_version + env = dict(backend=backend, ir=ir) + env['mmdeploy'] = mmdeploy.__version__ + env.update(get_backend_version()) + env.update(get_codebase_version()) + env.update(kwargs) + return env + + +class Checker(metaclass=ABCMeta): + + def __init__(self): + pass + + @abstractmethod + def check(self, env: Dict) -> bool: + pass + + +class BackendChecker(Checker): + + def __init__(self, required_backend: Backend): + super().__init__() + self.required_backend = required_backend + + def check(self, env: Dict) -> bool: + return env['backend'] == self.required_backend + + +class IRChecker(Checker): + + def __init__(self, required_ir: Backend): + super().__init__() + self.required_ir = required_ir + + def check(self, env: Dict) -> bool: + return env['ir'] == self.required_ir + + +class LibVersionChecker(Checker): + + def __init__(self, lib: str, min_version=None, max_version=None): + super().__init__() + self.lib = lib + self.min_version = min_version + self.max_version = max_version + + def check(self, env: Dict) -> bool: + from packaging import version + valid = True + if self.min_version is not None: + valid = version.parse(env[self.lib]) >= version.parse( + self.min_version) + if self.max_version is not None: + valid = version.parse(env[self.lib]) <= version.parse( + self.max_version) + return valid + + class RewriterRegistry: """A registry that recoreds rewrite objects. @@ -78,55 +141,121 @@ class RewriterRegistry: # TODO: replace backend string with "Backend" constant def __init__(self): self._rewrite_records = dict() - self.add_backend(Backend.DEFAULT.value) - - def _check_backend(self, backend: str): - """Check if a backend has been supported.""" - if backend not in self._rewrite_records: - raise Exception('Backend is not supported by registry.') - - def add_backend(self, backend: str): - """Add a backend dictionary.""" - if backend not in self._rewrite_records: - self._rewrite_records[backend] = dict() - - def get_records(self, backend: str) -> List: - """Get all registered records in record table.""" - self._check_backend(backend) - - if backend != Backend.DEFAULT.value: - # Update dict A with dict B. - # Then convert the result dict to a list, while keeping the order - # of A and B: the elements only belong to B should alwarys come - # after the elements only belong to A. - # The complexity is O(n + m). - dict_a = self._rewrite_records[Backend.DEFAULT.value] - dict_b = self._rewrite_records[backend] - records = [] - for k, v in dict_a.items(): - if k in dict_b: - records.append((k, dict_b[k])) + + def get_records(self, env: Dict) -> List: + """Get all registered records that are valid in the given environment + from record table. + + If the backend and ir of rewriter are set to 'default', then the + rewriter is regarded as default rewriter. The default rewriter will be + activated only when all other rewriters are not valid. If there are + multiple rewriters are valid (except default rewriter), we will + activate the first one (The order is determined by the time when + rewriters are loaded). + + Args: + env (dict): Environment dictionary that includes backend, ir, + codebase version, etc. + + Returns: + List: A list that includes valid records. + """ + default_records = list() + records = list() + + for origin_function, rewriter_records in self._rewrite_records.items(): + default_rewriter = None + final_rewriter = None + for record in rewriter_records: + # Get the checkers of current rewriter + checkers: List[Checker] = record['_checkers'] + + # Check if the rewriter is default rewriter + if len(checkers) == 0: + # Process the default rewriter exceptionally + default_rewriter = record else: - records.append((k, v)) - for k, v in dict_b.items(): - if k not in dict_a: - records.append((k, v)) - else: - records = list( - self._rewrite_records[Backend.DEFAULT.value].items()) - return records - - def _register(self, name: str, backend: str, **kwargs): + # Check if the checker is valid. + # The checker is valid only if all the checks are passed + valid = True + for checker in checkers: + if not checker.check(env): + valid = False + break + + if valid: + # Check if there are multiple valid rewriters + if final_rewriter is not None: + warnings.warn( + 'Detect multiple valid rewriters for' + f'{origin_function}, use the first rewriter') + else: + final_rewriter = record + + # Append final rewriter. + # If there is no valid rewriter, try not apply default rewriter + if final_rewriter is not None: + records.append((origin_function, final_rewriter)) + elif default_rewriter is not None: + default_records.append((origin_function, default_rewriter)) + + # Make the default records como to the front of list because we may + # want the non-default records to override them. + return default_records + records + + def _register(self, name: str, backend: Backend, ir: IR, + extra_checkers: List[Checker], **kwargs): """The implementation of register.""" - self._check_backend(backend) - self._rewrite_records[backend][name] = kwargs - def register_object(self, name: str, backend: str, **kwargs) -> Callable: - """The decorator to register an object.""" - self._check_backend(backend) + # Merge checkers to kwargs + record_dict = kwargs + + # Try to create a checker according to 'backend' field + if backend != Backend.DEFAULT: + extra_checkers.append(BackendChecker(backend)) + + # Try to create a checker according to 'ir' field + if ir != IR.DEFAULT: + extra_checkers.append(IRChecker(ir)) + + record_dict['_checkers'] = extra_checkers + + # There may be multiple rewriters of a function/module. We use a list + # to store the rewriters of a function/module. + if name not in self._rewrite_records: + self._rewrite_records[name] = list() + self._rewrite_records[name].append(record_dict) + + def register_object(self, + name: str, + backend: str, + ir: IR, + extra_checkers: Optional[Union[Checker, + List[Checker]]] = None, + **kwargs) -> Callable: + """The decorator to register an object. + + Args: + name (str): The import path to access the function/module. + backend (str): The rewriter will be activated on which backend. + ir (IR): The rewriter will be activated on which ir. + extra_chekcers (None | Checker | List[Checker]): Other requirements + for the rewriters. Default to `None`. + + Returns: + Callable: The decorator. + """ + + if extra_checkers is None: + extra_checkers = [] + elif isinstance(extra_checkers, Checker): + extra_checkers = [extra_checkers] + + backend = Backend.get(backend) def decorator(object): - self._register(name, backend, _object=object, **kwargs) + self._register( + name, backend, ir, extra_checkers, _object=object, **kwargs) return object return decorator diff --git a/mmdeploy/core/rewriters/symbolic_rewriter.py b/mmdeploy/core/rewriters/symbolic_rewriter.py index c9c16d071d..f53dbfa261 100644 --- a/mmdeploy/core/rewriters/symbolic_rewriter.py +++ b/mmdeploy/core/rewriters/symbolic_rewriter.py @@ -1,13 +1,14 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Callable, Dict, Optional, Sequence +from typing import Callable, Dict, List, Optional, Sequence, Union from torch.autograd import Function from torch.onnx.symbolic_helper import parse_args from torch.onnx.symbolic_registry import _registry as pytorch_registry from torch.onnx.symbolic_registry import register_op -from mmdeploy.utils import Backend, get_root_logger -from .rewriter_utils import ContextCaller, RewriterRegistry, eval_with_import +from mmdeploy.utils import IR, Backend, get_root_logger +from .rewriter_utils import (Checker, ContextCaller, RewriterRegistry, + eval_with_import) class SymbolicRewriter: @@ -35,15 +36,14 @@ class SymbolicRewriter: def __init__(self) -> None: self._registry = RewriterRegistry() - def add_backend(self, backend: str): - """Add a backend by calling the _registry.add_backend.""" - self._registry.add_backend(backend) - def register_symbolic(self, func_name: str, backend: str = Backend.DEFAULT.value, is_pytorch: bool = False, arg_descriptors: Optional[Sequence[str]] = None, + ir: IR = IR.DEFAULT, + extra_checkers: Optional[Union[ + Checker, List[Checker]]] = None, **kwargs) -> Callable: """The decorator of the custom symbolic. @@ -61,18 +61,20 @@ def register_symbolic(self, return self._registry.register_object( func_name, backend, + ir, + extra_checkers, is_pytorch=is_pytorch, arg_descriptors=arg_descriptors, **kwargs) def enter(self, cfg: Dict = dict(), - backend: str = Backend.DEFAULT.value, + env: Dict = dict(), opset: int = 11, **kwargs): """The implementation of symbolic register.""" # Get current records - symbolic_records = self._registry.get_records(backend) + symbolic_records = self._registry.get_records(env) self._pytorch_symbolic = list() self._extra_symbolic = list() diff --git a/mmdeploy/utils/__init__.py b/mmdeploy/utils/__init__.py index 03543f9d5f..850113e253 100644 --- a/mmdeploy/utils/__init__.py +++ b/mmdeploy/utils/__init__.py @@ -6,8 +6,9 @@ get_model_inputs, get_onnx_config, get_partition_config, get_task_type, is_dynamic_batch, is_dynamic_shape, load_config) -from .constants import SDK_TASK_MAP, Backend, Codebase, Task +from .constants import IR, SDK_TASK_MAP, Backend, Codebase, Task from .device import parse_cuda_device_id, parse_device_id +from .env import get_backend_version, get_codebase_version, get_library_version from .utils import get_root_logger, target_wrapper __all__ = [ @@ -18,5 +19,6 @@ 'get_model_inputs', 'cfg_apply_marks', 'get_input_shape', 'parse_device_id', 'parse_cuda_device_id', 'get_codebase_config', 'get_backend_config', 'get_root_logger', 'get_dynamic_axes', - 'target_wrapper', 'SDK_TASK_MAP' + 'target_wrapper', 'SDK_TASK_MAP', 'get_library_version', + 'get_codebase_version', 'get_backend_version', 'IR' ] diff --git a/mmdeploy/utils/constants.py b/mmdeploy/utils/constants.py index ab726fd528..fade370904 100644 --- a/mmdeploy/utils/constants.py +++ b/mmdeploy/utils/constants.py @@ -35,6 +35,13 @@ class Codebase(AdvancedEnum): MMEDIT = 'mmedit' +class IR(AdvancedEnum): + """Define intermediate representation enumerations.""" + ONNX = 'onnx' + TORCHSCRIPT = 'torchscript' + DEFAULT = 'default' + + class Backend(AdvancedEnum): """Define backend enumerations.""" PYTORCH = 'pytorch' diff --git a/mmdeploy/utils/env.py b/mmdeploy/utils/env.py new file mode 100644 index 0000000000..901cc87fd7 --- /dev/null +++ b/mmdeploy/utils/env.py @@ -0,0 +1,49 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import importlib + +from mmdeploy.utils import Codebase + + +def get_library_version(lib): + """Try to get the version of a library if it has been installed. + + Args: + lib (str): The name of library. + + Returns: + None | str: If the library has been installed, return version. + """ + try: + lib = importlib.import_module(lib) + except ImportError: + version = None + else: + version = lib.__version__ + + return version + + +def get_codebase_version(): + """Get the version dictionary of all supported codebases. + + Returns: + Dict: The name and the version of supported codebases. + """ + version_dict = dict() + for enum in Codebase: + codebase = enum.value + version_dict[codebase] = get_library_version(codebase) + return version_dict + + +def get_backend_version(): + """Get the version dictionary of some supported backend. + + Returns: + Dict: The name and the version of some supported backend. + """ + backend_library_list = ['tensorrt', 'onnxruntime', 'ncnn'] + version_dict = dict() + for backend in backend_library_list: + version_dict[backend] = get_library_version(backend) + return version_dict diff --git a/tests/test_core/test_function_rewriter.py b/tests/test_core/test_function_rewriter.py index b9b43fb688..97a814e929 100644 --- a/tests/test_core/test_function_rewriter.py +++ b/tests/test_core/test_function_rewriter.py @@ -3,7 +3,8 @@ from mmdeploy.core import FUNCTION_REWRITER, RewriterContext from mmdeploy.core.rewriters.function_rewriter import FunctionRewriter -from mmdeploy.utils.constants import Backend +from mmdeploy.core.rewriters.rewriter_utils import collect_env +from mmdeploy.utils.constants import IR, Backend def test_function_rewriter(): @@ -97,7 +98,6 @@ def test_rewrite_homonymic_functions(self): assert package.module.func() == 1 function_rewriter = FunctionRewriter() - function_rewriter.add_backend(Backend.NCNN.value) @function_rewriter.register_rewriter(func_name=path1) def func_2(ctx): @@ -108,7 +108,7 @@ def func_2(ctx): def func_3(ctx): return 3 - function_rewriter.enter(backend=Backend.NCNN.value) + function_rewriter.enter(env=collect_env(Backend.NCNN, ir=IR.DEFAULT)) # This is a feature assert package.func() == 2 assert package.module.func() == 3 @@ -118,7 +118,6 @@ def func_3(ctx): assert package.module.func() == 1 function_rewriter2 = FunctionRewriter() - function_rewriter2.add_backend(Backend.NCNN.value) @function_rewriter2.register_rewriter( func_name=path1, backend=Backend.NCNN.value) @@ -129,7 +128,7 @@ def func_4(ctx): def func_5(ctx): return 5 - function_rewriter2.enter(backend=Backend.NCNN.value) + function_rewriter2.enter(env=collect_env(Backend.NCNN, ir=IR.DEFAULT)) # This is a feature assert package.func() == 4 assert package.module.func() == 5 @@ -146,7 +145,6 @@ def test_rewrite_homonymic_methods(self): c = package.C() function_rewriter = FunctionRewriter() - function_rewriter.add_backend(Backend.NCNN.value) assert c.method() == 1 @@ -159,14 +157,13 @@ def func_2(ctx, self): def func_3(ctx, self): return 3 - function_rewriter.enter(backend=Backend.NCNN.value) + function_rewriter.enter(env=collect_env(Backend.NCNN, ir=IR.DEFAULT)) assert c.method() == 3 function_rewriter.exit() assert c.method() == 1 function_rewriter2 = FunctionRewriter() - function_rewriter2.add_backend(Backend.NCNN.value) @function_rewriter2.register_rewriter( func_name=path1, backend=Backend.NCNN.value) @@ -177,7 +174,7 @@ def func_4(ctx, self): def func_5(ctx, self): return 5 - function_rewriter2.enter(backend=Backend.NCNN.value) + function_rewriter2.enter(env=collect_env(Backend.NCNN, ir=IR.DEFAULT)) assert c.method() == 4 function_rewriter2.exit() @@ -196,7 +193,6 @@ def test_rewrite_derived_methods(): assert derived_obj.method() == 1 function_rewriter = FunctionRewriter() - function_rewriter.add_backend(Backend.NCNN.value) @function_rewriter.register_rewriter(func_name=path1) def func_2(ctx, self): @@ -207,12 +203,12 @@ def func_2(ctx, self): def func_3(ctx, self): return 3 - function_rewriter.enter() + function_rewriter.enter(env=collect_env(Backend.DEFAULT, ir=IR.DEFAULT)) assert base_obj.method() == 2 assert derived_obj.method() == 2 function_rewriter.exit() - function_rewriter.enter(backend=Backend.NCNN.value) + function_rewriter.enter(env=collect_env(Backend.NCNN, ir=IR.DEFAULT)) assert base_obj.method() == 2 assert derived_obj.method() == 3 function_rewriter.exit() @@ -221,7 +217,7 @@ def func_3(ctx, self): assert derived_obj.method() == 1 # Check if the recovery is correct - function_rewriter.enter() + function_rewriter.enter(env=collect_env(Backend.DEFAULT, ir=IR.DEFAULT)) assert base_obj.method() == 2 assert derived_obj.method() == 2 function_rewriter.exit() diff --git a/tests/test_core/test_rewriter_registry.py b/tests/test_core/test_rewriter_registry.py deleted file mode 100644 index b577d02623..0000000000 --- a/tests/test_core/test_rewriter_registry.py +++ /dev/null @@ -1,59 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import pytest - -from mmdeploy.core.rewriters.rewriter_utils import RewriterRegistry -from mmdeploy.utils.constants import Backend - - -def test_check_backend(): - with pytest.raises(Exception): - registry = RewriterRegistry() - registry._check_backend(Backend.ONNXRUNTIME.value) - - -def test_add_backend(): - registry = RewriterRegistry() - registry.add_backend(Backend.ONNXRUNTIME.value) - assert Backend.ONNXRUNTIME.value in registry._rewrite_records - assert Backend.DEFAULT.value in registry._rewrite_records - assert Backend.TENSORRT.value not in registry._rewrite_records - - -def test_register_object(): - registry = RewriterRegistry() - - @registry.register_object('add', backend=Backend.DEFAULT.value) - def add(a, b): - return a + b - - records = registry._rewrite_records[Backend.DEFAULT.value] - assert records is not None - assert records['add'] is not None - assert records['add']['_object'] is not None - add_func = records['add']['_object'] - assert add_func(123, 456) == 123 + 456 - - -def test_get_records(): - registry = RewriterRegistry() - registry.add_backend(Backend.TENSORRT.value) - - @registry.register_object('add', backend=Backend.DEFAULT.value) - def add(a, b): - return a + b - - @registry.register_object('minus', backend=Backend.DEFAULT.value) - def minus(a, b): - return a - b - - @registry.register_object('add', backend=Backend.TENSORRT.value) - def fake_add(a, b): - return a * b - - default_records = dict(registry.get_records(Backend.DEFAULT.value)) - assert default_records['add']['_object'](1, 1) == 2 - assert default_records['minus']['_object'](1, 1) == 0 - - tensorrt_records = dict(registry.get_records(Backend.TENSORRT.value)) - assert tensorrt_records['add']['_object'](1, 1) == 1 - assert tensorrt_records['minus']['_object'](1, 1) == 0 diff --git a/tests/test_core/test_rewriter_utils.py b/tests/test_core/test_rewriter_utils.py new file mode 100644 index 0000000000..c2fa03322c --- /dev/null +++ b/tests/test_core/test_rewriter_utils.py @@ -0,0 +1,110 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import pytest + +from mmdeploy.core.rewriters.rewriter_utils import ( + BackendChecker, RewriterRegistry, collect_env) +import mmdeploy.core.rewriters.rewriter_utils as rewriter_utils +from mmdeploy.utils.constants import Backend, IR +import mmdeploy + + +def test_collect_env(): + env_dict = collect_env(Backend.ONNXRUNTIME, IR.ONNX, version='1.0') + assert env_dict['backend'] == Backend.ONNXRUNTIME + assert env_dict['ir'] == IR.ONNX + assert env_dict['version'] == '1.0' + assert env_dict['mmdeploy'] == mmdeploy.__version__ + + +class TestCheker: + env = collect_env(Backend.ONNXRUNTIME, IR.ONNX) + + def test_backend_checker(self): + true_checker = rewriter_utils.BackendChecker(Backend.ONNXRUNTIME) + assert true_checker.check(self.env) is True + + false_checker = rewriter_utils.BackendChecker(Backend.TENSORRT) + assert false_checker.check(self.env) is False + + def test_ir_checker(self): + true_checker = rewriter_utils.IRChecker(IR.ONNX) + assert true_checker.check(self.env) is True + + false_checker = rewriter_utils.IRChecker(IR.TORCHSCRIPT) + assert false_checker.check(self.env) is False + + def test_lib_version_checker(self): + true_checker = rewriter_utils.LibVersionChecker( + 'mmdeploy', mmdeploy.__version__, mmdeploy.__version__) + assert true_checker.check(self.env) is True + + false_checker = rewriter_utils.LibVersionChecker( + 'mmdeploy', max_version='0.0.0') + assert false_checker.check(self.env) is False + + +def test_register_object(): + registry = RewriterRegistry() + checker = rewriter_utils.BackendChecker(Backend.ONNXRUNTIME) + + @registry.register_object('add', backend=Backend.DEFAULT.value, + ir=IR.DEFAULT, extra_checkers=checker) + def add(a, b): + return a + b + + records = registry._rewrite_records + assert records is not None + assert records['add'] is not None + assert isinstance(records['add'][0]['_checkers'], list) + assert isinstance(records['add'][0]['_checkers'][0], BackendChecker) + assert records['add'][0]['_object'] is not None + add_func = records['add'][0]['_object'] + assert add_func(123, 456) == 123 + 456 + + +def test_get_records(): + registry = RewriterRegistry() + + @registry.register_object('get_num', backend=Backend.ONNXRUNTIME.value, + ir=IR.ONNX) + def get_num_1(): + return 1 + + @registry.register_object('get_num', backend=Backend.ONNXRUNTIME.value, + ir=IR.TORCHSCRIPT) + def get_num_2(): + return 2 + + @registry.register_object('get_num', backend=Backend.TENSORRT.value, + ir=IR.ONNX) + def get_num_3(): + return 3 + + @registry.register_object('get_num', backend=Backend.TENSORRT.value, + ir=IR.TORCHSCRIPT) + def get_num_4(): + return 4 + + @registry.register_object('get_num', backend=Backend.DEFAULT.value, + ir=IR.DEFAULT) + def get_num_5(): + return 5 + + records = dict(registry.get_records( + collect_env(Backend.ONNXRUNTIME, IR.ONNX))) + assert records['get_num']['_object']() == 1 + + records = dict(registry.get_records( + collect_env(Backend.ONNXRUNTIME, IR.TORCHSCRIPT))) + assert records['get_num']['_object']() == 2 + + records = dict(registry.get_records( + collect_env(Backend.TENSORRT, IR.ONNX))) + assert records['get_num']['_object']() == 3 + + records = dict(registry.get_records( + collect_env(Backend.TENSORRT, IR.TORCHSCRIPT))) + assert records['get_num']['_object']() == 4 + + records = dict(registry.get_records(collect_env(Backend.NCNN, IR.ONNX))) + assert records['get_num']['_object']() == 5 diff --git a/tests/test_utils/test_util.py b/tests/test_utils/test_util.py index e9f5ad33c2..ea36c63a69 100644 --- a/tests/test_utils/test_util.py +++ b/tests/test_utils/test_util.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. +import importlib import logging import os import tempfile @@ -440,3 +441,25 @@ def test_get_root_logger(): from mmdeploy.utils import get_root_logger logger = get_root_logger() logger.info('This is a test message') + + +def test_get_library_version(): + assert util.get_library_version('abcdefg') is None + try: + lib = importlib.import_module('setuptools') + except ImportError: + pass + else: + assert util.get_library_version('setuptools') == lib.__version__ + + +def test_get_codebase_version(): + versions = util.get_codebase_version() + for k, v in versions.items(): + assert v == util.get_library_version(k) + + +def test_get_backend_version(): + versions = util.get_backend_version() + for k, v in versions.items(): + assert v == util.get_library_version(k) diff --git a/tools/check_env.py b/tools/check_env.py index 68aa2799e7..3718db1bd5 100644 --- a/tools/check_env.py +++ b/tools/check_env.py @@ -4,49 +4,36 @@ from mmcv.utils import get_git_hash import mmdeploy -from mmdeploy.utils import get_root_logger +from mmdeploy.utils import (get_backend_version, get_codebase_version, + get_root_logger) def collect_env(): """Collect the information of the running environments.""" env_info = collect_base_env() - env_info['MMDeployment'] = f'{mmdeploy.__version__}+{get_git_hash()[:7]}' + env_info['MMDeploy'] = f'{mmdeploy.__version__}+{get_git_hash()[:7]}' return env_info def check_backend(): - try: - import onnxruntime as ort - except ImportError: - ort_version = None - else: - ort_version = ort.__version__ + backend_versions = get_backend_version() + ort_version = backend_versions['onnxruntime'] + trt_version = backend_versions['tensorrt'] + ncnn_version = backend_versions['ncnn'] + import mmdeploy.apis.onnxruntime as ort_apis logger = get_root_logger() - logger.info(f'onnxruntime: {ort_version} ops_is_avaliable : ' + logger.info(f'onnxruntime: {ort_version}\tops_is_avaliable : ' f'{ort_apis.is_available()}') - try: - import tensorrt as trt - except ImportError: - trt_version = None - else: - trt_version = trt.__version__ import mmdeploy.apis.tensorrt as trt_apis - logger.info( - f'tensorrt: {trt_version} ops_is_avaliable : {trt_apis.is_available()}' - ) - - try: - import ncnn - except ImportError: - ncnn_version = None - else: - ncnn_version = ncnn.__version__ + logger.info(f'tensorrt: {trt_version}\tops_is_avaliable : ' + f'{trt_apis.is_available()}') + import mmdeploy.apis.ncnn as ncnn_apis logger.info( - f'ncnn: {ncnn_version} ops_is_avaliable : {ncnn_apis.is_available()}') + f'ncnn: {ncnn_version}\tops_is_avaliable : {ncnn_apis.is_available()}') import mmdeploy.apis.pplnn as pplnn_apis logger.info(f'pplnn_is_avaliable: {pplnn_apis.is_available()}') @@ -56,45 +43,9 @@ def check_backend(): def check_codebase(): - try: - import mmcls - except ImportError: - mmcls_version = None - else: - mmcls_version = mmcls.__version__ - logger.info(f'mmcls: {mmcls_version}') - - try: - import mmdet - except ImportError: - mmdet_version = None - else: - mmdet_version = mmdet.__version__ - logger.info(f'mmdet: {mmdet_version}') - - try: - import mmedit - except ImportError: - mmedit_version = None - else: - mmedit_version = mmedit.__version__ - logger.info(f'mmedit: {mmedit_version}') - - try: - import mmocr - except ImportError: - mmocr_version = None - else: - mmocr_version = mmocr.__version__ - logger.info(f'mmocr: {mmocr_version}') - - try: - import mmseg - except ImportError: - mmseg_version = None - else: - mmseg_version = mmseg.__version__ - logger.info(f'mmseg: {mmseg_version}') + codebase_versions = get_codebase_version() + for k, v in codebase_versions.items(): + logger.info(f'{k}:\t{v}') if __name__ == '__main__': From 52ad0c84f395ca1eddf7c68664bff9fe958a1d18 Mon Sep 17 00:00:00 2001 From: SingleZombie Date: Fri, 11 Feb 2022 14:48:22 +0800 Subject: [PATCH 02/10] lint --- tests/test_core/test_rewriter_utils.py | 54 +++++++++++++------------- 1 file changed, 28 insertions(+), 26 deletions(-) diff --git a/tests/test_core/test_rewriter_utils.py b/tests/test_core/test_rewriter_utils.py index c2fa03322c..fe8f9f1df6 100644 --- a/tests/test_core/test_rewriter_utils.py +++ b/tests/test_core/test_rewriter_utils.py @@ -1,11 +1,10 @@ # Copyright (c) OpenMMLab. All rights reserved. -import pytest - -from mmdeploy.core.rewriters.rewriter_utils import ( - BackendChecker, RewriterRegistry, collect_env) -import mmdeploy.core.rewriters.rewriter_utils as rewriter_utils -from mmdeploy.utils.constants import Backend, IR import mmdeploy +import mmdeploy.core.rewriters.rewriter_utils as rewriter_utils +from mmdeploy.core.rewriters.rewriter_utils import (BackendChecker, + RewriterRegistry, + collect_env) +from mmdeploy.utils.constants import IR, Backend def test_collect_env(): @@ -47,8 +46,11 @@ def test_register_object(): registry = RewriterRegistry() checker = rewriter_utils.BackendChecker(Backend.ONNXRUNTIME) - @registry.register_object('add', backend=Backend.DEFAULT.value, - ir=IR.DEFAULT, extra_checkers=checker) + @registry.register_object( + 'add', + backend=Backend.DEFAULT.value, + ir=IR.DEFAULT, + extra_checkers=checker) def add(a, b): return a + b @@ -65,45 +67,45 @@ def add(a, b): def test_get_records(): registry = RewriterRegistry() - @registry.register_object('get_num', backend=Backend.ONNXRUNTIME.value, - ir=IR.ONNX) + @registry.register_object( + 'get_num', backend=Backend.ONNXRUNTIME.value, ir=IR.ONNX) def get_num_1(): return 1 - @registry.register_object('get_num', backend=Backend.ONNXRUNTIME.value, - ir=IR.TORCHSCRIPT) + @registry.register_object( + 'get_num', backend=Backend.ONNXRUNTIME.value, ir=IR.TORCHSCRIPT) def get_num_2(): return 2 - @registry.register_object('get_num', backend=Backend.TENSORRT.value, - ir=IR.ONNX) + @registry.register_object( + 'get_num', backend=Backend.TENSORRT.value, ir=IR.ONNX) def get_num_3(): return 3 - @registry.register_object('get_num', backend=Backend.TENSORRT.value, - ir=IR.TORCHSCRIPT) + @registry.register_object( + 'get_num', backend=Backend.TENSORRT.value, ir=IR.TORCHSCRIPT) def get_num_4(): return 4 - @registry.register_object('get_num', backend=Backend.DEFAULT.value, - ir=IR.DEFAULT) + @registry.register_object( + 'get_num', backend=Backend.DEFAULT.value, ir=IR.DEFAULT) def get_num_5(): return 5 - records = dict(registry.get_records( - collect_env(Backend.ONNXRUNTIME, IR.ONNX))) + records = dict( + registry.get_records(collect_env(Backend.ONNXRUNTIME, IR.ONNX))) assert records['get_num']['_object']() == 1 - records = dict(registry.get_records( - collect_env(Backend.ONNXRUNTIME, IR.TORCHSCRIPT))) + records = dict( + registry.get_records(collect_env(Backend.ONNXRUNTIME, IR.TORCHSCRIPT))) assert records['get_num']['_object']() == 2 - records = dict(registry.get_records( - collect_env(Backend.TENSORRT, IR.ONNX))) + records = dict( + registry.get_records(collect_env(Backend.TENSORRT, IR.ONNX))) assert records['get_num']['_object']() == 3 - records = dict(registry.get_records( - collect_env(Backend.TENSORRT, IR.TORCHSCRIPT))) + records = dict( + registry.get_records(collect_env(Backend.TENSORRT, IR.TORCHSCRIPT))) assert records['get_num']['_object']() == 4 records = dict(registry.get_records(collect_env(Backend.NCNN, IR.ONNX))) From e01fdff81255b909824948d6e772301a330b96ce Mon Sep 17 00:00:00 2001 From: SingleZombie Date: Mon, 14 Feb 2022 14:51:07 +0800 Subject: [PATCH 03/10] resolve comments --- mmdeploy/core/rewriters/rewriter_utils.py | 11 ++++++++--- mmdeploy/utils/env.py | 2 +- tests/test_core/test_rewriter_utils.py | 2 +- 3 files changed, 10 insertions(+), 5 deletions(-) diff --git a/mmdeploy/core/rewriters/rewriter_utils.py b/mmdeploy/core/rewriters/rewriter_utils.py index 12e29d605e..a5b14ea9fd 100644 --- a/mmdeploy/core/rewriters/rewriter_utils.py +++ b/mmdeploy/core/rewriters/rewriter_utils.py @@ -60,7 +60,7 @@ def import_function(path: str) -> Tuple[Callable, Optional[type]]: def collect_env(backend: Backend, ir: IR, **kwargs): - from mmdeploy.utils import get_codebase_version, get_backend_version + from mmdeploy.utils import get_backend_version, get_codebase_version env = dict(backend=backend, ir=ir) env['mmdeploy'] = mmdeploy.__version__ env.update(get_backend_version()) @@ -173,7 +173,12 @@ def get_records(self, env: Dict) -> List: # Check if the rewriter is default rewriter if len(checkers) == 0: # Process the default rewriter exceptionally - default_rewriter = record + if default_rewriter is not None: + default_rewriter = record + else: + warnings.warn( + 'Detect multiple valid rewriters for' + f'{origin_function}, use the first rewriter.') else: # Check if the checker is valid. # The checker is valid only if all the checks are passed @@ -188,7 +193,7 @@ def get_records(self, env: Dict) -> List: if final_rewriter is not None: warnings.warn( 'Detect multiple valid rewriters for' - f'{origin_function}, use the first rewriter') + f'{origin_function}, use the first rewriter.') else: final_rewriter = record diff --git a/mmdeploy/utils/env.py b/mmdeploy/utils/env.py index 901cc87fd7..8cc2cbd3d5 100644 --- a/mmdeploy/utils/env.py +++ b/mmdeploy/utils/env.py @@ -15,7 +15,7 @@ def get_library_version(lib): """ try: lib = importlib.import_module(lib) - except ImportError: + except Exception: version = None else: version = lib.__version__ diff --git a/tests/test_core/test_rewriter_utils.py b/tests/test_core/test_rewriter_utils.py index fe8f9f1df6..4954a573d8 100644 --- a/tests/test_core/test_rewriter_utils.py +++ b/tests/test_core/test_rewriter_utils.py @@ -15,7 +15,7 @@ def test_collect_env(): assert env_dict['mmdeploy'] == mmdeploy.__version__ -class TestCheker: +class TestChecker: env = collect_env(Backend.ONNXRUNTIME, IR.ONNX) def test_backend_checker(self): From 50693fda538b39297cd0da8bcd176f62bd33d519 Mon Sep 17 00:00:00 2001 From: SingleZombie Date: Mon, 14 Feb 2022 17:33:14 +0800 Subject: [PATCH 04/10] Fix tests --- mmdeploy/core/rewriters/rewriter_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mmdeploy/core/rewriters/rewriter_utils.py b/mmdeploy/core/rewriters/rewriter_utils.py index a5b14ea9fd..fcfbb0f297 100644 --- a/mmdeploy/core/rewriters/rewriter_utils.py +++ b/mmdeploy/core/rewriters/rewriter_utils.py @@ -173,7 +173,7 @@ def get_records(self, env: Dict) -> List: # Check if the rewriter is default rewriter if len(checkers) == 0: # Process the default rewriter exceptionally - if default_rewriter is not None: + if default_rewriter is None: default_rewriter = record else: warnings.warn( From c5d317b0fa3447bb627abbe3b06c534631c92ee5 Mon Sep 17 00:00:00 2001 From: SingleZombie Date: Mon, 14 Feb 2022 19:45:42 +0800 Subject: [PATCH 05/10] docstring & fix --- mmdeploy/core/rewriters/function_rewriter.py | 3 + mmdeploy/core/rewriters/module_rewriter.py | 8 ++- mmdeploy/core/rewriters/rewriter_manager.py | 5 +- mmdeploy/core/rewriters/rewriter_utils.py | 62 +++++++++++++++++--- mmdeploy/core/rewriters/symbolic_rewriter.py | 5 +- 5 files changed, 71 insertions(+), 12 deletions(-) diff --git a/mmdeploy/core/rewriters/function_rewriter.py b/mmdeploy/core/rewriters/function_rewriter.py index 5b02bd5cf6..e80ed41d06 100644 --- a/mmdeploy/core/rewriters/function_rewriter.py +++ b/mmdeploy/core/rewriters/function_rewriter.py @@ -79,6 +79,9 @@ def register_rewriter( Args: func_name (str): The function name/path to rewrite. backend (str): The rewriter will be activated on which backend. + ir (IR): The rewriter will be activated on which IR. + extra_checkers (Checker | List[Checker] | None): Other requirements + defined by Checker. Returns: Callable: The process of registering function. diff --git a/mmdeploy/core/rewriters/module_rewriter.py b/mmdeploy/core/rewriters/module_rewriter.py index 5d2166192d..d0961809a0 100644 --- a/mmdeploy/core/rewriters/module_rewriter.py +++ b/mmdeploy/core/rewriters/module_rewriter.py @@ -39,7 +39,10 @@ def register_rewrite_module( Args: module_type (str): The module type name to rewrite. - backend (str): The inference engine name. + backend (str): The rewriter will be activated on which backend. + ir (IR): The rewriter will be activated on which IR. + extra_checkers (Checker | List[Checker] | None): Other requirements + defined by Checker. Returns: nn.Module: The rewritten model. @@ -51,8 +54,8 @@ def patch_model(self, model: nn.Module, cfg: mmcv.Config, backend: str = Backend.DEFAULT.value, - recursive: bool = True, ir: IR = IR.DEFAULT, + recursive: bool = True, **kwargs) -> nn.Module: """Replace the models that was registered. @@ -60,6 +63,7 @@ def patch_model(self, model (torch.nn.Module): The model to patch. cfg (Dict): Config dictionary of deployment. backend (str): The inference engine name. + ir (IR): The intermeditate representation name. recursive (bool): The flag to enable recursive patching. Returns: diff --git a/mmdeploy/core/rewriters/rewriter_manager.py b/mmdeploy/core/rewriters/rewriter_manager.py index f76072e55d..d92684cf55 100644 --- a/mmdeploy/core/rewriters/rewriter_manager.py +++ b/mmdeploy/core/rewriters/rewriter_manager.py @@ -30,6 +30,7 @@ def __init__(self): def patch_model(model: nn.Module, cfg: mmcv.Config, backend: str = Backend.DEFAULT.value, + ir: IR = IR.DEFAULT, recursive: bool = True, **kwargs) -> nn.Module: """Patch the model, replace the modules that can be rewritten. Note that @@ -39,6 +40,7 @@ def patch_model(model: nn.Module, model (torch.nn.Module): The model to patch. cfg (Dict): Config dictionary of deployment. backend (str): The inference engine name. + ir (IR): The intermeditate representation name. recursive (bool): The flag to enable recursive patching. Returns: @@ -48,7 +50,7 @@ def patch_model(model: nn.Module, >>> from mmdeploy.core import patch_model >>> patched_model = patch_model(model, cfg=deploy_cfg, backend=backend) """ - return MODULE_REWRITER.patch_model(model, cfg, backend, recursive, + return MODULE_REWRITER.patch_model(model, cfg, backend, ir, recursive, **kwargs) @@ -60,6 +62,7 @@ class RewriterContext: Args: cfg (Dict): Config dictionary of deployment. backend (str): The inference engine name. + ir (IR): The intermeditate representation name. rewrite_manager (RewriterManager): An RewriteManager that consists of several rewriters diff --git a/mmdeploy/core/rewriters/rewriter_utils.py b/mmdeploy/core/rewriters/rewriter_utils.py index fcfbb0f297..991837a0a2 100644 --- a/mmdeploy/core/rewriters/rewriter_utils.py +++ b/mmdeploy/core/rewriters/rewriter_utils.py @@ -70,52 +70,98 @@ def collect_env(backend: Backend, ir: IR, **kwargs): class Checker(metaclass=ABCMeta): + """The interface for checking whether a rewriter is valid.""" def __init__(self): pass @abstractmethod def check(self, env: Dict) -> bool: + """Check the if the rewriter is valid according to environment. + + Args: + env (Dict): The backend, IR info and version info. + """ pass class BackendChecker(Checker): + """Checker that determines which backend the rewriter must run on. + + Args: + required_backend (Backend): The rewriter will be activated on + which backend. + """ def __init__(self, required_backend: Backend): super().__init__() self.required_backend = required_backend def check(self, env: Dict) -> bool: + """Check the if the rewriter is valid according to backend. + + Args: + env (Dict): The backend, IR info and version info. + """ return env['backend'] == self.required_backend class IRChecker(Checker): + """Checker that determines which IR the rewriter must run on. - def __init__(self, required_ir: Backend): + Args: + required_ir (IR): The rewriter will be activated on which IR. + """ + + def __init__(self, required_ir: IR): super().__init__() self.required_ir = required_ir def check(self, env: Dict) -> bool: + """Check the if the rewriter is valid according to IR. + + Args: + env (Dict): The backend, IR info and version info. + """ return env['ir'] == self.required_ir class LibVersionChecker(Checker): + """Checker that determines which IR the rewriter must run on. + + Args: + lib (str): The name of library. + min_version (str | None): The rewriter should no lower than which + version. Default to `None`. + max_version (str | None): The rewriter should no greater than which + version. Default to `None`. + """ - def __init__(self, lib: str, min_version=None, max_version=None): + def __init__(self, + lib: str, + min_version: Optional[str] = None, + max_version: Optional[str] = None): super().__init__() self.lib = lib self.min_version = min_version self.max_version = max_version def check(self, env: Dict) -> bool: + """Check the if the rewriter is valid according to library version. + + Args: + env (Dict): The backend, IR info and version info. + """ from packaging import version valid = True + # The version should no less than min version and no greater than + # max version. if self.min_version is not None: - valid = version.parse(env[self.lib]) >= version.parse( - self.min_version) + if version.parse(env[self.lib]) < version.parse(self.min_version): + valid = False if self.max_version is not None: - valid = version.parse(env[self.lib]) <= version.parse( - self.max_version) + if version.parse(env[self.lib]) > version.parse(self.max_version): + valid = False return valid @@ -146,7 +192,7 @@ def get_records(self, env: Dict) -> List: """Get all registered records that are valid in the given environment from record table. - If the backend and ir of rewriter are set to 'default', then the + If the backend and IR of rewriter are set to 'default', then the rewriter is regarded as default rewriter. The default rewriter will be activated only when all other rewriters are not valid. If there are multiple rewriters are valid (except default rewriter), we will @@ -154,7 +200,7 @@ def get_records(self, env: Dict) -> List: rewriters are loaded). Args: - env (dict): Environment dictionary that includes backend, ir, + env (dict): Environment dictionary that includes backend, IR, codebase version, etc. Returns: diff --git a/mmdeploy/core/rewriters/symbolic_rewriter.py b/mmdeploy/core/rewriters/symbolic_rewriter.py index f53dbfa261..dd47cd8d58 100644 --- a/mmdeploy/core/rewriters/symbolic_rewriter.py +++ b/mmdeploy/core/rewriters/symbolic_rewriter.py @@ -49,11 +49,14 @@ def register_symbolic(self, Args: func_name (str): The function name/path to override the symbolic. - backend (str): The inference engine name. + backend (str): The rewriter will be activated on which backend. is_pytorch (bool): Enable this flag if func_name is the name of \ a pytorch builtin function. arg_descriptors (Sequence[str]): The argument descriptors of the \ symbol. + ir (IR): The rewriter will be activated on which IR. + extra_checkers (Checker | List[Checker] | None): Other requirements + defined by Checker. Returns: Callable: The process of registered symbolic. From 65701e398e98045d82d004f9552560c8303b2c69 Mon Sep 17 00:00:00 2001 From: SingleZombie Date: Tue, 15 Feb 2022 14:27:34 +0800 Subject: [PATCH 06/10] Complement informations --- mmdeploy/core/rewriters/rewriter_utils.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/mmdeploy/core/rewriters/rewriter_utils.py b/mmdeploy/core/rewriters/rewriter_utils.py index 991837a0a2..e16e4c31ff 100644 --- a/mmdeploy/core/rewriters/rewriter_utils.py +++ b/mmdeploy/core/rewriters/rewriter_utils.py @@ -59,7 +59,18 @@ def import_function(path: str) -> Tuple[Callable, Optional[type]]: return obj, None -def collect_env(backend: Backend, ir: IR, **kwargs): +def collect_env(backend: Backend, ir: IR, **kwargs) -> Dict: + """Collect current environment informations, including backend, ir, + codebase version, etc. Rewriters will be checked according to env infos. + + Args: + backend (Backend): Current backend. + ir (IR): Current IR. + + Returns: + Dict: Record the value of Backend and IR as well as the versions of + libraries. + """ from mmdeploy.utils import get_backend_version, get_codebase_version env = dict(backend=backend, ir=ir) env['mmdeploy'] = mmdeploy.__version__ From 08beca56e753e5940725cdc0e35be4fb21cefa4d Mon Sep 17 00:00:00 2001 From: SingleZombie Date: Tue, 15 Feb 2022 14:32:06 +0800 Subject: [PATCH 07/10] lint --- mmdeploy/core/rewriters/rewriter_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mmdeploy/core/rewriters/rewriter_utils.py b/mmdeploy/core/rewriters/rewriter_utils.py index e16e4c31ff..f3d2c8b7c4 100644 --- a/mmdeploy/core/rewriters/rewriter_utils.py +++ b/mmdeploy/core/rewriters/rewriter_utils.py @@ -60,8 +60,8 @@ def import_function(path: str) -> Tuple[Callable, Optional[type]]: def collect_env(backend: Backend, ir: IR, **kwargs) -> Dict: - """Collect current environment informations, including backend, ir, - codebase version, etc. Rewriters will be checked according to env infos. + """Collect current environment information, including backend, ir, codebase + version, etc. Rewriters will be checked according to env infos. Args: backend (Backend): Current backend. From db88031d0633f779301d003a3eca0d93454623ee Mon Sep 17 00:00:00 2001 From: SingleZombie Date: Tue, 15 Feb 2022 19:36:45 +0800 Subject: [PATCH 08/10] Add example --- mmdeploy/codebase/mmdet/deploy/utils.py | 28 +++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/mmdeploy/codebase/mmdet/deploy/utils.py b/mmdeploy/codebase/mmdet/deploy/utils.py index 1ecd451e2f..32cc95a9a1 100644 --- a/mmdeploy/codebase/mmdet/deploy/utils.py +++ b/mmdeploy/codebase/mmdet/deploy/utils.py @@ -5,6 +5,8 @@ import torch from torch import Tensor +from mmdeploy.core import FUNCTION_REWRITER +from mmdeploy.core.rewriters.rewriter_utils import LibVersionChecker from mmdeploy.utils import load_config @@ -69,6 +71,32 @@ def clip_bboxes(x1: Tensor, y1: Tensor, x2: Tensor, y2: Tensor, return x1, y1, x2, y2 +@FUNCTION_REWRITER.register_rewriter( + func_name='mmdeploy.codebase.mmdet.deploy.utils.clip_bboxes', + backend='tensorrt', + extra_checkers=LibVersionChecker('tensorrt', min_version='8')) +def clip_bboxes__trt8(x1: Tensor, y1: Tensor, x2: Tensor, y2: Tensor, + max_shape: Union[Tensor, Sequence[int]]): + """Clip bboxes for onnx. From TensorRT 8 we can do the operators on the + tensors directly. + + Args: + x1 (Tensor): The x1 for bounding boxes. + y1 (Tensor): The y1 for bounding boxes. + x2 (Tensor): The x2 for bounding boxes. + y2 (Tensor): The y2 for bounding boxes. + max_shape (Tensor | Sequence[int]): The (H,W) of original image. + Returns: + tuple(Tensor): The clipped x1, y1, x2, y2. + """ + assert len(max_shape) == 2, '`max_shape` should be [h, w]' + x1 = torch.clamp(x1, 0, max_shape[1]) + y1 = torch.clamp(y1, 0, max_shape[0]) + x2 = torch.clamp(x2, 0, max_shape[1]) + y2 = torch.clamp(y2, 0, max_shape[0]) + return x1, y1, x2, y2 + + def pad_with_value(x: Tensor, pad_dim: int, pad_size: int, From 1f24d993973c47b04d56eb36bbd63fd2e239a8b6 Mon Sep 17 00:00:00 2001 From: SingleZombie Date: Tue, 15 Feb 2022 20:43:41 +0800 Subject: [PATCH 09/10] Fix version --- mmdeploy/codebase/mmdet/deploy/utils.py | 3 ++- mmdeploy/core/rewriters/rewriter_utils.py | 4 ++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/mmdeploy/codebase/mmdet/deploy/utils.py b/mmdeploy/codebase/mmdet/deploy/utils.py index 32cc95a9a1..860cb54239 100644 --- a/mmdeploy/codebase/mmdet/deploy/utils.py +++ b/mmdeploy/codebase/mmdet/deploy/utils.py @@ -75,12 +75,13 @@ def clip_bboxes(x1: Tensor, y1: Tensor, x2: Tensor, y2: Tensor, func_name='mmdeploy.codebase.mmdet.deploy.utils.clip_bboxes', backend='tensorrt', extra_checkers=LibVersionChecker('tensorrt', min_version='8')) -def clip_bboxes__trt8(x1: Tensor, y1: Tensor, x2: Tensor, y2: Tensor, +def clip_bboxes__trt8(ctx, x1: Tensor, y1: Tensor, x2: Tensor, y2: Tensor, max_shape: Union[Tensor, Sequence[int]]): """Clip bboxes for onnx. From TensorRT 8 we can do the operators on the tensors directly. Args: + ctx (ContextCaller): The context with additional information. x1 (Tensor): The x1 for bounding boxes. y1 (Tensor): The y1 for bounding boxes. x2 (Tensor): The x2 for bounding boxes. diff --git a/mmdeploy/core/rewriters/rewriter_utils.py b/mmdeploy/core/rewriters/rewriter_utils.py index f3d2c8b7c4..3b8201526f 100644 --- a/mmdeploy/core/rewriters/rewriter_utils.py +++ b/mmdeploy/core/rewriters/rewriter_utils.py @@ -163,6 +163,10 @@ def check(self, env: Dict) -> bool: Args: env (Dict): The backend, IR info and version info. """ + # If the library has not been installed + if env[self.lib] is None: + return False + from packaging import version valid = True # The version should no less than min version and no greater than From 5b1b0ee340a2d07cfed0796fbfdd43cc1deb333a Mon Sep 17 00:00:00 2001 From: SingleZombie Date: Wed, 16 Feb 2022 13:07:58 +0800 Subject: [PATCH 10/10] Remove todo --- mmdeploy/core/rewriters/rewriter_manager.py | 1 - mmdeploy/core/rewriters/rewriter_utils.py | 1 - 2 files changed, 2 deletions(-) diff --git a/mmdeploy/core/rewriters/rewriter_manager.py b/mmdeploy/core/rewriters/rewriter_manager.py index d92684cf55..de3acaffd2 100644 --- a/mmdeploy/core/rewriters/rewriter_manager.py +++ b/mmdeploy/core/rewriters/rewriter_manager.py @@ -82,7 +82,6 @@ def __init__(self, self._cfg = cfg self._kwargs = kwargs self._rewriter_manager = rewriter_manager - # TODO: Make the type of parameter backend to Backend self._env = collect_env(Backend.get(backend), ir) def enter(self): diff --git a/mmdeploy/core/rewriters/rewriter_utils.py b/mmdeploy/core/rewriters/rewriter_utils.py index 3b8201526f..a80fd84738 100644 --- a/mmdeploy/core/rewriters/rewriter_utils.py +++ b/mmdeploy/core/rewriters/rewriter_utils.py @@ -199,7 +199,6 @@ class RewriterRegistry: >>> records = FUNCTION_REGISTRY.get_record("default") """ - # TODO: replace backend string with "Backend" constant def __init__(self): self._rewrite_records = dict()