Skip to content

Commit

Permalink
Fix some type of error (open-mmlab#18)
Browse files Browse the repository at this point in the history
* Fix some type

* Fix lint

* update

* update

* fix docstring

Co-authored-by: AllentDan <[email protected]>
  • Loading branch information
hhaAndroid and AllentDan authored Jan 6, 2022
1 parent 3fd17ab commit b5d6c03
Show file tree
Hide file tree
Showing 7 changed files with 19 additions and 18 deletions.
2 changes: 1 addition & 1 deletion mmdeploy/apis/pytorch2onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def torch2onnx(img: Any,
model_cfg: Union[str, mmcv.Config],
model_checkpoint: Optional[str] = None,
device: str = 'cuda:0'):
"""Convert PyToch model to ONNX model.
"""Convert PyTorch model to ONNX model.
Args:
img (str | np.ndarray | torch.Tensor): Input image used to assist
Expand Down
2 changes: 1 addition & 1 deletion mmdeploy/apis/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

def build_task_processor(model_cfg: mmcv.Config, deploy_cfg: mmcv.Config,
device: str) -> BaseTask:
"""Build a task processor to manage the deploy pipeline.
"""Build a task processor to manage the deployment pipeline.
Args:
model_cfg (str | mmcv.Config): Model config file.
Expand Down
7 changes: 4 additions & 3 deletions mmdeploy/codebase/base/mmcodebase.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,19 @@
from torch.utils.data import DataLoader, Dataset

from mmdeploy.utils import Codebase, Task
from .task import BaseTask


class MMCodebase(metaclass=ABCMeta):
"""Wrap the apis of OpenMMLab Codebase."""

task_registry: Registry = None

def __init__() -> None:
def __init__(self) -> None:
pass

@classmethod
def get_task_class(cls, task: Task) -> type:
def get_task_class(cls, task: Task) -> BaseTask:
"""Get the task processors class according to the task type.
Args:
Expand Down Expand Up @@ -111,7 +112,7 @@ def __build_codebase_class(codebase: Codebase, registry: Registry):
CODEBASE = Registry('Codebases', build_func=__build_codebase_class)


def get_codebase_class(codebase: Codebase) -> type:
def get_codebase_class(codebase: Codebase) -> MMCodebase:
"""Get the codebase class from the registry.
Args:
Expand Down
4 changes: 1 addition & 3 deletions mmdeploy/codebase/base/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def create_input(self,
"""Create input for model.
Args:
imgs (str | np.ndarray): Input image(s), accpeted data types are
imgs (str | np.ndarray): Input image(s), accepted data types are
`str`, `np.ndarray`.
input_shape (list[int]): Input shape of image in (width, height)
format, defaults to `None`.
Expand All @@ -167,7 +167,6 @@ def visualize(self,
image (str | np.ndarray): Input image to draw predictions on.
result (list): A list of predictions.
output_file (str): Output file to save drawn image.
backend (Backend): Specifying backend type.
window_name (str): The name of visualization window. Defaults to
an empty string.
show_result (bool): Whether to show result in windows, defaults
Expand Down Expand Up @@ -233,7 +232,6 @@ def evaluate_outputs(model_cfg,
outputs (list): A list of predictions of model inference.
dataset (Dataset): Input dataset to run test.
model_cfg (mmcv.Config): The model config.
codebase (Codebase): Specifying codebase type.
metrics (str): Evaluation metrics, which depends on
the codebase and the dataset, e.g., "bbox", "segm", "proposal"
for COCO, and "mAP", "recall" for PASCAL VOC in mmdet;
Expand Down
8 changes: 4 additions & 4 deletions mmdeploy/core/rewriters/module_rewriter.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
import inspect
from copy import deepcopy
from typing import Dict

import mmcv
from torch import nn

from mmdeploy.utils.constants import Backend
Expand Down Expand Up @@ -48,7 +48,7 @@ def register_rewrite_module(self,

def patch_model(self,
model: nn.Module,
cfg: Dict,
cfg: mmcv.Config,
backend: str = Backend.DEFAULT.value,
recursive: bool = True,
**kwargs) -> nn.Module:
Expand Down Expand Up @@ -91,8 +91,8 @@ def _replace_one_module(self, module, cfg, **kwargs):

return module_class(module, cfg, **input_args)

def _replace_module(self, model: nn.Module, cfg: Dict, recursive: bool,
**kwargs):
def _replace_module(self, model: nn.Module, cfg: mmcv.Config,
recursive: bool, **kwargs):
"""DFS and replace target models."""

def _replace_module_impl(model, cfg, **kwargs):
Expand Down
3 changes: 2 additions & 1 deletion mmdeploy/core/rewriters/rewriter_manager.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict

import mmcv
import torch.nn as nn

from mmdeploy.utils.constants import Backend
Expand Down Expand Up @@ -38,7 +39,7 @@ def add_backend(self, backend: str):


def patch_model(model: nn.Module,
cfg: Dict,
cfg: mmcv.Config,
backend: str = Backend.DEFAULT.value,
recursive: bool = True,
**kwargs) -> nn.Module:
Expand Down
11 changes: 6 additions & 5 deletions mmdeploy/utils/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def get_input_shape(deploy_cfg: Union[str, mmcv.Config]) -> List[int]:
return input_shape


def cfg_apply_marks(deploy_cfg: Union[str, mmcv.Config]) -> Union[bool, None]:
def cfg_apply_marks(deploy_cfg: Union[str, mmcv.Config]) -> Optional[bool]:
"""Check if the model needs to be partitioned by checking if the config
contains 'apply_marks'.
Expand All @@ -253,15 +253,16 @@ def cfg_apply_marks(deploy_cfg: Union[str, mmcv.Config]) -> Union[bool, None]:
return apply_marks


def get_partition_config(deploy_cfg: Union[str, mmcv.Config]) -> Dict:
def get_partition_config(
deploy_cfg: Union[str, mmcv.Config]) -> Optional[Dict]:
"""Check if the model needs to be partitioned and get the config of
partition.
Args:
deploy_cfg (str | mmcv.Config): The path or content of config.
Returns:
dict: The config of partition.
dict or None: The config of partition.
"""
partition_config = deploy_cfg.get('partition_config', None)
if partition_config is None:
Expand All @@ -288,14 +289,14 @@ def get_calib_config(deploy_cfg: Union[str, mmcv.Config]) -> Dict:
return calib_config


def get_calib_filename(deploy_cfg: Union[str, mmcv.Config]) -> str:
def get_calib_filename(deploy_cfg: Union[str, mmcv.Config]) -> Optional[str]:
"""Check if the model needs to create calib and get filename of calib.
Args:
deploy_cfg (str | mmcv.Config): The path or content of config.
Returns:
str: The filename of output calib file.
str | None: Could be the filename of output calib file or None.
"""

calib_config = get_calib_config(deploy_cfg)
Expand Down

0 comments on commit b5d6c03

Please sign in to comment.