Skip to content

Commit

Permalink
use cp instead symlink
Browse files Browse the repository at this point in the history
  • Loading branch information
dreamerlin committed Jan 9, 2021
1 parent 0fbaf09 commit c2cd29c
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 13 deletions.
24 changes: 17 additions & 7 deletions mmcv/runner/hooks/eval.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import os
import os.path as osp
import warnings
from math import inf

from torch.utils.data import DataLoader

from mmcv import symlink
from mmcv.engine import multi_gpu_test, single_gpu_test
from mmcv.runner import Hook

Expand Down Expand Up @@ -49,7 +49,7 @@ class EvalHook(Hook):

rule_map = {'greater': lambda x, y: x > y, 'less': lambda x, y: x < y}
init_value_map = {'greater': -inf, 'less': inf}
greater_keys = ['acc', 'top', 'AR@', 'auc', 'precision']
greater_keys = ['acc', 'top', 'AR@', 'auc', 'precision', 'mAP']
less_keys = ['loss']

def __init__(self,
Expand Down Expand Up @@ -85,6 +85,7 @@ def __init__(self,
self.initial_flag = True

if self.save_best is not None:
self.best_ckpt_path = None
self._init_rule(rule, self.save_best)

def _init_rule(self, rule, key_indicator):
Expand Down Expand Up @@ -191,8 +192,10 @@ def evaluation_flag(self, runner):
def _save_ckpt(self, runner, key_score):
if self.by_epoch:
current = f'epoch_{runner.epoch + 1}'
cur_type, cur_time = 'epoch', runner.epoch + 1
else:
current = f'iter_{runner.iter + 1}'
cur_type, cur_time = 'iter', runner.iter + 1

best_score = runner.meta['hook_msgs'].get(
'best_score', self.init_value_map[self.rule])
Expand All @@ -201,12 +204,19 @@ def _save_ckpt(self, runner, key_score):
runner.meta['hook_msgs']['best_score'] = best_score
last_ckpt = runner.meta['hook_msgs']['last_ckpt']
runner.meta['hook_msgs']['best_ckpt'] = last_ckpt
symlink(
last_ckpt,
osp.join(runner.work_dir, f'best_{self.key_indicator}.pth'))

if self.best_ckpt_path and osp.isfile(self.best_ckpt_path):
os.remove(self.best_ckpt_path)

best_ckpt_name = f'best_{self.key_indicator}_{current}.pth'
runner.save_checkpoint(
runner.work_dir, best_ckpt_name, create_symlink=False)
self.best_ckpt_path = osp.join(runner.work_dir, best_ckpt_name)
runner.logger.info(
f'Now best checkpoint is saved as {best_ckpt_name}.')
runner.logger.info(
f'Now best checkpoint is {current}.pth.'
f'Best {self.key_indicator} is {best_score:0.4f}')
f'Best {self.key_indicator} is {best_score:0.4f} '
f'at {cur_time} {cur_type}.')

def evaluate(self, runner, results):
"""Evaluate the results.
Expand Down
12 changes: 6 additions & 6 deletions tests/test_runner/test_eval_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def test_eval_hook():
runner.run([loader], [('train', 1)], 8)

real_path = osp.join(tmpdir, 'epoch_4.pth')
link_path = osp.join(tmpdir, 'best_acc.pth')
link_path = osp.join(tmpdir, 'best_acc_epoch_4.pth')

assert runner.meta['hook_msgs']['best_ckpt'] == osp.realpath(
real_path)
Expand All @@ -161,7 +161,7 @@ def test_eval_hook():
runner.run([loader], [('train', 1)], 8)

real_path = osp.join(tmpdir, 'epoch_4.pth')
link_path = osp.join(tmpdir, 'best_acc.pth')
link_path = osp.join(tmpdir, 'best_acc_epoch_4.pth')

assert runner.meta['hook_msgs']['best_ckpt'] == osp.realpath(
real_path)
Expand All @@ -181,7 +181,7 @@ def test_eval_hook():
runner.run([loader], [('train', 1)], 8)

real_path = osp.join(tmpdir, 'epoch_4.pth')
link_path = osp.join(tmpdir, 'best_score.pth')
link_path = osp.join(tmpdir, 'best_score_epoch_4.pth')

assert runner.meta['hook_msgs']['best_ckpt'] == osp.realpath(
real_path)
Expand All @@ -201,7 +201,7 @@ def test_eval_hook():
runner.run([loader], [('train', 1)], 8)

real_path = osp.join(tmpdir, 'epoch_6.pth')
link_path = osp.join(tmpdir, 'best_acc.pth')
link_path = osp.join(tmpdir, 'best_acc_epoch_6.pth')

assert runner.meta['hook_msgs']['best_ckpt'] == osp.realpath(
real_path)
Expand All @@ -220,7 +220,7 @@ def test_eval_hook():
runner.run([loader], [('train', 1)], 2)

real_path = osp.join(tmpdir, 'epoch_2.pth')
link_path = osp.join(tmpdir, 'best_acc.pth')
link_path = osp.join(tmpdir, 'best_acc_epoch_2.pth')

assert runner.meta['hook_msgs']['best_ckpt'] == osp.realpath(
real_path)
Expand All @@ -238,7 +238,7 @@ def test_eval_hook():
runner.run([loader], [('train', 1)], 8)

real_path = osp.join(tmpdir, 'epoch_4.pth')
link_path = osp.join(tmpdir, 'best_acc.pth')
link_path = osp.join(tmpdir, 'best_acc_epoch_4.pth')

assert runner.meta['hook_msgs']['best_ckpt'] == osp.realpath(
real_path)
Expand Down

0 comments on commit c2cd29c

Please sign in to comment.