diff --git a/CHANGELOG.md b/CHANGELOG.md index b26bcc71c74..9fcd0e8f1cf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,7 +16,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added support for plotting of metrics through `.plot()` method ( [#1328](https://github.com/Lightning-AI/metrics/pull/1328), - [#1481](https://github.com/Lightning-AI/metrics/pull/1481) + [#1481](https://github.com/Lightning-AI/metrics/pull/1481), + [#1480](https://github.com/Lightning-AI/metrics/pull/1480) ) diff --git a/examples/plotting.py b/examples/plotting.py index cc85c091ebf..c613f4937d3 100644 --- a/examples/plotting.py +++ b/examples/plotting.py @@ -222,8 +222,147 @@ def confusion_matrix_example(): return fig, ax -if __name__ == "__main__": +def spectral_distortion_index_example(): + """Plot spectral distortion index example example.""" + from torchmetrics.image.d_lambda import SpectralDistortionIndex + + p = lambda: torch.rand([16, 3, 16, 16]) + t = lambda: torch.rand([16, 3, 16, 16]) + + # plot single value + metric = SpectralDistortionIndex() + metric.update(p(), t()) + fig, ax = metric.plot() + + # plot multiple values + metric = SpectralDistortionIndex() + vals = [metric(p(), t()) for _ in range(10)] + fig, ax = metric.plot(vals) + + return fig, ax + + +def error_relative_global_dimensionless_synthesis(): + """Plot error relative global dimensionless synthesis example.""" + from torchmetrics.image.ergas import ErrorRelativeGlobalDimensionlessSynthesis + + p = lambda: torch.rand([16, 1, 16, 16], generator=torch.manual_seed(42)) + t = lambda: torch.rand([16, 1, 16, 16], generator=torch.manual_seed(42)) + + # plot single value + metric = ErrorRelativeGlobalDimensionlessSynthesis() + metric.update(p(), t()) + fig, ax = metric.plot() + + # plot multiple values + metric = ErrorRelativeGlobalDimensionlessSynthesis() + vals = [metric(p(), t()) for _ in range(10)] + fig, ax = metric.plot(vals) + + return fig, ax + + +def peak_signal_noise_ratio(): + """Plot peak signal noise ratio example.""" + from torchmetrics.image.psnr import PeakSignalNoiseRatio + + p = lambda: torch.tensor([[0.0, 1.0], [2.0, 3.0]]) + t = lambda: torch.tensor([[3.0, 2.0], [1.0, 0.0]]) + + # plot single value + metric = PeakSignalNoiseRatio() + metric.update(p(), t()) + fig, ax = metric.plot() + + # plot multiple values + metric = PeakSignalNoiseRatio() + vals = [metric(p(), t()) for _ in range(10)] + fig, ax = metric.plot(vals) + + return fig, ax + + +def spectral_angle_mapper(): + """Plot spectral angle mapper example.""" + from torchmetrics.image.sam import SpectralAngleMapper + + p = lambda: torch.rand([16, 3, 16, 16], generator=torch.manual_seed(42)) + t = lambda: torch.rand([16, 3, 16, 16], generator=torch.manual_seed(123)) + + # plot single value + metric = SpectralAngleMapper() + metric.update(p(), t()) + fig, ax = metric.plot() + + # plot multiple values + metric = SpectralAngleMapper() + vals = [metric(p(), t()) for _ in range(10)] + fig, ax = metric.plot(vals) + + return fig, ax + + +def structural_similarity_index_measure(): + """Plot structural similarity index measure example.""" + from torchmetrics.image.ssim import StructuralSimilarityIndexMeasure + + p = lambda: torch.rand([3, 3, 256, 256], generator=torch.manual_seed(42)) + t = lambda: p() * 0.75 + + # plot single value + metric = StructuralSimilarityIndexMeasure() + metric.update(p(), t()) + fig, ax = metric.plot() + + # plot multiple values + metric = StructuralSimilarityIndexMeasure() + vals = [metric(p(), t()) for _ in range(10)] + fig, ax = metric.plot(vals) + + return fig, ax + + +def multiscale_structural_similarity_index_measure(): + """Plot multiscale structural similarity index measure example.""" + from torchmetrics.image.ssim import MultiScaleStructuralSimilarityIndexMeasure + + p = lambda: torch.rand([3, 3, 256, 256], generator=torch.manual_seed(42)) + t = lambda: p() * 0.75 + + # plot single value + metric = MultiScaleStructuralSimilarityIndexMeasure() + metric.update(p(), t()) + fig, ax = metric.plot() + + # plot multiple values + metric = MultiScaleStructuralSimilarityIndexMeasure() + vals = [metric(p(), t()) for _ in range(10)] + fig, ax = metric.plot(vals) + return fig, ax + + +def universal_image_quality_index(): + """Plot universal image quality index example.""" + from torchmetrics.image.uqi import UniversalImageQualityIndex + + p = lambda: torch.rand([16, 1, 16, 16]) + t = lambda: p() * 0.75 + + # plot single value + metric = UniversalImageQualityIndex() + metric.update(p(), t()) + fig, ax = metric.plot() + + # plot multiple values + metric = UniversalImageQualityIndex() + vals = [metric(p(), t()) for _ in range(10)] + fig, ax = metric.plot(vals) + + return fig, ax + + +if __name__ == "__main__": metrics_func = { "accuracy": accuracy_example, "pesq": pesq_example, @@ -235,6 +374,13 @@ def confusion_matrix_example(): "stoi": stoi_example, "mean_squared_error": mean_squared_error_example, "confusion_matrix": confusion_matrix_example, + "spectral_distortion_index": spectral_distortion_index_example, + "error_relative_global_dimensionless_synthesis": error_relative_global_dimensionless_synthesis, + "peak_signal_noise_ratio": peak_signal_noise_ratio, + "spectral_angle_mapper": spectral_angle_mapper, + "structural_similarity_index_measure": structural_similarity_index_measure, + "multiscale_structural_similarity_index_measure": multiscale_structural_similarity_index_measure, + "universal_image_quality_index": universal_image_quality_index, } parser = argparse.ArgumentParser(description="Example script for plotting metrics.") diff --git a/src/torchmetrics/classification/accuracy.py b/src/torchmetrics/classification/accuracy.py index 7960029a095..3ec993f6b6d 100644 --- a/src/torchmetrics/classification/accuracy.py +++ b/src/torchmetrics/classification/accuracy.py @@ -414,8 +414,7 @@ def plot( ax: An matplotlib axis object. If provided will add plot to that axis Returns: - fig: Figure object - ax: Axes object + Figure and Axes object Raises: ModuleNotFoundError: diff --git a/src/torchmetrics/classification/confusion_matrix.py b/src/torchmetrics/classification/confusion_matrix.py index ed37d9e451d..35b11782f55 100644 --- a/src/torchmetrics/classification/confusion_matrix.py +++ b/src/torchmetrics/classification/confusion_matrix.py @@ -239,8 +239,7 @@ def plot(self, val: Optional[Tensor] = None) -> _PLOT_OUT_TYPE: If no value is provided, will automatically call `metric.compute` and plot that result. Returns: - fig: Figure object - ax: Axes object + Figure and Axes object Raises: ModuleNotFoundError: diff --git a/src/torchmetrics/image/d_lambda.py b/src/torchmetrics/image/d_lambda.py index 20367cf2201..4386e23f849 100644 --- a/src/torchmetrics/image/d_lambda.py +++ b/src/torchmetrics/image/d_lambda.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List +from typing import Any, Dict, List, Optional, Sequence, Union from torch import Tensor from typing_extensions import Literal @@ -20,6 +20,11 @@ from torchmetrics.metric import Metric from torchmetrics.utilities import rank_zero_warn from torchmetrics.utilities.data import dim_zero_cat +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE, plot_single_or_multi_val + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["SpectralDistortionIndex.plot"] class SpectralDistortionIndex(Metric): @@ -95,3 +100,59 @@ def compute(self) -> Tensor: preds = dim_zero_cat(self.preds) target = dim_zero_cat(self.target) return _spectral_distortion_index_compute(preds, target, self.p, self.reduction) + + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> # Example plotting a single value + >>> import torch + >>> _ = torch.manual_seed(42) + >>> from torchmetrics import SpectralDistortionIndex + >>> preds = torch.rand([16, 3, 16, 16]) + >>> target = torch.rand([16, 3, 16, 16]) + >>> metric = SpectralDistortionIndex() + >>> metric.update(preds, target) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> # Example plotting multiple values + >>> import torch + >>> _ = torch.manual_seed(42) + >>> from torchmetrics import SpectralDistortionIndex + >>> preds = torch.rand([16, 3, 16, 16]) + >>> target = torch.rand([16, 3, 16, 16]) + >>> metric = SpectralDistortionIndex() + >>> values = [ ] + >>> for _ in range(10): + ... values.append(metric(preds, target)) + >>> fig_, ax_ = metric.plot(values) + """ + val = val or self.compute() + fig, ax = plot_single_or_multi_val( + val, + ax=ax, + higher_is_better=self.higher_is_better, + name=self.__class__.__name__, + lower_bound=0.0, + upper_bound=1.0, + ) + return fig, ax diff --git a/src/torchmetrics/image/ergas.py b/src/torchmetrics/image/ergas.py index 051a42bc38d..1c71a42540e 100644 --- a/src/torchmetrics/image/ergas.py +++ b/src/torchmetrics/image/ergas.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Union +from typing import Any, Dict, List, Optional, Sequence, Union from torch import Tensor from typing_extensions import Literal @@ -21,6 +21,11 @@ from torchmetrics.metric import Metric from torchmetrics.utilities import rank_zero_warn from torchmetrics.utilities.data import dim_zero_cat +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE, plot_single_or_multi_val + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["ErrorRelativeGlobalDimensionlessSynthesis.plot"] class ErrorRelativeGlobalDimensionlessSynthesis(Metric): @@ -94,3 +99,57 @@ def compute(self) -> Tensor: preds = dim_zero_cat(self.preds) target = dim_zero_cat(self.target) return _ergas_compute(preds, target, self.ratio, self.reduction) + + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> # Example plotting a single value + >>> import torch + >>> from torchmetrics import ErrorRelativeGlobalDimensionlessSynthesis + >>> preds = torch.rand([16, 1, 16, 16], generator=torch.manual_seed(42)) + >>> target = preds * 0.75 + >>> metric = ErrorRelativeGlobalDimensionlessSynthesis() + >>> metric.update(preds, target) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> # Example plotting multiple values + >>> import torch + >>> from torchmetrics import ErrorRelativeGlobalDimensionlessSynthesis + >>> preds = torch.rand([16, 1, 16, 16], generator=torch.manual_seed(42)) + >>> target = preds * 0.75 + >>> metric = ErrorRelativeGlobalDimensionlessSynthesis() + >>> values = [ ] + >>> for _ in range(10): + ... values.append(metric(preds, target)) + >>> fig_, ax_ = metric.plot(values) + """ + val = val or self.compute() + fig, ax = plot_single_or_multi_val( + val, + ax=ax, + higher_is_better=self.higher_is_better, + name=self.__class__.__name__, + lower_bound=0.0, + upper_bound=1.0, + ) + return fig, ax diff --git a/src/torchmetrics/image/psnr.py b/src/torchmetrics/image/psnr.py index dc23664375d..a0dd7097f1f 100644 --- a/src/torchmetrics/image/psnr.py +++ b/src/torchmetrics/image/psnr.py @@ -20,6 +20,11 @@ from torchmetrics.functional.image.psnr import _psnr_compute, _psnr_update from torchmetrics.metric import Metric from torchmetrics.utilities import rank_zero_warn +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE, plot_single_or_multi_val + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["PeakSignalNoiseRatio.plot"] class PeakSignalNoiseRatio(Metric): @@ -70,6 +75,7 @@ class PeakSignalNoiseRatio(Metric): is_differentiable: bool = True higher_is_better: bool = True full_state_update: bool = False + plot_options = {"lower_bound": 0.0, "upper_bound": 10.0} min_target: Tensor max_target: Tensor @@ -135,3 +141,52 @@ def compute(self) -> Tensor: sum_squared_error = torch.cat([values.flatten() for values in self.sum_squared_error]) total = torch.cat([values.flatten() for values in self.total]) return _psnr_compute(sum_squared_error, total, data_range, base=self.base, reduction=self.reduction) + + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> # Example plotting a single value + >>> import torch + >>> from torchmetrics import PeakSignalNoiseRatio + >>> metric = PeakSignalNoiseRatio() + >>> preds = torch.tensor([[0.0, 1.0], [2.0, 3.0]]) + >>> target = torch.tensor([[3.0, 2.0], [1.0, 0.0]]) + >>> metric.update(preds, target) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> # Example plotting multiple values + >>> import torch + >>> from torchmetrics import PeakSignalNoiseRatio + >>> metric = PeakSignalNoiseRatio() + >>> preds = torch.tensor([[0.0, 1.0], [2.0, 3.0]]) + >>> target = torch.tensor([[3.0, 2.0], [1.0, 0.0]]) + >>> values = [ ] + >>> for _ in range(10): + ... values.append(metric(preds, target)) + >>> fig_, ax_ = metric.plot(values) + """ + val = val or self.compute() + fig, ax = plot_single_or_multi_val( + val, ax=ax, higher_is_better=self.higher_is_better, **self.plot_options, name=self.__class__.__name__ + ) + return fig, ax diff --git a/src/torchmetrics/image/sam.py b/src/torchmetrics/image/sam.py index c0e064c2ebb..1d1611a5afd 100644 --- a/src/torchmetrics/image/sam.py +++ b/src/torchmetrics/image/sam.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List +from typing import Any, Dict, List, Optional, Sequence, Union from torch import Tensor from typing_extensions import Literal @@ -20,6 +20,11 @@ from torchmetrics.metric import Metric from torchmetrics.utilities import rank_zero_warn from torchmetrics.utilities.data import dim_zero_cat +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE, plot_single_or_multi_val + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["SpectralAngleMapper.plot"] class SpectralAngleMapper(Metric): @@ -92,3 +97,57 @@ def compute(self) -> Tensor: preds = dim_zero_cat(self.preds) target = dim_zero_cat(self.target) return _sam_compute(preds, target, self.reduction) + + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> # Example plotting single value + >>> import torch + >>> from torchmetrics import SpectralAngleMapper + >>> preds = torch.rand([16, 3, 16, 16], generator=torch.manual_seed(42)) + >>> target = torch.rand([16, 3, 16, 16], generator=torch.manual_seed(123)) + >>> metric = SpectralAngleMapper() + >>> metric.update(preds, target) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> # Example plotting multiple values + >>> import torch + >>> from torchmetrics import SpectralAngleMapper + >>> preds = torch.rand([16, 3, 16, 16], generator=torch.manual_seed(42)) + >>> target = torch.rand([16, 3, 16, 16], generator=torch.manual_seed(123)) + >>> metric = SpectralAngleMapper() + >>> values = [ ] + >>> for _ in range(10): + ... values.append(metric(preds, target)) + >>> fig_, ax_ = metric.plot(values) + """ + val = val or self.compute() + fig, ax = plot_single_or_multi_val( + val, + ax=ax, + higher_is_better=self.higher_is_better, + name=self.__class__.__name__, + lower_bound=0.0, + upper_bound=1.0, + ) + return fig, ax diff --git a/src/torchmetrics/image/ssim.py b/src/torchmetrics/image/ssim.py index 0751a0e9f26..6a1d41d36aa 100644 --- a/src/torchmetrics/image/ssim.py +++ b/src/torchmetrics/image/ssim.py @@ -20,6 +20,11 @@ from torchmetrics.functional.image.ssim import _multiscale_ssim_update, _ssim_check_inputs, _ssim_update from torchmetrics.metric import Metric from torchmetrics.utilities.data import dim_zero_cat +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE, plot_single_or_multi_val + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["StructuralSimilarityIndexMeasure.plot", "MultiScaleStructuralSimilarityIndexMeasure.plot"] class StructuralSimilarityIndexMeasure(Metric): @@ -72,6 +77,7 @@ class StructuralSimilarityIndexMeasure(Metric): higher_is_better: bool = True is_differentiable: bool = True full_state_update: bool = False + plot_options = {"lower_bound": 0.0, "upper_bound": 1.0} preds: List[Tensor] target: List[Tensor] @@ -160,6 +166,55 @@ def compute(self) -> Union[Tensor, Tuple[Tensor, Tensor]]: return similarity + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> # Example plotting a single value + >>> import torch + >>> from torchmetrics import StructuralSimilarityIndexMeasure + >>> preds = torch.rand([3, 3, 256, 256]) + >>> target = preds * 0.75 + >>> metric = StructuralSimilarityIndexMeasure(data_range=1.0) + >>> metric.update(preds, target) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> # Example plotting multiple values + >>> import torch + >>> from torchmetrics import StructuralSimilarityIndexMeasure + >>> preds = torch.rand([3, 3, 256, 256]) + >>> target = preds * 0.75 + >>> metric = StructuralSimilarityIndexMeasure(data_range=1.0) + >>> values = [ ] + >>> for _ in range(10): + ... values.append(metric(preds, target)) + >>> fig_, ax_ = metric.plot(values) + """ + val = val or self.compute() + fig, ax = plot_single_or_multi_val( + val, ax=ax, higher_is_better=self.higher_is_better, **self.plot_options, name=self.__class__.__name__ + ) + return fig, ax + class MultiScaleStructuralSimilarityIndexMeasure(Metric): """Compute `MultiScaleSSIM`_, Multi-scale Structural Similarity Index Measure, which is a generalization of @@ -219,6 +274,7 @@ class MultiScaleStructuralSimilarityIndexMeasure(Metric): higher_is_better: bool = True is_differentiable: bool = True full_state_update: bool = False + plot_options = {"lower_bound": 0.0, "upper_bound": 1.0} preds: List[Tensor] target: List[Tensor] @@ -308,3 +364,52 @@ def compute(self) -> Tensor: return self.similarity else: return self.similarity / self.total + + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> # Example plotting a single value + >>> from torchmetrics import MultiScaleStructuralSimilarityIndexMeasure + >>> import torch + >>> preds = torch.rand([3, 3, 256, 256], generator=torch.manual_seed(42)) + >>> target = preds * 0.75 + >>> metric = MultiScaleStructuralSimilarityIndexMeasure(data_range=1.0) + >>> metric.update(preds, target) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> # Example plotting multiple values + >>> from torchmetrics import MultiScaleStructuralSimilarityIndexMeasure + >>> import torch + >>> preds = torch.rand([3, 3, 256, 256], generator=torch.manual_seed(42)) + >>> target = preds * 0.75 + >>> metric = MultiScaleStructuralSimilarityIndexMeasure(data_range=1.0) + >>> values = [ ] + >>> for _ in range(10): + ... values.append(metric(preds, target)) + >>> fig_, ax_ = metric.plot(values) + """ + val = val or self.compute() + fig, ax = plot_single_or_multi_val( + val, ax=ax, higher_is_better=self.higher_is_better, **self.plot_options, name=self.__class__.__name__ + ) + return fig, ax diff --git a/src/torchmetrics/image/uqi.py b/src/torchmetrics/image/uqi.py index 0a40004ebbd..d48fe54a76f 100644 --- a/src/torchmetrics/image/uqi.py +++ b/src/torchmetrics/image/uqi.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Sequence +from typing import Any, List, Optional, Sequence, Union from torch import Tensor from typing_extensions import Literal @@ -20,6 +20,11 @@ from torchmetrics.metric import Metric from torchmetrics.utilities import rank_zero_warn from torchmetrics.utilities.data import dim_zero_cat +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE, plot_single_or_multi_val + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["UniversalImageQualityIndex.plot"] class UniversalImageQualityIndex(Metric): @@ -64,6 +69,7 @@ class UniversalImageQualityIndex(Metric): is_differentiable: bool = True higher_is_better: bool = True full_state_update: bool = False + plot_options = {"lower_bound": 0.0, "upper_bound": 1.0} preds: List[Tensor] target: List[Tensor] @@ -101,3 +107,52 @@ def compute(self) -> Tensor: preds = dim_zero_cat(self.preds) target = dim_zero_cat(self.target) return _uqi_compute(preds, target, self.kernel_size, self.sigma, self.reduction, self.data_range) + + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> # Example plotting a single value + >>> import torch + >>> from torchmetrics import UniversalImageQualityIndex + >>> preds = torch.rand([16, 1, 16, 16]) + >>> target = preds * 0.75 + >>> metric = UniversalImageQualityIndex() + >>> metric.update(preds, target) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> # Example plotting multiple values + >>> import torch + >>> from torchmetrics import UniversalImageQualityIndex + >>> preds = torch.rand([16, 1, 16, 16]) + >>> target = preds * 0.75 + >>> metric = UniversalImageQualityIndex() + >>> values = [ ] + >>> for _ in range(10): + ... values.append(metric(preds, target)) + >>> fig_, ax_ = metric.plot(values) + """ + val = val or self.compute() + fig, ax = plot_single_or_multi_val( + val, ax=ax, higher_is_better=self.higher_is_better, **self.plot_options, name=self.__class__.__name__ + ) + return fig, ax diff --git a/src/torchmetrics/regression/mse.py b/src/torchmetrics/regression/mse.py index 524b3e86499..1b97c3b391e 100644 --- a/src/torchmetrics/regression/mse.py +++ b/src/torchmetrics/regression/mse.py @@ -93,8 +93,7 @@ def plot( ax: An matplotlib axis object. If provided will add plot to that axis Returns: - fig: Figure object - ax: Axes object + Figure and Axes object Raises: ModuleNotFoundError: diff --git a/tests/unittests/utilities/test_plot.py b/tests/unittests/utilities/test_plot.py index 11a757f6d1e..f9d251c49ea 100644 --- a/tests/unittests/utilities/test_plot.py +++ b/tests/unittests/utilities/test_plot.py @@ -20,10 +20,15 @@ import torch from torchmetrics.functional import ( + multiscale_structural_similarity_index_measure, + peak_signal_noise_ratio, scale_invariant_signal_distortion_ratio, scale_invariant_signal_noise_ratio, signal_distortion_ratio, signal_noise_ratio, + spectral_angle_mapper, + structural_similarity_index_measure, + universal_image_quality_index, ) from torchmetrics.functional.audio import short_time_objective_intelligibility from torchmetrics.functional.audio.pesq import perceptual_evaluation_speech_quality @@ -34,18 +39,15 @@ multiclass_confusion_matrix, multilabel_confusion_matrix, ) +from torchmetrics.functional.image.d_lambda import spectral_distortion_index +from torchmetrics.functional.image.ergas import error_relative_global_dimensionless_synthesis from torchmetrics.utilities.plot import plot_confusion_matrix, plot_single_or_multi_val @pytest.mark.parametrize( ("metric", "preds", "target"), [ - pytest.param( - binary_accuracy, - lambda: torch.rand(100), - lambda: torch.randint(2, (100,)), - id="binary", - ), + pytest.param(binary_accuracy, lambda: torch.rand(100), lambda: torch.randint(2, (100,)), id="binary"), pytest.param( partial(multiclass_accuracy, num_classes=3), lambda: torch.randint(3, (100,)), @@ -58,6 +60,48 @@ lambda: torch.randint(3, (100,)), id="multiclass and average=None", ), + pytest.param( + partial(spectral_distortion_index), + lambda: torch.rand([16, 3, 16, 16]), + lambda: torch.rand([16, 3, 16, 16]), + id="spectral distortion index", + ), + pytest.param( + partial(error_relative_global_dimensionless_synthesis), + lambda: torch.rand([16, 1, 16, 16], generator=torch.manual_seed(42)), + lambda: torch.rand([16, 1, 16, 16], generator=torch.manual_seed(42)), + id="error relative global dimensionless synthesis", + ), + pytest.param( + partial(peak_signal_noise_ratio), + lambda: torch.tensor([[0.0, 1.0], [2.0, 3.0]]), + lambda: torch.tensor([[3.0, 2.0], [1.0, 0.0]]), + id="peak signal noise ratio", + ), + pytest.param( + partial(spectral_angle_mapper), + lambda: torch.rand([16, 3, 16, 16], generator=torch.manual_seed(42)), + lambda: torch.rand([16, 3, 16, 16], generator=torch.manual_seed(123)), + id="spectral angle mapper", + ), + pytest.param( + partial(structural_similarity_index_measure), + lambda: torch.rand([3, 3, 256, 256], generator=torch.manual_seed(42)), + lambda: torch.rand([3, 3, 256, 256], generator=torch.manual_seed(42)) * 0.75, + id="structural similarity index_measure", + ), + pytest.param( + partial(multiscale_structural_similarity_index_measure), + lambda: torch.rand([3, 3, 256, 256], generator=torch.manual_seed(42)), + lambda: torch.rand([3, 3, 256, 256], generator=torch.manual_seed(42)) * 0.75, + id="multiscale structural similarity index measure", + ), + pytest.param( + partial(universal_image_quality_index), + lambda: torch.rand([16, 1, 16, 16]), + lambda: torch.rand([16, 1, 16, 16]) * 0.75, + id="universal image quality index", + ), pytest.param( partial(perceptual_evaluation_speech_quality, fs=8000, mode="nb"), lambda: torch.randn(8000),