diff --git a/CHANGELOG.md b/CHANGELOG.md index 6c09c2f3216c1..6a1d26cfd1acb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -104,7 +104,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added a friendly error message when DDP attempts to spawn new distributed processes with rank > 0 ([#9005](https://github.com/PyTorchLightning/pytorch-lightning/pull/9005)) -- Added Rich Progress Bar ([#8929](https://github.com/PyTorchLightning/pytorch-lightning/pull/8929)) +- Added Rich Progress Bar: + * Rich Progress Bar ([#8929](https://github.com/PyTorchLightning/pytorch-lightning/pull/8929)) + * Improvements for rich progress bar ([#9559](https://github.com/PyTorchLightning/pytorch-lightning/pull/9559)) - Added validate logic for precision ([#9080](https://github.com/PyTorchLightning/pytorch-lightning/pull/9080)) diff --git a/pytorch_lightning/callbacks/progress/rich_progress.py b/pytorch_lightning/callbacks/progress/rich_progress.py index 1f36e86978e92..6bc2f4f7d8a43 100644 --- a/pytorch_lightning/callbacks/progress/rich_progress.py +++ b/pytorch_lightning/callbacks/progress/rich_progress.py @@ -11,15 +11,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from dataclasses import dataclass from datetime import timedelta -from typing import Dict, Optional +from typing import Optional, Union from pytorch_lightning.callbacks.progress.base import ProgressBarBase from pytorch_lightning.utilities import _RICH_AVAILABLE +Style = None if _RICH_AVAILABLE: from rich.console import Console, RenderableType - from rich.progress import BarColumn, Progress, ProgressColumn, SpinnerColumn, TextColumn + from rich.progress import BarColumn, Progress, ProgressColumn, TextColumn + from rich.style import Style from rich.text import Text class CustomTimeColumn(ProgressColumn): @@ -27,21 +30,33 @@ class CustomTimeColumn(ProgressColumn): # Only refresh twice a second to prevent jitter max_refresh = 0.5 + def __init__(self, style: Union[str, Style]) -> None: + self.style = style + super().__init__() + def render(self, task) -> Text: elapsed = task.finished_time if task.finished else task.elapsed remaining = task.time_remaining elapsed_delta = "-:--:--" if elapsed is None else str(timedelta(seconds=int(elapsed))) remaining_delta = "-:--:--" if remaining is None else str(timedelta(seconds=int(remaining))) - return Text.from_markup(f"[progress.elapsed]{elapsed_delta} < [progress.remaining]{remaining_delta}") + return Text(f"{elapsed_delta} • {remaining_delta}", style=self.style) class BatchesProcessedColumn(ProgressColumn): + def __init__(self, style: Union[str, Style]): + self.style = style + super().__init__() + def render(self, task) -> RenderableType: - return Text.from_markup(f"[magenta] {int(task.completed)}/{task.total}") + return Text(f"{int(task.completed)}/{task.total}", style=self.style) class ProcessingSpeedColumn(ProgressColumn): + def __init__(self, style: Union[str, Style]): + self.style = style + super().__init__() + def render(self, task) -> RenderableType: task_speed = f"{task.speed:>.2f}" if task.speed is not None else "0.00" - return Text.from_markup(f"[progress.data.speed] {task_speed}it/s") + return Text(f"{task_speed}it/s", style=self.style) class MetricsTextColumn(ProgressColumn): """A column containing text.""" @@ -71,19 +86,26 @@ def render(self, task) -> Text: metrics = self._trainer.progress_bar_callback.get_metrics(self._trainer, self._pl_module) else: metrics = self._trainer.progress_bar_metrics + for k, v in metrics.items(): _text += f"{k}: {round(v, 3) if isinstance(v, float) else v} " text = Text.from_markup(_text, style=None, justify="left") return text -STYLES: Dict[str, str] = { - "train": "red", - "sanity_check": "yellow", - "validate": "yellow", - "test": "yellow", - "predict": "yellow", -} +@dataclass +class RichProgressBarTheme: + """Styles to associate to different base components. + + https://rich.readthedocs.io/en/stable/style.html + """ + + text_color: str = "white" + progress_bar_complete: Union[str, Style] = "#6206E0" + progress_bar_finished: Union[str, Style] = "#6206E0" + batch_process: str = "white" + time: str = "grey54" + processing_speed: str = "grey70" class RichProgressBar(ProgressBarBase): @@ -104,13 +126,18 @@ class RichProgressBar(ProgressBarBase): Args: refresh_rate: the number of updates per second, must be strictly positive + theme: Contains styles used to stylize the progress bar. Raises: ImportError: If required `rich` package is not installed on the device. """ - def __init__(self, refresh_rate: float = 1.0): + def __init__( + self, + refresh_rate: float = 1.0, + theme: RichProgressBarTheme = RichProgressBarTheme(), + ) -> None: if not _RICH_AVAILABLE: raise ImportError( "`RichProgressBar` requires `rich` to be installed. Install it by running `pip install rich`." @@ -126,6 +153,7 @@ def __init__(self, refresh_rate: float = 1.0): self.test_progress_bar_id: Optional[int] = None self.predict_progress_bar_id: Optional[int] = None self.console = Console(record=True) + self.theme = theme @property def refresh_rate(self) -> int: @@ -147,31 +175,28 @@ def enable(self) -> None: @property def sanity_check_description(self) -> str: - return "[Validation Sanity Check]" + return "Validation Sanity Check" @property def validation_description(self) -> str: - return "[Validation]" + return "Validation" @property def test_description(self) -> str: - return "[Testing]" + return "Testing" @property def predict_description(self) -> str: - return "[Predicting]" + return "Predicting" def setup(self, trainer, pl_module, stage): self.progress = Progress( - SpinnerColumn(), TextColumn("[progress.description]{task.description}"), - BarColumn(), - BatchesProcessedColumn(), - "[", - CustomTimeColumn(), - ProcessingSpeedColumn(), + BarColumn(complete_style=self.theme.progress_bar_complete, finished_style=self.theme.progress_bar_finished), + BatchesProcessedColumn(style=self.theme.batch_process), + CustomTimeColumn(style=self.theme.time), + ProcessingSpeedColumn(style=self.theme.processing_speed), MetricsTextColumn(trainer, pl_module, stage), - "]", console=self.console, refresh_per_second=self.refresh_rate, ).__enter__() @@ -179,7 +204,7 @@ def setup(self, trainer, pl_module, stage): def on_sanity_check_start(self, trainer, pl_module): super().on_sanity_check_start(trainer, pl_module) self.val_sanity_progress_bar_id = self.progress.add_task( - f"[{STYLES['sanity_check']}]{self.sanity_check_description}", + f"[{self.theme.text_color}]{self.sanity_check_description}", total=trainer.num_sanity_val_steps, ) @@ -201,7 +226,7 @@ def on_train_epoch_start(self, trainer, pl_module): train_description = self._get_train_description(trainer.current_epoch) self.main_progress_bar_id = self.progress.add_task( - f"[{STYLES['train']}]{train_description}", + f"[{self.theme.text_color}]{train_description}", total=total_batches, ) @@ -209,7 +234,7 @@ def on_validation_epoch_start(self, trainer, pl_module): super().on_validation_epoch_start(trainer, pl_module) if self._total_val_batches > 0: self.val_progress_bar_id = self.progress.add_task( - f"[{STYLES['validate']}]{self.validation_description}", + f"[{self.theme.text_color}]{self.validation_description}", total=self._total_val_batches, ) @@ -221,14 +246,14 @@ def on_validation_epoch_end(self, trainer, pl_module): def on_test_epoch_start(self, trainer, pl_module): super().on_train_epoch_start(trainer, pl_module) self.test_progress_bar_id = self.progress.add_task( - f"[{STYLES['test']}]{self.test_description}", + f"[{self.theme.text_color}]{self.test_description}", total=self.total_test_batches, ) def on_predict_epoch_start(self, trainer, pl_module): super().on_predict_epoch_start(trainer, pl_module) self.predict_progress_bar_id = self.progress.add_task( - f"[{STYLES['predict']}]{self.predict_description}", + f"[{self.theme.text_color}]{self.predict_description}", total=self.total_predict_batches, ) @@ -261,7 +286,7 @@ def _should_update(self, current, total) -> bool: return self.is_enabled and (current % self.refresh_rate == 0 or current == total) def _get_train_description(self, current_epoch: int) -> str: - train_description = f"[Epoch {current_epoch}]" + train_description = f"Epoch {current_epoch}" if len(self.validation_description) > len(train_description): # Padding is required to avoid flickering due of uneven lengths of "Epoch X" # and "Validation" Bar description @@ -273,3 +298,7 @@ def _get_train_description(self, current_epoch: int) -> str: def teardown(self, trainer, pl_module, stage): self.progress.__exit__(None, None, None) + + def on_exception(self, trainer, pl_module, exception: BaseException) -> None: + if isinstance(exception, KeyboardInterrupt): + self.progress.stop() diff --git a/tests/callbacks/test_rich_progress_bar.py b/tests/callbacks/test_rich_progress_bar.py index 687f1f679b858..ce25c929b0547 100644 --- a/tests/callbacks/test_rich_progress_bar.py +++ b/tests/callbacks/test_rich_progress_bar.py @@ -12,11 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. from unittest import mock +from unittest.mock import DEFAULT import pytest from pytorch_lightning import Trainer from pytorch_lightning.callbacks import ProgressBarBase, RichProgressBar +from pytorch_lightning.callbacks.progress.rich_progress import RichProgressBarTheme from pytorch_lightning.utilities.imports import _RICH_AVAILABLE from tests.helpers.boring_model import BoringModel from tests.helpers.runif import RunIf @@ -24,7 +26,6 @@ @RunIf(rich=True) def test_rich_progress_bar_callback(): - trainer = Trainer(callbacks=RichProgressBar()) progress_bars = [c for c in trainer.callbacks if isinstance(c, ProgressBarBase)] @@ -36,7 +37,6 @@ def test_rich_progress_bar_callback(): @RunIf(rich=True) @mock.patch("pytorch_lightning.callbacks.progress.rich_progress.Progress.update") def test_rich_progress_bar(progress_update, tmpdir): - model = BoringModel() trainer = Trainer( @@ -58,7 +58,61 @@ def test_rich_progress_bar(progress_update, tmpdir): def test_rich_progress_bar_import_error(): - if not _RICH_AVAILABLE: with pytest.raises(ImportError, match="`RichProgressBar` requires `rich` to be installed."): Trainer(callbacks=RichProgressBar()) + + +@RunIf(rich=True) +def test_rich_progress_bar_custom_theme(tmpdir): + """Test to ensure that custom theme styles are used.""" + with mock.patch.multiple( + "pytorch_lightning.callbacks.progress.rich_progress", + BarColumn=DEFAULT, + BatchesProcessedColumn=DEFAULT, + CustomTimeColumn=DEFAULT, + ProcessingSpeedColumn=DEFAULT, + ) as mocks: + + theme = RichProgressBarTheme() + + progress_bar = RichProgressBar(theme=theme) + progress_bar.setup(Trainer(tmpdir), BoringModel(), stage=None) + + assert progress_bar.theme == theme + args, kwargs = mocks["BarColumn"].call_args + assert kwargs["complete_style"] == theme.progress_bar_complete + assert kwargs["finished_style"] == theme.progress_bar_finished + + args, kwargs = mocks["BatchesProcessedColumn"].call_args + assert kwargs["style"] == theme.batch_process + + args, kwargs = mocks["CustomTimeColumn"].call_args + assert kwargs["style"] == theme.time + + args, kwargs = mocks["ProcessingSpeedColumn"].call_args + assert kwargs["style"] == theme.processing_speed + + +@RunIf(rich=True) +def test_rich_progress_bar_keyboard_interrupt(tmpdir): + """Test to ensure that when the user keyboard interrupts, we close the progress bar.""" + + class TestModel(BoringModel): + def on_train_start(self) -> None: + raise KeyboardInterrupt + + model = TestModel() + + with mock.patch( + "pytorch_lightning.callbacks.progress.rich_progress.Progress.stop", autospec=True + ) as mock_progress_stop: + progress_bar = RichProgressBar() + trainer = Trainer( + default_root_dir=tmpdir, + fast_dev_run=True, + callbacks=progress_bar, + ) + + trainer.fit(model) + mock_progress_stop.assert_called_once()