Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
dreamerlin committed Dec 25, 2020
1 parent 897d333 commit 34db925
Showing 1 changed file with 8 additions and 7 deletions.
15 changes: 8 additions & 7 deletions mmaction/core/evaluation/eval_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def __init__(self,
assert isinstance(save_best, str) or save_best is None
self.save_best = save_best
self.eval_kwargs = eval_kwargs
self.initial_epoch_flag = True
self.initial_flag = True

if self.save_best is not None:
self._init_rule(rule, self.save_best)
Expand Down Expand Up @@ -125,21 +125,21 @@ def before_train_iter(self, runner):
"""Evaluate the model only at the start of training by iteration."""
if self.by_epoch:
return
if not self.initial_epoch_flag:
if not self.initial_flag:
return
if self.start is not None and runner.iter >= self.start:
self.after_train_iter(runner)
self.initial_epoch_flag = False
self.initial_flag = False

def before_train_epoch(self, runner):
"""Evaluate the model only at the start of training by epoch."""
if not self.by_epoch:
return
if not self.initial_epoch_flag:
if not self.initial_flag:
return
if self.start is not None and runner.epoch >= self.start:
self.after_train_epoch(runner)
self.initial_epoch_flag = False
self.initial_flag = False

def after_train_iter(self, runner):
"""Called after every training iter to evaluate the results."""
Expand Down Expand Up @@ -183,7 +183,8 @@ def evaluation_flag(self, runner):
# No evaluation if start is larger than the current time.
return False
else:
# Evaluation only at epochs 3, 5, 7... if start==3 and interval==2
# Evaluation only at epochs/iters 3, 5, 7...
# if start==3 and interval==2
if (current + 1 - self.start) % self.interval:
return False
return True
Expand All @@ -192,7 +193,7 @@ def _save_ckpt(self, runner, key_score):
if self.by_epoch:
current = f'epoch_{runner.epoch + 1}'
else:
current = f'iter_{runner.epoch + 1}'
current = f'iter_{runner.iter + 1}'

best_score = runner.meta['hook_msgs'].get(
'best_score', self.init_value_map[self.rule])
Expand Down

0 comments on commit 34db925

Please sign in to comment.