From d09c9816f5dcd9af56e1ef9e530b5e3353252c3e Mon Sep 17 00:00:00 2001 From: Gu Wang Date: Thu, 28 Sep 2023 12:41:52 +0800 Subject: [PATCH 1/3] allow size to be generic Sequence For example, when we pass a `size` arg in a config created by omegaconf, only allowing `list, tuple` will raise an error. --- torchvision/transforms/v2/_geometry.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/transforms/v2/_geometry.py b/torchvision/transforms/v2/_geometry.py index df5d82e75ad..5f373e9df55 100644 --- a/torchvision/transforms/v2/_geometry.py +++ b/torchvision/transforms/v2/_geometry.py @@ -135,7 +135,7 @@ def __init__( if isinstance(size, int): size = [size] - elif isinstance(size, (list, tuple)) and len(size) in {1, 2}: + elif isinstance(size, (list, tuple, Sequence)) and len(size) in {1, 2}: size = list(size) else: raise ValueError( From 6dca625a21aa24b2df581596043a18967e899c58 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 28 Sep 2023 09:58:23 +0100 Subject: [PATCH 2/3] Update torchvision/transforms/v2/_geometry.py --- torchvision/transforms/v2/_geometry.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/transforms/v2/_geometry.py b/torchvision/transforms/v2/_geometry.py index 5f373e9df55..ee5d50a64a1 100644 --- a/torchvision/transforms/v2/_geometry.py +++ b/torchvision/transforms/v2/_geometry.py @@ -135,7 +135,7 @@ def __init__( if isinstance(size, int): size = [size] - elif isinstance(size, (list, tuple, Sequence)) and len(size) in {1, 2}: + elif isinstance(size, Sequence) and len(size) in {1, 2}: size = list(size) else: raise ValueError( From 0e77079073b19e425cc7762f2eca6ed8a18af795 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 28 Sep 2023 10:45:03 +0100 Subject: [PATCH 3/3] Fix error message --- test/test_transforms_v2_refactored.py | 2 +- torchvision/transforms/v2/_geometry.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_transforms_v2_refactored.py b/test/test_transforms_v2_refactored.py index 92caaa4db4d..78d7da5a054 100644 --- a/test/test_transforms_v2_refactored.py +++ b/test/test_transforms_v2_refactored.py @@ -699,7 +699,7 @@ def test_interpolation_int(self, interpolation, make_input): assert_equal(actual, expected) def test_transform_unknown_size_error(self): - with pytest.raises(ValueError, match="size can either be an integer or a list or tuple of one or two integers"): + with pytest.raises(ValueError, match="size can either be an integer or a sequence of one or two integers"): transforms.Resize(size=object()) @pytest.mark.parametrize( diff --git a/torchvision/transforms/v2/_geometry.py b/torchvision/transforms/v2/_geometry.py index ee5d50a64a1..4d3f3fc7fc5 100644 --- a/torchvision/transforms/v2/_geometry.py +++ b/torchvision/transforms/v2/_geometry.py @@ -139,7 +139,7 @@ def __init__( size = list(size) else: raise ValueError( - f"size can either be an integer or a list or tuple of one or two integers, " f"but got {size} instead." + f"size can either be an integer or a sequence of one or two integers, but got {size} instead." ) self.size = size