Skip to content

Commit

Permalink
Fix the issue ivy-llc#15056
Browse files Browse the repository at this point in the history
  • Loading branch information
Aryan8912 committed Sep 28, 2023
1 parent af38d41 commit 8b1671f
Showing 1 changed file with 80 additions and 0 deletions.
80 changes: 80 additions & 0 deletions ivy/functional/frontends/paddle/nn/functional/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,3 +170,83 @@ def pixel_shuffle(x, upscale_factor, data_format="NCHW"):
return ivy.reshape(
ivy.permute_dims(input_reshaped, (0, 1, 4, 2, 5, 3)), (b, oh, ow, oc)
)


"Add NN Vision Functions to Paddle Frontend "

def grid_sample(input, grid, mode='bilinear', padding_mode='zeros'):
"""
Samples elements from the input tensor using bilinear or nearest-neighbor sampling.
:param input: The input tensor of shape (batch_size, channels, height, width).
:param grid: The sampling grid of shape (batch_size, height, width, 2).
:param mode: The sampling mode - 'bilinear' or 'nearest'. Default is 'bilinear'.
:param padding_mode: The padding mode when grid values are out-of-bounds. Supports 'zeros' and 'border'.
:return: The sampled output tensor.
"""

# Bilinear sampling
if mode == 'bilinear':
# Extract dimensions
B, C, H, W = input.shape
_, H_prime, W_prime, _ = grid.shape

# Normalize the grid values to be in the range [-1, 1]
grid = 2.0 * grid / torch.tensor([W - 1, H - 1], dtype=torch.float32) - 1.0

# Map grid points to pixel indices
grid = (grid + 1) * torch.tensor([W - 1, H - 1], dtype=torch.float32) / 2
grid_floor = torch.floor(grid).long()
grid_ceil = grid_floor + 1

# Get pixel values at grid points
indices_tl = grid_floor[..., 1, :, :].clamp(0, H - 1), grid_floor[..., 0, :, :].clamp(0, W - 1)
indices_tr = grid_floor[..., 1, :, :].clamp(0, H - 1), grid_ceil[..., 0, :, :].clamp(0, W - 1)
indices_bl = grid_ceil[..., 1, :, :].clamp(0, H - 1), grid_floor[..., 0, :, :].clamp(0, W - 1)
indices_br = grid_ceil[..., 1, :, :].clamp(0, H - 1), grid_ceil[..., 0, :, :].clamp(0, W - 1)

values_tl = input[..., indices_tl[0], indices_tl[1]]
values_tr = input[..., indices_tr[0], indices_tr[1]]
values_bl = input[..., indices_bl[0], indices_bl[1]]
values_br = input[..., indices_br[0], indices_br[1]]

# Calculate bilinear interpolation weights
wa = ((grid[..., 0, :, :] - indices_tl[1].float()) * (grid[..., 1, :, :] - indices_tl[0].float())).unsqueeze(1)
wb = ((indices_tr[1].float() - grid[..., 0, :, :]) * (grid[..., 1, :, :] - indices_tr[0].float())).unsqueeze(1)
wc = ((grid[..., 0, :, :] - indices_bl[1].float()) * (indices_bl[0].float() - grid[..., 1, :, :])).unsqueeze(1)
wd = ((indices_br[1].float() - grid[..., 0, :, :]) * (indices_br[0].float() - grid[..., 1, :, :])).unsqueeze(1)

output = wa * values_tl + wb * values_tr + wc * values_bl + wd * values_br

# Nearest-neighbor sampling
elif mode == 'nearest':
# Round the grid values to get the closest integer indices
x_rounded = torch.round(grid[..., 0]).long()
y_rounded = torch.round(grid[..., 1]).long()

if padding_mode == 'zeros':
# Create masks for out-of-bound x and y positions
mask_x = torch.logical_or(x_rounded < 0, x_rounded >= W)
mask_y = torch.logical_or(y_rounded < 0, y_rounded >= H)

# Using the indices, gather the values from the input tensor
sampled_output = input[..., y_rounded, x_rounded]

# Use the mask to set out-of-bound positions in the output to zero
sampled_output = torch.where(mask_x | mask_y, torch.zeros_like(sampled_output), sampled_output)

elif padding_mode == 'border':
# Clamp the indices to lie within the borders
x_clamped = torch.clamp(x_rounded, 0, W - 1)
y_clamped = torch.clamp(y_rounded, 0, H - 1)

# Using the clamped indices, gather the values from the input tensor
sampled_output = input[..., y_clamped, x_clamped]

else:
raise ValueError("Unsupported padding_mode. Expected 'zeros' or 'border'.")

else:
raise ValueError("Unsupported mode. Expected 'bilinear' or 'nearest'.")

return sampled_output

0 comments on commit 8b1671f

Please sign in to comment.