Skip to content

Commit

Permalink
add by_epoch
Browse files Browse the repository at this point in the history
  • Loading branch information
dreamerlin committed Dec 19, 2020
1 parent db2a67a commit ad039ab
Show file tree
Hide file tree
Showing 6 changed files with 109 additions and 69 deletions.
6 changes: 3 additions & 3 deletions docs/tutorials/customize_runtime.md
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ We support many other learning rate schedule [here](https://github.com/open-mmla

## Customize Workflow

By default, we recommend users to use `EpochEvalHook` to do evaluation after training epoch, but they can still use `val` workflow as an alternative.
By default, we recommend users to use `EvalHook` to do evaluation after training epoch, but they can still use `val` workflow as an alternative.

Workflow is a list of (phase, epochs) to specify the running order and epochs. By default it is set to be

Expand All @@ -213,7 +213,7 @@ so that 1 epoch for training and 1 epoch for validation will be run iteratively.

1. The parameters of model will not be updated during val epoch.
2. Keyword `total_epochs` in the config only controls the number of training epochs and will not affect the validation workflow.
3. Workflows `[('train', 1), ('val', 1)]` and `[('train', 1)]` will not change the behavior of `EpochEvalHook` because `EpochEvalHook` is called by `after_train_epoch` and validation workflow only affect hooks that are called through `after_val_epoch`.
3. Workflows `[('train', 1), ('val', 1)]` and `[('train', 1)]` will not change the behavior of `EvalHook` because `EvalHook` is called by `after_train_epoch` and validation workflow only affect hooks that are called through `after_val_epoch`.
Therefore, the only difference between `[('train', 1), ('val', 1)]` and ``[('train', 1)]`` is that the runner will calculate losses on validation set after each training epoch.

## Customize Hooks
Expand Down Expand Up @@ -344,7 +344,7 @@ log_config = dict(

#### Evaluation config

The config of `evaluation` will be used to initialize the [`EpochEvalHook`](https://github.com/open-mmlab/mmaction2/blob/master/mmaction/core/evaluation/eval_hooks.py#L12).
The config of `evaluation` will be used to initialize the [`EvalHook`](https://github.com/open-mmlab/mmaction2/blob/master/mmaction/core/evaluation/eval_hooks.py#L12).
Except the key `interval`, other arguments such as `metrics` will be passed to the `dataset.evaluate()`

```python
Expand Down
6 changes: 3 additions & 3 deletions mmaction/apis/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
build_optimizer)
from mmcv.runner.hooks import Fp16OptimizerHook

from ..core import (DistEpochEvalHook, EpochEvalHook,
OmniSourceDistSamplerSeedHook, OmniSourceRunner)
from ..core import (DistEvalHook, EvalHook, OmniSourceDistSamplerSeedHook,
OmniSourceRunner)
from ..datasets import build_dataloader, build_dataset
from ..utils import get_root_logger

Expand Down Expand Up @@ -128,7 +128,7 @@ def train_model(model,
dataloader_setting = dict(dataloader_setting,
**cfg.data.get('val_dataloader', {}))
val_dataloader = build_dataloader(val_dataset, **dataloader_setting)
eval_hook = DistEpochEvalHook if distributed else EpochEvalHook
eval_hook = DistEvalHook if distributed else EvalHook
runner.register_hook(eval_hook(val_dataloader, **eval_cfg))

if cfg.resume_from:
Expand Down
13 changes: 6 additions & 7 deletions mmaction/core/evaluation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,12 @@
mmit_mean_average_precision, pairwise_temporal_iou,
softmax, top_k_accuracy)
from .eval_detection import ActivityNetLocalization
from .eval_hooks import DistEpochEvalHook, EpochEvalHook
from .eval_hooks import DistEvalHook, EvalHook

__all__ = [
'DistEpochEvalHook', 'EpochEvalHook', 'top_k_accuracy',
'mean_class_accuracy', 'confusion_matrix', 'mean_average_precision',
'get_weighted_score', 'average_recall_at_avg_proposals',
'pairwise_temporal_iou', 'average_precision_at_temporal_iou',
'ActivityNetLocalization', 'softmax', 'interpolated_precision_recall',
'mmit_mean_average_precision'
'DistEvalHook', 'EvalHook', 'top_k_accuracy', 'mean_class_accuracy',
'confusion_matrix', 'mean_average_precision', 'get_weighted_score',
'average_recall_at_avg_proposals', 'pairwise_temporal_iou',
'average_precision_at_temporal_iou', 'ActivityNetLocalization', 'softmax',
'interpolated_precision_recall', 'mmit_mean_average_precision'
]
102 changes: 72 additions & 30 deletions mmaction/core/evaluation/eval_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@
from mmaction.utils import get_root_logger


class EpochEvalHook(Hook):
"""Non-Distributed evaluation hook based on epochs.
class EvalHook(Hook):
"""Non-Distributed evaluation hook.
Notes:
If new arguments are added for EpochEvalHook, tools/test.py,
If new arguments are added for EvalHook, tools/test.py,
tools/eval_metric.py may be effected.
This hook will regularly perform evaluation in a given interval when
Expand All @@ -25,7 +25,10 @@ class EpochEvalHook(Hook):
evaluation before the training starts if ``start`` <= the resuming
epoch. If None, whether to evaluate is merely decided by
``interval``. Default: None.
interval (int): Evaluation interval (by epochs). Default: 1.
interval (int): Evaluation interval. Default: 1.
by_epoch (bool): Determine perform evaluation by epoch or by iteration.
If set to True, it will perform by epoch. Otherwise, by iteration.
default: True.
save_best (str | None, optional): If a metric is specified, it would
measure the best checkpoint during evaluation. The information
about best checkpoint would be save in best.json.
Expand Down Expand Up @@ -53,6 +56,7 @@ def __init__(self,
dataloader,
start=None,
interval=1,
by_epoch=True,
save_best=None,
rule=None,
**eval_kwargs):
Expand All @@ -63,6 +67,8 @@ def __init__(self,
if interval <= 0:
raise ValueError(f'interval must be positive, but got {interval}')

assert isinstance(by_epoch, bool)

if start is not None and start < 0:
warnings.warn(
f'The evaluation start epoch {start} is smaller than 0, '
Expand All @@ -71,6 +77,7 @@ def __init__(self,
self.dataloader = dataloader
self.interval = interval
self.start = start
self.by_epoch = by_epoch

assert isinstance(save_best, str) or save_best is None
self.save_best = save_best
Expand Down Expand Up @@ -116,45 +123,77 @@ def before_run(self, runner):
runner.meta = dict()
runner.meta.setdefault('hook_msgs', dict())

def before_train_iter(self, runner):
"""Evaluate the model only at the start of training by iteration."""
if self.by_epoch:
return
if not self.initial_epoch_flag:
return
if self.start is not None and runner.iter >= self.start:
self.after_train_iter(runner)
self.initial_epoch_flag = False

def before_train_epoch(self, runner):
"""Evaluate the model only at the start of training."""
"""Evaluate the model only at the start of training by epoch."""
if not self.by_epoch:
return
if not self.initial_epoch_flag:
return
if self.start is not None and runner.epoch >= self.start:
self.after_train_epoch(runner)
self.initial_epoch_flag = False

def after_train_iter(self, runner):
"""Called after every training iter to evaluate the results."""
self._do_evaluate(runner)

def after_train_epoch(self, runner):
"""Called after every training epoch to evaluate the results."""
self._do_evaluate(runner)

def _do_evaluate(self, runner):
"""perform evaluation and save ckpt."""
if not self.evaluation_flag(runner):
return

from mmaction.apis import single_gpu_test
results = single_gpu_test(runner.model, self.dataloader)
key_score = self.evaluate(runner, results)
if self.save_best:
self._save_ckpt(runner, key_score)

def evaluation_flag(self, runner):
"""Judge whether to perform_evaluation after this epoch.
"""Judge whether to perform_evaluation.
Returns:
bool: The flag indicating whether to perform evaluation.
"""
if self.by_epoch:
current = runner.epoch
check_time = self.every_n_epochs
else:
current = runner.iter
check_time = self.every_n_iters

if self.start is None:
if not self.every_n_epochs(runner, self.interval):
# No evaluation during the interval epochs.
if not check_time(runner, self.interval):
# No evaluation during the interval.
return False
elif (runner.epoch + 1) < self.start:
# No evaluation if start is larger than the current epoch.
elif (current + 1) < self.start:
# No evaluation if start is larger than the current time.
return False
else:
# Evaluation only at epochs 3, 5, 7... if start==3 and interval==2
if (runner.epoch + 1 - self.start) % self.interval:
if (current + 1 - self.start) % self.interval:
return False
return True

def after_train_epoch(self, runner):
"""Called after every training epoch to evaluate the results."""
if not self.evaluation_flag(runner):
return

from mmaction.apis import single_gpu_test
results = single_gpu_test(runner.model, self.dataloader)
key_score = self.evaluate(runner, results)
if self.save_best:
self._save_ckpt(runner, key_score)

def _save_ckpt(self, runner, key_score):
if self.by_epoch:
current = f'epoch_{runner.epoch + 1}'
else:
current = f'iter_{runner.epoch + 1}'

best_score = runner.meta['hook_msgs'].get(
'best_score', self.init_value_map[self.rule])
if self.compare_func(key_score, best_score):
Expand All @@ -165,9 +204,8 @@ def _save_ckpt(self, runner, key_score):
mmcv.symlink(
last_ckpt,
osp.join(runner.work_dir, f'best_{self.key_indicator}.pth'))
self.logger.info(
f'Now best checkpoint is epoch_{runner.epoch + 1}.pth.'
f'Best {self.key_indicator} is {best_score:0.4f}')
self.logger.info(f'Now best checkpoint is {current}.pth.'
f'Best {self.key_indicator} is {best_score:0.4f}')

def evaluate(self, runner, results):
"""Evaluate the results.
Expand All @@ -190,8 +228,8 @@ def evaluate(self, runner, results):
return None


class DistEpochEvalHook(EpochEvalHook):
"""Distributed evaluation hook based on epochs.
class DistEvalHook(EvalHook):
"""Distributed evaluation hook.
This hook will regularly perform evaluation in a given interval when
performing in distributed environment.
Expand All @@ -202,7 +240,10 @@ class DistEpochEvalHook(EpochEvalHook):
evaluation before the training starts if ``start`` <= the resuming
epoch. If None, whether to evaluate is merely decided by
``interval``. Default: None.
interval (int): Evaluation interval (by epochs). Default: 1.
interval (int): Evaluation interval. Default: 1.
by_epoch (bool): Determine perform evaluation by epoch or by iteration.
If set to True, it will perform by epoch. Otherwise, by iteration.
default: True.
save_best (str | None, optional): If a metric is specified, it would
measure the best checkpoint during evaluation. The information
about best checkpoint would be save in best.json.
Expand All @@ -229,6 +270,7 @@ def __init__(self,
dataloader,
start=None,
interval=1,
by_epoch=True,
save_best=None,
rule=None,
tmpdir=None,
Expand All @@ -238,14 +280,14 @@ def __init__(self,
dataloader,
start=start,
interval=interval,
by_epoch=by_epoch,
save_best=save_best,
rule=rule,
**eval_kwargs)
self.tmpdir = tmpdir
self.gpu_collect = gpu_collect

def after_train_epoch(self, runner):
"""Called after each training epoch to evaluate the model."""
def _do_evaluate(self, runner):
if not self.evaluation_flag(runner):
return

Expand Down
Loading

0 comments on commit ad039ab

Please sign in to comment.