-
Notifications
You must be signed in to change notification settings - Fork 7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add pil_to_tensor to functionals #2092
Changes from 5 commits
286e316
f7eb489
f90b3bc
9c2fd3b
08ab5ec
cb19ed4
38ad5f3
1fa91a8
0fefbcb
7662b23
75be7bb
123503a
eff1db0
266860a
610fc1e
b9cca77
1b10f77
fbf661c
598107f
d69048e
fa1084c
2cb7a4f
3d565fd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -82,6 +82,54 @@ def to_tensor(pic): | |
return img | ||
|
||
|
||
def as_tensor(pic): | ||
"""Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor of same type. | ||
|
||
See ``AsTensor`` for more details. | ||
|
||
Args: | ||
pic (PIL Image or numpy.ndarray): Image to be converted to tensor. | ||
|
||
Returns: | ||
Tensor: Converted image. | ||
""" | ||
if not(_is_pil_image(pic) or _is_numpy(pic)): | ||
raise TypeError('pic should be PIL Image or ndarray. Got {}'.format(type(pic))) | ||
|
||
if _is_numpy(pic) and not _is_numpy_image(pic): | ||
raise ValueError('pic should be 2/3 dimensional. Got {} dimensions.'.format(pic.ndim)) | ||
|
||
if isinstance(pic, np.ndarray): | ||
# handle numpy array | ||
if pic.ndim == 2: | ||
pic = pic[:, :, None] | ||
|
||
img = torch.from_numpy(pic.transpose((2, 0, 1))) | ||
return img | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm wondering if we should even support handling Indeed, the data in a As such, I think that we should probably only handle PIL Images -- handling numpy arrays is trivial from the user perspective ( There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not really for or against keeping the numpy conversion here. With that said I think the primary purpose of this function is doing the conversion to a pytorch tensor format and making it into a channels first format. The scope of the inputs can therefore be broadened or narrowed without really affecting the goals of the function. The OpenCV arrays will still come out to be channels first after being passed through this function. We do not need to make any other assumptions other than the data format is HWC (or in the case of black&white images that it is HW). So let me know if you think it's best to drop numpy. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think what we should aim for is the least amount of potential user-errors or surprises. Indeed, both OpenCV and For that reason, we could try to make the scope of this function to be as narrow as possible, so that we can be sure we won't be mishandling user inputs. Let me know what you think There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am okay with the narrowed scope and have made changes to reflect that. As a separate point I would like to keep the functionality of ToPILImage to accept numpy arrays. That way users can still load in numpy (or other formats that they can convert to numpy) and as long as the user can convert the numpy array to PIL Image the sequence of compose will still work (as written below). transforms.Compose([
transforms.ToPILImage(),
...
transforms.PILToTensor(),
]) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sounds good, changing |
||
|
||
if accimage is not None and isinstance(pic, accimage.Image): | ||
xksteven marked this conversation as resolved.
Show resolved
Hide resolved
|
||
nppic = np.zeros([pic.channels, pic.height, pic.width], dtype=np.float32) | ||
pic.copyto(nppic) | ||
return torch.from_numpy(nppic) | ||
|
||
# handle PIL Image | ||
if pic.mode == 'I': | ||
img = torch.from_numpy(np.array(pic, np.int32, copy=False)) | ||
elif pic.mode == 'I;16': | ||
img = torch.from_numpy(np.array(pic, np.int16, copy=False)) | ||
elif pic.mode == 'F': | ||
img = torch.from_numpy(np.array(pic, np.float32, copy=False)) | ||
elif pic.mode == '1': | ||
img = torch.from_numpy(np.array(pic, np.uint8, copy=False)) | ||
else: | ||
img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes())) | ||
xksteven marked this conversation as resolved.
Show resolved
Hide resolved
xksteven marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
img = img.view(pic.size[1], pic.size[0], len(pic.getbands())) | ||
# put it from HWC to CHW format | ||
img = img.permute((2, 0, 1)).contiguous() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm unsure if we want to call There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you explain your rationale for the HWC memory format? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree with @xksteven here. The only reason I see for changing the format is if someone just wants the import and want to squeeze every milli- / microsecond he can get. If that is the intention, I suggest we add a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I miscommunicated what my intentions were, sorry about that. What I wanted to say was that images are naturally stored as HWC, while all PyTorch operations expect CHW (up to now). But there is an ongoing effort on PyTorch to add support for Given that all downstream operations in torchvision should support arbitrarily-strided tensors, I would vote for returning non-contiguous tensors, so that PyTorch, in the future when dedicated kernel support for channels_last is implemented) we will be able to handle those in an efficient manner. |
||
return img | ||
|
||
|
||
def to_pil_image(pic, mode=None): | ||
"""Convert a tensor or an ndarray to PIL Image. | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we indeed only consider that this function only supports the
PIL -> tensor
conversion, then maybe a better name would bepil_to_tensor
or something like that? Open to suggestionsThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you take a second pass at your earliest convenience?
One of the tests is a little awkward in that ToPILImage converts FloatTensors to bytes.
The other thing was I'm unsure of the parameter name "swap_to_channelsfirst".
Let me know.