Skip to content

Commit

Permalink
Refine default hooks and custom hooks priority rank. (#1120)
Browse files Browse the repository at this point in the history
* Refine default hooks and custom hooks priority rank.

* Add unit tests for custom hooks with string priority.

* Use priority `ABOVE_NORMAL` and `BELOW_NORMAL` instead of `HIGHER` and
`LOWER`.

And add unit tests for custom hook with the same priority as
default hooks.
  • Loading branch information
mzr1996 authored Jun 25, 2021
1 parent d9effbd commit 6fe3722
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 33 deletions.
40 changes: 26 additions & 14 deletions mmcv/runner/base_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,7 @@ def register_lr_hook(self, lr_config):
hook = mmcv.build_from_cfg(lr_config, HOOKS)
else:
hook = lr_config
self.register_hook(hook, priority=10)
self.register_hook(hook, priority='VERY_HIGH')

def register_momentum_hook(self, momentum_config):
if momentum_config is None:
Expand All @@ -415,7 +415,7 @@ def register_momentum_hook(self, momentum_config):
hook = mmcv.build_from_cfg(momentum_config, HOOKS)
else:
hook = momentum_config
self.register_hook(hook, priority=30)
self.register_hook(hook, priority='HIGH')

def register_optimizer_hook(self, optimizer_config):
if optimizer_config is None:
Expand All @@ -425,7 +425,7 @@ def register_optimizer_hook(self, optimizer_config):
hook = mmcv.build_from_cfg(optimizer_config, HOOKS)
else:
hook = optimizer_config
self.register_hook(hook, priority=50)
self.register_hook(hook, priority='ABOVE_NORMAL')

def register_checkpoint_hook(self, checkpoint_config):
if checkpoint_config is None:
Expand All @@ -435,7 +435,7 @@ def register_checkpoint_hook(self, checkpoint_config):
hook = mmcv.build_from_cfg(checkpoint_config, HOOKS)
else:
hook = checkpoint_config
self.register_hook(hook, priority=70)
self.register_hook(hook, priority='NORMAL')

def register_logger_hooks(self, log_config):
if log_config is None:
Expand All @@ -444,7 +444,7 @@ def register_logger_hooks(self, log_config):
for info in log_config['hooks']:
logger_hook = mmcv.build_from_cfg(
info, HOOKS, default_args=dict(interval=log_interval))
self.register_hook(logger_hook, priority=90)
self.register_hook(logger_hook, priority='VERY_LOW')

def register_timer_hook(self, timer_config):
if timer_config is None:
Expand All @@ -454,7 +454,7 @@ def register_timer_hook(self, timer_config):
hook = mmcv.build_from_cfg(timer_config_, HOOKS)
else:
hook = timer_config
self.register_hook(hook, priority=80)
self.register_hook(hook, priority='LOW')

def register_custom_hooks(self, custom_config):
if custom_config is None:
Expand Down Expand Up @@ -491,14 +491,26 @@ def register_training_hooks(self,
Default and custom hooks include:
Hooks Priority
- LrUpdaterHook 10
- MomentumUpdaterHook 30
- OptimizerStepperHook 50
- CheckpointSaverHook 70
- IterTimerHook 80
- LoggerHook(s) 90
- CustomHook(s) 50 (default)
+----------------------+-------------------------+
| Hooks | Priority |
+======================+=========================+
| LrUpdaterHook | VERY_HIGH (10) |
+----------------------+-------------------------+
| MomentumUpdaterHook | HIGH (30) |
+----------------------+-------------------------+
| OptimizerStepperHook | ABOVE_NORMAL (40) |
+----------------------+-------------------------+
| CheckpointSaverHook | NORMAL (50) |
+----------------------+-------------------------+
| IterTimerHook | LOW (70) |
+----------------------+-------------------------+
| LoggerHook(s) | VERY_LOW (90) |
+----------------------+-------------------------+
| CustomHook(s) | defaults to NORMAL (50) |
+----------------------+-------------------------+
If custom hooks have same priority with default hooks, custom hooks
will be triggered after default hooks.
"""
self.register_lr_hook(lr_config)
self.register_momentum_hook(momentum_config)
Expand Down
40 changes: 23 additions & 17 deletions mmcv/runner/priority.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,29 +5,35 @@
class Priority(Enum):
"""Hook priority levels.
+------------+------------+
| Level | Value |
+============+============+
| HIGHEST | 0 |
+------------+------------+
| VERY_HIGH | 10 |
+------------+------------+
| HIGH | 30 |
+------------+------------+
| NORMAL | 50 |
+------------+------------+
| LOW | 70 |
+------------+------------+
| VERY_LOW | 90 |
+------------+------------+
| LOWEST | 100 |
+------------+------------+
+--------------+------------+
| Level | Value |
+==============+============+
| HIGHEST | 0 |
+--------------+------------+
| VERY_HIGH | 10 |
+--------------+------------+
| HIGH | 30 |
+--------------+------------+
| ABOVE_NORMAL | 40 |
+--------------+------------+
| NORMAL | 50 |
+--------------+------------+
| BELOW_NORMAL | 60 |
+--------------+------------+
| LOW | 70 |
+--------------+------------+
| VERY_LOW | 90 |
+--------------+------------+
| LOWEST | 100 |
+--------------+------------+
"""

HIGHEST = 0
VERY_HIGH = 10
HIGH = 30
ABOVE_NORMAL = 40
NORMAL = 50
BELOW_NORMAL = 60
LOW = 70
VERY_LOW = 90
LOWEST = 100
Expand Down
24 changes: 22 additions & 2 deletions tests/test_runner/test_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
"""
import logging
import os.path as osp
import random
import re
import shutil
import sys
Expand Down Expand Up @@ -149,10 +150,27 @@ def __init__(self, info, *args, **kwargs):
assert len(runner.hooks) == 3 and runner.hooks[1].info == 'default'
shutil.rmtree(runner.work_dir)

runner = _build_demo_runner_without_hook('EpochBasedRunner', max_epochs=1)
# test custom_hooks with string priority setting
priority_ranks = [
'HIGHEST', 'VERY_HIGH', 'HIGH', 'ABOVE_NORMAL', 'NORMAL',
'BELOW_NORMAL', 'LOW', 'VERY_LOW', 'LOWEST'
]
random_priority_ranks = priority_ranks.copy()
random.shuffle(random_priority_ranks)
custom_hooks_cfg = [
dict(type='ToyHook', priority=rank, info=rank)
for rank in random_priority_ranks
]
runner.register_custom_hooks(custom_hooks_cfg)
assert [hook.info for hook in runner.hooks] == priority_ranks
shutil.rmtree(runner.work_dir)

runner = _build_demo_runner_without_hook('EpochBasedRunner', max_epochs=1)
# test register_training_hooks order
custom_hooks_cfg = [
dict(type='ToyHook', priority=1, info='custom 1'),
dict(type='ToyHook', priority='NORMAL', info='custom normal'),
dict(type='ToyHook', priority=89, info='custom 89')
]
runner.register_training_hooks(
Expand All @@ -163,9 +181,11 @@ def __init__(self, info, *args, **kwargs):
momentum_config=ToyHook('momentum'),
timer_config=ToyHook('timer'),
custom_hooks_config=custom_hooks_cfg)
# If custom hooks have same priority with default hooks, custom hooks
# will be triggered after default hooks.
hooks_order = [
'custom 1', 'lr', 'momentum', 'optimizer', 'checkpoint', 'timer',
'custom 89', 'log'
'custom 1', 'lr', 'momentum', 'optimizer', 'checkpoint',
'custom normal', 'timer', 'custom 89', 'log'
]
assert [hook.info for hook in runner.hooks] == hooks_order
shutil.rmtree(runner.work_dir)
Expand Down

0 comments on commit 6fe3722

Please sign in to comment.