Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unified Pad and F.pad opertion for PIL and Tensor inputs #2345

Merged
merged 6 commits into from
Jun 26, 2020
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 25 additions & 2 deletions test/test_functional_tensor.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,27 @@
import torch
from torch import Tensor
import torchvision.transforms as transforms
import torchvision.transforms.functional_tensor as F_t
import torchvision.transforms.functional_pil as F_pil
import torchvision.transforms.functional as F
import numpy as np
import unittest
import random
import colorsys
from torch.jit.annotations import Optional, List, BroadcastingList2, Tuple

from PIL import Image


class Tester(unittest.TestCase):

def _create_data(self, height=3, width=3, channels=3):
tensor = torch.randint(0, 255, (channels, height, width), dtype=torch.uint8)
pil_img = Image.fromarray(tensor.permute(1, 2, 0).contiguous().numpy())
return tensor, pil_img

def compareTensorToPIL(self, tensor, pil_image, msg=None):
pil_tensor = torch.as_tensor(np.array(pil_image).transpose((2, 0, 1)))
self.assertTrue(tensor.equal(pil_tensor), msg)

def test_vflip(self):
script_vflip = torch.jit.script(F_t.vflip)
img_tensor = torch.randn(3, 16, 16)
Expand Down Expand Up @@ -234,6 +244,19 @@ def test_ten_crop(self):
for cropped_script_img, cropped_tensor_img in zip(cropped_script, cropped_tensor):
self.assertTrue(torch.equal(cropped_script_img, cropped_tensor_img))

def test_pad(self):
script_fn = torch.jit.script(F_t.pad)
tensor, pil_img = self._create_data(7, 8)
for pad in [1, [0, 1], (2, 2), [1, 0, 1, 2]]:
padding_mode = 'constant'
for fill in [0, 10, 20]:
pad_tensor = F_t.pad(tensor, pad, fill=fill, padding_mode=padding_mode)
pad_pil_img = F_pil.pad(pil_img, pad, fill=fill, padding_mode=padding_mode)
self.compareTensorToPIL(pad_tensor, pad_pil_img, f'{pad}, {fill}')
if not isinstance(pad, int):
pad_tensor_script = script_fn(tensor, pad, fill=fill, padding_mode=padding_mode)
self.assertTrue(pad_tensor.equal(pad_tensor_script), f'{pad}, {fill}')


if __name__ == '__main__':
unittest.main()
34 changes: 24 additions & 10 deletions test/test_transforms_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,26 +18,30 @@ def compareTensorToPIL(self, tensor, pil_image):
pil_tensor = torch.as_tensor(np.array(pil_image).transpose((2, 0, 1)))
self.assertTrue(tensor.equal(pil_tensor))

def _test_flip(self, func, method):
tensor, pil_img = self._create_data()
flip_tensor = getattr(F, func)(tensor)
flip_pil_img = getattr(F, func)(pil_img)
self.compareTensorToPIL(flip_tensor, flip_pil_img)
def _test_geom_op(self, func, method, fn_kwargs=None, meth_kwargs=None):
if fn_kwargs is None:
fn_kwargs = {}
if meth_kwargs is None:
meth_kwargs = {}
tensor, pil_img = self._create_data(height=10, width=10)
transformed_tensor = getattr(F, func)(tensor, **fn_kwargs)
transformed_pil_img = getattr(F, func)(pil_img, **fn_kwargs)
self.compareTensorToPIL(transformed_tensor, transformed_pil_img)

scripted_fn = torch.jit.script(getattr(F, func))
flip_tensor_script = scripted_fn(tensor)
self.assertTrue(flip_tensor.equal(flip_tensor_script))
transformed_tensor_script = scripted_fn(tensor, **fn_kwargs)
self.assertTrue(transformed_tensor.equal(transformed_tensor_script))

# test for class interface
f = getattr(T, method)()
f = getattr(T, method)(**meth_kwargs)
scripted_fn = torch.jit.script(f)
scripted_fn(tensor)

def test_random_horizontal_flip(self):
self._test_flip('hflip', 'RandomHorizontalFlip')
self._test_geom_op('hflip', 'RandomHorizontalFlip')

def test_random_vertical_flip(self):
self._test_flip('vflip', 'RandomVerticalFlip')
self._test_geom_op('vflip', 'RandomVerticalFlip')

def test_adjustments(self):
fns = ['adjust_brightness', 'adjust_contrast', 'adjust_saturation']
Expand Down Expand Up @@ -65,6 +69,16 @@ def test_adjustments(self):
self.assertLess(max_diff, 5 / 255 + 1e-5)
self.assertLess(max_diff_scripted, 5 / 255 + 1e-5)

def test_pad(self):
fn_kwargs = meth_kwargs = {"padding": [4, 4], "fill": 0, "padding_mode": "constant"}
self._test_geom_op(
"pad", "Pad", fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
)
fn_kwargs = meth_kwargs = {"padding": (2, 2, 2, 2), "fill": 127, "padding_mode": 'constant'}
self._test_geom_op(
"pad", "Pad", fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
)


if __name__ == '__main__':
unittest.main()
99 changes: 25 additions & 74 deletions torchvision/transforms/functional.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
import torch
from torch import Tensor
import math
import numbers
import warnings
from collections.abc import Iterable

import numpy as np
from numpy import sin, cos, tan
from PIL import Image, ImageOps, ImageEnhance, __version__ as PILLOW_VERSION

import torch
from torch import Tensor
from torch.jit.annotations import List

try:
import accimage
except ImportError:
accimage = None
import numpy as np
from numpy import sin, cos, tan
import numbers
from collections.abc import Sequence, Iterable
import warnings

from . import functional_pil as F_pil
from . import functional_tensor as F_t
Expand Down Expand Up @@ -342,20 +346,23 @@ def scale(*args, **kwargs):
return resize(*args, **kwargs)


def pad(img, padding, fill=0, padding_mode='constant'):
r"""Pad the given PIL Image on all sides with specified padding mode and fill value.
def pad(img: Tensor, padding: List[int], fill: int = 0, padding_mode: str = "constant") -> Tensor:
r"""Pad the given image on all sides with the given "pad" value.
The image can be a PIL Image or a torch Tensor, in which case it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions

Args:
img (PIL Image): Image to be padded.
padding (int or tuple): Padding on each border. If a single int is provided this
img (PIL Image or Tensor): Image to be padded.
padding (int or tuple or list): Padding on each border. If a single int is provided this
is used to pad all borders. If tuple of length 2 is provided this is the padding
on left/right and top/bottom respectively. If a tuple of length 4 is provided
this is the padding for the left, top, right and bottom borders
respectively.
fill: Pixel fill value for constant fill. Default is 0. If a tuple of
respectively. Only list and tuple types are supported for Tensors.
fill (int or str or tuple): Pixel fill value for constant fill. Default is 0. If a tuple of
length 3, it is used to fill R, G, B channels respectively.
This value is only used when the padding_mode is constant
This value is only used when the padding_mode is constant. Only int value is supported for Tensors.
padding_mode: Type of padding. Should be: constant, edge, reflect or symmetric. Default is constant.
Only constant supported for Tensors.
vfdev-5 marked this conversation as resolved.
Show resolved Hide resolved

- constant: pads with a constant value, this value is specified with fill

Expand All @@ -372,68 +379,12 @@ def pad(img, padding, fill=0, padding_mode='constant'):
will result in [2, 1, 1, 2, 3, 4, 4, 3]

Returns:
PIL Image: Padded image.
PIL Image or Tensor: Padded image.
"""
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

if not isinstance(padding, (numbers.Number, tuple)):
raise TypeError('Got inappropriate padding arg')
if not isinstance(fill, (numbers.Number, str, tuple)):
raise TypeError('Got inappropriate fill arg')
if not isinstance(padding_mode, str):
raise TypeError('Got inappropriate padding_mode arg')

if isinstance(padding, Sequence) and len(padding) not in [2, 4]:
raise ValueError("Padding must be an int or a 2, or 4 element tuple, not a " +
"{} element tuple".format(len(padding)))

assert padding_mode in ['constant', 'edge', 'reflect', 'symmetric'], \
'Padding mode should be either constant, edge, reflect or symmetric'

if padding_mode == 'constant':
if isinstance(fill, numbers.Number):
fill = (fill,) * len(img.getbands())
if len(fill) != len(img.getbands()):
raise ValueError('fill should have the same number of elements '
'as the number of channels in the image '
'({}), got {} instead'.format(len(img.getbands()), len(fill)))
if img.mode == 'P':
palette = img.getpalette()
image = ImageOps.expand(img, border=padding, fill=fill)
image.putpalette(palette)
return image

return ImageOps.expand(img, border=padding, fill=fill)
else:
if isinstance(padding, int):
pad_left = pad_right = pad_top = pad_bottom = padding
if isinstance(padding, Sequence) and len(padding) == 2:
pad_left = pad_right = padding[0]
pad_top = pad_bottom = padding[1]
if isinstance(padding, Sequence) and len(padding) == 4:
pad_left = padding[0]
pad_top = padding[1]
pad_right = padding[2]
pad_bottom = padding[3]

if img.mode == 'P':
palette = img.getpalette()
img = np.asarray(img)
img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right)), padding_mode)
img = Image.fromarray(img)
img.putpalette(palette)
return img

img = np.asarray(img)
# RGB image
if len(img.shape) == 3:
img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right), (0, 0)), padding_mode)
# Grayscale image
if len(img.shape) == 2:
img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right)), padding_mode)
if not isinstance(img, torch.Tensor):
return F_pil.pad(img, padding=padding, fill=fill, padding_mode=padding_mode)

return Image.fromarray(img)
return F_t.pad(img, padding=padding, fill=fill, padding_mode=padding_mode)


def crop(img, top, left, height, width):
Expand Down
103 changes: 102 additions & 1 deletion torchvision/transforms/functional_pil.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import numbers

import torch
try:
import accimage
except ImportError:
accimage = None
from PIL import Image, ImageOps, ImageEnhance, __version__ as PILLOW_VERSION
from PIL import Image, ImageOps, ImageEnhance
import numpy as np


Expand Down Expand Up @@ -152,3 +154,102 @@ def adjust_hue(img, hue_factor):

img = Image.merge('HSV', (h, s, v)).convert(input_mode)
return img


@torch.jit.unused
def pad(img, padding, fill=0, padding_mode="constant"):
r"""Pad the given PIL.Image on all sides with the given "pad" value.

Args:
img (PIL Image): Image to be padded.
padding (int or tuple or list): Padding on each border. If a single int is provided this
is used to pad all borders. If tuple of length 2 is provided this is the padding
on left/right and top/bottom respectively. If a tuple of length 4 is provided
this is the padding for the left, top, right and bottom borders
respectively.
fill (int or str or tuple): Pixel fill value for constant fill. Default is 0. If a tuple of
length 3, it is used to fill R, G, B channels respectively.
This value is only used when the padding_mode is constant.
padding_mode: Type of padding. Should be: constant, edge, reflect or symmetric. Default is constant.

- constant: pads with a constant value, this value is specified with fill

- edge: pads with the last value on the edge of the image

- reflect: pads with reflection of image (without repeating the last value on the edge)

padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode
will result in [3, 2, 1, 2, 3, 4, 3, 2]

- symmetric: pads with reflection of image (repeating the last value on the edge)

padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode
will result in [2, 1, 1, 2, 3, 4, 4, 3]

Returns:
PIL Image: Padded image.
"""

if not _is_pil_image(img):
raise TypeError("img should be PIL Image. Got {}".format(type(img)))

if not isinstance(padding, (numbers.Number, tuple, list)):
raise TypeError("Got inappropriate padding arg")
if not isinstance(fill, (numbers.Number, str, tuple)):
raise TypeError("Got inappropriate fill arg")
if not isinstance(padding_mode, str):
raise TypeError("Got inappropriate padding_mode arg")

if isinstance(padding, list):
padding = tuple(padding)

if isinstance(padding, tuple) and len(padding) not in [2, 4]:
raise ValueError("Padding must be an int or a 2, or 4 element tuple, not a " +
"{} element tuple".format(len(padding)))

if padding_mode not in ["constant", "edge", "reflect", "symmetric"]:
raise ValueError("Padding mode should be either constant, edge, reflect or symmetric")

if padding_mode == "constant":
if isinstance(fill, numbers.Number):
fill = (fill,) * len(img.getbands())
if len(fill) != len(img.getbands()):
raise ValueError("fill should have the same number of elements "
"as the number of channels in the image "
"({}), got {} instead".format(len(img.getbands()), len(fill)))
if img.mode == "P":
palette = img.getpalette()
image = ImageOps.expand(img, border=padding, fill=fill)
image.putpalette(palette)
return image

return ImageOps.expand(img, border=padding, fill=fill)
else:
if isinstance(padding, int):
pad_left = pad_right = pad_top = pad_bottom = padding
if isinstance(padding, tuple) and len(padding) == 2:
pad_left = pad_right = padding[0]
pad_top = pad_bottom = padding[1]
if isinstance(padding, tuple) and len(padding) == 4:
pad_left = padding[0]
pad_top = padding[1]
pad_right = padding[2]
pad_bottom = padding[3]

if img.mode == 'P':
palette = img.getpalette()
img = np.asarray(img)
img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right)), padding_mode)
img = Image.fromarray(img)
img.putpalette(palette)
return img

img = np.asarray(img)
# RGB image
if len(img.shape) == 3:
img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right), (0, 0)), padding_mode)
# Grayscale image
if len(img.shape) == 2:
img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right)), padding_mode)

return Image.fromarray(img)
Loading