diff --git a/CHANGELOG.md b/CHANGELOG.md index 1ea7677e59c..25022f9a595 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -25,6 +25,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 [#1581](https://github.com/Lightning-AI/metrics/pull/1581), [#1585](https://github.com/Lightning-AI/metrics/pull/1585), [#1593](https://github.com/Lightning-AI/metrics/pull/1593), + [#1600](https://github.com/Lightning-AI/metrics/pull/1600), ) diff --git a/requirements/docs.txt b/requirements/docs.txt index 4dd5da84959..1209e6d16b0 100644 --- a/requirements/docs.txt +++ b/requirements/docs.txt @@ -16,3 +16,4 @@ sphinx-copybutton>=0.3 -r visual.txt -r audio.txt -r detection.txt +-r image.txt diff --git a/src/torchmetrics/image/fid.py b/src/torchmetrics/image/fid.py index 352f7d0a647..cd0fe874a60 100644 --- a/src/torchmetrics/image/fid.py +++ b/src/torchmetrics/image/fid.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from copy import deepcopy -from typing import Any, List, Optional, Union +from typing import Any, List, Optional, Sequence, Union import numpy as np import torch @@ -22,7 +22,11 @@ from torchmetrics.metric import Metric from torchmetrics.utilities import rank_zero_info -from torchmetrics.utilities.imports import _SCIPY_AVAILABLE, _TORCH_FIDELITY_AVAILABLE +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE, _SCIPY_AVAILABLE, _TORCH_FIDELITY_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["FrechetInceptionDistance.plot"] if _TORCH_FIDELITY_AVAILABLE: from torch_fidelity.feature_extractor_inceptionv3 import FeatureExtractorInceptionV3 as _FeatureExtractorInceptionV3 @@ -31,7 +35,7 @@ class _FeatureExtractorInceptionV3(Module): pass - __doctest_skip__ = ["FrechetInceptionDistance", "FID"] + __doctest_skip__ = ["FrechetInceptionDistance", "FrechetInceptionDistance.plot"] if _SCIPY_AVAILABLE: @@ -303,3 +307,54 @@ def reset(self) -> None: self.real_features_num_samples = real_features_num_samples else: super().reset() + + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> # Example plotting a single value + >>> import torch + >>> from torchmetrics.image.fid import FrechetInceptionDistance + >>> imgs_dist1 = torch.randint(0, 200, (100, 3, 299, 299), dtype=torch.uint8) + >>> imgs_dist2 = torch.randint(100, 255, (100, 3, 299, 299), dtype=torch.uint8) + >>> metric = FrechetInceptionDistance(feature=64) + >>> metric.update(imgs_dist1, real=True) + >>> metric.update(imgs_dist2, real=False) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> # Example plotting multiple values + >>> import torch + >>> from torchmetrics.image.fid import FrechetInceptionDistance + >>> imgs_dist1 = lambda: torch.randint(0, 200, (100, 3, 299, 299), dtype=torch.uint8) + >>> imgs_dist2 = lambda: torch.randint(100, 255, (100, 3, 299, 299), dtype=torch.uint8) + >>> metric = FrechetInceptionDistance(feature=64) + >>> values = [ ] + >>> for _ in range(3): + ... metric.update(imgs_dist1(), real=True) + ... metric.update(imgs_dist2(), real=False) + ... values.append(metric.compute()) + ... metric.reset() + >>> fig_, ax_ = metric.plot(values) + + + """ + return self._plot(val, ax) diff --git a/src/torchmetrics/image/inception.py b/src/torchmetrics/image/inception.py index 9bf8cbe511f..8ce90fb777d 100644 --- a/src/torchmetrics/image/inception.py +++ b/src/torchmetrics/image/inception.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Tuple, Union +from typing import Any, List, Optional, Sequence, Tuple, Union import torch from torch import Tensor @@ -21,9 +21,14 @@ from torchmetrics.metric import Metric from torchmetrics.utilities import rank_zero_warn from torchmetrics.utilities.data import dim_zero_cat -from torchmetrics.utilities.imports import _TORCH_FIDELITY_AVAILABLE +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE, _TORCH_FIDELITY_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE -__doctest_requires__ = {("InceptionScore", "IS"): ["torch_fidelity"]} +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["InceptionScore.plot"] + + +__doctest_requires__ = {("InceptionScore", "InceptionScore.plot"): ["torch_fidelity"]} class InceptionScore(Metric): @@ -162,3 +167,46 @@ def compute(self) -> Tuple[Tensor, Tensor]: # return mean and std return kl.mean(), kl.std() + + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> # Example plotting a single value + >>> import torch + >>> from torchmetrics.image.inception import InceptionScore + >>> metric = InceptionScore() + >>> metric.update(torch.randint(0, 255, (50, 3, 299, 299), dtype=torch.uint8)) + >>> fig_, ax_ = metric.plot() # the returned plot only shows the mean value by default + + .. plot:: + :scale: 75 + + >>> # Example plotting multiple values + >>> import torch + >>> from torchmetrics.image.inception import InceptionScore + >>> metric = InceptionScore() + >>> values = [ ] + >>> for _ in range(3): + ... # we index by 0 such that only the mean value is plotted + ... values.append(metric(torch.randint(0, 255, (50, 3, 299, 299), dtype=torch.uint8))[0]) + >>> fig_, ax_ = metric.plot(values) + """ + val = val or self.compute()[0] # by default we select the mean to plot + return self._plot(val, ax) diff --git a/src/torchmetrics/image/kid.py b/src/torchmetrics/image/kid.py index d52dfbfe0b8..b7b9906185e 100644 --- a/src/torchmetrics/image/kid.py +++ b/src/torchmetrics/image/kid.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Tuple, Union +from typing import Any, List, Optional, Sequence, Tuple, Union import torch from torch import Tensor @@ -21,9 +21,13 @@ from torchmetrics.metric import Metric from torchmetrics.utilities import rank_zero_warn from torchmetrics.utilities.data import dim_zero_cat -from torchmetrics.utilities.imports import _TORCH_FIDELITY_AVAILABLE +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE, _TORCH_FIDELITY_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE -__doctest_requires__ = {("KernelInceptionDistance", "KID"): ["torch_fidelity"]} +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["KernelInceptionDistance.plot"] + +__doctest_requires__ = {("KernelInceptionDistance", "KernelInceptionDistance.plot"): ["torch_fidelity"]} def maximum_mean_discrepancy(k_xx: Tensor, k_xy: Tensor, k_yy: Tensor) -> Tensor: @@ -274,3 +278,53 @@ def reset(self) -> None: self._defaults["real_features"] = value else: super().reset() + + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> # Example plotting a single value + >>> import torch + >>> from torchmetrics.image.kid import KernelInceptionDistance + >>> imgs_dist1 = torch.randint(0, 200, (30, 3, 299, 299), dtype=torch.uint8) + >>> imgs_dist2 = torch.randint(100, 255, (30, 3, 299, 299), dtype=torch.uint8) + >>> metric = KernelInceptionDistance(subsets=3, subset_size=20) + >>> metric.update(imgs_dist1, real=True) + >>> metric.update(imgs_dist2, real=False) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> # Example plotting multiple values + >>> import torch + >>> from torchmetrics.image.kid import KernelInceptionDistance + >>> imgs_dist1 = lambda: torch.randint(0, 200, (30, 3, 299, 299), dtype=torch.uint8) + >>> imgs_dist2 = lambda: torch.randint(100, 255, (30, 3, 299, 299), dtype=torch.uint8) + >>> metric = KernelInceptionDistance(subsets=3, subset_size=20) + >>> values = [ ] + >>> for _ in range(3): + ... metric.update(imgs_dist1(), real=True) + ... metric.update(imgs_dist2(), real=False) + ... values.append(metric.compute()[0]) + ... metric.reset() + >>> fig_, ax_ = metric.plot(values) + """ + val = val or self.compute()[0] # by default we select the mean to plot + return self._plot(val, ax) diff --git a/src/torchmetrics/image/lpip.py b/src/torchmetrics/image/lpip.py index ae5216a1d0a..2e7133a27e5 100644 --- a/src/torchmetrics/image/lpip.py +++ b/src/torchmetrics/image/lpip.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import os -from typing import Any, List +from typing import Any, List, Optional, Sequence, Union import torch from torch import Tensor @@ -21,7 +21,11 @@ from torchmetrics.metric import Metric from torchmetrics.utilities.checks import _SKIP_SLOW_DOCTEST, _try_proceed_with_timeout -from torchmetrics.utilities.imports import _LPIPS_AVAILABLE +from torchmetrics.utilities.imports import _LPIPS_AVAILABLE, _MATPLOTLIB_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["LearnedPerceptualImagePatchSimilarity.plot"] if _LPIPS_AVAILABLE: from lpips import LPIPS as _LPIPS @@ -30,13 +34,13 @@ def _download_lpips() -> None: _LPIPS(pretrained=True, net="vgg") if _SKIP_SLOW_DOCTEST and not _try_proceed_with_timeout(_download_lpips): - __doctest_skip__ = ["LearnedPerceptualImagePatchSimilarity", "LPIPS"] + __doctest_skip__ = ["LearnedPerceptualImagePatchSimilarity", "LearnedPerceptualImagePatchSimilarity.plot"] else: class _LPIPS(Module): pass - __doctest_skip__ = ["LearnedPerceptualImagePatchSimilarity", "LPIPS"] + __doctest_skip__ = ["LearnedPerceptualImagePatchSimilarity", "LearnedPerceptualImagePatchSimilarity.plot"] class NoTrainLpips(_LPIPS): @@ -167,3 +171,44 @@ def compute(self) -> Tensor: if self.reduction == "sum": return self.sum_scores return None + + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> # Example plotting a single value + >>> import torch + >>> from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity + >>> metric = LearnedPerceptualImagePatchSimilarity() + >>> metric.update(torch.rand(10, 3, 100, 100), torch.rand(10, 3, 100, 100)) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> # Example plotting multiple values + >>> import torch + >>> from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity + >>> metric = LearnedPerceptualImagePatchSimilarity() + >>> values = [ ] + >>> for _ in range(3): + ... values.append(metric(torch.rand(10, 3, 100, 100), torch.rand(10, 3, 100, 100))) + >>> fig_, ax_ = metric.plot(values) + """ + return self._plot(val, ax) diff --git a/src/torchmetrics/image/rase.py b/src/torchmetrics/image/rase.py index e2b112a4671..7ed80c19afc 100644 --- a/src/torchmetrics/image/rase.py +++ b/src/torchmetrics/image/rase.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional, Sequence, Union import torch from torch import Tensor @@ -20,6 +20,11 @@ from torchmetrics.functional.image.rase import relative_average_spectral_error from torchmetrics.metric import Metric from torchmetrics.utilities.data import dim_zero_cat +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["RelativeAverageSpectralError.plot"] class RelativeAverageSpectralError(Metric): @@ -85,3 +90,45 @@ def compute(self) -> Tensor: preds = dim_zero_cat(self.preds) target = dim_zero_cat(self.target) return relative_average_spectral_error(preds, target, self.window_size) + + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> # Example plotting a single value + >>> import torch + >>> from torchmetrics import RelativeAverageSpectralError + >>> metric = RelativeAverageSpectralError() + >>> metric.update(torch.rand(4, 3, 16, 16), torch.rand(4, 3, 16, 16)) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> # Example plotting multiple values + >>> import torch + >>> _ = torch.manual_seed(42) + >>> from torchmetrics import RelativeAverageSpectralError + >>> metric = RelativeAverageSpectralError() + >>> values = [ ] + >>> for _ in range(10): + ... values.append(metric(torch.rand(4, 3, 16, 16), torch.rand(4, 3, 16, 16))) + >>> fig_, ax_ = metric.plot(values) + """ + return self._plot(val, ax) diff --git a/src/torchmetrics/image/rmse_sw.py b/src/torchmetrics/image/rmse_sw.py index 9f61b490ef3..4352ddca2e6 100644 --- a/src/torchmetrics/image/rmse_sw.py +++ b/src/torchmetrics/image/rmse_sw.py @@ -12,13 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any, Dict, Optional, Sequence, Tuple, Union import torch from torch import Tensor from torchmetrics.functional.image.rmse_sw import _rmse_sw_compute, _rmse_sw_update from torchmetrics.metric import Metric +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["RootMeanSquaredErrorUsingSlidingWindow.plot"] class RootMeanSquaredErrorUsingSlidingWindow(Metric): @@ -85,3 +90,44 @@ def compute(self) -> Optional[Tensor]: assert self.rmse_map is not None rmse, _ = _rmse_sw_compute(self.rmse_val_sum, self.rmse_map, self.total_images) return rmse + + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> # Example plotting a single value + >>> import torch + >>> from torchmetrics import RootMeanSquaredErrorUsingSlidingWindow + >>> metric = RootMeanSquaredErrorUsingSlidingWindow() + >>> metric.update(torch.rand(4, 3, 16, 16), torch.rand(4, 3, 16, 16)) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> # Example plotting multiple values + >>> import torch + >>> from torchmetrics import RootMeanSquaredErrorUsingSlidingWindow + >>> metric = RootMeanSquaredErrorUsingSlidingWindow() + >>> values = [ ] + >>> for _ in range(10): + ... values.append(metric(torch.rand(4, 3, 16, 16), torch.rand(4, 3, 16, 16))) + >>> fig_, ax_ = metric.plot(values) + """ + return self._plot(val, ax) diff --git a/src/torchmetrics/image/tv.py b/src/torchmetrics/image/tv.py index 9d30148cd33..40b92297ba0 100644 --- a/src/torchmetrics/image/tv.py +++ b/src/torchmetrics/image/tv.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any +from typing import Any, Optional, Sequence, Union import torch from torch import Tensor, tensor @@ -20,6 +20,11 @@ from torchmetrics.functional.image.tv import _total_variation_compute, _total_variation_update from torchmetrics.metric import Metric from torchmetrics.utilities.data import dim_zero_cat +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["TotalVariation.plot"] class TotalVariation(Metric): @@ -86,3 +91,44 @@ def compute(self) -> Tensor: """Compute final total variation.""" score = dim_zero_cat(self.score) if self.reduction is None or self.reduction == "none" else self.score return _total_variation_compute(score, self.num_elements, self.reduction) + + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> # Example plotting a single value + >>> import torch + >>> from torchmetrics import TotalVariation + >>> metric = TotalVariation() + >>> metric.update(torch.rand(5, 3, 28, 28)) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> # Example plotting multiple values + >>> import torch + >>> from torchmetrics import TotalVariation + >>> metric = TotalVariation() + >>> values = [ ] + >>> for _ in range(10): + ... values.append(metric(torch.rand(5, 3, 28, 28))) + >>> fig_, ax_ = metric.plot(values) + """ + return self._plot(val, ax) diff --git a/tests/unittests/utilities/test_plot.py b/tests/unittests/utilities/test_plot.py index c672b3a06e6..e6843bad18f 100644 --- a/tests/unittests/utilities/test_plot.py +++ b/tests/unittests/utilities/test_plot.py @@ -53,11 +53,18 @@ from torchmetrics.functional.audio import scale_invariant_signal_noise_ratio from torchmetrics.image import ( ErrorRelativeGlobalDimensionlessSynthesis, + FrechetInceptionDistance, + InceptionScore, + KernelInceptionDistance, + LearnedPerceptualImagePatchSimilarity, MultiScaleStructuralSimilarityIndexMeasure, PeakSignalNoiseRatio, + RelativeAverageSpectralError, + RootMeanSquaredErrorUsingSlidingWindow, SpectralAngleMapper, SpectralDistortionIndex, StructuralSimilarityIndexMeasure, + TotalVariation, UniversalImageQualityIndex, ) from torchmetrics.nominal import CramersV, PearsonsContingencyCoefficient, TheilsU, TschuprowsT @@ -244,13 +251,26 @@ _multilabel_randint_input, id="multilabel average precision", ), + pytest.param(TotalVariation, _image_input, None, id="total variation"), + pytest.param( + RootMeanSquaredErrorUsingSlidingWindow, + _image_input, + _image_input, + id="root mean squared error using sliding window", + ), + pytest.param(RelativeAverageSpectralError, _image_input, _image_input, id="relative average spectral error"), + pytest.param( + LearnedPerceptualImagePatchSimilarity, + lambda: torch.rand(10, 3, 100, 100), + lambda: torch.rand(10, 3, 100, 100), + id="learned perceptual image patch similarity", + ), ], ) @pytest.mark.parametrize("num_vals", [1, 5]) -def test_single_multi_val_plot_methods(metric_class: object, preds: Callable, target: Callable, num_vals: int): +def test_plot_methods(metric_class: object, preds: Callable, target: Callable, num_vals: int): """Test the plot method of metrics that only output a single tensor scalar.""" metric = metric_class() - input = (lambda: (preds(),)) if target is None else lambda: (preds(), target()) if num_vals == 1: @@ -266,6 +286,64 @@ def test_single_multi_val_plot_methods(metric_class: object, preds: Callable, ta assert isinstance(ax, matplotlib.axes.Axes) +@pytest.mark.parametrize( + ("metric_class", "preds", "target", "index_0"), + [ + pytest.param( + partial(KernelInceptionDistance, feature=64, subsets=3, subset_size=20), + lambda: torch.randint(0, 200, (30, 3, 299, 299), dtype=torch.uint8), + lambda: torch.randint(0, 200, (30, 3, 299, 299), dtype=torch.uint8), + True, + id="kernel inception distance", + ), + pytest.param( + partial(FrechetInceptionDistance, feature=64), + lambda: torch.randint(0, 200, (30, 3, 299, 299), dtype=torch.uint8), + lambda: torch.randint(0, 200, (30, 3, 299, 299), dtype=torch.uint8), + False, + id="frechet inception distance", + ), + pytest.param( + partial(InceptionScore, feature=64), + lambda: torch.randint(0, 255, (30, 3, 299, 299), dtype=torch.uint8), + None, + True, + id="inception score", + ), + ], +) +@pytest.mark.parametrize("num_vals", [1, 2]) +def test_plot_methods_special_image_metrics(metric_class, preds, target, index_0, num_vals): + """Test the plot method of metrics that only output a single tensor scalar. + + This takes care of FID, KID and inception score image metrics as these have a slightly different call and update + signature than other metrics. + """ + metric = metric_class() + + if num_vals == 1: + if target is None: + metric.update(preds()) + else: + metric.update(preds(), real=True) + metric.update(target(), real=False) + fig, ax = metric.plot() + else: + vals = [] + for _ in range(num_vals): + if target is None: + vals.append(metric(preds())[0]) + else: + metric.update(preds(), real=True) + metric.update(target(), real=False) + vals.append(metric.compute() if not index_0 else metric.compute()[0]) + metric.reset() + fig, ax = metric.plot(vals) + + assert isinstance(fig, plt.Figure) + assert isinstance(ax, matplotlib.axes.Axes) + + @pytest.mark.parametrize( ("metric_class", "preds", "target", "labels"), [