Skip to content

Commit

Permalink
EvalHook uses case-insensitive key indicator matching and configurabl… (
Browse files Browse the repository at this point in the history
#1076)

* EvalHook uses case-insensitive key indicator matching and configurable test functions

* * fix docstring

* * move test_fn import into __init__
* configurable greater/less keys

* * update unittest
* update DistEvalHook

* fix comments and remove debug code

* support single greater/less key
  • Loading branch information
ly015 authored Jun 24, 2021
1 parent 49a1d34 commit 560719d
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 13 deletions.
79 changes: 68 additions & 11 deletions mmcv/runner/hooks/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from torch.nn.modules.batchnorm import _BatchNorm
from torch.utils.data import DataLoader

from mmcv.utils import is_seq_of
from .hook import Hook


Expand Down Expand Up @@ -41,6 +42,16 @@ class EvalHook(Hook):
.etc will be inferred by 'greater' rule. Keys contain 'loss' will
be inferred by 'less' rule. Options are 'greater', 'less', None.
Default: None.
test_fn (callable, optional): test a model with samples from a
dataloader, and return the test results. If ``None``, the default
test function ``mmcv.engine.single_gpu_test`` will be used.
(default: ``None``)
greater_keys (List[str] | None, optional): Metric keys that will be
inferred by 'greater' comparison rule rule. If ``None``,
_default_greater_keys will be used. (default: ``None``)
less_keys (List[str] | None, optional): Metric keys that will be
inferred by 'less' comparison rule. If ``None``, _default_less_keys
will be used. (default: ``None``)
**eval_kwargs: Evaluation arguments fed into the evaluate function of
the dataset.
Expand All @@ -55,11 +66,11 @@ class EvalHook(Hook):

rule_map = {'greater': lambda x, y: x > y, 'less': lambda x, y: x < y}
init_value_map = {'greater': -inf, 'less': inf}
greater_keys = [
_default_greater_keys = [
'acc', 'top', 'AR@', 'auc', 'precision', 'mAP', 'mDice', 'mIoU',
'mAcc', 'aAcc'
]
less_keys = ['loss']
_default_less_keys = ['loss']

def __init__(self,
dataloader,
Expand All @@ -68,6 +79,9 @@ def __init__(self,
by_epoch=True,
save_best=None,
rule=None,
test_fn=None,
greater_keys=None,
less_keys=None,
**eval_kwargs):
if not isinstance(dataloader, DataLoader):
raise TypeError(f'dataloader must be a pytorch DataLoader, '
Expand Down Expand Up @@ -95,6 +109,28 @@ def __init__(self,
self.eval_kwargs = eval_kwargs
self.initial_flag = True

if test_fn is None:
from mmcv.engine import single_gpu_test
self.test_fn = single_gpu_test
else:
self.test_fn = test_fn

if greater_keys is None:
self.greater_keys = self._default_greater_keys
else:
if not isinstance(greater_keys, (list, tuple)):
greater_keys = (greater_keys, )
assert is_seq_of(greater_keys, str)
self.greater_keys = greater_keys

if less_keys is None:
self.less_keys = self._default_less_keys
else:
if not isinstance(less_keys, (list, tuple)):
less_keys = (less_keys, )
assert is_seq_of(less_keys, str)
self.less_keys = less_keys

if self.save_best is not None:
self.best_ckpt_path = None
self._init_rule(rule, self.save_best)
Expand All @@ -103,7 +139,8 @@ def _init_rule(self, rule, key_indicator):
"""Initialize rule, key_indicator, comparison_func, and best score.
Here is the rule to determine which rule is used for key indicator
when the rule is not specific:
when the rule is not specific (note that the key indicator matching
is case-insensitive):
1. If the key indicator is in ``self.greater_keys``, the rule will be
specified as 'greater'.
2. Or if the key indicator is in ``self.less_keys``, the rule will be
Expand All @@ -124,13 +161,19 @@ def _init_rule(self, rule, key_indicator):

if rule is None:
if key_indicator != 'auto':
if key_indicator in self.greater_keys:
# `_lc` here means we use the lower case of keys for
# case-insensitive matching
key_indicator_lc = key_indicator.lower()
greater_keys = [key.lower() for key in self.greater_keys]
less_keys = [key.lower() for key in self.less_keys]

if key_indicator_lc in greater_keys:
rule = 'greater'
elif key_indicator in self.less_keys:
elif key_indicator_lc in less_keys:
rule = 'less'
elif any(key in key_indicator for key in self.greater_keys):
elif any(key in key_indicator_lc for key in greater_keys):
rule = 'greater'
elif any(key in key_indicator for key in self.less_keys):
elif any(key in key_indicator_lc for key in less_keys):
rule = 'less'
else:
raise ValueError(f'Cannot infer the rule for key '
Expand Down Expand Up @@ -181,8 +224,7 @@ def _do_evaluate(self, runner):
if not self._should_evaluate(runner):
return

from mmcv.engine import single_gpu_test
results = single_gpu_test(runner.model, self.dataloader)
results = self.test_fn(runner.model, self.dataloader)
runner.log_buffer.output['eval_iter_num'] = len(self.dataloader)
key_score = self.evaluate(runner, results)
if self.save_best:
Expand Down Expand Up @@ -311,6 +353,10 @@ class DistEvalHook(EvalHook):
.etc will be inferred by 'greater' rule. Keys contain 'loss' will
be inferred by 'less' rule. Options are 'greater', 'less', None.
Default: None.
test_fn (callable, optional): test a model with samples from a
dataloader in a multi-gpu manner, and return the test results. If
``None``, the default test function ``mmcv.engine.multi_gpu_test``
will be used. (default: ``None``)
tmpdir (str | None): Temporary directory to save the results of all
processes. Default: None.
gpu_collect (bool): Whether to use gpu or cpu to collect results.
Expand All @@ -329,18 +375,30 @@ def __init__(self,
by_epoch=True,
save_best=None,
rule=None,
test_fn=None,
greater_keys=None,
less_keys=None,
broadcast_bn_buffer=True,
tmpdir=None,
gpu_collect=False,
**eval_kwargs):

if test_fn is None:
from mmcv.engine import multi_gpu_test
test_fn = multi_gpu_test

super().__init__(
dataloader,
start=start,
interval=interval,
by_epoch=by_epoch,
save_best=save_best,
rule=rule,
test_fn=test_fn,
greater_keys=greater_keys,
less_keys=less_keys,
**eval_kwargs)

self.broadcast_bn_buffer = broadcast_bn_buffer
self.tmpdir = tmpdir
self.gpu_collect = gpu_collect
Expand All @@ -367,8 +425,7 @@ def _do_evaluate(self, runner):
if tmpdir is None:
tmpdir = osp.join(runner.work_dir, '.eval_hook')

from mmcv.engine import multi_gpu_test
results = multi_gpu_test(
results = self.test_fn(
runner.model,
self.dataloader,
tmpdir=tmpdir,
Expand Down
29 changes: 27 additions & 2 deletions tests/test_runner/test_eval_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,8 @@ def _build_iter_runner():

class EvalHook(BaseEvalHook):

greater_keys = ['acc', 'top']
less_keys = ['loss', 'loss_top']
_default_greater_keys = ['acc', 'top']
_default_less_keys = ['loss', 'loss_top']

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
Expand Down Expand Up @@ -273,6 +273,31 @@ def test_eval_hook():
assert runner.meta['hook_msgs']['best_score'] == 7
assert not osp.exists(old_ckpt_path)

# test EvalHook with customer test_fn and greater/less keys
loader = DataLoader(EvalDataset())
model = Model()
data_loader = DataLoader(EvalDataset())

eval_hook = EvalHook(
data_loader,
save_best='acc',
test_fn=mock.MagicMock(return_value={}),
greater_keys=[],
less_keys=['acc'])

with tempfile.TemporaryDirectory() as tmpdir:
logger = get_logger('test_eval')
runner = EpochBasedRunner(model=model, work_dir=tmpdir, logger=logger)
runner.register_checkpoint_hook(dict(interval=1))
runner.register_hook(eval_hook)
runner.run([loader], [('train', 1)], 8)

ckpt_path = osp.join(tmpdir, 'best_acc_epoch_6.pth')

assert runner.meta['hook_msgs']['best_ckpt'] == ckpt_path
assert osp.exists(ckpt_path)
assert runner.meta['hook_msgs']['best_score'] == -3


@patch('mmcv.engine.single_gpu_test', MagicMock)
@patch('mmcv.engine.multi_gpu_test', MagicMock)
Expand Down

0 comments on commit 560719d

Please sign in to comment.