diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 5909b68966b..676ddb13b4f 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -123,6 +123,7 @@ Transforms on PIL Image and torch.\*Tensor Resize TenCrop GaussianBlur + GaussianNoise RandomInvert RandomPosterize RandomSolarize diff --git a/gallery/plot_transforms.py b/gallery/plot_transforms.py index c6e44a14e22..20b6c290490 100644 --- a/gallery/plot_transforms.py +++ b/gallery/plot_transforms.py @@ -119,6 +119,15 @@ def plot(imgs, with_orig=True, row_title=None, **imshow_kwargs): blurred_imgs = [blurrer(orig_img) for _ in range(4)] plot(blurred_imgs) +#################################### +# GaussianNoise +# ~~~~~~~~~~~~~ +# The :class:`~torchvision.transforms.GaussianNoise` transform +# perturbs the input image with gaussian noise. +noisy = T.GaussianNoise(mean=0, sigma=(5., 50.)) +noisy_imgs = [noisy(orig_img) for _ in range(2)] +plot(noisy_imgs) + #################################### # RandomPerspective # ~~~~~~~~~~~~~~~~~ diff --git a/test/test_transforms.py b/test/test_transforms.py index 57e61bbad70..557c6d649f0 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -1324,6 +1324,24 @@ def test_gaussian_blur_asserts(): transforms.GaussianBlur(3, "sigma_string") +def test_gaussian_noise(): + np_img = np.ones((100, 100, 3), dtype=np.uint8) * 255 + img = F.to_pil_image(np_img, "RGB") + transforms.GaussianNoise(2.0, (0.1, 2.0))(img) + + with pytest.raises(TypeError, match="Tensor is not a torch image"): + transforms.GaussianNoise(2.0, (0.1, 2.0))(torch.ones(4)) + + with pytest.raises(ValueError, match="Mean should be a positive number"): + transforms.GaussianNoise(-1) + + with pytest.raises(ValueError, match="If sigma is a single number, it must be positive."): + transforms.GaussianNoise(2.0, -1) + + with pytest.raises(ValueError, match="sigma should be a single number or a list/tuple with length 2."): + transforms.GaussianNoise(2.0, (1, 2, 3)) + + def test_lambda(): trans = transforms.Lambda(lambda x: x.add(10)) x = torch.randn(10) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index c5b2a71d0d7..13484e56c8d 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -1390,6 +1390,49 @@ def gaussian_blur(img: Tensor, kernel_size: List[int], sigma: Optional[List[floa return output +def gaussian_noise(img: Tensor, mean: float, sigma: float) -> Tensor: + """Performs Gaussian blurring on the image by given kernel. + If the image is torch Tensor, it is expected + to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions. + + Args: + img (PIL Image or Tensor): Image to be blurred + mean (float): Mean of the desired noise corruption. + sigma (float): Gaussian noise standard deviation. Can be a single float. + + .. note:: + In torchscript mode sigma as single float is + not supported, use a sequence of length 1: ``[sigma, ]``. + + Returns: + PIL Image or Tensor: Gaussian Blurred version of the image. + """ + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + _log_api_usage_once(gaussian_noise) + + if sigma is None: + raise ValueError("The value of sigma cannot be None.") + + if sigma is not None and not isinstance(sigma, (int, float)): + raise TypeError(f"sigma should be a float. Got {type(sigma)}") + if sigma <= 0.0: + raise ValueError(f"sigma should have positive values. Got {sigma}") + + t_img = img + if not isinstance(img, torch.Tensor): + if not F_pil._is_pil_image(img): + raise TypeError(f"img should be PIL Image or Tensor. Got {type(img)}") + + t_img = pil_to_tensor(img) + + output = F_t.gaussian_noise(t_img, mean, sigma) + + if not isinstance(img, torch.Tensor): + output = to_pil_image(output, mode=img.mode) + + return output + + def invert(img: Tensor) -> Tensor: """Invert the colors of an RGB/grayscale image. diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index d0e7c17882b..d546d4946d5 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -764,6 +764,21 @@ def gaussian_blur(img: Tensor, kernel_size: List[int], sigma: List[float]) -> Te return img +def gaussian_noise(img: Tensor, mean: float, sigma: float) -> Tensor: + if not (isinstance(img, torch.Tensor)): + raise TypeError(f"img should be Tensor. Got {type(img)}") + + _assert_image_tensor(img) + dtype = img.dtype if torch.is_floating_point(img) else torch.float32 + img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, [dtype]) + # add the gaussian noise with the given mean and sigma. + noise = sigma * torch.randn_like(img) + mean + img = img + noise + + img = _cast_squeeze_out(img, need_cast, need_squeeze, out_dtype) + return img + + def invert(img: Tensor) -> Tensor: _assert_image_tensor(img) diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index 90cb0374eee..37dd2994e77 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -46,6 +46,7 @@ "RandomPerspective", "RandomErasing", "GaussianBlur", + "GaussianNoise", "InterpolationMode", "RandomInvert", "RandomPosterize", @@ -1816,6 +1817,64 @@ def __repr__(self) -> str: return s +class GaussianNoise(torch.nn.Module): + """Adds Gaussian noise to the image with specified mean and standard deviation. + If the image is torch Tensor, it is expected + to have [..., C, H, W] shape, where ... means an arbitrary number of leading dimensions. + + Args: + mean (float or sequence): Mean of the sampling gaussian distribution . + sigma (float or tuple of float (min, max)): Standard deviation to be used for + sampling the gaussian noise. If float, sigma is fixed. If it is tuple + of float (min, max), sigma is chosen uniformly at random to lie in the + given range. + + Returns: + PIL Image or Tensor: Input image perturbed with Gaussian Noise. + + """ + + def __init__(self, mean, sigma=(0.1, 0.5)): + super().__init__() + _log_api_usage_once(self) + + if mean < 0: + raise ValueError("Mean should be a positive number") + + if isinstance(sigma, numbers.Number): + if sigma <= 0: + raise ValueError("If sigma is a single number, it must be positive.") + sigma = (sigma, sigma) + elif isinstance(sigma, Sequence) and len(sigma) == 2: + if not 0.0 < sigma[0] <= sigma[1]: + raise ValueError("sigma values should be positive and of the form (min, max).") + else: + raise ValueError("sigma should be a single number or a list/tuple with length 2.") + + self.mean = mean + self.sigma = sigma + + @staticmethod + def get_params(sigma_min: float, sigma_max: float) -> float: + return torch.empty(1).uniform_(sigma_min, sigma_max).item() + + def forward(self, image: Tensor) -> Tensor: + """ + Args: + image (PIL Image or Tensor): image to be perturbed with gaussian noise. + + Returns: + PIL Image or Tensor: Image added with gaussian noise. + """ + sigma = self.get_params(self.sigma[0], self.sigma[1]) + output = F.gaussian_noise(image, self.mean, sigma) + return output + + def __repr__(self) -> str: + s = f"{self.__class__.__name__}(mean={self.mean}, sigma={self.sigma})" + return s + + def _setup_size(size, error_msg): if isinstance(size, numbers.Number): return int(size), int(size)