Skip to content

Commit

Permalink
[Feture] Export preprocess and deploy information to SDK (open-mmlab#65)
Browse files Browse the repository at this point in the history
* add export info

* add dump-info funciton

* add collect info

* fix lint

* add docstring

* docstring

* docstring
  • Loading branch information
AllentDan authored Sep 13, 2021
1 parent 745c51f commit 10793f4
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 1 deletion.
38 changes: 38 additions & 0 deletions mmdeploy/utils/export_info.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from typing import Union

import mmcv

from mmdeploy.utils import load_config


def dump_info(deploy_cfg: Union[str, mmcv.Config],
model_cfg: Union[str, mmcv.Config], work_dir: str):
"""Export information to SDK.
Args:
deploy_cfg (str | mmcv.Config): deploy config file or dict
model_cfg (str | mmcv.Config): model config file or dict
work_dir (str): work dir to save json files
"""
# TODO dump default values of transformation function to json
deploy_cfg, model_cfg = load_config(deploy_cfg, model_cfg)
meta_keys = [
'filename', 'ori_filename', 'ori_shape', 'img_shape', 'pad_shape',
'scale_factor', 'flip', 'flip_direction', 'img_norm_cfg'
]
if 'transforms' in model_cfg.data.test.pipeline[-1]:
model_cfg.data.test.pipeline[-1]['transforms'][-1][
'meta_keys'] = meta_keys
else:
model_cfg.data.test.pipeline[-1]['meta_keys'] = meta_keys
mmcv.dump(
model_cfg.data.test.pipeline,
'{}/preprocess.json'.format(work_dir),
sort_keys=False,
indent=4)

mmcv.dump(
deploy_cfg._cfg_dict,
'{}/deploy_cfg.json'.format(work_dir),
sort_keys=False,
indent=4)
8 changes: 7 additions & 1 deletion tools/deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from mmdeploy.apis.utils import get_partition_cfg
from mmdeploy.utils.config_utils import (Backend, get_backend, get_codebase,
load_config)
from mmdeploy.utils.export_info import dump_info


def parse_args():
Expand All @@ -37,6 +38,8 @@ def parse_args():
choices=list(logging._nameToLevel.keys()))
parser.add_argument(
'--show', action='store_true', help='Show detection outputs')
parser.add_argument(
'--dump-info', action='store_true', help='Output information for SDK')
args = parser.parse_args()

return args
Expand Down Expand Up @@ -87,7 +90,10 @@ def main():
checkpoint_path = args.checkpoint

# load deploy_cfg
deploy_cfg = load_config(deploy_cfg_path)[0]
deploy_cfg, model_cfg = load_config(deploy_cfg_path, model_cfg_path)

if args.dump_info:
dump_info(deploy_cfg, model_cfg, args.work_dir, args.img, args.device)

# create work_dir if not
mmcv.mkdir_or_exist(osp.abspath(args.work_dir))
Expand Down

0 comments on commit 10793f4

Please sign in to comment.