Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix _should_reload_dl_epoch causing inconsistent validation dataloader reloading #11036

Merged
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed wrong typehint for `Trainer.lightning_optimizers` ([#11155](https://github.com/PyTorchLightning/pytorch-lightning/pull/11155))


- Fixed dataloaders not getting reloaded the correct amount of times when setting `reload_dataloaders_every_n_epochs` and `check_val_every_n_epoch` ([#10948](https://github.com/PyTorchLightning/pytorch-lightning/pull/10948))


## [1.5.7] - 2021-12-21

### Fixed
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/loops/dataloader/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def _reload_evaluation_dataloaders(self) -> None:
"""Reloads dataloaders if necessary."""
if self.trainer.testing:
self.trainer.reset_test_dataloader()
elif self.trainer.val_dataloaders is None or self.trainer._should_reload_dl_epoch:
elif self.trainer.val_dataloaders is None or self.trainer._should_reload_val_dl:
self.trainer.reset_val_dataloader()

def _on_evaluation_start(self, *args: Any, **kwargs: Any) -> None:
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/loops/fit_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def on_advance_start(self) -> None: # type: ignore[override]
model = self.trainer.lightning_module

# reset train dataloader
if not self._is_fresh_start_epoch and self.trainer._should_reload_dl_epoch:
if not self._is_fresh_start_epoch and self.trainer._should_reload_train_dl:
self.trainer.reset_train_dataloader(model)
self._is_fresh_start_epoch = False

Expand Down
22 changes: 22 additions & 0 deletions pytorch_lightning/trainer/data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ class TrainerDataLoadingMixin(ABC):
# this is just a summary on variables used in this abstract class,
# the proper values/initialisation should be done in child class
val_check_interval: float
reload_dataloaders_every_n_epochs: int
tpu_local_core_rank: int
train_dataloader: DataLoader
limit_train_batches: Union[int, float]
Expand All @@ -67,7 +68,22 @@ class TrainerDataLoadingMixin(ABC):
distributed_sampler_kwargs: dict
accelerator: Accelerator
call_hook: Callable
current_epoch: int
_accelerator_connector: AcceleratorConnector
_last_train_dl_reload_epoch: int
_last_val_dl_reload_epoch: int

@property
def _should_reload_train_dl(self) -> bool:
"""Check if train dataloader should be reloaded."""
n_epochs = self.reload_dataloaders_every_n_epochs
return n_epochs and (self.current_epoch - self._last_train_dl_reload_epoch >= n_epochs)

@property
def _should_reload_val_dl(self) -> bool:
"""Check if validation dataloader should be reloaded."""
n_epochs = self.reload_dataloaders_every_n_epochs
return n_epochs and (self.current_epoch - self._last_val_dl_reload_epoch >= n_epochs)

def _worker_check(self, dataloader: DataLoader, name: str) -> None:
if not isinstance(dataloader, DataLoader):
Expand Down Expand Up @@ -278,6 +294,9 @@ def reset_train_dataloader(self, model: Optional["pl.LightningModule"] = None) -
category=PossibleUserWarning,
)

# store epoch of dataloader reset for reload_dataloaders_every_n_epochs
self._last_train_dl_reload_epoch = self.current_epoch

def _reset_eval_dataloader(
self, mode: RunningStage, model: Optional["pl.LightningModule"] = None
) -> Tuple[List[Union[int, float]], List[DataLoader]]:
Expand Down Expand Up @@ -369,6 +388,9 @@ def reset_val_dataloader(self, model: Optional["pl.LightningModule"] = None) ->
RunningStage.VALIDATING, model=pl_module
)

# store epoch of dataloader reset for reload_dataloaders_every_n_epochs
self._last_val_dl_reload_epoch = self.current_epoch

def reset_test_dataloader(self, model: Optional["pl.LightningModule"] = None) -> None:
"""Resets the test dataloader and determines the number of batches.

Expand Down
10 changes: 4 additions & 6 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -658,6 +658,8 @@ def _setup_on_init(self, num_sanity_val_steps: int) -> None:
self.num_val_batches = []
self.test_dataloaders = None
self.val_dataloaders = None
self._last_train_dl_reload_epoch = float("-inf")
self._last_val_dl_reload_epoch = float("-inf")

self.num_predict_batches = []

Expand Down Expand Up @@ -743,6 +745,8 @@ def _fit_impl(
self.state.fn = TrainerFn.FITTING
self.state.status = TrainerStatus.RUNNING
self.training = True
self._last_train_dl_reload_epoch = float("-inf")
self._last_val_dl_reload_epoch = float("-inf")

# if a datamodule comes in as the second arg, then fix it for the user
if isinstance(train_dataloaders, LightningDataModule):
Expand Down Expand Up @@ -1963,12 +1967,6 @@ def progress_bar_dict(self) -> dict:
return self.progress_bar_callback.get_metrics(self, ref_model)
return self.progress_bar_metrics

@property
def _should_reload_dl_epoch(self) -> bool:
"""Check if dataloader should be reloaded in the current epoch."""
n_epochs = self.reload_dataloaders_every_n_epochs
return n_epochs and (not self.current_epoch % n_epochs)

@property
def enable_validation(self) -> bool:
"""Check if we should run validation during training."""
Expand Down
1 change: 0 additions & 1 deletion tests/models/test_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -870,7 +870,6 @@ def call(hook, fn, *args, **kwargs):
dict(name="setup", kwargs=dict(stage="fit")),
dict(name="val_dataloader"),
dict(name="train_dataloader"),
dict(name="val_dataloader"),
dict(name="on_save_checkpoint", args=(ANY,)),
dict(name="teardown", kwargs=dict(stage="fit")),
]
Expand Down
170 changes: 112 additions & 58 deletions tests/trainer/test_dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -1012,17 +1012,12 @@ def test_dataloaders_load_only_once(tmpdir):
assert tracker.mock_calls == [call.val_dataloader(), call.train_dataloader()]


def test_dataloaders_load_only_once_val_interval(tmpdir):
def test_dataloaders_load_only_once_no_sanity_check(tmpdir):
model = BoringModel()

# logger file to get meta
trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=10,
limit_val_batches=10,
val_check_interval=0.3,
reload_dataloaders_every_n_epochs=1,
max_epochs=3,
default_root_dir=tmpdir, limit_train_batches=0.3, limit_val_batches=0.3, num_sanity_val_steps=0, max_epochs=3
carmocca marked this conversation as resolved.
Show resolved Hide resolved
)

tracker = Mock()
Expand All @@ -1035,34 +1030,33 @@ def test_dataloaders_load_only_once_val_interval(tmpdir):
tracker.attach_mock(model.test_dataloader, "test_dataloader")

trainer.fit(model)
trainer.test(model)

# verify the sequence
expected_sequence = [
call.val_dataloader(),
call.train_dataloader(),
call.val_dataloader(),
call.val_dataloader(),
call.val_dataloader(),
call.train_dataloader(),
call.val_dataloader(),
call.val_dataloader(),
call.val_dataloader(),
call.train_dataloader(),
call.val_dataloader(),
call.val_dataloader(),
call.val_dataloader(),
call.test_dataloader(),
]
expected_sequence = [call.train_dataloader(), call.val_dataloader()]
assert tracker.mock_calls == expected_sequence


def test_dataloaders_load_only_once_no_sanity_check(tmpdir):
model = BoringModel()
@pytest.mark.parametrize("n", [1, 2])
def test_dataloaders_load_every_n_epochs(tmpdir, n):
train_reload_epochs, val_reload_epochs = [], []

class TestModel(BoringModel):
def train_dataloader(self):
train_reload_epochs.append(self.current_epoch)
return super().train_dataloader()

def val_dataloader(self):
val_reload_epochs.append(self.current_epoch)
return super().val_dataloader()

model = TestModel()

# logger file to get meta
trainer = Trainer(
default_root_dir=tmpdir, limit_train_batches=0.3, limit_val_batches=0.3, num_sanity_val_steps=0, max_epochs=3
default_root_dir=tmpdir,
limit_train_batches=0.3,
limit_val_batches=0.3,
reload_dataloaders_every_n_epochs=n,
max_epochs=5,
)

tracker = Mock()
Expand All @@ -1075,44 +1069,113 @@ def test_dataloaders_load_only_once_no_sanity_check(tmpdir):
tracker.attach_mock(model.test_dataloader, "test_dataloader")

trainer.fit(model)
trainer.test(model)

# Verify the sequence
expected_sequence = [call.val_dataloader(), call.train_dataloader()] # Sanity check first
if n == 1:
expected_sequence += [call.train_dataloader(), call.val_dataloader()] * 4
elif n == 2:
expected_sequence += [call.train_dataloader(), call.val_dataloader()] * 2
expected_sequence += [call.test_dataloader()]

# verify the sequence
expected_sequence = [call.train_dataloader(), call.val_dataloader()]
assert tracker.mock_calls == expected_sequence

# Verify epoch of reloads
if n == 1:
assert train_reload_epochs == [0, 1, 2, 3, 4]
assert val_reload_epochs == [0, 1, 2, 3, 4]
elif n == 2:
assert train_reload_epochs == [0, 2, 4]
assert val_reload_epochs == [0, 2, 4]

@pytest.mark.parametrize("n", [1, 2])
def test_dataloaders_load_every_n_epochs(tmpdir, n):
model = BoringModel()

@pytest.mark.parametrize(
"n, train_reload_epochs_expect, val_reload_epochs_expect",
[
# Sanity check at epoch 0 creates a validation dataloader, but validation is
# checked (and in this case reloaded) every n epochs starting from epoch n-1
(3, [0, 2, 4, 6, 8], [0, 2, 5, 8]),
(5, [0, 2, 4, 6, 8], [0, 4, 9]),
],
)
def test_dataloaders_load_every_n_epochs_infrequent_val(
tmpdir, n, train_reload_epochs_expect, val_reload_epochs_expect
):
"""Test dataloader reload behavior when infrequently checking validation set (via check_val_every_n_epoch)"""
train_reload_epochs, val_reload_epochs = [], []

class TestModel(BoringModel):
def train_dataloader(self):
train_reload_epochs.append(self.current_epoch)
return super().train_dataloader()

def val_dataloader(self):
val_reload_epochs.append(self.current_epoch)
return super().val_dataloader()

model = TestModel()

trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=0.3,
limit_val_batches=0.3,
reload_dataloaders_every_n_epochs=n,
check_val_every_n_epoch=n,
reload_dataloaders_every_n_epochs=2,
max_epochs=10,
)
model.test_dataloader = Mock(wraps=model.test_dataloader)

trainer.fit(model)
trainer.test(model)

# Verify epoch of reloads
assert train_reload_epochs == train_reload_epochs_expect
assert val_reload_epochs == val_reload_epochs_expect

model.test_dataloader.assert_called_once()


def test_dataloaders_load_every_n_epochs_frequent_val(tmpdir):
"""Test dataloader reload behavior when frequently checking validation set (via val_check_interval)"""
train_reload_epochs, val_reload_epochs, val_check_epochs = [], [], []

class TestModel(BoringModel):
def train_dataloader(self):
train_reload_epochs.append(self.current_epoch)
return super().train_dataloader()

def val_dataloader(self):
val_reload_epochs.append(self.current_epoch)
return super().val_dataloader()

def validation_epoch_end(self, outputs):
val_check_epochs.append(self.current_epoch)
return super().validation_epoch_end(outputs)

model = TestModel()

trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=0.3,
limit_val_batches=0.3,
val_check_interval=0.3,
reload_dataloaders_every_n_epochs=1,
max_epochs=3,
)

tracker = Mock()
model.train_dataloader = Mock(wraps=model.train_dataloader)
model.val_dataloader = Mock(wraps=model.val_dataloader)
model.test_dataloader = Mock(wraps=model.test_dataloader)

tracker.attach_mock(model.train_dataloader, "train_dataloader")
tracker.attach_mock(model.val_dataloader, "val_dataloader")
tracker.attach_mock(model.test_dataloader, "test_dataloader")

trainer.fit(model)
trainer.test(model)

# verify the sequence
expected_sequence = [call.val_dataloader()]
if n == 1:
expected_sequence += [call.train_dataloader(), call.val_dataloader()] * 3
elif n == 2:
expected_sequence += [call.train_dataloader(), call.val_dataloader()] * 2
expected_sequence += [call.test_dataloader()]
assert tracker.mock_calls == expected_sequence
# Verify epoch of reloads
assert train_reload_epochs == [0, 1, 2]
assert val_reload_epochs == [0, 1, 2]
model.test_dataloader.assert_called_once()

# Verify validation happens 3 times per epoch + 1 for sanity check
assert val_check_epochs == [0, 0, 0, 0, 1, 1, 1, 2, 2, 2]


@pytest.mark.parametrize("n", ["test", -1])
Expand Down Expand Up @@ -1159,15 +1222,6 @@ def validation_step(self, batch, batch_idx):
expected_calls = [
call.train_dataloader(),
call.val_dataloader(),
# This has subsequent calls to val_dataloader
# because the training loop runs the evaluation loop,
# which reloads the val dataloader again.
# We cannot yet rely on trainer.current_epoch=0 to skip reloading
# the val dataloader on the first epoch because this only tracks the training epoch
# meaning multiple passes through the validation data within a single training epoch
# would not have the dataloader reloaded.
# This breaks the assumption behind reload_dataloaders_every_n_epochs=1
call.val_dataloader(),
adamviola marked this conversation as resolved.
Show resolved Hide resolved
call.train_dataloader(),
call.val_dataloader(),
call.train_dataloader(),
Expand Down