From 06fd8ae7a583abc0a3dcd5452f3e65080a8b8d25 Mon Sep 17 00:00:00 2001 From: Kaushik B Date: Fri, 3 Dec 2021 10:55:19 +0530 Subject: [PATCH 1/3] Fix sanity check for RichProgressBar --- .../callbacks/progress/rich_progress.py | 3 ++- tests/callbacks/test_rich_progress_bar.py | 21 +++++++++++++++++++ 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/callbacks/progress/rich_progress.py b/pytorch_lightning/callbacks/progress/rich_progress.py index 50a017a72934c..fa4d925800d07 100644 --- a/pytorch_lightning/callbacks/progress/rich_progress.py +++ b/pytorch_lightning/callbacks/progress/rich_progress.py @@ -334,7 +334,8 @@ def on_sanity_check_start(self, trainer, pl_module): def on_sanity_check_end(self, trainer, pl_module): super().on_sanity_check_end(trainer, pl_module) - self._update(self.val_sanity_progress_bar_id, visible=False) + if self.progress is not None: + self.progress.update(self.val_sanity_progress_bar_id, advance=0, visible=False) self.refresh() def on_train_epoch_start(self, trainer, pl_module): diff --git a/tests/callbacks/test_rich_progress_bar.py b/tests/callbacks/test_rich_progress_bar.py index 5b0123e3060d3..44d4c52706627 100644 --- a/tests/callbacks/test_rich_progress_bar.py +++ b/tests/callbacks/test_rich_progress_bar.py @@ -201,3 +201,24 @@ def test_rich_progress_bar_refresh_rate(progress_update, tmpdir, refresh_rate, e trainer.fit(model) assert progress_update.call_count == expected_call_count + + +@RunIf(rich=True) +def test_rich_progress_bar_stages(tmpdir): + model = BoringModel() + + trainer = Trainer( + default_root_dir=tmpdir, + num_sanity_val_steps=1, + limit_train_batches=1, + limit_val_batches=1, + limit_test_batches=1, + limit_predict_batches=1, + max_epochs=1, + callbacks=RichProgressBar(), + ) + + trainer.fit(model) + trainer.validate(model) + trainer.test(model) + trainer.predict(model) From cd4bb4b5cfdbc3ac29109a16a8aa8c03a80e71e5 Mon Sep 17 00:00:00 2001 From: Kaushik B Date: Fri, 3 Dec 2021 10:58:27 +0530 Subject: [PATCH 2/3] Update changelog --- CHANGELOG.md | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9ed0408bfbb15..b7280e0608c10 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -231,9 +231,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed a bug that caused incorrect batch indices to be passed to the `BasePredictionWriter` hooks when using a dataloader with `num_workers > 0` ([#10870](https://github.com/PyTorchLightning/pytorch-lightning/pull/10870)) - -- - +- Fixed running sanity check with `RichProgressBar` ([#10913](https://github.com/PyTorchLightning/pytorch-lightning/pull/10913)) ## [1.5.4] - 2021-11-30 From 5a2203fb543c288cbaa6dc6116d77accb75c9021 Mon Sep 17 00:00:00 2001 From: Kaushik B Date: Fri, 3 Dec 2021 16:36:41 +0530 Subject: [PATCH 3/3] Update test --- tests/callbacks/test_rich_progress_bar.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/tests/callbacks/test_rich_progress_bar.py b/tests/callbacks/test_rich_progress_bar.py index 44d4c52706627..f2e75006f7ecb 100644 --- a/tests/callbacks/test_rich_progress_bar.py +++ b/tests/callbacks/test_rich_progress_bar.py @@ -204,21 +204,21 @@ def test_rich_progress_bar_refresh_rate(progress_update, tmpdir, refresh_rate, e @RunIf(rich=True) -def test_rich_progress_bar_stages(tmpdir): +@pytest.mark.parametrize("limit_val_batches", (1, 5)) +def test_rich_progress_bar_num_sanity_val_steps(tmpdir, limit_val_batches: int): model = BoringModel() + progress_bar = RichProgressBar() + num_sanity_val_steps = 3 + trainer = Trainer( default_root_dir=tmpdir, - num_sanity_val_steps=1, + num_sanity_val_steps=num_sanity_val_steps, limit_train_batches=1, - limit_val_batches=1, - limit_test_batches=1, - limit_predict_batches=1, + limit_val_batches=limit_val_batches, max_epochs=1, - callbacks=RichProgressBar(), + callbacks=progress_bar, ) trainer.fit(model) - trainer.validate(model) - trainer.test(model) - trainer.predict(model) + assert progress_bar.progress.tasks[0].completed == min(num_sanity_val_steps, limit_val_batches)