Skip to content

Commit

Permalink
Add type hints for mmcv/runner (#2002)
Browse files Browse the repository at this point in the history
* Fix

* Fix

* fix type hint

* minor fix

* remove some type hints of functions or methods

* minor fix

* Apply suggestions from code review

* minor fix

* minor refinement

Co-authored-by: HAOCHENYE <[email protected]>
Co-authored-by: Zaida Zhou <[email protected]>
Co-authored-by: zhouzaida <[email protected]>
  • Loading branch information
4 people authored Jun 20, 2022
1 parent b9a96e5 commit 1f25001
Show file tree
Hide file tree
Showing 6 changed files with 44 additions and 38 deletions.
4 changes: 2 additions & 2 deletions mmcv/ops/merge_cells.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ def __init__(self,

if self.with_out_conv:
self.out_conv = ConvModule(
fused_channels,
out_channels,
fused_channels, # type: ignore
out_channels, # type: ignore
**out_conv_cfg,
norm_cfg=out_norm_cfg,
order=out_conv_order)
Expand Down
8 changes: 4 additions & 4 deletions mmcv/runner/base_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,10 @@ def __init__(self, init_cfg: Optional[dict] = None):
# self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)

@property
def is_init(self):
def is_init(self) -> bool:
return self._is_init

def init_weights(self):
def init_weights(self) -> None:
"""Initialize the weights."""

is_top_level_module = False
Expand All @@ -68,7 +68,7 @@ def init_weights(self):
# which indicates whether the parameter has been modified.
# this attribute would be deleted after all parameters
# is initialized.
self._params_init_info = defaultdict(dict)
self._params_init_info: defaultdict = defaultdict(dict)
is_top_level_module = True

# Initialize the `_params_init_info`,
Expand Down Expand Up @@ -134,7 +134,7 @@ def init_weights(self):
del sub_module._params_init_info

@master_only
def _dump_init_info(self, logger_name: str):
def _dump_init_info(self, logger_name: str) -> None:
"""Dump the initialization information to a file named
`initialization.log.json` in workdir.
Expand Down
40 changes: 22 additions & 18 deletions mmcv/runner/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@
from collections import OrderedDict
from importlib import import_module
from tempfile import TemporaryDirectory
from typing import Callable, List, Optional, Sequence, Union
from typing import Callable, Dict, List, Optional, Tuple, Union

import torch
import torch.nn as nn
import torchvision
from torch.optim import Optimizer

Expand All @@ -28,7 +29,7 @@
DEFAULT_CACHE_DIR = '~/.cache'


def _get_mmcv_home():
def _get_mmcv_home() -> str:
mmcv_home = os.path.expanduser(
os.getenv(
ENV_MMCV_HOME,
Expand All @@ -39,7 +40,7 @@ def _get_mmcv_home():
return mmcv_home


def load_state_dict(module: torch.nn.Module,
def load_state_dict(module: nn.Module,
state_dict: Union[dict, OrderedDict],
strict: bool = False,
logger: Optional[logging.Logger] = None) -> None:
Expand All @@ -51,19 +52,19 @@ def load_state_dict(module: torch.nn.Module,
Args:
module (Module): Module that receives the state_dict.
state_dict (OrderedDict): Weights.
state_dict (dict or OrderedDict): Weights.
strict (bool): whether to strictly enforce that the keys
in :attr:`state_dict` match the keys returned by this module's
:meth:`~torch.nn.Module.state_dict` function. Default: ``False``.
logger (:obj:`logging.Logger`, optional): Logger to log the error
message. If not specified, print function will be used.
"""
unexpected_keys: List = []
all_missing_keys: List = []
err_msg: List = []
unexpected_keys: List[str] = []
all_missing_keys: List[str] = []
err_msg: List[str] = []

metadata = getattr(state_dict, '_metadata', None)
state_dict = state_dict.copy()
state_dict = state_dict.copy() # type: ignore
if metadata is not None:
state_dict._metadata = metadata # type: ignore

Expand Down Expand Up @@ -187,7 +188,7 @@ def get_deprecated_model_names():
return deprecate_urls


def _process_mmcls_checkpoint(checkpoint):
def _process_mmcls_checkpoint(checkpoint: Dict) -> Dict:
if 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
else:
Expand All @@ -209,7 +210,10 @@ class CheckpointLoader:
_schemes: dict = {}

@classmethod
def _register_scheme(cls, prefixes, loader, force=False):
def _register_scheme(cls,
prefixes: Union[str, List, Tuple],
loader: Callable,
force: bool = False) -> None:
if isinstance(prefixes, str):
prefixes = [prefixes]
else:
Expand All @@ -227,9 +231,9 @@ def _register_scheme(cls, prefixes, loader, force=False):

@classmethod
def register_scheme(cls,
prefixes: Union[str, Sequence[str]],
prefixes: Union[str, List[str], Tuple[str, ...]],
loader: Optional[Callable] = None,
force: bool = False):
force: bool = False) -> Callable:
"""Register a loader to CheckpointLoader.
This method can be used as a normal class method or a decorator.
Expand All @@ -246,7 +250,7 @@ def register_scheme(cls,

if loader is not None:
cls._register_scheme(prefixes, loader, force=force)
return
return # type: ignore

def _register(loader_cls):
cls._register_scheme(prefixes, loader_cls, force=force)
Expand All @@ -255,7 +259,7 @@ def _register(loader_cls):
return _register

@classmethod
def _get_checkpoint_loader(cls, path):
def _get_checkpoint_loader(cls, path: str):
"""Finds a loader that supports the given path. Falls back to the local
loader if no other loader is found.
Expand Down Expand Up @@ -293,10 +297,10 @@ def load_checkpoint(
"""

checkpoint_loader = cls._get_checkpoint_loader(filename)
class_name = checkpoint_loader.__name__
class_name = checkpoint_loader.__name__ # type: ignore
mmcv.print_log(
f'load checkpoint from {class_name[10:]} path: {filename}', logger)
return checkpoint_loader(filename, map_location)
return checkpoint_loader(filename, map_location) # type: ignore


@CheckpointLoader.register_scheme(prefixes='')
Expand Down Expand Up @@ -719,7 +723,7 @@ def get_state_dict(module: torch.nn.Module,
destination._metadata = OrderedDict() # type: ignore
destination._metadata[prefix[:-1]] = local_metadata = dict( # type: ignore
version=module._version)
_save_to_state_dict(module, destination, prefix, keep_vars)
_save_to_state_dict(module, destination, prefix, keep_vars) # type: ignore
for name, child in module._modules.items():
if child is not None:
get_state_dict(
Expand Down Expand Up @@ -766,7 +770,7 @@ def save_checkpoint(model: torch.nn.Module,

checkpoint = {
'meta': meta,
'state_dict': weights_to_cpu(get_state_dict(model))
'state_dict': weights_to_cpu(get_state_dict(model)) # type: ignore
}
# save optimizer state dict in the checkpoint
if isinstance(optimizer, Optimizer):
Expand Down
14 changes: 8 additions & 6 deletions mmcv/runner/dist_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from mmcv.utils import IS_MLU_AVAILABLE


def _find_free_port():
def _find_free_port() -> str:
# Copied from https://github.com/facebookresearch/detectron2/blob/main/detectron2/engine/launch.py # noqa: E501
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
# Binding to port 0 will cause the OS to find an available port for us
Expand All @@ -27,7 +27,7 @@ def _find_free_port():
return port


def _is_free_port(port):
def _is_free_port(port: int) -> bool:
ips = socket.gethostbyname_ex(socket.gethostname())[-1]
ips.append('localhost')
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
Expand All @@ -47,7 +47,7 @@ def init_dist(launcher: str, backend: str = 'nccl', **kwargs) -> None:
raise ValueError(f'Invalid launcher type: {launcher}')


def _init_dist_pytorch(backend: str, **kwargs):
def _init_dist_pytorch(backend: str, **kwargs) -> None:
# TODO: use local_rank instead of rank % num_gpus
rank = int(os.environ['RANK'])
if IS_MLU_AVAILABLE:
Expand All @@ -64,7 +64,7 @@ def _init_dist_pytorch(backend: str, **kwargs):
dist.init_process_group(backend=backend, **kwargs)


def _init_dist_mpi(backend: str, **kwargs):
def _init_dist_mpi(backend: str, **kwargs) -> None:
local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
torch.cuda.set_device(local_rank)
if 'MASTER_PORT' not in os.environ:
Expand All @@ -77,7 +77,7 @@ def _init_dist_mpi(backend: str, **kwargs):
dist.init_process_group(backend=backend, **kwargs)


def _init_dist_slurm(backend: str, port: Optional[int] = None):
def _init_dist_slurm(backend: str, port: Optional[int] = None) -> None:
"""Initialize slurm distributed training environment.
If argument ``port`` is not specified, then the master port will be system
Expand Down Expand Up @@ -187,7 +187,9 @@ def allreduce_grads(params: List[torch.nn.Parameter],
dist.all_reduce(tensor.div_(world_size))


def _allreduce_coalesced(tensors, world_size, bucket_size_mb=-1):
def _allreduce_coalesced(tensors: torch.Tensor,
world_size: int,
bucket_size_mb: int = -1) -> None:
if bucket_size_mb > 0:
bucket_size_bytes = bucket_size_mb * 1024 * 1024
buckets = _take_tensors(tensors, bucket_size_bytes)
Expand Down
12 changes: 6 additions & 6 deletions mmcv/runner/fp16_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,10 +103,10 @@ def auto_fp16(
>>> pass
"""

def auto_fp16_wrapper(old_func):
def auto_fp16_wrapper(old_func: Callable) -> Callable:

@functools.wraps(old_func)
def new_func(*args, **kwargs):
def new_func(*args, **kwargs) -> Callable:
# check if the module has set the attribute `fp16_enabled`, if not,
# just fallback to the original method.
if not isinstance(args[0], supported_types):
Expand Down Expand Up @@ -195,7 +195,7 @@ def force_fp32(apply_to: Optional[Iterable] = None,
def force_fp32_wrapper(old_func):

@functools.wraps(old_func)
def new_func(*args, **kwargs):
def new_func(*args, **kwargs) -> Callable:
# check if the module has set the attribute `fp16_enabled`, if not,
# just fallback to the original method.
if not isinstance(args[0], torch.nn.Module):
Expand Down Expand Up @@ -380,7 +380,7 @@ def has_overflow(self, params: List[Parameter]) -> bool:
return True
return False

def _has_inf_or_nan(x):
def _has_inf_or_nan(x: torch.Tensor) -> bool:
"""Check if params contain NaN."""
try:
cpu_sum = float(x.float().sum())
Expand All @@ -407,7 +407,7 @@ def update_scale(self, overflow: bool) -> None:
self.cur_scale *= self.scale_factor
self.cur_iter += 1

def state_dict(self):
def state_dict(self) -> dict:
"""Returns the state of the scaler as a :class:`dict`."""
return dict(
cur_scale=self.cur_scale,
Expand All @@ -431,5 +431,5 @@ def load_state_dict(self, state_dict: dict) -> None:
self.scale_window = state_dict['scale_window']

@property
def loss_scale(self):
def loss_scale(self) -> float:
return self.cur_scale
4 changes: 2 additions & 2 deletions mmcv/runner/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import mmcv


def get_host_info():
def get_host_info() -> str:
"""Get hostname and username.
Return empty string if exception raised, e.g. ``getpass.getuser()`` will
Expand All @@ -30,7 +30,7 @@ def get_host_info():
return host


def get_time_str():
def get_time_str() -> str:
return time.strftime('%Y%m%d_%H%M%S', time.localtime())


Expand Down

0 comments on commit 1f25001

Please sign in to comment.