diff --git a/.github/workflows/ci-tests-full.yml b/.github/workflows/ci-tests-full.yml index 95b0fd37c71..44d86858497 100644 --- a/.github/workflows/ci-tests-full.yml +++ b/.github/workflows/ci-tests-full.yml @@ -71,7 +71,7 @@ jobs: run: | echo 'UNITTEST_TIMEOUT=--timeout=120' >> $GITHUB_ENV sudo apt update --fix-missing - sudo apt install -y ffmpeg + sudo apt install -y ffmpeg dvipng texlive-latex-extra texlive-fonts-recommended cm-super - name: Setup Windows if: ${{ runner.os == 'windows' }} run: | diff --git a/.github/workflows/docs-check.yml b/.github/workflows/docs-check.yml index fb2611973f0..a23902257e1 100644 --- a/.github/workflows/docs-check.yml +++ b/.github/workflows/docs-check.yml @@ -65,7 +65,7 @@ jobs: pip install -e . -U -q -r requirements/docs.txt -f https://download.pytorch.org/whl/torch_stable.html # install Texlive, see https://linuxconfig.org/how-to-install-latex-on-ubuntu-20-04-focal-fossa-linux sudo apt-get update - sudo apt-get install -y texlive-latex-extra dvipng texlive-pictures + sudo apt-get install -y texlive-latex-extra dvipng texlive-pictures texlive-fonts-recommended cm-super python --version pip --version pip list diff --git a/requirements/visual.txt b/requirements/visual.txt index 3c3bb1d622a..b4329b5206c 100644 --- a/requirements/visual.txt +++ b/requirements/visual.txt @@ -2,3 +2,4 @@ # in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment matplotlib >=3.2.0, <=3.6.3 +SciencePlots >=2.0.0, <= 2.0.1 diff --git a/src/torchmetrics/utilities/imports.py b/src/torchmetrics/utilities/imports.py index 7a56f7db610..b9008a38966 100644 --- a/src/torchmetrics/utilities/imports.py +++ b/src/torchmetrics/utilities/imports.py @@ -13,6 +13,7 @@ # limitations under the License. """Import utilities.""" import operator +import shutil from typing import Optional from lightning_utilities.core.imports import compare_version, package_available @@ -42,5 +43,8 @@ _PYSTOI_AVAILABLE: bool = package_available("pystoi") _FAST_BSS_EVAL_AVAILABLE: bool = package_available("fast_bss_eval") _MATPLOTLIB_AVAILABLE: bool = package_available("matplotlib") +_SCIENCEPLOT_AVAILABLE: bool = package_available("scienceplots") _MULTIPROCESSING_AVAILABLE: bool = package_available("multiprocessing") _XLA_AVAILABLE: bool = package_available("torch_xla") + +_LATEX_AVAILABLE: bool = shutil.which("latex") is not None diff --git a/src/torchmetrics/utilities/plot.py b/src/torchmetrics/utilities/plot.py index 922e4643ad4..7efdfd249a9 100644 --- a/src/torchmetrics/utilities/plot.py +++ b/src/torchmetrics/utilities/plot.py @@ -13,13 +13,13 @@ # limitations under the License. from itertools import product from math import ceil, floor, sqrt -from typing import Any, List, Optional, Sequence, Tuple, Union +from typing import Any, Generator, List, Optional, Sequence, Tuple, Union import numpy as np import torch from torch import Tensor -from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE +from torchmetrics.utilities.imports import _LATEX_AVAILABLE, _MATPLOTLIB_AVAILABLE, _SCIENCEPLOT_AVAILABLE if _MATPLOTLIB_AVAILABLE: import matplotlib @@ -27,10 +27,27 @@ _PLOT_OUT_TYPE = Tuple[plt.Figure, Union[matplotlib.axes.Axes, np.ndarray]] _AX_TYPE = matplotlib.axes.Axes + + style_change = plt.style.context else: _PLOT_OUT_TYPE = Tuple[object, object] # type: ignore[misc] _AX_TYPE = object + from contextlib import contextmanager + + @contextmanager + def style_change(*args: Any, **kwargs: Any) -> Generator: + """Default no-ops decorator if matplotlib is not installed.""" + yield + + +if _SCIENCEPLOT_AVAILABLE: + import scienceplots # noqa: F401 + + _style = ["science", "no-latex"] + +_style = ["science"] if _SCIENCEPLOT_AVAILABLE and _LATEX_AVAILABLE else ["default"] + def _error_on_missing_matplotlib() -> None: """Raise error if matplotlib is not installed.""" @@ -40,6 +57,7 @@ def _error_on_missing_matplotlib() -> None: ) +@style_change(_style) def plot_single_or_multi_val( val: Union[Tensor, Sequence[Tensor]], ax: Optional[_AX_TYPE] = None, # type: ignore[valid-type] @@ -152,6 +170,7 @@ def trim_axs(axs: Union[_AX_TYPE, np.ndarray], nb: int) -> np.ndarray: # type: return axs[:nb] +@style_change(_style) def plot_confusion_matrix( confmat: Tensor, add_text: bool = True, @@ -217,6 +236,7 @@ def plot_confusion_matrix( return fig, axs +@style_change(_style) def plot_binary_roc_curve( tpr: Tensor, fpr: Tensor,