Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add pil_to_tensor to functionals #2092

Merged
merged 23 commits into from
May 18, 2020
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,6 +512,50 @@ def test_accimage_to_tensor(self):
self.assertEqual(expected_output.size(), output.size())
self.assertTrue(np.allclose(output.numpy(), expected_output.numpy()))

def test_as_tensor(self):
test_channels = [1, 3, 4]
height, width = 4, 4
trans = transforms.AsTensor()

with self.assertRaises(TypeError):
trans(np.random.rand(1, height, width).tolist())

with self.assertRaises(ValueError):
trans(np.random.rand(height))
trans(np.random.rand(1, 1, height, width))

for channels in test_channels:
input_data = torch.ByteTensor(channels, height, width).random_(0, 255)
img = transforms.ToPILImage()(input_data)
output = trans(img)
self.assertTrue(np.allclose(input_data.numpy(), output.numpy()))

ndarray = np.random.randint(low=0, high=255, size=(height, width, channels)).astype(np.uint8)
output = trans(ndarray)
expected_output = ndarray.transpose((2, 0, 1))
self.assertTrue(np.allclose(output.numpy(), expected_output))

ndarray = np.random.rand(height, width, channels).astype(np.float32)
output = trans(ndarray)
expected_output = ndarray.transpose((2, 0, 1))
self.assertTrue(np.allclose(output.numpy(), expected_output))

# separate test for mode '1' PIL images
input_data = torch.ByteTensor(1, height, width).bernoulli_()
img = transforms.ToPILImage()(input_data.mul(255)).convert('1')
output = trans(img)
self.assertTrue(np.allclose(input_data.numpy(), output.numpy()))

@unittest.skipIf(accimage is None, 'accimage not available')
def test_accimage_as_tensor(self):
trans = transforms.AsTensor()

expected_output = trans(Image.open(GRACE_HOPPER).convert('RGB'))
output = trans(accimage.Image(GRACE_HOPPER))

self.assertEqual(expected_output.size(), output.size())
self.assertTrue(np.allclose(output.numpy(), expected_output.numpy()))

@unittest.skipIf(accimage is None, 'accimage not available')
def test_accimage_resize(self):
trans = transforms.Compose([
Expand Down
48 changes: 48 additions & 0 deletions torchvision/transforms/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,54 @@ def to_tensor(pic):
return img


def as_tensor(pic):
"""Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor of same type.

See ``AsTensor`` for more details.

Args:
pic (PIL Image or numpy.ndarray): Image to be converted to tensor.

Returns:
Tensor: Converted image.
"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we indeed only consider that this function only supports the PIL -> tensor conversion, then maybe a better name would be pil_to_tensor or something like that? Open to suggestions

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you take a second pass at your earliest convenience?

One of the tests is a little awkward in that ToPILImage converts FloatTensors to bytes.
The other thing was I'm unsure of the parameter name "swap_to_channelsfirst".

Let me know.

if not(_is_pil_image(pic) or _is_numpy(pic)):
raise TypeError('pic should be PIL Image or ndarray. Got {}'.format(type(pic)))

if _is_numpy(pic) and not _is_numpy_image(pic):
raise ValueError('pic should be 2/3 dimensional. Got {} dimensions.'.format(pic.ndim))

if isinstance(pic, np.ndarray):
# handle numpy array
if pic.ndim == 2:
pic = pic[:, :, None]

img = torch.from_numpy(pic.transpose((2, 0, 1)))
return img
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering if we should even support handling np.ndarray in this function.

Indeed, the data in a np.ndarray can have any format (for example, it can be a float array with range from 0-255), and we can't properly handle all possible cases. It's the responsibility of the user to do it. Plus, if the user passes OpenCV arrays to the function, it will be in BGR format (different from RGB from Pillow and what we use in torchvision)

As such, I think that we should probably only handle PIL Images -- handling numpy arrays is trivial from the user perspective (torch.as_tensor(ndarray))

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not really for or against keeping the numpy conversion here.

With that said I think the primary purpose of this function is doing the conversion to a pytorch tensor format and making it into a channels first format. The scope of the inputs can therefore be broadened or narrowed without really affecting the goals of the function.

The OpenCV arrays will still come out to be channels first after being passed through this function. We do not need to make any other assumptions other than the data format is HWC (or in the case of black&white images that it is HW).

So let me know if you think it's best to drop numpy.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think what we should aim for is the least amount of potential user-errors or surprises.

Indeed, both OpenCV and scikit-image returns images as ndarrays of HWC format. But the color convention is not the same, and from our perspective there is no way to know if the array is indeed HWC or not (imagine multi-band images for example).
What scaries me is that the ndarray that is passed could also be CHW for some reason, and the function would just return something wrong.

For that reason, we could try to make the scope of this function to be as narrow as possible, so that we can be sure we won't be mishandling user inputs.
PIL Images and AccImage have a well-defined format representation (although I'm not sure that many people use AccImage actually), which is not the case for ndarrays which are generic data containers.

Let me know what you think

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am okay with the narrowed scope and have made changes to reflect that.

As a separate point I would like to keep the functionality of ToPILImage to accept numpy arrays. That way users can still load in numpy (or other formats that they can convert to numpy) and as long as the user can convert the numpy array to PIL Image the sequence of compose will still work (as written below).

transforms.Compose([
    transforms.ToPILImage(),
    ...
    transforms.PILToTensor(),
 ])

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good, changing ToPILImage was not in the plans (it would be a backwards-incompatible change)


if accimage is not None and isinstance(pic, accimage.Image):
xksteven marked this conversation as resolved.
Show resolved Hide resolved
nppic = np.zeros([pic.channels, pic.height, pic.width], dtype=np.float32)
pic.copyto(nppic)
return torch.from_numpy(nppic)

# handle PIL Image
if pic.mode == 'I':
img = torch.from_numpy(np.array(pic, np.int32, copy=False))
elif pic.mode == 'I;16':
img = torch.from_numpy(np.array(pic, np.int16, copy=False))
elif pic.mode == 'F':
img = torch.from_numpy(np.array(pic, np.float32, copy=False))
elif pic.mode == '1':
img = torch.from_numpy(np.array(pic, np.uint8, copy=False))
else:
img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
xksteven marked this conversation as resolved.
Show resolved Hide resolved
xksteven marked this conversation as resolved.
Show resolved Hide resolved

img = img.view(pic.size[1], pic.size[0], len(pic.getbands()))
# put it from HWC to CHW format
img = img.permute((2, 0, 1)).contiguous()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm unsure if we want to call contiguous() here.
If fact, I was thinking about letting the tensor be with a different memory format (channels_last, HWC).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you explain your rationale for the HWC memory format?
Almost all of the downstream operations expect the CHW format so should there be a separate function that handles this?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with @xksteven here. The only reason I see for changing the format is if someone just wants the import and want to squeeze every milli- / microsecond he can get. If that is the intention, I suggest we add a channels_first flag that defaults to True.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I miscommunicated what my intentions were, sorry about that.

What I wanted to say was that images are naturally stored as HWC, while all PyTorch operations expect CHW (up to now). But there is an ongoing effort on PyTorch to add support for channels_last, which takes tensors of shape CHW but with strides such that is just a transposed HWC (no contiguous call).

Given that all downstream operations in torchvision should support arbitrarily-strided tensors, I would vote for returning non-contiguous tensors, so that PyTorch, in the future when dedicated kernel support for channels_last is implemented) we will be able to handle those in an efficient manner.

return img


def to_pil_image(pic, mode=None):
"""Convert a tensor or an ndarray to PIL Image.

Expand Down
24 changes: 23 additions & 1 deletion torchvision/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from . import functional as F


__all__ = ["Compose", "ToTensor", "ToPILImage", "Normalize", "Resize", "Scale", "CenterCrop", "Pad",
__all__ = ["Compose", "ToTensor", "AsTensor", "ToPILImage", "Normalize", "Resize", "Scale", "CenterCrop", "Pad",
"Lambda", "RandomApply", "RandomChoice", "RandomOrder", "RandomCrop", "RandomHorizontalFlip",
"RandomVerticalFlip", "RandomResizedCrop", "RandomSizedCrop", "FiveCrop", "TenCrop", "LinearTransformation",
"ColorJitter", "RandomRotation", "RandomAffine", "Grayscale", "RandomGrayscale",
Expand Down Expand Up @@ -95,6 +95,28 @@ def __repr__(self):
return self.__class__.__name__ + '()'


class AsTensor(object):
"""Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor of the same type.

Converts a PIL Image or numpy.ndarray (H x W x C) to a torch.Tensor of shape (C x H x W)
if the PIL Image belongs to one of the modes (L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK, 1)
or if the numpy.ndarray has dtype = np.uint8
"""

def __call__(self, pic):
"""
Args:
pic (PIL Image or numpy.ndarray): Image to be converted to tensor.

Returns:
Tensor: Converted image.
"""
return F.as_tensor(pic)

def __repr__(self):
return self.__class__.__name__ + '()'


class ToPILImage(object):
"""Convert a tensor or an ndarray to PIL Image.

Expand Down