Skip to content

Commit

Permalink
[fbsync] Fast rotation for right angles (#8295)
Browse files Browse the repository at this point in the history
Summary: Co-authored-by: Thien Tran <[email protected]>

Reviewed By: vmoens

Differential Revision: D55062767

fbshipit-source-id: b67a7f7fecab8a33143b95e6233637c336c74cd4
  • Loading branch information
NicolasHug authored and facebook-github-bot committed Mar 21, 2024
1 parent 240e044 commit 419ea82
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 0 deletions.
11 changes: 11 additions & 0 deletions test/test_transforms_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1782,6 +1782,17 @@ def test_transform_unknown_fill_error(self):
with pytest.raises(TypeError, match="Got inappropriate fill arg"):
transforms.RandomAffine(degrees=0, fill="fill")

@pytest.mark.parametrize("size", [(11, 17), (16, 16)])
@pytest.mark.parametrize("angle", [0, 90, 180, 270])
@pytest.mark.parametrize("expand", [False, True])
def test_functional_image_fast_path_correctness(self, size, angle, expand):
image = make_image(size, dtype=torch.uint8, device="cpu")

actual = F.rotate(image, angle=angle, expand=expand)
expected = F.to_image(F.rotate(F.to_pil_image(image), angle=angle, expand=expand))

torch.testing.assert_close(actual, expected)


class TestContainerTransforms:
class BuiltinTransform(transforms.Transform):
Expand Down
15 changes: 15 additions & 0 deletions torchvision/transforms/v2/functional/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -997,6 +997,21 @@ def rotate_image(
center: Optional[List[float]] = None,
fill: _FillTypeJIT = None,
) -> torch.Tensor:
angle = angle % 360 # shift angle to [0, 360) range

# fast path: transpose without affine transform
if center is None:
if angle == 0:
return image.clone()
if angle == 180:
return torch.rot90(image, k=2, dims=(-2, -1))

if expand or image.shape[-1] == image.shape[-2]:
if angle == 90:
return torch.rot90(image, k=1, dims=(-2, -1))
if angle == 270:
return torch.rot90(image, k=3, dims=(-2, -1))

interpolation = _check_interpolation(interpolation)

input_height, input_width = image.shape[-2:]
Expand Down

0 comments on commit 419ea82

Please sign in to comment.