Skip to content

Commit

Permalink
[Refactor] refactor EvalHook (#395)
Browse files Browse the repository at this point in the history
* polish eval hooks

* update unittest

* add by_epoch

* correct unittest

* correct unittest

* use runner.hook

* add comment

* polish

* polish again

* fix

* polish

* polish logger info

* fix unittest

* BC

* add deprecated class

* add deprecated info
  • Loading branch information
dreamerlin authored Jan 27, 2021
1 parent 944c5b1 commit 4240105
Show file tree
Hide file tree
Showing 18 changed files with 422 additions and 384 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@
total_epochs = 20
checkpoint_config = dict(interval=1)
workflow = [('train', 1)]
evaluation = dict(interval=1, key_indicator='[email protected]')
evaluation = dict(interval=1, save_best='[email protected]')
log_config = dict(
interval=20, hooks=[
dict(type='TextLoggerHook'),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@
total_epochs = 20
checkpoint_config = dict(interval=1)
workflow = [('train', 1)]
evaluation = dict(interval=1, key_indicator='[email protected]')
evaluation = dict(interval=1, save_best='[email protected]')
log_config = dict(
interval=20, hooks=[
dict(type='TextLoggerHook'),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@
total_epochs = 20
checkpoint_config = dict(interval=1)
workflow = [('train', 1)]
evaluation = dict(interval=1, key_indicator='[email protected]')
evaluation = dict(interval=1, save_best='[email protected]')
log_config = dict(
interval=20, hooks=[
dict(type='TextLoggerHook'),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@
total_epochs = 20
checkpoint_config = dict(interval=1)
workflow = [('train', 1)]
evaluation = dict(interval=1, key_indicator='[email protected]')
evaluation = dict(interval=1, save_best='[email protected]')
log_config = dict(
interval=20, hooks=[
dict(type='TextLoggerHook'),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@
total_epochs = 20
checkpoint_config = dict(interval=1)
workflow = [('train', 1)]
evaluation = dict(interval=1, key_indicator='[email protected]')
evaluation = dict(interval=1, save_best='[email protected]')
log_config = dict(
interval=20, hooks=[
dict(type='TextLoggerHook'),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@
total_epochs = 20
checkpoint_config = dict(interval=1)
workflow = [('train', 1)]
evaluation = dict(interval=1, key_indicator='[email protected]')
evaluation = dict(interval=1, save_best='[email protected]')
log_config = dict(
interval=20, hooks=[
dict(type='TextLoggerHook'),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@
total_epochs = 20
checkpoint_config = dict(interval=1)
workflow = [('train', 1)]
evaluation = dict(interval=1, key_indicator='[email protected]')
evaluation = dict(interval=1, save_best='[email protected]')
log_config = dict(
interval=20, hooks=[
dict(type='TextLoggerHook'),
Expand Down
6 changes: 3 additions & 3 deletions docs/tutorials/7_customize_runtime.md
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ We support many other learning rate schedule [here](https://github.com/open-mmla

## Customize Workflow

By default, we recommend users to use `EpochEvalHook` to do evaluation after training epoch, but they can still use `val` workflow as an alternative.
By default, we recommend users to use `EvalHook` to do evaluation after training epoch, but they can still use `val` workflow as an alternative.

Workflow is a list of (phase, epochs) to specify the running order and epochs. By default it is set to be

Expand All @@ -213,7 +213,7 @@ so that 1 epoch for training and 1 epoch for validation will be run iteratively.

1. The parameters of model will not be updated during val epoch.
2. Keyword `total_epochs` in the config only controls the number of training epochs and will not affect the validation workflow.
3. Workflows `[('train', 1), ('val', 1)]` and `[('train', 1)]` will not change the behavior of `EpochEvalHook` because `EpochEvalHook` is called by `after_train_epoch` and validation workflow only affect hooks that are called through `after_val_epoch`.
3. Workflows `[('train', 1), ('val', 1)]` and `[('train', 1)]` will not change the behavior of `EvalHook` because `EvalHook` is called by `after_train_epoch` and validation workflow only affect hooks that are called through `after_val_epoch`.
Therefore, the only difference between `[('train', 1), ('val', 1)]` and ``[('train', 1)]`` is that the runner will calculate losses on validation set after each training epoch.

## Customize Hooks
Expand Down Expand Up @@ -344,7 +344,7 @@ log_config = dict(

#### Evaluation config

The config of `evaluation` will be used to initialize the [`EpochEvalHook`](https://github.com/open-mmlab/mmaction2/blob/master/mmaction/core/evaluation/eval_hooks.py#L12).
The config of `evaluation` will be used to initialize the [`EvalHook`](https://github.com/open-mmlab/mmaction2/blob/master/mmaction/core/evaluation/eval_hooks.py#L12).
Except the key `interval`, other arguments such as `metrics` will be passed to the `dataset.evaluate()`

```python
Expand Down
6 changes: 3 additions & 3 deletions mmaction/apis/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
build_optimizer)
from mmcv.runner.hooks import Fp16OptimizerHook

from ..core import (DistEpochEvalHook, EpochEvalHook,
OmniSourceDistSamplerSeedHook, OmniSourceRunner)
from ..core import (DistEvalHook, EvalHook, OmniSourceDistSamplerSeedHook,
OmniSourceRunner)
from ..datasets import build_dataloader, build_dataset
from ..utils import PreciseBNHook, get_root_logger

Expand Down Expand Up @@ -143,7 +143,7 @@ def train_model(model,
dataloader_setting = dict(dataloader_setting,
**cfg.data.get('val_dataloader', {}))
val_dataloader = build_dataloader(val_dataset, **dataloader_setting)
eval_hook = DistEpochEvalHook if distributed else EpochEvalHook
eval_hook = DistEvalHook if distributed else EvalHook
runner.register_hook(eval_hook(val_dataloader, **eval_cfg))

if cfg.resume_from:
Expand Down
13 changes: 6 additions & 7 deletions mmaction/core/evaluation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,12 @@
mmit_mean_average_precision, pairwise_temporal_iou,
softmax, top_k_accuracy)
from .eval_detection import ActivityNetLocalization
from .eval_hooks import DistEpochEvalHook, EpochEvalHook
from .eval_hooks import DistEvalHook, EvalHook

__all__ = [
'DistEpochEvalHook', 'EpochEvalHook', 'top_k_accuracy',
'mean_class_accuracy', 'confusion_matrix', 'mean_average_precision',
'get_weighted_score', 'average_recall_at_avg_proposals',
'pairwise_temporal_iou', 'average_precision_at_temporal_iou',
'ActivityNetLocalization', 'softmax', 'interpolated_precision_recall',
'mmit_mean_average_precision'
'DistEvalHook', 'EvalHook', 'top_k_accuracy', 'mean_class_accuracy',
'confusion_matrix', 'mean_average_precision', 'get_weighted_score',
'average_recall_at_avg_proposals', 'pairwise_temporal_iou',
'average_precision_at_temporal_iou', 'ActivityNetLocalization', 'softmax',
'interpolated_precision_recall', 'mmit_mean_average_precision'
]
Loading

0 comments on commit 4240105

Please sign in to comment.