From 7f1a05a346cf1d5c6fe6b0af3ffb3e4b0b651ff3 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Wed, 2 Dec 2020 18:19:47 +0000 Subject: [PATCH] Check num of channels on adjust_* transformations (#3069) * Fixing upperbound value on tests and documentation. * Limit the number of channels on adjust_* transoforms. --- test/common_utils.py | 4 +-- torchvision/transforms/functional_tensor.py | 28 +++++++++++++++------ 2 files changed, 23 insertions(+), 9 deletions(-) diff --git a/test/common_utils.py b/test/common_utils.py index fbfc64bd76b..5b4d1cf2d14 100644 --- a/test/common_utils.py +++ b/test/common_utils.py @@ -339,13 +339,13 @@ def freeze_rng_state(): class TransformsTester(unittest.TestCase): def _create_data(self, height=3, width=3, channels=3, device="cpu"): - tensor = torch.randint(0, 255, (channels, height, width), dtype=torch.uint8, device=device) + tensor = torch.randint(0, 256, (channels, height, width), dtype=torch.uint8, device=device) pil_img = Image.fromarray(tensor.permute(1, 2, 0).contiguous().cpu().numpy()) return tensor, pil_img def _create_data_batch(self, height=3, width=3, channels=3, num_samples=4, device="cpu"): batch_tensor = torch.randint( - 0, 255, + 0, 256, (num_samples, channels, height, width), dtype=torch.uint8, device=device diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index d21e2d6220e..0c72a745bba 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -1,5 +1,5 @@ import warnings -from typing import Optional, Dict, Tuple +from typing import Optional, Tuple import torch from torch import Tensor @@ -45,6 +45,12 @@ def _max_value(dtype: torch.dtype) -> float: return max_value.item() +def _assert_channels(img: Tensor, permitted: List[int]) -> None: + c = _get_image_num_channels(img) + if c not in permitted: + raise TypeError("Input image tensor permitted channel values are {}, but found {}".format(permitted, c)) + + def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float) -> torch.Tensor: """PRIVATE METHOD. Convert a tensor image to the given ``dtype`` and scale the values accordingly @@ -210,9 +216,7 @@ def rgb_to_grayscale(img: Tensor, num_output_channels: int = 1) -> Tensor: """ if img.ndim < 3: raise TypeError("Input image tensor should have at least 3 dimensions, but found {}".format(img.ndim)) - c = img.shape[-3] - if c != 3: - raise TypeError("Input image tensor should 3 channels, but found {}".format(c)) + _assert_channels(img, [3]) if num_output_channels not in (1, 3): raise ValueError('num_output_channels should be either 1 or 3') @@ -230,7 +234,7 @@ def rgb_to_grayscale(img: Tensor, num_output_channels: int = 1) -> Tensor: def adjust_brightness(img: Tensor, brightness_factor: float) -> Tensor: - """PRIVATE METHOD. Adjust brightness of an RGB image. + """PRIVATE METHOD. Adjust brightness of a Grayscale or RGB image. .. warning:: @@ -252,6 +256,8 @@ def adjust_brightness(img: Tensor, brightness_factor: float) -> Tensor: if not _is_tensor_a_torch_image(img): raise TypeError('tensor is not a torch image.') + _assert_channels(img, [1, 3]) + return _blend(img, torch.zeros_like(img), brightness_factor) @@ -278,6 +284,8 @@ def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor: if not _is_tensor_a_torch_image(img): raise TypeError('tensor is not a torch image.') + _assert_channels(img, [3]) + dtype = img.dtype if torch.is_floating_point(img) else torch.float32 mean = torch.mean(rgb_to_grayscale(img).to(dtype), dim=(-3, -2, -1), keepdim=True) @@ -285,7 +293,7 @@ def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor: def adjust_hue(img: Tensor, hue_factor: float) -> Tensor: - """PRIVATE METHOD. Adjust hue of an image. + """PRIVATE METHOD. Adjust hue of an RGB image. .. warning:: @@ -320,6 +328,8 @@ def adjust_hue(img: Tensor, hue_factor: float) -> Tensor: if not (isinstance(img, torch.Tensor) and _is_tensor_a_torch_image(img)): raise TypeError('Input img should be Tensor image') + _assert_channels(img, [3]) + orig_dtype = img.dtype if img.dtype == torch.uint8: img = img.to(dtype=torch.float32) / 255.0 @@ -359,11 +369,13 @@ def adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor: if not _is_tensor_a_torch_image(img): raise TypeError('tensor is not a torch image.') + _assert_channels(img, [3]) + return _blend(img, rgb_to_grayscale(img), saturation_factor) def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor: - r"""PRIVATE METHOD. Adjust gamma of an RGB image. + r"""PRIVATE METHOD. Adjust gamma of a Grayscale or RGB image. .. warning:: @@ -391,6 +403,8 @@ def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor: if not isinstance(img, torch.Tensor): raise TypeError('Input img should be a Tensor.') + _assert_channels(img, [1, 3]) + if gamma < 0: raise ValueError('Gamma should be a non-negative real number')