-
Notifications
You must be signed in to change notification settings - Fork 7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make ColorJitter torchscriptable #2298
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for the PR!
I think the only thing missing so that the PR can be merged is to add a test for the ColorJitter
transform, ensuring that it is torchscriptable and can run with a dummy tensor. I think there might still be a few things that will need to be changed so that ColorJitter
is torchscriptable.
This might require changing uses of random.uniform
to use instead torch.rand(1)
in get_params
, and also removing the lambdas and instead having methods for it to be compatible with torchscript.
Let me know if you have questions.
I'm not exactly sure how to replace the lambdas by methods. I would still need to define an inner function for it to work with Compose. Something like:
But I always end up with Do you have something specific in mind to counter this issue? |
@clemkoa I would not use fn_idx = torch.randperm(4)
for fn_id in fn_idx:
if fn_id == 0:
brightness_factor = torch.tensor(1.0).uniform_(brightness[0], brightness[1]).item()
img = F.adjust_brightness(img, brightness_factor)
elif ... |
@fmassa thanks for the advice, it is indeed easier without Compose. The only thing I don't really like is |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great, thanks a lot!
If you could add a @torch.jit.unused
to both get_params
and _check_inputs
, I think it would be safer because those functions are not supposed to be serialized to torchscript.
Also, I have some more comments on how I think we could improve the code a bit further, let me know what you think
torchvision/transforms/transforms.py
Outdated
return torch.FloatTensor() | ||
return torch.FloatTensor(value) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This can probably be kept as a float
, by annotating it with Optional[float]
.
Then, in the forward
where it's used, you'll need to do something like
if fn_id == 0 and self.brightness is not None:
brightness = self.brightness
assert brightness is not None
...
this is needed to tell torchscript that it can specialize the type of the argument to float, instead of Optional[float]
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_check_input
returns a list of 2 items, not a float. I could change the naming of value
because it's misleading.
Since it's returning a list, I think a tensor is the only way to go? Sorry if i'm missing your point
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can also return an Optional[Tuple[float, float]]
if it's always two elements that are returned, or Optional[List[float]]
if you really want a list.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I found out that just converting the input to float in _check_input
makes it torchscript happy. The tests pass on Jarvis (at least the test_functional_tensor.py: https://travis-ci.org/github/pytorch/vision/jobs/696793453 )
I didn't need the type nor the assert. If you want them I'm happy to add them, but I thought I'd let you know they don't seem required.
I've also added a simple test with an int as input for the torchscript.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great, thanks a lot @clemkoa !
Thanks a lot for your help and advice @fmassa ! I'd be happy to contribute more in the future! |
@@ -909,6 +912,7 @@ def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_firs | |||
return value | |||
|
|||
@staticmethod | |||
@torch.jit.unused | |||
def get_params(brightness, contrast, saturation, hue): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For backwards-compatibility. The way it was implemented was pretty non-canonical though, but I would be ok if we made a BC-breaking change to use it again by maybe returning a list of transform name / params somehow, but this seems lower priority for me I think
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can give a shot to that as I'm working on F.adjust_hue
and it is used by ColorJitter...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good to me, but can you do it in a separate PR which only does this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No problems, can do it in a separate PR
This is the first half of #2292, regarding adjust_brightness, adjust_contrast and adjust_saturation
I've added some tests trying to mimic what was in test_functional_tensors.py because the results of adjustments are not strictly equal between PIL transforms and Tensor transforms.
Note that I'm not testing for class interface like it is done for hflip and flip. I wanted to go for something like this:
But it raises the error
torch.jit.frontend.UnsupportedNodeError: Lambda aren't supported:
. Not sure what to do about this.I hope this PR is up to your standards. This is my first contribution on this repo so feel free to give me some feedback!