diff --git a/mmengine/infer/infer.py b/mmengine/infer/infer.py index d72e986c0b..e27d1b233d 100644 --- a/mmengine/infer/infer.py +++ b/mmengine/infer/infer.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. import copy +import importlib import os.path as osp import re import warnings @@ -24,7 +25,6 @@ from mmengine.runner.checkpoint import (_load_checkpoint, _load_checkpoint_to_model) from mmengine.structures import InstanceData -from mmengine.utils import get_installed_path, is_installed from mmengine.visualization import Visualizer InstanceList = List[InstanceData] @@ -368,10 +368,10 @@ def _load_model_from_metafile(self, model: str) -> Tuple[Config, str]: assert self.scope in MODULE2PACKAGE, ( f'{self.scope} not in {MODULE2PACKAGE}!,' 'please pass a valid scope.') - project = MODULE2PACKAGE[self.scope] - assert is_installed(project), f'Please install {project}' - package_path = get_installed_path(project) - for model_cfg in BaseInferencer._get_models_from_package(package_path): + + repo_or_mim_dir = BaseInferencer._get_repo_or_mim_dir(self.scope) + for model_cfg in BaseInferencer._get_models_from_metafile( + repo_or_mim_dir): model_name = model_cfg['Name'].lower() model_aliases = model_cfg.get('Alias', []) if isinstance(model_aliases, str): @@ -380,11 +380,50 @@ def _load_model_from_metafile(self, model: str) -> Tuple[Config, str]: model_aliases = [alias.lower() for alias in model_aliases] if (model_name == model or model in model_aliases): cfg = Config.fromfile( - osp.join(package_path, '.mim', model_cfg['Config'])) + osp.join(repo_or_mim_dir, model_cfg['Config'])) weights = model_cfg['Weights'] weights = weights[0] if isinstance(weights, list) else weights return cfg, weights - raise ValueError(f'Cannot find model: {model} in {project}') + raise ValueError(f'Cannot find model: {model} in {self.scope}') + + @staticmethod + def _get_repo_or_mim_dir(scope): + """Get the directory where the ``Configs`` located when the package is + installed or ``PYTHONPATH`` is set. + + Args: + scope (str): The scope of repository. + + Returns: + str: The directory where the ``Configs`` is located. + """ + try: + module = importlib.import_module(scope) + except ImportError: + if scope not in MODULE2PACKAGE: + raise KeyError( + f'{scope} is not a valid scope. The available scopes ' + f'are {MODULE2PACKAGE.keys()}') + else: + project = MODULE2PACKAGE[scope] + raise ImportError( + f'Cannot import {scope} correctly, please try to install ' + f'the {project} by "pip install {project}"') + # Since none of OpenMMLab series packages are namespace packages + # (https://docs.python.org/3/glossary.html#term-namespace-package), + # The first element of module.__path__ means package installation path. + package_path = module.__path__[0] + + if osp.exists(osp.join(osp.dirname(package_path), 'configs')): + repo_dir = osp.dirname(package_path) + return repo_dir + else: + mim_dir = osp.join(package_path, '.mim') + if not osp.exists(osp.join(mim_dir, 'Configs')): + raise FileNotFoundError( + f'Cannot find Configs directory in {package_path}!, ' + f'please check the completeness of the {scope}.') + return mim_dir def _init_model( self, @@ -591,19 +630,21 @@ def _dispatch_kwargs(self, **kwargs) -> Tuple[Dict, Dict, Dict, Dict]: ) @staticmethod - def _get_models_from_package(package_path: str): + def _get_models_from_metafile(dir: str): """Load model config defined in metafile from package path. Args: - package_path (str): Path to the package. + dir (str): Path to the directory of Config. It requires the + directory ``Config``, file ``model-index.yml`` exists in the + ``dir``. Yields: dict: Model config defined in metafile. """ - meta_indexes = load(osp.join(package_path, '.mim', 'model-index.yml')) + meta_indexes = load(osp.join(dir, 'model-index.yml')) for meta_path in meta_indexes['Import']: # meta_path example: mmcls/.mim/configs/conformer/metafile.yml - meta_path = osp.join(package_path, '.mim', meta_path) + meta_path = osp.join(dir, meta_path) metainfo = load(meta_path) yield from metainfo['Models'] @@ -631,11 +672,9 @@ def list_models(scope: Optional[str] = None, patterns: str = r'.*'): assert scope in MODULE2PACKAGE, ( f'{scope} not in {MODULE2PACKAGE}!, please make pass a valid ' 'scope.') - project = MODULE2PACKAGE[scope] - assert is_installed(project), (f'Please install {project}') - package_path = get_installed_path(project) - - for model_cfg in BaseInferencer._get_models_from_package(package_path): + root_or_mim_dir = BaseInferencer._get_repo_or_mim_dir(scope) + for model_cfg in BaseInferencer._get_models_from_metafile( + root_or_mim_dir): model_name = [model_cfg['Name']] model_name.extend(model_cfg.get('Alias', [])) for name in model_name: