Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

EvalHook uses case-insensitive key indicator matching and configurabl… #1076

Merged
merged 6 commits into from
Jun 24, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 68 additions & 11 deletions mmcv/runner/hooks/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from torch.nn.modules.batchnorm import _BatchNorm
from torch.utils.data import DataLoader

from mmcv.utils import is_seq_of
from .hook import Hook


Expand Down Expand Up @@ -41,6 +42,16 @@ class EvalHook(Hook):
.etc will be inferred by 'greater' rule. Keys contain 'loss' will
be inferred by 'less' rule. Options are 'greater', 'less', None.
Default: None.
test_fn (callable, optional): test a model with samples from a
dataloader, and return the test results. If ``None``, the default
test function ``mmcv.engine.single_gpu_test`` will be used.
(default: ``None``)
greater_keys (List[str] | None, optional): Metric keys that will be
inferred by 'greater' comparison rule rule. If ``None``,
_default_greater_keys will be used. (default: ``None``)
less_keys (List[str] | None, optional): Metric keys that will be
inferred by 'less' comparison rule. If ``None``, _default_less_keys
will be used. (default: ``None``)
**eval_kwargs: Evaluation arguments fed into the evaluate function of
the dataset.

Expand All @@ -55,11 +66,11 @@ class EvalHook(Hook):

rule_map = {'greater': lambda x, y: x > y, 'less': lambda x, y: x < y}
init_value_map = {'greater': -inf, 'less': inf}
greater_keys = [
_default_greater_keys = [
'acc', 'top', 'AR@', 'auc', 'precision', 'mAP', 'mDice', 'mIoU',
'mAcc', 'aAcc'
]
less_keys = ['loss']
_default_less_keys = ['loss']

def __init__(self,
dataloader,
Expand All @@ -68,6 +79,9 @@ def __init__(self,
by_epoch=True,
save_best=None,
rule=None,
test_fn=None,
greater_keys=None,
less_keys=None,
**eval_kwargs):
if not isinstance(dataloader, DataLoader):
raise TypeError(f'dataloader must be a pytorch DataLoader, '
Expand Down Expand Up @@ -95,6 +109,28 @@ def __init__(self,
self.eval_kwargs = eval_kwargs
self.initial_flag = True

if test_fn is None:
from mmcv.engine import single_gpu_test
self.test_fn = single_gpu_test
else:
self.test_fn = test_fn

if greater_keys is None:
self.greater_keys = self._default_greater_keys
else:
if not isinstance(greater_keys, (list, tuple)):
greater_keys = (greater_keys, )
assert is_seq_of(greater_keys, str)
self.greater_keys = greater_keys

if less_keys is None:
self.less_keys = self._default_less_keys
else:
if not isinstance(less_keys, (list, tuple)):
less_keys = (less_keys, )
assert is_seq_of(less_keys, str)
self.less_keys = less_keys

if self.save_best is not None:
self.best_ckpt_path = None
ZwwWayne marked this conversation as resolved.
Show resolved Hide resolved
self._init_rule(rule, self.save_best)
Expand All @@ -103,7 +139,8 @@ def _init_rule(self, rule, key_indicator):
"""Initialize rule, key_indicator, comparison_func, and best score.

Here is the rule to determine which rule is used for key indicator
when the rule is not specific:
when the rule is not specific (note that the key indicator matching
is case-insensitive):
1. If the key indicator is in ``self.greater_keys``, the rule will be
specified as 'greater'.
2. Or if the key indicator is in ``self.less_keys``, the rule will be
Expand All @@ -124,13 +161,19 @@ def _init_rule(self, rule, key_indicator):

if rule is None:
if key_indicator != 'auto':
if key_indicator in self.greater_keys:
# `_lc` here means we use the lower case of keys for
# case-insensitive matching
key_indicator_lc = key_indicator.lower()
ZwwWayne marked this conversation as resolved.
Show resolved Hide resolved
greater_keys = [key.lower() for key in self.greater_keys]
less_keys = [key.lower() for key in self.less_keys]

if key_indicator_lc in greater_keys:
rule = 'greater'
elif key_indicator in self.less_keys:
elif key_indicator_lc in less_keys:
rule = 'less'
elif any(key in key_indicator for key in self.greater_keys):
elif any(key in key_indicator_lc for key in greater_keys):
rule = 'greater'
elif any(key in key_indicator for key in self.less_keys):
elif any(key in key_indicator_lc for key in less_keys):
rule = 'less'
else:
raise ValueError(f'Cannot infer the rule for key '
Expand Down Expand Up @@ -181,8 +224,7 @@ def _do_evaluate(self, runner):
if not self._should_evaluate(runner):
return

from mmcv.engine import single_gpu_test
results = single_gpu_test(runner.model, self.dataloader)
results = self.test_fn(runner.model, self.dataloader)
runner.log_buffer.output['eval_iter_num'] = len(self.dataloader)
key_score = self.evaluate(runner, results)
if self.save_best:
Expand Down Expand Up @@ -311,6 +353,10 @@ class DistEvalHook(EvalHook):
.etc will be inferred by 'greater' rule. Keys contain 'loss' will
be inferred by 'less' rule. Options are 'greater', 'less', None.
Default: None.
test_fn (callable, optional): test a model with samples from a
dataloader in a multi-gpu manner, and return the test results. If
``None``, the default test function ``mmcv.engine.multi_gpu_test``
will be used. (default: ``None``)
tmpdir (str | None): Temporary directory to save the results of all
processes. Default: None.
gpu_collect (bool): Whether to use gpu or cpu to collect results.
Expand All @@ -329,18 +375,30 @@ def __init__(self,
by_epoch=True,
save_best=None,
rule=None,
test_fn=None,
greater_keys=None,
less_keys=None,
broadcast_bn_buffer=True,
tmpdir=None,
gpu_collect=False,
**eval_kwargs):

if test_fn is None:
from mmcv.engine import multi_gpu_test
test_fn = multi_gpu_test

super().__init__(
dataloader,
start=start,
interval=interval,
by_epoch=by_epoch,
save_best=save_best,
rule=rule,
test_fn=test_fn,
greater_keys=greater_keys,
less_keys=less_keys,
**eval_kwargs)

self.broadcast_bn_buffer = broadcast_bn_buffer
self.tmpdir = tmpdir
self.gpu_collect = gpu_collect
Expand All @@ -367,8 +425,7 @@ def _do_evaluate(self, runner):
if tmpdir is None:
tmpdir = osp.join(runner.work_dir, '.eval_hook')

from mmcv.engine import multi_gpu_test
results = multi_gpu_test(
results = self.test_fn(
runner.model,
self.dataloader,
tmpdir=tmpdir,
Expand Down
29 changes: 27 additions & 2 deletions tests/test_runner/test_eval_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,8 @@ def _build_iter_runner():

class EvalHook(BaseEvalHook):

greater_keys = ['acc', 'top']
less_keys = ['loss', 'loss_top']
_default_greater_keys = ['acc', 'top']
_default_less_keys = ['loss', 'loss_top']

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
Expand Down Expand Up @@ -273,6 +273,31 @@ def test_eval_hook():
assert runner.meta['hook_msgs']['best_score'] == 7
assert not osp.exists(old_ckpt_path)

# test EvalHook with customer test_fn and greater/less keys
loader = DataLoader(EvalDataset())
model = Model()
data_loader = DataLoader(EvalDataset())

eval_hook = EvalHook(
data_loader,
save_best='acc',
test_fn=mock.MagicMock(return_value={}),
greater_keys=[],
less_keys=['acc'])

with tempfile.TemporaryDirectory() as tmpdir:
logger = get_logger('test_eval')
runner = EpochBasedRunner(model=model, work_dir=tmpdir, logger=logger)
runner.register_checkpoint_hook(dict(interval=1))
runner.register_hook(eval_hook)
runner.run([loader], [('train', 1)], 8)

ckpt_path = osp.join(tmpdir, 'best_acc_epoch_6.pth')

assert runner.meta['hook_msgs']['best_ckpt'] == ckpt_path
assert osp.exists(ckpt_path)
assert runner.meta['hook_msgs']['best_score'] == -3


@patch('mmcv.engine.single_gpu_test', MagicMock)
@patch('mmcv.engine.multi_gpu_test', MagicMock)
Expand Down