From 5888679e83cc32c499e1ad9def1f561020d17787 Mon Sep 17 00:00:00 2001 From: Jirka Date: Thu, 6 Apr 2023 19:06:57 +0200 Subject: [PATCH 1/5] imports --- src/torchmetrics/__init__.py | 28 +- src/torchmetrics/image/_deprecated.py | 258 ++++++++++++++++++ .../deprecations/root_class_imports.py | 21 ++ 3 files changed, 296 insertions(+), 11 deletions(-) create mode 100644 src/torchmetrics/image/_deprecated.py diff --git a/src/torchmetrics/__init__.py b/src/torchmetrics/__init__.py index 1aedeb14c07..418471a5405 100644 --- a/src/torchmetrics/__init__.py +++ b/src/torchmetrics/__init__.py @@ -46,18 +46,24 @@ ) from torchmetrics.collections import MetricCollection # noqa: E402 from torchmetrics.detection import ModifiedPanopticQuality, PanopticQuality # noqa: E402 -from torchmetrics.image import ( # noqa: E402 - ErrorRelativeGlobalDimensionlessSynthesis, - MultiScaleStructuralSimilarityIndexMeasure, - PeakSignalNoiseRatio, - RelativeAverageSpectralError, - RootMeanSquaredErrorUsingSlidingWindow, - SpectralAngleMapper, - SpectralDistortionIndex, - StructuralSimilarityIndexMeasure, - TotalVariation, - UniversalImageQualityIndex, +from torchmetrics.image._deprecated import ( # noqa: E402 + _ErrorRelativeGlobalDimensionlessSynthesis as ErrorRelativeGlobalDimensionlessSynthesis, ) +from torchmetrics.image._deprecated import ( # noqa: E402 + _MultiScaleStructuralSimilarityIndexMeasure as MultiScaleStructuralSimilarityIndexMeasure, +) +from torchmetrics.image._deprecated import _PeakSignalNoiseRatio as PeakSignalNoiseRatio # noqa: E402 +from torchmetrics.image._deprecated import _RelativeAverageSpectralError as RelativeAverageSpectralError # noqa: E402 +from torchmetrics.image._deprecated import ( # noqa: E402 + _RootMeanSquaredErrorUsingSlidingWindow as RootMeanSquaredErrorUsingSlidingWindow, +) +from torchmetrics.image._deprecated import _SpectralAngleMapper as SpectralAngleMapper # noqa: E402 +from torchmetrics.image._deprecated import _SpectralDistortionIndex as SpectralDistortionIndex # noqa: E402 +from torchmetrics.image._deprecated import ( # noqa: E402 + _StructuralSimilarityIndexMeasure as StructuralSimilarityIndexMeasure, +) +from torchmetrics.image._deprecated import _TotalVariation as TotalVariation # noqa: E402 +from torchmetrics.image._deprecated import _UniversalImageQualityIndex as UniversalImageQualityIndex # noqa: E402 from torchmetrics.metric import Metric # noqa: E402 from torchmetrics.nominal import CramersV # noqa: E402 from torchmetrics.nominal import PearsonsContingencyCoefficient # noqa: E402 diff --git a/src/torchmetrics/image/_deprecated.py b/src/torchmetrics/image/_deprecated.py new file mode 100644 index 00000000000..bab0b7db86c --- /dev/null +++ b/src/torchmetrics/image/_deprecated.py @@ -0,0 +1,258 @@ +from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union + +from typing_extensions import Literal + +from torchmetrics.image.d_lambda import SpectralDistortionIndex +from torchmetrics.image.ergas import ErrorRelativeGlobalDimensionlessSynthesis +from torchmetrics.image.psnr import PeakSignalNoiseRatio +from torchmetrics.image.rase import RelativeAverageSpectralError +from torchmetrics.image.rmse_sw import RootMeanSquaredErrorUsingSlidingWindow +from torchmetrics.image.sam import SpectralAngleMapper +from torchmetrics.image.ssim import MultiScaleStructuralSimilarityIndexMeasure, StructuralSimilarityIndexMeasure +from torchmetrics.image.tv import TotalVariation +from torchmetrics.image.uqi import UniversalImageQualityIndex +from torchmetrics.utilities.prints import _deprecated_root_import_class + + +class _ErrorRelativeGlobalDimensionlessSynthesis(ErrorRelativeGlobalDimensionlessSynthesis): + """Wrapper for deprecated import. + + >>> import torch + >>> preds = torch.rand([16, 1, 16, 16], generator=torch.manual_seed(42)) + >>> target = preds * 0.75 + >>> ergas = _ErrorRelativeGlobalDimensionlessSynthesis() + >>> torch.round(ergas(preds, target)) + tensor(154.) + """ + + def __init__( + self, + ratio: Union[int, float] = 4, + reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean", + **kwargs: Any, + ) -> None: + _deprecated_root_import_class("ErrorRelativeGlobalDimensionlessSynthesis", "image") + return super().__init__(ratio=ratio, reduction=reduction, **kwargs) + + +class _MultiScaleStructuralSimilarityIndexMeasure(MultiScaleStructuralSimilarityIndexMeasure): + """Wrapper for deprecated import. + + >>> import torch + >>> preds = torch.rand([3, 3, 256, 256], generator=torch.manual_seed(42)) + >>> target = preds * 0.75 + >>> ms_ssim = _MultiScaleStructuralSimilarityIndexMeasure(data_range=1.0) + >>> ms_ssim(preds, target) + tensor(0.9627) + """ + + def __init__( + self, + gaussian_kernel: bool = True, + kernel_size: Union[int, Sequence[int]] = 11, + sigma: Union[float, Sequence[float]] = 1.5, + reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean", + data_range: Optional[Union[float, Tuple[float, float]]] = None, + k1: float = 0.01, + k2: float = 0.03, + betas: Tuple[float, ...] = (0.0448, 0.2856, 0.3001, 0.2363, 0.1333), + normalize: Literal["relu", "simple", None] = "relu", + **kwargs: Any, + ) -> None: + _deprecated_root_import_class("MultiScaleStructuralSimilarityIndexMeasure", "image") + return super().__init__( + gaussian_kernel=gaussian_kernel, + kernel_size=kernel_size, + sigma=sigma, + reduction=reduction, + data_range=data_range, + k1=k1, + k2=k2, + betas=betas, + normalize=normalize, + **kwargs, + ) + + +class _PeakSignalNoiseRatio(PeakSignalNoiseRatio): + """Wrapper for deprecated import. + + >>> from torch import tensor + >>> psnr = _PeakSignalNoiseRatio() + >>> preds = tensor([[0.0, 1.0], [2.0, 3.0]]) + >>> target = tensor([[3.0, 2.0], [1.0, 0.0]]) + >>> psnr(preds, target) + tensor(2.5527) + """ + + def __init__( + self, + data_range: Optional[Union[float, Tuple[float, float]]] = None, + base: float = 10.0, + reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean", + dim: Optional[Union[int, Tuple[int, ...]]] = None, + **kwargs: Any, + ) -> None: + _deprecated_root_import_class("PeakSignalNoiseRatio", "image") + return super().__init__(data_range=data_range, base=base, reduction=reduction, dim=dim, **kwargs) + + +class _RelativeAverageSpectralError(RelativeAverageSpectralError): + """Wrapper for deprecated import. + + >>> import torch + >>> g = torch.manual_seed(22) + >>> preds = torch.rand(4, 3, 16, 16) + >>> target = torch.rand(4, 3, 16, 16) + >>> rase = _RelativeAverageSpectralError() + >>> rase(preds, target) + tensor(5114.6641) + """ + + def __init__( + self, + window_size: int = 8, + **kwargs: Dict[str, Any], + ) -> None: + _deprecated_root_import_class("RelativeAverageSpectralError", "image") + return super().__init__(window_size=window_size, **kwargs) + + +class _RootMeanSquaredErrorUsingSlidingWindow(RootMeanSquaredErrorUsingSlidingWindow): + """Wrapper for deprecated import. + + >>> import torch + >>> g = torch.manual_seed(22) + >>> preds = torch.rand(4, 3, 16, 16) + >>> target = torch.rand(4, 3, 16, 16) + >>> rmse_sw = RootMeanSquaredErrorUsingSlidingWindow() + >>> rmse_sw(preds, target) + tensor(0.3999) + """ + + def __init__( + self, + window_size: int = 8, + **kwargs: Dict[str, Any], + ) -> None: + _deprecated_root_import_class("RootMeanSquaredErrorUsingSlidingWindow", "image") + return super().__init__(window_size=window_size, **kwargs) + + +class _SpectralAngleMapper(SpectralAngleMapper): + """Wrapper for deprecated import. + + >>> import torch + >>> preds = torch.rand([16, 3, 16, 16], generator=torch.manual_seed(42)) + >>> target = torch.rand([16, 3, 16, 16], generator=torch.manual_seed(123)) + >>> sam = _SpectralAngleMapper() + >>> sam(preds, target) + tensor(0.5943) + """ + + def __init__( + self, + reduction: Literal["elementwise_mean", "sum", "none"] = "elementwise_mean", + **kwargs: Any, + ) -> None: + _deprecated_root_import_class("SpectralAngleMapper", "image") + return super().__init__(reduction=reduction, **kwargs) + + +class _SpectralDistortionIndex(SpectralDistortionIndex): + """Wrapper for deprecated import. + + >>> import torch + >>> _ = torch.manual_seed(42) + >>> preds = torch.rand([16, 3, 16, 16]) + >>> target = torch.rand([16, 3, 16, 16]) + >>> sdi = _SpectralDistortionIndex() + >>> sdi(preds, target) + tensor(0.0234) + """ + + def __init__( + self, p: int = 1, reduction: Literal["elementwise_mean", "sum", "none"] = "elementwise_mean", **kwargs: Any + ) -> None: + _deprecated_root_import_class("SpectralDistortionIndex", "image") + return super().__init__(p=p, reduction=reduction, **kwargs) + + +class _StructuralSimilarityIndexMeasure(StructuralSimilarityIndexMeasure): + """Wrapper for deprecated import. + + >>> import torch + >>> preds = torch.rand([3, 3, 256, 256]) + >>> target = preds * 0.75 + >>> ssim = _StructuralSimilarityIndexMeasure(data_range=1.0) + >>> ssim(preds, target) + tensor(0.9219) + """ + + def __init__( + self, + gaussian_kernel: bool = True, + sigma: Union[float, Sequence[float]] = 1.5, + kernel_size: Union[int, Sequence[int]] = 11, + reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean", + data_range: Optional[Union[float, Tuple[float, float]]] = None, + k1: float = 0.01, + k2: float = 0.03, + return_full_image: bool = False, + return_contrast_sensitivity: bool = False, + **kwargs: Any, + ) -> None: + _deprecated_root_import_class("StructuralSimilarityIndexMeasure", "image") + return super().__init__( + gaussian_kernel=gaussian_kernel, + sigma=sigma, + kernel_size=kernel_size, + reduction=reduction, + data_range=data_range, + k1=k1, + k2=k2, + return_full_image=return_full_image, + return_contrast_sensitivity=return_contrast_sensitivity, + **kwargs, + ) + + +class _TotalVariation(TotalVariation): + """Wrapper for deprecated import. + + >>> import torch + >>> _ = torch.manual_seed(42) + >>> tv = _TotalVariation() + >>> img = torch.rand(5, 3, 28, 28) + >>> tv(img) + tensor(7546.8018) + """ + + def __init__(self, reduction: Literal["mean", "sum", "none", None] = "sum", **kwargs: Any) -> None: + _deprecated_root_import_class("TotalVariation", "image") + return super().__init__(reduction=reduction, **kwargs) + + +class _UniversalImageQualityIndex(UniversalImageQualityIndex): + """Wrapper for deprecated import. + + >>> import torch + >>> preds = torch.rand([16, 1, 16, 16]) + >>> target = preds * 0.75 + >>> uqi = _UniversalImageQualityIndex() + >>> uqi(preds, target) + tensor(0.9216) + """ + + def __init__( + self, + kernel_size: Sequence[int] = (11, 11), + sigma: Sequence[float] = (1.5, 1.5), + reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean", + data_range: Optional[float] = None, + **kwargs: Any, + ) -> None: + _deprecated_root_import_class("UniversalImageQualityIndex", "image") + return super().__init__( + kernel_size=kernel_size, sigma=sigma, reduction=reduction, data_range=data_range, **kwargs + ) diff --git a/tests/unittests/deprecations/root_class_imports.py b/tests/unittests/deprecations/root_class_imports.py index 6d8ab1f64e7..b522cc40aeb 100644 --- a/tests/unittests/deprecations/root_class_imports.py +++ b/tests/unittests/deprecations/root_class_imports.py @@ -4,11 +4,21 @@ import pytest from torchmetrics import ( + ErrorRelativeGlobalDimensionlessSynthesis, + MultiScaleStructuralSimilarityIndexMeasure, + PeakSignalNoiseRatio, PermutationInvariantTraining, + RelativeAverageSpectralError, + RootMeanSquaredErrorUsingSlidingWindow, ScaleInvariantSignalDistortionRatio, ScaleInvariantSignalNoiseRatio, SignalDistortionRatio, SignalNoiseRatio, + SpectralAngleMapper, + SpectralDistortionIndex, + StructuralSimilarityIndexMeasure, + TotalVariation, + UniversalImageQualityIndex, ) from torchmetrics.functional.audio import scale_invariant_signal_noise_ratio @@ -23,6 +33,17 @@ ScaleInvariantSignalNoiseRatio, SignalDistortionRatio, SignalNoiseRatio, + # Image + ErrorRelativeGlobalDimensionlessSynthesis, + MultiScaleStructuralSimilarityIndexMeasure, + PeakSignalNoiseRatio, + RelativeAverageSpectralError, + RootMeanSquaredErrorUsingSlidingWindow, + SpectralAngleMapper, + SpectralDistortionIndex, + StructuralSimilarityIndexMeasure, + TotalVariation, + UniversalImageQualityIndex, ], ) def test_import_from_root_package(metric_cls): From 39f8181d2bad49d9a034e703082362e92c6eb98c Mon Sep 17 00:00:00 2001 From: Jirka Date: Fri, 7 Apr 2023 08:14:34 +0200 Subject: [PATCH 2/5] functional --- src/torchmetrics/functional/__init__.py | 28 +- .../functional/image/_deprecated.py | 255 ++++++++++++++++++ 2 files changed, 272 insertions(+), 11 deletions(-) create mode 100644 src/torchmetrics/functional/image/_deprecated.py diff --git a/src/torchmetrics/functional/__init__.py b/src/torchmetrics/functional/__init__.py index b48d0b7f586..f13b589629e 100644 --- a/src/torchmetrics/functional/__init__.py +++ b/src/torchmetrics/functional/__init__.py @@ -41,19 +41,25 @@ from torchmetrics.functional.classification.stat_scores import stat_scores from torchmetrics.functional.detection.modified_panoptic_quality import modified_panoptic_quality from torchmetrics.functional.detection.panoptic_quality import panoptic_quality -from torchmetrics.functional.image.d_lambda import spectral_distortion_index -from torchmetrics.functional.image.ergas import error_relative_global_dimensionless_synthesis -from torchmetrics.functional.image.gradients import image_gradients -from torchmetrics.functional.image.psnr import peak_signal_noise_ratio -from torchmetrics.functional.image.rase import relative_average_spectral_error -from torchmetrics.functional.image.rmse_sw import root_mean_squared_error_using_sliding_window -from torchmetrics.functional.image.sam import spectral_angle_mapper +from torchmetrics.functional.image.d_lambda import _spectral_distortion_index as spectral_distortion_index +from torchmetrics.functional.image.ergas import ( + _error_relative_global_dimensionless_synthesis as error_relative_global_dimensionless_synthesis, +) +from torchmetrics.functional.image.gradients import _image_gradients as image_gradients +from torchmetrics.functional.image.psnr import _peak_signal_noise_ratio as peak_signal_noise_ratio +from torchmetrics.functional.image.rase import _relative_average_spectral_error as relative_average_spectral_error +from torchmetrics.functional.image.rmse_sw import ( + _root_mean_squared_error_using_sliding_window as root_mean_squared_error_using_sliding_window, +) +from torchmetrics.functional.image.sam import _spectral_angle_mapper as spectral_angle_mapper +from torchmetrics.functional.image.ssim import ( + _multiscale_structural_similarity_index_measure as multiscale_structural_similarity_index_measure, +) from torchmetrics.functional.image.ssim import ( - multiscale_structural_similarity_index_measure, - structural_similarity_index_measure, + _structural_similarity_index_measure as structural_similarity_index_measure, ) -from torchmetrics.functional.image.tv import total_variation -from torchmetrics.functional.image.uqi import universal_image_quality_index +from torchmetrics.functional.image.tv import _total_variation as total_variation +from torchmetrics.functional.image.uqi import _universal_image_quality_index as universal_image_quality_index from torchmetrics.functional.nominal.cramers import cramers_v, cramers_v_matrix from torchmetrics.functional.nominal.pearson import ( pearsons_contingency_coefficient, diff --git a/src/torchmetrics/functional/image/_deprecated.py b/src/torchmetrics/functional/image/_deprecated.py new file mode 100644 index 00000000000..1b6eb0d9519 --- /dev/null +++ b/src/torchmetrics/functional/image/_deprecated.py @@ -0,0 +1,255 @@ +from typing import Optional, Sequence, Tuple, Union + +from torch import Tensor +from typing_extensions import Literal + +from torchmetrics.functional.image.d_lambda import spectral_distortion_index +from torchmetrics.functional.image.ergas import error_relative_global_dimensionless_synthesis +from torchmetrics.functional.image.gradients import image_gradients +from torchmetrics.functional.image.psnr import peak_signal_noise_ratio +from torchmetrics.functional.image.rase import relative_average_spectral_error +from torchmetrics.functional.image.rmse_sw import root_mean_squared_error_using_sliding_window +from torchmetrics.functional.image.sam import spectral_angle_mapper +from torchmetrics.functional.image.ssim import ( + multiscale_structural_similarity_index_measure, + structural_similarity_index_measure, +) +from torchmetrics.functional.image.tv import total_variation +from torchmetrics.functional.image.uqi import universal_image_quality_index +from torchmetrics.utilities.prints import _deprecated_root_import_func + + +def _spectral_distortion_index( + preds: Tensor, + target: Tensor, + p: int = 1, + reduction: Literal["elementwise_mean", "sum", "none"] = "elementwise_mean", +) -> Tensor: + """Wrapper for deprecated import. + + >>> import torch + >>> _ = torch.manual_seed(42) + >>> preds = torch.rand([16, 3, 16, 16]) + >>> target = torch.rand([16, 3, 16, 16]) + >>> _spectral_distortion_index(preds, target) + tensor(0.0234) + """ + _deprecated_root_import_func("spectral_distortion_index", "image") + return spectral_distortion_index(preds=preds, target=target, p=p, reduction=reduction) + + +def _error_relative_global_dimensionless_synthesis( + preds: Tensor, + target: Tensor, + ratio: Union[int, float] = 4, + reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean", +) -> Tensor: + """Wrapper for deprecated import. + + >>> import torch + >>> preds = torch.rand([16, 1, 16, 16], generator=torch.manual_seed(42)) + >>> target = preds * 0.75 + >>> ergds = _error_relative_global_dimensionless_synthesis(preds, target) + >>> torch.round(ergds) + tensor(154.) + """ + _deprecated_root_import_func("error_relative_global_dimensionless_synthesis", "image") + return error_relative_global_dimensionless_synthesis(preds=preds, target=target, ratio=ratio, reduction=reduction) + + +def _image_gradients(img: Tensor) -> Tuple[Tensor, Tensor]: + """Wrapper for deprecated import. + + >>> import torch + >>> image = torch.arange(0, 1*1*5*5, dtype=torch.float32) + >>> image = torch.reshape(image, (1, 1, 5, 5)) + >>> dy, dx = _image_gradients(image) + >>> dy[0, 0, :, :] + tensor([[5., 5., 5., 5., 5.], + [5., 5., 5., 5., 5.], + [5., 5., 5., 5., 5.], + [5., 5., 5., 5., 5.], + [0., 0., 0., 0., 0.]]) + """ + _deprecated_root_import_func("image_gradients", "image") + return image_gradients(img=img) + + +def _peak_signal_noise_ratio( + preds: Tensor, + target: Tensor, + data_range: Optional[Union[float, Tuple[float, float]]] = None, + base: float = 10.0, + reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean", + dim: Optional[Union[int, Tuple[int, ...]]] = None, +) -> Tensor: + """Wrapper for deprecated import. + + >>> from torch import tensor + >>> pred = tensor([[0.0, 1.0], [2.0, 3.0]]) + >>> target = tensor([[3.0, 2.0], [1.0, 0.0]]) + >>> _peak_signal_noise_ratio(pred, target) + tensor(2.5527) + """ + _deprecated_root_import_func("peak_signal_noise_ratio", "image") + return peak_signal_noise_ratio( + preds=preds, target=target, data_range=data_range, base=base, reduction=reduction, dim=dim + ) + + +def _relative_average_spectral_error(preds: Tensor, target: Tensor, window_size: int = 8) -> Tensor: + """Wrapper for deprecated import. + + >>> import torch + >>> g = torch.manual_seed(22) + >>> preds = torch.rand(4, 3, 16, 16) + >>> target = torch.rand(4, 3, 16, 16) + >>> _relative_average_spectral_error(preds, target) + tensor(5114.6641) + """ + _deprecated_root_import_func("relative_average_spectral_error", "image") + return relative_average_spectral_error(preds=preds, target=target, window_size=window_size) + + +def _root_mean_squared_error_using_sliding_window( + preds: Tensor, target: Tensor, window_size: int = 8, return_rmse_map: bool = False +) -> Union[Optional[Tensor], Tuple[Optional[Tensor], Tensor]]: + """Wrapper for deprecated import. + + >>> import torch + >>> g = torch.manual_seed(22) + >>> preds = torch.rand(4, 3, 16, 16) + >>> target = torch.rand(4, 3, 16, 16) + >>> _root_mean_squared_error_using_sliding_window(preds, target) + tensor(0.3999) + """ + _deprecated_root_import_func("root_mean_squared_error_using_sliding_window", "image") + return root_mean_squared_error_using_sliding_window( + preds=preds, target=target, window_size=window_size, return_rmse_map=return_rmse_map + ) + + +def _spectral_angle_mapper( + preds: Tensor, + target: Tensor, + reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean", +) -> Tensor: + """Wrapper for deprecated import. + + >>> import torch + >>> preds = torch.rand([16, 3, 16, 16], generator=torch.manual_seed(42)) + >>> target = torch.rand([16, 3, 16, 16], generator=torch.manual_seed(123)) + >>> _spectral_angle_mapper(preds, target) + tensor(0.5943) + """ + _deprecated_root_import_func("spectral_angle_mapper", "image") + return spectral_angle_mapper(preds=preds, target=target, reduction=reduction) + + +def _multiscale_structural_similarity_index_measure( + preds: Tensor, + target: Tensor, + gaussian_kernel: bool = True, + sigma: Union[float, Sequence[float]] = 1.5, + kernel_size: Union[int, Sequence[int]] = 11, + reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean", + data_range: Optional[Union[float, Tuple[float, float]]] = None, + k1: float = 0.01, + k2: float = 0.03, + betas: Tuple[float, ...] = (0.0448, 0.2856, 0.3001, 0.2363, 0.1333), + normalize: Optional[Literal["relu", "simple"]] = "relu", +) -> Tensor: + """Wrapper for deprecated import. + + >>> import torch + >>> preds = torch.rand([3, 3, 256, 256], generator=torch.manual_seed(42)) + >>> target = preds * 0.75 + >>> _multiscale_structural_similarity_index_measure(preds, target, data_range=1.0) + tensor(0.9627) + """ + _deprecated_root_import_func("multiscale_structural_similarity_index_measure", "image") + return multiscale_structural_similarity_index_measure( + preds=preds, + target=target, + gaussian_kernel=gaussian_kernel, + sigma=sigma, + kernel_size=kernel_size, + reduction=reduction, + data_range=data_range, + k1=k1, + k2=k2, + betas=betas, + normalize=normalize, + ) + + +def _structural_similarity_index_measure( + preds: Tensor, + target: Tensor, + gaussian_kernel: bool = True, + sigma: Union[float, Sequence[float]] = 1.5, + kernel_size: Union[int, Sequence[int]] = 11, + reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean", + data_range: Optional[Union[float, Tuple[float, float]]] = None, + k1: float = 0.01, + k2: float = 0.03, + return_full_image: bool = False, + return_contrast_sensitivity: bool = False, +) -> Union[Tensor, Tuple[Tensor, Tensor]]: + """Wrapper for deprecated import. + + >>> import torch + >>> preds = torch.rand([3, 3, 256, 256]) + >>> target = preds * 0.75 + >>> _structural_similarity_index_measure(preds, target) + tensor(0.9219) + """ + _deprecated_root_import_func("spectral_angle_mapper", "image") + return structural_similarity_index_measure( + preds=preds, + target=target, + gaussian_kernel=gaussian_kernel, + sigma=sigma, + kernel_size=kernel_size, + reduction=reduction, + data_range=data_range, + k1=k1, + k2=k2, + return_full_image=return_full_image, + return_contrast_sensitivity=return_contrast_sensitivity, + ) + + +def _total_variation(img: Tensor, reduction: Literal["mean", "sum", "none", None] = "sum") -> Tensor: + """Wrapper for deprecated import. + + >>> import torch + >>> _ = torch.manual_seed(42) + >>> img = torch.rand(5, 3, 28, 28) + >>> _total_variation(img) + tensor(7546.8018) + """ + _deprecated_root_import_func("total_variation", "image") + return total_variation(img=img, reduction=reduction) + + +def _universal_image_quality_index( + preds: Tensor, + target: Tensor, + kernel_size: Sequence[int] = (11, 11), + sigma: Sequence[float] = (1.5, 1.5), + reduction: Optional[Literal["elementwise_mean", "sum", "none"]] = "elementwise_mean", + data_range: Optional[float] = None, +) -> Tensor: + """Wrapper for deprecated import. + + >>> import torch + >>> preds = torch.rand([16, 1, 16, 16]) + >>> target = preds * 0.75 + >>> _universal_image_quality_index(preds, target) + tensor(0.9216) + """ + _deprecated_root_import_func("universal_image_quality_index", "image") + return universal_image_quality_index( + preds=preds, target=target, kernel_size=kernel_size, sigma=sigma, reduction=reduction, data_range=data_range + ) From 9658295f62d4e092ac3765872ae7b01e9e24c850 Mon Sep 17 00:00:00 2001 From: Jirka Date: Fri, 7 Apr 2023 08:35:00 +0200 Subject: [PATCH 3/5] fix --- src/torchmetrics/functional/__init__.py | 28 +++++++++++++------------ 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/src/torchmetrics/functional/__init__.py b/src/torchmetrics/functional/__init__.py index f13b589629e..5e026180f27 100644 --- a/src/torchmetrics/functional/__init__.py +++ b/src/torchmetrics/functional/__init__.py @@ -41,25 +41,27 @@ from torchmetrics.functional.classification.stat_scores import stat_scores from torchmetrics.functional.detection.modified_panoptic_quality import modified_panoptic_quality from torchmetrics.functional.detection.panoptic_quality import panoptic_quality -from torchmetrics.functional.image.d_lambda import _spectral_distortion_index as spectral_distortion_index -from torchmetrics.functional.image.ergas import ( +from torchmetrics.functional.image._deprecated import ( _error_relative_global_dimensionless_synthesis as error_relative_global_dimensionless_synthesis, ) -from torchmetrics.functional.image.gradients import _image_gradients as image_gradients -from torchmetrics.functional.image.psnr import _peak_signal_noise_ratio as peak_signal_noise_ratio -from torchmetrics.functional.image.rase import _relative_average_spectral_error as relative_average_spectral_error -from torchmetrics.functional.image.rmse_sw import ( - _root_mean_squared_error_using_sliding_window as root_mean_squared_error_using_sliding_window, -) -from torchmetrics.functional.image.sam import _spectral_angle_mapper as spectral_angle_mapper -from torchmetrics.functional.image.ssim import ( +from torchmetrics.functional.image._deprecated import _image_gradients as image_gradients +from torchmetrics.functional.image._deprecated import ( _multiscale_structural_similarity_index_measure as multiscale_structural_similarity_index_measure, ) -from torchmetrics.functional.image.ssim import ( +from torchmetrics.functional.image._deprecated import _peak_signal_noise_ratio as peak_signal_noise_ratio +from torchmetrics.functional.image._deprecated import ( + _relative_average_spectral_error as relative_average_spectral_error, +) +from torchmetrics.functional.image._deprecated import ( + _root_mean_squared_error_using_sliding_window as root_mean_squared_error_using_sliding_window, +) +from torchmetrics.functional.image._deprecated import _spectral_angle_mapper as spectral_angle_mapper +from torchmetrics.functional.image._deprecated import _spectral_distortion_index as spectral_distortion_index +from torchmetrics.functional.image._deprecated import ( _structural_similarity_index_measure as structural_similarity_index_measure, ) -from torchmetrics.functional.image.tv import _total_variation as total_variation -from torchmetrics.functional.image.uqi import _universal_image_quality_index as universal_image_quality_index +from torchmetrics.functional.image._deprecated import _total_variation as total_variation +from torchmetrics.functional.image._deprecated import _universal_image_quality_index as universal_image_quality_index from torchmetrics.functional.nominal.cramers import cramers_v, cramers_v_matrix from torchmetrics.functional.nominal.pearson import ( pearsons_contingency_coefficient, From f8a20e2d82d6ce0b9da6aceb3027de475d1cb122 Mon Sep 17 00:00:00 2001 From: Jirka Date: Fri, 7 Apr 2023 09:21:34 +0200 Subject: [PATCH 4/5] all --- src/torchmetrics/functional/image/__init__.py | 34 ++++++++++---- src/torchmetrics/image/__init__.py | 47 +++++++++++++------ 2 files changed, 56 insertions(+), 25 deletions(-) diff --git a/src/torchmetrics/functional/image/__init__.py b/src/torchmetrics/functional/image/__init__.py index 6c6d73cfa6d..df50bf7e31b 100644 --- a/src/torchmetrics/functional/image/__init__.py +++ b/src/torchmetrics/functional/image/__init__.py @@ -11,16 +11,30 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from torchmetrics.functional.image.d_lambda import spectral_distortion_index # noqa: F401 -from torchmetrics.functional.image.ergas import error_relative_global_dimensionless_synthesis # noqa: F401 -from torchmetrics.functional.image.gradients import image_gradients # noqa: F401 -from torchmetrics.functional.image.psnr import peak_signal_noise_ratio # noqa: F401 -from torchmetrics.functional.image.rase import relative_average_spectral_error # noqa: F401 -from torchmetrics.functional.image.rmse_sw import root_mean_squared_error_using_sliding_window # noqa: F401 -from torchmetrics.functional.image.sam import spectral_angle_mapper # noqa: F401 -from torchmetrics.functional.image.ssim import ( # noqa: F401 +from torchmetrics.functional.image.d_lambda import spectral_distortion_index +from torchmetrics.functional.image.ergas import error_relative_global_dimensionless_synthesis +from torchmetrics.functional.image.gradients import image_gradients +from torchmetrics.functional.image.psnr import peak_signal_noise_ratio +from torchmetrics.functional.image.rase import relative_average_spectral_error +from torchmetrics.functional.image.rmse_sw import root_mean_squared_error_using_sliding_window +from torchmetrics.functional.image.sam import spectral_angle_mapper +from torchmetrics.functional.image.ssim import ( multiscale_structural_similarity_index_measure, structural_similarity_index_measure, ) -from torchmetrics.functional.image.tv import total_variation # noqa: F401 -from torchmetrics.functional.image.uqi import universal_image_quality_index # noqa: F401 +from torchmetrics.functional.image.tv import total_variation +from torchmetrics.functional.image.uqi import universal_image_quality_index + +__all__ = [ + "spectral_distortion_index", + "error_relative_global_dimensionless_synthesis", + "image_gradients", + "peak_signal_noise_ratio", + "relative_average_spectral_error", + "root_mean_squared_error_using_sliding_window", + "spectral_angle_mapper", + "multiscale_structural_similarity_index_measure", + "structural_similarity_index_measure", + "total_variation", + "universal_image_quality_index", +] diff --git a/src/torchmetrics/image/__init__.py b/src/torchmetrics/image/__init__.py index 2e1959db954..df9d1f75efc 100644 --- a/src/torchmetrics/image/__init__.py +++ b/src/torchmetrics/image/__init__.py @@ -11,25 +11,42 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from torchmetrics.image.d_lambda import SpectralDistortionIndex # noqa: F401 -from torchmetrics.image.ergas import ErrorRelativeGlobalDimensionlessSynthesis # noqa: F401 -from torchmetrics.image.psnr import PeakSignalNoiseRatio # noqa: F401 -from torchmetrics.image.rase import RelativeAverageSpectralError # noqa: F401 -from torchmetrics.image.rmse_sw import RootMeanSquaredErrorUsingSlidingWindow # noqa: F401 -from torchmetrics.image.sam import SpectralAngleMapper # noqa: F401 -from torchmetrics.image.ssim import ( # noqa: F401 - MultiScaleStructuralSimilarityIndexMeasure, - StructuralSimilarityIndexMeasure, -) -from torchmetrics.image.uqi import UniversalImageQualityIndex # noqa: F401 +from torchmetrics.image.d_lambda import SpectralDistortionIndex +from torchmetrics.image.ergas import ErrorRelativeGlobalDimensionlessSynthesis +from torchmetrics.image.psnr import PeakSignalNoiseRatio +from torchmetrics.image.rase import RelativeAverageSpectralError +from torchmetrics.image.rmse_sw import RootMeanSquaredErrorUsingSlidingWindow +from torchmetrics.image.sam import SpectralAngleMapper +from torchmetrics.image.ssim import MultiScaleStructuralSimilarityIndexMeasure, StructuralSimilarityIndexMeasure +from torchmetrics.image.tv import TotalVariation +from torchmetrics.image.uqi import UniversalImageQualityIndex from torchmetrics.utilities.imports import _LPIPS_AVAILABLE, _TORCH_FIDELITY_AVAILABLE +__all__ = [ + "SpectralDistortionIndex", + "ErrorRelativeGlobalDimensionlessSynthesis", + "PeakSignalNoiseRatio", + "RelativeAverageSpectralError", + "RootMeanSquaredErrorUsingSlidingWindow", + "SpectralAngleMapper", + "MultiScaleStructuralSimilarityIndexMeasure", + "StructuralSimilarityIndexMeasure", + "UniversalImageQualityIndex", + "TotalVariation", +] + if _TORCH_FIDELITY_AVAILABLE: - from torchmetrics.image.fid import FrechetInceptionDistance # noqa: F401 - from torchmetrics.image.inception import InceptionScore # noqa: F401 - from torchmetrics.image.kid import KernelInceptionDistance # noqa: F401 + from torchmetrics.image.fid import FrechetInceptionDistance + from torchmetrics.image.inception import InceptionScore + from torchmetrics.image.kid import KernelInceptionDistance + + __all__ += [ + "FrechetInceptionDistance", + "InceptionScore", + "KernelInceptionDistance", + ] if _LPIPS_AVAILABLE: from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity # noqa: F401 -from torchmetrics.image.tv import TotalVariation # noqa: F401 + __all__.append("LearnedPerceptualImagePatchSimilarity") From 53f96c38523d99d0d0cd851744e75b935841f61f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 11 Apr 2023 13:02:20 +0000 Subject: [PATCH 5/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/unittests/deprecations/root_class_imports.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/unittests/deprecations/root_class_imports.py b/tests/unittests/deprecations/root_class_imports.py index a70a31a72fa..99d846982d9 100644 --- a/tests/unittests/deprecations/root_class_imports.py +++ b/tests/unittests/deprecations/root_class_imports.py @@ -5,10 +5,10 @@ from torchmetrics import ( ErrorRelativeGlobalDimensionlessSynthesis, - MultiScaleStructuralSimilarityIndexMeasure, - PeakSignalNoiseRatio, ModifiedPanopticQuality, + MultiScaleStructuralSimilarityIndexMeasure, PanopticQuality, + PeakSignalNoiseRatio, PermutationInvariantTraining, RelativeAverageSpectralError, RootMeanSquaredErrorUsingSlidingWindow,