Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Improvement] save best ckpt during training #464

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,14 @@
- Support training and testing for Spatio-Temporal Action Detection ([#351](https://github.com/open-mmlab/mmaction2/pull/351))
- Fix CI due to pip upgrade ([#454](https://github.com/open-mmlab/mmaction2/pull/454))
- Add markdown lint in pre-commit hook ([#255](https://github.com/open-mmlab/mmaction2/pull/225))
- Use title case in modelzoo statistics. ([#456](https://github.com/open-mmlab/mmaction2/pull/456))
- Use title case in modelzoo statistics ([#456](https://github.com/open-mmlab/mmaction2/pull/456))
- Add FAQ documents for easy troubleshooting. ([#413](https://github.com/open-mmlab/mmaction2/pull/413), [#420](https://github.com/open-mmlab/mmaction2/pull/420), [#439](https://github.com/open-mmlab/mmaction2/pull/439))
- Save best checkpoint during training. ([#464](https://github.com/open-mmlab/mmaction2/pull/464))

**Bug and Typo Fixes**

- Fix typo in default argument of BaseHead. ([#446](https://github.com/open-mmlab/mmaction2/pull/446))
- Fix typo in default argument of BaseHead ([#446](https://github.com/open-mmlab/mmaction2/pull/446))
- Fix potential bug about `output_config` overwrite ([#463](https://github.com/open-mmlab/mmaction2/pull/463))

**ModelZoo**

Expand Down
53 changes: 36 additions & 17 deletions mmaction/core/evaluation/eval_hooks.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import os.path as osp
import warnings
from math import inf
Expand Down Expand Up @@ -28,6 +29,8 @@ class EpochEvalHook(Hook):
interval (int): Evaluation interval (by epochs). Default: 1.
save_best (bool): Whether to save best checkpoint during evaluation.
Default: True.
save_best_ckpt (bool): Whether to save best checkpoint in work_dir.
Default: None.
key_indicator (str | None): Key indicator to measure the best
checkpoint during evaluation when ``save_best`` is set to True.
Options are the evaluation metrics to the test dataset. e.g.,
Expand All @@ -53,6 +56,7 @@ def __init__(self,
start=None,
interval=1,
save_best=True,
save_best_ckpt=False,
key_indicator='top1_acc',
rule=None,
**eval_kwargs):
Expand Down Expand Up @@ -93,6 +97,8 @@ def __init__(self,
self.start = start
self.eval_kwargs = eval_kwargs
self.save_best = save_best
self.save_best_ckpt = save_best_ckpt
self._cur_best_ckpt_path = None
self.key_indicator = key_indicator
self.rule = rule

Expand Down Expand Up @@ -132,6 +138,30 @@ def evaluation_flag(self, runner):
return False
return True

def _do_save_best(self, key_score, json_path, current_ckpt_path, runner):
if (self.save_best and self.compare_func(key_score, self.best_score)):
self.best_score = key_score
self.logger.info(
f'Now best checkpoint is epoch_{runner.epoch + 1}.pth')
self.best_json['best_score'] = self.best_score
self.best_json['best_ckpt'] = current_ckpt_path
self.best_json['key_indicator'] = self.key_indicator
mmcv.dump(self.best_json, json_path)

if self.save_best_ckpt:
# remove previous best ckpt
if self._cur_best_ckpt_path and \
osp.isfile(self._cur_best_ckpt_path):
os.remove(self._cur_best_ckpt_path)

# save current checkpoint in work_dir
# checkpoint name 'best_{best_score}_{epoch_id}.pth'
cur_best_ckpt_name = 'best_{:.4f}_{}.pth'.format(
key_score, runner.epoch + 1)
runner.save_checkpoint(runner.work_dir, cur_best_ckpt_name)
self._cur_best_ckpt_path = osp.join(runner.work_dir,
cur_best_ckpt_name)

def after_train_epoch(self, runner):
"""Called after every training epoch to evaluate the results."""
if not self.evaluation_flag(runner):
Expand All @@ -150,14 +180,7 @@ def after_train_epoch(self, runner):
from mmaction.apis import single_gpu_test
results = single_gpu_test(runner.model, self.dataloader)
key_score = self.evaluate(runner, results)
if (self.save_best and self.compare_func(key_score, self.best_score)):
self.best_score = key_score
self.logger.info(
f'Now best checkpoint is epoch_{runner.epoch + 1}.pth')
self.best_json['best_score'] = self.best_score
self.best_json['best_ckpt'] = current_ckpt_path
self.best_json['key_indicator'] = self.key_indicator
mmcv.dump(self.best_json, json_path)
self._do_save_best(key_score, json_path, current_ckpt_path, runner)

def evaluate(self, runner, results):
"""Evaluate the results.
Expand Down Expand Up @@ -197,6 +220,8 @@ class DistEpochEvalHook(EpochEvalHook):
interval (int): Evaluation interval (by epochs). Default: 1.
save_best (bool): Whether to save best checkpoint during evaluation.
Default: True.
save_best_ckpt (bool): Whether to save best checkpoint in work_dir.
Default: None.
key_indicator (str | None): Key indicator to measure the best
checkpoint during evaluation when ``save_best`` is set to True.
Options are the evaluation metrics to the test dataset. e.g.,
Expand All @@ -221,6 +246,7 @@ def __init__(self,
start=None,
interval=1,
save_best=True,
save_best_ckpt=False,
key_indicator='top1_acc',
rule=None,
tmpdir=None,
Expand All @@ -231,6 +257,7 @@ def __init__(self,
start=start,
interval=interval,
save_best=save_best,
save_best_ckpt=save_best_ckpt,
key_indicator=key_indicator,
rule=rule,
**eval_kwargs)
Expand Down Expand Up @@ -266,12 +293,4 @@ def after_train_epoch(self, runner):
if runner.rank == 0:
print('\n')
key_score = self.evaluate(runner, results)
if (self.save_best and key_score is not None
and self.compare_func(key_score, self.best_score)):
self.best_score = key_score
self.logger.info(
f'Now best checkpoint is epoch_{runner.epoch + 1}.pth')
self.best_json['best_score'] = self.best_score
self.best_json['best_ckpt'] = current_ckpt_path
self.best_json['key_indicator'] = self.key_indicator
mmcv.dump(self.best_json, json_path)
self._do_save_best(key_score, json_path, current_ckpt_path, runner)
2 changes: 1 addition & 1 deletion mmaction/models/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# Define an empty registry and building func, so that can import
DETECTORS = Registry('detector')

def bulid_detector(cfg, train_cfg, test_cfg):
def build_detector(cfg, train_cfg, test_cfg):
pass


Expand Down
2 changes: 1 addition & 1 deletion mmaction/models/heads/roi_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,5 +85,5 @@ def simple_test_bboxes(self,
return det_bboxes, det_labels
else:
# Just define an empty class, so that __init__ can import it.
class AVARoIHead(StandardRoIHead):
class AVARoIHead:
pass
6 changes: 5 additions & 1 deletion mmaction/models/roi_extractors/single_straight3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@

import torch
import torch.nn as nn
from mmcv.ops import RoIAlign, RoIPool

try:
from mmcv.ops import RoIAlign, RoIPool
except (ImportError, ModuleNotFoundError):
warnings.warn('Please install mmcv-full to use RoIAlign and RoIPool')

try:
import mmdet # noqa
Expand Down
6 changes: 4 additions & 2 deletions tests/test_runtime/test_eval_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,8 +247,8 @@ def test_eval_hook():
assert best_json['key_indicator'] == 'acc'

data_loader = DataLoader(EvalDataset(), batch_size=1)
eval_hook = EpochEvalHook(data_loader, key_indicator='acc')
with tempfile.TemporaryDirectory() as tmpdir:
eval_hook = EpochEvalHook(data_loader, key_indicator='acc')
logger = get_logger('test_eval')
runner = EpochBasedRunner(
model=model,
Expand All @@ -270,7 +270,8 @@ def test_eval_hook():

resume_from = osp.join(tmpdir, 'latest.pth')
loader = DataLoader(ExampleDataset(), batch_size=1)
eval_hook = EpochEvalHook(data_loader, key_indicator='acc')
eval_hook = EpochEvalHook(
data_loader, key_indicator='acc', save_best_ckpt=True)
runner = EpochBasedRunner(
model=model,
batch_processor=None,
Expand All @@ -289,6 +290,7 @@ def test_eval_hook():
assert best_json['best_ckpt'] == osp.realpath(real_path)
assert best_json['best_score'] == 7
assert best_json['key_indicator'] == 'acc'
assert osp.isfile(osp.join(tmpdir, 'best_7.0000_4.pth'))


@patch('mmaction.apis.single_gpu_test', MagicMock)
Expand Down
1 change: 1 addition & 0 deletions tools/data/activitynet/download_videos.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
conda env create -f environment.yml
source activate activitynet
pip install --upgrade youtube-dl
pip install mmcv

DATA_DIR="../../../data/ActivityNet"
python download.py
Expand Down
2 changes: 1 addition & 1 deletion tools/data/gym/environment.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: kinetics
name: gym
channels:
- anaconda
- menpo
Expand Down
17 changes: 11 additions & 6 deletions tools/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,15 +102,20 @@ def main():

# Load output_config from cfg
output_config = cfg.get('output_config', {})
# Overwrite output_config from args.out
output_config = Config._merge_a_into_b(dict(out=args.out), output_config)
if args.out:
# Overwrite output_config from args.out
output_config = Config._merge_a_into_b(
dict(out=args.out), output_config)

# Load eval_config from cfg
eval_config = cfg.get('eval_config', {})
# Overwrite eval_config from args.eval
eval_config = Config._merge_a_into_b(dict(metrics=args.eval), eval_config)
# Add options from args.eval_options
eval_config = Config._merge_a_into_b(args.eval_options, eval_config)
if args.eval:
# Overwrite eval_config from args.eval
eval_config = Config._merge_a_into_b(
dict(metrics=args.eval), eval_config)
if args.eval_options:
# Add options from args.eval_options
eval_config = Config._merge_a_into_b(args.eval_options, eval_config)

assert output_config or eval_config, \
('Please specify at least one operation (save or eval the '
Expand Down