diff --git a/CHANGELOG.md b/CHANGELOG.md index a6c7bbe4c6b40..09623b6b52e25 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -369,6 +369,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed wrong typehint for `Trainer.lightning_optimizers` ([#11155](https://github.com/PyTorchLightning/pytorch-lightning/pull/11155)) +- Fixed dataloaders not getting reloaded the correct amount of times when setting `reload_dataloaders_every_n_epochs` and `check_val_every_n_epoch` ([#10948](https://github.com/PyTorchLightning/pytorch-lightning/pull/10948)) + + ## [1.5.7] - 2021-12-21 ### Fixed diff --git a/pytorch_lightning/loops/dataloader/evaluation_loop.py b/pytorch_lightning/loops/dataloader/evaluation_loop.py index 8614034889a19..fa3cc1233b2a1 100644 --- a/pytorch_lightning/loops/dataloader/evaluation_loop.py +++ b/pytorch_lightning/loops/dataloader/evaluation_loop.py @@ -188,7 +188,7 @@ def _reload_evaluation_dataloaders(self) -> None: """Reloads dataloaders if necessary.""" if self.trainer.testing: self.trainer.reset_test_dataloader() - elif self.trainer.val_dataloaders is None or self.trainer._should_reload_dl_epoch: + elif self.trainer.val_dataloaders is None or self.trainer._should_reload_val_dl: self.trainer.reset_val_dataloader() def _on_evaluation_start(self, *args: Any, **kwargs: Any) -> None: diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py index 697f1e3b0d840..09493f57a59e7 100644 --- a/pytorch_lightning/loops/fit_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -203,7 +203,7 @@ def on_advance_start(self) -> None: # type: ignore[override] model = self.trainer.lightning_module # reset train dataloader - if not self._is_fresh_start_epoch and self.trainer._should_reload_dl_epoch: + if not self._is_fresh_start_epoch and self.trainer._should_reload_train_dl: self.trainer.reset_train_dataloader(model) self._is_fresh_start_epoch = False diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 890fe17a259b4..76dc4ce989c66 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -48,6 +48,7 @@ class TrainerDataLoadingMixin(ABC): # this is just a summary on variables used in this abstract class, # the proper values/initialisation should be done in child class val_check_interval: float + reload_dataloaders_every_n_epochs: int tpu_local_core_rank: int train_dataloader: DataLoader limit_train_batches: Union[int, float] @@ -67,7 +68,22 @@ class TrainerDataLoadingMixin(ABC): distributed_sampler_kwargs: dict accelerator: Accelerator call_hook: Callable + current_epoch: int _accelerator_connector: AcceleratorConnector + _last_train_dl_reload_epoch: int + _last_val_dl_reload_epoch: int + + @property + def _should_reload_train_dl(self) -> bool: + """Check if train dataloader should be reloaded.""" + n_epochs = self.reload_dataloaders_every_n_epochs + return n_epochs and (self.current_epoch - self._last_train_dl_reload_epoch >= n_epochs) + + @property + def _should_reload_val_dl(self) -> bool: + """Check if validation dataloader should be reloaded.""" + n_epochs = self.reload_dataloaders_every_n_epochs + return n_epochs and (self.current_epoch - self._last_val_dl_reload_epoch >= n_epochs) def _worker_check(self, dataloader: DataLoader, name: str) -> None: if not isinstance(dataloader, DataLoader): @@ -278,6 +294,9 @@ def reset_train_dataloader(self, model: Optional["pl.LightningModule"] = None) - category=PossibleUserWarning, ) + # store epoch of dataloader reset for reload_dataloaders_every_n_epochs + self._last_train_dl_reload_epoch = self.current_epoch + def _reset_eval_dataloader( self, mode: RunningStage, model: Optional["pl.LightningModule"] = None ) -> Tuple[List[Union[int, float]], List[DataLoader]]: @@ -369,6 +388,9 @@ def reset_val_dataloader(self, model: Optional["pl.LightningModule"] = None) -> RunningStage.VALIDATING, model=pl_module ) + # store epoch of dataloader reset for reload_dataloaders_every_n_epochs + self._last_val_dl_reload_epoch = self.current_epoch + def reset_test_dataloader(self, model: Optional["pl.LightningModule"] = None) -> None: """Resets the test dataloader and determines the number of batches. diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index b2ee41bef8ca8..b9ae7f2dc2036 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -658,6 +658,8 @@ def _setup_on_init(self, num_sanity_val_steps: int) -> None: self.num_val_batches = [] self.test_dataloaders = None self.val_dataloaders = None + self._last_train_dl_reload_epoch = float("-inf") + self._last_val_dl_reload_epoch = float("-inf") self.num_predict_batches = [] @@ -743,6 +745,8 @@ def _fit_impl( self.state.fn = TrainerFn.FITTING self.state.status = TrainerStatus.RUNNING self.training = True + self._last_train_dl_reload_epoch = float("-inf") + self._last_val_dl_reload_epoch = float("-inf") # if a datamodule comes in as the second arg, then fix it for the user if isinstance(train_dataloaders, LightningDataModule): @@ -1963,12 +1967,6 @@ def progress_bar_dict(self) -> dict: return self.progress_bar_callback.get_metrics(self, ref_model) return self.progress_bar_metrics - @property - def _should_reload_dl_epoch(self) -> bool: - """Check if dataloader should be reloaded in the current epoch.""" - n_epochs = self.reload_dataloaders_every_n_epochs - return n_epochs and (not self.current_epoch % n_epochs) - @property def enable_validation(self) -> bool: """Check if we should run validation during training.""" diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 2d37d0fd17b4e..acc12151739db 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -870,7 +870,6 @@ def call(hook, fn, *args, **kwargs): dict(name="setup", kwargs=dict(stage="fit")), dict(name="val_dataloader"), dict(name="train_dataloader"), - dict(name="val_dataloader"), dict(name="on_save_checkpoint", args=(ANY,)), dict(name="teardown", kwargs=dict(stage="fit")), ] diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index eebf9bb992615..47013c9e3f387 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -1012,17 +1012,12 @@ def test_dataloaders_load_only_once(tmpdir): assert tracker.mock_calls == [call.val_dataloader(), call.train_dataloader()] -def test_dataloaders_load_only_once_val_interval(tmpdir): +def test_dataloaders_load_only_once_no_sanity_check(tmpdir): model = BoringModel() # logger file to get meta trainer = Trainer( - default_root_dir=tmpdir, - limit_train_batches=10, - limit_val_batches=10, - val_check_interval=0.3, - reload_dataloaders_every_n_epochs=1, - max_epochs=3, + default_root_dir=tmpdir, limit_train_batches=0.3, limit_val_batches=0.3, num_sanity_val_steps=0, max_epochs=3 ) tracker = Mock() @@ -1035,34 +1030,33 @@ def test_dataloaders_load_only_once_val_interval(tmpdir): tracker.attach_mock(model.test_dataloader, "test_dataloader") trainer.fit(model) - trainer.test(model) # verify the sequence - expected_sequence = [ - call.val_dataloader(), - call.train_dataloader(), - call.val_dataloader(), - call.val_dataloader(), - call.val_dataloader(), - call.train_dataloader(), - call.val_dataloader(), - call.val_dataloader(), - call.val_dataloader(), - call.train_dataloader(), - call.val_dataloader(), - call.val_dataloader(), - call.val_dataloader(), - call.test_dataloader(), - ] + expected_sequence = [call.train_dataloader(), call.val_dataloader()] assert tracker.mock_calls == expected_sequence -def test_dataloaders_load_only_once_no_sanity_check(tmpdir): - model = BoringModel() +@pytest.mark.parametrize("n", [1, 2]) +def test_dataloaders_load_every_n_epochs(tmpdir, n): + train_reload_epochs, val_reload_epochs = [], [] + + class TestModel(BoringModel): + def train_dataloader(self): + train_reload_epochs.append(self.current_epoch) + return super().train_dataloader() + + def val_dataloader(self): + val_reload_epochs.append(self.current_epoch) + return super().val_dataloader() + + model = TestModel() - # logger file to get meta trainer = Trainer( - default_root_dir=tmpdir, limit_train_batches=0.3, limit_val_batches=0.3, num_sanity_val_steps=0, max_epochs=3 + default_root_dir=tmpdir, + limit_train_batches=0.3, + limit_val_batches=0.3, + reload_dataloaders_every_n_epochs=n, + max_epochs=5, ) tracker = Mock() @@ -1075,44 +1069,113 @@ def test_dataloaders_load_only_once_no_sanity_check(tmpdir): tracker.attach_mock(model.test_dataloader, "test_dataloader") trainer.fit(model) + trainer.test(model) + + # Verify the sequence + expected_sequence = [call.val_dataloader(), call.train_dataloader()] # Sanity check first + if n == 1: + expected_sequence += [call.train_dataloader(), call.val_dataloader()] * 4 + elif n == 2: + expected_sequence += [call.train_dataloader(), call.val_dataloader()] * 2 + expected_sequence += [call.test_dataloader()] - # verify the sequence - expected_sequence = [call.train_dataloader(), call.val_dataloader()] assert tracker.mock_calls == expected_sequence + # Verify epoch of reloads + if n == 1: + assert train_reload_epochs == [0, 1, 2, 3, 4] + assert val_reload_epochs == [0, 1, 2, 3, 4] + elif n == 2: + assert train_reload_epochs == [0, 2, 4] + assert val_reload_epochs == [0, 2, 4] -@pytest.mark.parametrize("n", [1, 2]) -def test_dataloaders_load_every_n_epochs(tmpdir, n): - model = BoringModel() + +@pytest.mark.parametrize( + "n, train_reload_epochs_expect, val_reload_epochs_expect", + [ + # Sanity check at epoch 0 creates a validation dataloader, but validation is + # checked (and in this case reloaded) every n epochs starting from epoch n-1 + (3, [0, 2, 4, 6, 8], [0, 2, 5, 8]), + (5, [0, 2, 4, 6, 8], [0, 4, 9]), + ], +) +def test_dataloaders_load_every_n_epochs_infrequent_val( + tmpdir, n, train_reload_epochs_expect, val_reload_epochs_expect +): + """Test dataloader reload behavior when infrequently checking validation set (via check_val_every_n_epoch)""" + train_reload_epochs, val_reload_epochs = [], [] + + class TestModel(BoringModel): + def train_dataloader(self): + train_reload_epochs.append(self.current_epoch) + return super().train_dataloader() + + def val_dataloader(self): + val_reload_epochs.append(self.current_epoch) + return super().val_dataloader() + + model = TestModel() trainer = Trainer( default_root_dir=tmpdir, limit_train_batches=0.3, limit_val_batches=0.3, - reload_dataloaders_every_n_epochs=n, + check_val_every_n_epoch=n, + reload_dataloaders_every_n_epochs=2, + max_epochs=10, + ) + model.test_dataloader = Mock(wraps=model.test_dataloader) + + trainer.fit(model) + trainer.test(model) + + # Verify epoch of reloads + assert train_reload_epochs == train_reload_epochs_expect + assert val_reload_epochs == val_reload_epochs_expect + + model.test_dataloader.assert_called_once() + + +def test_dataloaders_load_every_n_epochs_frequent_val(tmpdir): + """Test dataloader reload behavior when frequently checking validation set (via val_check_interval)""" + train_reload_epochs, val_reload_epochs, val_check_epochs = [], [], [] + + class TestModel(BoringModel): + def train_dataloader(self): + train_reload_epochs.append(self.current_epoch) + return super().train_dataloader() + + def val_dataloader(self): + val_reload_epochs.append(self.current_epoch) + return super().val_dataloader() + + def validation_epoch_end(self, outputs): + val_check_epochs.append(self.current_epoch) + return super().validation_epoch_end(outputs) + + model = TestModel() + + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=0.3, + limit_val_batches=0.3, + val_check_interval=0.3, + reload_dataloaders_every_n_epochs=1, max_epochs=3, ) - tracker = Mock() - model.train_dataloader = Mock(wraps=model.train_dataloader) - model.val_dataloader = Mock(wraps=model.val_dataloader) model.test_dataloader = Mock(wraps=model.test_dataloader) - tracker.attach_mock(model.train_dataloader, "train_dataloader") - tracker.attach_mock(model.val_dataloader, "val_dataloader") - tracker.attach_mock(model.test_dataloader, "test_dataloader") - trainer.fit(model) trainer.test(model) - # verify the sequence - expected_sequence = [call.val_dataloader()] - if n == 1: - expected_sequence += [call.train_dataloader(), call.val_dataloader()] * 3 - elif n == 2: - expected_sequence += [call.train_dataloader(), call.val_dataloader()] * 2 - expected_sequence += [call.test_dataloader()] - assert tracker.mock_calls == expected_sequence + # Verify epoch of reloads + assert train_reload_epochs == [0, 1, 2] + assert val_reload_epochs == [0, 1, 2] + model.test_dataloader.assert_called_once() + + # Verify validation happens 3 times per epoch + 1 for sanity check + assert val_check_epochs == [0, 0, 0, 0, 1, 1, 1, 2, 2, 2] @pytest.mark.parametrize("n", ["test", -1]) @@ -1159,15 +1222,6 @@ def validation_step(self, batch, batch_idx): expected_calls = [ call.train_dataloader(), call.val_dataloader(), - # This has subsequent calls to val_dataloader - # because the training loop runs the evaluation loop, - # which reloads the val dataloader again. - # We cannot yet rely on trainer.current_epoch=0 to skip reloading - # the val dataloader on the first epoch because this only tracks the training epoch - # meaning multiple passes through the validation data within a single training epoch - # would not have the dataloader reloaded. - # This breaks the assumption behind reload_dataloaders_every_n_epochs=1 - call.val_dataloader(), call.train_dataloader(), call.val_dataloader(), call.train_dataloader(),