diff --git a/mmcv/ops/merge_cells.py b/mmcv/ops/merge_cells.py index a75fc434ad..19c3fe6582 100644 --- a/mmcv/ops/merge_cells.py +++ b/mmcv/ops/merge_cells.py @@ -64,8 +64,8 @@ def __init__(self, if self.with_out_conv: self.out_conv = ConvModule( - fused_channels, - out_channels, + fused_channels, # type: ignore + out_channels, # type: ignore **out_conv_cfg, norm_cfg=out_norm_cfg, order=out_conv_order) diff --git a/mmcv/runner/base_module.py b/mmcv/runner/base_module.py index 7e64bdfb17..845e8c8ff2 100644 --- a/mmcv/runner/base_module.py +++ b/mmcv/runner/base_module.py @@ -50,10 +50,10 @@ def __init__(self, init_cfg: Optional[dict] = None): # self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) @property - def is_init(self): + def is_init(self) -> bool: return self._is_init - def init_weights(self): + def init_weights(self) -> None: """Initialize the weights.""" is_top_level_module = False @@ -68,7 +68,7 @@ def init_weights(self): # which indicates whether the parameter has been modified. # this attribute would be deleted after all parameters # is initialized. - self._params_init_info = defaultdict(dict) + self._params_init_info: defaultdict = defaultdict(dict) is_top_level_module = True # Initialize the `_params_init_info`, @@ -134,7 +134,7 @@ def init_weights(self): del sub_module._params_init_info @master_only - def _dump_init_info(self, logger_name: str): + def _dump_init_info(self, logger_name: str) -> None: """Dump the initialization information to a file named `initialization.log.json` in workdir. diff --git a/mmcv/runner/checkpoint.py b/mmcv/runner/checkpoint.py index aa042452ca..0811856642 100644 --- a/mmcv/runner/checkpoint.py +++ b/mmcv/runner/checkpoint.py @@ -10,9 +10,10 @@ from collections import OrderedDict from importlib import import_module from tempfile import TemporaryDirectory -from typing import Callable, List, Optional, Sequence, Union +from typing import Callable, Dict, List, Optional, Tuple, Union import torch +import torch.nn as nn import torchvision from torch.optim import Optimizer @@ -28,7 +29,7 @@ DEFAULT_CACHE_DIR = '~/.cache' -def _get_mmcv_home(): +def _get_mmcv_home() -> str: mmcv_home = os.path.expanduser( os.getenv( ENV_MMCV_HOME, @@ -39,7 +40,7 @@ def _get_mmcv_home(): return mmcv_home -def load_state_dict(module: torch.nn.Module, +def load_state_dict(module: nn.Module, state_dict: Union[dict, OrderedDict], strict: bool = False, logger: Optional[logging.Logger] = None) -> None: @@ -51,19 +52,19 @@ def load_state_dict(module: torch.nn.Module, Args: module (Module): Module that receives the state_dict. - state_dict (OrderedDict): Weights. + state_dict (dict or OrderedDict): Weights. strict (bool): whether to strictly enforce that the keys in :attr:`state_dict` match the keys returned by this module's :meth:`~torch.nn.Module.state_dict` function. Default: ``False``. logger (:obj:`logging.Logger`, optional): Logger to log the error message. If not specified, print function will be used. """ - unexpected_keys: List = [] - all_missing_keys: List = [] - err_msg: List = [] + unexpected_keys: List[str] = [] + all_missing_keys: List[str] = [] + err_msg: List[str] = [] metadata = getattr(state_dict, '_metadata', None) - state_dict = state_dict.copy() + state_dict = state_dict.copy() # type: ignore if metadata is not None: state_dict._metadata = metadata # type: ignore @@ -187,7 +188,7 @@ def get_deprecated_model_names(): return deprecate_urls -def _process_mmcls_checkpoint(checkpoint): +def _process_mmcls_checkpoint(checkpoint: Dict) -> Dict: if 'state_dict' in checkpoint: state_dict = checkpoint['state_dict'] else: @@ -209,7 +210,10 @@ class CheckpointLoader: _schemes: dict = {} @classmethod - def _register_scheme(cls, prefixes, loader, force=False): + def _register_scheme(cls, + prefixes: Union[str, List, Tuple], + loader: Callable, + force: bool = False) -> None: if isinstance(prefixes, str): prefixes = [prefixes] else: @@ -227,9 +231,9 @@ def _register_scheme(cls, prefixes, loader, force=False): @classmethod def register_scheme(cls, - prefixes: Union[str, Sequence[str]], + prefixes: Union[str, List[str], Tuple[str, ...]], loader: Optional[Callable] = None, - force: bool = False): + force: bool = False) -> Callable: """Register a loader to CheckpointLoader. This method can be used as a normal class method or a decorator. @@ -246,7 +250,7 @@ def register_scheme(cls, if loader is not None: cls._register_scheme(prefixes, loader, force=force) - return + return # type: ignore def _register(loader_cls): cls._register_scheme(prefixes, loader_cls, force=force) @@ -255,7 +259,7 @@ def _register(loader_cls): return _register @classmethod - def _get_checkpoint_loader(cls, path): + def _get_checkpoint_loader(cls, path: str): """Finds a loader that supports the given path. Falls back to the local loader if no other loader is found. @@ -293,10 +297,10 @@ def load_checkpoint( """ checkpoint_loader = cls._get_checkpoint_loader(filename) - class_name = checkpoint_loader.__name__ + class_name = checkpoint_loader.__name__ # type: ignore mmcv.print_log( f'load checkpoint from {class_name[10:]} path: {filename}', logger) - return checkpoint_loader(filename, map_location) + return checkpoint_loader(filename, map_location) # type: ignore @CheckpointLoader.register_scheme(prefixes='') @@ -719,7 +723,7 @@ def get_state_dict(module: torch.nn.Module, destination._metadata = OrderedDict() # type: ignore destination._metadata[prefix[:-1]] = local_metadata = dict( # type: ignore version=module._version) - _save_to_state_dict(module, destination, prefix, keep_vars) + _save_to_state_dict(module, destination, prefix, keep_vars) # type: ignore for name, child in module._modules.items(): if child is not None: get_state_dict( @@ -766,7 +770,7 @@ def save_checkpoint(model: torch.nn.Module, checkpoint = { 'meta': meta, - 'state_dict': weights_to_cpu(get_state_dict(model)) + 'state_dict': weights_to_cpu(get_state_dict(model)) # type: ignore } # save optimizer state dict in the checkpoint if isinstance(optimizer, Optimizer): diff --git a/mmcv/runner/dist_utils.py b/mmcv/runner/dist_utils.py index abed57d2ca..ee55dfda36 100644 --- a/mmcv/runner/dist_utils.py +++ b/mmcv/runner/dist_utils.py @@ -16,7 +16,7 @@ from mmcv.utils import IS_MLU_AVAILABLE -def _find_free_port(): +def _find_free_port() -> str: # Copied from https://github.com/facebookresearch/detectron2/blob/main/detectron2/engine/launch.py # noqa: E501 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) # Binding to port 0 will cause the OS to find an available port for us @@ -27,7 +27,7 @@ def _find_free_port(): return port -def _is_free_port(port): +def _is_free_port(port: int) -> bool: ips = socket.gethostbyname_ex(socket.gethostname())[-1] ips.append('localhost') with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: @@ -47,7 +47,7 @@ def init_dist(launcher: str, backend: str = 'nccl', **kwargs) -> None: raise ValueError(f'Invalid launcher type: {launcher}') -def _init_dist_pytorch(backend: str, **kwargs): +def _init_dist_pytorch(backend: str, **kwargs) -> None: # TODO: use local_rank instead of rank % num_gpus rank = int(os.environ['RANK']) if IS_MLU_AVAILABLE: @@ -64,7 +64,7 @@ def _init_dist_pytorch(backend: str, **kwargs): dist.init_process_group(backend=backend, **kwargs) -def _init_dist_mpi(backend: str, **kwargs): +def _init_dist_mpi(backend: str, **kwargs) -> None: local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) torch.cuda.set_device(local_rank) if 'MASTER_PORT' not in os.environ: @@ -77,7 +77,7 @@ def _init_dist_mpi(backend: str, **kwargs): dist.init_process_group(backend=backend, **kwargs) -def _init_dist_slurm(backend: str, port: Optional[int] = None): +def _init_dist_slurm(backend: str, port: Optional[int] = None) -> None: """Initialize slurm distributed training environment. If argument ``port`` is not specified, then the master port will be system @@ -187,7 +187,9 @@ def allreduce_grads(params: List[torch.nn.Parameter], dist.all_reduce(tensor.div_(world_size)) -def _allreduce_coalesced(tensors, world_size, bucket_size_mb=-1): +def _allreduce_coalesced(tensors: torch.Tensor, + world_size: int, + bucket_size_mb: int = -1) -> None: if bucket_size_mb > 0: bucket_size_bytes = bucket_size_mb * 1024 * 1024 buckets = _take_tensors(tensors, bucket_size_bytes) diff --git a/mmcv/runner/fp16_utils.py b/mmcv/runner/fp16_utils.py index 9deb228a5d..4674d27a44 100644 --- a/mmcv/runner/fp16_utils.py +++ b/mmcv/runner/fp16_utils.py @@ -103,10 +103,10 @@ def auto_fp16( >>> pass """ - def auto_fp16_wrapper(old_func): + def auto_fp16_wrapper(old_func: Callable) -> Callable: @functools.wraps(old_func) - def new_func(*args, **kwargs): + def new_func(*args, **kwargs) -> Callable: # check if the module has set the attribute `fp16_enabled`, if not, # just fallback to the original method. if not isinstance(args[0], supported_types): @@ -195,7 +195,7 @@ def force_fp32(apply_to: Optional[Iterable] = None, def force_fp32_wrapper(old_func): @functools.wraps(old_func) - def new_func(*args, **kwargs): + def new_func(*args, **kwargs) -> Callable: # check if the module has set the attribute `fp16_enabled`, if not, # just fallback to the original method. if not isinstance(args[0], torch.nn.Module): @@ -380,7 +380,7 @@ def has_overflow(self, params: List[Parameter]) -> bool: return True return False - def _has_inf_or_nan(x): + def _has_inf_or_nan(x: torch.Tensor) -> bool: """Check if params contain NaN.""" try: cpu_sum = float(x.float().sum()) @@ -407,7 +407,7 @@ def update_scale(self, overflow: bool) -> None: self.cur_scale *= self.scale_factor self.cur_iter += 1 - def state_dict(self): + def state_dict(self) -> dict: """Returns the state of the scaler as a :class:`dict`.""" return dict( cur_scale=self.cur_scale, @@ -431,5 +431,5 @@ def load_state_dict(self, state_dict: dict) -> None: self.scale_window = state_dict['scale_window'] @property - def loss_scale(self): + def loss_scale(self) -> float: return self.cur_scale diff --git a/mmcv/runner/utils.py b/mmcv/runner/utils.py index 34e4ed5a68..8cdc6faddb 100644 --- a/mmcv/runner/utils.py +++ b/mmcv/runner/utils.py @@ -15,7 +15,7 @@ import mmcv -def get_host_info(): +def get_host_info() -> str: """Get hostname and username. Return empty string if exception raised, e.g. ``getpass.getuser()`` will @@ -30,7 +30,7 @@ def get_host_info(): return host -def get_time_str(): +def get_time_str() -> str: return time.strftime('%Y%m%d_%H%M%S', time.localtime())