From 3c5b5b26c4d1ea52703007d8aa8337aeecb510c1 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Tue, 5 Mar 2024 09:42:02 +0800 Subject: [PATCH 1/7] add rotation fast path --- torchvision/transforms/_functional_pil.py | 14 ++++++++++++++ torchvision/transforms/v2/functional/_geometry.py | 6 ++++++ 2 files changed, 20 insertions(+) diff --git a/torchvision/transforms/_functional_pil.py b/torchvision/transforms/_functional_pil.py index 277848224ac..0dfb25e7867 100644 --- a/torchvision/transforms/_functional_pil.py +++ b/torchvision/transforms/_functional_pil.py @@ -1,3 +1,4 @@ +import math import numbers from typing import Any, Dict, List, Literal, Optional, Sequence, Tuple, Union @@ -310,6 +311,19 @@ def rotate( if not _is_pil_image(img): raise TypeError(f"img should be PIL Image. Got {type(img)}") + angle = angle - math.floor(angle / 360) * 360 # shift angle to [0, 360) range + + # fast path: transpose without affine transform + if expand or center is None: + if angle == 0: + return img + elif angle == 90: + return img.transpose(Image.ROTATE_90) + elif angle == 180: + return img.transpose(Image.ROTATE_180) + elif angle == 270: + return img.transpose(Image.ROTATE_270) + opts = _parse_fill(fill, img) return img.rotate(angle, interpolation, expand, center, **opts) diff --git a/torchvision/transforms/v2/functional/_geometry.py b/torchvision/transforms/v2/functional/_geometry.py index b681346ab09..1c35415195c 100644 --- a/torchvision/transforms/v2/functional/_geometry.py +++ b/torchvision/transforms/v2/functional/_geometry.py @@ -997,6 +997,12 @@ def rotate_image( center: Optional[List[float]] = None, fill: _FillTypeJIT = None, ) -> torch.Tensor: + angle = angle - math.floor(angle / 360) * 360 # shift angle to [0, 360) range + + # fast path: transpose without affine transform + if (expand or center is None) and angle in (0, 90, 180, 270): + return torch.rot90(image, k=angle // 90, dims=(-1, -2)) + interpolation = _check_interpolation(interpolation) input_height, input_width = image.shape[-2:] From cb4f9436de58185dcc223f658393c26039d04e92 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Tue, 5 Mar 2024 21:26:55 +0800 Subject: [PATCH 2/7] remove PIL fast path --- torchvision/transforms/_functional_pil.py | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/torchvision/transforms/_functional_pil.py b/torchvision/transforms/_functional_pil.py index 0dfb25e7867..277848224ac 100644 --- a/torchvision/transforms/_functional_pil.py +++ b/torchvision/transforms/_functional_pil.py @@ -1,4 +1,3 @@ -import math import numbers from typing import Any, Dict, List, Literal, Optional, Sequence, Tuple, Union @@ -311,19 +310,6 @@ def rotate( if not _is_pil_image(img): raise TypeError(f"img should be PIL Image. Got {type(img)}") - angle = angle - math.floor(angle / 360) * 360 # shift angle to [0, 360) range - - # fast path: transpose without affine transform - if expand or center is None: - if angle == 0: - return img - elif angle == 90: - return img.transpose(Image.ROTATE_90) - elif angle == 180: - return img.transpose(Image.ROTATE_180) - elif angle == 270: - return img.transpose(Image.ROTATE_270) - opts = _parse_fill(fill, img) return img.rotate(angle, interpolation, expand, center, **opts) From dbb36a833e2c0afee67a4bf73bdfeeb60e068ba6 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Tue, 5 Mar 2024 21:30:08 +0800 Subject: [PATCH 3/7] return clone() for angle=0 --- torchvision/transforms/v2/functional/_geometry.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/torchvision/transforms/v2/functional/_geometry.py b/torchvision/transforms/v2/functional/_geometry.py index 1c35415195c..c6139eed4c8 100644 --- a/torchvision/transforms/v2/functional/_geometry.py +++ b/torchvision/transforms/v2/functional/_geometry.py @@ -997,11 +997,15 @@ def rotate_image( center: Optional[List[float]] = None, fill: _FillTypeJIT = None, ) -> torch.Tensor: - angle = angle - math.floor(angle / 360) * 360 # shift angle to [0, 360) range + angle = angle % 360 # shift angle to [0, 360) range # fast path: transpose without affine transform - if (expand or center is None) and angle in (0, 90, 180, 270): - return torch.rot90(image, k=angle // 90, dims=(-1, -2)) + if expand or center is None: + if angle == 0: + return image.clone() + + if angle in (90, 180, 270): + return torch.rot90(image, k=angle // 90, dims=(-1, -2)) interpolation = _check_interpolation(interpolation) From f4d9e0063ab5f9f8601304ba64bc48b98be2937c Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Tue, 5 Mar 2024 22:08:16 +0800 Subject: [PATCH 4/7] fix torchscript (and typing) --- torchvision/transforms/v2/functional/_geometry.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/torchvision/transforms/v2/functional/_geometry.py b/torchvision/transforms/v2/functional/_geometry.py index c6139eed4c8..1dfc23c2acf 100644 --- a/torchvision/transforms/v2/functional/_geometry.py +++ b/torchvision/transforms/v2/functional/_geometry.py @@ -1003,9 +1003,12 @@ def rotate_image( if expand or center is None: if angle == 0: return image.clone() - - if angle in (90, 180, 270): - return torch.rot90(image, k=angle // 90, dims=(-1, -2)) + if angle == 90: + return torch.rot90(image, k=1, dims=(-1, -2)) + if angle == 180: + return torch.rot90(image, k=2, dims=(-1, -2)) + if angle == 270: + return torch.rot90(image, k=3, dims=(-1, -2)) interpolation = _check_interpolation(interpolation) From 7cad32105073c4bf18608516013febc414115ff9 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Tue, 5 Mar 2024 22:16:07 +0800 Subject: [PATCH 5/7] remove expand check for fast path --- torchvision/transforms/v2/functional/_geometry.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/transforms/v2/functional/_geometry.py b/torchvision/transforms/v2/functional/_geometry.py index 1dfc23c2acf..684a4caa478 100644 --- a/torchvision/transforms/v2/functional/_geometry.py +++ b/torchvision/transforms/v2/functional/_geometry.py @@ -1000,7 +1000,7 @@ def rotate_image( angle = angle % 360 # shift angle to [0, 360) range # fast path: transpose without affine transform - if expand or center is None: + if center is None: if angle == 0: return image.clone() if angle == 90: From 872d4c45ce8710cefabb5949d4b4c06a7e20a434 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Wed, 6 Mar 2024 21:25:42 +0800 Subject: [PATCH 6/7] add test --- test/test_transforms_v2.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 0fb3ee6c11f..69062c7b48c 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -1777,6 +1777,15 @@ def test_transform_unknown_fill_error(self): with pytest.raises(TypeError, match="Got inappropriate fill arg"): transforms.RandomAffine(degrees=0, fill="fill") + @pytest.mark.parametrize("angle", [0, 90, 180, 270]) + def test_functional_image_fast_path_correctness(self, angle): + image = make_image(dtype=torch.uint8, device="cpu") + + actual = F.rotate(image, angle=angle) + expected = F.to_image(F.rotate(F.to_pil_image(image), angle=angle)) + + torch.testing.assert_close(actual, expected) + class TestContainerTransforms: class BuiltinTransform(transforms.Transform): From 154d7159f5dc3ec759b431009f151d8fa9e696b6 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Thu, 14 Mar 2024 00:39:33 +0800 Subject: [PATCH 7/7] fix fast path --- test/test_transforms_v2.py | 10 ++++++---- torchvision/transforms/v2/functional/_geometry.py | 12 +++++++----- 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 7b30c07a9a0..b469a630b4a 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -1782,12 +1782,14 @@ def test_transform_unknown_fill_error(self): with pytest.raises(TypeError, match="Got inappropriate fill arg"): transforms.RandomAffine(degrees=0, fill="fill") + @pytest.mark.parametrize("size", [(11, 17), (16, 16)]) @pytest.mark.parametrize("angle", [0, 90, 180, 270]) - def test_functional_image_fast_path_correctness(self, angle): - image = make_image(dtype=torch.uint8, device="cpu") + @pytest.mark.parametrize("expand", [False, True]) + def test_functional_image_fast_path_correctness(self, size, angle, expand): + image = make_image(size, dtype=torch.uint8, device="cpu") - actual = F.rotate(image, angle=angle) - expected = F.to_image(F.rotate(F.to_pil_image(image), angle=angle)) + actual = F.rotate(image, angle=angle, expand=expand) + expected = F.to_image(F.rotate(F.to_pil_image(image), angle=angle, expand=expand)) torch.testing.assert_close(actual, expected) diff --git a/torchvision/transforms/v2/functional/_geometry.py b/torchvision/transforms/v2/functional/_geometry.py index 684a4caa478..2a1250ddf6c 100644 --- a/torchvision/transforms/v2/functional/_geometry.py +++ b/torchvision/transforms/v2/functional/_geometry.py @@ -1003,12 +1003,14 @@ def rotate_image( if center is None: if angle == 0: return image.clone() - if angle == 90: - return torch.rot90(image, k=1, dims=(-1, -2)) if angle == 180: - return torch.rot90(image, k=2, dims=(-1, -2)) - if angle == 270: - return torch.rot90(image, k=3, dims=(-1, -2)) + return torch.rot90(image, k=2, dims=(-2, -1)) + + if expand or image.shape[-1] == image.shape[-2]: + if angle == 90: + return torch.rot90(image, k=1, dims=(-2, -1)) + if angle == 270: + return torch.rot90(image, k=3, dims=(-2, -1)) interpolation = _check_interpolation(interpolation)