From 2a6dd63ac5ac45df5c43f46e30af53de30ae8324 Mon Sep 17 00:00:00 2001 From: Vincent Chan Date: Mon, 4 Dec 2023 21:32:04 +0000 Subject: [PATCH 01/32] added new matric "spatial distortion index" --- src/torchmetrics/functional/image/__init__.py | 2 + src/torchmetrics/functional/image/d_s.py | 246 ++++++++++++ src/torchmetrics/image/__init__.py | 2 + src/torchmetrics/image/d_s.py | 191 +++++++++ tests/unittests/helpers/testers.py | 27 +- tests/unittests/image/test_d_s.py | 365 ++++++++++++++++++ 6 files changed, 830 insertions(+), 3 deletions(-) create mode 100644 src/torchmetrics/functional/image/d_s.py create mode 100644 src/torchmetrics/image/d_s.py create mode 100644 tests/unittests/image/test_d_s.py diff --git a/src/torchmetrics/functional/image/__init__.py b/src/torchmetrics/functional/image/__init__.py index 329b33b66fe..a265a9934d5 100644 --- a/src/torchmetrics/functional/image/__init__.py +++ b/src/torchmetrics/functional/image/__init__.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from torchmetrics.functional.image.d_lambda import spectral_distortion_index +from torchmetrics.functional.image.d_s import spatial_distortion_index from torchmetrics.functional.image.ergas import error_relative_global_dimensionless_synthesis from torchmetrics.functional.image.gradients import image_gradients from torchmetrics.functional.image.lpips import learned_perceptual_image_patch_similarity @@ -31,6 +32,7 @@ __all__ = [ "spectral_distortion_index", + "spatial_distortion_index", "error_relative_global_dimensionless_synthesis", "image_gradients", "peak_signal_noise_ratio", diff --git a/src/torchmetrics/functional/image/d_s.py b/src/torchmetrics/functional/image/d_s.py new file mode 100644 index 00000000000..db2b74d9111 --- /dev/null +++ b/src/torchmetrics/functional/image/d_s.py @@ -0,0 +1,246 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, Tuple + +import torch +from kornia.filters import filter2d +from torch import Tensor +from torchvision.transforms.functional import resize +from typing_extensions import Literal + +from torchmetrics.functional.image.uqi import universal_image_quality_index +from torchmetrics.utilities.distributed import reduce + + +def _spatial_distortion_index_update(preds: Tensor, target: Dict[str, Tensor]) -> Tuple[Tensor, Tensor]: + """Update and returns variables required to compute Spatial Distortion Index. + + Args: + preds: High resolution multispectral image. + target: A dictionary containing the following keys: + + - ``'ms'``: low resolution multispectral image. + - ``'pan'``: high resolution panchromatic image. + - ``'pan_lr'``: (optional) low resolution panchromatic image. + + Return: + A tuple of Tensors containing ``preds`` and ``target``. + + Raises: + TypeError: + If ``preds`` and ``target`` don't have the same data type. + ValueError: + If ``preds`` and ``target`` don't have ``BxCxHxW shape``. + ValueError: + If ``preds`` and ``target`` don't have the same batch and channel sizes. + ValueError: + If ``target`` doesn't have ``ms`` and ``pan``. + + """ + if len(preds.shape) != 4: + raise ValueError(f"Expected `preds` to have BxCxHxW shape. Got preds: {preds.shape}.") + if "ms" not in target or "pan" not in target: + raise ValueError(f"Expected `target` to have keys ('ms', 'pan'). Got target: {target.keys()}") + for name, t in target.items(): + if preds.dtype != t.dtype: + raise TypeError( + f"Expected `preds` and `{name}` to have the same data type. " + "Got preds: {preds.dtype} and {name}: {t.dtype}." + ) + for name, t in target.items(): + if len(t.shape) != 4: + raise ValueError(f"Expected `{name}` to have BxCxHxW shape. Got {name}: {t.shape}.") + for name, t in target.items(): + if preds.shape[:2] != t.shape[:2]: + raise ValueError( + f"Expected `preds` and `{name}` to have same batch and channel sizes. " + "Got preds: {preds.shape} and {name}: {t.shape}." + ) + return preds, target + + +def _spatial_distortion_index_compute( + preds: Tensor, + target: Dict[str, Tensor], + p: int = 1, + ws: int = 7, + reduction: Literal["elementwise_mean", "sum", "none"] = "elementwise_mean", +) -> Tensor: + """Compute Spatial Distortion Index (SpatialDistortionIndex_). + + Args: + preds: High resolution multispectral image. + target: A dictionary containing the following keys: + + - ``'ms'``: low resolution multispectral image. + - ``'pan'``: high resolution panchromatic image. + - ``'pan_lr'``: (optional) low resolution panchromatic image. + + p: Order of the norm applied on the difference. + ws: Window size of the filter applied to degrade the high resolution panchromatic image. + reduction: A method to reduce metric score over labels. + + - ``'elementwise_mean'``: takes the mean (default) + - ``'sum'``: takes the sum + - ``'none'``: no reduction will be applied + + Return: + Tensor with SpatialDistortionIndex score + + Raises: + ValueError: + If ``preds`` and ``pan`` don't have the same dimension. + ValueError: + If ``ms`` and ``pan_lr`` don't have the same dimension. + ValueError: + If ``preds`` and ``pan`` don't have dimension which is multiple of that of ``ms``. + + Example: + >>> _ = torch.manual_seed(42) + >>> preds = torch.rand([16, 3, 32, 32]) + >>> target = { + >>> 'ms': torch.rand([16, 3, 16, 16]), + >>> 'pan': torch.rand([16, 3, 32, 32]), + >>> } + >>> preds, target = _spatial_distortion_index_update(preds, target) + >>> _spatial_distortion_index_compute(preds, target) + tensor(0.0051) + + """ + length = preds.shape[1] + + ms = target["ms"] + pan = target["pan"] + pan_lr = target["pan_lr"] if "pan_lr" in target else None + + preds_h, preds_w = preds.shape[-2:] + ms_h, ms_w = ms.shape[-2:] + pan_h, pan_w = pan.shape[-2:] + if preds_h != pan_h: + raise ValueError(f"Expected `preds` and `pan` to have the same height. Got preds: {preds_h} and pan: {pan_h}") + if preds_w != pan_w: + raise ValueError(f"Expected `preds` and `pan` to have the same width. Got preds: {preds_w} and pan: {pan_w}") + if preds_h % ms_h != 0: + raise ValueError( + f"Expected height of `preds` to be multiple of height of `ms`. Got preds: {preds_h} and ms: {ms_h}." + ) + if preds_w % ms_w != 0: + raise ValueError( + f"Expected width of `preds` to be multiple of width of `ms`. Got preds: {preds_w} and ms: {ms_w}." + ) + if pan_h % ms_h != 0: + raise ValueError( + f"Expected height of `pan` to be multiple of height of `ms`. Got preds: {pan_h} and ms: {ms_h}." + ) + if pan_w % ms_w != 0: + raise ValueError(f"Expected width of `pan` to be multiple of width of `ms`. Got preds: {pan_w} and ms: {ms_w}.") + if ws >= ms_h or ws >= ms_w: + raise ValueError(f"Expected `ws` to be smaller than dimension of `ms`. Got ws: {ws}.") + + if pan_lr is not None: + pan_lr_h, pan_lr_w = pan_lr.shape[-2:] + if pan_lr_h != ms_h: + raise ValueError( + f"Expected `ms` and `pan_lr` to have the same height. Got ms: {ms_h} and pan_lr: {pan_lr_h}." + ) + if pan_lr_w != ms_w: + raise ValueError( + f"Expected `ms` and `pan_lr` to have the same width. Got ms: {ms_w} and pan_lr: {pan_lr_w}." + ) + + pan_degraded = pan_lr + if pan_degraded is None: + kernel = torch.ones(size=(1, ws, ws)) + pan_degraded = filter2d(pan, kernel, border_type="replicate", normalized=True) + pan_degraded = resize(pan_degraded, size=ms.shape[-2:], antialias=False) + + m1 = torch.zeros(length, device=preds.device) + m2 = torch.zeros(length, device=preds.device) + + for i in range(length): + m1[i] = universal_image_quality_index(ms[:, i : i + 1], pan_degraded[:, i : i + 1]) + m2[i] = universal_image_quality_index(preds[:, i : i + 1], pan[:, i : i + 1]) + diff = (m1 - m2).abs() ** p + return reduce(diff, reduction) ** (1 / p) + + +def spatial_distortion_index( + preds: Tensor, + target: Dict[str, Tensor], + p: int = 1, + ws: int = 7, + reduction: Literal["elementwise_mean", "sum", "none"] = "elementwise_mean", +) -> Tensor: + """Calculate `Spatial Distortion Index`_ (SpatialDistortionIndex_) also known as D_s. + + Metric is used to compare the spatial distortion between two images. + + Args: + preds: High resolution multispectral image. + target: A dictionary containing the following keys: + + - ``'ms'``: low resolution multispectral image. + - ``'pan'``: high resolution panchromatic image. + - ``'pan_lr'``: (optional) low resolution panchromatic image. + + p: Order of the norm applied on the difference. + ws: Window size of the filter applied to degrade the high resolution panchromatic image. + reduction: A method to reduce metric score over labels. + + - ``'elementwise_mean'``: takes the mean (default) + - ``'sum'``: takes the sum + - ``'none'``: no reduction will be applied + + Return: + Tensor with SpatialDistortionIndex score + + Raises: + TypeError: + If ``preds`` and ``target`` don't have the same data type. + ValueError: + If ``preds`` and ``target`` don't have ``BxCxHxW shape``. + ValueError: + If ``preds`` and ``target`` don't have the same batch and channel sizes. + ValueError: + If ``target`` doesn't have ``ms`` and ``pan``. + ValueError: + If ``preds`` and ``pan`` don't have the same dimension. + ValueError: + If ``ms`` and ``pan_lr`` don't have the same dimension. + ValueError: + If ``preds`` and ``pan`` don't have dimension which is multiple of that of ``ms``. + ValueError: + If ``p`` is not a positive integer. + ValueError: + If ``ws`` is not a positive integer. + + Example: + >>> from torchmetrics.functional.image import spatial_distortion_index + >>> _ = torch.manual_seed(42) + >>> preds = torch.rand([16, 3, 32, 32]) + >>> target = { + >>> 'ms': torch.rand([16, 3, 16, 16]), + >>> 'pan': torch.rand([16, 3, 32, 32]), + >>> } + >>> spatial_distortion_index(preds, target) + tensor(0.0051) + + """ + if not isinstance(p, int) or p <= 0: + raise ValueError(f"Expected `p` to be a positive integer. Got p: {p}.") + if not isinstance(ws, int) or ws <= 0: + raise ValueError(f"Expected `ws` to be a positive integer. Got ws: {ws}.") + preds, target = _spatial_distortion_index_update(preds, target) + return _spatial_distortion_index_compute(preds, target, p, ws, reduction) diff --git a/src/torchmetrics/image/__init__.py b/src/torchmetrics/image/__init__.py index 1defa78bbf5..f0c6252881a 100644 --- a/src/torchmetrics/image/__init__.py +++ b/src/torchmetrics/image/__init__.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from torchmetrics.image.d_lambda import SpectralDistortionIndex +from torchmetrics.image.d_s import SpatialDistortionIndex from torchmetrics.image.ergas import ErrorRelativeGlobalDimensionlessSynthesis from torchmetrics.image.mifid import MemorizationInformedFrechetInceptionDistance from torchmetrics.image.psnr import PeakSignalNoiseRatio @@ -30,6 +31,7 @@ __all__ = [ "SpectralDistortionIndex", + "SpatialDistortionIndex", "ErrorRelativeGlobalDimensionlessSynthesis", "PeakSignalNoiseRatio", "PeakSignalNoiseRatioWithBlockedEffect", diff --git a/src/torchmetrics/image/d_s.py b/src/torchmetrics/image/d_s.py new file mode 100644 index 00000000000..c9bc9d83c67 --- /dev/null +++ b/src/torchmetrics/image/d_s.py @@ -0,0 +1,191 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, List, Optional, Sequence, Union + +from torch import Tensor +from typing_extensions import Literal + +from torchmetrics.functional.image.d_s import _spatial_distortion_index_compute, _spatial_distortion_index_update +from torchmetrics.metric import Metric +from torchmetrics.utilities import rank_zero_warn +from torchmetrics.utilities.data import dim_zero_cat +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["SpatialDistortionIndex.plot"] + + +class SpatialDistortionIndex(Metric): + """Compute Spatial Distortion Index (SpatialDistortionIndex_) also now as D_s. + + The metric is used to compare the spatial distortion between two images. + + As input to ``forward`` and ``update`` the metric accepts the following input + + - ``preds`` (:class:`~torch.Tensor`): High resolution multispectral image of shape ``(N,C,H,W)``. + - ``target`` (:class:`~Dict`): A dictionary containing the following keys: + - ``ms`` (:class:`~torch.Tensor`): Low resolution multispectral image of shape ``(N,C,H',W')``. + - ``pan`` (:class:`~torch.Tensor`): High resolution panchromatic image of shape ``(N,C,H,W)``. + - ``pan_lr`` (:class:`~torch.Tensor`): Low resolution panchromatic image of shape ``(N,C,H',W')``. + + where H and W must be multiple of H' and W'. + + As output of `forward` and `compute` the metric returns the following output + + - ``sdi`` (:class:`~torch.Tensor`): if ``reduction!='none'`` returns float scalar tensor with average SDI value + over sample else returns tensor of shape ``(N,)`` with SDI values per sample + + Args: + p: Order of the norm applied on the difference. + ws: Window size of the filter applied to degrade the high resolution panchromatic image. + reduction: a method to reduce metric score over labels. + + - ``'elementwise_mean'``: takes the mean (default) + - ``'sum'``: takes the sum + - ``'none'``: no reduction will be applied + + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Example: + >>> import torch + >>> _ = torch.manual_seed(42) + >>> from torchmetrics.image import SpatialDistortionIndex + >>> preds = torch.rand([16, 3, 32, 32]) + >>> target = { + >>> 'ms': torch.rand([16, 3, 16, 16]), + >>> 'pan': torch.rand([16, 3, 32, 32]), + >>> } + >>> sdi = SpatialDistortionIndex() + >>> sdi(preds, target) + tensor(0.0051) + + """ + + higher_is_better: bool = True + is_differentiable: bool = True + full_state_update: bool = False + plot_lower_bound: float = 0.0 + plot_upper_bound: float = 1.0 + + preds: List[Tensor] + ms: List[Tensor] + pan: List[Tensor] + pan_lr: List[Tensor] + + def __init__( + self, + p: int = 1, + ws: int = 7, + reduction: Literal["elementwise_mean", "sum", "none"] = "elementwise_mean", + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + rank_zero_warn( + "Metric `SpatialDistortionIndex` will save all targets and" + " predictions in buffer. For large datasets this may lead" + " to large memory footprint." + ) + + if not isinstance(p, int) or p <= 0: + raise ValueError(f"Expected `p` to be a positive integer. Got p: {p}.") + self.p = p + if not isinstance(ws, int) or ws <= 0: + raise ValueError(f"Expected `ws` to be a positive integer. Got ws: {ws}.") + self.ws = ws + allowed_reductions = ("elementwise_mean", "sum", "none") + if reduction not in allowed_reductions: + raise ValueError(f"Expected argument `reduction` be one of {allowed_reductions} but got {reduction}") + self.reduction = reduction + self.add_state("preds", default=[], dist_reduce_fx="cat") + self.add_state("ms", default=[], dist_reduce_fx="cat") + self.add_state("pan", default=[], dist_reduce_fx="cat") + self.add_state("pan_lr", default=[], dist_reduce_fx="cat") + + def update(self, preds: Tensor, target: Tensor) -> None: + """Update state with preds and target.""" + preds, target = _spatial_distortion_index_update(preds, target) + self.preds.append(preds) + self.ms.append(target["ms"]) + self.pan.append(target["pan"]) + if "pan_lr" in target: + self.pan_lr.append(target["pan_lr"]) + + def compute(self) -> Tensor: + """Compute and returns spatial distortion index.""" + preds = dim_zero_cat(self.preds) + ms = dim_zero_cat(self.ms) + pan = dim_zero_cat(self.pan) + pan_lr = dim_zero_cat(self.pan_lr) if len(self.pan_lr) > 0 else None + target = { + "ms": ms, + "pan": pan, + **({"pan_lr": pan_lr} if pan_lr is not None else {}), + } + return _spatial_distortion_index_compute(preds, target, self.p, self.ws, self.reduction) + + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> # Example plotting a single value + >>> import torch + >>> _ = torch.manual_seed(42) + >>> from torchmetrics.image import SpatialDistortionIndex + >>> preds = torch.rand([16, 3, 32, 32]) + >>> target = { + >>> 'ms': torch.rand([16, 3, 16, 16]), + >>> 'pan': torch.rand([16, 3, 32, 32]), + >>> } + >>> metric = SpatialDistortionIndex() + >>> metric.update(preds, target) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> # Example plotting multiple values + >>> import torch + >>> _ = torch.manual_seed(42) + >>> from torchmetrics.image import SpatialDistortionIndex + >>> preds = torch.rand([16, 3, 32, 32]) + >>> target = { + >>> 'ms': torch.rand([16, 3, 16, 16]), + >>> 'pan': torch.rand([16, 3, 32, 32]), + >>> } + >>> metric = SpatialDistortionIndex() + >>> values = [ ] + >>> for _ in range(10): + ... values.append(metric(preds, target)) + >>> fig_, ax_ = metric.plot(values) + + """ + return self._plot(val, ax) diff --git a/tests/unittests/helpers/testers.py b/tests/unittests/helpers/testers.py index ff315df6ec1..15e201e2bb7 100644 --- a/tests/unittests/helpers/testers.py +++ b/tests/unittests/helpers/testers.py @@ -209,6 +209,8 @@ def _class_test( total_preds = [item for sublist in preds for item in sublist] if isinstance(target, Tensor): total_target = torch.cat([target[i] for i in range(num_batches)]).cpu() + elif isinstance(target, list) and len(target) > 0 and isinstance(target[0], dict): + total_target = {k: torch.cat([t[k] for t in target]) for k in target[0]} else: total_target = [item for sublist in target for item in sublist] @@ -228,7 +230,7 @@ def _class_test( def _functional_test( preds: Union[Tensor, list], - target: Union[Tensor, list], + target: Union[Tensor, list, List[Dict[str, Tensor]]], metric_functional: Callable, reference_metric: Callable, metric_args: Optional[dict] = None, @@ -264,6 +266,13 @@ def _functional_test( preds = preds.to(device) if isinstance(target, Tensor): target = target.to(device) + elif isinstance(target, list): + for i, target_dict in enumerate(target): + if isinstance(target_dict, dict): + for k in target_dict: + if isinstance(target_dict[k], Tensor): + target[i][k] = target_dict[k].to(device) + kwargs_update = {k: v.to(device) if isinstance(v, Tensor) else v for k, v in kwargs_update.items()} for i in range(num_batches // 2): @@ -286,7 +295,7 @@ def _assert_dtype_support( metric_module: Optional[Metric], metric_functional: Optional[Callable], preds: Tensor, - target: Tensor, + target: Union[Tensor, List[Dict[str, Tensor]]], device: str = "cpu", dtype: torch.dtype = torch.half, **kwargs_update: Any, @@ -305,7 +314,18 @@ def _assert_dtype_support( """ y_hat = preds[0].to(dtype=dtype, device=device) if preds[0].is_floating_point() else preds[0].to(device) - y = target[0].to(dtype=dtype, device=device) if target[0].is_floating_point() else target[0].to(device) + y = ( + target[0].to(dtype=dtype, device=device) + if isinstance(target[0], Tensor) and target[0].is_floating_point() + else { + k: target[0][k].to(dtype=dtype, device=device) + if target[0][k].is_floating_point() + else target[0][k].to(device) + for k in target[0] + } + if isinstance(target[0], dict) + else target[0].to(device) + ) kwargs_update = { k: (v[0].to(dtype=dtype) if v.is_floating_point() else v[0]).to(device) if isinstance(v, Tensor) else v for k, v in kwargs_update.items() @@ -422,6 +442,7 @@ def run_class_metric_test( check_dist_sync_on_step=check_dist_sync_on_step, check_batch=check_batch, atol=atol, + device="cuda" if torch.cuda.is_available() else "cpu", fragment_kwargs=fragment_kwargs, check_scriptable=check_scriptable, check_state_dict=check_state_dict, diff --git a/tests/unittests/image/test_d_s.py b/tests/unittests/image/test_d_s.py new file mode 100644 index 00000000000..04020c4ca76 --- /dev/null +++ b/tests/unittests/image/test_d_s.py @@ -0,0 +1,365 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partial +from typing import Dict, List, NamedTuple + +import numpy as np +import pytest +import torch +from scipy.ndimage import uniform_filter +from skimage.transform import resize +from torch import Tensor +from torchmetrics.functional.image.d_s import spatial_distortion_index +from torchmetrics.functional.image.uqi import universal_image_quality_index +from torchmetrics.image.d_s import SpatialDistortionIndex + +from unittests import BATCH_SIZE, NUM_BATCHES +from unittests.helpers import seed_all +from unittests.helpers.testers import MetricTester + +seed_all(42) + + +class _Input(NamedTuple): + preds: Tensor + target: List[Dict[str, Tensor]] + p: int + ws: int + + +_inputs = [] +for size, channel, p, r, ws, pan_lr_exists, dtype in [ + (12, 3, 1, 16, 3, False, torch.float), + (13, 1, 3, 8, 5, False, torch.float32), + (14, 1, 4, 4, 5, True, torch.double), + (15, 3, 1, 2, 3, True, torch.float64), +]: + preds = torch.rand(NUM_BATCHES, BATCH_SIZE, channel, size * r, size * r, dtype=dtype) + ms = torch.rand(NUM_BATCHES, BATCH_SIZE, channel, size, size, dtype=dtype) + pan = torch.rand(NUM_BATCHES, BATCH_SIZE, channel, size * r, size * r, dtype=dtype) + pan_lr = torch.rand(NUM_BATCHES, BATCH_SIZE, channel, size, size, dtype=dtype) + _inputs.append( + _Input( + preds=preds, + target=[ + { + "ms": ms[i], + "pan": pan[i], + **({"pan_lr": pan_lr[i]} if pan_lr_exists else {}), + } + for i in range(NUM_BATCHES) + ], + p=p, + ws=ws, + ) + ) + + +def _baseline_d_s( + preds: np.ndarray, ms: np.ndarray, pan: np.ndarray, pan_lr: np.ndarray = None, p: int = 1, ws: int = 7 +) -> float: + """NumPy based implementation of Spatial Distortion Index, which uses UQI of TorchMetrics.""" + pan_degraded = pan_lr + if pan_degraded is None: + try: + pan_degraded = uniform_filter(pan, size=ws, axes=[1, 2]) + except TypeError: + pan_degraded = np.array( + [[uniform_filter(pan[i, ..., j], size=ws) for j in range(pan.shape[-1])] for i in range(len(pan))] + ).transpose((0, 2, 3, 1)) + pan_degraded = np.array([resize(img, ms.shape[1:3], anti_aliasing=False) for img in pan_degraded]) + + length = preds.shape[-1] + m1 = np.zeros(length, dtype=np.float32) + m2 = np.zeros(length, dtype=np.float32) + + # Convert target and preds to Torch Tensors, pass them to metrics UQI + # this is mainly because reference repo (sewar) uses uniform distribution + # in their implementation of UQI, and we use gaussian distribution + # and they have different default values for some kwargs like window size. + ms = torch.from_numpy(ms).permute(0, 3, 1, 2) + pan = torch.from_numpy(pan).permute(0, 3, 1, 2) + preds = torch.from_numpy(preds).permute(0, 3, 1, 2) + pan_degraded = torch.from_numpy(pan_degraded).permute(0, 3, 1, 2) + for i in range(length): + m1[i] = universal_image_quality_index(ms[:, i : i + 1], pan_degraded[:, i : i + 1]) + m2[i] = universal_image_quality_index(preds[:, i : i + 1], pan[:, i : i + 1]) + diff = np.abs(m1 - m2) ** p + return np.mean(diff) ** (1 / p) + + +def _np_d_s(preds, target, p, ws): + np_preds = preds.permute(0, 2, 3, 1).cpu().numpy() + assert isinstance(target, dict), f"Expected `target` to be dict. Got {type(target)}." + assert "ms" in target, "Expected `target` to contain 'ms'." + np_ms = target["ms"].permute(0, 2, 3, 1).cpu().numpy() + assert "pan" in target, "Expected `target` to contain 'pan'." + np_pan = target["pan"].permute(0, 2, 3, 1).cpu().numpy() + np_pan_lr = target["pan_lr"].permute(0, 2, 3, 1).cpu().numpy() if "pan_lr" in target else None + + return _baseline_d_s( + np_preds, + np_ms, + np_pan, + np_pan_lr, + p=p, + ws=ws, + ) + + +@pytest.mark.parametrize( + "preds, target, p, ws", + [(i.preds, i.target, i.p, i.ws) for i in _inputs], +) +class TestSpatialDistortionIndex(MetricTester): + """Test class for `SpatialDistortionIndex` metric.""" + + atol = 3e-6 + + @pytest.mark.parametrize("ddp", [True, False]) + def test_d_s(self, preds, target, p, ws, ddp): + """Test class implementation of metric.""" + self.run_class_metric_test( + ddp, + preds, + target, + SpatialDistortionIndex, + partial(_np_d_s, p=p, ws=ws), + metric_args={"p": p, "ws": ws}, + ) + + def test_d_s_functional(self, preds, target, p, ws): + """Test functional implementation of metric.""" + self.run_functional_metric_test( + preds, + target, + spatial_distortion_index, + partial(_np_d_s, p=p, ws=ws), + metric_args={"p": p, "ws": ws}, + ) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + def test_d_s_half_gpu(self, preds, target, p, ws): + """Test dtype support of the metric on GPU.""" + self.run_precision_test_gpu(preds, target, SpatialDistortionIndex, spatial_distortion_index, {"p": p, "ws": ws}) + + +@pytest.mark.parametrize( + ("preds", "target", "p", "ws", "match"), + [ + ( + [1, 16, 16], + {"ms": [1, 1, 4, 4], "pan": [1, 1, 16, 16]}, + 1, + 3, + "Expected `preds` to have BxCxHxW shape.*", + ), # len(preds.shape) + ([1, 1, 16, 16], {}, 1, 7, r"Expected `target` to have keys \('ms', 'pan'\).*"), # target.keys() + ( + [1, 1, 16, 16], + {"ms": [1, 1, 4, 4]}, + 1, + 3, + r"Expected `target` to have keys \('ms', 'pan'\).*", + ), # target.keys() + ( + [1, 1, 16, 16], + {"pan": [1, 1, 16, 16]}, + 1, + 3, + r"Expected `target` to have keys \('ms', 'pan'\).*", + ), # target.keys() + ( + [1, 1, 16, 16], + {"ms": [1, 4, 4], "pan": [1, 1, 16, 16]}, + 1, + 3, + "Expected `ms` to have BxCxHxW shape.*", + ), # len(target.shape) + ( + [1, 1, 16, 16], + {"ms": [1, 1, 4, 4], "pan": [1, 16, 16]}, + 1, + 3, + "Expected `pan` to have BxCxHxW shape.*", + ), # len(target.shape) + ( + [1, 1, 16, 16], + {"ms": [1, 1, 4, 4], "pan": [1, 1, 16, 16], "pan_lr": [1, 4, 4]}, + 1, + 3, + "Expected `pan_lr` to have BxCxHxW shape.*", + ), # len(target.shape) + ( + [1, 1, 16, 16], + {"ms": [1, 1, 4, 4], "pan": [1, 1, 16, 16]}, + 0, + 3, + "Expected `p` to be a positive integer. Got p: 0.", + ), # invalid p + ( + [1, 1, 16, 16], + {"ms": [1, 1, 4, 4], "pan": [1, 1, 16, 16]}, + -1, + 3, + "Expected `p` to be a positive integer. Got p: -1.", + ), # invalid p + ( + [1, 1, 16, 16], + {"ms": [1, 1, 4, 4], "pan": [1, 1, 16, 16]}, + 1, + 0, + "Expected `ws` to be a positive integer. Got ws: 0.", + ), # invalid ws + ( + [1, 1, 16, 16], + {"ms": [1, 1, 4, 4], "pan": [1, 1, 16, 16]}, + 1, + -1, + "Expected `ws` to be a positive integer. Got ws: -1.", + ), # invalid ws + ( + [1, 1, 16, 16], + {"ms": [1, 1, 4, 4], "pan": [1, 1, 17, 16]}, + 1, + 3, + "Expected `preds` and `pan` to have the same height.*", + ), # invalid pan_h + ( + [1, 1, 16, 16], + {"ms": [1, 1, 4, 4], "pan": [1, 1, 16, 17]}, + 1, + 3, + "Expected `preds` and `pan` to have the same width.*", + ), # invalid pan_w + ( + [1, 1, 16, 16], + {"ms": [1, 1, 5, 4], "pan": [1, 1, 16, 16]}, + 1, + 3, + "Expected height of `preds` to be multiple of height of `ms`.*", + ), # invalid ms_h + ( + [1, 1, 16, 16], + {"ms": [1, 1, 4, 5], "pan": [1, 1, 16, 16]}, + 1, + 3, + "Expected width of `preds` to be multiple of width of `ms`.*", + ), # invalid ms_w + ( + [1, 1, 16, 16], + {"ms": [1, 1, 4, 4], "pan": [1, 1, 16, 16], "pan_lr": [1, 1, 5, 4]}, + 1, + 3, + "Expected `ms` and `pan_lr` to have the same height.*", + ), # invalid pan_lr_h + ( + [1, 1, 16, 16], + {"ms": [1, 1, 4, 4], "pan": [1, 1, 16, 16], "pan_lr": [1, 1, 4, 5]}, + 1, + 3, + "Expected `ms` and `pan_lr` to have the same width.*", + ), # invalid pan_lr_w + ( + [1, 1, 16, 16], + {"ms": [1, 2, 4, 4], "pan": [1, 1, 16, 16]}, + 1, + 3, + "Expected `preds` and `ms` to have same batch and channel.*", + ), # invalid ms.shape + ( + [1, 1, 16, 16], + {"ms": [2, 1, 4, 4], "pan": [1, 1, 16, 16]}, + 1, + 3, + "Expected `preds` and `ms` to have same batch and channel.*", + ), # invalid ms.shape + ( + [1, 1, 16, 16], + {"ms": [1, 1, 4, 4], "pan": [1, 2, 16, 16]}, + 1, + 3, + "Expected `preds` and `pan` to have same batch and channel.*", + ), # invalid pan.shape + ( + [1, 1, 16, 16], + {"ms": [1, 1, 4, 4], "pan": [2, 1, 16, 16]}, + 1, + 3, + "Expected `preds` and `pan` to have same batch and channel.*", + ), # invalid pan.shape + ( + [1, 1, 16, 16], + {"ms": [1, 1, 4, 4], "pan": [1, 1, 16, 16], "pan_lr": [1, 2, 4, 4]}, + 1, + 3, + "Expected `preds` and `pan_lr` to have same batch and channel.*", + ), # invalid pan_lr.shape + ( + [1, 1, 16, 16], + {"ms": [1, 1, 4, 4], "pan": [1, 1, 16, 16], "pan_lr": [2, 1, 4, 4]}, + 1, + 3, + "Expected `preds` and `pan_lr` to have same batch and channel.*", + ), # invalid pan_lr.shape + ( + [1, 1, 16, 16], + {"ms": [1, 1, 4, 4], "pan": [1, 1, 16, 16]}, + 1, + 5, + "Expected `ws` to be smaller than dimension of `ms`.*", + ), # invalid ws + ], +) +def test_d_s_invalid_inputs(preds, target, p, ws, match): + """Test that invalid input raises the correct errors.""" + preds_t = torch.rand(preds) + target_t = {name: torch.rand(t) for name, t in target.items()} + with pytest.raises(ValueError, match=match): + spatial_distortion_index(preds_t, target_t, p, ws) + + +@pytest.mark.parametrize( + ("target", "match"), + [ + ( + { + "ms": torch.rand((1, 1, 4, 4), dtype=torch.float64), + "pan": torch.rand((1, 1, 16, 16)), + }, + "Expected `preds` and `ms` to have the same data type.*", + ), + ( + { + "ms": torch.rand((1, 1, 4, 4)), + "pan": torch.rand((1, 1, 16, 16), dtype=torch.float64), + }, + "Expected `preds` and `pan` to have the same data type.*", + ), + ( + { + "ms": torch.rand((1, 1, 4, 4)), + "pan": torch.rand((1, 1, 16, 16)), + "pan_lr": torch.rand((1, 1, 4, 4), dtype=torch.float64), + }, + "Expected `preds` and `pan_lr` to have the same data type.*", + ), + ], +) +def test_d_s_invalid_type(target, match): + """Test that error is raised on different dtypes.""" + preds_t = torch.rand((1, 1, 16, 16)) + with pytest.raises(TypeError, match=match): + spatial_distortion_index(preds_t, target, p=1, ws=7) From 6bc62bf27ccac97104d0a227a25c4f03eb55a9f2 Mon Sep 17 00:00:00 2001 From: Vincent Chan Date: Mon, 4 Dec 2023 21:46:00 +0000 Subject: [PATCH 02/32] added missing docs --- CHANGELOG.md | 3 +++ .../source/image/spatial_distortion_index.rst | 21 +++++++++++++++++++ 2 files changed, 24 insertions(+) create mode 100644 docs/source/image/spatial_distortion_index.rst diff --git a/CHANGELOG.md b/CHANGELOG.md index d156f89bb08..187b97662e9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -28,6 +28,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added `aggregate`` argument to retrieval metrics ([#2220](https://github.com/Lightning-AI/torchmetrics/pull/2220)) +- Added `SpatialDistortionIndex` metric to image domain ([#2260](https://github.com/Lightning-AI/torchmetrics/pull/2260)) + + ### Changed - Changed minimum supported Pytorch version from 1.8 to 1.10 ([#2145](https://github.com/Lightning-AI/torchmetrics/pull/2145)) diff --git a/docs/source/image/spatial_distortion_index.rst b/docs/source/image/spatial_distortion_index.rst new file mode 100644 index 00000000000..86af7b3c96e --- /dev/null +++ b/docs/source/image/spatial_distortion_index.rst @@ -0,0 +1,21 @@ +.. customcarditem:: + :header: Spatial Distortion Index + :image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/image_classification.svg + :tags: Image + +.. include:: ../links.rst + +######################### +Spatial Distortion Index +######################### + +Module Interface +________________ + +.. autoclass:: torchmetrics.image.SpatialDistortionIndex + :exclude-members: update, compute + +Functional Interface +____________________ + +.. autofunction:: torchmetrics.functional.image.spatial_distortion_index From fed2f4a3375b19f9dbdfc374891c5c7b9bfc6026 Mon Sep 17 00:00:00 2001 From: Vincent Chan Date: Mon, 4 Dec 2023 22:10:38 +0000 Subject: [PATCH 03/32] moved kornia from image_test.txt to image.txt --- requirements/image.txt | 1 + requirements/image_test.txt | 3 +-- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/requirements/image.txt b/requirements/image.txt index 4d15d7b79ba..d9bda03d0a9 100644 --- a/requirements/image.txt +++ b/requirements/image.txt @@ -4,3 +4,4 @@ scipy >1.0.0, <1.11.0 torchvision >=0.8, <0.17.0 torch-fidelity <=0.4.0 # bumping to allow install version from master, now used in testing +kornia >=0.6.7, <0.7.1 diff --git a/requirements/image_test.txt b/requirements/image_test.txt index ae48f31ab5a..010e4bb6a54 100644 --- a/requirements/image_test.txt +++ b/requirements/image_test.txt @@ -1,8 +1,7 @@ # NOTE: the upper bound for the package version is only set for CI stability, and it is dropped while installing this package # in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment -scikit-image >=0.19.0, <=0.21.0 -kornia >=0.6.7, <0.7.1 +scikit-image >=0.19.0, <=0.21. pytorch-msssim ==1.0.0 sewar >=0.4.4, <=0.4.6 numpy <1.25.0 From aab1a48145ebc2e61d439aab93cf07559d01eba1 Mon Sep 17 00:00:00 2001 From: Vincent Chan Date: Mon, 4 Dec 2023 22:31:18 +0000 Subject: [PATCH 04/32] fixed typo in version in requirements --- requirements/image_test.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/image_test.txt b/requirements/image_test.txt index 010e4bb6a54..f10ace1580c 100644 --- a/requirements/image_test.txt +++ b/requirements/image_test.txt @@ -1,7 +1,7 @@ # NOTE: the upper bound for the package version is only set for CI stability, and it is dropped while installing this package # in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment -scikit-image >=0.19.0, <=0.21. +scikit-image >=0.19.0, <=0.21.0 pytorch-msssim ==1.0.0 sewar >=0.4.4, <=0.4.6 numpy <1.25.0 From 440266b8c32ff5e6e05afbd0d4a7072412265172 Mon Sep 17 00:00:00 2001 From: Vincent Chan Date: Mon, 4 Dec 2023 23:17:16 +0000 Subject: [PATCH 05/32] fixed docstrings --- src/torchmetrics/functional/image/d_s.py | 16 ++++++++-------- src/torchmetrics/image/d_s.py | 20 ++++++++++---------- 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/src/torchmetrics/functional/image/d_s.py b/src/torchmetrics/functional/image/d_s.py index db2b74d9111..d1f8b2a4f28 100644 --- a/src/torchmetrics/functional/image/d_s.py +++ b/src/torchmetrics/functional/image/d_s.py @@ -111,12 +111,12 @@ def _spatial_distortion_index_compute( >>> _ = torch.manual_seed(42) >>> preds = torch.rand([16, 3, 32, 32]) >>> target = { - >>> 'ms': torch.rand([16, 3, 16, 16]), - >>> 'pan': torch.rand([16, 3, 32, 32]), - >>> } + ... 'ms': torch.rand([16, 3, 16, 16]), + ... 'pan': torch.rand([16, 3, 32, 32]), + ... } >>> preds, target = _spatial_distortion_index_update(preds, target) >>> _spatial_distortion_index_compute(preds, target) - tensor(0.0051) + tensor(0.0090) """ length = preds.shape[1] @@ -231,11 +231,11 @@ def spatial_distortion_index( >>> _ = torch.manual_seed(42) >>> preds = torch.rand([16, 3, 32, 32]) >>> target = { - >>> 'ms': torch.rand([16, 3, 16, 16]), - >>> 'pan': torch.rand([16, 3, 32, 32]), - >>> } + ... 'ms': torch.rand([16, 3, 16, 16]), + ... 'pan': torch.rand([16, 3, 32, 32]), + ... } >>> spatial_distortion_index(preds, target) - tensor(0.0051) + tensor(0.0090) """ if not isinstance(p, int) or p <= 0: diff --git a/src/torchmetrics/image/d_s.py b/src/torchmetrics/image/d_s.py index c9bc9d83c67..4a40a9c75fe 100644 --- a/src/torchmetrics/image/d_s.py +++ b/src/torchmetrics/image/d_s.py @@ -65,12 +65,12 @@ class SpatialDistortionIndex(Metric): >>> from torchmetrics.image import SpatialDistortionIndex >>> preds = torch.rand([16, 3, 32, 32]) >>> target = { - >>> 'ms': torch.rand([16, 3, 16, 16]), - >>> 'pan': torch.rand([16, 3, 32, 32]), - >>> } + ... 'ms': torch.rand([16, 3, 16, 16]), + ... 'pan': torch.rand([16, 3, 32, 32]), + ... } >>> sdi = SpatialDistortionIndex() >>> sdi(preds, target) - tensor(0.0051) + tensor(0.0090) """ @@ -162,9 +162,9 @@ def plot( >>> from torchmetrics.image import SpatialDistortionIndex >>> preds = torch.rand([16, 3, 32, 32]) >>> target = { - >>> 'ms': torch.rand([16, 3, 16, 16]), - >>> 'pan': torch.rand([16, 3, 32, 32]), - >>> } + ... 'ms': torch.rand([16, 3, 16, 16]), + ... 'pan': torch.rand([16, 3, 32, 32]), + ... } >>> metric = SpatialDistortionIndex() >>> metric.update(preds, target) >>> fig_, ax_ = metric.plot() @@ -178,9 +178,9 @@ def plot( >>> from torchmetrics.image import SpatialDistortionIndex >>> preds = torch.rand([16, 3, 32, 32]) >>> target = { - >>> 'ms': torch.rand([16, 3, 16, 16]), - >>> 'pan': torch.rand([16, 3, 32, 32]), - >>> } + ... 'ms': torch.rand([16, 3, 16, 16]), + ... 'pan': torch.rand([16, 3, 32, 32]), + ... } >>> metric = SpatialDistortionIndex() >>> values = [ ] >>> for _ in range(10): From 3fe20a98cf66bbd9978086dc57d670c32bc32c53 Mon Sep 17 00:00:00 2001 From: Vincent Chan Date: Tue, 5 Dec 2023 23:18:13 +0000 Subject: [PATCH 06/32] changed kornia to lazy import --- src/torchmetrics/functional/image/d_s.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchmetrics/functional/image/d_s.py b/src/torchmetrics/functional/image/d_s.py index d1f8b2a4f28..800cafdb234 100644 --- a/src/torchmetrics/functional/image/d_s.py +++ b/src/torchmetrics/functional/image/d_s.py @@ -15,7 +15,6 @@ from typing import Dict, Tuple import torch -from kornia.filters import filter2d from torch import Tensor from torchvision.transforms.functional import resize from typing_extensions import Literal @@ -162,6 +161,7 @@ def _spatial_distortion_index_compute( pan_degraded = pan_lr if pan_degraded is None: + from kornia.filters import filter2d kernel = torch.ones(size=(1, ws, ws)) pan_degraded = filter2d(pan, kernel, border_type="replicate", normalized=True) pan_degraded = resize(pan_degraded, size=ms.shape[-2:], antialias=False) From fe9a83577129232c4d3ec7dcfd13185fa0753e27 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 5 Dec 2023 23:19:37 +0000 Subject: [PATCH 07/32] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/torchmetrics/functional/image/d_s.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/torchmetrics/functional/image/d_s.py b/src/torchmetrics/functional/image/d_s.py index 800cafdb234..fb5c2bdecbc 100644 --- a/src/torchmetrics/functional/image/d_s.py +++ b/src/torchmetrics/functional/image/d_s.py @@ -162,6 +162,7 @@ def _spatial_distortion_index_compute( pan_degraded = pan_lr if pan_degraded is None: from kornia.filters import filter2d + kernel = torch.ones(size=(1, ws, ws)) pan_degraded = filter2d(pan, kernel, border_type="replicate", normalized=True) pan_degraded = resize(pan_degraded, size=ms.shape[-2:], antialias=False) From 15eb07980c80a6042f4d428a52665386557288e0 Mon Sep 17 00:00:00 2001 From: Vincent Chan Date: Tue, 5 Dec 2023 23:45:10 +0000 Subject: [PATCH 08/32] changed torchvision to lazy import --- src/torchmetrics/functional/image/d_s.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/torchmetrics/functional/image/d_s.py b/src/torchmetrics/functional/image/d_s.py index fb5c2bdecbc..d1b6dd5ac13 100644 --- a/src/torchmetrics/functional/image/d_s.py +++ b/src/torchmetrics/functional/image/d_s.py @@ -16,7 +16,6 @@ import torch from torch import Tensor -from torchvision.transforms.functional import resize from typing_extensions import Literal from torchmetrics.functional.image.uqi import universal_image_quality_index @@ -162,7 +161,7 @@ def _spatial_distortion_index_compute( pan_degraded = pan_lr if pan_degraded is None: from kornia.filters import filter2d - + from torchvision.transforms.functional import resize kernel = torch.ones(size=(1, ws, ws)) pan_degraded = filter2d(pan, kernel, border_type="replicate", normalized=True) pan_degraded = resize(pan_degraded, size=ms.shape[-2:], antialias=False) From 4035bfeec7b01f30fe2a62a5758b3d52310458b5 Mon Sep 17 00:00:00 2001 From: Vincent Chan Date: Tue, 5 Dec 2023 23:46:10 +0000 Subject: [PATCH 09/32] fix style --- src/torchmetrics/functional/image/d_s.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/torchmetrics/functional/image/d_s.py b/src/torchmetrics/functional/image/d_s.py index d1b6dd5ac13..743a7fecfcf 100644 --- a/src/torchmetrics/functional/image/d_s.py +++ b/src/torchmetrics/functional/image/d_s.py @@ -162,6 +162,7 @@ def _spatial_distortion_index_compute( if pan_degraded is None: from kornia.filters import filter2d from torchvision.transforms.functional import resize + kernel = torch.ones(size=(1, ws, ws)) pan_degraded = filter2d(pan, kernel, border_type="replicate", normalized=True) pan_degraded = resize(pan_degraded, size=ms.shape[-2:], antialias=False) From b4fbb08eb3646ededc9358481de6f33f53f7f1bf Mon Sep 17 00:00:00 2001 From: Vincent Chan Date: Wed, 6 Dec 2023 00:15:50 +0000 Subject: [PATCH 10/32] fix type hint --- src/torchmetrics/functional/image/d_s.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchmetrics/functional/image/d_s.py b/src/torchmetrics/functional/image/d_s.py index 743a7fecfcf..ae0f6f8e4a2 100644 --- a/src/torchmetrics/functional/image/d_s.py +++ b/src/torchmetrics/functional/image/d_s.py @@ -22,7 +22,7 @@ from torchmetrics.utilities.distributed import reduce -def _spatial_distortion_index_update(preds: Tensor, target: Dict[str, Tensor]) -> Tuple[Tensor, Tensor]: +def _spatial_distortion_index_update(preds: Tensor, target: Dict[str, Tensor]) -> Tuple[Tensor, Dict[str, Tensor]]: """Update and returns variables required to compute Spatial Distortion Index. Args: From dace7f6ef3de768886632d6384e6c910cbf2145b Mon Sep 17 00:00:00 2001 From: Vincent Chan Date: Wed, 6 Dec 2023 01:46:27 +0000 Subject: [PATCH 11/32] fix missing link in doc --- docs/source/links.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/source/links.rst b/docs/source/links.rst index 4ee095dd659..eafa1a5ffa7 100644 --- a/docs/source/links.rst +++ b/docs/source/links.rst @@ -84,6 +84,7 @@ .. _MultiScaleSSIM: https://ece.uwaterloo.ca/~z70wang/publications/msssim .. _UniversalImageQualityIndex: https://ieeexplore.ieee.org/abstract/document/995823 .. _SpectralDistortionIndex: https://www.semanticscholar.org/paper/Multispectral-and-panchromatic-data-fusion-without-Alparone-Aiazzi/b6db12e3785326577cb95fd743fecbf5bc66c7c9 +.. _SpatialDistortionIndex: https://www.semanticscholar.org/paper/Multispectral-and-panchromatic-data-fusion-without-Alparone-Aiazzi/b6db12e3785326577cb95fd743fecbf5bc66c7c9 .. _RelativeAverageSpectralError: https://www.semanticscholar.org/paper/Data-Fusion.-Definitions-and-Architectures-Fusion-Wald/51b2b81e5124b3bb7ec53517a5dd64d8e348cadf .. _WMAPE: https://en.wikipedia.org/wiki/WMAPE .. _CER: https://rechtsprechung-im-ostseeraum.archiv.uni-greifswald.de/word-error-rate-character-error-rate-how-to-evaluate-a-model From 2fa495106f96c051eb8a16c77807e1b8dac2bfab Mon Sep 17 00:00:00 2001 From: Vincent Chan Date: Thu, 7 Dec 2023 23:18:47 +0000 Subject: [PATCH 12/32] fix type hint --- src/torchmetrics/image/d_s.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/torchmetrics/image/d_s.py b/src/torchmetrics/image/d_s.py index 4a40a9c75fe..7a88c27ab8f 100644 --- a/src/torchmetrics/image/d_s.py +++ b/src/torchmetrics/image/d_s.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Sequence, Union +from typing import Any, Dict, List, Optional, Sequence, Union from torch import Tensor from typing_extensions import Literal @@ -114,7 +114,7 @@ def __init__( self.add_state("pan", default=[], dist_reduce_fx="cat") self.add_state("pan_lr", default=[], dist_reduce_fx="cat") - def update(self, preds: Tensor, target: Tensor) -> None: + def update(self, preds: Tensor, target: Dict[str, Tensor]) -> None: """Update state with preds and target.""" preds, target = _spatial_distortion_index_update(preds, target) self.preds.append(preds) From 075cf50f7f7d992367c57d414a273f39a69cd435 Mon Sep 17 00:00:00 2001 From: Vincent Chan Date: Fri, 8 Dec 2023 00:18:56 +0000 Subject: [PATCH 13/32] fix mypy error --- src/torchmetrics/functional/image/d_s.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/torchmetrics/functional/image/d_s.py b/src/torchmetrics/functional/image/d_s.py index ae0f6f8e4a2..d21d8276a1f 100644 --- a/src/torchmetrics/functional/image/d_s.py +++ b/src/torchmetrics/functional/image/d_s.py @@ -167,6 +167,8 @@ def _spatial_distortion_index_compute( pan_degraded = filter2d(pan, kernel, border_type="replicate", normalized=True) pan_degraded = resize(pan_degraded, size=ms.shape[-2:], antialias=False) + assert pan_degraded is not None + m1 = torch.zeros(length, device=preds.device) m2 = torch.zeros(length, device=preds.device) From d6fccf63356849da8e80b07589603303e4820f1d Mon Sep 17 00:00:00 2001 From: Vincent Chan Date: Fri, 8 Dec 2023 00:24:25 +0000 Subject: [PATCH 14/32] remove dependence of kornia --- requirements/image.txt | 1 - requirements/image_test.txt | 1 + src/torchmetrics/functional/image/d_s.py | 6 +++--- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/requirements/image.txt b/requirements/image.txt index d9bda03d0a9..4d15d7b79ba 100644 --- a/requirements/image.txt +++ b/requirements/image.txt @@ -4,4 +4,3 @@ scipy >1.0.0, <1.11.0 torchvision >=0.8, <0.17.0 torch-fidelity <=0.4.0 # bumping to allow install version from master, now used in testing -kornia >=0.6.7, <0.7.1 diff --git a/requirements/image_test.txt b/requirements/image_test.txt index f10ace1580c..ae48f31ab5a 100644 --- a/requirements/image_test.txt +++ b/requirements/image_test.txt @@ -2,6 +2,7 @@ # in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment scikit-image >=0.19.0, <=0.21.0 +kornia >=0.6.7, <0.7.1 pytorch-msssim ==1.0.0 sewar >=0.4.4, <=0.4.6 numpy <1.25.0 diff --git a/src/torchmetrics/functional/image/d_s.py b/src/torchmetrics/functional/image/d_s.py index d21d8276a1f..27002658f6b 100644 --- a/src/torchmetrics/functional/image/d_s.py +++ b/src/torchmetrics/functional/image/d_s.py @@ -160,11 +160,11 @@ def _spatial_distortion_index_compute( pan_degraded = pan_lr if pan_degraded is None: - from kornia.filters import filter2d from torchvision.transforms.functional import resize - kernel = torch.ones(size=(1, ws, ws)) - pan_degraded = filter2d(pan, kernel, border_type="replicate", normalized=True) + from torchmetrics.functional.image.helper import _uniform_filter + + pan_degraded = _uniform_filter(pan, window_size=ws) pan_degraded = resize(pan_degraded, size=ms.shape[-2:], antialias=False) assert pan_degraded is not None From 9df37821de1597799022c3460798672445069fe6 Mon Sep 17 00:00:00 2001 From: Vincent Chan Date: Fri, 8 Dec 2023 23:21:21 +0000 Subject: [PATCH 15/32] fixed ruff error --- src/torchmetrics/functional/image/d_s.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/torchmetrics/functional/image/d_s.py b/src/torchmetrics/functional/image/d_s.py index 27002658f6b..f3349f13448 100644 --- a/src/torchmetrics/functional/image/d_s.py +++ b/src/torchmetrics/functional/image/d_s.py @@ -158,16 +158,15 @@ def _spatial_distortion_index_compute( f"Expected `ms` and `pan_lr` to have the same width. Got ms: {ms_w} and pan_lr: {pan_lr_w}." ) - pan_degraded = pan_lr - if pan_degraded is None: + if pan_lr is None: from torchvision.transforms.functional import resize from torchmetrics.functional.image.helper import _uniform_filter pan_degraded = _uniform_filter(pan, window_size=ws) pan_degraded = resize(pan_degraded, size=ms.shape[-2:], antialias=False) - - assert pan_degraded is not None + else: + pan_degraded = pan_lr m1 = torch.zeros(length, device=preds.device) m2 = torch.zeros(length, device=preds.device) From 0f8a1e68a86dde3859e57be2157c9f701ef5cd1f Mon Sep 17 00:00:00 2001 From: Vincent Chan Date: Mon, 11 Dec 2023 15:54:32 +0000 Subject: [PATCH 16/32] Update docs/source/image/spatial_distortion_index.rst Co-authored-by: Nicki Skafte Detlefsen --- docs/source/image/spatial_distortion_index.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/image/spatial_distortion_index.rst b/docs/source/image/spatial_distortion_index.rst index 86af7b3c96e..f7d95d0c0df 100644 --- a/docs/source/image/spatial_distortion_index.rst +++ b/docs/source/image/spatial_distortion_index.rst @@ -5,9 +5,9 @@ .. include:: ../links.rst -######################### +######################## Spatial Distortion Index -######################### +######################## Module Interface ________________ From caf9dea33a8a66993ff7ee294ec6bc1140d5fd62 Mon Sep 17 00:00:00 2001 From: Vincent Chan Date: Mon, 11 Dec 2023 15:55:53 +0000 Subject: [PATCH 17/32] Update src/torchmetrics/functional/image/d_s.py Co-authored-by: Nicki Skafte Detlefsen --- src/torchmetrics/functional/image/d_s.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/torchmetrics/functional/image/d_s.py b/src/torchmetrics/functional/image/d_s.py index f3349f13448..7e933fe9c7b 100644 --- a/src/torchmetrics/functional/image/d_s.py +++ b/src/torchmetrics/functional/image/d_s.py @@ -159,6 +159,8 @@ def _spatial_distortion_index_compute( ) if pan_lr is None: + if not _TORCHVISION_AVAILABLE: + raise ValueError("When `pan_lr` is not provided as input to metric Spatial distortion index, torchvision should be installed. Please install with `pip install torchvision` or `pip install torchmetrics[image]`.") from torchvision.transforms.functional import resize from torchmetrics.functional.image.helper import _uniform_filter From 7f38b0add0288158bdccd791ae6568c11512dc93 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 11 Dec 2023 15:56:56 +0000 Subject: [PATCH 18/32] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/torchmetrics/functional/image/d_s.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/torchmetrics/functional/image/d_s.py b/src/torchmetrics/functional/image/d_s.py index 7e933fe9c7b..1ee6ba5b64c 100644 --- a/src/torchmetrics/functional/image/d_s.py +++ b/src/torchmetrics/functional/image/d_s.py @@ -160,7 +160,9 @@ def _spatial_distortion_index_compute( if pan_lr is None: if not _TORCHVISION_AVAILABLE: - raise ValueError("When `pan_lr` is not provided as input to metric Spatial distortion index, torchvision should be installed. Please install with `pip install torchvision` or `pip install torchmetrics[image]`.") + raise ValueError( + "When `pan_lr` is not provided as input to metric Spatial distortion index, torchvision should be installed. Please install with `pip install torchvision` or `pip install torchmetrics[image]`." + ) from torchvision.transforms.functional import resize from torchmetrics.functional.image.helper import _uniform_filter From 00f29a004b92be5af9a19e20526ee633d73d9553 Mon Sep 17 00:00:00 2001 From: Vincent Chan Date: Mon, 11 Dec 2023 15:57:49 +0000 Subject: [PATCH 19/32] Update src/torchmetrics/image/d_s.py Co-authored-by: Nicki Skafte Detlefsen --- src/torchmetrics/image/d_s.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/torchmetrics/image/d_s.py b/src/torchmetrics/image/d_s.py index 7a88c27ab8f..0adbd065965 100644 --- a/src/torchmetrics/image/d_s.py +++ b/src/torchmetrics/image/d_s.py @@ -129,11 +129,8 @@ def compute(self) -> Tensor: ms = dim_zero_cat(self.ms) pan = dim_zero_cat(self.pan) pan_lr = dim_zero_cat(self.pan_lr) if len(self.pan_lr) > 0 else None - target = { - "ms": ms, - "pan": pan, - **({"pan_lr": pan_lr} if pan_lr is not None else {}), - } + target = {"ms": ms, "pan": pan} + target.update({"pan_lr": pan_lr} if pan_lr is not None else {}) return _spatial_distortion_index_compute(preds, target, self.p, self.ws, self.reduction) def plot( From c64915b61ad01be28cf70e35995ff8979e3cf307 Mon Sep 17 00:00:00 2001 From: Vincent Chan Date: Mon, 11 Dec 2023 18:08:31 +0000 Subject: [PATCH 20/32] fix style --- src/torchmetrics/functional/image/d_s.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/torchmetrics/functional/image/d_s.py b/src/torchmetrics/functional/image/d_s.py index 1ee6ba5b64c..a08b02ca52f 100644 --- a/src/torchmetrics/functional/image/d_s.py +++ b/src/torchmetrics/functional/image/d_s.py @@ -20,6 +20,7 @@ from torchmetrics.functional.image.uqi import universal_image_quality_index from torchmetrics.utilities.distributed import reduce +from torchmetrics.utilities.imports import _TORCHVISION_AVAILABLE def _spatial_distortion_index_update(preds: Tensor, target: Dict[str, Tensor]) -> Tuple[Tensor, Dict[str, Tensor]]: @@ -161,7 +162,8 @@ def _spatial_distortion_index_compute( if pan_lr is None: if not _TORCHVISION_AVAILABLE: raise ValueError( - "When `pan_lr` is not provided as input to metric Spatial distortion index, torchvision should be installed. Please install with `pip install torchvision` or `pip install torchmetrics[image]`." + "When `pan_lr` is not provided as input to metric Spatial distortion index, torchvision should be " + "installed. Please install with `pip install torchvision` or `pip install torchmetrics[image]`." ) from torchvision.transforms.functional import resize From f169369d83164e22812abc374902453b463e8806 Mon Sep 17 00:00:00 2001 From: Vincent Chan Date: Mon, 11 Dec 2023 20:15:10 +0000 Subject: [PATCH 21/32] moved checking of tensor input to update --- src/torchmetrics/functional/image/d_s.py | 83 +++++++++++++----------- 1 file changed, 46 insertions(+), 37 deletions(-) diff --git a/src/torchmetrics/functional/image/d_s.py b/src/torchmetrics/functional/image/d_s.py index a08b02ca52f..138cfcbbd55 100644 --- a/src/torchmetrics/functional/image/d_s.py +++ b/src/torchmetrics/functional/image/d_s.py @@ -46,6 +46,12 @@ def _spatial_distortion_index_update(preds: Tensor, target: Dict[str, Tensor]) - If ``preds`` and ``target`` don't have the same batch and channel sizes. ValueError: If ``target`` doesn't have ``ms`` and ``pan``. + ValueError: + If ``preds`` and ``pan`` don't have the same dimension. + ValueError: + If ``ms`` and ``pan_lr`` don't have the same dimension. + ValueError: + If ``preds`` and ``pan`` don't have dimension which is multiple of that of ``ms``. """ if len(preds.shape) != 4: @@ -67,6 +73,44 @@ def _spatial_distortion_index_update(preds: Tensor, target: Dict[str, Tensor]) - f"Expected `preds` and `{name}` to have same batch and channel sizes. " "Got preds: {preds.shape} and {name}: {t.shape}." ) + + ms = target["ms"] + pan = target["pan"] + pan_lr = target["pan_lr"] if "pan_lr" in target else None + + preds_h, preds_w = preds.shape[-2:] + ms_h, ms_w = ms.shape[-2:] + pan_h, pan_w = pan.shape[-2:] + if preds_h != pan_h: + raise ValueError(f"Expected `preds` and `pan` to have the same height. Got preds: {preds_h} and pan: {pan_h}") + if preds_w != pan_w: + raise ValueError(f"Expected `preds` and `pan` to have the same width. Got preds: {preds_w} and pan: {pan_w}") + if preds_h % ms_h != 0: + raise ValueError( + f"Expected height of `preds` to be multiple of height of `ms`. Got preds: {preds_h} and ms: {ms_h}." + ) + if preds_w % ms_w != 0: + raise ValueError( + f"Expected width of `preds` to be multiple of width of `ms`. Got preds: {preds_w} and ms: {ms_w}." + ) + if pan_h % ms_h != 0: + raise ValueError( + f"Expected height of `pan` to be multiple of height of `ms`. Got preds: {pan_h} and ms: {ms_h}." + ) + if pan_w % ms_w != 0: + raise ValueError(f"Expected width of `pan` to be multiple of width of `ms`. Got preds: {pan_w} and ms: {ms_w}.") + + if pan_lr is not None: + pan_lr_h, pan_lr_w = pan_lr.shape[-2:] + if pan_lr_h != ms_h: + raise ValueError( + f"Expected `ms` and `pan_lr` to have the same height. Got ms: {ms_h} and pan_lr: {pan_lr_h}." + ) + if pan_lr_w != ms_w: + raise ValueError( + f"Expected `ms` and `pan_lr` to have the same width. Got ms: {ms_w} and pan_lr: {pan_lr_w}." + ) + return preds, target @@ -99,12 +143,8 @@ def _spatial_distortion_index_compute( Tensor with SpatialDistortionIndex score Raises: - ValueError: - If ``preds`` and ``pan`` don't have the same dimension. - ValueError: - If ``ms`` and ``pan_lr`` don't have the same dimension. - ValueError: - If ``preds`` and ``pan`` don't have dimension which is multiple of that of ``ms``. + ValueError + If ``ws`` is smaller than dimension of ``ms``. Example: >>> _ = torch.manual_seed(42) @@ -124,41 +164,10 @@ def _spatial_distortion_index_compute( pan = target["pan"] pan_lr = target["pan_lr"] if "pan_lr" in target else None - preds_h, preds_w = preds.shape[-2:] ms_h, ms_w = ms.shape[-2:] - pan_h, pan_w = pan.shape[-2:] - if preds_h != pan_h: - raise ValueError(f"Expected `preds` and `pan` to have the same height. Got preds: {preds_h} and pan: {pan_h}") - if preds_w != pan_w: - raise ValueError(f"Expected `preds` and `pan` to have the same width. Got preds: {preds_w} and pan: {pan_w}") - if preds_h % ms_h != 0: - raise ValueError( - f"Expected height of `preds` to be multiple of height of `ms`. Got preds: {preds_h} and ms: {ms_h}." - ) - if preds_w % ms_w != 0: - raise ValueError( - f"Expected width of `preds` to be multiple of width of `ms`. Got preds: {preds_w} and ms: {ms_w}." - ) - if pan_h % ms_h != 0: - raise ValueError( - f"Expected height of `pan` to be multiple of height of `ms`. Got preds: {pan_h} and ms: {ms_h}." - ) - if pan_w % ms_w != 0: - raise ValueError(f"Expected width of `pan` to be multiple of width of `ms`. Got preds: {pan_w} and ms: {ms_w}.") if ws >= ms_h or ws >= ms_w: raise ValueError(f"Expected `ws` to be smaller than dimension of `ms`. Got ws: {ws}.") - if pan_lr is not None: - pan_lr_h, pan_lr_w = pan_lr.shape[-2:] - if pan_lr_h != ms_h: - raise ValueError( - f"Expected `ms` and `pan_lr` to have the same height. Got ms: {ms_h} and pan_lr: {pan_lr_h}." - ) - if pan_lr_w != ms_w: - raise ValueError( - f"Expected `ms` and `pan_lr` to have the same width. Got ms: {ms_w} and pan_lr: {pan_lr_w}." - ) - if pan_lr is None: if not _TORCHVISION_AVAILABLE: raise ValueError( From cd9c4dc004207c35eea8ecff6deea8067850d898 Mon Sep 17 00:00:00 2001 From: Vincent Chan Date: Thu, 14 Dec 2023 21:17:46 +0000 Subject: [PATCH 22/32] Update src/torchmetrics/functional/image/d_s.py Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com> --- src/torchmetrics/functional/image/d_s.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/torchmetrics/functional/image/d_s.py b/src/torchmetrics/functional/image/d_s.py index 138cfcbbd55..9bd8165eb6e 100644 --- a/src/torchmetrics/functional/image/d_s.py +++ b/src/torchmetrics/functional/image/d_s.py @@ -61,8 +61,8 @@ def _spatial_distortion_index_update(preds: Tensor, target: Dict[str, Tensor]) - for name, t in target.items(): if preds.dtype != t.dtype: raise TypeError( - f"Expected `preds` and `{name}` to have the same data type. " - "Got preds: {preds.dtype} and {name}: {t.dtype}." + f"Expected `preds` and `{name}` to have the same data type." + " Got preds: {preds.dtype} and {name}: {t.dtype}." ) for name, t in target.items(): if len(t.shape) != 4: From c6087d6251510da5a333656be9d7ef0c329c21ba Mon Sep 17 00:00:00 2001 From: Vincent Chan Date: Thu, 14 Dec 2023 21:18:35 +0000 Subject: [PATCH 23/32] Update src/torchmetrics/functional/image/d_s.py Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com> --- src/torchmetrics/functional/image/d_s.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchmetrics/functional/image/d_s.py b/src/torchmetrics/functional/image/d_s.py index 9bd8165eb6e..4d2858219d3 100644 --- a/src/torchmetrics/functional/image/d_s.py +++ b/src/torchmetrics/functional/image/d_s.py @@ -117,7 +117,7 @@ def _spatial_distortion_index_update(preds: Tensor, target: Dict[str, Tensor]) - def _spatial_distortion_index_compute( preds: Tensor, target: Dict[str, Tensor], - p: int = 1, + norm_order: int = 1, ws: int = 7, reduction: Literal["elementwise_mean", "sum", "none"] = "elementwise_mean", ) -> Tensor: From af5288bc540793d7ac3e0a95b8d8cedbb165d039 Mon Sep 17 00:00:00 2001 From: Vincent Chan Date: Thu, 14 Dec 2023 21:18:47 +0000 Subject: [PATCH 24/32] Update src/torchmetrics/functional/image/d_s.py Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com> --- src/torchmetrics/functional/image/d_s.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchmetrics/functional/image/d_s.py b/src/torchmetrics/functional/image/d_s.py index 4d2858219d3..c9cc0ba08dc 100644 --- a/src/torchmetrics/functional/image/d_s.py +++ b/src/torchmetrics/functional/image/d_s.py @@ -197,7 +197,7 @@ def spatial_distortion_index( preds: Tensor, target: Dict[str, Tensor], p: int = 1, - ws: int = 7, + window_size: int = 7, reduction: Literal["elementwise_mean", "sum", "none"] = "elementwise_mean", ) -> Tensor: """Calculate `Spatial Distortion Index`_ (SpatialDistortionIndex_) also known as D_s. From f7debf03311d945cf4f7a0e42f7807f960ad07eb Mon Sep 17 00:00:00 2001 From: Vincent Chan Date: Thu, 14 Dec 2023 21:28:26 +0000 Subject: [PATCH 25/32] changed `ws` to `window_size` --- src/torchmetrics/functional/image/d_s.py | 26 +++++----- src/torchmetrics/image/d_s.py | 12 ++--- tests/unittests/image/test_d_s.py | 61 +++++++++++++----------- 3 files changed, 53 insertions(+), 46 deletions(-) diff --git a/src/torchmetrics/functional/image/d_s.py b/src/torchmetrics/functional/image/d_s.py index c9cc0ba08dc..7f21bb47cf1 100644 --- a/src/torchmetrics/functional/image/d_s.py +++ b/src/torchmetrics/functional/image/d_s.py @@ -117,8 +117,8 @@ def _spatial_distortion_index_update(preds: Tensor, target: Dict[str, Tensor]) - def _spatial_distortion_index_compute( preds: Tensor, target: Dict[str, Tensor], - norm_order: int = 1, - ws: int = 7, + p: int = 1, + window_size: int = 7, reduction: Literal["elementwise_mean", "sum", "none"] = "elementwise_mean", ) -> Tensor: """Compute Spatial Distortion Index (SpatialDistortionIndex_). @@ -132,7 +132,7 @@ def _spatial_distortion_index_compute( - ``'pan_lr'``: (optional) low resolution panchromatic image. p: Order of the norm applied on the difference. - ws: Window size of the filter applied to degrade the high resolution panchromatic image. + window_size: Window size of the filter applied to degrade the high resolution panchromatic image. reduction: A method to reduce metric score over labels. - ``'elementwise_mean'``: takes the mean (default) @@ -144,7 +144,7 @@ def _spatial_distortion_index_compute( Raises: ValueError - If ``ws`` is smaller than dimension of ``ms``. + If ``window_size`` is smaller than dimension of ``ms``. Example: >>> _ = torch.manual_seed(42) @@ -165,8 +165,10 @@ def _spatial_distortion_index_compute( pan_lr = target["pan_lr"] if "pan_lr" in target else None ms_h, ms_w = ms.shape[-2:] - if ws >= ms_h or ws >= ms_w: - raise ValueError(f"Expected `ws` to be smaller than dimension of `ms`. Got ws: {ws}.") + if window_size >= ms_h or window_size >= ms_w: + raise ValueError( + f"Expected `window_size` to be smaller than dimension of `ms`. Got window_size: {window_size}." + ) if pan_lr is None: if not _TORCHVISION_AVAILABLE: @@ -178,7 +180,7 @@ def _spatial_distortion_index_compute( from torchmetrics.functional.image.helper import _uniform_filter - pan_degraded = _uniform_filter(pan, window_size=ws) + pan_degraded = _uniform_filter(pan, window_size=window_size) pan_degraded = resize(pan_degraded, size=ms.shape[-2:], antialias=False) else: pan_degraded = pan_lr @@ -213,7 +215,7 @@ def spatial_distortion_index( - ``'pan_lr'``: (optional) low resolution panchromatic image. p: Order of the norm applied on the difference. - ws: Window size of the filter applied to degrade the high resolution panchromatic image. + window_size: Window size of the filter applied to degrade the high resolution panchromatic image. reduction: A method to reduce metric score over labels. - ``'elementwise_mean'``: takes the mean (default) @@ -241,7 +243,7 @@ def spatial_distortion_index( ValueError: If ``p`` is not a positive integer. ValueError: - If ``ws`` is not a positive integer. + If ``window_size`` is not a positive integer. Example: >>> from torchmetrics.functional.image import spatial_distortion_index @@ -257,7 +259,7 @@ def spatial_distortion_index( """ if not isinstance(p, int) or p <= 0: raise ValueError(f"Expected `p` to be a positive integer. Got p: {p}.") - if not isinstance(ws, int) or ws <= 0: - raise ValueError(f"Expected `ws` to be a positive integer. Got ws: {ws}.") + if not isinstance(window_size, int) or window_size <= 0: + raise ValueError(f"Expected `window_size` to be a positive integer. Got window_size: {window_size}.") preds, target = _spatial_distortion_index_update(preds, target) - return _spatial_distortion_index_compute(preds, target, p, ws, reduction) + return _spatial_distortion_index_compute(preds, target, p, window_size, reduction) diff --git a/src/torchmetrics/image/d_s.py b/src/torchmetrics/image/d_s.py index 0adbd065965..6f8ff406a31 100644 --- a/src/torchmetrics/image/d_s.py +++ b/src/torchmetrics/image/d_s.py @@ -50,7 +50,7 @@ class SpatialDistortionIndex(Metric): Args: p: Order of the norm applied on the difference. - ws: Window size of the filter applied to degrade the high resolution panchromatic image. + window_size: Window size of the filter applied to degrade the high resolution panchromatic image. reduction: a method to reduce metric score over labels. - ``'elementwise_mean'``: takes the mean (default) @@ -88,7 +88,7 @@ class SpatialDistortionIndex(Metric): def __init__( self, p: int = 1, - ws: int = 7, + window_size: int = 7, reduction: Literal["elementwise_mean", "sum", "none"] = "elementwise_mean", **kwargs: Any, ) -> None: @@ -102,9 +102,9 @@ def __init__( if not isinstance(p, int) or p <= 0: raise ValueError(f"Expected `p` to be a positive integer. Got p: {p}.") self.p = p - if not isinstance(ws, int) or ws <= 0: - raise ValueError(f"Expected `ws` to be a positive integer. Got ws: {ws}.") - self.ws = ws + if not isinstance(window_size, int) or window_size <= 0: + raise ValueError(f"Expected `window_size` to be a positive integer. Got window_size: {window_size}.") + self.window_size = window_size allowed_reductions = ("elementwise_mean", "sum", "none") if reduction not in allowed_reductions: raise ValueError(f"Expected argument `reduction` be one of {allowed_reductions} but got {reduction}") @@ -131,7 +131,7 @@ def compute(self) -> Tensor: pan_lr = dim_zero_cat(self.pan_lr) if len(self.pan_lr) > 0 else None target = {"ms": ms, "pan": pan} target.update({"pan_lr": pan_lr} if pan_lr is not None else {}) - return _spatial_distortion_index_compute(preds, target, self.p, self.ws, self.reduction) + return _spatial_distortion_index_compute(preds, target, self.p, self.window_size, self.reduction) def plot( self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None diff --git a/tests/unittests/image/test_d_s.py b/tests/unittests/image/test_d_s.py index 04020c4ca76..215ce53dba6 100644 --- a/tests/unittests/image/test_d_s.py +++ b/tests/unittests/image/test_d_s.py @@ -36,11 +36,11 @@ class _Input(NamedTuple): preds: Tensor target: List[Dict[str, Tensor]] p: int - ws: int + window_size: int _inputs = [] -for size, channel, p, r, ws, pan_lr_exists, dtype in [ +for size, channel, p, r, window_size, pan_lr_exists, dtype in [ (12, 3, 1, 16, 3, False, torch.float), (13, 1, 3, 8, 5, False, torch.float32), (14, 1, 4, 4, 5, True, torch.double), @@ -62,22 +62,25 @@ class _Input(NamedTuple): for i in range(NUM_BATCHES) ], p=p, - ws=ws, + window_size=window_size, ) ) def _baseline_d_s( - preds: np.ndarray, ms: np.ndarray, pan: np.ndarray, pan_lr: np.ndarray = None, p: int = 1, ws: int = 7 + preds: np.ndarray, ms: np.ndarray, pan: np.ndarray, pan_lr: np.ndarray = None, p: int = 1, window_size: int = 7 ) -> float: """NumPy based implementation of Spatial Distortion Index, which uses UQI of TorchMetrics.""" pan_degraded = pan_lr if pan_degraded is None: try: - pan_degraded = uniform_filter(pan, size=ws, axes=[1, 2]) + pan_degraded = uniform_filter(pan, size=window_size, axes=[1, 2]) except TypeError: pan_degraded = np.array( - [[uniform_filter(pan[i, ..., j], size=ws) for j in range(pan.shape[-1])] for i in range(len(pan))] + [ + [uniform_filter(pan[i, ..., j], size=window_size) for j in range(pan.shape[-1])] + for i in range(len(pan)) + ] ).transpose((0, 2, 3, 1)) pan_degraded = np.array([resize(img, ms.shape[1:3], anti_aliasing=False) for img in pan_degraded]) @@ -100,7 +103,7 @@ def _baseline_d_s( return np.mean(diff) ** (1 / p) -def _np_d_s(preds, target, p, ws): +def _np_d_s(preds, target, p, window_size): np_preds = preds.permute(0, 2, 3, 1).cpu().numpy() assert isinstance(target, dict), f"Expected `target` to be dict. Got {type(target)}." assert "ms" in target, "Expected `target` to contain 'ms'." @@ -115,13 +118,13 @@ def _np_d_s(preds, target, p, ws): np_pan, np_pan_lr, p=p, - ws=ws, + window_size=window_size, ) @pytest.mark.parametrize( - "preds, target, p, ws", - [(i.preds, i.target, i.p, i.ws) for i in _inputs], + "preds, target, p, window_size", + [(i.preds, i.target, i.p, i.window_size) for i in _inputs], ) class TestSpatialDistortionIndex(MetricTester): """Test class for `SpatialDistortionIndex` metric.""" @@ -129,35 +132,37 @@ class TestSpatialDistortionIndex(MetricTester): atol = 3e-6 @pytest.mark.parametrize("ddp", [True, False]) - def test_d_s(self, preds, target, p, ws, ddp): + def test_d_s(self, preds, target, p, window_size, ddp): """Test class implementation of metric.""" self.run_class_metric_test( ddp, preds, target, SpatialDistortionIndex, - partial(_np_d_s, p=p, ws=ws), - metric_args={"p": p, "ws": ws}, + partial(_np_d_s, p=p, window_size=window_size), + metric_args={"p": p, "window_size": window_size}, ) - def test_d_s_functional(self, preds, target, p, ws): + def test_d_s_functional(self, preds, target, p, window_size): """Test functional implementation of metric.""" self.run_functional_metric_test( preds, target, spatial_distortion_index, - partial(_np_d_s, p=p, ws=ws), - metric_args={"p": p, "ws": ws}, + partial(_np_d_s, p=p, window_size=window_size), + metric_args={"p": p, "window_size": window_size}, ) @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") - def test_d_s_half_gpu(self, preds, target, p, ws): + def test_d_s_half_gpu(self, preds, target, p, window_size): """Test dtype support of the metric on GPU.""" - self.run_precision_test_gpu(preds, target, SpatialDistortionIndex, spatial_distortion_index, {"p": p, "ws": ws}) + self.run_precision_test_gpu( + preds, target, SpatialDistortionIndex, spatial_distortion_index, {"p": p, "window_size": window_size} + ) @pytest.mark.parametrize( - ("preds", "target", "p", "ws", "match"), + ("preds", "target", "p", "window_size", "match"), [ ( [1, 16, 16], @@ -221,15 +226,15 @@ def test_d_s_half_gpu(self, preds, target, p, ws): {"ms": [1, 1, 4, 4], "pan": [1, 1, 16, 16]}, 1, 0, - "Expected `ws` to be a positive integer. Got ws: 0.", - ), # invalid ws + "Expected `window_size` to be a positive integer. Got window_size: 0.", + ), # invalid window_size ( [1, 1, 16, 16], {"ms": [1, 1, 4, 4], "pan": [1, 1, 16, 16]}, 1, -1, - "Expected `ws` to be a positive integer. Got ws: -1.", - ), # invalid ws + "Expected `window_size` to be a positive integer. Got window_size: -1.", + ), # invalid window_size ( [1, 1, 16, 16], {"ms": [1, 1, 4, 4], "pan": [1, 1, 17, 16]}, @@ -319,16 +324,16 @@ def test_d_s_half_gpu(self, preds, target, p, ws): {"ms": [1, 1, 4, 4], "pan": [1, 1, 16, 16]}, 1, 5, - "Expected `ws` to be smaller than dimension of `ms`.*", - ), # invalid ws + "Expected `window_size` to be smaller than dimension of `ms`.*", + ), # invalid window_size ], ) -def test_d_s_invalid_inputs(preds, target, p, ws, match): +def test_d_s_invalid_inputs(preds, target, p, window_size, match): """Test that invalid input raises the correct errors.""" preds_t = torch.rand(preds) target_t = {name: torch.rand(t) for name, t in target.items()} with pytest.raises(ValueError, match=match): - spatial_distortion_index(preds_t, target_t, p, ws) + spatial_distortion_index(preds_t, target_t, p, window_size) @pytest.mark.parametrize( @@ -362,4 +367,4 @@ def test_d_s_invalid_type(target, match): """Test that error is raised on different dtypes.""" preds_t = torch.rand((1, 1, 16, 16)) with pytest.raises(TypeError, match=match): - spatial_distortion_index(preds_t, target, p=1, ws=7) + spatial_distortion_index(preds_t, target, p=1, window_size=7) From 3137b47214e1f8ea802de3fe8cfc8a4520b4f80b Mon Sep 17 00:00:00 2001 From: Vincent Chan Date: Thu, 14 Dec 2023 21:39:16 +0000 Subject: [PATCH 26/32] changed `p` to `norm_order` --- src/torchmetrics/functional/image/d_s.py | 20 ++++----- src/torchmetrics/image/d_s.py | 12 ++--- tests/unittests/image/test_d_s.py | 57 ++++++++++++++---------- 3 files changed, 49 insertions(+), 40 deletions(-) diff --git a/src/torchmetrics/functional/image/d_s.py b/src/torchmetrics/functional/image/d_s.py index 7f21bb47cf1..46cbeeab318 100644 --- a/src/torchmetrics/functional/image/d_s.py +++ b/src/torchmetrics/functional/image/d_s.py @@ -117,7 +117,7 @@ def _spatial_distortion_index_update(preds: Tensor, target: Dict[str, Tensor]) - def _spatial_distortion_index_compute( preds: Tensor, target: Dict[str, Tensor], - p: int = 1, + norm_order: int = 1, window_size: int = 7, reduction: Literal["elementwise_mean", "sum", "none"] = "elementwise_mean", ) -> Tensor: @@ -131,7 +131,7 @@ def _spatial_distortion_index_compute( - ``'pan'``: high resolution panchromatic image. - ``'pan_lr'``: (optional) low resolution panchromatic image. - p: Order of the norm applied on the difference. + norm_order: Order of the norm applied on the difference. window_size: Window size of the filter applied to degrade the high resolution panchromatic image. reduction: A method to reduce metric score over labels. @@ -191,14 +191,14 @@ def _spatial_distortion_index_compute( for i in range(length): m1[i] = universal_image_quality_index(ms[:, i : i + 1], pan_degraded[:, i : i + 1]) m2[i] = universal_image_quality_index(preds[:, i : i + 1], pan[:, i : i + 1]) - diff = (m1 - m2).abs() ** p - return reduce(diff, reduction) ** (1 / p) + diff = (m1 - m2).abs() ** norm_order + return reduce(diff, reduction) ** (1 / norm_order) def spatial_distortion_index( preds: Tensor, target: Dict[str, Tensor], - p: int = 1, + norm_order: int = 1, window_size: int = 7, reduction: Literal["elementwise_mean", "sum", "none"] = "elementwise_mean", ) -> Tensor: @@ -214,7 +214,7 @@ def spatial_distortion_index( - ``'pan'``: high resolution panchromatic image. - ``'pan_lr'``: (optional) low resolution panchromatic image. - p: Order of the norm applied on the difference. + norm_order: Order of the norm applied on the difference. window_size: Window size of the filter applied to degrade the high resolution panchromatic image. reduction: A method to reduce metric score over labels. @@ -241,7 +241,7 @@ def spatial_distortion_index( ValueError: If ``preds`` and ``pan`` don't have dimension which is multiple of that of ``ms``. ValueError: - If ``p`` is not a positive integer. + If ``norm_order`` is not a positive integer. ValueError: If ``window_size`` is not a positive integer. @@ -257,9 +257,9 @@ def spatial_distortion_index( tensor(0.0090) """ - if not isinstance(p, int) or p <= 0: - raise ValueError(f"Expected `p` to be a positive integer. Got p: {p}.") + if not isinstance(norm_order, int) or norm_order <= 0: + raise ValueError(f"Expected `norm_order` to be a positive integer. Got norm_order: {norm_order}.") if not isinstance(window_size, int) or window_size <= 0: raise ValueError(f"Expected `window_size` to be a positive integer. Got window_size: {window_size}.") preds, target = _spatial_distortion_index_update(preds, target) - return _spatial_distortion_index_compute(preds, target, p, window_size, reduction) + return _spatial_distortion_index_compute(preds, target, norm_order, window_size, reduction) diff --git a/src/torchmetrics/image/d_s.py b/src/torchmetrics/image/d_s.py index 6f8ff406a31..0f23ab0712b 100644 --- a/src/torchmetrics/image/d_s.py +++ b/src/torchmetrics/image/d_s.py @@ -49,7 +49,7 @@ class SpatialDistortionIndex(Metric): over sample else returns tensor of shape ``(N,)`` with SDI values per sample Args: - p: Order of the norm applied on the difference. + norm_order: Order of the norm applied on the difference. window_size: Window size of the filter applied to degrade the high resolution panchromatic image. reduction: a method to reduce metric score over labels. @@ -87,7 +87,7 @@ class SpatialDistortionIndex(Metric): def __init__( self, - p: int = 1, + norm_order: int = 1, window_size: int = 7, reduction: Literal["elementwise_mean", "sum", "none"] = "elementwise_mean", **kwargs: Any, @@ -99,9 +99,9 @@ def __init__( " to large memory footprint." ) - if not isinstance(p, int) or p <= 0: - raise ValueError(f"Expected `p` to be a positive integer. Got p: {p}.") - self.p = p + if not isinstance(norm_order, int) or norm_order <= 0: + raise ValueError(f"Expected `norm_order` to be a positive integer. Got norm_order: {norm_order}.") + self.norm_order = norm_order if not isinstance(window_size, int) or window_size <= 0: raise ValueError(f"Expected `window_size` to be a positive integer. Got window_size: {window_size}.") self.window_size = window_size @@ -131,7 +131,7 @@ def compute(self) -> Tensor: pan_lr = dim_zero_cat(self.pan_lr) if len(self.pan_lr) > 0 else None target = {"ms": ms, "pan": pan} target.update({"pan_lr": pan_lr} if pan_lr is not None else {}) - return _spatial_distortion_index_compute(preds, target, self.p, self.window_size, self.reduction) + return _spatial_distortion_index_compute(preds, target, self.norm_order, self.window_size, self.reduction) def plot( self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None diff --git a/tests/unittests/image/test_d_s.py b/tests/unittests/image/test_d_s.py index 215ce53dba6..0164dc733d5 100644 --- a/tests/unittests/image/test_d_s.py +++ b/tests/unittests/image/test_d_s.py @@ -35,12 +35,12 @@ class _Input(NamedTuple): preds: Tensor target: List[Dict[str, Tensor]] - p: int + norm_order: int window_size: int _inputs = [] -for size, channel, p, r, window_size, pan_lr_exists, dtype in [ +for size, channel, norm_order, r, window_size, pan_lr_exists, dtype in [ (12, 3, 1, 16, 3, False, torch.float), (13, 1, 3, 8, 5, False, torch.float32), (14, 1, 4, 4, 5, True, torch.double), @@ -61,14 +61,19 @@ class _Input(NamedTuple): } for i in range(NUM_BATCHES) ], - p=p, + norm_order=norm_order, window_size=window_size, ) ) def _baseline_d_s( - preds: np.ndarray, ms: np.ndarray, pan: np.ndarray, pan_lr: np.ndarray = None, p: int = 1, window_size: int = 7 + preds: np.ndarray, + ms: np.ndarray, + pan: np.ndarray, + pan_lr: np.ndarray = None, + norm_order: int = 1, + window_size: int = 7, ) -> float: """NumPy based implementation of Spatial Distortion Index, which uses UQI of TorchMetrics.""" pan_degraded = pan_lr @@ -99,11 +104,11 @@ def _baseline_d_s( for i in range(length): m1[i] = universal_image_quality_index(ms[:, i : i + 1], pan_degraded[:, i : i + 1]) m2[i] = universal_image_quality_index(preds[:, i : i + 1], pan[:, i : i + 1]) - diff = np.abs(m1 - m2) ** p - return np.mean(diff) ** (1 / p) + diff = np.abs(m1 - m2) ** norm_order + return np.mean(diff) ** (1 / norm_order) -def _np_d_s(preds, target, p, window_size): +def _np_d_s(preds, target, norm_order, window_size): np_preds = preds.permute(0, 2, 3, 1).cpu().numpy() assert isinstance(target, dict), f"Expected `target` to be dict. Got {type(target)}." assert "ms" in target, "Expected `target` to contain 'ms'." @@ -117,14 +122,14 @@ def _np_d_s(preds, target, p, window_size): np_ms, np_pan, np_pan_lr, - p=p, + norm_order=norm_order, window_size=window_size, ) @pytest.mark.parametrize( - "preds, target, p, window_size", - [(i.preds, i.target, i.p, i.window_size) for i in _inputs], + "preds, target, norm_order, window_size", + [(i.preds, i.target, i.norm_order, i.window_size) for i in _inputs], ) class TestSpatialDistortionIndex(MetricTester): """Test class for `SpatialDistortionIndex` metric.""" @@ -132,37 +137,41 @@ class TestSpatialDistortionIndex(MetricTester): atol = 3e-6 @pytest.mark.parametrize("ddp", [True, False]) - def test_d_s(self, preds, target, p, window_size, ddp): + def test_d_s(self, preds, target, norm_order, window_size, ddp): """Test class implementation of metric.""" self.run_class_metric_test( ddp, preds, target, SpatialDistortionIndex, - partial(_np_d_s, p=p, window_size=window_size), - metric_args={"p": p, "window_size": window_size}, + partial(_np_d_s, norm_order=norm_order, window_size=window_size), + metric_args={"norm_order": norm_order, "window_size": window_size}, ) - def test_d_s_functional(self, preds, target, p, window_size): + def test_d_s_functional(self, preds, target, norm_order, window_size): """Test functional implementation of metric.""" self.run_functional_metric_test( preds, target, spatial_distortion_index, - partial(_np_d_s, p=p, window_size=window_size), - metric_args={"p": p, "window_size": window_size}, + partial(_np_d_s, norm_order=norm_order, window_size=window_size), + metric_args={"norm_order": norm_order, "window_size": window_size}, ) @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") - def test_d_s_half_gpu(self, preds, target, p, window_size): + def test_d_s_half_gpu(self, preds, target, norm_order, window_size): """Test dtype support of the metric on GPU.""" self.run_precision_test_gpu( - preds, target, SpatialDistortionIndex, spatial_distortion_index, {"p": p, "window_size": window_size} + preds, + target, + SpatialDistortionIndex, + spatial_distortion_index, + {"norm_order": norm_order, "window_size": window_size}, ) @pytest.mark.parametrize( - ("preds", "target", "p", "window_size", "match"), + ("preds", "target", "norm_order", "window_size", "match"), [ ( [1, 16, 16], @@ -212,14 +221,14 @@ def test_d_s_half_gpu(self, preds, target, p, window_size): {"ms": [1, 1, 4, 4], "pan": [1, 1, 16, 16]}, 0, 3, - "Expected `p` to be a positive integer. Got p: 0.", + "Expected `norm_order` to be a positive integer. Got norm_order: 0.", ), # invalid p ( [1, 1, 16, 16], {"ms": [1, 1, 4, 4], "pan": [1, 1, 16, 16]}, -1, 3, - "Expected `p` to be a positive integer. Got p: -1.", + "Expected `norm_order` to be a positive integer. Got norm_order: -1.", ), # invalid p ( [1, 1, 16, 16], @@ -328,12 +337,12 @@ def test_d_s_half_gpu(self, preds, target, p, window_size): ), # invalid window_size ], ) -def test_d_s_invalid_inputs(preds, target, p, window_size, match): +def test_d_s_invalid_inputs(preds, target, norm_order, window_size, match): """Test that invalid input raises the correct errors.""" preds_t = torch.rand(preds) target_t = {name: torch.rand(t) for name, t in target.items()} with pytest.raises(ValueError, match=match): - spatial_distortion_index(preds_t, target_t, p, window_size) + spatial_distortion_index(preds_t, target_t, norm_order, window_size) @pytest.mark.parametrize( @@ -367,4 +376,4 @@ def test_d_s_invalid_type(target, match): """Test that error is raised on different dtypes.""" preds_t = torch.rand((1, 1, 16, 16)) with pytest.raises(TypeError, match=match): - spatial_distortion_index(preds_t, target, p=1, window_size=7) + spatial_distortion_index(preds_t, target, norm_order=1, window_size=7) From c9a506c6ed91019f63e36e2fe4b146b3ad636016 Mon Sep 17 00:00:00 2001 From: Vincent Chan Date: Thu, 14 Dec 2023 21:41:17 +0000 Subject: [PATCH 27/32] Update src/torchmetrics/functional/image/d_s.py Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com> --- src/torchmetrics/functional/image/d_s.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/torchmetrics/functional/image/d_s.py b/src/torchmetrics/functional/image/d_s.py index 46cbeeab318..2058c4023a9 100644 --- a/src/torchmetrics/functional/image/d_s.py +++ b/src/torchmetrics/functional/image/d_s.py @@ -70,8 +70,8 @@ def _spatial_distortion_index_update(preds: Tensor, target: Dict[str, Tensor]) - for name, t in target.items(): if preds.shape[:2] != t.shape[:2]: raise ValueError( - f"Expected `preds` and `{name}` to have same batch and channel sizes. " - "Got preds: {preds.shape} and {name}: {t.shape}." + f"Expected `preds` and `{name}` to have the same batch and channel sizes." + " Got preds: {preds.shape} and {name}: {t.shape}." ) ms = target["ms"] From 3b40466e98d34965c12ad78248c03e2175bc5c60 Mon Sep 17 00:00:00 2001 From: Vincent Chan Date: Thu, 14 Dec 2023 22:57:17 +0000 Subject: [PATCH 28/32] fix assert regex in tests --- tests/unittests/image/test_d_s.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/unittests/image/test_d_s.py b/tests/unittests/image/test_d_s.py index 0164dc733d5..7625190a15f 100644 --- a/tests/unittests/image/test_d_s.py +++ b/tests/unittests/image/test_d_s.py @@ -291,42 +291,42 @@ def test_d_s_half_gpu(self, preds, target, norm_order, window_size): {"ms": [1, 2, 4, 4], "pan": [1, 1, 16, 16]}, 1, 3, - "Expected `preds` and `ms` to have same batch and channel.*", + "Expected `preds` and `ms` to have the same batch and channel.*", ), # invalid ms.shape ( [1, 1, 16, 16], {"ms": [2, 1, 4, 4], "pan": [1, 1, 16, 16]}, 1, 3, - "Expected `preds` and `ms` to have same batch and channel.*", + "Expected `preds` and `ms` to have the same batch and channel.*", ), # invalid ms.shape ( [1, 1, 16, 16], {"ms": [1, 1, 4, 4], "pan": [1, 2, 16, 16]}, 1, 3, - "Expected `preds` and `pan` to have same batch and channel.*", + "Expected `preds` and `pan` to have the same batch and channel.*", ), # invalid pan.shape ( [1, 1, 16, 16], {"ms": [1, 1, 4, 4], "pan": [2, 1, 16, 16]}, 1, 3, - "Expected `preds` and `pan` to have same batch and channel.*", + "Expected `preds` and `pan` to have the same batch and channel.*", ), # invalid pan.shape ( [1, 1, 16, 16], {"ms": [1, 1, 4, 4], "pan": [1, 1, 16, 16], "pan_lr": [1, 2, 4, 4]}, 1, 3, - "Expected `preds` and `pan_lr` to have same batch and channel.*", + "Expected `preds` and `pan_lr` to have the same batch and channel.*", ), # invalid pan_lr.shape ( [1, 1, 16, 16], {"ms": [1, 1, 4, 4], "pan": [1, 1, 16, 16], "pan_lr": [2, 1, 4, 4]}, 1, 3, - "Expected `preds` and `pan_lr` to have same batch and channel.*", + "Expected `preds` and `pan_lr` to have the same batch and channel.*", ), # invalid pan_lr.shape ( [1, 1, 16, 16], From af4b50a0ffab426776105316bb02e82d82c9362c Mon Sep 17 00:00:00 2001 From: Vincent Chan Date: Fri, 15 Dec 2023 23:09:10 +0000 Subject: [PATCH 29/32] changed `_update` and `_compute` functions to take `ms`, `pan` and `pan_lr` as arguments --- src/torchmetrics/functional/image/d_s.py | 141 +++++++++-------- src/torchmetrics/image/d_s.py | 9 +- tests/unittests/image/test_d_s.py | 190 +++++++++++++++-------- 3 files changed, 198 insertions(+), 142 deletions(-) diff --git a/src/torchmetrics/functional/image/d_s.py b/src/torchmetrics/functional/image/d_s.py index 2058c4023a9..4a97a3cf151 100644 --- a/src/torchmetrics/functional/image/d_s.py +++ b/src/torchmetrics/functional/image/d_s.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, Tuple +from typing import Tuple import torch from torch import Tensor @@ -23,29 +23,27 @@ from torchmetrics.utilities.imports import _TORCHVISION_AVAILABLE -def _spatial_distortion_index_update(preds: Tensor, target: Dict[str, Tensor]) -> Tuple[Tensor, Dict[str, Tensor]]: +def _spatial_distortion_index_update( + preds: Tensor, ms: Tensor, pan: Tensor, pan_lr: Tensor = None +) -> Tuple[Tensor, Tensor, Tensor, Tensor]: """Update and returns variables required to compute Spatial Distortion Index. Args: preds: High resolution multispectral image. - target: A dictionary containing the following keys: - - - ``'ms'``: low resolution multispectral image. - - ``'pan'``: high resolution panchromatic image. - - ``'pan_lr'``: (optional) low resolution panchromatic image. + ms: Low resolution multispectral image. + pan: High resolution panchromatic image. + pan_lr: Low resolution panchromatic image. Return: - A tuple of Tensors containing ``preds`` and ``target``. + A tuple of Tensors containing ``preds``, ``ms``, ``pan`` and ``pan_lr``. Raises: TypeError: - If ``preds`` and ``target`` don't have the same data type. - ValueError: - If ``preds`` and ``target`` don't have ``BxCxHxW shape``. + If ``preds``, ``ms``, ``pan`` and ``pan_lr`` don't have the same data type. ValueError: - If ``preds`` and ``target`` don't have the same batch and channel sizes. + If ``preds``, ``ms``, ``pan`` and ``pan_lr`` don't have ``BxCxHxW shape``. ValueError: - If ``target`` doesn't have ``ms`` and ``pan``. + If ``preds``, ``ms``, ``pan`` and ``pan_lr`` don't have the same batch and channel sizes. ValueError: If ``preds`` and ``pan`` don't have the same dimension. ValueError: @@ -56,27 +54,40 @@ def _spatial_distortion_index_update(preds: Tensor, target: Dict[str, Tensor]) - """ if len(preds.shape) != 4: raise ValueError(f"Expected `preds` to have BxCxHxW shape. Got preds: {preds.shape}.") - if "ms" not in target or "pan" not in target: - raise ValueError(f"Expected `target` to have keys ('ms', 'pan'). Got target: {target.keys()}") - for name, t in target.items(): - if preds.dtype != t.dtype: - raise TypeError( - f"Expected `preds` and `{name}` to have the same data type." - " Got preds: {preds.dtype} and {name}: {t.dtype}." - ) - for name, t in target.items(): - if len(t.shape) != 4: - raise ValueError(f"Expected `{name}` to have BxCxHxW shape. Got {name}: {t.shape}.") - for name, t in target.items(): - if preds.shape[:2] != t.shape[:2]: - raise ValueError( - f"Expected `preds` and `{name}` to have the same batch and channel sizes." - " Got preds: {preds.shape} and {name}: {t.shape}." - ) - - ms = target["ms"] - pan = target["pan"] - pan_lr = target["pan_lr"] if "pan_lr" in target else None + if preds.dtype != ms.dtype: + raise TypeError( + f"Expected `preds` and `ms` to have the same data type. Got preds: {preds.dtype} and ms: {ms.dtype}." + ) + if preds.dtype != pan.dtype: + raise TypeError( + f"Expected `preds` and `pan` to have the same data type. Got preds: {preds.dtype} and pan: {pan.dtype}." + ) + if pan_lr is not None and preds.dtype != pan_lr.dtype: + raise TypeError( + f"Expected `preds` and `pan_lr` to have the same data type." + f" Got preds: {preds.dtype} and pan_lr: {pan_lr.dtype}." + ) + if len(ms.shape) != 4: + raise ValueError(f"Expected `ms` to have BxCxHxW shape. Got ms: {ms.shape}.") + if len(pan.shape) != 4: + raise ValueError(f"Expected `pan` to have BxCxHxW shape. Got pan: {pan.shape}.") + if pan_lr is not None and len(pan_lr.shape) != 4: + raise ValueError(f"Expected `pan_lr` to have BxCxHxW shape. Got pan_lr: {pan_lr.shape}.") + if preds.shape[:2] != ms.shape[:2]: + raise ValueError( + f"Expected `preds` and `ms` to have the same batch and channel sizes." + f" Got preds: {preds.shape} and ms: {ms.shape}." + ) + if preds.shape[:2] != pan.shape[:2]: + raise ValueError( + f"Expected `preds` and `pan` to have the same batch and channel sizes." + f" Got preds: {preds.shape} and pan: {pan.shape}." + ) + if pan_lr is not None and preds.shape[:2] != pan_lr.shape[:2]: + raise ValueError( + f"Expected `preds` and `pan_lr` to have the same batch and channel sizes." + f" Got preds: {preds.shape} and pan_lr: {pan_lr.shape}." + ) preds_h, preds_w = preds.shape[-2:] ms_h, ms_w = ms.shape[-2:] @@ -111,12 +122,14 @@ def _spatial_distortion_index_update(preds: Tensor, target: Dict[str, Tensor]) - f"Expected `ms` and `pan_lr` to have the same width. Got ms: {ms_w} and pan_lr: {pan_lr_w}." ) - return preds, target + return preds, ms, pan, pan_lr def _spatial_distortion_index_compute( preds: Tensor, - target: Dict[str, Tensor], + ms: Tensor, + pan: Tensor, + pan_lr: Tensor = None, norm_order: int = 1, window_size: int = 7, reduction: Literal["elementwise_mean", "sum", "none"] = "elementwise_mean", @@ -125,12 +138,9 @@ def _spatial_distortion_index_compute( Args: preds: High resolution multispectral image. - target: A dictionary containing the following keys: - - - ``'ms'``: low resolution multispectral image. - - ``'pan'``: high resolution panchromatic image. - - ``'pan_lr'``: (optional) low resolution panchromatic image. - + ms: Low resolution multispectral image. + pan: High resolution panchromatic image. + pan_lr: Low resolution panchromatic image. norm_order: Order of the norm applied on the difference. window_size: Window size of the filter applied to degrade the high resolution panchromatic image. reduction: A method to reduce metric score over labels. @@ -149,21 +159,15 @@ def _spatial_distortion_index_compute( Example: >>> _ = torch.manual_seed(42) >>> preds = torch.rand([16, 3, 32, 32]) - >>> target = { - ... 'ms': torch.rand([16, 3, 16, 16]), - ... 'pan': torch.rand([16, 3, 32, 32]), - ... } - >>> preds, target = _spatial_distortion_index_update(preds, target) - >>> _spatial_distortion_index_compute(preds, target) + >>> ms = torch.rand([16, 3, 16, 16]) + >>> pan = torch.rand([16, 3, 32, 32]) + >>> preds, ms, pan, pan_lr = _spatial_distortion_index_update(preds, ms, pan) + >>> _spatial_distortion_index_compute(preds, ms, pan, pan_lr) tensor(0.0090) """ length = preds.shape[1] - ms = target["ms"] - pan = target["pan"] - pan_lr = target["pan_lr"] if "pan_lr" in target else None - ms_h, ms_w = ms.shape[-2:] if window_size >= ms_h or window_size >= ms_w: raise ValueError( @@ -197,7 +201,9 @@ def _spatial_distortion_index_compute( def spatial_distortion_index( preds: Tensor, - target: Dict[str, Tensor], + ms: Tensor, + pan: Tensor, + pan_lr: Tensor = None, norm_order: int = 1, window_size: int = 7, reduction: Literal["elementwise_mean", "sum", "none"] = "elementwise_mean", @@ -208,12 +214,9 @@ def spatial_distortion_index( Args: preds: High resolution multispectral image. - target: A dictionary containing the following keys: - - - ``'ms'``: low resolution multispectral image. - - ``'pan'``: high resolution panchromatic image. - - ``'pan_lr'``: (optional) low resolution panchromatic image. - + ms: Low resolution multispectral image. + pan: High resolution panchromatic image. + pan_lr: Low resolution panchromatic image. norm_order: Order of the norm applied on the difference. window_size: Window size of the filter applied to degrade the high resolution panchromatic image. reduction: A method to reduce metric score over labels. @@ -227,13 +230,11 @@ def spatial_distortion_index( Raises: TypeError: - If ``preds`` and ``target`` don't have the same data type. - ValueError: - If ``preds`` and ``target`` don't have ``BxCxHxW shape``. + If ``preds``, ``ms``, ``pan`` and ``pan_lr`` don't have the same data type. ValueError: - If ``preds`` and ``target`` don't have the same batch and channel sizes. + If ``preds``, ``ms``, ``pan`` and ``pan_lr`` don't have ``BxCxHxW shape``. ValueError: - If ``target`` doesn't have ``ms`` and ``pan``. + If ``preds``, ``ms``, ``pan`` and ``pan_lr`` don't have the same batch and channel sizes. ValueError: If ``preds`` and ``pan`` don't have the same dimension. ValueError: @@ -249,11 +250,9 @@ def spatial_distortion_index( >>> from torchmetrics.functional.image import spatial_distortion_index >>> _ = torch.manual_seed(42) >>> preds = torch.rand([16, 3, 32, 32]) - >>> target = { - ... 'ms': torch.rand([16, 3, 16, 16]), - ... 'pan': torch.rand([16, 3, 32, 32]), - ... } - >>> spatial_distortion_index(preds, target) + >>> ms = torch.rand([16, 3, 16, 16]) + >>> pan = torch.rand([16, 3, 32, 32]) + >>> spatial_distortion_index(preds, ms, pan) tensor(0.0090) """ @@ -261,5 +260,5 @@ def spatial_distortion_index( raise ValueError(f"Expected `norm_order` to be a positive integer. Got norm_order: {norm_order}.") if not isinstance(window_size, int) or window_size <= 0: raise ValueError(f"Expected `window_size` to be a positive integer. Got window_size: {window_size}.") - preds, target = _spatial_distortion_index_update(preds, target) - return _spatial_distortion_index_compute(preds, target, norm_order, window_size, reduction) + preds, ms, pan, pan_lr = _spatial_distortion_index_update(preds, ms, pan, pan_lr) + return _spatial_distortion_index_compute(preds, ms, pan, pan_lr, norm_order, window_size, reduction) diff --git a/src/torchmetrics/image/d_s.py b/src/torchmetrics/image/d_s.py index 0f23ab0712b..16930c46f0a 100644 --- a/src/torchmetrics/image/d_s.py +++ b/src/torchmetrics/image/d_s.py @@ -116,7 +116,10 @@ def __init__( def update(self, preds: Tensor, target: Dict[str, Tensor]) -> None: """Update state with preds and target.""" - preds, target = _spatial_distortion_index_update(preds, target) + ms = target["ms"] if "ms" in target else None + pan = target["pan"] if "pan" in target else None + pan_lr = target["pan_lr"] if "pan_lr" in target else None + preds, ms, pan, pan_lr = _spatial_distortion_index_update(preds, ms, pan, pan_lr) self.preds.append(preds) self.ms.append(target["ms"]) self.pan.append(target["pan"]) @@ -131,7 +134,9 @@ def compute(self) -> Tensor: pan_lr = dim_zero_cat(self.pan_lr) if len(self.pan_lr) > 0 else None target = {"ms": ms, "pan": pan} target.update({"pan_lr": pan_lr} if pan_lr is not None else {}) - return _spatial_distortion_index_compute(preds, target, self.norm_order, self.window_size, self.reduction) + return _spatial_distortion_index_compute( + preds, ms, pan, pan_lr, self.norm_order, self.window_size, self.reduction + ) def plot( self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None diff --git a/tests/unittests/image/test_d_s.py b/tests/unittests/image/test_d_s.py index 7625190a15f..8cc412b694c 100644 --- a/tests/unittests/image/test_d_s.py +++ b/tests/unittests/image/test_d_s.py @@ -35,6 +35,9 @@ class _Input(NamedTuple): preds: Tensor target: List[Dict[str, Tensor]] + ms: Tensor + pan: Tensor + pan_lr: Tensor norm_order: int window_size: int @@ -61,6 +64,9 @@ class _Input(NamedTuple): } for i in range(NUM_BATCHES) ], + ms=ms, + pan=pan, + pan_lr=pan_lr if pan_lr_exists else None, norm_order=norm_order, window_size=window_size, ) @@ -108,14 +114,18 @@ def _baseline_d_s( return np.mean(diff) ** (1 / norm_order) -def _np_d_s(preds, target, norm_order, window_size): +def _np_d_s(preds, target, pan=None, pan_lr=None, norm_order=1, window_size=7): np_preds = preds.permute(0, 2, 3, 1).cpu().numpy() - assert isinstance(target, dict), f"Expected `target` to be dict. Got {type(target)}." - assert "ms" in target, "Expected `target` to contain 'ms'." - np_ms = target["ms"].permute(0, 2, 3, 1).cpu().numpy() - assert "pan" in target, "Expected `target` to contain 'pan'." - np_pan = target["pan"].permute(0, 2, 3, 1).cpu().numpy() - np_pan_lr = target["pan_lr"].permute(0, 2, 3, 1).cpu().numpy() if "pan_lr" in target else None + if isinstance(target, dict): + assert "ms" in target, "Expected `target` to contain 'ms'." + np_ms = target["ms"].permute(0, 2, 3, 1).cpu().numpy() + assert "pan" in target, "Expected `target` to contain 'pan'." + np_pan = target["pan"].permute(0, 2, 3, 1).cpu().numpy() + np_pan_lr = target["pan_lr"].permute(0, 2, 3, 1).cpu().numpy() if "pan_lr" in target else None + else: + np_ms = target.permute(0, 2, 3, 1).cpu().numpy() + np_pan = pan.permute(0, 2, 3, 1).cpu().numpy() + np_pan_lr = pan_lr.permute(0, 2, 3, 1).cpu().numpy() if pan_lr is not None else None return _baseline_d_s( np_preds, @@ -127,9 +137,16 @@ def _np_d_s(preds, target, norm_order, window_size): ) +def _invoke_spatial_distortion_index(preds, target, ms, pan, pan_lr, norm_order, window_size): + ms = target["ms"] if "ms" in target else ms + pan = target["pan"] if "pan" in target else pan + pan_lr = target["pan_lr"] if "pan_lr" in target else pan_lr + return spatial_distortion_index(preds, ms, pan, pan_lr, norm_order, window_size) + + @pytest.mark.parametrize( - "preds, target, norm_order, window_size", - [(i.preds, i.target, i.norm_order, i.window_size) for i in _inputs], + "preds, target, ms, pan, pan_lr, norm_order, window_size", + [(i.preds, i.target, i.ms, i.pan, i.pan_lr, i.norm_order, i.window_size) for i in _inputs], ) class TestSpatialDistortionIndex(MetricTester): """Test class for `SpatialDistortionIndex` metric.""" @@ -137,7 +154,7 @@ class TestSpatialDistortionIndex(MetricTester): atol = 3e-6 @pytest.mark.parametrize("ddp", [True, False]) - def test_d_s(self, preds, target, norm_order, window_size, ddp): + def test_d_s(self, preds, target, ms, pan, pan_lr, norm_order, window_size, ddp): """Test class implementation of metric.""" self.run_class_metric_test( ddp, @@ -148,232 +165,267 @@ def test_d_s(self, preds, target, norm_order, window_size, ddp): metric_args={"norm_order": norm_order, "window_size": window_size}, ) - def test_d_s_functional(self, preds, target, norm_order, window_size): + def test_d_s_functional(self, preds, target, ms, pan, pan_lr, norm_order, window_size): """Test functional implementation of metric.""" self.run_functional_metric_test( preds, - target, + ms, spatial_distortion_index, partial(_np_d_s, norm_order=norm_order, window_size=window_size), metric_args={"norm_order": norm_order, "window_size": window_size}, + fragment_kwargs=True, + pan=pan, + pan_lr=pan_lr, ) @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") - def test_d_s_half_gpu(self, preds, target, norm_order, window_size): + def test_d_s_half_gpu(self, preds, target, ms, pan, pan_lr, norm_order, window_size): """Test dtype support of the metric on GPU.""" self.run_precision_test_gpu( preds, target, SpatialDistortionIndex, - spatial_distortion_index, + partial( + _invoke_spatial_distortion_index, + ms=ms, + pan=pan, + pan_lr=pan_lr, + norm_order=norm_order, + window_size=window_size, + ), {"norm_order": norm_order, "window_size": window_size}, ) @pytest.mark.parametrize( - ("preds", "target", "norm_order", "window_size", "match"), + ("preds", "ms", "pan", "pan_lr", "norm_order", "window_size", "match"), [ ( [1, 16, 16], - {"ms": [1, 1, 4, 4], "pan": [1, 1, 16, 16]}, + [1, 1, 4, 4], + [1, 1, 16, 16], + None, 1, 3, "Expected `preds` to have BxCxHxW shape.*", ), # len(preds.shape) - ([1, 1, 16, 16], {}, 1, 7, r"Expected `target` to have keys \('ms', 'pan'\).*"), # target.keys() - ( - [1, 1, 16, 16], - {"ms": [1, 1, 4, 4]}, - 1, - 3, - r"Expected `target` to have keys \('ms', 'pan'\).*", - ), # target.keys() ( [1, 1, 16, 16], - {"pan": [1, 1, 16, 16]}, - 1, - 3, - r"Expected `target` to have keys \('ms', 'pan'\).*", - ), # target.keys() - ( + [1, 4, 4], [1, 1, 16, 16], - {"ms": [1, 4, 4], "pan": [1, 1, 16, 16]}, + None, 1, 3, "Expected `ms` to have BxCxHxW shape.*", ), # len(target.shape) ( [1, 1, 16, 16], - {"ms": [1, 1, 4, 4], "pan": [1, 16, 16]}, + [1, 1, 4, 4], + [1, 16, 16], + None, 1, 3, "Expected `pan` to have BxCxHxW shape.*", ), # len(target.shape) ( [1, 1, 16, 16], - {"ms": [1, 1, 4, 4], "pan": [1, 1, 16, 16], "pan_lr": [1, 4, 4]}, + [1, 1, 4, 4], + [1, 1, 16, 16], + [1, 4, 4], 1, 3, "Expected `pan_lr` to have BxCxHxW shape.*", ), # len(target.shape) ( [1, 1, 16, 16], - {"ms": [1, 1, 4, 4], "pan": [1, 1, 16, 16]}, + [1, 1, 4, 4], + [1, 1, 16, 16], + None, 0, 3, "Expected `norm_order` to be a positive integer. Got norm_order: 0.", ), # invalid p ( [1, 1, 16, 16], - {"ms": [1, 1, 4, 4], "pan": [1, 1, 16, 16]}, + [1, 1, 4, 4], + [1, 1, 16, 16], + None, -1, 3, "Expected `norm_order` to be a positive integer. Got norm_order: -1.", ), # invalid p ( [1, 1, 16, 16], - {"ms": [1, 1, 4, 4], "pan": [1, 1, 16, 16]}, + [1, 1, 4, 4], + [1, 1, 16, 16], + None, 1, 0, "Expected `window_size` to be a positive integer. Got window_size: 0.", ), # invalid window_size ( [1, 1, 16, 16], - {"ms": [1, 1, 4, 4], "pan": [1, 1, 16, 16]}, + [1, 1, 4, 4], + [1, 1, 16, 16], + None, 1, -1, "Expected `window_size` to be a positive integer. Got window_size: -1.", ), # invalid window_size ( [1, 1, 16, 16], - {"ms": [1, 1, 4, 4], "pan": [1, 1, 17, 16]}, + [1, 1, 4, 4], + [1, 1, 17, 16], + None, 1, 3, "Expected `preds` and `pan` to have the same height.*", ), # invalid pan_h ( [1, 1, 16, 16], - {"ms": [1, 1, 4, 4], "pan": [1, 1, 16, 17]}, + [1, 1, 4, 4], + [1, 1, 16, 17], + None, 1, 3, "Expected `preds` and `pan` to have the same width.*", ), # invalid pan_w ( [1, 1, 16, 16], - {"ms": [1, 1, 5, 4], "pan": [1, 1, 16, 16]}, + [1, 1, 5, 4], + [1, 1, 16, 16], + None, 1, 3, "Expected height of `preds` to be multiple of height of `ms`.*", ), # invalid ms_h ( [1, 1, 16, 16], - {"ms": [1, 1, 4, 5], "pan": [1, 1, 16, 16]}, + [1, 1, 4, 5], + [1, 1, 16, 16], + None, 1, 3, "Expected width of `preds` to be multiple of width of `ms`.*", ), # invalid ms_w ( [1, 1, 16, 16], - {"ms": [1, 1, 4, 4], "pan": [1, 1, 16, 16], "pan_lr": [1, 1, 5, 4]}, + [1, 1, 4, 4], + [1, 1, 16, 16], + [1, 1, 5, 4], 1, 3, "Expected `ms` and `pan_lr` to have the same height.*", ), # invalid pan_lr_h ( [1, 1, 16, 16], - {"ms": [1, 1, 4, 4], "pan": [1, 1, 16, 16], "pan_lr": [1, 1, 4, 5]}, + [1, 1, 4, 4], + [1, 1, 16, 16], + [1, 1, 4, 5], 1, 3, "Expected `ms` and `pan_lr` to have the same width.*", ), # invalid pan_lr_w ( [1, 1, 16, 16], - {"ms": [1, 2, 4, 4], "pan": [1, 1, 16, 16]}, + [1, 2, 4, 4], + [1, 1, 16, 16], + None, 1, 3, "Expected `preds` and `ms` to have the same batch and channel.*", ), # invalid ms.shape ( [1, 1, 16, 16], - {"ms": [2, 1, 4, 4], "pan": [1, 1, 16, 16]}, + [2, 1, 4, 4], + [1, 1, 16, 16], + None, 1, 3, "Expected `preds` and `ms` to have the same batch and channel.*", ), # invalid ms.shape ( [1, 1, 16, 16], - {"ms": [1, 1, 4, 4], "pan": [1, 2, 16, 16]}, + [1, 1, 4, 4], + [1, 2, 16, 16], + None, 1, 3, "Expected `preds` and `pan` to have the same batch and channel.*", ), # invalid pan.shape ( [1, 1, 16, 16], - {"ms": [1, 1, 4, 4], "pan": [2, 1, 16, 16]}, + [1, 1, 4, 4], + [2, 1, 16, 16], + None, 1, 3, "Expected `preds` and `pan` to have the same batch and channel.*", ), # invalid pan.shape ( [1, 1, 16, 16], - {"ms": [1, 1, 4, 4], "pan": [1, 1, 16, 16], "pan_lr": [1, 2, 4, 4]}, + [1, 1, 4, 4], + [1, 1, 16, 16], + [1, 2, 4, 4], 1, 3, "Expected `preds` and `pan_lr` to have the same batch and channel.*", ), # invalid pan_lr.shape ( [1, 1, 16, 16], - {"ms": [1, 1, 4, 4], "pan": [1, 1, 16, 16], "pan_lr": [2, 1, 4, 4]}, + [1, 1, 4, 4], + [1, 1, 16, 16], + [2, 1, 4, 4], 1, 3, "Expected `preds` and `pan_lr` to have the same batch and channel.*", ), # invalid pan_lr.shape ( [1, 1, 16, 16], - {"ms": [1, 1, 4, 4], "pan": [1, 1, 16, 16]}, + [1, 1, 4, 4], + [1, 1, 16, 16], + None, 1, 5, "Expected `window_size` to be smaller than dimension of `ms`.*", ), # invalid window_size ], ) -def test_d_s_invalid_inputs(preds, target, norm_order, window_size, match): +def test_d_s_invalid_inputs(preds, ms, pan, pan_lr, norm_order, window_size, match): """Test that invalid input raises the correct errors.""" preds_t = torch.rand(preds) - target_t = {name: torch.rand(t) for name, t in target.items()} + ms_t = torch.rand(ms) + pan_t = torch.rand(pan) + pan_lr_t = torch.rand(pan_lr) if pan_lr is not None else None with pytest.raises(ValueError, match=match): - spatial_distortion_index(preds_t, target_t, norm_order, window_size) + spatial_distortion_index(preds_t, ms_t, pan_t, pan_lr_t, norm_order, window_size) @pytest.mark.parametrize( - ("target", "match"), + ("ms", "pan", "pan_lr", "match"), [ ( - { - "ms": torch.rand((1, 1, 4, 4), dtype=torch.float64), - "pan": torch.rand((1, 1, 16, 16)), - }, + torch.rand((1, 1, 4, 4), dtype=torch.float64), + torch.rand((1, 1, 16, 16)), + None, "Expected `preds` and `ms` to have the same data type.*", ), ( - { - "ms": torch.rand((1, 1, 4, 4)), - "pan": torch.rand((1, 1, 16, 16), dtype=torch.float64), - }, + torch.rand((1, 1, 4, 4)), + torch.rand((1, 1, 16, 16), dtype=torch.float64), + None, "Expected `preds` and `pan` to have the same data type.*", ), ( - { - "ms": torch.rand((1, 1, 4, 4)), - "pan": torch.rand((1, 1, 16, 16)), - "pan_lr": torch.rand((1, 1, 4, 4), dtype=torch.float64), - }, + torch.rand((1, 1, 4, 4)), + torch.rand((1, 1, 16, 16)), + torch.rand((1, 1, 4, 4), dtype=torch.float64), "Expected `preds` and `pan_lr` to have the same data type.*", ), ], ) -def test_d_s_invalid_type(target, match): +def test_d_s_invalid_type(ms, pan, pan_lr, match): """Test that error is raised on different dtypes.""" preds_t = torch.rand((1, 1, 16, 16)) with pytest.raises(TypeError, match=match): - spatial_distortion_index(preds_t, target, norm_order=1, window_size=7) + spatial_distortion_index(preds_t, ms, pan, pan_lr, norm_order=1, window_size=7) From 115ead34cc88fd5ca150b3965d933f3926f22c88 Mon Sep 17 00:00:00 2001 From: Vincent Chan Date: Mon, 18 Dec 2023 20:01:25 +0000 Subject: [PATCH 30/32] fix type hint --- src/torchmetrics/functional/image/d_s.py | 10 +++++----- src/torchmetrics/image/d_s.py | 24 +++++++++++++++++++++--- 2 files changed, 26 insertions(+), 8 deletions(-) diff --git a/src/torchmetrics/functional/image/d_s.py b/src/torchmetrics/functional/image/d_s.py index 4a97a3cf151..33f64217839 100644 --- a/src/torchmetrics/functional/image/d_s.py +++ b/src/torchmetrics/functional/image/d_s.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple +from typing import Optional, Tuple import torch from torch import Tensor @@ -24,8 +24,8 @@ def _spatial_distortion_index_update( - preds: Tensor, ms: Tensor, pan: Tensor, pan_lr: Tensor = None -) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + preds: Tensor, ms: Tensor, pan: Tensor, pan_lr: Optional[Tensor] = None +) -> Tuple[Tensor, Tensor, Tensor, Optional[Tensor]]: """Update and returns variables required to compute Spatial Distortion Index. Args: @@ -129,7 +129,7 @@ def _spatial_distortion_index_compute( preds: Tensor, ms: Tensor, pan: Tensor, - pan_lr: Tensor = None, + pan_lr: Optional[Tensor] = None, norm_order: int = 1, window_size: int = 7, reduction: Literal["elementwise_mean", "sum", "none"] = "elementwise_mean", @@ -203,7 +203,7 @@ def spatial_distortion_index( preds: Tensor, ms: Tensor, pan: Tensor, - pan_lr: Tensor = None, + pan_lr: Optional[Tensor] = None, norm_order: int = 1, window_size: int = 7, reduction: Literal["elementwise_mean", "sum", "none"] = "elementwise_mean", diff --git a/src/torchmetrics/image/d_s.py b/src/torchmetrics/image/d_s.py index 16930c46f0a..257b2d0e17b 100644 --- a/src/torchmetrics/image/d_s.py +++ b/src/torchmetrics/image/d_s.py @@ -115,9 +115,27 @@ def __init__( self.add_state("pan_lr", default=[], dist_reduce_fx="cat") def update(self, preds: Tensor, target: Dict[str, Tensor]) -> None: - """Update state with preds and target.""" - ms = target["ms"] if "ms" in target else None - pan = target["pan"] if "pan" in target else None + """Update state with preds and target. + + Args: + preds: High resolution multispectral image. + target: A dictionary containing the following keys: + + - ``'ms'``: low resolution multispectral image. + - ``'pan'``: high resolution panchromatic image. + - ``'pan_lr'``: (optional) low resolution panchromatic image. + + Raises: + ValueError: + If ``target`` doesn't have ``ms`` and ``pan``. + + """ + if "ms" not in target: + raise ValueError(f"Expected `target` to have key `ms`. Got target: {target.keys()}.") + if "pan" not in target: + raise ValueError(f"Expected `target` to have key `pan`. Got target: {target.keys()}.") + ms = target["ms"] + pan = target["pan"] pan_lr = target["pan_lr"] if "pan_lr" in target else None preds, ms, pan, pan_lr = _spatial_distortion_index_update(preds, ms, pan, pan_lr) self.preds.append(preds) From a334561bc943fc80120742cce2e0167dfcc546d9 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Wed, 20 Dec 2023 11:12:52 +0100 Subject: [PATCH 31/32] skip on missing import --- src/torchmetrics/functional/image/d_s.py | 3 +++ src/torchmetrics/image/d_s.py | 5 ++++- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/src/torchmetrics/functional/image/d_s.py b/src/torchmetrics/functional/image/d_s.py index 33f64217839..aebacb672fd 100644 --- a/src/torchmetrics/functional/image/d_s.py +++ b/src/torchmetrics/functional/image/d_s.py @@ -22,6 +22,9 @@ from torchmetrics.utilities.distributed import reduce from torchmetrics.utilities.imports import _TORCHVISION_AVAILABLE +if not _TORCHVISION_AVAILABLE: + __doctest_skip__ = ["_spatial_distortion_index_compute", "spatial_distortion_index"] + def _spatial_distortion_index_update( preds: Tensor, ms: Tensor, pan: Tensor, pan_lr: Optional[Tensor] = None diff --git a/src/torchmetrics/image/d_s.py b/src/torchmetrics/image/d_s.py index 257b2d0e17b..722bbfa87c3 100644 --- a/src/torchmetrics/image/d_s.py +++ b/src/torchmetrics/image/d_s.py @@ -21,9 +21,12 @@ from torchmetrics.metric import Metric from torchmetrics.utilities import rank_zero_warn from torchmetrics.utilities.data import dim_zero_cat -from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE, _TORCHVISION_AVAILABLE from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE +if not _TORCHVISION_AVAILABLE: + __doctest_skip__ = ["SpatialDistortionIndex", "SpatialDistortionIndex.plot"] + if not _MATPLOTLIB_AVAILABLE: __doctest_skip__ = ["SpatialDistortionIndex.plot"] From f926209c5aa8fe1295eaf3e976121a5056b33d82 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Wed, 20 Dec 2023 11:18:47 +0100 Subject: [PATCH 32/32] skip on missing import --- src/torchmetrics/image/d_s.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/torchmetrics/image/d_s.py b/src/torchmetrics/image/d_s.py index 722bbfa87c3..19e9d61678e 100644 --- a/src/torchmetrics/image/d_s.py +++ b/src/torchmetrics/image/d_s.py @@ -24,12 +24,12 @@ from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE, _TORCHVISION_AVAILABLE from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE -if not _TORCHVISION_AVAILABLE: - __doctest_skip__ = ["SpatialDistortionIndex", "SpatialDistortionIndex.plot"] - if not _MATPLOTLIB_AVAILABLE: __doctest_skip__ = ["SpatialDistortionIndex.plot"] +if not _TORCHVISION_AVAILABLE: + __doctest_skip__ = ["SpatialDistortionIndex", "SpatialDistortionIndex.plot"] + class SpatialDistortionIndex(Metric): """Compute Spatial Distortion Index (SpatialDistortionIndex_) also now as D_s.