Skip to content

Commit

Permalink
changes
Browse files Browse the repository at this point in the history
  • Loading branch information
soma2000-lang committed Jan 4, 2023
1 parent b0e963c commit f85d3c8
Showing 1 changed file with 15 additions and 18 deletions.
33 changes: 15 additions & 18 deletions src/torchmetrics/functional/image/psnrb.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,32 +39,28 @@ def _compute_bef(
h_b = torch.tensor(range(block_size - 1, width - 1, block_size))
h_bc = torch.tensor(list(set(h).symmetric_difference(h_b)))

h = torch.arange(height - 1)
v = torch.arange(height - 1)
v_b = torch.tensor(range(block_size - 1, height - 1, block_size))
v_bc = torch.tensor(list(set(v).symmetric_difference(v_b)))

d_b = 0
d_bc = 0

# h_b for loop
for i in list(h_b):
diff = target[:, i] - target[:, i + 1]
d_b += torch.sum(torch.square(diff))
h_b = torch.arange(0, target.shape[1] - 1, dtype=torch.long)
h_bc = h_b + 1
v_b = torch.arange(0, target.shape[0] - 1, dtype=torch.long)
v_bc = v_b + 1
diff = target.gather(1, h_b.unsqueeze(-1)) - torch.gather(1, h_b.unsqueeze(-1))
d_b += torch.sum(torch.square(diff))
diff = torch.gather(0, v_b.unsqueeze(0)) - torch.gather(0, v_b.unsqueeze(0))
d_b += torch.sum(torch.square(diff))

# h_bc for loop
for i in list(h_bc):
diff = target[:, i] - target[:, i + 1]
d_bc += torch.sum(torch.square(diff))
diff = torch.gather(1, h_bc.unsqueeze(-1)) - torch.gather(1, h_b.unsqueeze(-1))
d_bc += torch.sum(torch.square(diff))
diff = torch.gather(0, v_bc.unsqueeze(0)) - torch.gather(0, v_b.unsqueeze(0))
d_bc += torch.sum(torch.square(diff))

# v_b for loop
for j in list(v_b):
diff = target[j, :] - target[j + 1, :]
d_b += torch.sum(torch.square(diff))

# V_bc for loop
for j in list(v_bc):
diff = target[j, :] - target[j + 1, :]
d_bc += torch.sum(tensor.square(diff))

# N code
n_hb = height * (width / block_size) - 1
Expand All @@ -84,6 +80,7 @@ def _compute_bef(

return bef


def _psnr_compute(
sum_squared_error: Tensor,
n_obs: Tensor,
Expand Down Expand Up @@ -155,7 +152,7 @@ def _psnr_update(
return sum_squared_error, n_obs


def peak_signal_noise_ratio(
def peak_signal_noise_ratio_with_blocked_effect(
preds: Tensor,
target: Tensor,
data_range: Optional[float] = None,
Expand Down

0 comments on commit f85d3c8

Please sign in to comment.