Skip to content

Commit

Permalink
EvalHook uses case-insensitive key indicator matching and configurabl…
Browse files Browse the repository at this point in the history
…e test functions
  • Loading branch information
ly015 committed Jun 3, 2021
1 parent d212bd5 commit 244aaf4
Showing 1 changed file with 39 additions and 9 deletions.
48 changes: 39 additions & 9 deletions mmcv/runner/hooks/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ class EvalHook(Hook):
.etc will be inferred by 'greater' rule. Keys contain 'loss' will
be inferred by 'less' rule. Options are 'greater', 'less', None.
Default: None.
test_fn (callable, optional): test a model with samples from a
dataloader, and return the test results. If ``None``, the default
test function ``mmcv.engine.single_gpu_test`` will be used.
(default: ``None``)
**eval_kwargs: Evaluation arguments fed into the evaluate function of
the dataset.
Expand Down Expand Up @@ -68,6 +72,7 @@ def __init__(self,
by_epoch=True,
save_best=None,
rule=None,
test_fn=None,
**eval_kwargs):
if not isinstance(dataloader, DataLoader):
raise TypeError(f'dataloader must be a pytorch DataLoader, '
Expand Down Expand Up @@ -99,11 +104,14 @@ def __init__(self,
self.best_ckpt_path = None
self._init_rule(rule, self.save_best)

self.test_fn = test_fn

def _init_rule(self, rule, key_indicator):
"""Initialize rule, key_indicator, comparison_func, and best score.
Here is the rule to determine which rule is used for key indicator
when the rule is not specific:
when the rule is not specific (note that the key indicator matching
is case-insensitive):
1. If the key indicator is in ``self.greater_keys``, the rule will be
specified as 'greater'.
2. Or if the key indicator is in ``self.less_keys``, the rule will be
Expand All @@ -124,13 +132,18 @@ def _init_rule(self, rule, key_indicator):

if rule is None:
if key_indicator != 'auto':
if key_indicator in self.greater_keys:
# use case-insensitive key indicator to infer the rule
key_indicator_lc = key_indicator.lower()
greater_keys = [key.lower() for key in self.greater_keys]
less_keys = [key.lower() for key in self.less_keys]

if key_indicator_lc in greater_keys:
rule = 'greater'
elif key_indicator in self.less_keys:
elif key_indicator_lc in less_keys:
rule = 'less'
elif any(key in key_indicator for key in self.greater_keys):
elif any(key in key_indicator_lc for key in greater_keys):
rule = 'greater'
elif any(key in key_indicator for key in self.less_keys):
elif any(key in key_indicator_lc for key in less_keys):
rule = 'less'
else:
raise ValueError(f'Cannot infer the rule for key '
Expand Down Expand Up @@ -181,8 +194,13 @@ def _do_evaluate(self, runner):
if not self._should_evaluate(runner):
return

from mmcv.engine import single_gpu_test
results = single_gpu_test(runner.model, self.dataloader)
if self.test_fn is None:
from mmcv.engine import single_gpu_test
test_fn = single_gpu_test
else:
test_fn = self.test_fn

results = test_fn(runner.model, self.dataloader)
runner.log_buffer.output['eval_iter_num'] = len(self.dataloader)
key_score = self.evaluate(runner, results)
if self.save_best:
Expand Down Expand Up @@ -311,13 +329,18 @@ class DistEvalHook(EvalHook):
.etc will be inferred by 'greater' rule. Keys contain 'loss' will
be inferred by 'less' rule. Options are 'greater', 'less', None.
Default: None.
test_fn (callable, optional): test a model with samples from a
dataloader in a multi-gpu manner, and return the test results. If
``None``, the default test function ``mmcv.engine.multi_gpu_test``
will be used. (default: ``None``)
tmpdir (str | None): Temporary directory to save the results of all
processes. Default: None.
gpu_collect (bool): Whether to use gpu or cpu to collect results.
Default: False.
broadcast_bn_buffer (bool): Whether to broadcast the
buffer(running_mean and running_var) of rank 0 to other rank
before evaluation. Default: True.
multi_
**eval_kwargs: Evaluation arguments fed into the evaluate function of
the dataset.
"""
Expand All @@ -329,6 +352,7 @@ def __init__(self,
by_epoch=True,
save_best=None,
rule=None,
test_fn=None,
broadcast_bn_buffer=True,
tmpdir=None,
gpu_collect=False,
Expand All @@ -340,6 +364,7 @@ def __init__(self,
by_epoch=by_epoch,
save_best=save_best,
rule=rule,
test_fn=test_fn,
**eval_kwargs)
self.broadcast_bn_buffer = broadcast_bn_buffer
self.tmpdir = tmpdir
Expand Down Expand Up @@ -367,8 +392,13 @@ def _do_evaluate(self, runner):
if tmpdir is None:
tmpdir = osp.join(runner.work_dir, '.eval_hook')

from mmcv.engine import multi_gpu_test
results = multi_gpu_test(
if self.test_fn is None:
from mmcv.engine import multi_gpu_test
test_fn = multi_gpu_test
else:
test_fn = self.test_fn

results = test_fn(
runner.model,
self.dataloader,
tmpdir=tmpdir,
Expand Down

0 comments on commit 244aaf4

Please sign in to comment.