Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improvements for rich progress bar #9559

Merged
merged 5 commits into from
Sep 16, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 61 additions & 30 deletions pytorch_lightning/callbacks/progress/rich_progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,37 +11,51 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from datetime import timedelta
from typing import Dict, Optional
from typing import Optional, Union

from pytorch_lightning.callbacks.progress.base import ProgressBarBase
from pytorch_lightning.utilities import _RICH_AVAILABLE

if _RICH_AVAILABLE:
from rich.console import Console, RenderableType
from rich.progress import BarColumn, Progress, ProgressColumn, SpinnerColumn, TextColumn
from rich.progress import BarColumn, Progress, ProgressColumn, TextColumn
from rich.style import Style
from rich.text import Text

class CustomTimeColumn(ProgressColumn):

# Only refresh twice a second to prevent jitter
max_refresh = 0.5

def __init__(self, style: Union[str, Style]):
SeanNaren marked this conversation as resolved.
Show resolved Hide resolved
self.style = style
super().__init__()

def render(self, task) -> Text:
elapsed = task.finished_time if task.finished else task.elapsed
remaining = task.time_remaining
elapsed_delta = "-:--:--" if elapsed is None else str(timedelta(seconds=int(elapsed)))
remaining_delta = "-:--:--" if remaining is None else str(timedelta(seconds=int(remaining)))
return Text.from_markup(f"[progress.elapsed]{elapsed_delta} < [progress.remaining]{remaining_delta}")
return Text(f"{elapsed_delta} {remaining_delta}", style=self.style)

class BatchesProcessedColumn(ProgressColumn):
def __init__(self, style: Union[str, Style]):
self.style = style
super().__init__()

def render(self, task) -> RenderableType:
return Text.from_markup(f"[magenta] {int(task.completed)}/{task.total}")
return Text(f"{int(task.completed)}/{task.total}", style=self.style)

class ProcessingSpeedColumn(ProgressColumn):
def __init__(self, style: Union[str, Style]):
self.style = style
super().__init__()

def render(self, task) -> RenderableType:
task_speed = f"{task.speed:>.2f}" if task.speed is not None else "0.00"
return Text.from_markup(f"[progress.data.speed] {task_speed}it/s")
return Text(f"{task_speed}it/s", style=self.style)

class MetricsTextColumn(ProgressColumn):
"""A column containing text."""
Expand Down Expand Up @@ -71,19 +85,29 @@ def render(self, task) -> Text:
metrics = self._trainer.progress_bar_callback.get_metrics(self._trainer, self._pl_module)
else:
metrics = self._trainer.progress_bar_metrics

# todo: what even is v_num
metrics.pop("v_num", None)
SeanNaren marked this conversation as resolved.
Show resolved Hide resolved

for k, v in metrics.items():
_text += f"{k}: {round(v, 3) if isinstance(v, float) else v} "
text = Text.from_markup(_text, style=None, justify="left")
return text


STYLES: Dict[str, str] = {
"train": "red",
"sanity_check": "yellow",
"validate": "yellow",
"test": "yellow",
"predict": "yellow",
}
@dataclass
class RichProgressBarTheme:
"""Styles to associate to different base components.

https://rich.readthedocs.io/en/stable/style.html
"""

text_color: str = "white"
progress_bar_complete: Union[str, Style] = "#6206E0"
progress_bar_finished: Union[str, Style] = "#6206E0"
batch_process: str = "white"
time: str = "grey54"
processing_speed: str = "grey70"
Comment on lines -80 to +108
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SeanNaren @kaushikb11

noob question: Why did we introduce this new class RichProgressBarTheme and pass an instance to RichProgressBar? I'm asking because if we didn't have this class and passed colours (or a dictionary of colours) directly to RichProgressBar, colour customisation would be easier for users in my opinion. (One less thing to remember :])

from pytorch_lightning.callbacks import RichProgressBar
-from pytorch_lightning.callbacks.progress.rich_progress import RichProgressBarTheme

progress_bar = RichProgressBar(
-    theme=RichProgressBarTheme(
        description="green_yellow",
        progress_bar="green1",
        progress_bar_finished="green1",
        batch_progress="green_yellow",
        time="grey82",
        processing_speed="grey82",
        metrics="grey82",
-    )
)

FYI, I'm looking at the example in the blog post: https://devblog.pytorchlightning.ai/super-charged-progress-bars-with-rich-lightning-669653d6ab97



class RichProgressBar(ProgressBarBase):
Expand All @@ -104,13 +128,18 @@ class RichProgressBar(ProgressBarBase):

Args:
refresh_rate: the number of updates per second, must be strictly positive
theme: Contains styles used to stylize the progress bar.

Raises:
ImportError:
If required `rich` package is not installed on the device.
"""

def __init__(self, refresh_rate: float = 1.0):
def __init__(
self,
refresh_rate: float = 1.0,
theme: RichProgressBarTheme = RichProgressBarTheme(),
):
SeanNaren marked this conversation as resolved.
Show resolved Hide resolved
if not _RICH_AVAILABLE:
raise ImportError(
"`RichProgressBar` requires `rich` to be installed. Install it by running `pip install rich`."
Expand All @@ -126,6 +155,7 @@ def __init__(self, refresh_rate: float = 1.0):
self.test_progress_bar_id: Optional[int] = None
self.predict_progress_bar_id: Optional[int] = None
self.console = Console(record=True)
self.theme = theme

@property
def refresh_rate(self) -> int:
Expand All @@ -147,39 +177,36 @@ def enable(self) -> None:

@property
def sanity_check_description(self) -> str:
return "[Validation Sanity Check]"
return "Validation Sanity Check"

@property
def validation_description(self) -> str:
return "[Validation]"
return "Validation"

@property
def test_description(self) -> str:
return "[Testing]"
return "Testing"

@property
def predict_description(self) -> str:
return "[Predicting]"
return "Predicting"

def setup(self, trainer, pl_module, stage):
self.progress = Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
BarColumn(),
BatchesProcessedColumn(),
"[",
CustomTimeColumn(),
ProcessingSpeedColumn(),
BarColumn(complete_style=self.theme.progress_bar_complete, finished_style=self.theme.progress_bar_finished),
BatchesProcessedColumn(style=self.theme.batch_process),
CustomTimeColumn(style=self.theme.time),
ProcessingSpeedColumn(style=self.theme.processing_speed),
MetricsTextColumn(trainer, pl_module, stage),
"]",
console=self.console,
refresh_per_second=self.refresh_rate,
).__enter__()

def on_sanity_check_start(self, trainer, pl_module):
super().on_sanity_check_start(trainer, pl_module)
self.val_sanity_progress_bar_id = self.progress.add_task(
f"[{STYLES['sanity_check']}]{self.sanity_check_description}",
f"[{self.theme.text_color}]{self.sanity_check_description}",
total=trainer.num_sanity_val_steps,
)

Expand All @@ -201,15 +228,15 @@ def on_train_epoch_start(self, trainer, pl_module):
train_description = self._get_train_description(trainer.current_epoch)

self.main_progress_bar_id = self.progress.add_task(
f"[{STYLES['train']}]{train_description}",
f"[{self.theme.text_color}]{train_description}",
total=total_batches,
)

def on_validation_epoch_start(self, trainer, pl_module):
super().on_validation_epoch_start(trainer, pl_module)
if self._total_val_batches > 0:
self.val_progress_bar_id = self.progress.add_task(
f"[{STYLES['validate']}]{self.validation_description}",
f"[{self.theme.text_color}]{self.validation_description}",
total=self._total_val_batches,
)

Expand All @@ -221,14 +248,14 @@ def on_validation_epoch_end(self, trainer, pl_module):
def on_test_epoch_start(self, trainer, pl_module):
super().on_train_epoch_start(trainer, pl_module)
self.test_progress_bar_id = self.progress.add_task(
f"[{STYLES['test']}]{self.test_description}",
f"[{self.theme.text_color}]{self.test_description}",
total=self.total_test_batches,
)

def on_predict_epoch_start(self, trainer, pl_module):
super().on_predict_epoch_start(trainer, pl_module)
self.predict_progress_bar_id = self.progress.add_task(
f"[{STYLES['predict']}]{self.predict_description}",
f"[{self.theme.text_color}]{self.predict_description}",
total=self.total_predict_batches,
)

Expand Down Expand Up @@ -261,7 +288,7 @@ def _should_update(self, current, total) -> bool:
return self.is_enabled and (current % self.refresh_rate == 0 or current == total)

def _get_train_description(self, current_epoch: int) -> str:
train_description = f"[Epoch {current_epoch}]"
train_description = f"Epoch {current_epoch}"
if len(self.validation_description) > len(train_description):
# Padding is required to avoid flickering due of uneven lengths of "Epoch X"
# and "Validation" Bar description
Expand All @@ -273,3 +300,7 @@ def _get_train_description(self, current_epoch: int) -> str:

def teardown(self, trainer, pl_module, stage):
self.progress.__exit__(None, None, None)

def on_exception(self, trainer, pl_module, exception: BaseException) -> None:
if isinstance(exception, KeyboardInterrupt):
self.progress.stop()
60 changes: 57 additions & 3 deletions tests/callbacks/test_rich_progress_bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest import mock
from unittest.mock import DEFAULT

import pytest

from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ProgressBarBase, RichProgressBar
from pytorch_lightning.callbacks.progress.rich_progress import RichProgressBarTheme
from pytorch_lightning.utilities.imports import _RICH_AVAILABLE
from tests.helpers.boring_model import BoringModel
from tests.helpers.runif import RunIf


@RunIf(rich=True)
def test_rich_progress_bar_callback():

trainer = Trainer(callbacks=RichProgressBar())

progress_bars = [c for c in trainer.callbacks if isinstance(c, ProgressBarBase)]
Expand All @@ -36,7 +37,6 @@ def test_rich_progress_bar_callback():
@RunIf(rich=True)
@mock.patch("pytorch_lightning.callbacks.progress.rich_progress.Progress.update")
def test_rich_progress_bar(progress_update, tmpdir):

model = BoringModel()

trainer = Trainer(
Expand All @@ -58,7 +58,61 @@ def test_rich_progress_bar(progress_update, tmpdir):


def test_rich_progress_bar_import_error():

if not _RICH_AVAILABLE:
with pytest.raises(ImportError, match="`RichProgressBar` requires `rich` to be installed."):
Trainer(callbacks=RichProgressBar())


@RunIf(rich=True)
def test_rich_progress_bar_custom_theme(tmpdir):
"""Test to ensure that custom theme styles are used."""
with mock.patch.multiple(
"pytorch_lightning.callbacks.progress.rich_progress",
BarColumn=DEFAULT,
BatchesProcessedColumn=DEFAULT,
CustomTimeColumn=DEFAULT,
ProcessingSpeedColumn=DEFAULT,
) as mocks:

theme = RichProgressBarTheme()

progress_bar = RichProgressBar(theme=theme)
progress_bar.setup(Trainer(tmpdir), BoringModel(), stage=None)

assert progress_bar.theme == theme
args, kwargs = mocks["BarColumn"].call_args
assert kwargs["complete_style"] == theme.progress_bar_complete
assert kwargs["finished_style"] == theme.progress_bar_finished

args, kwargs = mocks["BatchesProcessedColumn"].call_args
assert kwargs["style"] == theme.batch_process

args, kwargs = mocks["CustomTimeColumn"].call_args
assert kwargs["style"] == theme.time

args, kwargs = mocks["ProcessingSpeedColumn"].call_args
assert kwargs["style"] == theme.processing_speed


@RunIf(rich=True)
def test_rich_progress_bar_keyboard_interrupt(tmpdir):
"""Test to ensure that when the user keyboard interrupts, we close the progress bar."""

class TestModel(BoringModel):
def on_train_start(self) -> None:
raise KeyboardInterrupt

model = TestModel()

with mock.patch(
"pytorch_lightning.callbacks.progress.rich_progress.Progress.stop", autospec=True
) as mock_progress_stop:
progress_bar = RichProgressBar()
trainer = Trainer(
default_root_dir=tmpdir,
fast_dev_run=True,
callbacks=progress_bar,
)

trainer.fit(model)
mock_progress_stop.assert_called_once()