Skip to content

Commit

Permalink
Change plot style (#1582)
Browse files Browse the repository at this point in the history
Co-authored-by: Jirka <[email protected]>
  • Loading branch information
SkafteNicki and Borda authored Mar 3, 2023
1 parent 128166d commit 2a8f074
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 4 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci-tests-full.yml
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ jobs:
run: |
echo 'UNITTEST_TIMEOUT=--timeout=120' >> $GITHUB_ENV
sudo apt update --fix-missing
sudo apt install -y ffmpeg
sudo apt install -y ffmpeg dvipng texlive-latex-extra texlive-fonts-recommended cm-super
- name: Setup Windows
if: ${{ runner.os == 'windows' }}
run: |
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/docs-check.yml
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ jobs:
pip install -e . -U -q -r requirements/docs.txt -f https://download.pytorch.org/whl/torch_stable.html
# install Texlive, see https://linuxconfig.org/how-to-install-latex-on-ubuntu-20-04-focal-fossa-linux
sudo apt-get update
sudo apt-get install -y texlive-latex-extra dvipng texlive-pictures
sudo apt-get install -y texlive-latex-extra dvipng texlive-pictures texlive-fonts-recommended cm-super
python --version
pip --version
pip list
Expand Down
1 change: 1 addition & 0 deletions requirements/visual.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
# in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment

matplotlib >=3.2.0, <=3.6.3
SciencePlots >=2.0.0, <= 2.0.1
4 changes: 4 additions & 0 deletions src/torchmetrics/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
"""Import utilities."""
import operator
import shutil
from typing import Optional

from lightning_utilities.core.imports import compare_version, package_available
Expand Down Expand Up @@ -42,5 +43,8 @@
_PYSTOI_AVAILABLE: bool = package_available("pystoi")
_FAST_BSS_EVAL_AVAILABLE: bool = package_available("fast_bss_eval")
_MATPLOTLIB_AVAILABLE: bool = package_available("matplotlib")
_SCIENCEPLOT_AVAILABLE: bool = package_available("scienceplots")
_MULTIPROCESSING_AVAILABLE: bool = package_available("multiprocessing")
_XLA_AVAILABLE: bool = package_available("torch_xla")

_LATEX_AVAILABLE: bool = shutil.which("latex") is not None
24 changes: 22 additions & 2 deletions src/torchmetrics/utilities/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,24 +13,41 @@
# limitations under the License.
from itertools import product
from math import ceil, floor, sqrt
from typing import Any, List, Optional, Sequence, Tuple, Union
from typing import Any, Generator, List, Optional, Sequence, Tuple, Union

import numpy as np
import torch
from torch import Tensor

from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
from torchmetrics.utilities.imports import _LATEX_AVAILABLE, _MATPLOTLIB_AVAILABLE, _SCIENCEPLOT_AVAILABLE

if _MATPLOTLIB_AVAILABLE:
import matplotlib
import matplotlib.pyplot as plt

_PLOT_OUT_TYPE = Tuple[plt.Figure, Union[matplotlib.axes.Axes, np.ndarray]]
_AX_TYPE = matplotlib.axes.Axes

style_change = plt.style.context
else:
_PLOT_OUT_TYPE = Tuple[object, object] # type: ignore[misc]
_AX_TYPE = object

from contextlib import contextmanager

@contextmanager
def style_change(*args: Any, **kwargs: Any) -> Generator:
"""Default no-ops decorator if matplotlib is not installed."""
yield


if _SCIENCEPLOT_AVAILABLE:
import scienceplots # noqa: F401

_style = ["science", "no-latex"]

_style = ["science"] if _SCIENCEPLOT_AVAILABLE and _LATEX_AVAILABLE else ["default"]


def _error_on_missing_matplotlib() -> None:
"""Raise error if matplotlib is not installed."""
Expand All @@ -40,6 +57,7 @@ def _error_on_missing_matplotlib() -> None:
)


@style_change(_style)
def plot_single_or_multi_val(
val: Union[Tensor, Sequence[Tensor]],
ax: Optional[_AX_TYPE] = None, # type: ignore[valid-type]
Expand Down Expand Up @@ -152,6 +170,7 @@ def trim_axs(axs: Union[_AX_TYPE, np.ndarray], nb: int) -> np.ndarray: # type:
return axs[:nb]


@style_change(_style)
def plot_confusion_matrix(
confmat: Tensor,
add_text: bool = True,
Expand Down Expand Up @@ -217,6 +236,7 @@ def plot_confusion_matrix(
return fig, axs


@style_change(_style)
def plot_binary_roc_curve(
tpr: Tensor,
fpr: Tensor,
Expand Down

0 comments on commit 2a8f074

Please sign in to comment.