From 460e4bfd972d5a63cc7d7714a9a9a3750911a2d8 Mon Sep 17 00:00:00 2001 From: dreamerlin <528557675@qq.com> Date: Sat, 28 Nov 2020 21:57:17 +0800 Subject: [PATCH 01/16] polish eval hooks --- mmaction/core/evaluation/eval_hooks.py | 176 +++++++++++------------ mmaction/datasets/activitynet_dataset.py | 3 +- mmaction/datasets/base.py | 4 +- mmaction/datasets/hvu_dataset.py | 4 +- mmaction/datasets/ssn_dataset.py | 3 +- tests/test_runtime/test_train.py | 3 +- tools/analysis/eval_metric.py | 3 +- 7 files changed, 95 insertions(+), 101 deletions(-) diff --git a/mmaction/core/evaluation/eval_hooks.py b/mmaction/core/evaluation/eval_hooks.py index 7427f66a46..1a8694ecb4 100644 --- a/mmaction/core/evaluation/eval_hooks.py +++ b/mmaction/core/evaluation/eval_hooks.py @@ -26,21 +26,22 @@ class EpochEvalHook(Hook): epoch. If None, whether to evaluate is merely decided by ``interval``. Default: None. interval (int): Evaluation interval (by epochs). Default: 1. - save_best (bool): Whether to save best checkpoint during evaluation. - Default: True. - key_indicator (str | None): Key indicator to measure the best - checkpoint during evaluation when ``save_best`` is set to True. + save_best (str | None, optional): If a metric is specified, it would + measure the best checkpoint during evaluation. The information + about best checkpoint would be save in best.json. Options are the evaluation metrics to the test dataset. e.g., ``top1_acc``, ``top5_acc``, ``mean_class_accuracy``, ``mean_average_precision``, ``mmit_mean_average_precision`` for action recognition dataset (RawframeDataset and VideoDataset). - ``AR@AN``, ``auc`` for action localization dataset + ``AR@AN``, ``auc`` for action localization dataset. (ActivityNetDataset). ``Recall@0.5@100``, ``AR@100``, ``mAP@0.5IOU`` for spatio-temporal action detection dataset (AVADataset). Default: `top1_acc`. - rule (str | None): Comparison rule for best score. Options are None, - 'greater' and 'less'. If set to None, it will infer a reasonable - rule. Default: 'None'. + rule (str | None, optional): Comparison rule for best score. If set to + None, it will infer a reasonable rule. Keys such as 'acc', 'top' + .etc will be inferred by 'greater' rule. Keys contain 'loss' will + be inferred by 'less' rule. Options are 'greater', 'less', None. + Default: None. **eval_kwargs: Evaluation arguments fed into the evaluate function of the dataset. """ @@ -54,58 +55,68 @@ def __init__(self, dataloader, start=None, interval=1, - save_best=True, - key_indicator='top1_acc', + save_best=None, rule=None, **eval_kwargs): if not isinstance(dataloader, DataLoader): raise TypeError(f'dataloader must be a pytorch DataLoader, ' f'but got {type(dataloader)}') - if not isinstance(save_best, bool): - raise TypeError("'save_best' should be a boolean") - - if save_best and not key_indicator: - raise ValueError('key_indicator should not be None, when ' - 'save_best is set to True.') - if rule not in self.rule_map and rule is not None: - raise KeyError(f'rule must be greater, less or None, ' - f'but got {rule}.') - - if rule is None and save_best: - if any(key in key_indicator for key in self.greater_keys): - rule = 'greater' - elif any(key in key_indicator for key in self.less_keys): - rule = 'less' - else: - raise ValueError( - f'key_indicator must be in {self.greater_keys} ' - f'or in {self.less_keys} when rule is None, ' - f'but got {key_indicator}') if interval <= 0: raise ValueError(f'interval must be positive, but got {interval}') + if start is not None and start < 0: warnings.warn( f'The evaluation start epoch {start} is smaller than 0, ' f'use 0 instead', UserWarning) start = 0 - self.dataloader = dataloader self.interval = interval self.start = start - self.eval_kwargs = eval_kwargs + + assert isinstance(save_best, str) or save_best is None self.save_best = save_best - self.key_indicator = key_indicator - self.rule = rule + self.eval_kwargs = eval_kwargs + self.initial_epoch_flag = True self.logger = get_root_logger() - if self.save_best: + if self.save_best is not None: + self._init_rule(rule, self.save_best) + + def _init_rule(self, rule, key_indicator): + """Initialize rule, key_indicator, comparison_func, and best score. + + Args: + rule (str | None): Comparison rule for best score. + key_indicator (str | None): Key indicator to determine the + comparison rule. + """ + if rule not in self.rule_map and rule is not None: + raise KeyError(f'rule must be greater, less or None, ' + f'but got {rule}.') + + if rule is None: + if key_indicator != 'auto': + if any(key in key_indicator for key in self.greater_keys): + rule = 'greater' + elif any(key in key_indicator for key in self.less_keys): + rule = 'less' + else: + raise ValueError(f'Cannot infer the rule for key ' + f'{key_indicator}, thus a specific rule ' + f'must be specified.') + self.rule = rule + self.key_indicator = key_indicator + if self.rule is not None: self.compare_func = self.rule_map[self.rule] - self.best_score = self.init_value_map[self.rule] - self.best_json = dict() - self.initial_epoch_flag = True + def before_run(self, runner): + if self.save_best is not None: + if runner.meta is None: + warnings.warn('runner.meta is None. Creating a empty one.') + runner.meta = dict() + runner.meta.setdefault('hook_msgs', dict()) def before_train_epoch(self, runner): """Evaluate the model only at the start of training.""" @@ -139,27 +150,26 @@ def after_train_epoch(self, runner): if not self.evaluation_flag(runner): return - current_ckpt_path = osp.join(runner.work_dir, - f'epoch_{runner.epoch + 1}.pth') - json_path = osp.join(runner.work_dir, 'best.json') - - if osp.exists(json_path) and len(self.best_json) == 0: - self.best_json = mmcv.load(json_path) - self.best_score = self.best_json['best_score'] - self.best_ckpt = self.best_json['best_ckpt'] - self.key_indicator = self.best_json['key_indicator'] - from mmaction.apis import single_gpu_test results = single_gpu_test(runner.model, self.dataloader) key_score = self.evaluate(runner, results) - if (self.save_best and self.compare_func(key_score, self.best_score)): - self.best_score = key_score + if self.save_best: + self._save_ckpt(runner, key_score) + + def _save_ckpt(self, runner, key_score): + best_score = runner.meta['hook_msgs'].get( + 'best_score', self.init_value_map[self.rule]) + if self.compare_func(key_score, best_score): + best_score = key_score + runner.meta['hook_msgs']['best_score'] = best_score + last_ckpt = runner.meta['hook_msgs']['last_ckpt'] + runner.meta['hook_msgs']['best_ckpt'] = last_ckpt + mmcv.symlink( + last_ckpt, + osp.join(runner.work_dir, f'best_{self.key_indicator}.pth')) self.logger.info( - f'Now best checkpoint is epoch_{runner.epoch + 1}.pth') - self.best_json['best_score'] = self.best_score - self.best_json['best_ckpt'] = current_ckpt_path - self.best_json['key_indicator'] = self.key_indicator - mmcv.dump(self.best_json, json_path) + f'Now best checkpoint is epoch_{runner.epoch + 1}.pth.' + f'Best {self.key_indicator} is {best_score:0.4f}') def evaluate(self, runner, results): """Evaluate the results. @@ -173,12 +183,10 @@ def evaluate(self, runner, results): for name, val in eval_res.items(): runner.log_buffer.output[name] = val runner.log_buffer.ready = True - if self.key_indicator is not None: - if self.key_indicator not in eval_res: - warnings.warn('The key indicator for evaluation is not ' - 'included in evaluation result, please specify ' - 'it in config file') - return None + if self.save_best is not None: + if self.key_indicator == 'auto': + # infer from eval_results + self._init_rule(self.rule, list(eval_res.keys())[0]) return eval_res[self.key_indicator] return None @@ -197,19 +205,20 @@ class DistEpochEvalHook(EpochEvalHook): epoch. If None, whether to evaluate is merely decided by ``interval``. Default: None. interval (int): Evaluation interval (by epochs). Default: 1. - save_best (bool): Whether to save best checkpoint during evaluation. - Default: True. - key_indicator (str | None): Key indicator to measure the best - checkpoint during evaluation when ``save_best`` is set to True. + save_best (str | None, optional): If a metric is specified, it would + measure the best checkpoint during evaluation. The information + about best checkpoint would be save in best.json. Options are the evaluation metrics to the test dataset. e.g., ``top1_acc``, ``top5_acc``, ``mean_class_accuracy``, ``mean_average_precision``, ``mmit_mean_average_precision`` for action recognition dataset (RawframeDataset and VideoDataset). - ``AR@AN``, ``auc`` for action localization dataset - (ActivityNetDataset). Default: `top1_acc`. - rule (str | None): Comparison rule for best score. Options are None, - 'greater' and 'less'. If set to None, it will infer a reasonable - rule. Default: 'None'. + ``AR@AN``, ``auc`` for action localization dataset. + (ActivityNetDataset). Default: None. + rule (str | None, optional): Comparison rule for best score. If set to + None, it will infer a reasonable rule. Keys such as 'acc', 'top' + .etc will be inferred by 'greater' rule. Keys contain 'loss' will + be inferred by 'less' rule. Options are 'greater', 'less', None. + Default: None. tmpdir (str | None): Temporary directory to save the results of all processes. Default: None. gpu_collect (bool): Whether to use gpu or cpu to collect results. @@ -222,8 +231,7 @@ def __init__(self, dataloader, start=None, interval=1, - save_best=True, - key_indicator='top1_acc', + save_best=None, rule=None, tmpdir=None, gpu_collect=False, @@ -233,7 +241,6 @@ def __init__(self, start=start, interval=interval, save_best=save_best, - key_indicator=key_indicator, rule=rule, **eval_kwargs) self.tmpdir = tmpdir @@ -244,18 +251,7 @@ def after_train_epoch(self, runner): if not self.evaluation_flag(runner): return - current_ckpt_path = osp.join(runner.work_dir, - f'epoch_{runner.epoch + 1}.pth') - json_path = osp.join(runner.work_dir, 'best.json') - - if osp.exists(json_path) and len(self.best_json) == 0: - self.best_json = mmcv.load(json_path) - self.best_score = self.best_json['best_score'] - self.best_ckpt = self.best_json['best_ckpt'] - self.key_indicator = self.best_json['key_indicator'] - from mmaction.apis import multi_gpu_test - tmpdir = self.tmpdir if tmpdir is None: tmpdir = osp.join(runner.work_dir, '.eval_hook') @@ -268,12 +264,6 @@ def after_train_epoch(self, runner): if runner.rank == 0: print('\n') key_score = self.evaluate(runner, results) - if (self.save_best and key_score is not None - and self.compare_func(key_score, self.best_score)): - self.best_score = key_score - self.logger.info( - f'Now best checkpoint is epoch_{runner.epoch + 1}.pth') - self.best_json['best_score'] = self.best_score - self.best_json['best_ckpt'] = current_ckpt_path - self.best_json['key_indicator'] = self.key_indicator - mmcv.dump(self.best_json, json_path) + + if self.save_best: + self._save_ckpt(runner, key_score) diff --git a/mmaction/datasets/activitynet_dataset.py b/mmaction/datasets/activitynet_dataset.py index d448cf19f5..018fd7b889 100644 --- a/mmaction/datasets/activitynet_dataset.py +++ b/mmaction/datasets/activitynet_dataset.py @@ -2,6 +2,7 @@ import os import os.path as osp import warnings +from collections import OrderedDict import mmcv import numpy as np @@ -238,7 +239,7 @@ def evaluate( if metric not in allowed_metrics: raise KeyError(f'metric {metric} is not supported') - eval_results = {} + eval_results = OrderedDict() ground_truth = self._import_ground_truth() proposal, num_proposals = self._import_proposals(results) diff --git a/mmaction/datasets/base.py b/mmaction/datasets/base.py index fbfc7f9bd9..fbc2c79c44 100644 --- a/mmaction/datasets/base.py +++ b/mmaction/datasets/base.py @@ -2,7 +2,7 @@ import os.path as osp import warnings from abc import ABCMeta, abstractmethod -from collections import defaultdict +from collections import OrderedDict, defaultdict import mmcv import numpy as np @@ -170,7 +170,7 @@ def evaluate(self, if metric not in allowed_metrics: raise KeyError(f'metric {metric} is not supported') - eval_results = {} + eval_results = OrderedDict() gt_labels = [ann['label'] for ann in self.video_infos] for metric in metrics: diff --git a/mmaction/datasets/hvu_dataset.py b/mmaction/datasets/hvu_dataset.py index b523748093..12beeb64aa 100644 --- a/mmaction/datasets/hvu_dataset.py +++ b/mmaction/datasets/hvu_dataset.py @@ -1,5 +1,6 @@ import copy import os.path as osp +from collections import OrderedDict import mmcv import numpy as np @@ -164,7 +165,8 @@ def evaluate(self, gt_labels = [ann['label'] for ann in self.video_infos] - eval_results = {} + eval_results = OrderedDict() + for category in self.tag_categories: start_idx = self.category2startidx[category] diff --git a/mmaction/datasets/ssn_dataset.py b/mmaction/datasets/ssn_dataset.py index 26cb7de436..374926d42a 100644 --- a/mmaction/datasets/ssn_dataset.py +++ b/mmaction/datasets/ssn_dataset.py @@ -1,6 +1,7 @@ import copy import os.path as osp import warnings +from collections import OrderedDict import mmcv import numpy as np @@ -470,7 +471,7 @@ def evaluate(self, for x in dets.tolist()]) plain_detections[class_idx] = detection_list - eval_results = {} + eval_results = OrderedDict() for metric in metrics: if metric == 'mAP': eval_dataset = metric_options.setdefault('mAP', {}).setdefault( diff --git a/tests/test_runtime/test_train.py b/tests/test_runtime/test_train.py index 6d5357db0b..13ccf88ff5 100644 --- a/tests/test_runtime/test_train.py +++ b/tests/test_runtime/test_train.py @@ -1,5 +1,6 @@ import copy import tempfile +from collections import OrderedDict import pytest import torch @@ -18,7 +19,7 @@ def __init__(self, test_mode=False): self.test_mode = test_mode def evaluate(self, results, logger=None): - eval_results = dict() + eval_results = OrderedDict() eval_results['acc'] = 1 return eval_results diff --git a/tools/analysis/eval_metric.py b/tools/analysis/eval_metric.py index ca7d0b1dd0..84c17db058 100644 --- a/tools/analysis/eval_metric.py +++ b/tools/analysis/eval_metric.py @@ -53,8 +53,7 @@ def main(): eval_kwargs = cfg.get('evaluation', {}).copy() # hard-code way to remove EpochEvalHook args for key in [ - 'interval', 'tmpdir', 'start', 'gpu_collect', 'save_best', 'rule', - 'key_indicator' + 'interval', 'tmpdir', 'start', 'gpu_collect', 'save_best', 'rule' ]: eval_kwargs.pop(key, None) eval_kwargs.update(dict(metrics=args.eval, **kwargs)) From 7bf40dea080b1188c1f9d04772171330a422ce9a Mon Sep 17 00:00:00 2001 From: dreamerlin <528557675@qq.com> Date: Sat, 28 Nov 2020 22:59:53 +0800 Subject: [PATCH 02/16] update unittest --- tests/test_runtime/test_eval_hook.py | 315 ++++++++++++++------------- 1 file changed, 162 insertions(+), 153 deletions(-) diff --git a/tests/test_runtime/test_eval_hook.py b/tests/test_runtime/test_eval_hook.py index af02be5cbb..d1bc9d1f96 100644 --- a/tests/test_runtime/test_eval_hook.py +++ b/tests/test_runtime/test_eval_hook.py @@ -1,9 +1,9 @@ import os.path as osp import tempfile import unittest.mock as mock +from collections import OrderedDict from unittest.mock import MagicMock, patch -import mmcv import pytest import torch import torch.nn as nn @@ -36,7 +36,7 @@ class EvalDataset(ExampleDataset): def evaluate(self, results, logger=None): acc = self.eval_result[self.index] - output = dict(acc=acc, index=self.index, score=acc) + output = OrderedDict(acc=acc, index=self.index, score=acc) self.index += 1 return output @@ -88,8 +88,8 @@ def val_step(self, x, optimizer, **kwargs): def test_eval_hook(): - with pytest.raises(TypeError): - # `save_best` should be a boolean + with pytest.raises(AssertionError): + # `save_best` should be a str test_dataset = ExampleModel() data_loader = DataLoader( test_dataset, @@ -97,7 +97,7 @@ def test_eval_hook(): sampler=None, num_workers=0, shuffle=False) - EpochEvalHook(data_loader, save_best='True') + EpochEvalHook(data_loader, save_best=True) with pytest.raises(TypeError): # dataloader must be a pytorch DataLoader @@ -113,15 +113,15 @@ def test_eval_hook(): EpochEvalHook(data_loader) with pytest.raises(ValueError): - # when `save_best` is True, `key_indicator` should not be None - test_dataset = ExampleModel() + # key_indicator must be valid when rule_map is None + test_dataset = ExampleDataset() data_loader = DataLoader( test_dataset, batch_size=1, sampler=None, num_workers=0, shuffle=False) - EpochEvalHook(data_loader, key_indicator=None) + EpochEvalHook(data_loader, save_best='unsupport') with pytest.raises(KeyError): # rule must be in keys of rule_map @@ -132,18 +132,7 @@ def test_eval_hook(): sampler=None, num_workers=0, shuffle=False) - EpochEvalHook(data_loader, save_best=False, rule='unsupport') - - with pytest.raises(ValueError): - # key_indicator must be valid when rule_map is None - test_dataset = ExampleModel() - data_loader = DataLoader( - test_dataset, - batch_size=1, - sampler=None, - num_workers=0, - shuffle=False) - EpochEvalHook(data_loader, key_indicator='unsupport') + EpochEvalHook(data_loader, save_best='auto', rule='unsupport') optimizer_cfg = dict( type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001) @@ -154,7 +143,7 @@ def test_eval_hook(): optimizer = build_optimizer(model, optimizer_cfg) data_loader = DataLoader(test_dataset, batch_size=1) - eval_hook = EpochEvalHook(data_loader, save_best=False) + eval_hook = EpochEvalHook(data_loader, save_best=None) with tempfile.TemporaryDirectory() as tmpdir: logger = get_logger('test_eval') runner = EpochBasedRunner( @@ -167,128 +156,150 @@ def test_eval_hook(): runner.run([loader], [('train', 1)], 1) test_dataset.evaluate.assert_called_with( test_dataset, [torch.tensor([1])], logger=runner.logger) - - best_json_path = osp.join(tmpdir, 'best.json') - assert not osp.exists(best_json_path) - - loader = DataLoader(EvalDataset(), batch_size=1) - model = ExampleModel() - data_loader = DataLoader(EvalDataset(), batch_size=1) - eval_hook = EpochEvalHook( - data_loader, interval=1, save_best=True, key_indicator='acc') - - with tempfile.TemporaryDirectory() as tmpdir: - logger = get_logger('test_eval') - runner = EpochBasedRunner( - model=model, - batch_processor=None, - optimizer=optimizer, - work_dir=tmpdir, - logger=logger) - runner.register_checkpoint_hook(dict(interval=1)) - runner.register_hook(eval_hook) - runner.run([loader], [('train', 1)], 8) - - best_json_path = osp.join(tmpdir, 'best.json') - best_json = mmcv.load(best_json_path) - real_path = osp.join(tmpdir, 'epoch_4.pth') - - assert best_json['best_ckpt'] == osp.realpath(real_path) - assert best_json['best_score'] == 7 - assert best_json['key_indicator'] == 'acc' - - data_loader = DataLoader(EvalDataset(), batch_size=1) - eval_hook = EpochEvalHook( - data_loader, - interval=1, - save_best=True, - key_indicator='score', - rule='greater') - with tempfile.TemporaryDirectory() as tmpdir: - logger = get_logger('test_eval') - runner = EpochBasedRunner( - model=model, - batch_processor=None, - optimizer=optimizer, - work_dir=tmpdir, - logger=logger) - runner.register_checkpoint_hook(dict(interval=1)) - runner.register_hook(eval_hook) - runner.run([loader], [('train', 1)], 8) - - best_json_path = osp.join(tmpdir, 'best.json') - best_json = mmcv.load(best_json_path) - real_path = osp.join(tmpdir, 'epoch_4.pth') - - assert best_json['best_ckpt'] == osp.realpath(real_path) - assert best_json['best_score'] == 7 - assert best_json['key_indicator'] == 'score' - - data_loader = DataLoader(EvalDataset(), batch_size=1) - eval_hook = EpochEvalHook(data_loader, rule='less', key_indicator='acc') - with tempfile.TemporaryDirectory() as tmpdir: - logger = get_logger('test_eval') - runner = EpochBasedRunner( - model=model, - batch_processor=None, - optimizer=optimizer, - work_dir=tmpdir, - logger=logger) - runner.register_checkpoint_hook(dict(interval=1)) - runner.register_hook(eval_hook) - runner.run([loader], [('train', 1)], 8) - - best_json_path = osp.join(tmpdir, 'best.json') - best_json = mmcv.load(best_json_path) - real_path = osp.join(tmpdir, 'epoch_6.pth') - - assert best_json['best_ckpt'] == osp.realpath(real_path) - assert best_json['best_score'] == -3 - assert best_json['key_indicator'] == 'acc' - - data_loader = DataLoader(EvalDataset(), batch_size=1) - eval_hook = EpochEvalHook(data_loader, key_indicator='acc') - with tempfile.TemporaryDirectory() as tmpdir: - logger = get_logger('test_eval') - runner = EpochBasedRunner( - model=model, - batch_processor=None, - optimizer=optimizer, - work_dir=tmpdir, - logger=logger) - runner.register_checkpoint_hook(dict(interval=1)) - runner.register_hook(eval_hook) - runner.run([loader], [('train', 1)], 2) - - best_json_path = osp.join(tmpdir, 'best.json') - best_json = mmcv.load(best_json_path) - real_path = osp.join(tmpdir, 'epoch_2.pth') - - assert best_json['best_ckpt'] == osp.realpath(real_path) - assert best_json['best_score'] == 4 - assert best_json['key_indicator'] == 'acc' - - resume_from = osp.join(tmpdir, 'latest.pth') - loader = DataLoader(ExampleDataset(), batch_size=1) - eval_hook = EpochEvalHook(data_loader, key_indicator='acc') - runner = EpochBasedRunner( - model=model, - batch_processor=None, - optimizer=optimizer, - work_dir=tmpdir, - logger=logger) - runner.register_checkpoint_hook(dict(interval=1)) - runner.register_hook(eval_hook) - runner.resume(resume_from) - runner.run([loader], [('train', 1)], 8) - - best_json_path = osp.join(tmpdir, 'best.json') - best_json = mmcv.load(best_json_path) - real_path = osp.join(tmpdir, 'epoch_4.pth') - - assert best_json['best_ckpt'] == osp.realpath(real_path) - assert best_json['best_score'] == 7 - assert best_json['key_indicator'] == 'acc' + assert runner.meta is None or 'best_score' not in runner.meta[ + 'hook_msgs'] + assert runner.meta is None or 'best_ckpt' not in runner.meta[ + 'hook_msgs'] + + # when `save_best` is set to 'auto', first metric will be used. + loader = DataLoader(EvalDataset(), batch_size=1) + model = ExampleModel() + data_loader = DataLoader(EvalDataset(), batch_size=1) + eval_hook = EpochEvalHook(data_loader, interval=1, save_best='auto') + + with tempfile.TemporaryDirectory() as tmpdir: + logger = get_logger('test_eval') + runner = EpochBasedRunner( + model=model, + batch_processor=None, + optimizer=optimizer, + work_dir=tmpdir, + logger=logger) + runner.register_checkpoint_hook(dict(interval=1)) + runner.register_hook(eval_hook) + runner.run([loader], [('train', 1)], 8) + + real_path = osp.join(tmpdir, 'epoch_4.pth') + link_path = osp.join(tmpdir, 'best_acc.pth') + + assert runner.meta['hook_msgs']['best_ckpt'] == osp.realpath( + real_path) + assert osp.exists(link_path) + assert runner.meta['hook_msgs']['best_score'] == 7 + + loader = DataLoader(EvalDataset(), batch_size=1) + model = ExampleModel() + data_loader = DataLoader(EvalDataset(), batch_size=1) + eval_hook = EpochEvalHook(data_loader, interval=1, save_best='acc') + + with tempfile.TemporaryDirectory() as tmpdir: + logger = get_logger('test_eval') + runner = EpochBasedRunner( + model=model, + batch_processor=None, + optimizer=optimizer, + work_dir=tmpdir, + logger=logger) + runner.register_checkpoint_hook(dict(interval=1)) + runner.register_hook(eval_hook) + runner.run([loader], [('train', 1)], 8) + + real_path = osp.join(tmpdir, 'epoch_4.pth') + link_path = osp.join(tmpdir, 'best_acc.pth') + + assert runner.meta['hook_msgs']['best_ckpt'] == osp.realpath( + real_path) + assert osp.exists(link_path) + assert runner.meta['hook_msgs']['best_score'] == 7 + + data_loader = DataLoader(EvalDataset(), batch_size=1) + eval_hook = EpochEvalHook( + data_loader, interval=1, save_best='score', rule='greater') + with tempfile.TemporaryDirectory() as tmpdir: + logger = get_logger('test_eval') + runner = EpochBasedRunner( + model=model, + batch_processor=None, + optimizer=optimizer, + work_dir=tmpdir, + logger=logger) + runner.register_checkpoint_hook(dict(interval=1)) + runner.register_hook(eval_hook) + runner.run([loader], [('train', 1)], 8) + + real_path = osp.join(tmpdir, 'epoch_4.pth') + link_path = osp.join(tmpdir, 'best_score.pth') + + assert runner.meta['hook_msgs']['best_ckpt'] == osp.realpath( + real_path) + assert osp.exists(link_path) + assert runner.meta['hook_msgs']['best_score'] == 7 + + data_loader = DataLoader(EvalDataset(), batch_size=1) + eval_hook = EpochEvalHook(data_loader, save_best='acc', rule='less') + with tempfile.TemporaryDirectory() as tmpdir: + logger = get_logger('test_eval') + runner = EpochBasedRunner( + model=model, + batch_processor=None, + optimizer=optimizer, + work_dir=tmpdir, + logger=logger) + runner.register_checkpoint_hook(dict(interval=1)) + runner.register_hook(eval_hook) + runner.run([loader], [('train', 1)], 8) + + real_path = osp.join(tmpdir, 'epoch_6.pth') + link_path = osp.join(tmpdir, 'best_acc.pth') + + assert runner.meta['hook_msgs']['best_ckpt'] == osp.realpath( + real_path) + assert osp.exists(link_path) + assert runner.meta['hook_msgs']['best_score'] == -3 + + data_loader = DataLoader(EvalDataset(), batch_size=1) + eval_hook = EpochEvalHook(data_loader, save_best='acc') + with tempfile.TemporaryDirectory() as tmpdir: + logger = get_logger('test_eval') + runner = EpochBasedRunner( + model=model, + batch_processor=None, + optimizer=optimizer, + work_dir=tmpdir, + logger=logger) + runner.register_checkpoint_hook(dict(interval=1)) + runner.register_hook(eval_hook) + runner.run([loader], [('train', 1)], 2) + + real_path = osp.join(tmpdir, 'epoch_2.pth') + link_path = osp.join(tmpdir, 'best_acc.pth') + + assert runner.meta['hook_msgs']['best_ckpt'] == osp.realpath( + real_path) + assert osp.exists(link_path) + assert runner.meta['hook_msgs']['best_score'] == 4 + + resume_from = osp.join(tmpdir, 'latest.pth') + loader = DataLoader(ExampleDataset(), batch_size=1) + eval_hook = EpochEvalHook(data_loader, save_best='acc') + runner = EpochBasedRunner( + model=model, + batch_processor=None, + optimizer=optimizer, + work_dir=tmpdir, + logger=logger) + runner.register_checkpoint_hook(dict(interval=1)) + runner.register_hook(eval_hook) + runner.resume(resume_from) + runner.run([loader], [('train', 1)], 8) + + real_path = osp.join(tmpdir, 'epoch_4.pth') + link_path = osp.join(tmpdir, 'best_acc.pth') + + assert runner.meta['hook_msgs']['best_ckpt'] == osp.realpath( + real_path) + assert osp.exists(link_path) + assert runner.meta['hook_msgs']['best_score'] == 7 @patch('mmaction.apis.single_gpu_test', MagicMock) @@ -309,7 +320,7 @@ def test_start_param(EpochEvalHookParam): # 1. start=None, interval=1: perform evaluation after each epoch. runner = _build_demo_runner() - evalhook = EpochEvalHookParam(dataloader, interval=1, save_best=False) + evalhook = EpochEvalHookParam(dataloader, interval=1) evalhook.evaluate = MagicMock() runner.register_hook(evalhook) runner.run([dataloader], [('train', 1)], 2) @@ -317,8 +328,7 @@ def test_start_param(EpochEvalHookParam): # 2. start=1, interval=1: perform evaluation after each epoch. runner = _build_demo_runner() - evalhook = EpochEvalHookParam( - dataloader, start=1, interval=1, save_best=False) + evalhook = EpochEvalHookParam(dataloader, start=1, interval=1) evalhook.evaluate = MagicMock() runner.register_hook(evalhook) runner.run([dataloader], [('train', 1)], 2) @@ -326,7 +336,7 @@ def test_start_param(EpochEvalHookParam): # 3. start=None, interval=2: perform evaluation after epoch 2, 4, 6, etc runner = _build_demo_runner() - evalhook = EpochEvalHookParam(dataloader, interval=2, save_best=False) + evalhook = EpochEvalHookParam(dataloader, interval=2) evalhook.evaluate = MagicMock() runner.register_hook(evalhook) runner.run([dataloader], [('train', 1)], 2) @@ -334,8 +344,7 @@ def test_start_param(EpochEvalHookParam): # 4. start=1, interval=2: perform evaluation after epoch 1, 3, 5, etc runner = _build_demo_runner() - evalhook = EpochEvalHookParam( - dataloader, start=1, interval=2, save_best=False) + evalhook = EpochEvalHookParam(dataloader, start=1, interval=2) evalhook.evaluate = MagicMock() runner.register_hook(evalhook) runner.run([dataloader], [('train', 1)], 3) @@ -344,7 +353,7 @@ def test_start_param(EpochEvalHookParam): # 5. start=0/negative, interval=1: perform evaluation after each epoch and # before epoch 1. runner = _build_demo_runner() - evalhook = EpochEvalHookParam(dataloader, start=0, save_best=False) + evalhook = EpochEvalHookParam(dataloader, start=0) evalhook.evaluate = MagicMock() runner.register_hook(evalhook) runner.run([dataloader], [('train', 1)], 2) @@ -352,7 +361,7 @@ def test_start_param(EpochEvalHookParam): runner = _build_demo_runner() with pytest.warns(UserWarning): - evalhook = EpochEvalHookParam(dataloader, start=-2, save_best=False) + evalhook = EpochEvalHookParam(dataloader, start=-2) evalhook.evaluate = MagicMock() runner.register_hook(evalhook) runner.run([dataloader], [('train', 1)], 2) @@ -361,7 +370,7 @@ def test_start_param(EpochEvalHookParam): # 6. resuming from epoch i, start = x (x<=i), interval =1: perform # evaluation after each epoch and before the first epoch. runner = _build_demo_runner() - evalhook = EpochEvalHookParam(dataloader, start=1, save_best=False) + evalhook = EpochEvalHookParam(dataloader, start=1) evalhook.evaluate = MagicMock() runner.register_hook(evalhook) runner._epoch = 2 @@ -371,7 +380,7 @@ def test_start_param(EpochEvalHookParam): # 7. resuming from epoch i, start = i+1/None, interval =1: perform # evaluation after each epoch. runner = _build_demo_runner() - evalhook = EpochEvalHookParam(dataloader, start=2, save_best=False) + evalhook = EpochEvalHookParam(dataloader, start=2) evalhook.evaluate = MagicMock() runner.register_hook(evalhook) runner._epoch = 1 From 9b6a3f15c9c2280289126ebd4523f60c599d8f76 Mon Sep 17 00:00:00 2001 From: dreamerlin <528557675@qq.com> Date: Sat, 19 Dec 2020 12:44:36 +0800 Subject: [PATCH 03/16] add by_epoch --- docs/tutorials/7_customize_runtime.md | 6 +- mmaction/apis/train.py | 6 +- mmaction/core/evaluation/__init__.py | 13 ++-- mmaction/core/evaluation/eval_hooks.py | 102 +++++++++++++++++-------- tests/test_runtime/test_eval_hook.py | 49 ++++++------ tools/analysis/eval_metric.py | 2 +- 6 files changed, 109 insertions(+), 69 deletions(-) diff --git a/docs/tutorials/7_customize_runtime.md b/docs/tutorials/7_customize_runtime.md index 44d4b03607..5b0d8754f4 100644 --- a/docs/tutorials/7_customize_runtime.md +++ b/docs/tutorials/7_customize_runtime.md @@ -191,7 +191,7 @@ We support many other learning rate schedule [here](https://github.com/open-mmla ## Customize Workflow -By default, we recommend users to use `EpochEvalHook` to do evaluation after training epoch, but they can still use `val` workflow as an alternative. +By default, we recommend users to use `EvalHook` to do evaluation after training epoch, but they can still use `val` workflow as an alternative. Workflow is a list of (phase, epochs) to specify the running order and epochs. By default it is set to be @@ -213,7 +213,7 @@ so that 1 epoch for training and 1 epoch for validation will be run iteratively. 1. The parameters of model will not be updated during val epoch. 2. Keyword `total_epochs` in the config only controls the number of training epochs and will not affect the validation workflow. -3. Workflows `[('train', 1), ('val', 1)]` and `[('train', 1)]` will not change the behavior of `EpochEvalHook` because `EpochEvalHook` is called by `after_train_epoch` and validation workflow only affect hooks that are called through `after_val_epoch`. +3. Workflows `[('train', 1), ('val', 1)]` and `[('train', 1)]` will not change the behavior of `EvalHook` because `EvalHook` is called by `after_train_epoch` and validation workflow only affect hooks that are called through `after_val_epoch`. Therefore, the only difference between `[('train', 1), ('val', 1)]` and ``[('train', 1)]`` is that the runner will calculate losses on validation set after each training epoch. ## Customize Hooks @@ -344,7 +344,7 @@ log_config = dict( #### Evaluation config -The config of `evaluation` will be used to initialize the [`EpochEvalHook`](https://github.com/open-mmlab/mmaction2/blob/master/mmaction/core/evaluation/eval_hooks.py#L12). +The config of `evaluation` will be used to initialize the [`EvalHook`](https://github.com/open-mmlab/mmaction2/blob/master/mmaction/core/evaluation/eval_hooks.py#L12). Except the key `interval`, other arguments such as `metrics` will be passed to the `dataset.evaluate()` ```python diff --git a/mmaction/apis/train.py b/mmaction/apis/train.py index b68360e63c..2e88eab40f 100644 --- a/mmaction/apis/train.py +++ b/mmaction/apis/train.py @@ -6,8 +6,8 @@ build_optimizer) from mmcv.runner.hooks import Fp16OptimizerHook -from ..core import (DistEpochEvalHook, EpochEvalHook, - OmniSourceDistSamplerSeedHook, OmniSourceRunner) +from ..core import (DistEvalHook, EvalHook, OmniSourceDistSamplerSeedHook, + OmniSourceRunner) from ..datasets import build_dataloader, build_dataset from ..utils import PreciseBNHook, get_root_logger @@ -143,7 +143,7 @@ def train_model(model, dataloader_setting = dict(dataloader_setting, **cfg.data.get('val_dataloader', {})) val_dataloader = build_dataloader(val_dataset, **dataloader_setting) - eval_hook = DistEpochEvalHook if distributed else EpochEvalHook + eval_hook = DistEvalHook if distributed else EvalHook runner.register_hook(eval_hook(val_dataloader, **eval_cfg)) if cfg.resume_from: diff --git a/mmaction/core/evaluation/__init__.py b/mmaction/core/evaluation/__init__.py index acb76e3844..c8000ed162 100644 --- a/mmaction/core/evaluation/__init__.py +++ b/mmaction/core/evaluation/__init__.py @@ -5,13 +5,12 @@ mmit_mean_average_precision, pairwise_temporal_iou, softmax, top_k_accuracy) from .eval_detection import ActivityNetLocalization -from .eval_hooks import DistEpochEvalHook, EpochEvalHook +from .eval_hooks import DistEvalHook, EvalHook __all__ = [ - 'DistEpochEvalHook', 'EpochEvalHook', 'top_k_accuracy', - 'mean_class_accuracy', 'confusion_matrix', 'mean_average_precision', - 'get_weighted_score', 'average_recall_at_avg_proposals', - 'pairwise_temporal_iou', 'average_precision_at_temporal_iou', - 'ActivityNetLocalization', 'softmax', 'interpolated_precision_recall', - 'mmit_mean_average_precision' + 'DistEvalHook', 'EvalHook', 'top_k_accuracy', 'mean_class_accuracy', + 'confusion_matrix', 'mean_average_precision', 'get_weighted_score', + 'average_recall_at_avg_proposals', 'pairwise_temporal_iou', + 'average_precision_at_temporal_iou', 'ActivityNetLocalization', 'softmax', + 'interpolated_precision_recall', 'mmit_mean_average_precision' ] diff --git a/mmaction/core/evaluation/eval_hooks.py b/mmaction/core/evaluation/eval_hooks.py index 1a8694ecb4..dbc457f8c5 100644 --- a/mmaction/core/evaluation/eval_hooks.py +++ b/mmaction/core/evaluation/eval_hooks.py @@ -9,11 +9,11 @@ from mmaction.utils import get_root_logger -class EpochEvalHook(Hook): - """Non-Distributed evaluation hook based on epochs. +class EvalHook(Hook): + """Non-Distributed evaluation hook. Notes: - If new arguments are added for EpochEvalHook, tools/test.py, + If new arguments are added for EvalHook, tools/test.py, tools/eval_metric.py may be effected. This hook will regularly perform evaluation in a given interval when @@ -25,7 +25,10 @@ class EpochEvalHook(Hook): evaluation before the training starts if ``start`` <= the resuming epoch. If None, whether to evaluate is merely decided by ``interval``. Default: None. - interval (int): Evaluation interval (by epochs). Default: 1. + interval (int): Evaluation interval. Default: 1. + by_epoch (bool): Determine perform evaluation by epoch or by iteration. + If set to True, it will perform by epoch. Otherwise, by iteration. + default: True. save_best (str | None, optional): If a metric is specified, it would measure the best checkpoint during evaluation. The information about best checkpoint would be save in best.json. @@ -55,6 +58,7 @@ def __init__(self, dataloader, start=None, interval=1, + by_epoch=True, save_best=None, rule=None, **eval_kwargs): @@ -65,6 +69,8 @@ def __init__(self, if interval <= 0: raise ValueError(f'interval must be positive, but got {interval}') + assert isinstance(by_epoch, bool) + if start is not None and start < 0: warnings.warn( f'The evaluation start epoch {start} is smaller than 0, ' @@ -73,6 +79,7 @@ def __init__(self, self.dataloader = dataloader self.interval = interval self.start = start + self.by_epoch = by_epoch assert isinstance(save_best, str) or save_best is None self.save_best = save_best @@ -118,45 +125,77 @@ def before_run(self, runner): runner.meta = dict() runner.meta.setdefault('hook_msgs', dict()) + def before_train_iter(self, runner): + """Evaluate the model only at the start of training by iteration.""" + if self.by_epoch: + return + if not self.initial_epoch_flag: + return + if self.start is not None and runner.iter >= self.start: + self.after_train_iter(runner) + self.initial_epoch_flag = False + def before_train_epoch(self, runner): - """Evaluate the model only at the start of training.""" + """Evaluate the model only at the start of training by epoch.""" + if not self.by_epoch: + return if not self.initial_epoch_flag: return if self.start is not None and runner.epoch >= self.start: self.after_train_epoch(runner) self.initial_epoch_flag = False + def after_train_iter(self, runner): + """Called after every training iter to evaluate the results.""" + self._do_evaluate(runner) + + def after_train_epoch(self, runner): + """Called after every training epoch to evaluate the results.""" + self._do_evaluate(runner) + + def _do_evaluate(self, runner): + """perform evaluation and save ckpt.""" + if not self.evaluation_flag(runner): + return + + from mmaction.apis import single_gpu_test + results = single_gpu_test(runner.model, self.dataloader) + key_score = self.evaluate(runner, results) + if self.save_best: + self._save_ckpt(runner, key_score) + def evaluation_flag(self, runner): - """Judge whether to perform_evaluation after this epoch. + """Judge whether to perform_evaluation. Returns: bool: The flag indicating whether to perform evaluation. """ + if self.by_epoch: + current = runner.epoch + check_time = self.every_n_epochs + else: + current = runner.iter + check_time = self.every_n_iters + if self.start is None: - if not self.every_n_epochs(runner, self.interval): - # No evaluation during the interval epochs. + if not check_time(runner, self.interval): + # No evaluation during the interval. return False - elif (runner.epoch + 1) < self.start: - # No evaluation if start is larger than the current epoch. + elif (current + 1) < self.start: + # No evaluation if start is larger than the current time. return False else: # Evaluation only at epochs 3, 5, 7... if start==3 and interval==2 - if (runner.epoch + 1 - self.start) % self.interval: + if (current + 1 - self.start) % self.interval: return False return True - def after_train_epoch(self, runner): - """Called after every training epoch to evaluate the results.""" - if not self.evaluation_flag(runner): - return - - from mmaction.apis import single_gpu_test - results = single_gpu_test(runner.model, self.dataloader) - key_score = self.evaluate(runner, results) - if self.save_best: - self._save_ckpt(runner, key_score) - def _save_ckpt(self, runner, key_score): + if self.by_epoch: + current = f'epoch_{runner.epoch + 1}' + else: + current = f'iter_{runner.epoch + 1}' + best_score = runner.meta['hook_msgs'].get( 'best_score', self.init_value_map[self.rule]) if self.compare_func(key_score, best_score): @@ -167,9 +206,8 @@ def _save_ckpt(self, runner, key_score): mmcv.symlink( last_ckpt, osp.join(runner.work_dir, f'best_{self.key_indicator}.pth')) - self.logger.info( - f'Now best checkpoint is epoch_{runner.epoch + 1}.pth.' - f'Best {self.key_indicator} is {best_score:0.4f}') + self.logger.info(f'Now best checkpoint is {current}.pth.' + f'Best {self.key_indicator} is {best_score:0.4f}') def evaluate(self, runner, results): """Evaluate the results. @@ -192,8 +230,8 @@ def evaluate(self, runner, results): return None -class DistEpochEvalHook(EpochEvalHook): - """Distributed evaluation hook based on epochs. +class DistEvalHook(EvalHook): + """Distributed evaluation hook. This hook will regularly perform evaluation in a given interval when performing in distributed environment. @@ -204,7 +242,10 @@ class DistEpochEvalHook(EpochEvalHook): evaluation before the training starts if ``start`` <= the resuming epoch. If None, whether to evaluate is merely decided by ``interval``. Default: None. - interval (int): Evaluation interval (by epochs). Default: 1. + interval (int): Evaluation interval. Default: 1. + by_epoch (bool): Determine perform evaluation by epoch or by iteration. + If set to True, it will perform by epoch. Otherwise, by iteration. + default: True. save_best (str | None, optional): If a metric is specified, it would measure the best checkpoint during evaluation. The information about best checkpoint would be save in best.json. @@ -231,6 +272,7 @@ def __init__(self, dataloader, start=None, interval=1, + by_epoch=True, save_best=None, rule=None, tmpdir=None, @@ -240,14 +282,14 @@ def __init__(self, dataloader, start=start, interval=interval, + by_epoch=by_epoch, save_best=save_best, rule=rule, **eval_kwargs) self.tmpdir = tmpdir self.gpu_collect = gpu_collect - def after_train_epoch(self, runner): - """Called after each training epoch to evaluate the model.""" + def _do_evaluate(self, runner): if not self.evaluation_flag(runner): return diff --git a/tests/test_runtime/test_eval_hook.py b/tests/test_runtime/test_eval_hook.py index d1bc9d1f96..3342730f62 100644 --- a/tests/test_runtime/test_eval_hook.py +++ b/tests/test_runtime/test_eval_hook.py @@ -11,7 +11,7 @@ from mmcv.utils import get_logger from torch.utils.data import DataLoader, Dataset -from mmaction.core import DistEpochEvalHook, EpochEvalHook +from mmaction.core import DistEvalHook, EvalHook class ExampleDataset(Dataset): @@ -97,7 +97,7 @@ def test_eval_hook(): sampler=None, num_workers=0, shuffle=False) - EpochEvalHook(data_loader, save_best=True) + EvalHook(data_loader, save_best=True) with pytest.raises(TypeError): # dataloader must be a pytorch DataLoader @@ -110,7 +110,7 @@ def test_eval_hook(): num_worker=0, shuffle=False) ] - EpochEvalHook(data_loader) + EvalHook(data_loader) with pytest.raises(ValueError): # key_indicator must be valid when rule_map is None @@ -121,7 +121,7 @@ def test_eval_hook(): sampler=None, num_workers=0, shuffle=False) - EpochEvalHook(data_loader, save_best='unsupport') + EvalHook(data_loader, save_best='unsupport') with pytest.raises(KeyError): # rule must be in keys of rule_map @@ -132,7 +132,7 @@ def test_eval_hook(): sampler=None, num_workers=0, shuffle=False) - EpochEvalHook(data_loader, save_best='auto', rule='unsupport') + EvalHook(data_loader, save_best='auto', rule='unsupport') optimizer_cfg = dict( type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001) @@ -143,7 +143,7 @@ def test_eval_hook(): optimizer = build_optimizer(model, optimizer_cfg) data_loader = DataLoader(test_dataset, batch_size=1) - eval_hook = EpochEvalHook(data_loader, save_best=None) + eval_hook = EvalHook(data_loader, save_best=None) with tempfile.TemporaryDirectory() as tmpdir: logger = get_logger('test_eval') runner = EpochBasedRunner( @@ -165,7 +165,7 @@ def test_eval_hook(): loader = DataLoader(EvalDataset(), batch_size=1) model = ExampleModel() data_loader = DataLoader(EvalDataset(), batch_size=1) - eval_hook = EpochEvalHook(data_loader, interval=1, save_best='auto') + eval_hook = EvalHook(data_loader, interval=1, save_best='auto') with tempfile.TemporaryDirectory() as tmpdir: logger = get_logger('test_eval') @@ -190,7 +190,7 @@ def test_eval_hook(): loader = DataLoader(EvalDataset(), batch_size=1) model = ExampleModel() data_loader = DataLoader(EvalDataset(), batch_size=1) - eval_hook = EpochEvalHook(data_loader, interval=1, save_best='acc') + eval_hook = EvalHook(data_loader, interval=1, save_best='acc') with tempfile.TemporaryDirectory() as tmpdir: logger = get_logger('test_eval') @@ -213,7 +213,7 @@ def test_eval_hook(): assert runner.meta['hook_msgs']['best_score'] == 7 data_loader = DataLoader(EvalDataset(), batch_size=1) - eval_hook = EpochEvalHook( + eval_hook = EvalHook( data_loader, interval=1, save_best='score', rule='greater') with tempfile.TemporaryDirectory() as tmpdir: logger = get_logger('test_eval') @@ -236,7 +236,7 @@ def test_eval_hook(): assert runner.meta['hook_msgs']['best_score'] == 7 data_loader = DataLoader(EvalDataset(), batch_size=1) - eval_hook = EpochEvalHook(data_loader, save_best='acc', rule='less') + eval_hook = EvalHook(data_loader, save_best='acc', rule='less') with tempfile.TemporaryDirectory() as tmpdir: logger = get_logger('test_eval') runner = EpochBasedRunner( @@ -258,7 +258,7 @@ def test_eval_hook(): assert runner.meta['hook_msgs']['best_score'] == -3 data_loader = DataLoader(EvalDataset(), batch_size=1) - eval_hook = EpochEvalHook(data_loader, save_best='acc') + eval_hook = EvalHook(data_loader, save_best='acc') with tempfile.TemporaryDirectory() as tmpdir: logger = get_logger('test_eval') runner = EpochBasedRunner( @@ -281,7 +281,7 @@ def test_eval_hook(): resume_from = osp.join(tmpdir, 'latest.pth') loader = DataLoader(ExampleDataset(), batch_size=1) - eval_hook = EpochEvalHook(data_loader, save_best='acc') + eval_hook = EvalHook(data_loader, save_best='acc') runner = EpochBasedRunner( model=model, batch_processor=None, @@ -304,23 +304,22 @@ def test_eval_hook(): @patch('mmaction.apis.single_gpu_test', MagicMock) @patch('mmaction.apis.multi_gpu_test', MagicMock) -@pytest.mark.parametrize('EpochEvalHookParam', - (EpochEvalHook, DistEpochEvalHook)) -def test_start_param(EpochEvalHookParam): +@pytest.mark.parametrize('EvalHookParam', (EvalHook, DistEvalHook)) +def test_start_param(EvalHookParam): # create dummy data dataloader = DataLoader(torch.ones((5, 2))) # 0.1. dataloader is not a DataLoader object with pytest.raises(TypeError): - EpochEvalHookParam(dataloader=MagicMock(), interval=-1) + EvalHookParam(dataloader=MagicMock(), interval=-1) # 0.2. negative interval with pytest.raises(ValueError): - EpochEvalHookParam(dataloader, interval=-1) + EvalHookParam(dataloader, interval=-1) # 1. start=None, interval=1: perform evaluation after each epoch. runner = _build_demo_runner() - evalhook = EpochEvalHookParam(dataloader, interval=1) + evalhook = EvalHookParam(dataloader, interval=1) evalhook.evaluate = MagicMock() runner.register_hook(evalhook) runner.run([dataloader], [('train', 1)], 2) @@ -328,7 +327,7 @@ def test_start_param(EpochEvalHookParam): # 2. start=1, interval=1: perform evaluation after each epoch. runner = _build_demo_runner() - evalhook = EpochEvalHookParam(dataloader, start=1, interval=1) + evalhook = EvalHookParam(dataloader, start=1, interval=1) evalhook.evaluate = MagicMock() runner.register_hook(evalhook) runner.run([dataloader], [('train', 1)], 2) @@ -336,7 +335,7 @@ def test_start_param(EpochEvalHookParam): # 3. start=None, interval=2: perform evaluation after epoch 2, 4, 6, etc runner = _build_demo_runner() - evalhook = EpochEvalHookParam(dataloader, interval=2) + evalhook = EvalHookParam(dataloader, interval=2) evalhook.evaluate = MagicMock() runner.register_hook(evalhook) runner.run([dataloader], [('train', 1)], 2) @@ -344,7 +343,7 @@ def test_start_param(EpochEvalHookParam): # 4. start=1, interval=2: perform evaluation after epoch 1, 3, 5, etc runner = _build_demo_runner() - evalhook = EpochEvalHookParam(dataloader, start=1, interval=2) + evalhook = EvalHookParam(dataloader, start=1, interval=2) evalhook.evaluate = MagicMock() runner.register_hook(evalhook) runner.run([dataloader], [('train', 1)], 3) @@ -353,7 +352,7 @@ def test_start_param(EpochEvalHookParam): # 5. start=0/negative, interval=1: perform evaluation after each epoch and # before epoch 1. runner = _build_demo_runner() - evalhook = EpochEvalHookParam(dataloader, start=0) + evalhook = EvalHookParam(dataloader, start=0) evalhook.evaluate = MagicMock() runner.register_hook(evalhook) runner.run([dataloader], [('train', 1)], 2) @@ -361,7 +360,7 @@ def test_start_param(EpochEvalHookParam): runner = _build_demo_runner() with pytest.warns(UserWarning): - evalhook = EpochEvalHookParam(dataloader, start=-2) + evalhook = EvalHookParam(dataloader, start=-2) evalhook.evaluate = MagicMock() runner.register_hook(evalhook) runner.run([dataloader], [('train', 1)], 2) @@ -370,7 +369,7 @@ def test_start_param(EpochEvalHookParam): # 6. resuming from epoch i, start = x (x<=i), interval =1: perform # evaluation after each epoch and before the first epoch. runner = _build_demo_runner() - evalhook = EpochEvalHookParam(dataloader, start=1) + evalhook = EvalHookParam(dataloader, start=1) evalhook.evaluate = MagicMock() runner.register_hook(evalhook) runner._epoch = 2 @@ -380,7 +379,7 @@ def test_start_param(EpochEvalHookParam): # 7. resuming from epoch i, start = i+1/None, interval =1: perform # evaluation after each epoch. runner = _build_demo_runner() - evalhook = EpochEvalHookParam(dataloader, start=2) + evalhook = EvalHookParam(dataloader, start=2) evalhook.evaluate = MagicMock() runner.register_hook(evalhook) runner._epoch = 1 diff --git a/tools/analysis/eval_metric.py b/tools/analysis/eval_metric.py index 84c17db058..a68e566585 100644 --- a/tools/analysis/eval_metric.py +++ b/tools/analysis/eval_metric.py @@ -51,7 +51,7 @@ def main(): kwargs = {} if args.eval_options is None else args.eval_options eval_kwargs = cfg.get('evaluation', {}).copy() - # hard-code way to remove EpochEvalHook args + # hard-code way to remove EvalHook args for key in [ 'interval', 'tmpdir', 'start', 'gpu_collect', 'save_best', 'rule' ]: From bb7b47e14c9e3889e28faa1b0baef0534060be2c Mon Sep 17 00:00:00 2001 From: dreamerlin <528557675@qq.com> Date: Sat, 19 Dec 2020 18:04:16 +0800 Subject: [PATCH 04/16] correct unittest --- mmaction/core/evaluation/eval_hooks.py | 6 ++- tests/test_runtime/test_eval_hook.py | 70 +++++++++++++++++--------- 2 files changed, 50 insertions(+), 26 deletions(-) diff --git a/mmaction/core/evaluation/eval_hooks.py b/mmaction/core/evaluation/eval_hooks.py index dbc457f8c5..0d979219a8 100644 --- a/mmaction/core/evaluation/eval_hooks.py +++ b/mmaction/core/evaluation/eval_hooks.py @@ -147,11 +147,13 @@ def before_train_epoch(self, runner): def after_train_iter(self, runner): """Called after every training iter to evaluate the results.""" - self._do_evaluate(runner) + if not self.by_epoch: + self._do_evaluate(runner) def after_train_epoch(self, runner): """Called after every training epoch to evaluate the results.""" - self._do_evaluate(runner) + if self.by_epoch: + self._do_evaluate(runner) def _do_evaluate(self, runner): """perform evaluation and save ckpt.""" diff --git a/tests/test_runtime/test_eval_hook.py b/tests/test_runtime/test_eval_hook.py index 3342730f62..8a9d03b16f 100644 --- a/tests/test_runtime/test_eval_hook.py +++ b/tests/test_runtime/test_eval_hook.py @@ -7,7 +7,7 @@ import pytest import torch import torch.nn as nn -from mmcv.runner import EpochBasedRunner, build_optimizer +from mmcv.runner import EpochBasedRunner, IterBasedRunner, build_optimizer from mmcv.utils import get_logger from torch.utils.data import DataLoader, Dataset @@ -62,22 +62,23 @@ def train_step(self, data_batch, optimizer, **kwargs): return outputs -def _build_demo_runner(): +class Model(nn.Module): - class Model(nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(2, 1) + + def forward(self, x): + return self.linear(x) - def __init__(self): - super().__init__() - self.linear = nn.Linear(2, 1) + def train_step(self, x, optimizer, **kwargs): + return dict(loss=self(x)) - def forward(self, x): - return self.linear(x) + def val_step(self, x, optimizer, **kwargs): + return dict(loss=self(x)) - def train_step(self, x, optimizer, **kwargs): - return dict(loss=self(x)) - def val_step(self, x, optimizer, **kwargs): - return dict(loss=self(x)) +def _build_epoch_runner(): model = Model() tmp_dir = tempfile.mkdtemp() @@ -87,6 +88,16 @@ def val_step(self, x, optimizer, **kwargs): return runner +def _build_iter_runner(): + + model = Model() + tmp_dir = tempfile.mkdtemp() + + runner = IterBasedRunner( + model=model, work_dir=tmp_dir, logger=get_logger('demo')) + return runner + + def test_eval_hook(): with pytest.raises(AssertionError): # `save_best` should be a str @@ -304,8 +315,11 @@ def test_eval_hook(): @patch('mmaction.apis.single_gpu_test', MagicMock) @patch('mmaction.apis.multi_gpu_test', MagicMock) -@pytest.mark.parametrize('EvalHookParam', (EvalHook, DistEvalHook)) -def test_start_param(EvalHookParam): +@pytest.mark.parametrize('EvalHookParam', [EvalHook, DistEvalHook]) +@pytest.mark.parametrize('_build_demo_runner,by_epoch', + [(_build_epoch_runner, True), + (_build_iter_runner, False)]) +def test_start_param(EvalHookParam, _build_demo_runner, by_epoch): # create dummy data dataloader = DataLoader(torch.ones((5, 2))) @@ -319,7 +333,7 @@ def test_start_param(EvalHookParam): # 1. start=None, interval=1: perform evaluation after each epoch. runner = _build_demo_runner() - evalhook = EvalHookParam(dataloader, interval=1) + evalhook = EvalHookParam(dataloader, interval=1, by_epoch=by_epoch) evalhook.evaluate = MagicMock() runner.register_hook(evalhook) runner.run([dataloader], [('train', 1)], 2) @@ -327,7 +341,8 @@ def test_start_param(EvalHookParam): # 2. start=1, interval=1: perform evaluation after each epoch. runner = _build_demo_runner() - evalhook = EvalHookParam(dataloader, start=1, interval=1) + evalhook = EvalHookParam( + dataloader, start=1, interval=1, by_epoch=by_epoch) evalhook.evaluate = MagicMock() runner.register_hook(evalhook) runner.run([dataloader], [('train', 1)], 2) @@ -335,7 +350,7 @@ def test_start_param(EvalHookParam): # 3. start=None, interval=2: perform evaluation after epoch 2, 4, 6, etc runner = _build_demo_runner() - evalhook = EvalHookParam(dataloader, interval=2) + evalhook = EvalHookParam(dataloader, interval=2, by_epoch=by_epoch) evalhook.evaluate = MagicMock() runner.register_hook(evalhook) runner.run([dataloader], [('train', 1)], 2) @@ -343,7 +358,8 @@ def test_start_param(EvalHookParam): # 4. start=1, interval=2: perform evaluation after epoch 1, 3, 5, etc runner = _build_demo_runner() - evalhook = EvalHookParam(dataloader, start=1, interval=2) + evalhook = EvalHookParam( + dataloader, start=1, interval=2, by_epoch=by_epoch) evalhook.evaluate = MagicMock() runner.register_hook(evalhook) runner.run([dataloader], [('train', 1)], 3) @@ -352,7 +368,7 @@ def test_start_param(EvalHookParam): # 5. start=0/negative, interval=1: perform evaluation after each epoch and # before epoch 1. runner = _build_demo_runner() - evalhook = EvalHookParam(dataloader, start=0) + evalhook = EvalHookParam(dataloader, start=0, by_epoch=by_epoch) evalhook.evaluate = MagicMock() runner.register_hook(evalhook) runner.run([dataloader], [('train', 1)], 2) @@ -360,7 +376,7 @@ def test_start_param(EvalHookParam): runner = _build_demo_runner() with pytest.warns(UserWarning): - evalhook = EvalHookParam(dataloader, start=-2) + evalhook = EvalHookParam(dataloader, start=-2, by_epoch=by_epoch) evalhook.evaluate = MagicMock() runner.register_hook(evalhook) runner.run([dataloader], [('train', 1)], 2) @@ -369,19 +385,25 @@ def test_start_param(EvalHookParam): # 6. resuming from epoch i, start = x (x<=i), interval =1: perform # evaluation after each epoch and before the first epoch. runner = _build_demo_runner() - evalhook = EvalHookParam(dataloader, start=1) + evalhook = EvalHookParam(dataloader, start=1, by_epoch=by_epoch) evalhook.evaluate = MagicMock() runner.register_hook(evalhook) - runner._epoch = 2 + if by_epoch: + runner._epoch = 2 + else: + runner._iter = 2 runner.run([dataloader], [('train', 1)], 3) assert evalhook.evaluate.call_count == 2 # before & after epoch 3 # 7. resuming from epoch i, start = i+1/None, interval =1: perform # evaluation after each epoch. runner = _build_demo_runner() - evalhook = EvalHookParam(dataloader, start=2) + evalhook = EvalHookParam(dataloader, start=2, by_epoch=by_epoch) evalhook.evaluate = MagicMock() runner.register_hook(evalhook) - runner._epoch = 1 + if by_epoch: + runner._epoch = 1 + else: + runner._iter = 1 runner.run([dataloader], [('train', 1)], 3) assert evalhook.evaluate.call_count == 2 # after epoch 2 & 3 From 9dfe71b0b256bd38b381699c8af4570b63f3b190 Mon Sep 17 00:00:00 2001 From: dreamerlin <528557675@qq.com> Date: Sat, 19 Dec 2020 18:55:21 +0800 Subject: [PATCH 05/16] correct unittest --- tests/test_runtime/test_eval_hook.py | 153 ++++++++------------------- 1 file changed, 43 insertions(+), 110 deletions(-) diff --git a/tests/test_runtime/test_eval_hook.py b/tests/test_runtime/test_eval_hook.py index 8a9d03b16f..77fc6f962c 100644 --- a/tests/test_runtime/test_eval_hook.py +++ b/tests/test_runtime/test_eval_hook.py @@ -7,7 +7,7 @@ import pytest import torch import torch.nn as nn -from mmcv.runner import EpochBasedRunner, IterBasedRunner, build_optimizer +from mmcv.runner import EpochBasedRunner, IterBasedRunner from mmcv.utils import get_logger from torch.utils.data import DataLoader, Dataset @@ -21,7 +21,7 @@ def __init__(self): self.eval_result = [1, 4, 3, 7, 2, -3, 4, 6] def __getitem__(self, idx): - results = dict(imgs=torch.tensor([1])) + results = dict(x=torch.tensor([1])) return results def __len__(self): @@ -41,38 +41,19 @@ def evaluate(self, results, logger=None): return output -class ExampleModel(nn.Module): - - def __init__(self): - super().__init__() - self.conv = nn.Linear(1, 1) - self.test_cfg = None - - def forward(self, imgs, return_loss=False): - return imgs - - def train_step(self, data_batch, optimizer, **kwargs): - outputs = { - 'loss': 0.5, - 'log_vars': { - 'accuracy': 0.98 - }, - 'num_samples': 1 - } - return outputs - - class Model(nn.Module): def __init__(self): super().__init__() self.linear = nn.Linear(2, 1) - def forward(self, x): - return self.linear(x) + def forward(self, x, **kwargs): + return x - def train_step(self, x, optimizer, **kwargs): - return dict(loss=self(x)) + def train_step(self, data_batch, optimizer, **kwargs): + if not isinstance(data_batch, dict): + data_batch = dict(x=data_batch) + return data_batch def val_step(self, x, optimizer, **kwargs): return dict(loss=self(x)) @@ -101,68 +82,39 @@ def _build_iter_runner(): def test_eval_hook(): with pytest.raises(AssertionError): # `save_best` should be a str - test_dataset = ExampleModel() - data_loader = DataLoader( - test_dataset, - batch_size=1, - sampler=None, - num_workers=0, - shuffle=False) + test_dataset = Model() + data_loader = DataLoader(test_dataset) EvalHook(data_loader, save_best=True) with pytest.raises(TypeError): # dataloader must be a pytorch DataLoader - test_dataset = ExampleModel() - data_loader = [ - DataLoader( - test_dataset, - batch_size=1, - sampler=None, - num_worker=0, - shuffle=False) - ] + test_dataset = Model() + data_loader = [DataLoader(test_dataset)] EvalHook(data_loader) with pytest.raises(ValueError): # key_indicator must be valid when rule_map is None test_dataset = ExampleDataset() - data_loader = DataLoader( - test_dataset, - batch_size=1, - sampler=None, - num_workers=0, - shuffle=False) + data_loader = DataLoader(test_dataset) EvalHook(data_loader, save_best='unsupport') with pytest.raises(KeyError): # rule must be in keys of rule_map - test_dataset = ExampleModel() - data_loader = DataLoader( - test_dataset, - batch_size=1, - sampler=None, - num_workers=0, - shuffle=False) + test_dataset = Model() + data_loader = DataLoader(test_dataset) EvalHook(data_loader, save_best='auto', rule='unsupport') - optimizer_cfg = dict( - type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001) - test_dataset = ExampleDataset() - loader = DataLoader(test_dataset, batch_size=1) - model = ExampleModel() - optimizer = build_optimizer(model, optimizer_cfg) - - data_loader = DataLoader(test_dataset, batch_size=1) + loader = DataLoader(test_dataset) + model = Model() + data_loader = DataLoader(test_dataset) eval_hook = EvalHook(data_loader, save_best=None) + with tempfile.TemporaryDirectory() as tmpdir: + + # total_epochs = 1 logger = get_logger('test_eval') - runner = EpochBasedRunner( - model=model, - batch_processor=None, - optimizer=optimizer, - work_dir=tmpdir, - logger=logger) + runner = EpochBasedRunner(model=model, work_dir=tmpdir, logger=logger) runner.register_hook(eval_hook) runner.run([loader], [('train', 1)], 1) test_dataset.evaluate.assert_called_with( @@ -173,19 +125,15 @@ def test_eval_hook(): 'hook_msgs'] # when `save_best` is set to 'auto', first metric will be used. - loader = DataLoader(EvalDataset(), batch_size=1) - model = ExampleModel() - data_loader = DataLoader(EvalDataset(), batch_size=1) + loader = DataLoader(EvalDataset()) + model = Model() + data_loader = DataLoader(EvalDataset()) eval_hook = EvalHook(data_loader, interval=1, save_best='auto') with tempfile.TemporaryDirectory() as tmpdir: logger = get_logger('test_eval') runner = EpochBasedRunner( - model=model, - batch_processor=None, - optimizer=optimizer, - work_dir=tmpdir, - logger=logger) + model=model, work_dir=tmpdir, logger=logger) runner.register_checkpoint_hook(dict(interval=1)) runner.register_hook(eval_hook) runner.run([loader], [('train', 1)], 8) @@ -198,19 +146,16 @@ def test_eval_hook(): assert osp.exists(link_path) assert runner.meta['hook_msgs']['best_score'] == 7 - loader = DataLoader(EvalDataset(), batch_size=1) - model = ExampleModel() - data_loader = DataLoader(EvalDataset(), batch_size=1) + # total_epochs = 8, return the best acc and corresponding epoch + loader = DataLoader(EvalDataset()) + model = Model() + data_loader = DataLoader(EvalDataset()) eval_hook = EvalHook(data_loader, interval=1, save_best='acc') with tempfile.TemporaryDirectory() as tmpdir: logger = get_logger('test_eval') runner = EpochBasedRunner( - model=model, - batch_processor=None, - optimizer=optimizer, - work_dir=tmpdir, - logger=logger) + model=model, work_dir=tmpdir, logger=logger) runner.register_checkpoint_hook(dict(interval=1)) runner.register_hook(eval_hook) runner.run([loader], [('train', 1)], 8) @@ -223,17 +168,14 @@ def test_eval_hook(): assert osp.exists(link_path) assert runner.meta['hook_msgs']['best_score'] == 7 - data_loader = DataLoader(EvalDataset(), batch_size=1) + # total_epochs = 8, return the best score and corresponding epoch + data_loader = DataLoader(EvalDataset()) eval_hook = EvalHook( data_loader, interval=1, save_best='score', rule='greater') with tempfile.TemporaryDirectory() as tmpdir: logger = get_logger('test_eval') runner = EpochBasedRunner( - model=model, - batch_processor=None, - optimizer=optimizer, - work_dir=tmpdir, - logger=logger) + model=model, work_dir=tmpdir, logger=logger) runner.register_checkpoint_hook(dict(interval=1)) runner.register_hook(eval_hook) runner.run([loader], [('train', 1)], 8) @@ -246,16 +188,14 @@ def test_eval_hook(): assert osp.exists(link_path) assert runner.meta['hook_msgs']['best_score'] == 7 - data_loader = DataLoader(EvalDataset(), batch_size=1) + # total_epochs = 8, return the best score using less compare func + # and indicate corresponding epoch + data_loader = DataLoader(EvalDataset()) eval_hook = EvalHook(data_loader, save_best='acc', rule='less') with tempfile.TemporaryDirectory() as tmpdir: logger = get_logger('test_eval') runner = EpochBasedRunner( - model=model, - batch_processor=None, - optimizer=optimizer, - work_dir=tmpdir, - logger=logger) + model=model, work_dir=tmpdir, logger=logger) runner.register_checkpoint_hook(dict(interval=1)) runner.register_hook(eval_hook) runner.run([loader], [('train', 1)], 8) @@ -268,16 +208,13 @@ def test_eval_hook(): assert osp.exists(link_path) assert runner.meta['hook_msgs']['best_score'] == -3 - data_loader = DataLoader(EvalDataset(), batch_size=1) + # Test the EvalHook when resume happend + data_loader = DataLoader(EvalDataset()) eval_hook = EvalHook(data_loader, save_best='acc') with tempfile.TemporaryDirectory() as tmpdir: logger = get_logger('test_eval') runner = EpochBasedRunner( - model=model, - batch_processor=None, - optimizer=optimizer, - work_dir=tmpdir, - logger=logger) + model=model, work_dir=tmpdir, logger=logger) runner.register_checkpoint_hook(dict(interval=1)) runner.register_hook(eval_hook) runner.run([loader], [('train', 1)], 2) @@ -291,14 +228,10 @@ def test_eval_hook(): assert runner.meta['hook_msgs']['best_score'] == 4 resume_from = osp.join(tmpdir, 'latest.pth') - loader = DataLoader(ExampleDataset(), batch_size=1) + loader = DataLoader(ExampleDataset()) eval_hook = EvalHook(data_loader, save_best='acc') runner = EpochBasedRunner( - model=model, - batch_processor=None, - optimizer=optimizer, - work_dir=tmpdir, - logger=logger) + model=model, work_dir=tmpdir, logger=logger) runner.register_checkpoint_hook(dict(interval=1)) runner.register_hook(eval_hook) runner.resume(resume_from) From e0036c0b9152182f8a39d509e9fbae33ee506409 Mon Sep 17 00:00:00 2001 From: dreamerlin <528557675@qq.com> Date: Mon, 21 Dec 2020 13:26:32 +0800 Subject: [PATCH 06/16] use runner.hook --- mmaction/core/evaluation/eval_hooks.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/mmaction/core/evaluation/eval_hooks.py b/mmaction/core/evaluation/eval_hooks.py index 0d979219a8..3cd3f8603f 100644 --- a/mmaction/core/evaluation/eval_hooks.py +++ b/mmaction/core/evaluation/eval_hooks.py @@ -6,8 +6,6 @@ from mmcv.runner import Hook from torch.utils.data import DataLoader -from mmaction.utils import get_root_logger - class EvalHook(Hook): """Non-Distributed evaluation hook. @@ -86,8 +84,6 @@ def __init__(self, self.eval_kwargs = eval_kwargs self.initial_epoch_flag = True - self.logger = get_root_logger() - if self.save_best is not None: self._init_rule(rule, self.save_best) @@ -208,8 +204,9 @@ def _save_ckpt(self, runner, key_score): mmcv.symlink( last_ckpt, osp.join(runner.work_dir, f'best_{self.key_indicator}.pth')) - self.logger.info(f'Now best checkpoint is {current}.pth.' - f'Best {self.key_indicator} is {best_score:0.4f}') + runner.logger.info( + f'Now best checkpoint is {current}.pth.' + f'Best {self.key_indicator} is {best_score:0.4f}') def evaluate(self, runner, results): """Evaluate the results. From f4ea7558409d2d8b93afebaa1fe39edff82d959d Mon Sep 17 00:00:00 2001 From: dreamerlin <528557675@qq.com> Date: Wed, 23 Dec 2020 11:42:04 +0800 Subject: [PATCH 07/16] add comment --- mmaction/core/evaluation/eval_hooks.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/mmaction/core/evaluation/eval_hooks.py b/mmaction/core/evaluation/eval_hooks.py index 3cd3f8603f..96c65fe954 100644 --- a/mmaction/core/evaluation/eval_hooks.py +++ b/mmaction/core/evaluation/eval_hooks.py @@ -252,8 +252,10 @@ class DistEvalHook(EvalHook): ``top1_acc``, ``top5_acc``, ``mean_class_accuracy``, ``mean_average_precision``, ``mmit_mean_average_precision`` for action recognition dataset (RawframeDataset and VideoDataset). - ``AR@AN``, ``auc`` for action localization dataset. - (ActivityNetDataset). Default: None. + ``AR@AN``, ``auc`` for action localization dataset + (ActivityNetDataset). If ``save_best`` is ``auto``, the first key + will be used. The interval of ``CheckpointHook`` should device + EvalHook. Default: None. rule (str | None, optional): Comparison rule for best score. If set to None, it will infer a reasonable rule. Keys such as 'acc', 'top' .etc will be inferred by 'greater' rule. Keys contain 'loss' will From d3faa86caa4109bb9160b81eeab86f21d58ec042 Mon Sep 17 00:00:00 2001 From: dreamerlin <528557675@qq.com> Date: Wed, 23 Dec 2020 12:14:04 +0800 Subject: [PATCH 08/16] polish --- tools/analysis/eval_metric.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tools/analysis/eval_metric.py b/tools/analysis/eval_metric.py index a68e566585..4335ea8c84 100644 --- a/tools/analysis/eval_metric.py +++ b/tools/analysis/eval_metric.py @@ -53,7 +53,8 @@ def main(): eval_kwargs = cfg.get('evaluation', {}).copy() # hard-code way to remove EvalHook args for key in [ - 'interval', 'tmpdir', 'start', 'gpu_collect', 'save_best', 'rule' + 'interval', 'tmpdir', 'start', 'gpu_collect', 'save_best', 'rule', + 'by_epoch' ]: eval_kwargs.pop(key, None) eval_kwargs.update(dict(metrics=args.eval, **kwargs)) From b6b030aa518e7b355f502ff58bb14b17761780c6 Mon Sep 17 00:00:00 2001 From: dreamerlin <528557675@qq.com> Date: Wed, 23 Dec 2020 16:28:52 +0800 Subject: [PATCH 09/16] polish again --- mmaction/core/evaluation/eval_hooks.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/mmaction/core/evaluation/eval_hooks.py b/mmaction/core/evaluation/eval_hooks.py index 96c65fe954..851015002d 100644 --- a/mmaction/core/evaluation/eval_hooks.py +++ b/mmaction/core/evaluation/eval_hooks.py @@ -34,10 +34,17 @@ class EvalHook(Hook): ``top1_acc``, ``top5_acc``, ``mean_class_accuracy``, ``mean_average_precision``, ``mmit_mean_average_precision`` for action recognition dataset (RawframeDataset and VideoDataset). +<<<<<<< HEAD ``AR@AN``, ``auc`` for action localization dataset. (ActivityNetDataset). ``Recall@0.5@100``, ``AR@100``, ``mAP@0.5IOU`` for spatio-temporal action detection dataset (AVADataset). Default: `top1_acc`. +======= + ``AR@AN``, ``auc`` for action localization dataset + (ActivityNetDataset). If ``save_best`` is ``auto``, the first key + of the returned ``OrderedDict`` result will be used. The interval + of ``CheckpointHook`` should device EvalHook. Default: None. +>>>>>>> polish again rule (str | None, optional): Comparison rule for best score. If set to None, it will infer a reasonable rule. Keys such as 'acc', 'top' .etc will be inferred by 'greater' rule. Keys contain 'loss' will @@ -254,8 +261,8 @@ class DistEvalHook(EvalHook): for action recognition dataset (RawframeDataset and VideoDataset). ``AR@AN``, ``auc`` for action localization dataset (ActivityNetDataset). If ``save_best`` is ``auto``, the first key - will be used. The interval of ``CheckpointHook`` should device - EvalHook. Default: None. + of the returned ``OrderedDict`` result will be used. The interval + of ``CheckpointHook`` should device EvalHook. Default: None. rule (str | None, optional): Comparison rule for best score. If set to None, it will infer a reasonable rule. Keys such as 'acc', 'top' .etc will be inferred by 'greater' rule. Keys contain 'loss' will From 9793e2100dd6d407071f9dc9bd41cd15eed80094 Mon Sep 17 00:00:00 2001 From: dreamerlin <528557675@qq.com> Date: Fri, 25 Dec 2020 16:48:42 +0800 Subject: [PATCH 10/16] fix --- mmaction/core/evaluation/eval_hooks.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/mmaction/core/evaluation/eval_hooks.py b/mmaction/core/evaluation/eval_hooks.py index 851015002d..6c53ef67fb 100644 --- a/mmaction/core/evaluation/eval_hooks.py +++ b/mmaction/core/evaluation/eval_hooks.py @@ -89,7 +89,7 @@ def __init__(self, assert isinstance(save_best, str) or save_best is None self.save_best = save_best self.eval_kwargs = eval_kwargs - self.initial_epoch_flag = True + self.initial_flag = True if self.save_best is not None: self._init_rule(rule, self.save_best) @@ -132,21 +132,21 @@ def before_train_iter(self, runner): """Evaluate the model only at the start of training by iteration.""" if self.by_epoch: return - if not self.initial_epoch_flag: + if not self.initial_flag: return if self.start is not None and runner.iter >= self.start: self.after_train_iter(runner) - self.initial_epoch_flag = False + self.initial_flag = False def before_train_epoch(self, runner): """Evaluate the model only at the start of training by epoch.""" if not self.by_epoch: return - if not self.initial_epoch_flag: + if not self.initial_flag: return if self.start is not None and runner.epoch >= self.start: self.after_train_epoch(runner) - self.initial_epoch_flag = False + self.initial_flag = False def after_train_iter(self, runner): """Called after every training iter to evaluate the results.""" @@ -190,7 +190,8 @@ def evaluation_flag(self, runner): # No evaluation if start is larger than the current time. return False else: - # Evaluation only at epochs 3, 5, 7... if start==3 and interval==2 + # Evaluation only at epochs/iters 3, 5, 7... + # if start==3 and interval==2 if (current + 1 - self.start) % self.interval: return False return True @@ -199,7 +200,7 @@ def _save_ckpt(self, runner, key_score): if self.by_epoch: current = f'epoch_{runner.epoch + 1}' else: - current = f'iter_{runner.epoch + 1}' + current = f'iter_{runner.iter + 1}' best_score = runner.meta['hook_msgs'].get( 'best_score', self.init_value_map[self.rule]) From 19c50e613fe5b7db8c7a22c50ae11913435c4d0f Mon Sep 17 00:00:00 2001 From: dreamerlin <528557675@qq.com> Date: Sat, 9 Jan 2021 15:41:12 +0800 Subject: [PATCH 11/16] polish --- mmaction/core/evaluation/eval_hooks.py | 20 ++++++++++++++------ tests/test_runtime/test_eval_hook.py | 12 ++++++------ 2 files changed, 20 insertions(+), 12 deletions(-) diff --git a/mmaction/core/evaluation/eval_hooks.py b/mmaction/core/evaluation/eval_hooks.py index 6c53ef67fb..83beb3a950 100644 --- a/mmaction/core/evaluation/eval_hooks.py +++ b/mmaction/core/evaluation/eval_hooks.py @@ -1,8 +1,8 @@ +import os import os.path as osp import warnings from math import inf -import mmcv from mmcv.runner import Hook from torch.utils.data import DataLoader @@ -92,6 +92,7 @@ def __init__(self, self.initial_flag = True if self.save_best is not None: + self.best_ckpt_path = None self._init_rule(rule, self.save_best) def _init_rule(self, rule, key_indicator): @@ -209,12 +210,19 @@ def _save_ckpt(self, runner, key_score): runner.meta['hook_msgs']['best_score'] = best_score last_ckpt = runner.meta['hook_msgs']['last_ckpt'] runner.meta['hook_msgs']['best_ckpt'] = last_ckpt - mmcv.symlink( - last_ckpt, - osp.join(runner.work_dir, f'best_{self.key_indicator}.pth')) + + if self.best_ckpt_path and osp.isfile(self.best_ckpt_path): + os.remove(self.best_ckpt_path) + + best_ckpt_name = f'best_{self.key_indicator}_{current}.pth' + runner.save_checkpoint( + runner.work_dir, best_ckpt_name, create_symlink=False) + self.best_ckpt_path = osp.join(runner.work_dir, best_ckpt_name) + runner.logger.info( - f'Now best checkpoint is {current}.pth.' - f'Best {self.key_indicator} is {best_score:0.4f}') + f'Now best checkpoint is saved as {best_ckpt_name}.' + f'Best {self.key_indicator} is {best_score:0.4f} at {current}.' + ) def evaluate(self, runner, results): """Evaluate the results. diff --git a/tests/test_runtime/test_eval_hook.py b/tests/test_runtime/test_eval_hook.py index 77fc6f962c..f670ce9613 100644 --- a/tests/test_runtime/test_eval_hook.py +++ b/tests/test_runtime/test_eval_hook.py @@ -139,7 +139,7 @@ def test_eval_hook(): runner.run([loader], [('train', 1)], 8) real_path = osp.join(tmpdir, 'epoch_4.pth') - link_path = osp.join(tmpdir, 'best_acc.pth') + link_path = osp.join(tmpdir, 'best_acc_epoch_4.pth') assert runner.meta['hook_msgs']['best_ckpt'] == osp.realpath( real_path) @@ -161,7 +161,7 @@ def test_eval_hook(): runner.run([loader], [('train', 1)], 8) real_path = osp.join(tmpdir, 'epoch_4.pth') - link_path = osp.join(tmpdir, 'best_acc.pth') + link_path = osp.join(tmpdir, 'best_acc_epoch_4.pth') assert runner.meta['hook_msgs']['best_ckpt'] == osp.realpath( real_path) @@ -181,7 +181,7 @@ def test_eval_hook(): runner.run([loader], [('train', 1)], 8) real_path = osp.join(tmpdir, 'epoch_4.pth') - link_path = osp.join(tmpdir, 'best_score.pth') + link_path = osp.join(tmpdir, 'best_score_epoch_4.pth') assert runner.meta['hook_msgs']['best_ckpt'] == osp.realpath( real_path) @@ -201,7 +201,7 @@ def test_eval_hook(): runner.run([loader], [('train', 1)], 8) real_path = osp.join(tmpdir, 'epoch_6.pth') - link_path = osp.join(tmpdir, 'best_acc.pth') + link_path = osp.join(tmpdir, 'best_acc_epoch_6.pth') assert runner.meta['hook_msgs']['best_ckpt'] == osp.realpath( real_path) @@ -220,7 +220,7 @@ def test_eval_hook(): runner.run([loader], [('train', 1)], 2) real_path = osp.join(tmpdir, 'epoch_2.pth') - link_path = osp.join(tmpdir, 'best_acc.pth') + link_path = osp.join(tmpdir, 'best_acc_epoch_2.pth') assert runner.meta['hook_msgs']['best_ckpt'] == osp.realpath( real_path) @@ -238,7 +238,7 @@ def test_eval_hook(): runner.run([loader], [('train', 1)], 8) real_path = osp.join(tmpdir, 'epoch_4.pth') - link_path = osp.join(tmpdir, 'best_acc.pth') + link_path = osp.join(tmpdir, 'best_acc_epoch_4.pth') assert runner.meta['hook_msgs']['best_ckpt'] == osp.realpath( real_path) From 62757022370b674006f1e9b779ab059fa5d733fb Mon Sep 17 00:00:00 2001 From: dreamerlin <528557675@qq.com> Date: Sat, 9 Jan 2021 20:50:14 +0800 Subject: [PATCH 12/16] polish logger info --- mmaction/core/evaluation/eval_hooks.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/mmaction/core/evaluation/eval_hooks.py b/mmaction/core/evaluation/eval_hooks.py index 83beb3a950..1843fa385b 100644 --- a/mmaction/core/evaluation/eval_hooks.py +++ b/mmaction/core/evaluation/eval_hooks.py @@ -200,16 +200,16 @@ def evaluation_flag(self, runner): def _save_ckpt(self, runner, key_score): if self.by_epoch: current = f'epoch_{runner.epoch + 1}' + cur_type, cur_time = 'epoch', runner.epoch + 1 else: current = f'iter_{runner.iter + 1}' + cur_type, cur_time = 'iter', runner.iter + 1 best_score = runner.meta['hook_msgs'].get( 'best_score', self.init_value_map[self.rule]) if self.compare_func(key_score, best_score): best_score = key_score runner.meta['hook_msgs']['best_score'] = best_score - last_ckpt = runner.meta['hook_msgs']['last_ckpt'] - runner.meta['hook_msgs']['best_ckpt'] = last_ckpt if self.best_ckpt_path and osp.isfile(self.best_ckpt_path): os.remove(self.best_ckpt_path) @@ -219,10 +219,12 @@ def _save_ckpt(self, runner, key_score): runner.work_dir, best_ckpt_name, create_symlink=False) self.best_ckpt_path = osp.join(runner.work_dir, best_ckpt_name) + runner.meta['hook_msgs']['best_ckpt'] = self.best_ckpt_path runner.logger.info( - f'Now best checkpoint is saved as {best_ckpt_name}.' - f'Best {self.key_indicator} is {best_score:0.4f} at {current}.' - ) + f'Now best checkpoint is saved as {best_ckpt_name}.') + runner.logger.info( + f'Best {self.key_indicator} is {best_score:0.4f} ' + f'at {cur_time} {cur_type}.') def evaluate(self, runner, results): """Evaluate the results. From 9923cdf0d08afde79c20c0433918d5d2957ad017 Mon Sep 17 00:00:00 2001 From: dreamerlin <528557675@qq.com> Date: Fri, 15 Jan 2021 23:36:03 +0800 Subject: [PATCH 13/16] fix unittest --- mmaction/core/evaluation/eval_hooks.py | 13 +++----- tests/test_runtime/test_eval_hook.py | 42 +++++++++++--------------- 2 files changed, 23 insertions(+), 32 deletions(-) diff --git a/mmaction/core/evaluation/eval_hooks.py b/mmaction/core/evaluation/eval_hooks.py index 1843fa385b..b2f1a92859 100644 --- a/mmaction/core/evaluation/eval_hooks.py +++ b/mmaction/core/evaluation/eval_hooks.py @@ -34,17 +34,13 @@ class EvalHook(Hook): ``top1_acc``, ``top5_acc``, ``mean_class_accuracy``, ``mean_average_precision``, ``mmit_mean_average_precision`` for action recognition dataset (RawframeDataset and VideoDataset). -<<<<<<< HEAD ``AR@AN``, ``auc`` for action localization dataset. (ActivityNetDataset). ``Recall@0.5@100``, ``AR@100``, ``mAP@0.5IOU`` for spatio-temporal action detection dataset - (AVADataset). Default: `top1_acc`. -======= - ``AR@AN``, ``auc`` for action localization dataset - (ActivityNetDataset). If ``save_best`` is ``auto``, the first key + (AVADataset). If ``save_best`` is ``auto``, the first key of the returned ``OrderedDict`` result will be used. The interval - of ``CheckpointHook`` should device EvalHook. Default: None. ->>>>>>> polish again + of ``EvalHook`` should be divisible by that of ``CheckpointHook``. + Default: 'top1_acc'. rule (str | None, optional): Comparison rule for best score. If set to None, it will infer a reasonable rule. Keys such as 'acc', 'top' .etc will be inferred by 'greater' rule. Keys contain 'loss' will @@ -273,7 +269,8 @@ class DistEvalHook(EvalHook): ``AR@AN``, ``auc`` for action localization dataset (ActivityNetDataset). If ``save_best`` is ``auto``, the first key of the returned ``OrderedDict`` result will be used. The interval - of ``CheckpointHook`` should device EvalHook. Default: None. + of ``EvalHook`` should be divisible of that in ``CheckpointHook``. + Default: None. rule (str | None, optional): Comparison rule for best score. If set to None, it will infer a reasonable rule. Keys such as 'acc', 'top' .etc will be inferred by 'greater' rule. Keys contain 'loss' will diff --git a/tests/test_runtime/test_eval_hook.py b/tests/test_runtime/test_eval_hook.py index f670ce9613..49e9ed4202 100644 --- a/tests/test_runtime/test_eval_hook.py +++ b/tests/test_runtime/test_eval_hook.py @@ -138,12 +138,11 @@ def test_eval_hook(): runner.register_hook(eval_hook) runner.run([loader], [('train', 1)], 8) - real_path = osp.join(tmpdir, 'epoch_4.pth') - link_path = osp.join(tmpdir, 'best_acc_epoch_4.pth') + ckpt_path = osp.join(tmpdir, 'best_acc_epoch_4.pth') assert runner.meta['hook_msgs']['best_ckpt'] == osp.realpath( - real_path) - assert osp.exists(link_path) + ckpt_path) + assert osp.exists(ckpt_path) assert runner.meta['hook_msgs']['best_score'] == 7 # total_epochs = 8, return the best acc and corresponding epoch @@ -160,12 +159,11 @@ def test_eval_hook(): runner.register_hook(eval_hook) runner.run([loader], [('train', 1)], 8) - real_path = osp.join(tmpdir, 'epoch_4.pth') - link_path = osp.join(tmpdir, 'best_acc_epoch_4.pth') + ckpt_path = osp.join(tmpdir, 'best_acc_epoch_4.pth') assert runner.meta['hook_msgs']['best_ckpt'] == osp.realpath( - real_path) - assert osp.exists(link_path) + ckpt_path) + assert osp.exists(ckpt_path) assert runner.meta['hook_msgs']['best_score'] == 7 # total_epochs = 8, return the best score and corresponding epoch @@ -180,12 +178,11 @@ def test_eval_hook(): runner.register_hook(eval_hook) runner.run([loader], [('train', 1)], 8) - real_path = osp.join(tmpdir, 'epoch_4.pth') - link_path = osp.join(tmpdir, 'best_score_epoch_4.pth') + ckpt_path = osp.join(tmpdir, 'best_score_epoch_4.pth') assert runner.meta['hook_msgs']['best_ckpt'] == osp.realpath( - real_path) - assert osp.exists(link_path) + ckpt_path) + assert osp.exists(ckpt_path) assert runner.meta['hook_msgs']['best_score'] == 7 # total_epochs = 8, return the best score using less compare func @@ -200,12 +197,11 @@ def test_eval_hook(): runner.register_hook(eval_hook) runner.run([loader], [('train', 1)], 8) - real_path = osp.join(tmpdir, 'epoch_6.pth') - link_path = osp.join(tmpdir, 'best_acc_epoch_6.pth') + ckpt_path = osp.join(tmpdir, 'best_acc_epoch_6.pth') assert runner.meta['hook_msgs']['best_ckpt'] == osp.realpath( - real_path) - assert osp.exists(link_path) + ckpt_path) + assert osp.exists(ckpt_path) assert runner.meta['hook_msgs']['best_score'] == -3 # Test the EvalHook when resume happend @@ -219,12 +215,11 @@ def test_eval_hook(): runner.register_hook(eval_hook) runner.run([loader], [('train', 1)], 2) - real_path = osp.join(tmpdir, 'epoch_2.pth') - link_path = osp.join(tmpdir, 'best_acc_epoch_2.pth') + ckpt_path = osp.join(tmpdir, 'best_acc_epoch_2.pth') assert runner.meta['hook_msgs']['best_ckpt'] == osp.realpath( - real_path) - assert osp.exists(link_path) + ckpt_path) + assert osp.exists(ckpt_path) assert runner.meta['hook_msgs']['best_score'] == 4 resume_from = osp.join(tmpdir, 'latest.pth') @@ -237,12 +232,11 @@ def test_eval_hook(): runner.resume(resume_from) runner.run([loader], [('train', 1)], 8) - real_path = osp.join(tmpdir, 'epoch_4.pth') - link_path = osp.join(tmpdir, 'best_acc_epoch_4.pth') + ckpt_path = osp.join(tmpdir, 'best_acc_epoch_4.pth') assert runner.meta['hook_msgs']['best_ckpt'] == osp.realpath( - real_path) - assert osp.exists(link_path) + ckpt_path) + assert osp.exists(ckpt_path) assert runner.meta['hook_msgs']['best_score'] == 7 From cbfbf823df60632441fe0363ffa96e7a62b3db35 Mon Sep 17 00:00:00 2001 From: dreamerlin <528557675@qq.com> Date: Wed, 27 Jan 2021 15:58:22 +0800 Subject: [PATCH 14/16] BC --- ...t_context_kinetics_pretrained_r50_4x16x1_20e_ava_rgb.py | 2 +- .../slowfast_kinetics_pretrained_r50_4x16x1_20e_ava_rgb.py | 2 +- .../slowfast_kinetics_pretrained_r50_8x8x1_20e_ava_rgb.py | 2 +- .../slowonly_kinetics_pretrained_r101_8x8x1_20e_ava_rgb.py | 2 +- .../slowonly_kinetics_pretrained_r50_4x16x1_20e_ava_rgb.py | 2 +- ...lowonly_omnisource_pretrained_r101_8x8x1_20e_ava_rgb.py | 2 +- ...lowonly_omnisource_pretrained_r50_4x16x1_20e_ava_rgb.py | 2 +- mmaction/core/evaluation/eval_hooks.py | 7 +++++++ 8 files changed, 14 insertions(+), 7 deletions(-) diff --git a/configs/detection/ava/slowfast_context_kinetics_pretrained_r50_4x16x1_20e_ava_rgb.py b/configs/detection/ava/slowfast_context_kinetics_pretrained_r50_4x16x1_20e_ava_rgb.py index 01d23d024d..1366382a6d 100644 --- a/configs/detection/ava/slowfast_context_kinetics_pretrained_r50_4x16x1_20e_ava_rgb.py +++ b/configs/detection/ava/slowfast_context_kinetics_pretrained_r50_4x16x1_20e_ava_rgb.py @@ -159,7 +159,7 @@ total_epochs = 20 checkpoint_config = dict(interval=1) workflow = [('train', 1)] -evaluation = dict(interval=1, key_indicator='mAP@0.5IOU') +evaluation = dict(interval=1, save_best='mAP@0.5IOU') log_config = dict( interval=20, hooks=[ dict(type='TextLoggerHook'), diff --git a/configs/detection/ava/slowfast_kinetics_pretrained_r50_4x16x1_20e_ava_rgb.py b/configs/detection/ava/slowfast_kinetics_pretrained_r50_4x16x1_20e_ava_rgb.py index 7ca4172b51..47ca925ce9 100644 --- a/configs/detection/ava/slowfast_kinetics_pretrained_r50_4x16x1_20e_ava_rgb.py +++ b/configs/detection/ava/slowfast_kinetics_pretrained_r50_4x16x1_20e_ava_rgb.py @@ -158,7 +158,7 @@ total_epochs = 20 checkpoint_config = dict(interval=1) workflow = [('train', 1)] -evaluation = dict(interval=1, key_indicator='mAP@0.5IOU') +evaluation = dict(interval=1, save_best='mAP@0.5IOU') log_config = dict( interval=20, hooks=[ dict(type='TextLoggerHook'), diff --git a/configs/detection/ava/slowfast_kinetics_pretrained_r50_8x8x1_20e_ava_rgb.py b/configs/detection/ava/slowfast_kinetics_pretrained_r50_8x8x1_20e_ava_rgb.py index 6300ff1623..a81ec62234 100644 --- a/configs/detection/ava/slowfast_kinetics_pretrained_r50_8x8x1_20e_ava_rgb.py +++ b/configs/detection/ava/slowfast_kinetics_pretrained_r50_8x8x1_20e_ava_rgb.py @@ -158,7 +158,7 @@ total_epochs = 20 checkpoint_config = dict(interval=1) workflow = [('train', 1)] -evaluation = dict(interval=1, key_indicator='mAP@0.5IOU') +evaluation = dict(interval=1, save_best='mAP@0.5IOU') log_config = dict( interval=20, hooks=[ dict(type='TextLoggerHook'), diff --git a/configs/detection/ava/slowonly_kinetics_pretrained_r101_8x8x1_20e_ava_rgb.py b/configs/detection/ava/slowonly_kinetics_pretrained_r101_8x8x1_20e_ava_rgb.py index 745f31e44c..ecdbde0ccc 100644 --- a/configs/detection/ava/slowonly_kinetics_pretrained_r101_8x8x1_20e_ava_rgb.py +++ b/configs/detection/ava/slowonly_kinetics_pretrained_r101_8x8x1_20e_ava_rgb.py @@ -143,7 +143,7 @@ total_epochs = 20 checkpoint_config = dict(interval=1) workflow = [('train', 1)] -evaluation = dict(interval=1, key_indicator='mAP@0.5IOU') +evaluation = dict(interval=1, save_best='mAP@0.5IOU') log_config = dict( interval=20, hooks=[ dict(type='TextLoggerHook'), diff --git a/configs/detection/ava/slowonly_kinetics_pretrained_r50_4x16x1_20e_ava_rgb.py b/configs/detection/ava/slowonly_kinetics_pretrained_r50_4x16x1_20e_ava_rgb.py index 55c6bd3454..e6edb13d7a 100644 --- a/configs/detection/ava/slowonly_kinetics_pretrained_r50_4x16x1_20e_ava_rgb.py +++ b/configs/detection/ava/slowonly_kinetics_pretrained_r50_4x16x1_20e_ava_rgb.py @@ -142,7 +142,7 @@ total_epochs = 20 checkpoint_config = dict(interval=1) workflow = [('train', 1)] -evaluation = dict(interval=1, key_indicator='mAP@0.5IOU') +evaluation = dict(interval=1, save_best='mAP@0.5IOU') log_config = dict( interval=20, hooks=[ dict(type='TextLoggerHook'), diff --git a/configs/detection/ava/slowonly_omnisource_pretrained_r101_8x8x1_20e_ava_rgb.py b/configs/detection/ava/slowonly_omnisource_pretrained_r101_8x8x1_20e_ava_rgb.py index 3d92ac41c5..5fa1b7e596 100644 --- a/configs/detection/ava/slowonly_omnisource_pretrained_r101_8x8x1_20e_ava_rgb.py +++ b/configs/detection/ava/slowonly_omnisource_pretrained_r101_8x8x1_20e_ava_rgb.py @@ -142,7 +142,7 @@ total_epochs = 20 checkpoint_config = dict(interval=1) workflow = [('train', 1)] -evaluation = dict(interval=1, key_indicator='mAP@0.5IOU') +evaluation = dict(interval=1, save_best='mAP@0.5IOU') log_config = dict( interval=20, hooks=[ dict(type='TextLoggerHook'), diff --git a/configs/detection/ava/slowonly_omnisource_pretrained_r50_4x16x1_20e_ava_rgb.py b/configs/detection/ava/slowonly_omnisource_pretrained_r50_4x16x1_20e_ava_rgb.py index bbf13c89a3..283417df85 100644 --- a/configs/detection/ava/slowonly_omnisource_pretrained_r50_4x16x1_20e_ava_rgb.py +++ b/configs/detection/ava/slowonly_omnisource_pretrained_r50_4x16x1_20e_ava_rgb.py @@ -143,7 +143,7 @@ total_epochs = 20 checkpoint_config = dict(interval=1) workflow = [('train', 1)] -evaluation = dict(interval=1, key_indicator='mAP@0.5IOU') +evaluation = dict(interval=1, save_best='mAP@0.5IOU') log_config = dict( interval=20, hooks=[ dict(type='TextLoggerHook'), diff --git a/mmaction/core/evaluation/eval_hooks.py b/mmaction/core/evaluation/eval_hooks.py index b2f1a92859..f76e18f3ea 100644 --- a/mmaction/core/evaluation/eval_hooks.py +++ b/mmaction/core/evaluation/eval_hooks.py @@ -63,6 +63,13 @@ def __init__(self, save_best=None, rule=None, **eval_kwargs): + if 'key_indicator' in eval_kwargs: + raise RuntimeError( + '"key_indicator" is deprecated, ' + 'you need to use "save_best" instead. ' + 'See https://github.com/open-mmlab/mmaction2/pull/395 for more info' # noqa: E501 + ) + if not isinstance(dataloader, DataLoader): raise TypeError(f'dataloader must be a pytorch DataLoader, ' f'but got {type(dataloader)}') From 629561a91fba1ccbc414c662ce81e98161263dcf Mon Sep 17 00:00:00 2001 From: dreamerlin <528557675@qq.com> Date: Wed, 27 Jan 2021 16:28:36 +0800 Subject: [PATCH 15/16] add deprecated class --- mmaction/core/evaluation/eval_hooks.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/mmaction/core/evaluation/eval_hooks.py b/mmaction/core/evaluation/eval_hooks.py index f76e18f3ea..a302040ebd 100644 --- a/mmaction/core/evaluation/eval_hooks.py +++ b/mmaction/core/evaluation/eval_hooks.py @@ -332,3 +332,21 @@ def _do_evaluate(self, runner): if self.save_best: self._save_ckpt(runner, key_score) + + +class EpochEvalHook(EvalHook): + """Deprecated class for ``EvalHook``.""" + + def __init__(self, *args, **kwargs): + warnings.warn('"EpochEvalHook" is deprecated, please switch to' + '"EvalHook"') + super().__init__(*args, **kwargs) + + +class DistEpochEvalHook(DistEvalHook): + """Deprecated class for ``DistEvalHook``.""" + + def __init__(self, *args, **kwargs): + warnings.warn('"DistEpochEvalHook" is deprecated, please switch to' + '"DistEvalHook"') + super().__init__(*args, **kwargs) From 45c85e9c7366a215aafd40b9aacbe4e01716fd18 Mon Sep 17 00:00:00 2001 From: dreamerlin <528557675@qq.com> Date: Wed, 27 Jan 2021 16:34:02 +0800 Subject: [PATCH 16/16] add deprecated info --- mmaction/core/evaluation/eval_hooks.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/mmaction/core/evaluation/eval_hooks.py b/mmaction/core/evaluation/eval_hooks.py index a302040ebd..09524d6de2 100644 --- a/mmaction/core/evaluation/eval_hooks.py +++ b/mmaction/core/evaluation/eval_hooks.py @@ -338,8 +338,10 @@ class EpochEvalHook(EvalHook): """Deprecated class for ``EvalHook``.""" def __init__(self, *args, **kwargs): - warnings.warn('"EpochEvalHook" is deprecated, please switch to' - '"EvalHook"') + warnings.warn( + '"EpochEvalHook" is deprecated, please switch to' + '"EvalHook". See https://github.com/open-mmlab/mmaction2/pull/395 for more info' # noqa: E501 + ) super().__init__(*args, **kwargs) @@ -347,6 +349,8 @@ class DistEpochEvalHook(DistEvalHook): """Deprecated class for ``DistEvalHook``.""" def __init__(self, *args, **kwargs): - warnings.warn('"DistEpochEvalHook" is deprecated, please switch to' - '"DistEvalHook"') + warnings.warn( + '"DistEpochEvalHook" is deprecated, please switch to' + '"DistEvalHook". See https://github.com/open-mmlab/mmaction2/pull/395 for more info' # noqa: E501 + ) super().__init__(*args, **kwargs)