Skip to content

Commit

Permalink
Fixed sigma input type for v2.GaussianBlur (#7887)
Browse files Browse the repository at this point in the history
Co-authored-by: Philip Meier <[email protected]>
Co-authored-by: Nicolas Hug <[email protected]>
  • Loading branch information
3 people authored Aug 30, 2023
1 parent a2f8f8e commit f1b4c7a
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 65 deletions.
45 changes: 4 additions & 41 deletions test/test_transforms_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,37 +449,6 @@ def test__get_params(self, fill, side_range):
assert 0 <= params["padding"][3] <= (side_range[1] - 1) * h


class TestGaussianBlur:
def test_assertions(self):
with pytest.raises(ValueError, match="Kernel size should be a tuple/list of two integers"):
transforms.GaussianBlur([10, 12, 14])

with pytest.raises(ValueError, match="Kernel size value should be an odd and positive number"):
transforms.GaussianBlur(4)

with pytest.raises(
TypeError, match="sigma should be a single int or float or a list/tuple with length 2 floats."
):
transforms.GaussianBlur(3, sigma=[1, 2, 3])

with pytest.raises(ValueError, match="If sigma is a single number, it must be positive"):
transforms.GaussianBlur(3, sigma=-1.0)

with pytest.raises(ValueError, match="sigma values should be positive and of the form"):
transforms.GaussianBlur(3, sigma=[2.0, 1.0])

@pytest.mark.parametrize("sigma", [10.0, [10.0, 12.0]])
def test__get_params(self, sigma):
transform = transforms.GaussianBlur(3, sigma=sigma)
params = transform._get_params([])

if isinstance(sigma, float):
assert params["sigma"][0] == params["sigma"][1] == 10
else:
assert sigma[0] <= params["sigma"][0] <= sigma[1]
assert sigma[0] <= params["sigma"][1] <= sigma[1]


class TestRandomPerspective:
def test_assertions(self):
with pytest.raises(ValueError, match="Argument distortion_scale value should be between 0 and 1"):
Expand All @@ -503,24 +472,18 @@ def test__get_params(self):
class TestElasticTransform:
def test_assertions(self):

with pytest.raises(TypeError, match="alpha should be float or a sequence of floats"):
with pytest.raises(TypeError, match="alpha should be a number or a sequence of numbers"):
transforms.ElasticTransform({})

with pytest.raises(ValueError, match="alpha is a sequence its length should be one of 2"):
with pytest.raises(ValueError, match="alpha is a sequence its length should be 1 or 2"):
transforms.ElasticTransform([1.0, 2.0, 3.0])

with pytest.raises(ValueError, match="alpha should be a sequence of floats"):
transforms.ElasticTransform([1, 2])

with pytest.raises(TypeError, match="sigma should be float or a sequence of floats"):
with pytest.raises(TypeError, match="sigma should be a number or a sequence of numbers"):
transforms.ElasticTransform(1.0, {})

with pytest.raises(ValueError, match="sigma is a sequence its length should be one of 2"):
with pytest.raises(ValueError, match="sigma is a sequence its length should be 1 or 2"):
transforms.ElasticTransform(1.0, [1.0, 2.0, 3.0])

with pytest.raises(ValueError, match="sigma should be a sequence of floats"):
transforms.ElasticTransform(1.0, [1, 2])

with pytest.raises(TypeError, match="Got inappropriate fill arg"):
transforms.ElasticTransform(1.0, 2.0, fill="abc")

Expand Down
43 changes: 43 additions & 0 deletions test/test_transforms_v2_refactored.py
Original file line number Diff line number Diff line change
Expand Up @@ -2859,3 +2859,46 @@ def test_transform_passthrough(self, make_input):
_, output = transform(make_image(self.INPUT_SIZE), input)

assert output is input


class TestGaussianBlur:
@pytest.mark.parametrize(
"make_input",
[make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask, make_video],
)
@pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("sigma", [5, (0.5, 2)])
def test_transform(self, make_input, device, sigma):
check_transform(transforms.GaussianBlur(kernel_size=3, sigma=sigma), make_input(device=device))

def test_assertions(self):
with pytest.raises(ValueError, match="Kernel size should be a tuple/list of two integers"):
transforms.GaussianBlur([10, 12, 14])

with pytest.raises(ValueError, match="Kernel size value should be an odd and positive number"):
transforms.GaussianBlur(4)

with pytest.raises(ValueError, match="If sigma is a sequence its length should be 1 or 2. Got 3"):
transforms.GaussianBlur(3, sigma=[1, 2, 3])

with pytest.raises(ValueError, match="sigma values should be positive and of the form"):
transforms.GaussianBlur(3, sigma=-1.0)

with pytest.raises(ValueError, match="sigma values should be positive and of the form"):
transforms.GaussianBlur(3, sigma=[2.0, 1.0])

with pytest.raises(TypeError, match="sigma should be a number or a sequence of numbers"):
transforms.GaussianBlur(3, sigma={})

@pytest.mark.parametrize("sigma", [10.0, [10.0, 12.0], (10, 12.0), [10]])
def test__get_params(self, sigma):
transform = transforms.GaussianBlur(3, sigma=sigma)
params = transform._get_params([])

if isinstance(sigma, float):
assert params["sigma"][0] == params["sigma"][1] == sigma
elif isinstance(sigma, list) and len(sigma) == 1:
assert params["sigma"][0] == params["sigma"][1] == sigma[0]
else:
assert sigma[0] <= params["sigma"][0] <= sigma[1]
assert sigma[0] <= params["sigma"][1] <= sigma[1]
6 changes: 3 additions & 3 deletions torchvision/transforms/v2/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
_get_fill,
_setup_angle,
_setup_fill_arg,
_setup_float_or_seq,
_setup_number_or_seq,
_setup_size,
get_bounding_boxes,
has_all,
Expand Down Expand Up @@ -1060,8 +1060,8 @@ def __init__(
fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = 0,
) -> None:
super().__init__()
self.alpha = _setup_float_or_seq(alpha, "alpha", 2)
self.sigma = _setup_float_or_seq(sigma, "sigma", 2)
self.alpha = _setup_number_or_seq(alpha, "alpha")
self.sigma = _setup_number_or_seq(sigma, "sigma")

self.interpolation = _check_interpolation(interpolation)
self.fill = fill
Expand Down
15 changes: 4 additions & 11 deletions torchvision/transforms/v2/_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from torchvision import transforms as _transforms, tv_tensors
from torchvision.transforms.v2 import functional as F, Transform

from ._utils import _parse_labels_getter, _setup_float_or_seq, _setup_size, get_bounding_boxes, has_any, is_pure_tensor
from ._utils import _parse_labels_getter, _setup_number_or_seq, _setup_size, get_bounding_boxes, has_any, is_pure_tensor


# TODO: do we want/need to expose this?
Expand Down Expand Up @@ -198,17 +198,10 @@ def __init__(
if ks <= 0 or ks % 2 == 0:
raise ValueError("Kernel size value should be an odd and positive number.")

if isinstance(sigma, (int, float)):
if sigma <= 0:
raise ValueError("If sigma is a single number, it must be positive.")
sigma = float(sigma)
elif isinstance(sigma, Sequence) and len(sigma) == 2:
if not 0.0 < sigma[0] <= sigma[1]:
raise ValueError("sigma values should be positive and of the form (min, max).")
else:
raise TypeError("sigma should be a single int or float or a list/tuple with length 2 floats.")
self.sigma = _setup_number_or_seq(sigma, "sigma")

self.sigma = _setup_float_or_seq(sigma, "sigma", 2)
if not 0.0 < self.sigma[0] <= self.sigma[1]:
raise ValueError(f"sigma values should be positive and of the form (min, max). Got {self.sigma}")

def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
sigma = torch.empty(1).uniform_(self.sigma[0], self.sigma[1]).item()
Expand Down
23 changes: 13 additions & 10 deletions torchvision/transforms/v2/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,23 @@
from torchvision.transforms.v2.functional._utils import _FillType, _FillTypeJIT


def _setup_float_or_seq(arg: Union[float, Sequence[float]], name: str, req_size: int = 2) -> Sequence[float]:
if not isinstance(arg, (float, Sequence)):
raise TypeError(f"{name} should be float or a sequence of floats. Got {type(arg)}")
if isinstance(arg, Sequence) and len(arg) != req_size:
raise ValueError(f"If {name} is a sequence its length should be one of {req_size}. Got {len(arg)}")
def _setup_number_or_seq(arg: Union[int, float, Sequence[Union[int, float]]], name: str) -> Sequence[float]:
if not isinstance(arg, (int, float, Sequence)):
raise TypeError(f"{name} should be a number or a sequence of numbers. Got {type(arg)}")
if isinstance(arg, Sequence) and len(arg) not in (1, 2):
raise ValueError(f"If {name} is a sequence its length should be 1 or 2. Got {len(arg)}")
if isinstance(arg, Sequence):
for element in arg:
if not isinstance(element, float):
raise ValueError(f"{name} should be a sequence of floats. Got {type(element)}")
if not isinstance(element, (int, float)):
raise ValueError(f"{name} should be a sequence of numbers. Got {type(element)}")

if isinstance(arg, float):
if isinstance(arg, (int, float)):
arg = [float(arg), float(arg)]
if isinstance(arg, (list, tuple)) and len(arg) == 1:
arg = [arg[0], arg[0]]
elif isinstance(arg, Sequence):
if len(arg) == 1:
arg = [float(arg[0]), float(arg[0])]
else:
arg = [float(arg[0]), float(arg[1])]
return arg


Expand Down

0 comments on commit f1b4c7a

Please sign in to comment.