Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make ColorJitter torchscriptable #2298

Merged
merged 7 commits into from
Jun 10, 2020

Conversation

clemkoa
Copy link
Contributor

@clemkoa clemkoa commented Jun 7, 2020

This is the first half of #2292, regarding adjust_brightness, adjust_contrast and adjust_saturation

I've added some tests trying to mimic what was in test_functional_tensors.py because the results of adjustments are not strictly equal between PIL transforms and Tensor transforms.

Note that I'm not testing for class interface like it is done for hflip and flip. I wanted to go for something like this:

            # test for class interface
            f = T.ColorJitter(brightness=factor)
            scripted_fn = torch.jit.script(f)
            scripted_fn(tensor)

But it raises the error torch.jit.frontend.UnsupportedNodeError: Lambda aren't supported:. Not sure what to do about this.

I hope this PR is up to your standards. This is my first contribution on this repo so feel free to give me some feedback!

Copy link
Member

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for the PR!

I think the only thing missing so that the PR can be merged is to add a test for the ColorJitter transform, ensuring that it is torchscriptable and can run with a dummy tensor. I think there might still be a few things that will need to be changed so that ColorJitter is torchscriptable.

This might require changing uses of random.uniform to use instead torch.rand(1) in get_params, and also removing the lambdas and instead having methods for it to be compatible with torchscript.

Let me know if you have questions.

test/test_transforms_tensor.py Show resolved Hide resolved
@clemkoa
Copy link
Contributor Author

clemkoa commented Jun 8, 2020

I'm not exactly sure how to replace the lambdas by methods. I would still need to define an inner function for it to work with Compose. Something like:

contrast_factor = ...
def c(img):
     return F.adjust_contrast(img, contrast_factor)

transforms.append(c)

But I always end up with torch.jit.frontend.UnsupportedNodeError: function definitions aren't supported:, whether it's an inner function in get_params, a static method or even outside of the class.

Do you have something specific in mind to counter this issue?

@fmassa
Copy link
Member

fmassa commented Jun 9, 2020

@clemkoa I would not use Compose in here, but instead manually re-implement the logic for selecting the hyperparameters and functions.
Maybe something like

fn_idx = torch.randperm(4)
for fn_id in fn_idx:
    if fn_id == 0:
        brightness_factor = torch.tensor(1.0).uniform_(brightness[0], brightness[1]).item()
        img = F.adjust_brightness(img, brightness_factor)
    elif ...

@clemkoa
Copy link
Contributor Author

clemkoa commented Jun 10, 2020

@fmassa thanks for the advice, it is indeed easier without Compose.
I've added the changes you requested, as well as a test for torchscript in test_functional_tensor.

The only thing I don't really like is adjust_hue. I had to move the PIL version to functional_pil.py with a @torch.jit.unused, otherwise torchscript was throwing errors. The remaining function in functional.py is a bit weird with the exception raising. Tell me if you're happy with it

Copy link
Member

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great, thanks a lot!

If you could add a @torch.jit.unused to both get_params and _check_inputs, I think it would be safer because those functions are not supposed to be serialized to torchscript.

Also, I have some more comments on how I think we could improve the code a bit further, let me know what you think

torchvision/transforms/functional.py Show resolved Hide resolved
Comment on lines 909 to 910
return torch.FloatTensor()
return torch.FloatTensor(value)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can probably be kept as a float, by annotating it with Optional[float].

Then, in the forward where it's used, you'll need to do something like

if fn_id == 0 and self.brightness is not None:
    brightness = self.brightness
    assert brightness is not None
    ...

this is needed to tell torchscript that it can specialize the type of the argument to float, instead of Optional[float].

Copy link
Contributor Author

@clemkoa clemkoa Jun 10, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_check_input returns a list of 2 items, not a float. I could change the naming of value because it's misleading.

Since it's returning a list, I think a tensor is the only way to go? Sorry if i'm missing your point

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can also return an Optional[Tuple[float, float]] if it's always two elements that are returned, or Optional[List[float]] if you really want a list.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found out that just converting the input to float in _check_input makes it torchscript happy. The tests pass on Jarvis (at least the test_functional_tensor.py: https://travis-ci.org/github/pytorch/vision/jobs/696793453 )

I didn't need the type nor the assert. If you want them I'm happy to add them, but I thought I'd let you know they don't seem required.

I've also added a simple test with an int as input for the torchscript.

torchvision/transforms/transforms.py Outdated Show resolved Hide resolved
Copy link
Member

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great, thanks a lot @clemkoa !

@fmassa fmassa merged commit 883f1fb into pytorch:master Jun 10, 2020
@clemkoa
Copy link
Contributor Author

clemkoa commented Jun 10, 2020

Thanks a lot for your help and advice @fmassa ! I'd be happy to contribute more in the future!

@clemkoa clemkoa deleted the color-jitter-torchscript branch June 10, 2020 23:18
@@ -909,6 +912,7 @@ def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_firs
return value

@staticmethod
@torch.jit.unused
def get_params(brightness, contrast, saturation, hue):
Copy link
Collaborator

@vfdev-5 vfdev-5 Aug 7, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fmassa @clemkoa looks like get_params becomes unused, why did we keep it ?

Copy link
Member

@fmassa fmassa Aug 7, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For backwards-compatibility. The way it was implemented was pretty non-canonical though, but I would be ok if we made a BC-breaking change to use it again by maybe returning a list of transform name / params somehow, but this seems lower priority for me I think

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can give a shot to that as I'm working on F.adjust_hue and it is used by ColorJitter...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good to me, but can you do it in a separate PR which only does this?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No problems, can do it in a separate PR

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants