Skip to content

Commit

Permalink
Fixed the hflip not being along the right coordinate
Browse files Browse the repository at this point in the history
  • Loading branch information
Alexandre-SCHOEPP committed Dec 17, 2024
1 parent cabce1c commit d1b27ad
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions torchvision/transforms/v2/functional/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def horizontal_flip_mask(mask: torch.Tensor) -> torch.Tensor:


def horizontal_flip_keypoints(kp: torch.Tensor, canvas_size: Tuple[int, int]):
kp[0] = kp[0].sub_(canvas_size[1]).neg_()
kp[..., 0] = kp[..., 0].sub_(canvas_size[1]).neg_()
return kp


Expand Down Expand Up @@ -135,7 +135,7 @@ def vertical_flip_mask(mask: torch.Tensor) -> torch.Tensor:

@_register_kernel_internal(vertical_flip, tv_tensors.KeyPoints, tv_tensor_wrapper=False)
def vertical_flip_keypoints(kp: tv_tensors.KeyPoints):
kp[1] = kp[1].sub_(kp.canvas_size[0]).neg_()
kp[..., 1] = kp[..., 1].sub_(kp.canvas_size[0]).neg_()
return kp


Expand Down Expand Up @@ -363,8 +363,8 @@ def resize_keypoints(

w_ratio = new_width / old_width
h_ratio = new_height / old_height
ratios = torch.tensor([w_ratio, h_ratio])
kp.data = kp.data.mul(ratios).to(kp.dtype)
ratios = torch.tensor([w_ratio, h_ratio], device=kp.device)
kp = kp.mul(ratios).to(kp.dtype)

return kp, (new_height, new_width)

Expand Down Expand Up @@ -880,14 +880,14 @@ def affine_keypoints(

@_register_kernel_internal(affine, tv_tensors.KeyPoints, tv_tensor_wrapper=False)
def _affine_keypoints_dispatch(
inpt: tv_tensors.BoundingBoxes,
inpt: tv_tensors.KeyPoints,
angle: Union[int, float],
translate: List[float],
scale: float,
shear: List[float],
center: Optional[List[float]] = None,
**kwargs,
) -> tv_tensors.BoundingBoxes:
) -> tv_tensors.KeyPoints:
output, canvas_size = affine_keypoints(
inpt.as_subclass(torch.Tensor),
canvas_size=inpt.canvas_size,
Expand Down Expand Up @@ -2490,7 +2490,7 @@ def resized_crop_keypoints(


@_register_kernel_internal(resized_crop, tv_tensors.KeyPoints, tv_tensor_wrapper=False)
def _resized_crop_dispatch(
def _resized_crop_keypoints_dispatch(
inpt: tv_tensors.BoundingBoxes, top: int, left: int, height: int, width: int, size: List[int], **kwargs
):
out, canvas_size = resized_crop_keypoints(
Expand Down

0 comments on commit d1b27ad

Please sign in to comment.