Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Jan 4, 2023
1 parent cdd5b91 commit 2c0854d
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 70 deletions.
2 changes: 1 addition & 1 deletion docs/source/links.rst
Original file line number Diff line number Diff line change
Expand Up @@ -126,4 +126,4 @@
.. _kid ref2: https://arxiv.org/abs/1706.08500
.. _Spectral Angle Mapper: https://ntrs.nasa.gov/citations/19940012238
.. _Multilabel coverage error: https://link.springer.com/chapter/10.1007/978-0-387-09823-4_34
.. _Peak Signal to Noise Ratio with Blocked Effect:https://ieeexplore.ieee.org/abstract/document/5535179
.. _Peak Signal to Noise Ratio with Blocked Effect:https://ieeexplore.ieee.org/abstract/document/5535179
4 changes: 2 additions & 2 deletions src/torchmetrics/functional/image/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@
from torchmetrics.functional.image.d_lambda import spectral_distortion_index # noqa: F401
from torchmetrics.functional.image.ergas import error_relative_global_dimensionless_synthesis # noqa: F401
from torchmetrics.functional.image.gradients import image_gradients # noqa: F401
from torchmetrics.functional.image.psnr import peak_signal_noise_ratio # noqa: F401
from torchmetrics.functional.image.psnr import peak_signal_noise_ratio
from torchmetrics.functional.image.psnrb import peak_signal_noise_ratio # noqa: F401
from torchmetrics.functional.image.sam import spectral_angle_mapper # noqa: F401
from torchmetrics.functional.image.ssim import ( # noqa: F401
multiscale_structural_similarity_index_measure,
structural_similarity_index_measure,
)
from torchmetrics.functional.image.tv import total_variation # noqa: F401
from torchmetrics.functional.image.uqi import universal_image_quality_index # noqa: F401
from torchmetrics.functional.image.psnrb import peak_signal_noise_ratio # noqa: F401
132 changes: 65 additions & 67 deletions src/torchmetrics/functional/image/psnrb.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,75 +19,73 @@

from torchmetrics.utilities import rank_zero_warn, reduce

def _compute_bef(self,
target: Tensor,
dim: Optional[Union[int, Tuple[int, ...]]] = None,
block_size=8) -> Tuple[Tensor, Tensor]:

if dim == 3:
height, width, channels = target.Size
elif dim == 2:
height, width = target.Size
channels = 1
else:
raise ValueError("Not a 1-channel/3-channel grayscale image")

if channels > 1:
raise ValueError("Not for color images")

h = torch.tensor(range(0, width - 1))
h_b = torch.tensor(range(block_size - 1, width - 1, block_size))
h_bc = torch.tensor(list(set(h).symmetric_difference(h_b)))

v = torch.tensor(range(0, height - 1))
v_b = torch.tensor(range(block_size - 1, height - 1, block_size))
v_bc = torch.tensor(list(set(v).symmetric_difference(v_b)))

d_b = 0
d_bc = 0

# h_b for loop
for i in list(h_b):
diff = target[:, i] - target[:, i+1]
d_b += torch.sum(torch.square(diff))

# h_bc for loop
for i in list(h_bc):
diff = target[:, i] - target[:, i+1]
d_bc += torch.sum(torch.square(diff))

# v_b for loop
for j in list(v_b):
diff = target[j, :] - target[j+1, :]
d_b += torch.sum(torch.square(diff))

# V_bc for loop
for j in list(v_bc):
diff = target[j, :] - target[j+1, :]
d_bc += torch.sum(tensor.square(diff))

# N code
n_hb = height * (width/block_size) - 1
n_hbc = (height * (width - 1)) - n_hb
n_vb = width * (height/block_size) - 1
n_vbc = (width * (height - 1)) - n_vb

# D code
d_b /= (n_hb + n_vb)
d_bc /= (n_hbc + n_vbc)

# Log
if d_b > d_bc:
t = torch.log2(block_size)/torch.log2(min(height, width))
else:
t = 0

# BEF
bef = t*(d_b - d_bc)

return bef

def _compute_bef(
self, target: Tensor, dim: Optional[Union[int, Tuple[int, ...]]] = None, block_size=8
) -> Tuple[Tensor, Tensor]:

if dim == 3:
height, width, channels = target.Size
elif dim == 2:
height, width = target.Size
channels = 1
else:
raise ValueError("Not a 1-channel/3-channel grayscale image")

if channels > 1:
raise ValueError("Not for color images")

h = torch.tensor(range(0, width - 1))
h_b = torch.tensor(range(block_size - 1, width - 1, block_size))
h_bc = torch.tensor(list(set(h).symmetric_difference(h_b)))

v = torch.tensor(range(0, height - 1))
v_b = torch.tensor(range(block_size - 1, height - 1, block_size))
v_bc = torch.tensor(list(set(v).symmetric_difference(v_b)))

d_b = 0
d_bc = 0

# h_b for loop
for i in list(h_b):
diff = target[:, i] - target[:, i + 1]
d_b += torch.sum(torch.square(diff))

# h_bc for loop
for i in list(h_bc):
diff = target[:, i] - target[:, i + 1]
d_bc += torch.sum(torch.square(diff))

# v_b for loop
for j in list(v_b):
diff = target[j, :] - target[j + 1, :]
d_b += torch.sum(torch.square(diff))

# V_bc for loop
for j in list(v_bc):
diff = target[j, :] - target[j + 1, :]
d_bc += torch.sum(tensor.square(diff))

# N code
n_hb = height * (width / block_size) - 1
n_hbc = (height * (width - 1)) - n_hb
n_vb = width * (height / block_size) - 1
n_vbc = (width * (height - 1)) - n_vb

# D code
d_b /= n_hb + n_vb
d_bc /= n_hbc + n_vbc

# Log
if d_b > d_bc:
t = torch.log2(block_size) / torch.log2(min(height, width))
else:
t = 0

# BEF
bef = t * (d_b - d_bc)

return bef


def _psnr_compute(
Expand Down

0 comments on commit 2c0854d

Please sign in to comment.