diff --git a/CHANGELOG.md b/CHANGELOG.md index 96a24f3dd1ba8..331185ee82403 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -261,6 +261,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed +- Fixed running sanity check with `RichProgressBar` ([#10913](https://github.com/PyTorchLightning/pytorch-lightning/pull/10913)) + + - Fixed support for `CombinedLoader` while checking for warning raised with eval dataloaders ([#10994](https://github.com/PyTorchLightning/pytorch-lightning/pull/10994)) diff --git a/pytorch_lightning/callbacks/progress/rich_progress.py b/pytorch_lightning/callbacks/progress/rich_progress.py index 50a017a72934c..fa4d925800d07 100644 --- a/pytorch_lightning/callbacks/progress/rich_progress.py +++ b/pytorch_lightning/callbacks/progress/rich_progress.py @@ -334,7 +334,8 @@ def on_sanity_check_start(self, trainer, pl_module): def on_sanity_check_end(self, trainer, pl_module): super().on_sanity_check_end(trainer, pl_module) - self._update(self.val_sanity_progress_bar_id, visible=False) + if self.progress is not None: + self.progress.update(self.val_sanity_progress_bar_id, advance=0, visible=False) self.refresh() def on_train_epoch_start(self, trainer, pl_module): diff --git a/tests/callbacks/test_rich_progress_bar.py b/tests/callbacks/test_rich_progress_bar.py index 5b0123e3060d3..f2e75006f7ecb 100644 --- a/tests/callbacks/test_rich_progress_bar.py +++ b/tests/callbacks/test_rich_progress_bar.py @@ -201,3 +201,24 @@ def test_rich_progress_bar_refresh_rate(progress_update, tmpdir, refresh_rate, e trainer.fit(model) assert progress_update.call_count == expected_call_count + + +@RunIf(rich=True) +@pytest.mark.parametrize("limit_val_batches", (1, 5)) +def test_rich_progress_bar_num_sanity_val_steps(tmpdir, limit_val_batches: int): + model = BoringModel() + + progress_bar = RichProgressBar() + num_sanity_val_steps = 3 + + trainer = Trainer( + default_root_dir=tmpdir, + num_sanity_val_steps=num_sanity_val_steps, + limit_train_batches=1, + limit_val_batches=limit_val_batches, + max_epochs=1, + callbacks=progress_bar, + ) + + trainer.fit(model) + assert progress_bar.progress.tasks[0].completed == min(num_sanity_val_steps, limit_val_batches)