Skip to content

Commit

Permalink
[Enhance] Get model configs from metafile without installation (open-…
Browse files Browse the repository at this point in the history
…mmlab#901)

* Support get config from model-index without installing downstream repo

* Rename _get_models_from_package to _get_models_from_config_dir

* adjust priority

* Fix as comment

* Refine exception

* Replace osp.xxx with fileio.xxx

* Refine as comment

* Revert "Replace osp.xxx with fileio.xxx"

This reverts commit 6aed9b2e88f5cf98614772ddbd89ccad22fa7d2f.

* replace fileio with osp

* fix

* Fix as comment
  • Loading branch information
HAOCHENYE authored Feb 23, 2023
1 parent d8abf9a commit 7e1b273
Showing 1 changed file with 55 additions and 16 deletions.
71 changes: 55 additions & 16 deletions mmengine/infer/infer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import importlib
import os.path as osp
import re
import warnings
Expand All @@ -24,7 +25,6 @@
from mmengine.runner.checkpoint import (_load_checkpoint,
_load_checkpoint_to_model)
from mmengine.structures import InstanceData
from mmengine.utils import get_installed_path, is_installed
from mmengine.visualization import Visualizer

InstanceList = List[InstanceData]
Expand Down Expand Up @@ -368,10 +368,10 @@ def _load_model_from_metafile(self, model: str) -> Tuple[Config, str]:
assert self.scope in MODULE2PACKAGE, (
f'{self.scope} not in {MODULE2PACKAGE}!,'
'please pass a valid scope.')
project = MODULE2PACKAGE[self.scope]
assert is_installed(project), f'Please install {project}'
package_path = get_installed_path(project)
for model_cfg in BaseInferencer._get_models_from_package(package_path):

repo_or_mim_dir = BaseInferencer._get_repo_or_mim_dir(self.scope)
for model_cfg in BaseInferencer._get_models_from_metafile(
repo_or_mim_dir):
model_name = model_cfg['Name'].lower()
model_aliases = model_cfg.get('Alias', [])
if isinstance(model_aliases, str):
Expand All @@ -380,11 +380,50 @@ def _load_model_from_metafile(self, model: str) -> Tuple[Config, str]:
model_aliases = [alias.lower() for alias in model_aliases]
if (model_name == model or model in model_aliases):
cfg = Config.fromfile(
osp.join(package_path, '.mim', model_cfg['Config']))
osp.join(repo_or_mim_dir, model_cfg['Config']))
weights = model_cfg['Weights']
weights = weights[0] if isinstance(weights, list) else weights
return cfg, weights
raise ValueError(f'Cannot find model: {model} in {project}')
raise ValueError(f'Cannot find model: {model} in {self.scope}')

@staticmethod
def _get_repo_or_mim_dir(scope):
"""Get the directory where the ``Configs`` located when the package is
installed or ``PYTHONPATH`` is set.
Args:
scope (str): The scope of repository.
Returns:
str: The directory where the ``Configs`` is located.
"""
try:
module = importlib.import_module(scope)
except ImportError:
if scope not in MODULE2PACKAGE:
raise KeyError(
f'{scope} is not a valid scope. The available scopes '
f'are {MODULE2PACKAGE.keys()}')
else:
project = MODULE2PACKAGE[scope]
raise ImportError(
f'Cannot import {scope} correctly, please try to install '
f'the {project} by "pip install {project}"')
# Since none of OpenMMLab series packages are namespace packages
# (https://docs.python.org/3/glossary.html#term-namespace-package),
# The first element of module.__path__ means package installation path.
package_path = module.__path__[0]

if osp.exists(osp.join(osp.dirname(package_path), 'configs')):
repo_dir = osp.dirname(package_path)
return repo_dir
else:
mim_dir = osp.join(package_path, '.mim')
if not osp.exists(osp.join(mim_dir, 'Configs')):
raise FileNotFoundError(
f'Cannot find Configs directory in {package_path}!, '
f'please check the completeness of the {scope}.')
return mim_dir

def _init_model(
self,
Expand Down Expand Up @@ -591,19 +630,21 @@ def _dispatch_kwargs(self, **kwargs) -> Tuple[Dict, Dict, Dict, Dict]:
)

@staticmethod
def _get_models_from_package(package_path: str):
def _get_models_from_metafile(dir: str):
"""Load model config defined in metafile from package path.
Args:
package_path (str): Path to the package.
dir (str): Path to the directory of Config. It requires the
directory ``Config``, file ``model-index.yml`` exists in the
``dir``.
Yields:
dict: Model config defined in metafile.
"""
meta_indexes = load(osp.join(package_path, '.mim', 'model-index.yml'))
meta_indexes = load(osp.join(dir, 'model-index.yml'))
for meta_path in meta_indexes['Import']:
# meta_path example: mmcls/.mim/configs/conformer/metafile.yml
meta_path = osp.join(package_path, '.mim', meta_path)
meta_path = osp.join(dir, meta_path)
metainfo = load(meta_path)
yield from metainfo['Models']

Expand Down Expand Up @@ -631,11 +672,9 @@ def list_models(scope: Optional[str] = None, patterns: str = r'.*'):
assert scope in MODULE2PACKAGE, (
f'{scope} not in {MODULE2PACKAGE}!, please make pass a valid '
'scope.')
project = MODULE2PACKAGE[scope]
assert is_installed(project), (f'Please install {project}')
package_path = get_installed_path(project)

for model_cfg in BaseInferencer._get_models_from_package(package_path):
root_or_mim_dir = BaseInferencer._get_repo_or_mim_dir(scope)
for model_cfg in BaseInferencer._get_models_from_metafile(
root_or_mim_dir):
model_name = [model_cfg['Name']]
model_name.extend(model_cfg.get('Alias', []))
for name in model_name:
Expand Down

0 comments on commit 7e1b273

Please sign in to comment.