Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug: torchvision/transforms/functional/to_pil_image always converts 1-channel (gray) FloatTensor images to 8-bit unsigned int #448

Open
mathski opened this issue Mar 19, 2018 · 16 comments

Comments

@mathski
Copy link

mathski commented Mar 19, 2018

  • OS: Ubuntu 16.04.4 LTS x64
  • PyTorch version: 0.3.0
  • Torchvision version: 0.2.0
  • How you installed PyTorch (conda, pip, source): conda
  • Python version: 3
  • CUDA/cuDNN version: 8.0
  • GPU models and configuration: Titan X (Maxwell)

ERROR:
ValueError: Incorrect mode (<class 'float'>) supplied for input type <class 'numpy.dtype'>. Should be L

The torchvision transform ToPILImage(mode=float) will always break for input of type torch.FloatTensor
ToPILImage() uses the internal function to_pil_image found in torchvision/transforms/functional.py

In https://github.com/pytorch/vision/blob/master/torchvision/transforms/functional.py:
Line 104 checks if the input is of type torch.FloatTensor
If so, line 105 scales the input by 255, but then converts it to byte
Lines 113-127 check if the user-specified mode is the expected mode, and throws an error if not.
The expected mode is assigned by npimg.dtype, which return np.uint8 if line 105 is executed

I believe the bug can be fixed by changing line 105 from:
pic = pic.mul(255).byte()
-to-
pic = pic.mul(255)

Test script:
import torch
from torchvision import transforms
a = torch.FloatTensor(1,64,64)
tform = transforms.Compose([transforms.ToPILImage(mode='F')])
b = tform(a)

Please let me know if I am in error.
Thank you.

@fmassa
Copy link
Member

fmassa commented Mar 21, 2018

Yes, it looks like we currently don't handle this case properly.
I'm not even sure what is the expected range in PIL for float32 images, so I don't even know if we should multiply by 255.

One workaround for the moment seems to be to convert the torch tensor to a numpy array, but it would be better to fix this case.

@mathski
Copy link
Author

mathski commented Mar 21, 2018

Thank you very much for your response.

My workaround has been to use local copies of a few of the Torchvision functions, such as ToPILImage(), so I can edit them directly. I changed if isinstance(pic, torch.FloatTensor) to if isinstance(pic, torch.ByteTensor) on line 104 (in the version on https://github.com/pytorch/vision/blob/master/torchvision/transforms/functional.py) and removed the scaling by 255.

Currently PIL has very little support for scientific imaging, i.e. Gaussian_Blur on a float32 image, or handling an image with several float32 color channels. This is rather limiting for those looking to use signal and image processing techniques. I'm not sure what the demand is for these features, but I would be happy to contribute any of my implementations to Torchvision.

Thank you again.

@fmassa
Copy link
Member

fmassa commented Mar 21, 2018

Unfortunately such change would not be backward compatible, as we assume that images are in float tensors in 0-1 range, so a different fix should be added.

@marii-moe
Copy link

Fairly new to coding/deep learning/pytorch/vision, so hope I am doing this right. Here is what I found:

  1. 650eb32
    ToPILImage was originally added for testing purposes. Only requirement at the time was ToTensor(ToPILImage(tensor)) returned the identity, because it was used in testing other transformations. Originally took a FloatTensor, converted to byte array, and passed into fromArray() without "mode" being a part of the API
  2. bodokaiser@683852d
    First commit to add the concept of mode to the code, but doesn't change anything related to this bug.
  3. b8e69d8 (Mar 23, 2017)
    This commit introduced the concept of passing in a mode, in order specify what logic to apply.
  4. 901c1ad (Oct 17, 2017)
    This commit added checking to verify that the passed in mode matched the final output mode. This was not previously a requirement, and is why we are having issues here.

From 1-4 the API has been kept so that passing in a float from 0-1 with mode=None was a requirement. Since #4 (Oct.17,2017) passing in a FloatTensor with mode='F' has been broken.

My thought on options

  1. Since mode='F' has been broken for awhile now, I think we can possibly have different behavior here. When mode='F' is specified we allow the range to be 0-255 or whatever is appropriate for Image.fromArray(). Maybe give off a DeprecationWarning for mode=None and revisit it at a latter date/version?
  2. Break API and switch both mode='F' and mode=None. This is mainly due to the fact that the behavior of mode=None and mode='F' are inconsistent with the rest of the API currently. Also, mode=None currently assumes mode='F' behavior, which is a bit strange according to how the API evolved over time.
  3. Turn off the check added in fix ToTensor to handle numpy #4 for mode='F' only. This would be most consistent with the API from the time between cifar 10 and 100 #3 and fix ToTensor to handle numpy #4

What is the preference on the fix? Either way would be interested in trying to code this one if it isn't too much of a problem.

@fmassa
Copy link
Member

fmassa commented Aug 28, 2018

My current thinking to solve this issue is to have a wrapper Image class that knows what are the expected ranges, and performs the conversion as expected. This means that the user can specify what is the ranges for their data if he wants.
I'm sketching an API for that to see what it would look like, to see if it would a number of problems that have been mentioned here already.

@mathski
Copy link
Author

mathski commented Sep 20, 2018

I'd be interested to see this.
As you suggest, I imagine an Image class would be most useful for those looking seeking flexibility in range and data type for image processing.

Please let me know if there is anything I can contribute here.
Thanks.

@fmassa
Copy link
Member

fmassa commented Sep 21, 2018

Hi @mathski ,

I'm currently looking for datasets from other domains, like medical imagery, astronomy / etc, which do have images but which are in specialized formats.
I've never worked with those kinds of data, so having an idea of what's out there will definitely help.

@mathski
Copy link
Author

mathski commented Sep 23, 2018

Hey @fmassa ,

I've worked with multi-spectral satellite and medical image data previously, but those datasets are not publicly available.
Instead, I asked some colleagues to provide me with public analogues, and they referred me to an IARPA satellite database with 8-band multispectral TIFF images (downloaded through Amazon AWS):
https://www.iarpa.gov/challenges/fmow.html

I actually opened this issue not because I was dealing with an uncommon dataset type, but because I was attempting to do some simple image processing on intermediate outputs during training.
During training, my data was 3-band RGB data represented as a torch.FloatTensor composed of floats, and I wanted to convert ToPILImage(mode='F') to take advantage of functions like resize and Gaussian blurring.
But I was having trouble with vanishing gradients when converting back and forth with strict controls on range and data type.
As I mentioned earlier, I was able to solve the problem on my end. But the ability to perform image processing (and other) operations in conjunction with network training should be allowable.

Thanks for the support in any case.

@fmassa
Copy link
Member

fmassa commented Sep 24, 2018

I see.
In those cases, I think it might have been better for you to leverage functions that support backprop, like interpolate (for image resizing), conv2d for blurring and grid_sample for generic warpings on the image.

I think we might want to make torchvision support backpropagation, and possibly avoid the need of converting back and forth to PIL images or numpy arrays. I think we are getting there with better support for interpolate and grid_sample, but it still need to be wrapped up more nicely (to avoid having to create the transformation flow yourself, which can be error-prone).

And thanks for the dataset!

@mathski
Copy link
Author

mathski commented Sep 24, 2018

Just so I understand, functions like conv2d and grid_sample allow the user to define the function values?
For example, if I want to define and apply my own blurring kernel rather than learning a blurring kernel, conv2d allows me to do that?

@fmassa
Copy link
Member

fmassa commented Sep 24, 2018

yes, definitely! If you want your kernel to be fixed, you can do something like

class GaussianBlur(nn.Module):
    def __init__(self, ...):
        super(...)
        self.register_buffer('filter', torch.rand(1, 1, 3, 3))

    def forward(self, input):
        return F.conv2d(input, self.filter, ...)

@mathski
Copy link
Author

mathski commented Sep 24, 2018

Excellent. Thank you very much.

@tbung
Copy link

tbung commented Feb 20, 2020

Would it be possible to document this behavior in relevant places until this issue gets resolved?

I encountered this issue when using the FakeData dataset and got unexpected results and went on debugging my own code until I stumbled onto this discussion, so it would probably be helpful to others as well to put a warning there.

@cpuhrsch
Copy link
Contributor

@tbung - I'll flag this for follow-up to make a decision on this

@fmassa
Copy link
Member

fmassa commented Feb 25, 2020

@tbung I agree, I think we should improve the documentation of to_pil_image to make it explicit what are the conversions we do.

Would you mind sending a PR improving the documentation?

@mathski
Copy link
Author

mathski commented Jul 21, 2021

Thanks for the help with this a while back. I have a follow up question:

Is there a list of torchvision functions that are supported by autograd?
Is it safe to assume that any class/function that inherits from torch.nn.Module supports backprop?
For example, would GaussianBlur defined here support backprop on tensor input?
https://pytorch.org/vision/stable/_modules/torchvision/transforms/transforms.html#GaussianBlur

I'm trying to figure out which transformation functions I can avoid re-implementing myself.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

5 participants