Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add epoch and iter information in model #1115

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions mmcv/runner/epoch_based_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,13 @@ def train(self, data_loader, **kwargs):
self.mode = 'train'
self.data_loader = data_loader
self._max_iters = self._max_epochs * len(self.data_loader)
setattr(self.model, '_epoch', self._epoch)
self.call_hook('before_train_epoch')
time.sleep(2) # Prevent possible deadlock during epoch transition
for i, data_batch in enumerate(self.data_loader):
self._inner_iter = i
setattr(self.model, '_iter', self._iter)
setattr(self.model, '_inner_iter', self._inner_iter)
self.call_hook('before_train_iter')
self.run_iter(data_batch, train_mode=True, **kwargs)
self.call_hook('after_train_iter')
Expand All @@ -63,6 +66,7 @@ def val(self, data_loader, **kwargs):
time.sleep(2) # Prevent possible deadlock during epoch transition
for i, data_batch in enumerate(self.data_loader):
self._inner_iter = i
setattr(self.model, '_inner_iter', self._inner_iter)
self.call_hook('before_val_iter')
self.run_iter(data_batch, train_mode=False)
self.call_hook('after_val_iter')
Expand Down Expand Up @@ -105,6 +109,7 @@ def run(self, data_loaders, workflow, max_epochs=None, **kwargs):
self._max_epochs)
self.call_hook('before_run')

setattr(self.model, '_max_epochs', self._max_epochs)
while self.epoch < self._max_epochs:
for i, flow in enumerate(workflow):
mode, epochs = flow
Expand Down
4 changes: 4 additions & 0 deletions mmcv/runner/iter_based_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ def train(self, data_loader, **kwargs):
self.data_loader = data_loader
self._epoch = data_loader.epoch
data_batch = next(data_loader)
setattr(self.model, '_iter', self._iter)
setattr(self.model, '_inner_iter', self._inner_iter)
self.call_hook('before_train_iter')
outputs = self.model.train_step(data_batch, self.optimizer, **kwargs)
if not isinstance(outputs, dict):
Expand All @@ -73,6 +75,7 @@ def val(self, data_loader, **kwargs):
self.mode = 'val'
self.data_loader = data_loader
data_batch = next(data_loader)
setattr(self.model, '_inner_iter', self._inner_iter)
self.call_hook('before_val_iter')
outputs = self.model.val_step(data_batch, **kwargs)
if not isinstance(outputs, dict):
Expand Down Expand Up @@ -116,6 +119,7 @@ def run(self, data_loaders, workflow, max_iters=None, **kwargs):

self.call_hook('before_epoch')

setattr(self.model, '_max_iters', self._max_iters)
while self.iter < self._max_iters:
for i, flow in enumerate(workflow):
self._inner_iter = 0
Expand Down
55 changes: 55 additions & 0 deletions tests/test_runner/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import pytest
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

from mmcv.parallel import MMDataParallel
from mmcv.runner import (RUNNERS, EpochBasedRunner, IterBasedRunner,
Expand All @@ -32,6 +33,46 @@ def val_step(self):
pass


class ExampleModel(Model):

def __init__(self, by_epoch=True, max_time=None):
super().__init__()
self.by_epoch = by_epoch
self.max_time = max_time
if by_epoch:
self.record = [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9],
[10, 11, 12, 13, 14]]
else:
self.record = [0, 1, 2, 3, 4]
self.val_iter_step = 0

def train_step(self, x, optimizer=None, **kwargs):
if self.by_epoch:
assert self._max_epochs == self.max_time
assert self._inner_iter == self._iter % len(self.record[0])
value = self.record[self._epoch][self._inner_iter]
assert x + self._epoch * 5 == value
else:
assert not hasattr(self, '_epoch')
assert self._max_iters == self.max_time
value = self.record[self._iter]
assert x == value
return {}

def val_step(self, x, optimizer=None, **kwargs):
if self.by_epoch:
assert self._max_epochs == self.max_time
value = self.record[self._epoch][self._inner_iter]
assert x + self._epoch * 5 == value
else:
assert not hasattr(self, '_epoch')
assert self._max_iters == self.max_time
value = self.record[self.val_iter_step]
self.val_iter_step += 1
assert x == value
return {}


def test_build_runner():
temp_root = tempfile.gettempdir()
dir_name = ''.join(
Expand Down Expand Up @@ -281,3 +322,17 @@ def test_register_timer_hook(runner_class):
runner.register_timer_hook(timer_config)
assert len(runner.hooks) == 2
assert isinstance(runner.hooks[1], IterTimerHook)


@pytest.mark.parametrize('runner_class', RUNNERS.module_dict.values())
def test_check_epoch(runner_class):
loader = DataLoader([0, 1, 2, 3, 4])
if runner_class == IterBasedRunner:
model = ExampleModel(by_epoch=False, max_time=5)
runner = runner_class(
model=model, logger=logging.getLogger(), max_iters=5)
else:
model = ExampleModel(max_time=3)
runner = runner_class(
model=model, logger=logging.getLogger(), max_epochs=3)
runner.run([loader, loader], [('train', 2), ('val', 1)])