Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Tensor support for some transforms #1104

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 81 additions & 73 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,48 +23,45 @@
os.path.dirname(os.path.abspath(__file__)), 'assets', 'grace_hopper_517x606.jpg')


def transform_helper(t, is_pil=True):
t = [t]
if is_pil:
t.insert(0, transforms.ToPILImage())
t.append(transforms.ToTensor())
return transforms.Compose(t)


class Tester(unittest.TestCase):

def test_crop(self):
height = random.randint(10, 32) * 2
width = random.randint(10, 32) * 2
oheight = random.randint(5, (height - 2) / 2) * 2
owidth = random.randint(5, (width - 2) / 2) * 2

img = torch.ones(3, height, width)
oh1 = (height - oheight) // 2
ow1 = (width - owidth) // 2
imgnarrow = img[:, oh1:oh1 + oheight, ow1:ow1 + owidth]
imgnarrow.fill_(0)
result = transforms.Compose([
transforms.ToPILImage(),
transforms.CenterCrop((oheight, owidth)),
transforms.ToTensor(),
])(img)
assert result.sum() == 0, "height: " + str(height) + " width: " \
+ str(width) + " oheight: " + str(oheight) + " owidth: " + str(owidth)
oheight += 1
owidth += 1
result = transforms.Compose([
transforms.ToPILImage(),
transforms.CenterCrop((oheight, owidth)),
transforms.ToTensor(),
])(img)
sum1 = result.sum()
assert sum1 > 1, "height: " + str(height) + " width: " \
+ str(width) + " oheight: " + str(oheight) + " owidth: " + str(owidth)
oheight += 1
owidth += 1
result = transforms.Compose([
transforms.ToPILImage(),
transforms.CenterCrop((oheight, owidth)),
transforms.ToTensor(),
])(img)
sum2 = result.sum()
assert sum2 > 0, "height: " + str(height) + " width: " \
+ str(width) + " oheight: " + str(oheight) + " owidth: " + str(owidth)
assert sum2 > sum1, "height: " + str(height) + " width: " \
+ str(width) + " oheight: " + str(oheight) + " owidth: " + str(owidth)
for is_pil in [True, False]:
oheight = random.randint(5, (height - 2) / 2) * 2
owidth = random.randint(5, (width - 2) / 2) * 2
img = torch.ones(3, height, width)
oh1 = (height - oheight) // 2
ow1 = (width - owidth) // 2
imgnarrow = img[:, oh1:oh1 + oheight, ow1:ow1 + owidth]
imgnarrow.fill_(0)
result = transform_helper(transforms.CenterCrop((oheight, owidth)), is_pil)(img)
assert result.sum() == 0, "height: " + str(height) + " width: " \
+ str(width) + " oheight: " + str(oheight) + " owidth: " + str(owidth)
oheight += 1
owidth += 1
result = transform_helper(transforms.CenterCrop((oheight, owidth)), is_pil)(img)
sum1 = result.sum()
assert sum1 > 1, "height: " + str(height) + " width: " \
+ str(width) + " oheight: " + str(oheight) + " owidth: " + str(owidth)
oheight += 1
owidth += 1
result = transform_helper(transforms.CenterCrop((oheight, owidth)), is_pil)(img)
sum2 = result.sum()
assert sum2 > 0, "height: " + str(height) + " width: " \
+ str(width) + " oheight: " + str(oheight) + " owidth: " + str(owidth)
assert sum2 > sum1, "height: " + str(height) + " width: " \
+ str(width) + " oheight: " + str(oheight) + " owidth: " + str(owidth)

def test_five_crop(self):
to_pil_image = transforms.ToPILImage()
Expand All @@ -87,7 +84,6 @@ def test_five_crop(self):
for crop in results:
assert crop.size == (crop_w, crop_h)

to_pil_image = transforms.ToPILImage()
tl = to_pil_image(img[:, 0:crop_h, 0:crop_w])
tr = to_pil_image(img[:, 0:crop_h, w - crop_w:])
bl = to_pil_image(img[:, h - crop_h:, 0:crop_w])
Expand Down Expand Up @@ -175,46 +171,37 @@ def test_randomperspective(self):
def test_resize(self):
height = random.randint(24, 32) * 2
width = random.randint(24, 32) * 2
osize = random.randint(5, 12) * 2

img = torch.ones(3, height, width)
result = transforms.Compose([
transforms.ToPILImage(),
transforms.Resize(osize),
transforms.ToTensor(),
])(img)
assert osize in result.size()
if height < width:
assert result.size(1) <= result.size(2)
elif width < height:
assert result.size(1) >= result.size(2)

result = transforms.Compose([
transforms.ToPILImage(),
transforms.Resize([osize, osize]),
transforms.ToTensor(),
])(img)
assert osize in result.size()
assert result.size(1) == osize
assert result.size(2) == osize

osize = random.randint(5, 12) * 2
oheight = random.randint(5, 12) * 2
owidth = random.randint(5, 12) * 2
result = transforms.Compose([
transforms.ToPILImage(),
transforms.Resize((oheight, owidth)),
transforms.ToTensor(),
])(img)
assert result.size(1) == oheight
assert result.size(2) == owidth

result = transforms.Compose([
transforms.ToPILImage(),
transforms.Resize([oheight, owidth]),
transforms.ToTensor(),
])(img)
assert result.size(1) == oheight
assert result.size(2) == owidth
img = torch.rand(3, height, width)

for is_pil in [True, False]:
result = transform_helper(transforms.Resize(osize), is_pil)(img)
self.assertIn(osize, result.size())
if height < width:
self.assertTrue(result.size(1) <= result.size(2))
elif width < height:
self.assertTrue(result.size(1) >= result.size(2))

for size in [[osize, osize], (oheight, owidth), [oheight, owidth]]:
result = transform_helper(transforms.Resize(size), is_pil)(img)
self.assertTrue(result.size(1) == size[0])
self.assertTrue(result.size(2) == size[1])

# test resize on 3d and 4d images for tensor inputs
t = transform_helper(transforms.Resize((oheight, owidth)), is_pil=False)
img = torch.rand(3, height, width)
r = t(img)
self.assertEqual(tuple(r.shape), (3, oheight, owidth))
img = torch.rand(1, 3, height, width)
r = t(img)
self.assertEqual(tuple(r.shape), (1, 3, oheight, owidth))
img = torch.rand(2, 3, height, width)
r = t(img)
self.assertEqual(tuple(r.shape), (2, 3, oheight, owidth))

def test_random_crop(self):
height = random.randint(10, 32) * 2
Expand Down Expand Up @@ -737,6 +724,27 @@ def test_ndarray_bad_types_to_pil_image(self):
with self.assertRaises(ValueError):
transforms.ToPILImage()(np.ones([1, 4, 4, 3]))

def _test_flip(self, method):
img = torch.rand(3, 10, 10)
pil_img = transforms.functional.to_pil_image(img)

func = getattr(transforms.functional, method)

f_img = func(img)
f_pil_img = func(pil_img)
f_pil_img = transforms.functional.to_tensor(f_pil_img)
# there are rounding differences with PIL due to uint8 conversion
self.assertTrue((f_img - f_pil_img).abs().max() < 1.0 / 255)

ff_img = func(f_img)
self.assertTrue(img.equal(ff_img))

def test_vertical_flip(self):
self._test_flip('vflip')

def test_horizontal_flip(self):
self._test_flip('hflip')

@unittest.skipIf(stats is None, 'scipy.stats not available')
def test_random_vertical_flip(self):
random_state = random.getstate()
Expand Down
58 changes: 46 additions & 12 deletions torchvision/transforms/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,15 @@ def _is_tensor_image(img):
return torch.is_tensor(img) and img.ndimension() == 3


def _get_image_size(img):
if _is_pil_image(img):
return img.size
elif isinstance(img, torch.Tensor) and img.dim() > 2:
return img.shape[-2:][::-1]
else:
raise TypeError("Unexpected type {}".format(type(img)))


def _is_numpy(img):
return isinstance(img, np.ndarray)

Expand Down Expand Up @@ -234,26 +243,42 @@ def resize(img, size, interpolation=Image.BILINEAR):
Returns:
PIL Image: Resized image.
"""
if not _is_pil_image(img):
if not (_is_pil_image(img) or isinstance(img, torch.Tensor)):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the comment above needs to be updated that it takes torch.Tensor (and has to specify what range the Tensor's values have to be)

raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
if not (isinstance(size, int) or (isinstance(size, Iterable) and len(size) == 2)):
raise TypeError('Got inappropriate size arg: {}'.format(size))

if isinstance(size, int):
w, h = img.size
w, h = _get_image_size(img)
if (w <= h and w == size) or (h <= w and h == size):
return img
if w < h:
ow = size
oh = int(size * h / w)
return img.resize((ow, oh), interpolation)
else:
oh = size
ow = int(size * w / h)
return img.resize((ow, oh), interpolation)
else:
size = (oh, ow)
if _is_pil_image(img):
return img.resize(size[::-1], interpolation)

# tensor codepath
# TODO maybe move this outside
_PIL_TO_TORCH_INTERP_MODE = {
Image.NEAREST: "nearest",
Image.BILINEAR: "bilinear"
}
should_unsqueeze = False
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is actually should_squeeze as you squeeze below in 278

if img.dim() == 3:
img = img[None]
should_unsqueeze = True
out = torch.nn.functional.interpolate(img, size=size,
mode=_PIL_TO_TORCH_INTERP_MODE[interpolation],
align_corners=False)
if should_unsqueeze:
out = out[0]
return out


def scale(*args, **kwargs):
warnings.warn("The use of the transforms.Scale transform is deprecated, " +
Expand Down Expand Up @@ -362,16 +387,19 @@ def crop(img, i, j, h, w):
Returns:
PIL Image: Cropped image.
"""
if not _is_pil_image(img):
if not (_is_pil_image(img) or isinstance(img, torch.Tensor)):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment as above

raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

return img.crop((j, i, j + w, i + h))
if _is_pil_image(img):
return img.crop((j, i, j + w, i + h))

return img[..., i:(i + h), j:(j + w)]


def center_crop(img, output_size):
if isinstance(output_size, numbers.Number):
output_size = (int(output_size), int(output_size))
w, h = img.size
w, h = _get_image_size(img)
th, tw = output_size
i = int(round((h - th) / 2.))
j = int(round((w - tw) / 2.))
Expand Down Expand Up @@ -410,10 +438,13 @@ def hflip(img):
Returns:
PIL Image: Horizontall flipped image.
"""
if not _is_pil_image(img):
if not (_is_pil_image(img) or isinstance(img, torch.Tensor)):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same documentation comment as above

raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

return img.transpose(Image.FLIP_LEFT_RIGHT)
if _is_pil_image(img):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's much cleaner to write explicit if/else, rather than have if shortcut to a return

return img.transpose(Image.FLIP_LEFT_RIGHT)

return img.flip(dims=(-1,))


def _get_perspective_coeffs(startpoints, endpoints):
Expand Down Expand Up @@ -468,10 +499,13 @@ def vflip(img):
Returns:
PIL Image: Vertically flipped image.
"""
if not _is_pil_image(img):
if not (_is_pil_image(img) or isinstance(img, torch.Tensor)):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

doc comment as above

raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

return img.transpose(Image.FLIP_TOP_BOTTOM)
if _is_pil_image(img):
return img.transpose(Image.FLIP_TOP_BOTTOM)

return img.flip(dims=(-2,))


def five_crop(img, size):
Expand Down