diff --git a/test/test_functional_tensor.py b/test/test_functional_tensor.py index 1a8c77c827f..07a699345bd 100644 --- a/test/test_functional_tensor.py +++ b/test/test_functional_tensor.py @@ -1,17 +1,27 @@ import torch -from torch import Tensor import torchvision.transforms as transforms import torchvision.transforms.functional_tensor as F_t +import torchvision.transforms.functional_pil as F_pil import torchvision.transforms.functional as F import numpy as np import unittest import random import colorsys -from torch.jit.annotations import Optional, List, BroadcastingList2, Tuple + +from PIL import Image class Tester(unittest.TestCase): + def _create_data(self, height=3, width=3, channels=3): + tensor = torch.randint(0, 255, (channels, height, width), dtype=torch.uint8) + pil_img = Image.fromarray(tensor.permute(1, 2, 0).contiguous().numpy()) + return tensor, pil_img + + def compareTensorToPIL(self, tensor, pil_image, msg=None): + pil_tensor = torch.as_tensor(np.array(pil_image).transpose((2, 0, 1))) + self.assertTrue(tensor.equal(pil_tensor), msg) + def test_vflip(self): script_vflip = torch.jit.script(F_t.vflip) img_tensor = torch.randn(3, 16, 16) @@ -234,6 +244,22 @@ def test_ten_crop(self): for cropped_script_img, cropped_tensor_img in zip(cropped_script, cropped_tensor): self.assertTrue(torch.equal(cropped_script_img, cropped_tensor_img)) + def test_pad(self): + script_fn = torch.jit.script(F_t.pad) + tensor, pil_img = self._create_data(7, 8) + for pad in [1, [1, ], [0, 1], (2, 2), [1, 0, 1, 2]]: + padding_mode = "constant" + for fill in [0, 10, 20]: + pad_tensor = F_t.pad(tensor, pad, fill=fill, padding_mode=padding_mode) + pad_pil_img = F_pil.pad(pil_img, pad, fill=fill, padding_mode=padding_mode) + self.compareTensorToPIL(pad_tensor, pad_pil_img, msg="{}, {}".format(pad, fill)) + if isinstance(pad, int): + script_pad = [pad, ] + else: + script_pad = pad + pad_tensor_script = script_fn(tensor, script_pad, fill=fill, padding_mode=padding_mode) + self.assertTrue(pad_tensor.equal(pad_tensor_script), msg="{}, {}".format(pad, fill)) + if __name__ == '__main__': unittest.main() diff --git a/test/test_transforms_tensor.py b/test/test_transforms_tensor.py index 7791dd8b4f9..1479602b534 100644 --- a/test/test_transforms_tensor.py +++ b/test/test_transforms_tensor.py @@ -18,26 +18,38 @@ def compareTensorToPIL(self, tensor, pil_image): pil_tensor = torch.as_tensor(np.array(pil_image).transpose((2, 0, 1))) self.assertTrue(tensor.equal(pil_tensor)) - def _test_flip(self, func, method): - tensor, pil_img = self._create_data() - flip_tensor = getattr(F, func)(tensor) - flip_pil_img = getattr(F, func)(pil_img) - self.compareTensorToPIL(flip_tensor, flip_pil_img) + def _test_functional_geom_op(self, func, fn_kwargs): + if fn_kwargs is None: + fn_kwargs = {} + tensor, pil_img = self._create_data(height=10, width=10) + transformed_tensor = getattr(F, func)(tensor, **fn_kwargs) + transformed_pil_img = getattr(F, func)(pil_img, **fn_kwargs) + self.compareTensorToPIL(transformed_tensor, transformed_pil_img) + + def _test_geom_op(self, func, method, fn_kwargs=None, meth_kwargs=None): + if fn_kwargs is None: + fn_kwargs = {} + if meth_kwargs is None: + meth_kwargs = {} + tensor, pil_img = self._create_data(height=10, width=10) + transformed_tensor = getattr(F, func)(tensor, **fn_kwargs) + transformed_pil_img = getattr(F, func)(pil_img, **fn_kwargs) + self.compareTensorToPIL(transformed_tensor, transformed_pil_img) scripted_fn = torch.jit.script(getattr(F, func)) - flip_tensor_script = scripted_fn(tensor) - self.assertTrue(flip_tensor.equal(flip_tensor_script)) + transformed_tensor_script = scripted_fn(tensor, **fn_kwargs) + self.assertTrue(transformed_tensor.equal(transformed_tensor_script)) # test for class interface - f = getattr(T, method)() + f = getattr(T, method)(**meth_kwargs) scripted_fn = torch.jit.script(f) scripted_fn(tensor) def test_random_horizontal_flip(self): - self._test_flip('hflip', 'RandomHorizontalFlip') + self._test_geom_op('hflip', 'RandomHorizontalFlip') def test_random_vertical_flip(self): - self._test_flip('vflip', 'RandomVerticalFlip') + self._test_geom_op('vflip', 'RandomVerticalFlip') def test_adjustments(self): fns = ['adjust_brightness', 'adjust_contrast', 'adjust_saturation'] @@ -65,6 +77,28 @@ def test_adjustments(self): self.assertLess(max_diff, 5 / 255 + 1e-5) self.assertLess(max_diff_scripted, 5 / 255 + 1e-5) + def test_pad(self): + + # Test functional.pad (PIL and Tensor) with padding as single int + self._test_functional_geom_op( + "pad", fn_kwargs={"padding": 2, "fill": 0, "padding_mode": "constant"} + ) + # Test functional.pad and transforms.Pad with padding as [int, ] + fn_kwargs = meth_kwargs = {"padding": [2, ], "fill": 0, "padding_mode": "constant"} + self._test_geom_op( + "pad", "Pad", fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs + ) + # Test functional.pad and transforms.Pad with padding as list + fn_kwargs = meth_kwargs = {"padding": [4, 4], "fill": 0, "padding_mode": "constant"} + self._test_geom_op( + "pad", "Pad", fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs + ) + # Test functional.pad and transforms.Pad with padding as tuple + fn_kwargs = meth_kwargs = {"padding": (2, 2, 2, 2), "fill": 127, "padding_mode": "constant"} + self._test_geom_op( + "pad", "Pad", fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs + ) + if __name__ == '__main__': unittest.main() diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 5d8549ea883..06a54c6aa5f 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -1,16 +1,20 @@ -import torch -from torch import Tensor import math +import numbers +import warnings +from collections.abc import Iterable + +import numpy as np +from numpy import sin, cos, tan from PIL import Image, ImageOps, ImageEnhance, __version__ as PILLOW_VERSION + +import torch +from torch import Tensor +from torch.jit.annotations import List + try: import accimage except ImportError: accimage = None -import numpy as np -from numpy import sin, cos, tan -import numbers -from collections.abc import Sequence, Iterable -import warnings from . import functional_pil as F_pil from . import functional_tensor as F_t @@ -342,20 +346,24 @@ def scale(*args, **kwargs): return resize(*args, **kwargs) -def pad(img, padding, fill=0, padding_mode='constant'): - r"""Pad the given PIL Image on all sides with specified padding mode and fill value. +def pad(img: Tensor, padding: List[int], fill: int = 0, padding_mode: str = "constant") -> Tensor: + r"""Pad the given image on all sides with the given "pad" value. + The image can be a PIL Image or a torch Tensor, in which case it is expected + to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions Args: - img (PIL Image): Image to be padded. - padding (int or tuple): Padding on each border. If a single int is provided this + img (PIL Image or Tensor): Image to be padded. + padding (int or tuple or list): Padding on each border. If a single int is provided this is used to pad all borders. If tuple of length 2 is provided this is the padding on left/right and top/bottom respectively. If a tuple of length 4 is provided - this is the padding for the left, top, right and bottom borders - respectively. - fill: Pixel fill value for constant fill. Default is 0. If a tuple of + this is the padding for the left, top, right and bottom borders respectively. + In torchscript mode padding as single int is not supported, use a tuple or + list of length 1: ``[padding, ]``. + fill (int or str or tuple): Pixel fill value for constant fill. Default is 0. If a tuple of length 3, it is used to fill R, G, B channels respectively. - This value is only used when the padding_mode is constant + This value is only used when the padding_mode is constant. Only int value is supported for Tensors. padding_mode: Type of padding. Should be: constant, edge, reflect or symmetric. Default is constant. + Only "constant" is supported for Tensors as of now. - constant: pads with a constant value, this value is specified with fill @@ -372,68 +380,12 @@ def pad(img, padding, fill=0, padding_mode='constant'): will result in [2, 1, 1, 2, 3, 4, 4, 3] Returns: - PIL Image: Padded image. + PIL Image or Tensor: Padded image. """ - if not _is_pil_image(img): - raise TypeError('img should be PIL Image. Got {}'.format(type(img))) - - if not isinstance(padding, (numbers.Number, tuple)): - raise TypeError('Got inappropriate padding arg') - if not isinstance(fill, (numbers.Number, str, tuple)): - raise TypeError('Got inappropriate fill arg') - if not isinstance(padding_mode, str): - raise TypeError('Got inappropriate padding_mode arg') - - if isinstance(padding, Sequence) and len(padding) not in [2, 4]: - raise ValueError("Padding must be an int or a 2, or 4 element tuple, not a " + - "{} element tuple".format(len(padding))) - - assert padding_mode in ['constant', 'edge', 'reflect', 'symmetric'], \ - 'Padding mode should be either constant, edge, reflect or symmetric' - - if padding_mode == 'constant': - if isinstance(fill, numbers.Number): - fill = (fill,) * len(img.getbands()) - if len(fill) != len(img.getbands()): - raise ValueError('fill should have the same number of elements ' - 'as the number of channels in the image ' - '({}), got {} instead'.format(len(img.getbands()), len(fill))) - if img.mode == 'P': - palette = img.getpalette() - image = ImageOps.expand(img, border=padding, fill=fill) - image.putpalette(palette) - return image - - return ImageOps.expand(img, border=padding, fill=fill) - else: - if isinstance(padding, int): - pad_left = pad_right = pad_top = pad_bottom = padding - if isinstance(padding, Sequence) and len(padding) == 2: - pad_left = pad_right = padding[0] - pad_top = pad_bottom = padding[1] - if isinstance(padding, Sequence) and len(padding) == 4: - pad_left = padding[0] - pad_top = padding[1] - pad_right = padding[2] - pad_bottom = padding[3] - - if img.mode == 'P': - palette = img.getpalette() - img = np.asarray(img) - img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right)), padding_mode) - img = Image.fromarray(img) - img.putpalette(palette) - return img - - img = np.asarray(img) - # RGB image - if len(img.shape) == 3: - img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right), (0, 0)), padding_mode) - # Grayscale image - if len(img.shape) == 2: - img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right)), padding_mode) + if not isinstance(img, torch.Tensor): + return F_pil.pad(img, padding=padding, fill=fill, padding_mode=padding_mode) - return Image.fromarray(img) + return F_t.pad(img, padding=padding, fill=fill, padding_mode=padding_mode) def crop(img, top, left, height, width): diff --git a/torchvision/transforms/functional_pil.py b/torchvision/transforms/functional_pil.py index 84e27e79040..3786d0e31a7 100644 --- a/torchvision/transforms/functional_pil.py +++ b/torchvision/transforms/functional_pil.py @@ -1,9 +1,11 @@ +import numbers + import torch try: import accimage except ImportError: accimage = None -from PIL import Image, ImageOps, ImageEnhance, __version__ as PILLOW_VERSION +from PIL import Image, ImageOps, ImageEnhance import numpy as np @@ -152,3 +154,107 @@ def adjust_hue(img, hue_factor): img = Image.merge('HSV', (h, s, v)).convert(input_mode) return img + + +@torch.jit.unused +def pad(img, padding, fill=0, padding_mode="constant"): + r"""Pad the given PIL.Image on all sides with the given "pad" value. + + Args: + img (PIL Image): Image to be padded. + padding (int or tuple or list): Padding on each border. If a single int is provided this + is used to pad all borders. If a tuple or list of length 2 is provided this is the padding + on left/right and top/bottom respectively. If a tuple or list of length 4 is provided + this is the padding for the left, top, right and bottom borders respectively. For compatibility reasons + with ``functional_tensor.pad``, if a tuple or list of length 1 is provided, it is interpreted as + a single int. + fill (int or str or tuple): Pixel fill value for constant fill. Default is 0. If a tuple of + length 3, it is used to fill R, G, B channels respectively. + This value is only used when the padding_mode is constant. + padding_mode: Type of padding. Should be: constant, edge, reflect or symmetric. Default is constant. + + - constant: pads with a constant value, this value is specified with fill + + - edge: pads with the last value on the edge of the image + + - reflect: pads with reflection of image (without repeating the last value on the edge) + + padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode + will result in [3, 2, 1, 2, 3, 4, 3, 2] + + - symmetric: pads with reflection of image (repeating the last value on the edge) + + padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode + will result in [2, 1, 1, 2, 3, 4, 4, 3] + + Returns: + PIL Image: Padded image. + """ + + if not _is_pil_image(img): + raise TypeError("img should be PIL Image. Got {}".format(type(img))) + + if not isinstance(padding, (numbers.Number, tuple, list)): + raise TypeError("Got inappropriate padding arg") + if not isinstance(fill, (numbers.Number, str, tuple)): + raise TypeError("Got inappropriate fill arg") + if not isinstance(padding_mode, str): + raise TypeError("Got inappropriate padding_mode arg") + + if isinstance(padding, list): + padding = tuple(padding) + + if isinstance(padding, tuple) and len(padding) not in [1, 2, 4]: + raise ValueError("Padding must be an int or a 1, 2, or 4 element tuple, not a " + + "{} element tuple".format(len(padding))) + + if isinstance(padding, tuple) and len(padding) == 1: + # Compatibility with `functional_tensor.pad` + padding = padding[0] + + if padding_mode not in ["constant", "edge", "reflect", "symmetric"]: + raise ValueError("Padding mode should be either constant, edge, reflect or symmetric") + + if padding_mode == "constant": + if isinstance(fill, numbers.Number): + fill = (fill,) * len(img.getbands()) + if len(fill) != len(img.getbands()): + raise ValueError("fill should have the same number of elements " + "as the number of channels in the image " + "({}), got {} instead".format(len(img.getbands()), len(fill))) + if img.mode == "P": + palette = img.getpalette() + image = ImageOps.expand(img, border=padding, fill=fill) + image.putpalette(palette) + return image + + return ImageOps.expand(img, border=padding, fill=fill) + else: + if isinstance(padding, int): + pad_left = pad_right = pad_top = pad_bottom = padding + if isinstance(padding, tuple) and len(padding) == 2: + pad_left = pad_right = padding[0] + pad_top = pad_bottom = padding[1] + if isinstance(padding, tuple) and len(padding) == 4: + pad_left = padding[0] + pad_top = padding[1] + pad_right = padding[2] + pad_bottom = padding[3] + + if img.mode == 'P': + palette = img.getpalette() + img = np.asarray(img) + img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right)), padding_mode) + img = Image.fromarray(img) + img.putpalette(palette) + return img + + img = np.asarray(img) + # RGB image + if len(img.shape) == 3: + img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right), (0, 0)), padding_mode) + # Grayscale image + if len(img.shape) == 2: + img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right)), padding_mode) + + return Image.fromarray(img) diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index 89440701d17..56703d0a1fd 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -1,6 +1,6 @@ import torch from torch import Tensor -from torch.jit.annotations import Optional, List, BroadcastingList2, Tuple +from torch.jit.annotations import List, BroadcastingList2 def _is_tensor_a_torch_image(input): @@ -327,3 +327,64 @@ def _hsv2rgb(img): a4 = torch.stack((a1, a2, a3)) return torch.einsum("ijk, xijk -> xjk", mask.to(dtype=img.dtype), a4) + + +def pad(img: Tensor, padding: List[int], fill: int, padding_mode: str = "constant") -> Tensor: + r"""Pad the given Tensor Image on all sides with specified padding mode and fill value. + + Args: + img (Tensor): Image to be padded. + padding (int or tuple or list): Padding on each border. If a single int is provided this + is used to pad all borders. If a tuple or list of length 2 is provided this is the padding + on left/right and top/bottom respectively. If a tuple or list of length 4 is provided + this is the padding for the left, top, right and bottom borders + respectively. In torchscript mode padding as single int is not supported, use a tuple or + list of length 1: ``[padding, ]``. + fill (int): Pixel fill value for constant fill. Default is 0. + This value is only used when the padding_mode is constant + padding_mode (str): Type of padding. Only "constant" is supported for Tensors as of now. + + - constant: pads with a constant value, this value is specified with fill + + Returns: + Tensor: Padded image. + """ + if not _is_tensor_a_torch_image(img): + raise TypeError("tensor is not a torch image.") + + if not isinstance(padding, (int, tuple, list)): + raise TypeError("Got inappropriate padding arg") + if not isinstance(fill, (int, float)): + raise TypeError("Got inappropriate fill arg") + if not isinstance(padding_mode, str): + raise TypeError("Got inappropriate padding_mode arg") + + if isinstance(padding, tuple): + padding = list(padding) + + if isinstance(padding, list) and len(padding) not in [1, 2, 4]: + raise ValueError("Padding must be an int or a 1, 2, or 4 element tuple, not a " + + "{} element tuple".format(len(padding))) + + if padding_mode not in ["constant", ]: + raise ValueError("Only constant padding_mode supported for torch tensors") + + if isinstance(padding, int): + if torch.jit.is_scripting(): + raise ValueError("padding can't be an int while torchscripting, set it as a list [value, ]") + pad_left = pad_right = pad_top = pad_bottom = padding + elif len(padding) == 1: + pad_left = pad_right = pad_top = pad_bottom = padding[0] + elif len(padding) == 2: + pad_left = pad_right = padding[0] + pad_top = pad_bottom = padding[1] + else: + pad_left = padding[0] + pad_top = padding[1] + pad_right = padding[2] + pad_bottom = padding[3] + + p = [pad_left, pad_right, pad_top, pad_bottom] + + img = torch.nn.functional.pad(img, p, mode=padding_mode, value=float(fill)) + return img diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index d54aa5099f2..16dcca81a72 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -287,20 +287,23 @@ def __repr__(self): return self.__class__.__name__ + '(size={0})'.format(self.size) -class Pad(object): - """Pad the given PIL Image on all sides with the given "pad" value. +class Pad(torch.nn.Module): + """Pad the given image on all sides with the given "pad" value. + The image can be a PIL Image or a torch Tensor, in which case it is expected + to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions Args: - padding (int or tuple): Padding on each border. If a single int is provided this + padding (int or tuple or list): Padding on each border. If a single int is provided this is used to pad all borders. If tuple of length 2 is provided this is the padding on left/right and top/bottom respectively. If a tuple of length 4 is provided - this is the padding for the left, top, right and bottom borders - respectively. + this is the padding for the left, top, right and bottom borders respectively. + In torchscript mode padding as single int is not supported, use a tuple or + list of length 1: ``[padding, ]``. fill (int or tuple): Pixel fill value for constant fill. Default is 0. If a tuple of length 3, it is used to fill R, G, B channels respectively. This value is only used when the padding_mode is constant padding_mode (str): Type of padding. Should be: constant, edge, reflect or symmetric. - Default is constant. + Default is constant. Only "constant" is supported for Tensors as of now. - constant: pads with a constant value, this value is specified with fill @@ -317,25 +320,32 @@ class Pad(object): will result in [2, 1, 1, 2, 3, 4, 4, 3] """ - def __init__(self, padding, fill=0, padding_mode='constant'): - assert isinstance(padding, (numbers.Number, tuple)) - assert isinstance(fill, (numbers.Number, str, tuple)) - assert padding_mode in ['constant', 'edge', 'reflect', 'symmetric'] - if isinstance(padding, Sequence) and len(padding) not in [2, 4]: - raise ValueError("Padding must be an int or a 2, or 4 element tuple, not a " + + def __init__(self, padding, fill=0, padding_mode="constant"): + super().__init__() + if not isinstance(padding, (numbers.Number, tuple, list)): + raise TypeError("Got inappropriate padding arg") + + if not isinstance(fill, (numbers.Number, str, tuple)): + raise TypeError("Got inappropriate fill arg") + + if padding_mode not in ["constant", "edge", "reflect", "symmetric"]: + raise ValueError("Padding mode should be either constant, edge, reflect or symmetric") + + if isinstance(padding, Sequence) and len(padding) not in [1, 2, 4]: + raise ValueError("Padding must be an int or a 1, 2, or 4 element tuple, not a " + "{} element tuple".format(len(padding))) self.padding = padding self.fill = fill self.padding_mode = padding_mode - def __call__(self, img): + def forward(self, img): """ Args: - img (PIL Image): Image to be padded. + img (PIL Image or Tensor): Image to be padded. Returns: - PIL Image: Padded image. + PIL Image or Tensor: Padded image. """ return F.pad(img, self.padding, self.fill, self.padding_mode)