diff --git a/mmcv/runner/base_runner.py b/mmcv/runner/base_runner.py index 263548d1e7..870daa5f72 100644 --- a/mmcv/runner/base_runner.py +++ b/mmcv/runner/base_runner.py @@ -394,7 +394,7 @@ def register_lr_hook(self, lr_config): hook = mmcv.build_from_cfg(lr_config, HOOKS) else: hook = lr_config - self.register_hook(hook, priority=10) + self.register_hook(hook, priority='VERY_HIGH') def register_momentum_hook(self, momentum_config): if momentum_config is None: @@ -415,7 +415,7 @@ def register_momentum_hook(self, momentum_config): hook = mmcv.build_from_cfg(momentum_config, HOOKS) else: hook = momentum_config - self.register_hook(hook, priority=30) + self.register_hook(hook, priority='HIGH') def register_optimizer_hook(self, optimizer_config): if optimizer_config is None: @@ -425,7 +425,7 @@ def register_optimizer_hook(self, optimizer_config): hook = mmcv.build_from_cfg(optimizer_config, HOOKS) else: hook = optimizer_config - self.register_hook(hook, priority=50) + self.register_hook(hook, priority='ABOVE_NORMAL') def register_checkpoint_hook(self, checkpoint_config): if checkpoint_config is None: @@ -435,7 +435,7 @@ def register_checkpoint_hook(self, checkpoint_config): hook = mmcv.build_from_cfg(checkpoint_config, HOOKS) else: hook = checkpoint_config - self.register_hook(hook, priority=70) + self.register_hook(hook, priority='NORMAL') def register_logger_hooks(self, log_config): if log_config is None: @@ -444,7 +444,7 @@ def register_logger_hooks(self, log_config): for info in log_config['hooks']: logger_hook = mmcv.build_from_cfg( info, HOOKS, default_args=dict(interval=log_interval)) - self.register_hook(logger_hook, priority=90) + self.register_hook(logger_hook, priority='VERY_LOW') def register_timer_hook(self, timer_config): if timer_config is None: @@ -454,7 +454,7 @@ def register_timer_hook(self, timer_config): hook = mmcv.build_from_cfg(timer_config_, HOOKS) else: hook = timer_config - self.register_hook(hook, priority=80) + self.register_hook(hook, priority='LOW') def register_custom_hooks(self, custom_config): if custom_config is None: @@ -491,14 +491,26 @@ def register_training_hooks(self, Default and custom hooks include: - Hooks Priority - - LrUpdaterHook 10 - - MomentumUpdaterHook 30 - - OptimizerStepperHook 50 - - CheckpointSaverHook 70 - - IterTimerHook 80 - - LoggerHook(s) 90 - - CustomHook(s) 50 (default) + +----------------------+-------------------------+ + | Hooks | Priority | + +======================+=========================+ + | LrUpdaterHook | VERY_HIGH (10) | + +----------------------+-------------------------+ + | MomentumUpdaterHook | HIGH (30) | + +----------------------+-------------------------+ + | OptimizerStepperHook | ABOVE_NORMAL (40) | + +----------------------+-------------------------+ + | CheckpointSaverHook | NORMAL (50) | + +----------------------+-------------------------+ + | IterTimerHook | LOW (70) | + +----------------------+-------------------------+ + | LoggerHook(s) | VERY_LOW (90) | + +----------------------+-------------------------+ + | CustomHook(s) | defaults to NORMAL (50) | + +----------------------+-------------------------+ + + If custom hooks have same priority with default hooks, custom hooks + will be triggered after default hooks. """ self.register_lr_hook(lr_config) self.register_momentum_hook(momentum_config) diff --git a/mmcv/runner/priority.py b/mmcv/runner/priority.py index b58c67e313..4a9383aa4e 100644 --- a/mmcv/runner/priority.py +++ b/mmcv/runner/priority.py @@ -5,29 +5,35 @@ class Priority(Enum): """Hook priority levels. - +------------+------------+ - | Level | Value | - +============+============+ - | HIGHEST | 0 | - +------------+------------+ - | VERY_HIGH | 10 | - +------------+------------+ - | HIGH | 30 | - +------------+------------+ - | NORMAL | 50 | - +------------+------------+ - | LOW | 70 | - +------------+------------+ - | VERY_LOW | 90 | - +------------+------------+ - | LOWEST | 100 | - +------------+------------+ + +--------------+------------+ + | Level | Value | + +==============+============+ + | HIGHEST | 0 | + +--------------+------------+ + | VERY_HIGH | 10 | + +--------------+------------+ + | HIGH | 30 | + +--------------+------------+ + | ABOVE_NORMAL | 40 | + +--------------+------------+ + | NORMAL | 50 | + +--------------+------------+ + | BELOW_NORMAL | 60 | + +--------------+------------+ + | LOW | 70 | + +--------------+------------+ + | VERY_LOW | 90 | + +--------------+------------+ + | LOWEST | 100 | + +--------------+------------+ """ HIGHEST = 0 VERY_HIGH = 10 HIGH = 30 + ABOVE_NORMAL = 40 NORMAL = 50 + BELOW_NORMAL = 60 LOW = 70 VERY_LOW = 90 LOWEST = 100 diff --git a/tests/test_runner/test_hooks.py b/tests/test_runner/test_hooks.py index fa63129d98..3f9ba7c03a 100644 --- a/tests/test_runner/test_hooks.py +++ b/tests/test_runner/test_hooks.py @@ -6,6 +6,7 @@ """ import logging import os.path as osp +import random import re import shutil import sys @@ -149,10 +150,27 @@ def __init__(self, info, *args, **kwargs): assert len(runner.hooks) == 3 and runner.hooks[1].info == 'default' shutil.rmtree(runner.work_dir) + runner = _build_demo_runner_without_hook('EpochBasedRunner', max_epochs=1) + # test custom_hooks with string priority setting + priority_ranks = [ + 'HIGHEST', 'VERY_HIGH', 'HIGH', 'ABOVE_NORMAL', 'NORMAL', + 'BELOW_NORMAL', 'LOW', 'VERY_LOW', 'LOWEST' + ] + random_priority_ranks = priority_ranks.copy() + random.shuffle(random_priority_ranks) + custom_hooks_cfg = [ + dict(type='ToyHook', priority=rank, info=rank) + for rank in random_priority_ranks + ] + runner.register_custom_hooks(custom_hooks_cfg) + assert [hook.info for hook in runner.hooks] == priority_ranks + shutil.rmtree(runner.work_dir) + runner = _build_demo_runner_without_hook('EpochBasedRunner', max_epochs=1) # test register_training_hooks order custom_hooks_cfg = [ dict(type='ToyHook', priority=1, info='custom 1'), + dict(type='ToyHook', priority='NORMAL', info='custom normal'), dict(type='ToyHook', priority=89, info='custom 89') ] runner.register_training_hooks( @@ -163,9 +181,11 @@ def __init__(self, info, *args, **kwargs): momentum_config=ToyHook('momentum'), timer_config=ToyHook('timer'), custom_hooks_config=custom_hooks_cfg) + # If custom hooks have same priority with default hooks, custom hooks + # will be triggered after default hooks. hooks_order = [ - 'custom 1', 'lr', 'momentum', 'optimizer', 'checkpoint', 'timer', - 'custom 89', 'log' + 'custom 1', 'lr', 'momentum', 'optimizer', 'checkpoint', + 'custom normal', 'timer', 'custom 89', 'log' ] assert [hook.info for hook in runner.hooks] == hooks_order shutil.rmtree(runner.work_dir)