From f1b4c7a6fd65479a096ed6ae44fb5e762af6c0f4 Mon Sep 17 00:00:00 2001 From: vfdev Date: Wed, 30 Aug 2023 18:13:02 +0200 Subject: [PATCH] Fixed sigma input type for v2.GaussianBlur (#7887) Co-authored-by: Philip Meier Co-authored-by: Nicolas Hug --- test/test_transforms_v2.py | 45 +++----------------------- test/test_transforms_v2_refactored.py | 43 ++++++++++++++++++++++++ torchvision/transforms/v2/_geometry.py | 6 ++-- torchvision/transforms/v2/_misc.py | 15 +++------ torchvision/transforms/v2/_utils.py | 23 +++++++------ 5 files changed, 67 insertions(+), 65 deletions(-) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 175a3ac161c..3f0056e96ab 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -449,37 +449,6 @@ def test__get_params(self, fill, side_range): assert 0 <= params["padding"][3] <= (side_range[1] - 1) * h -class TestGaussianBlur: - def test_assertions(self): - with pytest.raises(ValueError, match="Kernel size should be a tuple/list of two integers"): - transforms.GaussianBlur([10, 12, 14]) - - with pytest.raises(ValueError, match="Kernel size value should be an odd and positive number"): - transforms.GaussianBlur(4) - - with pytest.raises( - TypeError, match="sigma should be a single int or float or a list/tuple with length 2 floats." - ): - transforms.GaussianBlur(3, sigma=[1, 2, 3]) - - with pytest.raises(ValueError, match="If sigma is a single number, it must be positive"): - transforms.GaussianBlur(3, sigma=-1.0) - - with pytest.raises(ValueError, match="sigma values should be positive and of the form"): - transforms.GaussianBlur(3, sigma=[2.0, 1.0]) - - @pytest.mark.parametrize("sigma", [10.0, [10.0, 12.0]]) - def test__get_params(self, sigma): - transform = transforms.GaussianBlur(3, sigma=sigma) - params = transform._get_params([]) - - if isinstance(sigma, float): - assert params["sigma"][0] == params["sigma"][1] == 10 - else: - assert sigma[0] <= params["sigma"][0] <= sigma[1] - assert sigma[0] <= params["sigma"][1] <= sigma[1] - - class TestRandomPerspective: def test_assertions(self): with pytest.raises(ValueError, match="Argument distortion_scale value should be between 0 and 1"): @@ -503,24 +472,18 @@ def test__get_params(self): class TestElasticTransform: def test_assertions(self): - with pytest.raises(TypeError, match="alpha should be float or a sequence of floats"): + with pytest.raises(TypeError, match="alpha should be a number or a sequence of numbers"): transforms.ElasticTransform({}) - with pytest.raises(ValueError, match="alpha is a sequence its length should be one of 2"): + with pytest.raises(ValueError, match="alpha is a sequence its length should be 1 or 2"): transforms.ElasticTransform([1.0, 2.0, 3.0]) - with pytest.raises(ValueError, match="alpha should be a sequence of floats"): - transforms.ElasticTransform([1, 2]) - - with pytest.raises(TypeError, match="sigma should be float or a sequence of floats"): + with pytest.raises(TypeError, match="sigma should be a number or a sequence of numbers"): transforms.ElasticTransform(1.0, {}) - with pytest.raises(ValueError, match="sigma is a sequence its length should be one of 2"): + with pytest.raises(ValueError, match="sigma is a sequence its length should be 1 or 2"): transforms.ElasticTransform(1.0, [1.0, 2.0, 3.0]) - with pytest.raises(ValueError, match="sigma should be a sequence of floats"): - transforms.ElasticTransform(1.0, [1, 2]) - with pytest.raises(TypeError, match="Got inappropriate fill arg"): transforms.ElasticTransform(1.0, 2.0, fill="abc") diff --git a/test/test_transforms_v2_refactored.py b/test/test_transforms_v2_refactored.py index ad5cd8e00d8..b2e21fc4aca 100644 --- a/test/test_transforms_v2_refactored.py +++ b/test/test_transforms_v2_refactored.py @@ -2859,3 +2859,46 @@ def test_transform_passthrough(self, make_input): _, output = transform(make_image(self.INPUT_SIZE), input) assert output is input + + +class TestGaussianBlur: + @pytest.mark.parametrize( + "make_input", + [make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask, make_video], + ) + @pytest.mark.parametrize("device", cpu_and_cuda()) + @pytest.mark.parametrize("sigma", [5, (0.5, 2)]) + def test_transform(self, make_input, device, sigma): + check_transform(transforms.GaussianBlur(kernel_size=3, sigma=sigma), make_input(device=device)) + + def test_assertions(self): + with pytest.raises(ValueError, match="Kernel size should be a tuple/list of two integers"): + transforms.GaussianBlur([10, 12, 14]) + + with pytest.raises(ValueError, match="Kernel size value should be an odd and positive number"): + transforms.GaussianBlur(4) + + with pytest.raises(ValueError, match="If sigma is a sequence its length should be 1 or 2. Got 3"): + transforms.GaussianBlur(3, sigma=[1, 2, 3]) + + with pytest.raises(ValueError, match="sigma values should be positive and of the form"): + transforms.GaussianBlur(3, sigma=-1.0) + + with pytest.raises(ValueError, match="sigma values should be positive and of the form"): + transforms.GaussianBlur(3, sigma=[2.0, 1.0]) + + with pytest.raises(TypeError, match="sigma should be a number or a sequence of numbers"): + transforms.GaussianBlur(3, sigma={}) + + @pytest.mark.parametrize("sigma", [10.0, [10.0, 12.0], (10, 12.0), [10]]) + def test__get_params(self, sigma): + transform = transforms.GaussianBlur(3, sigma=sigma) + params = transform._get_params([]) + + if isinstance(sigma, float): + assert params["sigma"][0] == params["sigma"][1] == sigma + elif isinstance(sigma, list) and len(sigma) == 1: + assert params["sigma"][0] == params["sigma"][1] == sigma[0] + else: + assert sigma[0] <= params["sigma"][0] <= sigma[1] + assert sigma[0] <= params["sigma"][1] <= sigma[1] diff --git a/torchvision/transforms/v2/_geometry.py b/torchvision/transforms/v2/_geometry.py index ba3e690dd4d..721e9b7e452 100644 --- a/torchvision/transforms/v2/_geometry.py +++ b/torchvision/transforms/v2/_geometry.py @@ -21,7 +21,7 @@ _get_fill, _setup_angle, _setup_fill_arg, - _setup_float_or_seq, + _setup_number_or_seq, _setup_size, get_bounding_boxes, has_all, @@ -1060,8 +1060,8 @@ def __init__( fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = 0, ) -> None: super().__init__() - self.alpha = _setup_float_or_seq(alpha, "alpha", 2) - self.sigma = _setup_float_or_seq(sigma, "sigma", 2) + self.alpha = _setup_number_or_seq(alpha, "alpha") + self.sigma = _setup_number_or_seq(sigma, "sigma") self.interpolation = _check_interpolation(interpolation) self.fill = fill diff --git a/torchvision/transforms/v2/_misc.py b/torchvision/transforms/v2/_misc.py index 739f2fb7ff5..67aaf4f3753 100644 --- a/torchvision/transforms/v2/_misc.py +++ b/torchvision/transforms/v2/_misc.py @@ -9,7 +9,7 @@ from torchvision import transforms as _transforms, tv_tensors from torchvision.transforms.v2 import functional as F, Transform -from ._utils import _parse_labels_getter, _setup_float_or_seq, _setup_size, get_bounding_boxes, has_any, is_pure_tensor +from ._utils import _parse_labels_getter, _setup_number_or_seq, _setup_size, get_bounding_boxes, has_any, is_pure_tensor # TODO: do we want/need to expose this? @@ -198,17 +198,10 @@ def __init__( if ks <= 0 or ks % 2 == 0: raise ValueError("Kernel size value should be an odd and positive number.") - if isinstance(sigma, (int, float)): - if sigma <= 0: - raise ValueError("If sigma is a single number, it must be positive.") - sigma = float(sigma) - elif isinstance(sigma, Sequence) and len(sigma) == 2: - if not 0.0 < sigma[0] <= sigma[1]: - raise ValueError("sigma values should be positive and of the form (min, max).") - else: - raise TypeError("sigma should be a single int or float or a list/tuple with length 2 floats.") + self.sigma = _setup_number_or_seq(sigma, "sigma") - self.sigma = _setup_float_or_seq(sigma, "sigma", 2) + if not 0.0 < self.sigma[0] <= self.sigma[1]: + raise ValueError(f"sigma values should be positive and of the form (min, max). Got {self.sigma}") def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: sigma = torch.empty(1).uniform_(self.sigma[0], self.sigma[1]).item() diff --git a/torchvision/transforms/v2/_utils.py b/torchvision/transforms/v2/_utils.py index d5669f5739f..6147180a986 100644 --- a/torchvision/transforms/v2/_utils.py +++ b/torchvision/transforms/v2/_utils.py @@ -18,20 +18,23 @@ from torchvision.transforms.v2.functional._utils import _FillType, _FillTypeJIT -def _setup_float_or_seq(arg: Union[float, Sequence[float]], name: str, req_size: int = 2) -> Sequence[float]: - if not isinstance(arg, (float, Sequence)): - raise TypeError(f"{name} should be float or a sequence of floats. Got {type(arg)}") - if isinstance(arg, Sequence) and len(arg) != req_size: - raise ValueError(f"If {name} is a sequence its length should be one of {req_size}. Got {len(arg)}") +def _setup_number_or_seq(arg: Union[int, float, Sequence[Union[int, float]]], name: str) -> Sequence[float]: + if not isinstance(arg, (int, float, Sequence)): + raise TypeError(f"{name} should be a number or a sequence of numbers. Got {type(arg)}") + if isinstance(arg, Sequence) and len(arg) not in (1, 2): + raise ValueError(f"If {name} is a sequence its length should be 1 or 2. Got {len(arg)}") if isinstance(arg, Sequence): for element in arg: - if not isinstance(element, float): - raise ValueError(f"{name} should be a sequence of floats. Got {type(element)}") + if not isinstance(element, (int, float)): + raise ValueError(f"{name} should be a sequence of numbers. Got {type(element)}") - if isinstance(arg, float): + if isinstance(arg, (int, float)): arg = [float(arg), float(arg)] - if isinstance(arg, (list, tuple)) and len(arg) == 1: - arg = [arg[0], arg[0]] + elif isinstance(arg, Sequence): + if len(arg) == 1: + arg = [float(arg[0]), float(arg[0])] + else: + arg = [float(arg[0]), float(arg[1])] return arg