Skip to content

Commit

Permalink
Fix _should_reload_dl_epoch causing inconsistent validation dataloa…
Browse files Browse the repository at this point in the history
…der reloading (#11036)


Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Jirka Borovec <[email protected]>
Co-authored-by: thomas chaton <[email protected]>
Co-authored-by: Adrian Wälchli <[email protected]>
  • Loading branch information
5 people authored Dec 28, 2021
1 parent ca9b25d commit 1fc046c
Show file tree
Hide file tree
Showing 7 changed files with 143 additions and 67 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed wrong typehint for `Trainer.lightning_optimizers` ([#11155](https://github.com/PyTorchLightning/pytorch-lightning/pull/11155))


- Fixed dataloaders not getting reloaded the correct amount of times when setting `reload_dataloaders_every_n_epochs` and `check_val_every_n_epoch` ([#10948](https://github.com/PyTorchLightning/pytorch-lightning/pull/10948))


## [1.5.7] - 2021-12-21

### Fixed
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/loops/dataloader/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def _reload_evaluation_dataloaders(self) -> None:
"""Reloads dataloaders if necessary."""
if self.trainer.testing:
self.trainer.reset_test_dataloader()
elif self.trainer.val_dataloaders is None or self.trainer._should_reload_dl_epoch:
elif self.trainer.val_dataloaders is None or self.trainer._should_reload_val_dl:
self.trainer.reset_val_dataloader()

def _on_evaluation_start(self, *args: Any, **kwargs: Any) -> None:
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/loops/fit_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def on_advance_start(self) -> None: # type: ignore[override]
model = self.trainer.lightning_module

# reset train dataloader
if not self._is_fresh_start_epoch and self.trainer._should_reload_dl_epoch:
if not self._is_fresh_start_epoch and self.trainer._should_reload_train_dl:
self.trainer.reset_train_dataloader(model)
self._is_fresh_start_epoch = False

Expand Down
22 changes: 22 additions & 0 deletions pytorch_lightning/trainer/data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ class TrainerDataLoadingMixin(ABC):
# this is just a summary on variables used in this abstract class,
# the proper values/initialisation should be done in child class
val_check_interval: float
reload_dataloaders_every_n_epochs: int
tpu_local_core_rank: int
train_dataloader: DataLoader
limit_train_batches: Union[int, float]
Expand All @@ -67,7 +68,22 @@ class TrainerDataLoadingMixin(ABC):
distributed_sampler_kwargs: dict
accelerator: Accelerator
call_hook: Callable
current_epoch: int
_accelerator_connector: AcceleratorConnector
_last_train_dl_reload_epoch: int
_last_val_dl_reload_epoch: int

@property
def _should_reload_train_dl(self) -> bool:
"""Check if train dataloader should be reloaded."""
n_epochs = self.reload_dataloaders_every_n_epochs
return n_epochs and (self.current_epoch - self._last_train_dl_reload_epoch >= n_epochs)

@property
def _should_reload_val_dl(self) -> bool:
"""Check if validation dataloader should be reloaded."""
n_epochs = self.reload_dataloaders_every_n_epochs
return n_epochs and (self.current_epoch - self._last_val_dl_reload_epoch >= n_epochs)

def _worker_check(self, dataloader: DataLoader, name: str) -> None:
if not isinstance(dataloader, DataLoader):
Expand Down Expand Up @@ -278,6 +294,9 @@ def reset_train_dataloader(self, model: Optional["pl.LightningModule"] = None) -
category=PossibleUserWarning,
)

# store epoch of dataloader reset for reload_dataloaders_every_n_epochs
self._last_train_dl_reload_epoch = self.current_epoch

def _reset_eval_dataloader(
self, mode: RunningStage, model: Optional["pl.LightningModule"] = None
) -> Tuple[List[Union[int, float]], List[DataLoader]]:
Expand Down Expand Up @@ -369,6 +388,9 @@ def reset_val_dataloader(self, model: Optional["pl.LightningModule"] = None) ->
RunningStage.VALIDATING, model=pl_module
)

# store epoch of dataloader reset for reload_dataloaders_every_n_epochs
self._last_val_dl_reload_epoch = self.current_epoch

def reset_test_dataloader(self, model: Optional["pl.LightningModule"] = None) -> None:
"""Resets the test dataloader and determines the number of batches.
Expand Down
10 changes: 4 additions & 6 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -658,6 +658,8 @@ def _setup_on_init(self, num_sanity_val_steps: int) -> None:
self.num_val_batches = []
self.test_dataloaders = None
self.val_dataloaders = None
self._last_train_dl_reload_epoch = float("-inf")
self._last_val_dl_reload_epoch = float("-inf")

self.num_predict_batches = []

Expand Down Expand Up @@ -743,6 +745,8 @@ def _fit_impl(
self.state.fn = TrainerFn.FITTING
self.state.status = TrainerStatus.RUNNING
self.training = True
self._last_train_dl_reload_epoch = float("-inf")
self._last_val_dl_reload_epoch = float("-inf")

# if a datamodule comes in as the second arg, then fix it for the user
if isinstance(train_dataloaders, LightningDataModule):
Expand Down Expand Up @@ -1963,12 +1967,6 @@ def progress_bar_dict(self) -> dict:
return self.progress_bar_callback.get_metrics(self, ref_model)
return self.progress_bar_metrics

@property
def _should_reload_dl_epoch(self) -> bool:
"""Check if dataloader should be reloaded in the current epoch."""
n_epochs = self.reload_dataloaders_every_n_epochs
return n_epochs and (not self.current_epoch % n_epochs)

@property
def enable_validation(self) -> bool:
"""Check if we should run validation during training."""
Expand Down
1 change: 0 additions & 1 deletion tests/models/test_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -870,7 +870,6 @@ def call(hook, fn, *args, **kwargs):
dict(name="setup", kwargs=dict(stage="fit")),
dict(name="val_dataloader"),
dict(name="train_dataloader"),
dict(name="val_dataloader"),
dict(name="on_save_checkpoint", args=(ANY,)),
dict(name="teardown", kwargs=dict(stage="fit")),
]
Expand Down
170 changes: 112 additions & 58 deletions tests/trainer/test_dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -1012,17 +1012,12 @@ def test_dataloaders_load_only_once(tmpdir):
assert tracker.mock_calls == [call.val_dataloader(), call.train_dataloader()]


def test_dataloaders_load_only_once_val_interval(tmpdir):
def test_dataloaders_load_only_once_no_sanity_check(tmpdir):
model = BoringModel()

# logger file to get meta
trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=10,
limit_val_batches=10,
val_check_interval=0.3,
reload_dataloaders_every_n_epochs=1,
max_epochs=3,
default_root_dir=tmpdir, limit_train_batches=0.3, limit_val_batches=0.3, num_sanity_val_steps=0, max_epochs=3
)

tracker = Mock()
Expand All @@ -1035,34 +1030,33 @@ def test_dataloaders_load_only_once_val_interval(tmpdir):
tracker.attach_mock(model.test_dataloader, "test_dataloader")

trainer.fit(model)
trainer.test(model)

# verify the sequence
expected_sequence = [
call.val_dataloader(),
call.train_dataloader(),
call.val_dataloader(),
call.val_dataloader(),
call.val_dataloader(),
call.train_dataloader(),
call.val_dataloader(),
call.val_dataloader(),
call.val_dataloader(),
call.train_dataloader(),
call.val_dataloader(),
call.val_dataloader(),
call.val_dataloader(),
call.test_dataloader(),
]
expected_sequence = [call.train_dataloader(), call.val_dataloader()]
assert tracker.mock_calls == expected_sequence


def test_dataloaders_load_only_once_no_sanity_check(tmpdir):
model = BoringModel()
@pytest.mark.parametrize("n", [1, 2])
def test_dataloaders_load_every_n_epochs(tmpdir, n):
train_reload_epochs, val_reload_epochs = [], []

class TestModel(BoringModel):
def train_dataloader(self):
train_reload_epochs.append(self.current_epoch)
return super().train_dataloader()

def val_dataloader(self):
val_reload_epochs.append(self.current_epoch)
return super().val_dataloader()

model = TestModel()

# logger file to get meta
trainer = Trainer(
default_root_dir=tmpdir, limit_train_batches=0.3, limit_val_batches=0.3, num_sanity_val_steps=0, max_epochs=3
default_root_dir=tmpdir,
limit_train_batches=0.3,
limit_val_batches=0.3,
reload_dataloaders_every_n_epochs=n,
max_epochs=5,
)

tracker = Mock()
Expand All @@ -1075,44 +1069,113 @@ def test_dataloaders_load_only_once_no_sanity_check(tmpdir):
tracker.attach_mock(model.test_dataloader, "test_dataloader")

trainer.fit(model)
trainer.test(model)

# Verify the sequence
expected_sequence = [call.val_dataloader(), call.train_dataloader()] # Sanity check first
if n == 1:
expected_sequence += [call.train_dataloader(), call.val_dataloader()] * 4
elif n == 2:
expected_sequence += [call.train_dataloader(), call.val_dataloader()] * 2
expected_sequence += [call.test_dataloader()]

# verify the sequence
expected_sequence = [call.train_dataloader(), call.val_dataloader()]
assert tracker.mock_calls == expected_sequence

# Verify epoch of reloads
if n == 1:
assert train_reload_epochs == [0, 1, 2, 3, 4]
assert val_reload_epochs == [0, 1, 2, 3, 4]
elif n == 2:
assert train_reload_epochs == [0, 2, 4]
assert val_reload_epochs == [0, 2, 4]

@pytest.mark.parametrize("n", [1, 2])
def test_dataloaders_load_every_n_epochs(tmpdir, n):
model = BoringModel()

@pytest.mark.parametrize(
"n, train_reload_epochs_expect, val_reload_epochs_expect",
[
# Sanity check at epoch 0 creates a validation dataloader, but validation is
# checked (and in this case reloaded) every n epochs starting from epoch n-1
(3, [0, 2, 4, 6, 8], [0, 2, 5, 8]),
(5, [0, 2, 4, 6, 8], [0, 4, 9]),
],
)
def test_dataloaders_load_every_n_epochs_infrequent_val(
tmpdir, n, train_reload_epochs_expect, val_reload_epochs_expect
):
"""Test dataloader reload behavior when infrequently checking validation set (via check_val_every_n_epoch)"""
train_reload_epochs, val_reload_epochs = [], []

class TestModel(BoringModel):
def train_dataloader(self):
train_reload_epochs.append(self.current_epoch)
return super().train_dataloader()

def val_dataloader(self):
val_reload_epochs.append(self.current_epoch)
return super().val_dataloader()

model = TestModel()

trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=0.3,
limit_val_batches=0.3,
reload_dataloaders_every_n_epochs=n,
check_val_every_n_epoch=n,
reload_dataloaders_every_n_epochs=2,
max_epochs=10,
)
model.test_dataloader = Mock(wraps=model.test_dataloader)

trainer.fit(model)
trainer.test(model)

# Verify epoch of reloads
assert train_reload_epochs == train_reload_epochs_expect
assert val_reload_epochs == val_reload_epochs_expect

model.test_dataloader.assert_called_once()


def test_dataloaders_load_every_n_epochs_frequent_val(tmpdir):
"""Test dataloader reload behavior when frequently checking validation set (via val_check_interval)"""
train_reload_epochs, val_reload_epochs, val_check_epochs = [], [], []

class TestModel(BoringModel):
def train_dataloader(self):
train_reload_epochs.append(self.current_epoch)
return super().train_dataloader()

def val_dataloader(self):
val_reload_epochs.append(self.current_epoch)
return super().val_dataloader()

def validation_epoch_end(self, outputs):
val_check_epochs.append(self.current_epoch)
return super().validation_epoch_end(outputs)

model = TestModel()

trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=0.3,
limit_val_batches=0.3,
val_check_interval=0.3,
reload_dataloaders_every_n_epochs=1,
max_epochs=3,
)

tracker = Mock()
model.train_dataloader = Mock(wraps=model.train_dataloader)
model.val_dataloader = Mock(wraps=model.val_dataloader)
model.test_dataloader = Mock(wraps=model.test_dataloader)

tracker.attach_mock(model.train_dataloader, "train_dataloader")
tracker.attach_mock(model.val_dataloader, "val_dataloader")
tracker.attach_mock(model.test_dataloader, "test_dataloader")

trainer.fit(model)
trainer.test(model)

# verify the sequence
expected_sequence = [call.val_dataloader()]
if n == 1:
expected_sequence += [call.train_dataloader(), call.val_dataloader()] * 3
elif n == 2:
expected_sequence += [call.train_dataloader(), call.val_dataloader()] * 2
expected_sequence += [call.test_dataloader()]
assert tracker.mock_calls == expected_sequence
# Verify epoch of reloads
assert train_reload_epochs == [0, 1, 2]
assert val_reload_epochs == [0, 1, 2]
model.test_dataloader.assert_called_once()

# Verify validation happens 3 times per epoch + 1 for sanity check
assert val_check_epochs == [0, 0, 0, 0, 1, 1, 1, 2, 2, 2]


@pytest.mark.parametrize("n", ["test", -1])
Expand Down Expand Up @@ -1159,15 +1222,6 @@ def validation_step(self, batch, batch_idx):
expected_calls = [
call.train_dataloader(),
call.val_dataloader(),
# This has subsequent calls to val_dataloader
# because the training loop runs the evaluation loop,
# which reloads the val dataloader again.
# We cannot yet rely on trainer.current_epoch=0 to skip reloading
# the val dataloader on the first epoch because this only tracks the training epoch
# meaning multiple passes through the validation data within a single training epoch
# would not have the dataloader reloaded.
# This breaks the assumption behind reload_dataloaders_every_n_epochs=1
call.val_dataloader(),
call.train_dataloader(),
call.val_dataloader(),
call.train_dataloader(),
Expand Down

0 comments on commit 1fc046c

Please sign in to comment.