Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

3D extension for SSIM #818

Merged
merged 125 commits into from
Mar 25, 2022
Merged
Show file tree
Hide file tree
Changes from 30 commits
Commits
Show all changes
125 commits
Select commit Hold shift + click to select a range
848142f
3D ssim, first try
Jan 30, 2022
6b726b4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 31, 2022
f802930
changelog
SkafteNicki Jan 31, 2022
2eee2ea
fix doctest
SkafteNicki Jan 31, 2022
a8e9d02
Update torchmetrics/functional/image/ssim.py
weningerleon Jan 31, 2022
207c5bf
Update torchmetrics/functional/image/ssim.py
weningerleon Jan 31, 2022
fc2d6ed
Update torchmetrics/functional/image/ssim.py
weningerleon Jan 31, 2022
b11817b
Update torchmetrics/functional/image/ssim.py
weningerleon Jan 31, 2022
30345d7
update ssim 3d
Feb 1, 2022
26e6799
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 1, 2022
967209d
adding 3d ssim tests
Feb 1, 2022
002c932
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 1, 2022
75d93fb
Merge branch 'PyTorchLightning:master' into master
weningerleon Feb 1, 2022
0a552a0
Merge branch 'master' into master
Borda Feb 5, 2022
95b78c9
Merge branch 'master' into weningerleon/master
Borda Feb 8, 2022
bacbecb
update
Borda Feb 8, 2022
68011c8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 8, 2022
df1b9e7
Merge branch 'master' into master
weningerleon Feb 8, 2022
8199e78
Merge branch 'master' into master
Borda Feb 8, 2022
50be3f3
fixed formatting errors
Feb 9, 2022
b12a922
bug fix ssim
Feb 9, 2022
ad285e7
_ssim_update
Feb 9, 2022
665c92c
Merge branch 'master' into master
justusschock Feb 9, 2022
57ae26e
Apply suggestions from code review
SkafteNicki Feb 10, 2022
769a7f4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 10, 2022
084fc4f
docs
SkafteNicki Feb 10, 2022
bc66a4c
Merge branch 'master' of https://github.com/weningerleon/metrics into…
SkafteNicki Feb 10, 2022
faf9900
Merge branch 'master' into master
Borda Feb 10, 2022
98d1490
Merge branch 'master' of https://github.com/weningerleon/metrics into…
SkafteNicki Feb 10, 2022
bf0fcb2
Merge branch 'master' into master
mergify[bot] Feb 10, 2022
5fe1410
Merge branch 'master' into master
Borda Feb 10, 2022
378a07e
Merge branch 'master' into master
mergify[bot] Feb 10, 2022
ca6b90f
Merge branch 'master' into master
mergify[bot] Feb 11, 2022
3ae6db8
Merge branch 'master' into master
mergify[bot] Feb 11, 2022
2d40262
update kernel_size default
Feb 11, 2022
d38d8a2
ms-ssim in 3d
Feb 11, 2022
c0098a3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 11, 2022
a7766ee
formatting
Feb 11, 2022
9d2dc0d
merge
Feb 11, 2022
e33877a
Merge branch 'master' into master
mergify[bot] Feb 14, 2022
b73b0cd
Merge branch 'master' into master
mergify[bot] Feb 14, 2022
248eb95
Use our own 3D reflection padding
stancld Feb 16, 2022
9f2c13c
pytorch implementation depending on version, user warning if deprecated
Feb 17, 2022
2cc5873
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 17, 2022
f45c908
update torch version checking
Feb 17, 2022
fb38f21
Merge branch 'master' of github.com:weningerleon/metrics
Feb 17, 2022
d72b5b2
Apply suggestions from code review
stancld Feb 17, 2022
3c57d36
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 17, 2022
ed68ec2
Use smaller batch size due to OOM
stancld Feb 17, 2022
7d256f6
Fix test according to weningerleon's suggestsion + apply a small batch
stancld Feb 17, 2022
afe6de8
Clean reference sk metric
stancld Feb 17, 2022
08c0f6f
updates and bug fixes
Feb 17, 2022
20f5336
merge
Feb 17, 2022
6e8476d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 17, 2022
8309a78
docs
Feb 17, 2022
709d4e7
adapt ssim
Feb 17, 2022
a2be8bc
Merge branch 'master' of github.com:weningerleon/metrics
Feb 17, 2022
b37a512
adapt ssim
Feb 17, 2022
24314df
Merge branch 'master' into master
mergify[bot] Feb 18, 2022
986048e
fix tests
Feb 18, 2022
28aded0
old atol ssim
Feb 18, 2022
c6d9f83
Merge branch 'master' of github.com:weningerleon/metrics
Feb 18, 2022
98fef1b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 18, 2022
5b22686
formatting
Feb 18, 2022
7618c8b
merge
Feb 18, 2022
11c73b6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 18, 2022
86bd470
fix ms ssim
Feb 18, 2022
8185fbd
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 18, 2022
60f0719
docs
Feb 18, 2022
b4ffccc
num_batches
Feb 18, 2022
7dd1578
doctest
Feb 18, 2022
7e974ca
torch tensor
Feb 18, 2022
f7fdf88
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 18, 2022
bc3a781
Merge branch 'master' into master
mergify[bot] Feb 21, 2022
d88d02f
changelog
Feb 21, 2022
47adccf
changelog
Feb 21, 2022
ab09084
Merge branch 'master' into master
weningerleon Feb 21, 2022
40368c7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 21, 2022
15c0510
Merge branch 'master' into master
mergify[bot] Feb 21, 2022
b33e4e9
add dict
Feb 21, 2022
2082feb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 21, 2022
5a8a503
Merge branch 'master' into master
weningerleon Feb 23, 2022
69f8d56
typing
Feb 23, 2022
42d0aa1
merging
Feb 23, 2022
e5fc229
Merge branch 'master' into master
mergify[bot] Feb 23, 2022
0dc5a84
Merge branch 'master' into master
mergify[bot] Feb 23, 2022
1fafb1a
Merge branch 'master' into master
mergify[bot] Feb 24, 2022
dc2e79a
Update tests/image/test_ssim.py
weningerleon Feb 24, 2022
9f46a36
Update tests/image/test_ssim.py
weningerleon Feb 24, 2022
fcc2870
Update torchmetrics/functional/image/helper.py
weningerleon Feb 24, 2022
afbb379
Update torchmetrics/image/ssim.py
weningerleon Feb 24, 2022
a2ac937
removed kernel size parametrization
Feb 24, 2022
b93e36b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 24, 2022
371572f
add Raise to docstring
Feb 24, 2022
af9c82d
Merge branch 'master' of github.com:weningerleon/metrics
Feb 24, 2022
98675a1
Update torchmetrics/image/ssim.py
weningerleon Feb 24, 2022
8785563
Update torchmetrics/functional/image/ssim.py
weningerleon Feb 24, 2022
0923588
Merge branch 'master' into master
mergify[bot] Feb 25, 2022
0126a82
Apply suggestions from code review
Borda Feb 25, 2022
28363af
Merge branch 'master' into master
mergify[bot] Feb 28, 2022
fc897b6
Merge branch 'master' into master
mergify[bot] Mar 1, 2022
d35c88f
Merge branch 'master' into master
mergify[bot] Mar 1, 2022
cb960af
add reduce
Mar 2, 2022
1eb5dfa
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 2, 2022
299ca82
Merge branch 'master' into master
mergify[bot] Mar 3, 2022
7f583dc
re fix tests
Mar 3, 2022
cb6332d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 3, 2022
84dccfc
reduce, other settings
Mar 3, 2022
74c3ebf
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 3, 2022
9205d1c
Merge branch 'master' into master
mergify[bot] Mar 3, 2022
f732cd0
Merge branch 'master' into master
mergify[bot] Mar 7, 2022
d47be94
Merge branch 'master' into master
mergify[bot] Mar 11, 2022
20a2877
Merge branch 'master' into master
Borda Mar 20, 2022
4eed069
Merge branch 'master' into master
mergify[bot] Mar 21, 2022
e27b0ce
Merge branch 'master' into master
mergify[bot] Mar 21, 2022
6841417
missing docstrings
SkafteNicki Mar 21, 2022
5b455be
fix doctest
SkafteNicki Mar 21, 2022
cd0dbad
fix doc test
SkafteNicki Mar 21, 2022
6bcbee2
Merge branch 'master' into master
SkafteNicki Mar 21, 2022
daf5ae3
Merge branch 'master' into master
mergify[bot] Mar 22, 2022
7a8576b
fix memory issues
SkafteNicki Mar 22, 2022
c93b2c1
device placement
SkafteNicki Mar 22, 2022
1af87a8
lower memory
SkafteNicki Mar 22, 2022
eba0214
Merge branch 'master' into weningerleon/master
Borda Mar 24, 2022
b057212
Merge branch 'master' into master
mergify[bot] Mar 24, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added new image metric `UniversalImageQualityIndex` ([#824](https://github.com/PyTorchLightning/metrics/pull/824))


- Added support for 3D image in `StructuralSimilarityIndexMeasure` ([#818](https://github.com/PyTorchLightning/metrics/pull/818))


- Added smart update of `MetricCollection` ([#709](https://github.com/PyTorchLightning/metrics/pull/709))


Expand Down
43 changes: 26 additions & 17 deletions tests/image/test_ssim.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,19 @@
(14, 1, 0.7, False, torch.double),
(15, 3, 0.6, True, torch.float64),
]:
preds = torch.rand(NUM_BATCHES, BATCH_SIZE, channel, size, size, dtype=dtype)
preds2d = torch.rand(NUM_BATCHES, BATCH_SIZE, channel, size, size, dtype=dtype)
_inputs.append(
Input(
preds=preds,
target=preds * coef,
preds=preds2d,
target=preds2d * coef,
multichannel=multichannel,
)
)
preds3d = torch.rand(NUM_BATCHES, BATCH_SIZE, channel, size, size, size, dtype=dtype)
_inputs.append(
Input(
preds=preds3d,
target=preds3d * coef,
multichannel=multichannel,
)
)
Expand All @@ -52,16 +60,19 @@ def _sk_ssim(preds, target, data_range, multichannel, kernel_size):
sk_preds = sk_preds[:, :, :, 0]
sk_target = sk_target[:, :, :, 0]

return structural_similarity(
sk_target,
sk_preds,
data_range=data_range,
multichannel=multichannel,
gaussian_weights=True,
win_size=kernel_size,
sigma=1.5,
use_sample_covariance=False,
)
results = torch.zeros(sk_preds.shape[0])
for i in range(sk_preds.shape[0]):
results[i] = structural_similarity(
sk_target[i],
sk_preds[i],
data_range=data_range,
multichannel=multichannel,
gaussian_weights=True,
win_size=kernel_size,
sigma=1.5,
use_sample_covariance=False,
)
return results
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
weningerleon marked this conversation as resolved.
Show resolved Hide resolved


@pytest.mark.parametrize(
Expand All @@ -70,8 +81,6 @@ def _sk_ssim(preds, target, data_range, multichannel, kernel_size):
)
@pytest.mark.parametrize("kernel_size", [5, 11])
class TestSSIM(MetricTester):
stancld marked this conversation as resolved.
Show resolved Hide resolved
atol = 6e-3

@pytest.mark.parametrize("ddp", [True, False])
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
def test_ssim(self, preds, target, multichannel, kernel_size, ddp, dist_sync_on_step):
Expand All @@ -81,7 +90,7 @@ def test_ssim(self, preds, target, multichannel, kernel_size, ddp, dist_sync_on_
target,
StructuralSimilarityIndexMeasure,
partial(_sk_ssim, data_range=1.0, multichannel=multichannel, kernel_size=kernel_size),
metric_args={"data_range": 1.0, "kernel_size": (kernel_size, kernel_size)},
metric_args={"data_range": 1.0, "kernel_size": kernel_size},
dist_sync_on_step=dist_sync_on_step,
)

Expand All @@ -91,7 +100,7 @@ def test_ssim_functional(self, preds, target, multichannel, kernel_size):
target,
structural_similarity_index_measure,
partial(_sk_ssim, data_range=1.0, multichannel=multichannel, kernel_size=kernel_size),
metric_args={"data_range": 1.0, "kernel_size": (kernel_size, kernel_size)},
metric_args={"data_range": 1.0, "kernel_size": kernel_size},
)

# SSIM half + cpu does not work due to missing support in torch.log
Expand Down
33 changes: 29 additions & 4 deletions torchmetrics/functional/image/helper.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Sequence
from typing import Sequence, Union

import torch
from torch import Tensor
Expand All @@ -22,8 +22,12 @@ def _gaussian(kernel_size: int, sigma: float, dtype: torch.dtype, device: torch.
return (gauss / gauss.sum()).unsqueeze(dim=0) # (1, kernel_size)


def _gaussian_kernel(
channel: int, kernel_size: Sequence[int], sigma: Sequence[float], dtype: torch.dtype, device: torch.device
def _gaussian_kernel_2d(
channel: int,
kernel_size: Sequence[int],
sigma: Sequence[float],
dtype: torch.dtype,
device: Union[torch.device, str],
) -> Tensor:
"""Computes 2D gaussian kernel.

Expand All @@ -35,7 +39,7 @@ def _gaussian_kernel(
device: device of the output tensor

Example:
>>> _gaussian_kernel(1, (5,5), (1,1), torch.float, "cpu")
>>> _gaussian_kernel_2d(1, (5,5), (1,1), torch.float, "cpu")
tensor([[[[0.0030, 0.0133, 0.0219, 0.0133, 0.0030],
[0.0133, 0.0596, 0.0983, 0.0596, 0.0133],
[0.0219, 0.0983, 0.1621, 0.0983, 0.0219],
Expand All @@ -48,3 +52,24 @@ def _gaussian_kernel(
kernel = torch.matmul(gaussian_kernel_x.t(), gaussian_kernel_y) # (kernel_size, 1) * (1, kernel_size)

return kernel.expand(channel, 1, kernel_size[0], kernel_size[1])


def _gaussian_kernel_3d(
channel: int, kernel_size: Sequence[int], sigma: Sequence[float], dtype: torch.dtype, device: torch.device
) -> Tensor:
"""Computes 3D gaussian kernel.

Args:
channel: number of channels in the image
kernel_size: size of the gaussian kernel as a tuple (h, w, d)
sigma: Standard deviation of the gaussian kernel
dtype: data type of the output tensor
device: device of the output tensor
"""

gaussian_kernel_x = _gaussian(kernel_size[0], sigma[0], dtype, device)
gaussian_kernel_y = _gaussian(kernel_size[1], sigma[1], dtype, device)
gaussian_kernel_z = _gaussian(kernel_size[2], sigma[2], dtype, device)
kernel_xy = torch.matmul(gaussian_kernel_x.t(), gaussian_kernel_y) # (kernel_size, 1) * (1, kernel_size)
kernel = torch.mul(kernel_xy, gaussian_kernel_z.expand(kernel_size[0], kernel_size[1], kernel_size[2]))
return kernel.expand(channel, 1, kernel_size[0], kernel_size[1], kernel_size[2])
54 changes: 41 additions & 13 deletions torchmetrics/functional/image/ssim.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from torch.nn import functional as F
from typing_extensions import Literal

from torchmetrics.functional.image.helper import _gaussian_kernel
from torchmetrics.functional.image.helper import _gaussian_kernel_2d, _gaussian_kernel_3d
from torchmetrics.utilities.checks import _check_same_shape
from torchmetrics.utilities.distributed import reduce

Expand All @@ -38,9 +38,9 @@ def _ssim_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
f" Got preds: {preds.dtype} and target: {target.dtype}."
)
_check_same_shape(preds, target)
if len(preds.shape) != 4:
if len(preds.shape) not in (4, 5):
raise ValueError(
"Expected `preds` and `target` to have BxCxHxW shape."
"Expected `preds` and `target` to have BxCxHxW or BxCxDxHxW shape."
f" Got preds: {preds.shape} and target: {target.shape}."
)
return preds, target
Expand All @@ -49,8 +49,8 @@ def _ssim_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
def _ssim_compute(
preds: Tensor,
target: Tensor,
kernel_size: Sequence[int] = (11, 11),
sigma: Sequence[float] = (1.5, 1.5),
kernel_size: Union[int, Sequence[int]] = 11,
sigma: Union[float, Sequence[float]] = 1.5,
reduction: str = "elementwise_mean",
data_range: Optional[float] = None,
k1: float = 0.01,
Expand Down Expand Up @@ -81,10 +81,26 @@ def _ssim_compute(
>>> _ssim_compute(preds, target)
tensor(0.9219)
"""
if len(kernel_size) != 2 or len(sigma) != 2:
is_3d = len(preds.shape) == 5

if not isinstance(kernel_size, Sequence):
kernel_size = 3 * [kernel_size] if is_3d else 2 * [kernel_size]
if not isinstance(sigma, Sequence):
sigma = 3 * [sigma] if is_3d else 2 * [sigma]

if len(kernel_size) != len(target.shape) - 2:
raise ValueError(
f"`kernel_size` has dimension {len(kernel_size)}, but expected to be two less that target dimensionality, "
f"which is: {len(target.shape)}"
Borda marked this conversation as resolved.
Show resolved Hide resolved
)
if len(kernel_size) != len(sigma):
raise ValueError(
"Expected `kernel_size` and `sigma` to have the length of two."
f" Got kernel_size: {len(kernel_size)} and sigma: {len(sigma)}."
"Expected `kernel_size` and `sigma` to have the same length."
f" Kernel_size dimensionality: {len(kernel_size)}, sigma dimensionality: {len(sigma)}."
)
if len(kernel_size) not in (2, 3):
raise ValueError(
f"Expected `kernel_size` dimension to be 2 or 3. `kernel_size` dimensionality: {len(kernel_size)}"
)

if any(x % 2 == 0 or x <= 0 for x in kernel_size):
Expand All @@ -102,15 +118,27 @@ def _ssim_compute(

channel = preds.size(1)
dtype = preds.dtype
kernel = _gaussian_kernel(channel, kernel_size, sigma, dtype, device)
pad_h = (kernel_size[0] - 1) // 2
pad_w = (kernel_size[1] - 1) // 2

preds = F.pad(preds, (pad_h, pad_h, pad_w, pad_w), mode="reflect")
target = F.pad(target, (pad_h, pad_h, pad_w, pad_w), mode="reflect")
if is_3d:
pad_d = (kernel_size[2] - 1) // 2
preds = F.pad(preds, (pad_h, pad_h, pad_w, pad_w, pad_d, pad_d), mode="reflect")
target = F.pad(target, (pad_h, pad_h, pad_w, pad_w, pad_d, pad_d), mode="reflect")
kernel = _gaussian_kernel_3d(channel, kernel_size, sigma, dtype, device)

else:
preds = F.pad(preds, (pad_h, pad_h, pad_w, pad_w), mode="reflect")
target = F.pad(target, (pad_h, pad_h, pad_w, pad_w), mode="reflect")
kernel = _gaussian_kernel_2d(channel, kernel_size, sigma, dtype, device)

input_list = torch.cat((preds, target, preds * preds, target * target, preds * target)) # (5 * B, C, H, W)
outputs = F.conv2d(input_list, kernel, groups=channel)

if is_3d:
outputs = F.conv3d(input_list, kernel, groups=channel)
else:
outputs = F.conv2d(input_list, kernel, groups=channel)

output_list = outputs.split(preds.shape[0])

mu_pred_sq = output_list[0].pow(2)
Expand Down Expand Up @@ -145,7 +173,7 @@ def structural_similarity_index_measure(
k1: float = 0.01,
k2: float = 0.03,
) -> Tensor:
weningerleon marked this conversation as resolved.
Show resolved Hide resolved
"""Computes Structural Similarity Index Measure.
"""Computes Structural Similarity Index Measure. Supports both 2D and 3D images.

Args:
preds: estimated image
Expand Down
4 changes: 2 additions & 2 deletions torchmetrics/functional/image/uqi.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from torch.nn import functional as F
from typing_extensions import Literal

from torchmetrics.functional.image.helper import _gaussian_kernel
from torchmetrics.functional.image.helper import _gaussian_kernel_2d
from torchmetrics.utilities.checks import _check_same_shape
from torchmetrics.utilities.distributed import reduce

Expand Down Expand Up @@ -95,7 +95,7 @@ def _uqi_compute(
device = preds.device
channel = preds.size(1)
dtype = preds.dtype
kernel = _gaussian_kernel(channel, kernel_size, sigma, dtype, device)
kernel = _gaussian_kernel_2d(channel, kernel_size, sigma, dtype, device)
pad_h = (kernel_size[0] - 1) // 2
pad_w = (kernel_size[1] - 1) // 2

Expand Down
8 changes: 4 additions & 4 deletions torchmetrics/image/ssim.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, List, Optional, Sequence, Tuple
from typing import Any, List, Optional, Sequence, Tuple, Union

from torch import Tensor
from typing_extensions import Literal
Expand All @@ -23,7 +23,7 @@


class StructuralSimilarityIndexMeasure(Metric):
"""Computes Structual Similarity Index Measure (SSIM_).
"""Computes Structual Similarity Index Measure (SSIM_). Supports both 2D and 3D images.

Args:
kernel_size: size of the gaussian kernel
Expand Down Expand Up @@ -57,8 +57,8 @@ class StructuralSimilarityIndexMeasure(Metric):

def __init__(
self,
kernel_size: Sequence[int] = (11, 11),
sigma: Sequence[float] = (1.5, 1.5),
kernel_size: Union[int, Sequence[int]] = 11,
sigma: Union[float, Sequence[float]] = 1.5,
stancld marked this conversation as resolved.
Show resolved Hide resolved
reduction: str = "elementwise_mean",
data_range: Optional[float] = None,
k1: float = 0.01,
Expand Down