From 244aaf4ef257d28cbc0142deb8d5cea44ddcb726 Mon Sep 17 00:00:00 2001 From: ly015 Date: Thu, 3 Jun 2021 16:47:42 +0800 Subject: [PATCH] EvalHook uses case-insensitive key indicator matching and configurable test functions --- mmcv/runner/hooks/evaluation.py | 48 ++++++++++++++++++++++++++------- 1 file changed, 39 insertions(+), 9 deletions(-) diff --git a/mmcv/runner/hooks/evaluation.py b/mmcv/runner/hooks/evaluation.py index 6d2a31c21a..b0475a339f 100644 --- a/mmcv/runner/hooks/evaluation.py +++ b/mmcv/runner/hooks/evaluation.py @@ -41,6 +41,10 @@ class EvalHook(Hook): .etc will be inferred by 'greater' rule. Keys contain 'loss' will be inferred by 'less' rule. Options are 'greater', 'less', None. Default: None. + test_fn (callable, optional): test a model with samples from a + dataloader, and return the test results. If ``None``, the default + test function ``mmcv.engine.single_gpu_test`` will be used. + (default: ``None``) **eval_kwargs: Evaluation arguments fed into the evaluate function of the dataset. @@ -68,6 +72,7 @@ def __init__(self, by_epoch=True, save_best=None, rule=None, + test_fn=None, **eval_kwargs): if not isinstance(dataloader, DataLoader): raise TypeError(f'dataloader must be a pytorch DataLoader, ' @@ -99,11 +104,14 @@ def __init__(self, self.best_ckpt_path = None self._init_rule(rule, self.save_best) + self.test_fn = test_fn + def _init_rule(self, rule, key_indicator): """Initialize rule, key_indicator, comparison_func, and best score. Here is the rule to determine which rule is used for key indicator - when the rule is not specific: + when the rule is not specific (note that the key indicator matching + is case-insensitive): 1. If the key indicator is in ``self.greater_keys``, the rule will be specified as 'greater'. 2. Or if the key indicator is in ``self.less_keys``, the rule will be @@ -124,13 +132,18 @@ def _init_rule(self, rule, key_indicator): if rule is None: if key_indicator != 'auto': - if key_indicator in self.greater_keys: + # use case-insensitive key indicator to infer the rule + key_indicator_lc = key_indicator.lower() + greater_keys = [key.lower() for key in self.greater_keys] + less_keys = [key.lower() for key in self.less_keys] + + if key_indicator_lc in greater_keys: rule = 'greater' - elif key_indicator in self.less_keys: + elif key_indicator_lc in less_keys: rule = 'less' - elif any(key in key_indicator for key in self.greater_keys): + elif any(key in key_indicator_lc for key in greater_keys): rule = 'greater' - elif any(key in key_indicator for key in self.less_keys): + elif any(key in key_indicator_lc for key in less_keys): rule = 'less' else: raise ValueError(f'Cannot infer the rule for key ' @@ -181,8 +194,13 @@ def _do_evaluate(self, runner): if not self._should_evaluate(runner): return - from mmcv.engine import single_gpu_test - results = single_gpu_test(runner.model, self.dataloader) + if self.test_fn is None: + from mmcv.engine import single_gpu_test + test_fn = single_gpu_test + else: + test_fn = self.test_fn + + results = test_fn(runner.model, self.dataloader) runner.log_buffer.output['eval_iter_num'] = len(self.dataloader) key_score = self.evaluate(runner, results) if self.save_best: @@ -311,6 +329,10 @@ class DistEvalHook(EvalHook): .etc will be inferred by 'greater' rule. Keys contain 'loss' will be inferred by 'less' rule. Options are 'greater', 'less', None. Default: None. + test_fn (callable, optional): test a model with samples from a + dataloader in a multi-gpu manner, and return the test results. If + ``None``, the default test function ``mmcv.engine.multi_gpu_test`` + will be used. (default: ``None``) tmpdir (str | None): Temporary directory to save the results of all processes. Default: None. gpu_collect (bool): Whether to use gpu or cpu to collect results. @@ -318,6 +340,7 @@ class DistEvalHook(EvalHook): broadcast_bn_buffer (bool): Whether to broadcast the buffer(running_mean and running_var) of rank 0 to other rank before evaluation. Default: True. + multi_ **eval_kwargs: Evaluation arguments fed into the evaluate function of the dataset. """ @@ -329,6 +352,7 @@ def __init__(self, by_epoch=True, save_best=None, rule=None, + test_fn=None, broadcast_bn_buffer=True, tmpdir=None, gpu_collect=False, @@ -340,6 +364,7 @@ def __init__(self, by_epoch=by_epoch, save_best=save_best, rule=rule, + test_fn=test_fn, **eval_kwargs) self.broadcast_bn_buffer = broadcast_bn_buffer self.tmpdir = tmpdir @@ -367,8 +392,13 @@ def _do_evaluate(self, runner): if tmpdir is None: tmpdir = osp.join(runner.work_dir, '.eval_hook') - from mmcv.engine import multi_gpu_test - results = multi_gpu_test( + if self.test_fn is None: + from mmcv.engine import multi_gpu_test + test_fn = multi_gpu_test + else: + test_fn = self.test_fn + + results = test_fn( runner.model, self.dataloader, tmpdir=tmpdir,