From 2850524b61dc426d5786fd05604de70727022105 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Mon, 27 Feb 2023 16:18:36 +0100 Subject: [PATCH] Plot method for aggregation + refactor tests (#1485) * starting point * fix aggregation methods * update testing * fix confusion matrix * changelog * cleaning * drop unused plot_options --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com> Co-authored-by: Jirka Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com> --- CHANGELOG.md | 2 + src/torchmetrics/aggregation.py | 187 +++++++++- .../classification/confusion_matrix.py | 82 ++++- src/torchmetrics/metric.py | 1 + src/torchmetrics/utilities/plot.py | 19 +- tests/unittests/utilities/test_plot.py | 326 ++++++++---------- 6 files changed, 412 insertions(+), 205 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 96994f8daad..f5a5ef0fa38 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -52,6 +52,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added support for deterministic evaluation on GPU for metrics that uses `torch.cumsum` operator ([#1499](https://github.com/Lightning-AI/metrics/pull/1499)) +- Added support for plotting of aggregation metrics through `.plot()` method ([#1485](https://github.com/Lightning-AI/metrics/pull/1485)) + ### Changed - Changed `update_count` and `update_called` from private to public methods ([#1370](https://github.com/Lightning-AI/metrics/pull/1370)) diff --git a/src/torchmetrics/aggregation.py b/src/torchmetrics/aggregation.py index 10d04f4f5f9..845020e3272 100644 --- a/src/torchmetrics/aggregation.py +++ b/src/torchmetrics/aggregation.py @@ -11,8 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from typing import Any, Callable, List, Union +from typing import Any, Callable, List, Optional, Sequence, Union import torch from torch import Tensor @@ -20,6 +19,11 @@ from torchmetrics.metric import Metric from torchmetrics.utilities import rank_zero_warn from torchmetrics.utilities.data import dim_zero_cat +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE, plot_single_or_multi_val + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["SumMetric.plot", "MeanMetric.plot", "MaxMetric.plot", "MinMetric.plot"] class BaseAggregator(Metric): @@ -44,7 +48,7 @@ class BaseAggregator(Metric): value: Tensor is_differentiable = None higher_is_better = None - full_state_update = False + full_state_update: bool = False def __init__( self, @@ -128,7 +132,7 @@ class MaxMetric(BaseAggregator): tensor(3.) """ - full_state_update = True + full_state_update: bool = True def __init__( self, @@ -153,6 +157,49 @@ def update(self, value: Union[float, Tensor]) -> None: if value.numel(): # make sure tensor not empty self.value = torch.max(self.value, torch.max(value)) + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> # Example plotting a single value + >>> from torchmetrics import MaxMetric + >>> metric = MaxMetric() + >>> metric.update([1, 2, 3]) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> # Example plotting multiple values + >>> from torchmetrics import MaxMetric + >>> metric = MaxMetric() + >>> values = [ ] + >>> for i in range(10): + ... values.append(metric(i)) + >>> fig_, ax_ = metric.plot(values) + """ + val = val or self.compute() + fig, ax = plot_single_or_multi_val( + val, ax=ax, higher_is_better=self.higher_is_better, name=self.__class__.__name__ + ) + return fig, ax + class MinMetric(BaseAggregator): """Aggregate a stream of value into their minimum value. @@ -189,7 +236,7 @@ class MinMetric(BaseAggregator): tensor(1.) """ - full_state_update = True + full_state_update: bool = True def __init__( self, @@ -214,6 +261,49 @@ def update(self, value: Union[float, Tensor]) -> None: if value.numel(): # make sure tensor not empty self.value = torch.min(self.value, torch.min(value)) + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> # Example plotting a single value + >>> from torchmetrics import MinMetric + >>> metric = MinMetric() + >>> metric.update([1, 2, 3]) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> # Example plotting multiple values + >>> from torchmetrics import MinMetric + >>> metric = MinMetric() + >>> values = [ ] + >>> for i in range(10): + ... values.append(metric(i)) + >>> fig_, ax_ = metric.plot(values) + """ + val = val or self.compute() + fig, ax = plot_single_or_multi_val( + val, ax=ax, higher_is_better=self.higher_is_better, name=self.__class__.__name__ + ) + return fig, ax + class SumMetric(BaseAggregator): """Aggregate a stream of value into their sum. @@ -273,6 +363,50 @@ def update(self, value: Union[float, Tensor]) -> None: if value.numel(): self.value += value.sum() + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> # Example plotting a single value + >>> from torchmetrics import SumMetric + >>> metric = SumMetric() + >>> metric.update([1, 2, 3]) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> # Example plotting multiple values + >>> from torch import rand, randint + >>> from torchmetrics import SumMetric + >>> metric = SumMetric() + >>> values = [ ] + >>> for i in range(10): + ... values.append(metric([i, i+1])) + >>> fig_, ax_ = metric.plot(values) + """ + val = val or self.compute() + fig, ax = plot_single_or_multi_val( + val, ax=ax, higher_is_better=self.higher_is_better, name=self.__class__.__name__ + ) + return fig, ax + class CatMetric(BaseAggregator): """Concatenate a stream of values. @@ -407,3 +541,46 @@ def update(self, value: Union[float, Tensor], weight: Union[float, Tensor] = 1.0 def compute(self) -> Tensor: """Compute the aggregated value.""" return self.value / self.weight + + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> # Example plotting a single value + >>> from torchmetrics import MeanMetric + >>> metric = MeanMetric() + >>> metric.update([1, 2, 3]) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> # Example plotting multiple values + >>> from torchmetrics import MeanMetric + >>> metric = MeanMetric() + >>> values = [ ] + >>> for i in range(10): + ... values.append(metric([i, i+1])) + >>> fig_, ax_ = metric.plot(values) + """ + val = val or self.compute() + fig, ax = plot_single_or_multi_val( + val, ax=ax, higher_is_better=self.higher_is_better, name=self.__class__.__name__ + ) + return fig, ax diff --git a/src/torchmetrics/classification/confusion_matrix.py b/src/torchmetrics/classification/confusion_matrix.py index d761b4ed0fd..8d50b2b70f4 100644 --- a/src/torchmetrics/classification/confusion_matrix.py +++ b/src/torchmetrics/classification/confusion_matrix.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional +from typing import Any, List, Optional import torch from torch import Tensor @@ -40,7 +40,11 @@ from torchmetrics.utilities.plot import _PLOT_OUT_TYPE, plot_confusion_matrix if not _MATPLOTLIB_AVAILABLE: - __doctest_skip__ = ["MulticlassConfusionMatrix.plot"] + __doctest_skip__ = [ + "BinaryConfusionMatrix.plot", + "MulticlassConfusionMatrix.plot", + "MultilabelConfusionMatrix.plot", + ] class BinaryConfusionMatrix(Metric): @@ -126,6 +130,39 @@ def compute(self) -> Tensor: """Compute confusion matrix.""" return _binary_confusion_matrix_compute(self.confmat, self.normalize) + def plot( + self, val: Optional[Tensor] = None, add_text: bool = True, labels: Optional[List[str]] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + add_text: if the value of each cell should be added to the plot + labels: a list of strings, if provided will be added to the plot to indicate the different classes + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> from torch import randint + >>> from torchmetrics.classification import MulticlassConfusionMatrix + >>> metric = MulticlassConfusionMatrix(num_classes=5) + >>> metric.update(randint(5, (20,)), randint(5, (20,))) + >>> fig_, ax_ = metric.plot() + """ + val = val or self.compute() + if not isinstance(val, Tensor): + raise TypeError(f"Expected val to be a single tensor but got {val}") + fig, ax = plot_confusion_matrix(val, add_text=add_text, labels=labels) + return fig, ax + class MulticlassConfusionMatrix(Metric): r"""Compute the `confusion matrix`_ for multiclass tasks. @@ -231,12 +268,16 @@ def compute(self) -> Tensor: """Compute confusion matrix.""" return _multiclass_confusion_matrix_compute(self.confmat, self.normalize) - def plot(self, val: Optional[Tensor] = None) -> _PLOT_OUT_TYPE: + def plot( + self, val: Optional[Tensor] = None, add_text: bool = True, labels: Optional[List[str]] = None + ) -> _PLOT_OUT_TYPE: """Plot a single or multiple values from the metric. Args: val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. If no value is provided, will automatically call `metric.compute` and plot that result. + add_text: if the value of each cell should be added to the plot + labels: a list of strings, if provided will be added to the plot to indicate the different classes Returns: Figure and Axes object @@ -257,7 +298,7 @@ def plot(self, val: Optional[Tensor] = None) -> _PLOT_OUT_TYPE: val = val or self.compute() if not isinstance(val, Tensor): raise TypeError(f"Expected val to be a single tensor but got {val}") - fig, ax = plot_confusion_matrix(val) + fig, ax = plot_confusion_matrix(val, add_text=add_text, labels=labels) return fig, ax @@ -351,6 +392,39 @@ def compute(self) -> Tensor: """Compute confusion matrix.""" return _multilabel_confusion_matrix_compute(self.confmat, self.normalize) + def plot( + self, val: Optional[Tensor] = None, add_text: bool = True, labels: Optional[List[str]] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + add_text: if the value of each cell should be added to the plot + labels: a list of strings, if provided will be added to the plot to indicate the different classes + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> from torch import randint + >>> from torchmetrics.classification import MulticlassConfusionMatrix + >>> metric = MulticlassConfusionMatrix(num_classes=5) + >>> metric.update(randint(5, (20,)), randint(5, (20,))) + >>> fig_, ax_ = metric.plot() + """ + val = val or self.compute() + if not isinstance(val, Tensor): + raise TypeError(f"Expected val to be a single tensor but got {val}") + fig, ax = plot_confusion_matrix(val, add_text=add_text, labels=labels) + return fig, ax + class ConfusionMatrix: r"""Compute the `confusion matrix`_. diff --git a/src/torchmetrics/metric.py b/src/torchmetrics/metric.py index b004be9a2f2..fb033563400 100644 --- a/src/torchmetrics/metric.py +++ b/src/torchmetrics/metric.py @@ -79,6 +79,7 @@ class Metric(Module, ABC): is_differentiable: Optional[bool] = None higher_is_better: Optional[bool] = None full_state_update: Optional[bool] = None + plot_options: Dict[str, Union[str, float]] = {} def __init__( self, diff --git a/src/torchmetrics/utilities/plot.py b/src/torchmetrics/utilities/plot.py index 4dd96b0934f..922e4643ad4 100644 --- a/src/torchmetrics/utilities/plot.py +++ b/src/torchmetrics/utilities/plot.py @@ -155,7 +155,7 @@ def trim_axs(axs: Union[_AX_TYPE, np.ndarray], nb: int) -> np.ndarray: # type: def plot_confusion_matrix( confmat: Tensor, add_text: bool = True, - labels: Optional[List[str]] = None, + labels: Optional[List[Union[int, str]]] = None, ) -> _PLOT_OUT_TYPE: """Plot an confusion matrix. @@ -166,7 +166,7 @@ def plot_confusion_matrix( confmat: the confusion matrix. Either should be an [N,N] matrix in the binary and multiclass cases or an [N, 2, 2] matrix for multilabel classification add_text: if text should be added to each cell with the given value - labels: labels to add the the x and y axis + labels: labels to add the x- and y-axis Returns: A tuple consisting of the figure and respective ax objects (or array of ax objects) of the generated figure @@ -188,16 +188,19 @@ def plot_confusion_matrix( "Expected number of elements in arg `labels` to match number of labels in confmat but " f"got {len(labels)} and {n_classes}" ) - labels: Union[List[int], List[str]] = labels if labels is not None else np.arange(n_classes).tolist() + if confmat.ndim == 3: + fig_label = labels or np.arange(nb) + labels = list(map(str, range(n_classes))) + else: + fig_label = None + labels = labels or np.arange(n_classes).tolist() fig, axs = plt.subplots(nrows=rows, ncols=cols) axs = trim_axs(axs, nb) for i in range(nb): - if rows != 1 and cols != 1: - ax = axs[i] - ax.set_title(f"Label {i}", fontsize=15) - else: - ax = axs + ax = axs[i] if rows != 1 and cols != 1 else axs + if fig_label is not None: + ax.set_title(f"Label {fig_label[i]}", fontsize=15) ax.imshow(confmat[i].cpu().detach() if confmat.ndim == 3 else confmat.cpu().detach()) ax.set_xlabel("True class", fontsize=15) ax.set_ylabel("Predicted class", fontsize=15) diff --git a/tests/unittests/utilities/test_plot.py b/tests/unittests/utilities/test_plot.py index 3c6969fb8ad..5ba80cab4f6 100644 --- a/tests/unittests/utilities/test_plot.py +++ b/tests/unittests/utilities/test_plot.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from functools import partial +from typing import Callable import matplotlib import matplotlib.pyplot as plt @@ -19,265 +20,214 @@ import pytest import torch -from torchmetrics.functional import ( - multiscale_structural_similarity_index_measure, - peak_signal_noise_ratio, - scale_invariant_signal_distortion_ratio, - scale_invariant_signal_noise_ratio, - signal_distortion_ratio, - signal_noise_ratio, - spectral_angle_mapper, - structural_similarity_index_measure, - universal_image_quality_index, +from torchmetrics.aggregation import MaxMetric, MeanMetric, MinMetric, SumMetric +from torchmetrics.audio import ( + ScaleInvariantSignalDistortionRatio, + ScaleInvariantSignalNoiseRatio, + ShortTimeObjectiveIntelligibility, + SignalDistortionRatio, + SignalNoiseRatio, ) -from torchmetrics.functional.audio import short_time_objective_intelligibility -from torchmetrics.functional.audio.pesq import perceptual_evaluation_speech_quality -from torchmetrics.functional.audio.pit import permutation_invariant_training -from torchmetrics.functional.classification.accuracy import binary_accuracy, multiclass_accuracy -from torchmetrics.functional.classification.auroc import binary_auroc, multiclass_auroc -from torchmetrics.functional.classification.confusion_matrix import ( - binary_confusion_matrix, - multiclass_confusion_matrix, - multilabel_confusion_matrix, +from torchmetrics.audio.pesq import PerceptualEvaluationSpeechQuality +from torchmetrics.audio.pit import PermutationInvariantTraining +from torchmetrics.classification import ( + BinaryAccuracy, + BinaryAUROC, + BinaryConfusionMatrix, + BinaryROC, + MulticlassAccuracy, + MulticlassAUROC, + MulticlassConfusionMatrix, + MultilabelConfusionMatrix, ) -from torchmetrics.functional.classification.roc import binary_roc -from torchmetrics.functional.image.d_lambda import spectral_distortion_index -from torchmetrics.functional.image.ergas import error_relative_global_dimensionless_synthesis -from torchmetrics.utilities.plot import plot_binary_roc_curve, plot_confusion_matrix, plot_single_or_multi_val +from torchmetrics.functional.audio import scale_invariant_signal_noise_ratio +from torchmetrics.image import ( + ErrorRelativeGlobalDimensionlessSynthesis, + MultiScaleStructuralSimilarityIndexMeasure, + PeakSignalNoiseRatio, + SpectralAngleMapper, + SpectralDistortionIndex, + StructuralSimilarityIndexMeasure, + UniversalImageQualityIndex, +) +from torchmetrics.regression import MeanSquaredError + +_rand_input = lambda: torch.rand(10) +_binary_randint_input = lambda: torch.randint(2, (10,)) +_multiclass_randint_input = lambda: torch.randint(3, (10,)) +_multiclass_randn_input = lambda: torch.randn(10, 3).softmax(dim=-1) +_multilabel_randint_input = lambda: torch.randint(2, (10, 3)) +_audio_input = lambda: torch.randn(8000) +_image_input = lambda: torch.rand([8, 3, 16, 16]) @pytest.mark.parametrize( - ("metric", "preds", "target"), + ("metric_class", "preds", "target"), [ - # Accuracy + pytest.param(BinaryAccuracy, _rand_input, _binary_randint_input, id="binary accuracy"), pytest.param( - binary_accuracy, - lambda: torch.rand(100), - lambda: torch.randint(2, (100,)), - id="binary", + partial(MulticlassAccuracy, num_classes=3), + _multiclass_randint_input, + _multiclass_randint_input, + id="multiclass accuracy", ), - pytest.param(binary_accuracy, lambda: torch.rand(100), lambda: torch.randint(2, (100,)), id="binary"), pytest.param( - partial(multiclass_accuracy, num_classes=3), - lambda: torch.randint(3, (100,)), - lambda: torch.randint(3, (100,)), - id="multiclass", + partial(MulticlassAccuracy, num_classes=3, average=None), + _multiclass_randint_input, + _multiclass_randint_input, + id="multiclass accuracy and average=None", ), + # AUROC pytest.param( - partial(multiclass_accuracy, num_classes=3, average=None), - lambda: torch.randint(3, (100,)), - lambda: torch.randint(3, (100,)), - id="multiclass and average=None", + BinaryAUROC, + _rand_input, + _binary_randint_input, + id="binary auroc", ), - # AUROC pytest.param( - binary_auroc, - lambda: torch.nn.functional.softmax(torch.randn(100, 2), dim=1)[:, 1], - lambda: torch.randint(2, (100,)), - id="binary", + partial(MulticlassAUROC, num_classes=3), + _multiclass_randn_input, + _multiclass_randint_input, + id="multiclass auroc", ), pytest.param( - partial(multiclass_auroc, num_classes=3), - lambda: torch.randn(100, 3), - lambda: torch.randint(3, (100,)), - id="multiclass", + partial(MulticlassAUROC, num_classes=3, average=None), + _multiclass_randn_input, + _multiclass_randint_input, + id="multiclass auroc and average=None", ), pytest.param( - partial(multiclass_auroc, num_classes=3, average=None), - lambda: torch.randn(100, 3), - lambda: torch.randint(3, (100,)), - id="multiclass and average=None", + BinaryROC, + _rand_input, + _binary_randint_input, ), pytest.param( - partial(spectral_distortion_index), - lambda: torch.rand([16, 3, 16, 16]), - lambda: torch.rand([16, 3, 16, 16]), + SpectralDistortionIndex, + _image_input, + _image_input, id="spectral distortion index", ), pytest.param( - partial(error_relative_global_dimensionless_synthesis), - lambda: torch.rand([16, 1, 16, 16], generator=torch.manual_seed(42)), - lambda: torch.rand([16, 1, 16, 16], generator=torch.manual_seed(42)), + ErrorRelativeGlobalDimensionlessSynthesis, + _image_input, + _image_input, id="error relative global dimensionless synthesis", ), pytest.param( - partial(peak_signal_noise_ratio), + PeakSignalNoiseRatio, lambda: torch.tensor([[0.0, 1.0], [2.0, 3.0]]), lambda: torch.tensor([[3.0, 2.0], [1.0, 0.0]]), id="peak signal noise ratio", ), pytest.param( - partial(spectral_angle_mapper), - lambda: torch.rand([16, 3, 16, 16], generator=torch.manual_seed(42)), - lambda: torch.rand([16, 3, 16, 16], generator=torch.manual_seed(123)), + SpectralAngleMapper, + _image_input, + _image_input, id="spectral angle mapper", ), pytest.param( - partial(structural_similarity_index_measure), - lambda: torch.rand([3, 3, 256, 256], generator=torch.manual_seed(42)), - lambda: torch.rand([3, 3, 256, 256], generator=torch.manual_seed(42)) * 0.75, + StructuralSimilarityIndexMeasure, + _image_input, + _image_input, id="structural similarity index_measure", ), pytest.param( - partial(multiscale_structural_similarity_index_measure), - lambda: torch.rand([3, 3, 256, 256], generator=torch.manual_seed(42)), - lambda: torch.rand([3, 3, 256, 256], generator=torch.manual_seed(42)) * 0.75, + MultiScaleStructuralSimilarityIndexMeasure, + lambda: torch.rand(3, 3, 180, 180), + lambda: torch.rand(3, 3, 180, 180), id="multiscale structural similarity index measure", ), pytest.param( - partial(universal_image_quality_index), - lambda: torch.rand([16, 1, 16, 16]), - lambda: torch.rand([16, 1, 16, 16]) * 0.75, + UniversalImageQualityIndex, + _image_input, + _image_input, id="universal image quality index", ), pytest.param( - partial(perceptual_evaluation_speech_quality, fs=8000, mode="nb"), - lambda: torch.randn(8000), - lambda: torch.randn(8000), + partial(PerceptualEvaluationSpeechQuality, fs=8000, mode="nb"), + _audio_input, + _audio_input, id="perceptual_evaluation_speech_quality", ), + pytest.param(SignalDistortionRatio, _audio_input, _audio_input, id="signal_distortion_ratio"), pytest.param( - partial(signal_distortion_ratio), - lambda: torch.randn(8000), - lambda: torch.randn(8000), - id="signal_distortion_ratio", - ), - pytest.param( - partial(scale_invariant_signal_distortion_ratio), - lambda: torch.randn(5), - lambda: torch.randn(5), - id="scale_invariant_signal_distortion_ratio", + ScaleInvariantSignalDistortionRatio, _rand_input, _rand_input, id="scale_invariant_signal_distortion_ratio" ), + pytest.param(SignalNoiseRatio, _rand_input, _rand_input, id="signal_noise_ratio"), + pytest.param(ScaleInvariantSignalNoiseRatio, _rand_input, _rand_input, id="scale_invariant_signal_noise_ratio"), pytest.param( - partial(signal_noise_ratio), - lambda: torch.randn(4), - lambda: torch.randn(4), - id="signal_noise_ratio", - ), - pytest.param( - partial(scale_invariant_signal_noise_ratio), - lambda: torch.randn(4), - lambda: torch.randn(4), - id="scale_invariant_signal_noise_ratio", - ), - pytest.param( - partial(short_time_objective_intelligibility, fs=8000, extended=False), - lambda: torch.randn(8000), - lambda: torch.randn(8000), + partial(ShortTimeObjectiveIntelligibility, fs=8000, extended=False), + _audio_input, + _audio_input, id="short_time_objective_intelligibility", ), - ], -) -@pytest.mark.parametrize("num_vals", [1, 5, 10]) -def test_single_multi_val_plotter(metric, preds, target, num_vals): - vals = [] - for i in range(num_vals): - vals.append(metric(preds(), target())) - vals = vals[0] if i == 1 else vals - fig, ax = plot_single_or_multi_val(vals) - assert isinstance(fig, plt.Figure) - assert isinstance(ax, matplotlib.axes.Axes) - - -@pytest.mark.parametrize( - ("metric", "preds", "target"), - [ pytest.param( - partial(permutation_invariant_training, metric_func=scale_invariant_signal_noise_ratio, eval_func="max"), + partial(PermutationInvariantTraining, metric_func=scale_invariant_signal_noise_ratio, eval_func="max"), lambda: torch.randn(3, 2, 5), lambda: torch.randn(3, 2, 5), id="permutation_invariant_training", - ) + ), + pytest.param(MeanSquaredError, _rand_input, _rand_input, id="mean squared error"), + pytest.param(SumMetric, _rand_input, None, id="sum metric"), + pytest.param(MeanMetric, _rand_input, None, id="mean metric"), + pytest.param(MinMetric, _rand_input, None, id="min metric"), + pytest.param(MaxMetric, _rand_input, None, id="min metric"), ], ) -@pytest.mark.parametrize("num_vals", [1, 5, 10]) -def test_single_multi_val_plotter_pit(metric, preds, target, num_vals): - vals = [] - for i in range(num_vals): - vals.append(metric(preds(), target())[0]) - vals = vals[0] if i == 1 else vals - fig, ax = plot_single_or_multi_val(vals) - assert isinstance(fig, plt.Figure) - assert isinstance(ax, matplotlib.axes.Axes) +@pytest.mark.parametrize("num_vals", [1, 5]) +def test_single_multi_val_plot_methods(metric_class: object, preds: Callable, target: Callable, num_vals: int): + """Test the plot method of metrics that only output a single tensor scalar.""" + metric = metric_class() + input = (lambda: (preds(),)) if target is None else lambda: (preds(), target()) + + if num_vals == 1: + metric.update(*input()) + fig, ax = metric.plot() + else: + vals = [] + for _ in range(num_vals): + vals.append(metric(*input())) + fig, ax = metric.plot(vals) -@pytest.mark.parametrize( - ("metric", "preds", "target"), - [ - pytest.param( - binary_confusion_matrix, - torch.rand( - 100, - ), - torch.randint(2, (100,)), - id="binary", - ), - pytest.param( - partial(multiclass_confusion_matrix, num_classes=3), - torch.randint(3, (100,)), - torch.randint(3, (100,)), - id="multiclass", - ), - pytest.param( - partial(multilabel_confusion_matrix, num_labels=3), - torch.randint(2, (100, 3)), - torch.randint(2, (100, 3)), - id="multilabel", - ), - ], -) -def test_confusion_matrix_plotter(metric, preds, target): - confmat = metric(preds, target) - fig, axs = plot_confusion_matrix(confmat) assert isinstance(fig, plt.Figure) - cond1 = isinstance(axs, matplotlib.axes.Axes) - cond2 = isinstance(axs, np.ndarray) and all(isinstance(a, matplotlib.axes.Axes) for a in axs) - assert cond1 or cond2 + assert isinstance(ax, matplotlib.axes.Axes) @pytest.mark.parametrize( - ("metric", "preds", "target", "labels"), + ("metric_class", "preds", "target", "labels"), [ pytest.param( - binary_confusion_matrix, - torch.rand( - 100, - ), - torch.randint(2, (100,)), + BinaryConfusionMatrix, + _rand_input, + _binary_randint_input, ["cat", "dog"], - id="binary", + id="binary confusion matrix", ), pytest.param( - partial(multiclass_confusion_matrix, num_classes=3), - torch.randint(3, (100,)), - torch.randint(3, (100,)), + partial(MulticlassConfusionMatrix, num_classes=3), + _multiclass_randint_input, + _multiclass_randint_input, ["cat", "dog", "bird"], - id="multiclass", + id="multiclass confusion matrix", + ), + pytest.param( + partial(MultilabelConfusionMatrix, num_labels=3), + _multilabel_randint_input, + _multilabel_randint_input, + ["cat", "dog", "bird"], + id="multilabel confusion matrix", ), ], ) -def test_confusion_matrix_plotter_with_labels(metric, preds, target, labels): - confmat = metric(preds, target) - fig, axs = plot_confusion_matrix(confmat, labels=labels) +@pytest.mark.parametrize("use_labels", [False, True]) +def test_confusion_matrix_plotter(metric_class, preds, target, labels, use_labels): + """Test confusion matrix that uses specialized plot function.""" + metric = metric_class() + metric.update(preds(), target()) + labels = labels if use_labels else None + fig, axs = metric.plot(add_text=True, labels=labels) assert isinstance(fig, plt.Figure) cond1 = isinstance(axs, matplotlib.axes.Axes) cond2 = isinstance(axs, np.ndarray) and all(isinstance(a, matplotlib.axes.Axes) for a in axs) assert cond1 or cond2 - - -@pytest.mark.parametrize( - ("metric", "preds", "target"), - [ - pytest.param( - binary_roc, - lambda: torch.nn.functional.softmax(torch.randn(100, 2), dim=1)[:, 1], - lambda: torch.randint(2, (100,)), - id="binary", - ) - ], -) -def test_binary_roc_curve_plotter(metric, preds, target): - tpr, fpr, thresholds = metric(preds(), target()) - fig, ax = plot_binary_roc_curve(tpr, fpr) - assert isinstance(fig, plt.Figure) - assert isinstance(ax, matplotlib.axes.Axes)