Skip to content

Commit

Permalink
imports: deprecate from pkg root [3/n] Image (#1696)
Browse files Browse the repository at this point in the history
  • Loading branch information
Borda authored Apr 11, 2023
1 parent 60e4177 commit 1056c26
Show file tree
Hide file tree
Showing 7 changed files with 627 additions and 48 deletions.
28 changes: 17 additions & 11 deletions src/torchmetrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,18 +47,24 @@
from torchmetrics.collections import MetricCollection # noqa: E402
from torchmetrics.detection._deprecated import _ModifiedPanopticQuality as ModifiedPanopticQuality # noqa: E402
from torchmetrics.detection._deprecated import _PanopticQuality as PanopticQuality # noqa: E402
from torchmetrics.image import ( # noqa: E402
ErrorRelativeGlobalDimensionlessSynthesis,
MultiScaleStructuralSimilarityIndexMeasure,
PeakSignalNoiseRatio,
RelativeAverageSpectralError,
RootMeanSquaredErrorUsingSlidingWindow,
SpectralAngleMapper,
SpectralDistortionIndex,
StructuralSimilarityIndexMeasure,
TotalVariation,
UniversalImageQualityIndex,
from torchmetrics.image._deprecated import ( # noqa: E402
_ErrorRelativeGlobalDimensionlessSynthesis as ErrorRelativeGlobalDimensionlessSynthesis,
)
from torchmetrics.image._deprecated import ( # noqa: E402
_MultiScaleStructuralSimilarityIndexMeasure as MultiScaleStructuralSimilarityIndexMeasure,
)
from torchmetrics.image._deprecated import _PeakSignalNoiseRatio as PeakSignalNoiseRatio # noqa: E402
from torchmetrics.image._deprecated import _RelativeAverageSpectralError as RelativeAverageSpectralError # noqa: E402
from torchmetrics.image._deprecated import ( # noqa: E402
_RootMeanSquaredErrorUsingSlidingWindow as RootMeanSquaredErrorUsingSlidingWindow,
)
from torchmetrics.image._deprecated import _SpectralAngleMapper as SpectralAngleMapper # noqa: E402
from torchmetrics.image._deprecated import _SpectralDistortionIndex as SpectralDistortionIndex # noqa: E402
from torchmetrics.image._deprecated import ( # noqa: E402
_StructuralSimilarityIndexMeasure as StructuralSimilarityIndexMeasure,
)
from torchmetrics.image._deprecated import _TotalVariation as TotalVariation # noqa: E402
from torchmetrics.image._deprecated import _UniversalImageQualityIndex as UniversalImageQualityIndex # noqa: E402
from torchmetrics.metric import Metric # noqa: E402
from torchmetrics.nominal import CramersV # noqa: E402
from torchmetrics.nominal import PearsonsContingencyCoefficient # noqa: E402
Expand Down
32 changes: 20 additions & 12 deletions src/torchmetrics/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,19 +41,27 @@
from torchmetrics.functional.classification.stat_scores import stat_scores
from torchmetrics.functional.detection._deprecated import _modified_panoptic_quality as modified_panoptic_quality
from torchmetrics.functional.detection._deprecated import _panoptic_quality as panoptic_quality
from torchmetrics.functional.image.d_lambda import spectral_distortion_index
from torchmetrics.functional.image.ergas import error_relative_global_dimensionless_synthesis
from torchmetrics.functional.image.gradients import image_gradients
from torchmetrics.functional.image.psnr import peak_signal_noise_ratio
from torchmetrics.functional.image.rase import relative_average_spectral_error
from torchmetrics.functional.image.rmse_sw import root_mean_squared_error_using_sliding_window
from torchmetrics.functional.image.sam import spectral_angle_mapper
from torchmetrics.functional.image.ssim import (
multiscale_structural_similarity_index_measure,
structural_similarity_index_measure,
from torchmetrics.functional.image._deprecated import (
_error_relative_global_dimensionless_synthesis as error_relative_global_dimensionless_synthesis,
)
from torchmetrics.functional.image.tv import total_variation
from torchmetrics.functional.image.uqi import universal_image_quality_index
from torchmetrics.functional.image._deprecated import _image_gradients as image_gradients
from torchmetrics.functional.image._deprecated import (
_multiscale_structural_similarity_index_measure as multiscale_structural_similarity_index_measure,
)
from torchmetrics.functional.image._deprecated import _peak_signal_noise_ratio as peak_signal_noise_ratio
from torchmetrics.functional.image._deprecated import (
_relative_average_spectral_error as relative_average_spectral_error,
)
from torchmetrics.functional.image._deprecated import (
_root_mean_squared_error_using_sliding_window as root_mean_squared_error_using_sliding_window,
)
from torchmetrics.functional.image._deprecated import _spectral_angle_mapper as spectral_angle_mapper
from torchmetrics.functional.image._deprecated import _spectral_distortion_index as spectral_distortion_index
from torchmetrics.functional.image._deprecated import (
_structural_similarity_index_measure as structural_similarity_index_measure,
)
from torchmetrics.functional.image._deprecated import _total_variation as total_variation
from torchmetrics.functional.image._deprecated import _universal_image_quality_index as universal_image_quality_index
from torchmetrics.functional.nominal.cramers import cramers_v, cramers_v_matrix
from torchmetrics.functional.nominal.pearson import (
pearsons_contingency_coefficient,
Expand Down
34 changes: 24 additions & 10 deletions src/torchmetrics/functional/image/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,30 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from torchmetrics.functional.image.d_lambda import spectral_distortion_index # noqa: F401
from torchmetrics.functional.image.ergas import error_relative_global_dimensionless_synthesis # noqa: F401
from torchmetrics.functional.image.gradients import image_gradients # noqa: F401
from torchmetrics.functional.image.psnr import peak_signal_noise_ratio # noqa: F401
from torchmetrics.functional.image.rase import relative_average_spectral_error # noqa: F401
from torchmetrics.functional.image.rmse_sw import root_mean_squared_error_using_sliding_window # noqa: F401
from torchmetrics.functional.image.sam import spectral_angle_mapper # noqa: F401
from torchmetrics.functional.image.ssim import ( # noqa: F401
from torchmetrics.functional.image.d_lambda import spectral_distortion_index
from torchmetrics.functional.image.ergas import error_relative_global_dimensionless_synthesis
from torchmetrics.functional.image.gradients import image_gradients
from torchmetrics.functional.image.psnr import peak_signal_noise_ratio
from torchmetrics.functional.image.rase import relative_average_spectral_error
from torchmetrics.functional.image.rmse_sw import root_mean_squared_error_using_sliding_window
from torchmetrics.functional.image.sam import spectral_angle_mapper
from torchmetrics.functional.image.ssim import (
multiscale_structural_similarity_index_measure,
structural_similarity_index_measure,
)
from torchmetrics.functional.image.tv import total_variation # noqa: F401
from torchmetrics.functional.image.uqi import universal_image_quality_index # noqa: F401
from torchmetrics.functional.image.tv import total_variation
from torchmetrics.functional.image.uqi import universal_image_quality_index

__all__ = [
"spectral_distortion_index",
"error_relative_global_dimensionless_synthesis",
"image_gradients",
"peak_signal_noise_ratio",
"relative_average_spectral_error",
"root_mean_squared_error_using_sliding_window",
"spectral_angle_mapper",
"multiscale_structural_similarity_index_measure",
"structural_similarity_index_measure",
"total_variation",
"universal_image_quality_index",
]
255 changes: 255 additions & 0 deletions src/torchmetrics/functional/image/_deprecated.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,255 @@
from typing import Optional, Sequence, Tuple, Union

from torch import Tensor
from typing_extensions import Literal

from torchmetrics.functional.image.d_lambda import spectral_distortion_index
from torchmetrics.functional.image.ergas import error_relative_global_dimensionless_synthesis
from torchmetrics.functional.image.gradients import image_gradients
from torchmetrics.functional.image.psnr import peak_signal_noise_ratio
from torchmetrics.functional.image.rase import relative_average_spectral_error
from torchmetrics.functional.image.rmse_sw import root_mean_squared_error_using_sliding_window
from torchmetrics.functional.image.sam import spectral_angle_mapper
from torchmetrics.functional.image.ssim import (
multiscale_structural_similarity_index_measure,
structural_similarity_index_measure,
)
from torchmetrics.functional.image.tv import total_variation
from torchmetrics.functional.image.uqi import universal_image_quality_index
from torchmetrics.utilities.prints import _deprecated_root_import_func


def _spectral_distortion_index(
preds: Tensor,
target: Tensor,
p: int = 1,
reduction: Literal["elementwise_mean", "sum", "none"] = "elementwise_mean",
) -> Tensor:
"""Wrapper for deprecated import.
>>> import torch
>>> _ = torch.manual_seed(42)
>>> preds = torch.rand([16, 3, 16, 16])
>>> target = torch.rand([16, 3, 16, 16])
>>> _spectral_distortion_index(preds, target)
tensor(0.0234)
"""
_deprecated_root_import_func("spectral_distortion_index", "image")
return spectral_distortion_index(preds=preds, target=target, p=p, reduction=reduction)


def _error_relative_global_dimensionless_synthesis(
preds: Tensor,
target: Tensor,
ratio: Union[int, float] = 4,
reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean",
) -> Tensor:
"""Wrapper for deprecated import.
>>> import torch
>>> preds = torch.rand([16, 1, 16, 16], generator=torch.manual_seed(42))
>>> target = preds * 0.75
>>> ergds = _error_relative_global_dimensionless_synthesis(preds, target)
>>> torch.round(ergds)
tensor(154.)
"""
_deprecated_root_import_func("error_relative_global_dimensionless_synthesis", "image")
return error_relative_global_dimensionless_synthesis(preds=preds, target=target, ratio=ratio, reduction=reduction)


def _image_gradients(img: Tensor) -> Tuple[Tensor, Tensor]:
"""Wrapper for deprecated import.
>>> import torch
>>> image = torch.arange(0, 1*1*5*5, dtype=torch.float32)
>>> image = torch.reshape(image, (1, 1, 5, 5))
>>> dy, dx = _image_gradients(image)
>>> dy[0, 0, :, :]
tensor([[5., 5., 5., 5., 5.],
[5., 5., 5., 5., 5.],
[5., 5., 5., 5., 5.],
[5., 5., 5., 5., 5.],
[0., 0., 0., 0., 0.]])
"""
_deprecated_root_import_func("image_gradients", "image")
return image_gradients(img=img)


def _peak_signal_noise_ratio(
preds: Tensor,
target: Tensor,
data_range: Optional[Union[float, Tuple[float, float]]] = None,
base: float = 10.0,
reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean",
dim: Optional[Union[int, Tuple[int, ...]]] = None,
) -> Tensor:
"""Wrapper for deprecated import.
>>> from torch import tensor
>>> pred = tensor([[0.0, 1.0], [2.0, 3.0]])
>>> target = tensor([[3.0, 2.0], [1.0, 0.0]])
>>> _peak_signal_noise_ratio(pred, target)
tensor(2.5527)
"""
_deprecated_root_import_func("peak_signal_noise_ratio", "image")
return peak_signal_noise_ratio(
preds=preds, target=target, data_range=data_range, base=base, reduction=reduction, dim=dim
)


def _relative_average_spectral_error(preds: Tensor, target: Tensor, window_size: int = 8) -> Tensor:
"""Wrapper for deprecated import.
>>> import torch
>>> g = torch.manual_seed(22)
>>> preds = torch.rand(4, 3, 16, 16)
>>> target = torch.rand(4, 3, 16, 16)
>>> _relative_average_spectral_error(preds, target)
tensor(5114.6641)
"""
_deprecated_root_import_func("relative_average_spectral_error", "image")
return relative_average_spectral_error(preds=preds, target=target, window_size=window_size)


def _root_mean_squared_error_using_sliding_window(
preds: Tensor, target: Tensor, window_size: int = 8, return_rmse_map: bool = False
) -> Union[Optional[Tensor], Tuple[Optional[Tensor], Tensor]]:
"""Wrapper for deprecated import.
>>> import torch
>>> g = torch.manual_seed(22)
>>> preds = torch.rand(4, 3, 16, 16)
>>> target = torch.rand(4, 3, 16, 16)
>>> _root_mean_squared_error_using_sliding_window(preds, target)
tensor(0.3999)
"""
_deprecated_root_import_func("root_mean_squared_error_using_sliding_window", "image")
return root_mean_squared_error_using_sliding_window(
preds=preds, target=target, window_size=window_size, return_rmse_map=return_rmse_map
)


def _spectral_angle_mapper(
preds: Tensor,
target: Tensor,
reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean",
) -> Tensor:
"""Wrapper for deprecated import.
>>> import torch
>>> preds = torch.rand([16, 3, 16, 16], generator=torch.manual_seed(42))
>>> target = torch.rand([16, 3, 16, 16], generator=torch.manual_seed(123))
>>> _spectral_angle_mapper(preds, target)
tensor(0.5943)
"""
_deprecated_root_import_func("spectral_angle_mapper", "image")
return spectral_angle_mapper(preds=preds, target=target, reduction=reduction)


def _multiscale_structural_similarity_index_measure(
preds: Tensor,
target: Tensor,
gaussian_kernel: bool = True,
sigma: Union[float, Sequence[float]] = 1.5,
kernel_size: Union[int, Sequence[int]] = 11,
reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean",
data_range: Optional[Union[float, Tuple[float, float]]] = None,
k1: float = 0.01,
k2: float = 0.03,
betas: Tuple[float, ...] = (0.0448, 0.2856, 0.3001, 0.2363, 0.1333),
normalize: Optional[Literal["relu", "simple"]] = "relu",
) -> Tensor:
"""Wrapper for deprecated import.
>>> import torch
>>> preds = torch.rand([3, 3, 256, 256], generator=torch.manual_seed(42))
>>> target = preds * 0.75
>>> _multiscale_structural_similarity_index_measure(preds, target, data_range=1.0)
tensor(0.9627)
"""
_deprecated_root_import_func("multiscale_structural_similarity_index_measure", "image")
return multiscale_structural_similarity_index_measure(
preds=preds,
target=target,
gaussian_kernel=gaussian_kernel,
sigma=sigma,
kernel_size=kernel_size,
reduction=reduction,
data_range=data_range,
k1=k1,
k2=k2,
betas=betas,
normalize=normalize,
)


def _structural_similarity_index_measure(
preds: Tensor,
target: Tensor,
gaussian_kernel: bool = True,
sigma: Union[float, Sequence[float]] = 1.5,
kernel_size: Union[int, Sequence[int]] = 11,
reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean",
data_range: Optional[Union[float, Tuple[float, float]]] = None,
k1: float = 0.01,
k2: float = 0.03,
return_full_image: bool = False,
return_contrast_sensitivity: bool = False,
) -> Union[Tensor, Tuple[Tensor, Tensor]]:
"""Wrapper for deprecated import.
>>> import torch
>>> preds = torch.rand([3, 3, 256, 256])
>>> target = preds * 0.75
>>> _structural_similarity_index_measure(preds, target)
tensor(0.9219)
"""
_deprecated_root_import_func("spectral_angle_mapper", "image")
return structural_similarity_index_measure(
preds=preds,
target=target,
gaussian_kernel=gaussian_kernel,
sigma=sigma,
kernel_size=kernel_size,
reduction=reduction,
data_range=data_range,
k1=k1,
k2=k2,
return_full_image=return_full_image,
return_contrast_sensitivity=return_contrast_sensitivity,
)


def _total_variation(img: Tensor, reduction: Literal["mean", "sum", "none", None] = "sum") -> Tensor:
"""Wrapper for deprecated import.
>>> import torch
>>> _ = torch.manual_seed(42)
>>> img = torch.rand(5, 3, 28, 28)
>>> _total_variation(img)
tensor(7546.8018)
"""
_deprecated_root_import_func("total_variation", "image")
return total_variation(img=img, reduction=reduction)


def _universal_image_quality_index(
preds: Tensor,
target: Tensor,
kernel_size: Sequence[int] = (11, 11),
sigma: Sequence[float] = (1.5, 1.5),
reduction: Optional[Literal["elementwise_mean", "sum", "none"]] = "elementwise_mean",
data_range: Optional[float] = None,
) -> Tensor:
"""Wrapper for deprecated import.
>>> import torch
>>> preds = torch.rand([16, 1, 16, 16])
>>> target = preds * 0.75
>>> _universal_image_quality_index(preds, target)
tensor(0.9216)
"""
_deprecated_root_import_func("universal_image_quality_index", "image")
return universal_image_quality_index(
preds=preds, target=target, kernel_size=kernel_size, sigma=sigma, reduction=reduction, data_range=data_range
)
Loading

0 comments on commit 1056c26

Please sign in to comment.