Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Refactor] refactor EvalHook #395

Merged
merged 16 commits into from
Jan 27, 2021
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@
total_epochs = 20
checkpoint_config = dict(interval=1)
workflow = [('train', 1)]
evaluation = dict(interval=1, key_indicator='[email protected]')
evaluation = dict(interval=1, save_best='[email protected]')
log_config = dict(
interval=20, hooks=[
dict(type='TextLoggerHook'),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@
total_epochs = 20
checkpoint_config = dict(interval=1)
workflow = [('train', 1)]
evaluation = dict(interval=1, key_indicator='[email protected]')
evaluation = dict(interval=1, save_best='[email protected]')
log_config = dict(
interval=20, hooks=[
dict(type='TextLoggerHook'),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@
total_epochs = 20
checkpoint_config = dict(interval=1)
workflow = [('train', 1)]
evaluation = dict(interval=1, key_indicator='[email protected]')
evaluation = dict(interval=1, save_best='[email protected]')
log_config = dict(
interval=20, hooks=[
dict(type='TextLoggerHook'),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@
total_epochs = 20
checkpoint_config = dict(interval=1)
workflow = [('train', 1)]
evaluation = dict(interval=1, key_indicator='[email protected]')
evaluation = dict(interval=1, save_best='[email protected]')
log_config = dict(
interval=20, hooks=[
dict(type='TextLoggerHook'),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@
total_epochs = 20
checkpoint_config = dict(interval=1)
workflow = [('train', 1)]
evaluation = dict(interval=1, key_indicator='[email protected]')
evaluation = dict(interval=1, save_best='[email protected]')
log_config = dict(
interval=20, hooks=[
dict(type='TextLoggerHook'),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@
total_epochs = 20
checkpoint_config = dict(interval=1)
workflow = [('train', 1)]
evaluation = dict(interval=1, key_indicator='[email protected]')
evaluation = dict(interval=1, save_best='[email protected]')
log_config = dict(
interval=20, hooks=[
dict(type='TextLoggerHook'),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@
total_epochs = 20
checkpoint_config = dict(interval=1)
workflow = [('train', 1)]
evaluation = dict(interval=1, key_indicator='[email protected]')
evaluation = dict(interval=1, save_best='[email protected]')
log_config = dict(
interval=20, hooks=[
dict(type='TextLoggerHook'),
Expand Down
6 changes: 3 additions & 3 deletions docs/tutorials/7_customize_runtime.md
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ We support many other learning rate schedule [here](https://github.com/open-mmla

## Customize Workflow

By default, we recommend users to use `EpochEvalHook` to do evaluation after training epoch, but they can still use `val` workflow as an alternative.
By default, we recommend users to use `EvalHook` to do evaluation after training epoch, but they can still use `val` workflow as an alternative.

Workflow is a list of (phase, epochs) to specify the running order and epochs. By default it is set to be

Expand All @@ -213,7 +213,7 @@ so that 1 epoch for training and 1 epoch for validation will be run iteratively.

1. The parameters of model will not be updated during val epoch.
2. Keyword `total_epochs` in the config only controls the number of training epochs and will not affect the validation workflow.
3. Workflows `[('train', 1), ('val', 1)]` and `[('train', 1)]` will not change the behavior of `EpochEvalHook` because `EpochEvalHook` is called by `after_train_epoch` and validation workflow only affect hooks that are called through `after_val_epoch`.
3. Workflows `[('train', 1), ('val', 1)]` and `[('train', 1)]` will not change the behavior of `EvalHook` because `EvalHook` is called by `after_train_epoch` and validation workflow only affect hooks that are called through `after_val_epoch`.
Therefore, the only difference between `[('train', 1), ('val', 1)]` and ``[('train', 1)]`` is that the runner will calculate losses on validation set after each training epoch.

## Customize Hooks
Expand Down Expand Up @@ -344,7 +344,7 @@ log_config = dict(

#### Evaluation config

The config of `evaluation` will be used to initialize the [`EpochEvalHook`](https://github.com/open-mmlab/mmaction2/blob/master/mmaction/core/evaluation/eval_hooks.py#L12).
The config of `evaluation` will be used to initialize the [`EvalHook`](https://github.com/open-mmlab/mmaction2/blob/master/mmaction/core/evaluation/eval_hooks.py#L12).
Except the key `interval`, other arguments such as `metrics` will be passed to the `dataset.evaluate()`

```python
Expand Down
6 changes: 3 additions & 3 deletions mmaction/apis/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
build_optimizer)
from mmcv.runner.hooks import Fp16OptimizerHook

from ..core import (DistEpochEvalHook, EpochEvalHook,
OmniSourceDistSamplerSeedHook, OmniSourceRunner)
from ..core import (DistEvalHook, EvalHook, OmniSourceDistSamplerSeedHook,
OmniSourceRunner)
from ..datasets import build_dataloader, build_dataset
from ..utils import PreciseBNHook, get_root_logger

Expand Down Expand Up @@ -143,7 +143,7 @@ def train_model(model,
dataloader_setting = dict(dataloader_setting,
**cfg.data.get('val_dataloader', {}))
val_dataloader = build_dataloader(val_dataset, **dataloader_setting)
eval_hook = DistEpochEvalHook if distributed else EpochEvalHook
eval_hook = DistEvalHook if distributed else EvalHook
runner.register_hook(eval_hook(val_dataloader, **eval_cfg))

if cfg.resume_from:
Expand Down
13 changes: 6 additions & 7 deletions mmaction/core/evaluation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,12 @@
mmit_mean_average_precision, pairwise_temporal_iou,
softmax, top_k_accuracy)
from .eval_detection import ActivityNetLocalization
from .eval_hooks import DistEpochEvalHook, EpochEvalHook
from .eval_hooks import DistEvalHook, EvalHook

__all__ = [
'DistEpochEvalHook', 'EpochEvalHook', 'top_k_accuracy',
'mean_class_accuracy', 'confusion_matrix', 'mean_average_precision',
'get_weighted_score', 'average_recall_at_avg_proposals',
'pairwise_temporal_iou', 'average_precision_at_temporal_iou',
'ActivityNetLocalization', 'softmax', 'interpolated_precision_recall',
'mmit_mean_average_precision'
'DistEvalHook', 'EvalHook', 'top_k_accuracy', 'mean_class_accuracy',
'confusion_matrix', 'mean_average_precision', 'get_weighted_score',
'average_recall_at_avg_proposals', 'pairwise_temporal_iou',
'average_precision_at_temporal_iou', 'ActivityNetLocalization', 'softmax',
'interpolated_precision_recall', 'mmit_mean_average_precision'
]
Loading