diff --git a/mmcv/runner/epoch_based_runner.py b/mmcv/runner/epoch_based_runner.py index 1e1de295ed..f2c57f7aa8 100644 --- a/mmcv/runner/epoch_based_runner.py +++ b/mmcv/runner/epoch_based_runner.py @@ -42,10 +42,13 @@ def train(self, data_loader, **kwargs): self.mode = 'train' self.data_loader = data_loader self._max_iters = self._max_epochs * len(self.data_loader) + setattr(self.model, '_epoch', self._epoch) self.call_hook('before_train_epoch') time.sleep(2) # Prevent possible deadlock during epoch transition for i, data_batch in enumerate(self.data_loader): self._inner_iter = i + setattr(self.model, '_iter', self._iter) + setattr(self.model, '_inner_iter', self._inner_iter) self.call_hook('before_train_iter') self.run_iter(data_batch, train_mode=True, **kwargs) self.call_hook('after_train_iter') @@ -63,6 +66,7 @@ def val(self, data_loader, **kwargs): time.sleep(2) # Prevent possible deadlock during epoch transition for i, data_batch in enumerate(self.data_loader): self._inner_iter = i + setattr(self.model, '_inner_iter', self._inner_iter) self.call_hook('before_val_iter') self.run_iter(data_batch, train_mode=False) self.call_hook('after_val_iter') @@ -105,6 +109,7 @@ def run(self, data_loaders, workflow, max_epochs=None, **kwargs): self._max_epochs) self.call_hook('before_run') + setattr(self.model, '_max_epochs', self._max_epochs) while self.epoch < self._max_epochs: for i, flow in enumerate(workflow): mode, epochs = flow diff --git a/mmcv/runner/iter_based_runner.py b/mmcv/runner/iter_based_runner.py index 75133d5ec4..b39010247e 100644 --- a/mmcv/runner/iter_based_runner.py +++ b/mmcv/runner/iter_based_runner.py @@ -56,6 +56,8 @@ def train(self, data_loader, **kwargs): self.data_loader = data_loader self._epoch = data_loader.epoch data_batch = next(data_loader) + setattr(self.model, '_iter', self._iter) + setattr(self.model, '_inner_iter', self._inner_iter) self.call_hook('before_train_iter') outputs = self.model.train_step(data_batch, self.optimizer, **kwargs) if not isinstance(outputs, dict): @@ -73,6 +75,7 @@ def val(self, data_loader, **kwargs): self.mode = 'val' self.data_loader = data_loader data_batch = next(data_loader) + setattr(self.model, '_inner_iter', self._inner_iter) self.call_hook('before_val_iter') outputs = self.model.val_step(data_batch, **kwargs) if not isinstance(outputs, dict): @@ -116,6 +119,7 @@ def run(self, data_loaders, workflow, max_iters=None, **kwargs): self.call_hook('before_epoch') + setattr(self.model, '_max_iters', self._max_iters) while self.iter < self._max_iters: for i, flow in enumerate(workflow): self._inner_iter = 0 diff --git a/tests/test_runner/test_runner.py b/tests/test_runner/test_runner.py index abe5bb037a..b6a66063f9 100644 --- a/tests/test_runner/test_runner.py +++ b/tests/test_runner/test_runner.py @@ -9,6 +9,7 @@ import pytest import torch import torch.nn as nn +from torch.utils.data import DataLoader from mmcv.parallel import MMDataParallel from mmcv.runner import (RUNNERS, EpochBasedRunner, IterBasedRunner, @@ -32,6 +33,46 @@ def val_step(self): pass +class ExampleModel(Model): + + def __init__(self, by_epoch=True, max_time=None): + super().__init__() + self.by_epoch = by_epoch + self.max_time = max_time + if by_epoch: + self.record = [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9], + [10, 11, 12, 13, 14]] + else: + self.record = [0, 1, 2, 3, 4] + self.val_iter_step = 0 + + def train_step(self, x, optimizer=None, **kwargs): + if self.by_epoch: + assert self._max_epochs == self.max_time + assert self._inner_iter == self._iter % len(self.record[0]) + value = self.record[self._epoch][self._inner_iter] + assert x + self._epoch * 5 == value + else: + assert not hasattr(self, '_epoch') + assert self._max_iters == self.max_time + value = self.record[self._iter] + assert x == value + return {} + + def val_step(self, x, optimizer=None, **kwargs): + if self.by_epoch: + assert self._max_epochs == self.max_time + value = self.record[self._epoch][self._inner_iter] + assert x + self._epoch * 5 == value + else: + assert not hasattr(self, '_epoch') + assert self._max_iters == self.max_time + value = self.record[self.val_iter_step] + self.val_iter_step += 1 + assert x == value + return {} + + def test_build_runner(): temp_root = tempfile.gettempdir() dir_name = ''.join( @@ -281,3 +322,17 @@ def test_register_timer_hook(runner_class): runner.register_timer_hook(timer_config) assert len(runner.hooks) == 2 assert isinstance(runner.hooks[1], IterTimerHook) + + +@pytest.mark.parametrize('runner_class', RUNNERS.module_dict.values()) +def test_check_epoch(runner_class): + loader = DataLoader([0, 1, 2, 3, 4]) + if runner_class == IterBasedRunner: + model = ExampleModel(by_epoch=False, max_time=5) + runner = runner_class( + model=model, logger=logging.getLogger(), max_iters=5) + else: + model = ExampleModel(max_time=3) + runner = runner_class( + model=model, logger=logging.getLogger(), max_epochs=3) + runner.run([loader, loader], [('train', 2), ('val', 1)])