From adc4a597ee3758059423efcf2de685557470a1ff Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 6 Jan 2022 15:52:08 +0100 Subject: [PATCH] Fix rich bug --- pytorch_lightning/callbacks/progress/rich_progress.py | 3 +-- pytorch_lightning/callbacks/rich_model_summary.py | 2 +- tests/callbacks/test_rich_model_summary.py | 2 +- tests/callbacks/test_rich_progress_bar.py | 2 +- 4 files changed, 4 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/callbacks/progress/rich_progress.py b/pytorch_lightning/callbacks/progress/rich_progress.py index f983ccdaab12e..7306c36aec7b4 100644 --- a/pytorch_lightning/callbacks/progress/rich_progress.py +++ b/pytorch_lightning/callbacks/progress/rich_progress.py @@ -17,7 +17,6 @@ from typing import Any, Dict, Optional, Union from pytorch_lightning.callbacks.progress.base import ProgressBarBase -from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _RICH_AVAILABLE Task, Style = None, None @@ -231,7 +230,7 @@ def __init__( console_kwargs: Optional[Dict[str, Any]] = None, ) -> None: if not _RICH_AVAILABLE: - raise MisconfigurationException( + raise ModuleNotFoundError( "`RichProgressBar` requires `rich` >= 10.2.2. Install it by running `pip install -U rich`." ) diff --git a/pytorch_lightning/callbacks/rich_model_summary.py b/pytorch_lightning/callbacks/rich_model_summary.py index cce4eb316a3b0..14c078a273ece 100644 --- a/pytorch_lightning/callbacks/rich_model_summary.py +++ b/pytorch_lightning/callbacks/rich_model_summary.py @@ -61,7 +61,7 @@ class RichModelSummary(ModelSummary): def __init__(self, max_depth: int = 1) -> None: if not _RICH_AVAILABLE: raise ModuleNotFoundError( - "`RichProgressBar` requires `rich` to be installed. Install it by running `pip install -U rich`." + "`RichModelSummary` requires `rich` to be installed. Install it by running `pip install -U rich`." ) super().__init__(max_depth) diff --git a/tests/callbacks/test_rich_model_summary.py b/tests/callbacks/test_rich_model_summary.py index 5ab091bd01445..88c5f9ab531f0 100644 --- a/tests/callbacks/test_rich_model_summary.py +++ b/tests/callbacks/test_rich_model_summary.py @@ -35,7 +35,7 @@ def test_rich_model_summary_callback(): def test_rich_progress_bar_import_error(): if not _RICH_AVAILABLE: - with pytest.raises(ImportError, match="`RichModelSummary` requires `rich` to be installed."): + with pytest.raises(ModuleNotFoundError, match="`RichModelSummary` requires `rich` to be installed."): Trainer(callbacks=RichModelSummary()) diff --git a/tests/callbacks/test_rich_progress_bar.py b/tests/callbacks/test_rich_progress_bar.py index f2e75006f7ecb..d9a9ad8c3c726 100644 --- a/tests/callbacks/test_rich_progress_bar.py +++ b/tests/callbacks/test_rich_progress_bar.py @@ -85,7 +85,7 @@ def predict_dataloader(self): def test_rich_progress_bar_import_error(): if not _RICH_AVAILABLE: - with pytest.raises(ImportError, match="`RichProgressBar` requires `rich` >= 10.2.2."): + with pytest.raises(ModuleNotFoundError, match="`RichProgressBar` requires `rich` >= 10.2.2."): Trainer(callbacks=RichProgressBar())