diff --git a/docs/source/links.rst b/docs/source/links.rst index aa3cb5e3d06..95655661551 100644 --- a/docs/source/links.rst +++ b/docs/source/links.rst @@ -126,4 +126,4 @@ .. _kid ref2: https://arxiv.org/abs/1706.08500 .. _Spectral Angle Mapper: https://ntrs.nasa.gov/citations/19940012238 .. _Multilabel coverage error: https://link.springer.com/chapter/10.1007/978-0-387-09823-4_34 -.. _Peak Signal to Noise Ratio with Blocked Effect:https://ieeexplore.ieee.org/abstract/document/5535179 \ No newline at end of file +.. _Peak Signal to Noise Ratio with Blocked Effect:https://ieeexplore.ieee.org/abstract/document/5535179 diff --git a/src/torchmetrics/functional/image/__init__.py b/src/torchmetrics/functional/image/__init__.py index 3b7bc8eaf0f..a6e19c6a686 100644 --- a/src/torchmetrics/functional/image/__init__.py +++ b/src/torchmetrics/functional/image/__init__.py @@ -14,7 +14,8 @@ from torchmetrics.functional.image.d_lambda import spectral_distortion_index # noqa: F401 from torchmetrics.functional.image.ergas import error_relative_global_dimensionless_synthesis # noqa: F401 from torchmetrics.functional.image.gradients import image_gradients # noqa: F401 -from torchmetrics.functional.image.psnr import peak_signal_noise_ratio # noqa: F401 +from torchmetrics.functional.image.psnr import peak_signal_noise_ratio +from torchmetrics.functional.image.psnrb import peak_signal_noise_ratio # noqa: F401 from torchmetrics.functional.image.sam import spectral_angle_mapper # noqa: F401 from torchmetrics.functional.image.ssim import ( # noqa: F401 multiscale_structural_similarity_index_measure, @@ -22,4 +23,3 @@ ) from torchmetrics.functional.image.tv import total_variation # noqa: F401 from torchmetrics.functional.image.uqi import universal_image_quality_index # noqa: F401 -from torchmetrics.functional.image.psnrb import peak_signal_noise_ratio # noqa: F401 \ No newline at end of file diff --git a/src/torchmetrics/functional/image/psnrb.py b/src/torchmetrics/functional/image/psnrb.py index 409351f9d28..d42208c00e6 100644 --- a/src/torchmetrics/functional/image/psnrb.py +++ b/src/torchmetrics/functional/image/psnrb.py @@ -19,75 +19,73 @@ from torchmetrics.utilities import rank_zero_warn, reduce -def _compute_bef(self, - target: Tensor, - dim: Optional[Union[int, Tuple[int, ...]]] = None, - block_size=8) -> Tuple[Tensor, Tensor]: - - if dim == 3: - height, width, channels = target.Size - elif dim == 2: - height, width = target.Size - channels = 1 - else: - raise ValueError("Not a 1-channel/3-channel grayscale image") - - if channels > 1: - raise ValueError("Not for color images") - - h = torch.tensor(range(0, width - 1)) - h_b = torch.tensor(range(block_size - 1, width - 1, block_size)) - h_bc = torch.tensor(list(set(h).symmetric_difference(h_b))) - - v = torch.tensor(range(0, height - 1)) - v_b = torch.tensor(range(block_size - 1, height - 1, block_size)) - v_bc = torch.tensor(list(set(v).symmetric_difference(v_b))) - - d_b = 0 - d_bc = 0 - - # h_b for loop - for i in list(h_b): - diff = target[:, i] - target[:, i+1] - d_b += torch.sum(torch.square(diff)) - - # h_bc for loop - for i in list(h_bc): - diff = target[:, i] - target[:, i+1] - d_bc += torch.sum(torch.square(diff)) - - # v_b for loop - for j in list(v_b): - diff = target[j, :] - target[j+1, :] - d_b += torch.sum(torch.square(diff)) - - # V_bc for loop - for j in list(v_bc): - diff = target[j, :] - target[j+1, :] - d_bc += torch.sum(tensor.square(diff)) - - # N code - n_hb = height * (width/block_size) - 1 - n_hbc = (height * (width - 1)) - n_hb - n_vb = width * (height/block_size) - 1 - n_vbc = (width * (height - 1)) - n_vb - - # D code - d_b /= (n_hb + n_vb) - d_bc /= (n_hbc + n_vbc) - - # Log - if d_b > d_bc: - t = torch.log2(block_size)/torch.log2(min(height, width)) - else: - t = 0 - - # BEF - bef = t*(d_b - d_bc) - - return bef +def _compute_bef( + self, target: Tensor, dim: Optional[Union[int, Tuple[int, ...]]] = None, block_size=8 +) -> Tuple[Tensor, Tensor]: + + if dim == 3: + height, width, channels = target.Size + elif dim == 2: + height, width = target.Size + channels = 1 + else: + raise ValueError("Not a 1-channel/3-channel grayscale image") + + if channels > 1: + raise ValueError("Not for color images") + + h = torch.tensor(range(0, width - 1)) + h_b = torch.tensor(range(block_size - 1, width - 1, block_size)) + h_bc = torch.tensor(list(set(h).symmetric_difference(h_b))) + + v = torch.tensor(range(0, height - 1)) + v_b = torch.tensor(range(block_size - 1, height - 1, block_size)) + v_bc = torch.tensor(list(set(v).symmetric_difference(v_b))) + + d_b = 0 + d_bc = 0 + + # h_b for loop + for i in list(h_b): + diff = target[:, i] - target[:, i + 1] + d_b += torch.sum(torch.square(diff)) + + # h_bc for loop + for i in list(h_bc): + diff = target[:, i] - target[:, i + 1] + d_bc += torch.sum(torch.square(diff)) + + # v_b for loop + for j in list(v_b): + diff = target[j, :] - target[j + 1, :] + d_b += torch.sum(torch.square(diff)) + + # V_bc for loop + for j in list(v_bc): + diff = target[j, :] - target[j + 1, :] + d_bc += torch.sum(tensor.square(diff)) + + # N code + n_hb = height * (width / block_size) - 1 + n_hbc = (height * (width - 1)) - n_hb + n_vb = width * (height / block_size) - 1 + n_vbc = (width * (height - 1)) - n_vb + + # D code + d_b /= n_hb + n_vb + d_bc /= n_hbc + n_vbc + + # Log + if d_b > d_bc: + t = torch.log2(block_size) / torch.log2(min(height, width)) + else: + t = 0 + + # BEF + bef = t * (d_b - d_bc) + return bef def _psnr_compute(