Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support of float dtypes for draw_segmentation_masks #8150

Merged
merged 23 commits into from
Dec 18, 2023

Conversation

GsnMithra
Copy link
Contributor

Fixes: #8138

Hey there!

I've implemented support for draw_* methods to seamlessly handle both uint8 and float32 image types. The processed image will now be returned with the same data type as the input. Your feedback is always appreciated.

Thank you!

Copy link

pytorch-bot bot commented Dec 10, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/vision/8150

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (2 Unrelated Failures)

As of commit 3dce586 with merge base c35d385 (image):

FLAKY - The following jobs failed but were likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

Copy link
Member

@NicolasHug NicolasHug left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for submitting this PR @GsnMithra

This is promising, but allowing support for float will require a little bit more work than just converting the dtype using .to(), since the convention for float images is that their valuae range is in [0, 1] instead of [0, 255].

Perhaps the easiest way to make progress here will be to focus on draw_segmentation_masks first, and to add a small unit test in https://github.com/pytorch/vision/blob/main/test/test_utils.py, that checks for equality of result for a given image with different dtypes. Something roughly like this:

from torchvision.transforms.v2.functional import to_dtype

img_uint8 = torch.randing(0, 256, (3, 100, 100), dtype=torch.uint8)
img_float = to_dtype(img_uint8, torch.float32, scale=True)

out_uint8 = draw_segmentation_mask(img_uint8, ...)
out_float = draw_segmentation_mask(img_float, ...)

torch.testing.assert_close(out_uint8, to_dtype(out_float, torch.uint8, scale=True))

Does that make sense?

Comment on lines 285 to 286
elif image.dtype not in {torch.uint8, torch.float32}:
raise ValueError(f"The image dtype must be uint8 or float32, got {image.dtype}")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here and in the other places, let's just check for image.is_floating_point() instead of checking specifically for float32. This way we can also support float64, etc.

Copy link
Contributor Author

@GsnMithra GsnMithra Dec 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey, thanks for replying, this will be added with the upcoming commits.

torchvision/utils.py Outdated Show resolved Hide resolved
@GsnMithra
Copy link
Contributor Author

Hey @NicolasHug

Just dropping a note to mention that I've added some assertion checks for floating-point data types in the unit test method for draw_segmentation_mask() with the recent commit titled "test_draw_sementation_mask". I would love to get your thoughts on it.

Thank you for your time.

@NicolasHug
Copy link
Member

Thanks a lot @GsnMithra . It looks like there are a few things missing for now: the import of to_dtype is incorrect (it's written "to_type"), and some deleted variables are still in use, throwing an error (out_dtype). Make sure to run the tests locally before pushing :) We have instructions here https://github.com/pytorch/vision/blob/main/CONTRIBUTING.md#unit-tests

Also, I think it would be best not to modify the current tests, and instead just create a separate new test for the one suggested in #8150 (review). Thank you!

@GsnMithra
Copy link
Contributor Author

Hey @NicolasHug

I would like to sincerely apologize for the mistakes I made in my previous contributions. I am still in the learning process and appreciate your guidance.

In the latest commit, I have included a new unit test called test_draw_segmentation_masks_dtypes, which has been implemented according to the suggestions provided earlier. I would greatly appreciate your thoughts and feedback on this addition.

Once again, I apologize for any inconvenience caused and thank you for taking the time to review my work.

Copy link
Member

@NicolasHug NicolasHug left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the follow-up @GsnMithra, and no problems for the first few hicups.

I can see the new test is passing now, good job! In order to move forward with this PR, I would suggest the following:

  • Address the comments I made below (it should be pretty easy)
  • Remove the changes that were made to draw_bounding_boxes() and draw_keypoints(), and only keep the changes and tests relating to draw_segmentation_masks(). This way, we'll be able to merge this PR straight away, and then you can send 2 separate PR (one for boxes, one for keypoints), should you wish to.

Does that sound good to you?

test/test_utils.py Outdated Show resolved Hide resolved
.gitignore Outdated Show resolved Hide resolved
Comment on lines 11 to 12
import torchvision.transforms.functional as F
from torchvision.transforms.v2.functional import to_dtype
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You'll need to re-order the imports, check out https://github.com/pytorch/vision/actions/runs/7214131414/job/19676695616?pr=8150, or run the pre-commit hooks locally (check our contributing instructions)

test/test_utils.py Outdated Show resolved Hide resolved
Comment on lines 301 to 302
if image.is_floating_point():
image = (image * 255).to(torch.uint8)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's use to_dtype here instead, like you did in the test.

Comment on lines 321 to 322
if original_dtype in {torch.float16, torch.float32, torch.float64}:
out = out.float() / 255.0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same: Let's use to_dtype here instead, like you did in the test.

@@ -315,7 +318,10 @@ def draw_segmentation_masks(
img_to_draw[:, mask] = color[:, None]

out = image * (1 - alpha) + img_to_draw * alpha
return out.to(out_dtype)
if original_dtype in {torch.float16, torch.float32, torch.float64}:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check for is_floating_point() instead.

GsnMithra and others added 3 commits December 15, 2023 17:28
Co-authored-by: Nicolas Hug <[email protected]>
Co-authored-by: Nicolas Hug <[email protected]>
Co-authored-by: Nicolas Hug <[email protected]>
@GsnMithra
Copy link
Contributor Author

Hello @NicolasHug

I hope this message finds you well. I wanted to bring up a couple of points regarding the usage of to_dtype in the torchvision.transforms.v2.functional module, particularly within the draw_segmentation_masks() method.

Firstly, when importing to_dtype globally, it leads to a circular import issue due to the _log_api_usage_once(to_dtype) being utilized within the function. One solution could be importing it inside the draw_segmentation_masks() method, or alternatively, manually converting it.

Secondly, I've encountered a situation where, when a float dtype image is passed, converting it back to float using the following code results in mismatched elements for the unit test:

out = image * (1 - alpha) + img_to_draw * alpha
if original_image.is_floating_point():
    out = to_dtype(out, torch.float, scale=True)

It seems to work only when dividing by 255.0:

out = image * (1 - alpha) + img_to_draw * alpha
if original_image.is_floating_point():
    out = to_dtype(out, torch.float) / 255.0

I appreciate your insights and look forward to any further guidance or suggestions you may have.
Thank you for taking the time.

@GsnMithra GsnMithra changed the title support for float32 for draw_* support for float for draw_segmentation_masks Dec 15, 2023
@GsnMithra GsnMithra changed the title support for float for draw_segmentation_masks support of float dtypes for draw_segmentation_masks Dec 15, 2023
Copy link
Member

@NicolasHug NicolasHug left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for the follow-up @GsnMithra !

You are right about the need for /= 255. It was because out was a float32 image in [0, 255] when the input was a float image. So even doing dtype(scale=True) would not work as expected, as no conversion was done (since the image was already float32).

I found a simpler way to handle all this by simply converting the color's dtype and scale, instead of converting the input image. I pushed the changes and also reverted some minor things from the other functions.

I'll merge this PR when green, thanks a ton for your work on this!

(Feel free to submit follow-up PRs for the other functions if you wish to. Just giving you a heads-up that I will only be able to review them starting next year in Jan, as I'll be on leaves from tomorrow.)

Thanks again @GsnMithra !

@GsnMithra
Copy link
Contributor Author

Hey @NicolasHug
You're welcome, I'll be submitting PR's for the other functions also.

Wishing you a wonderful break ;)

@NicolasHug NicolasHug merged commit 6c2e0ae into pytorch:main Dec 18, 2023
62 of 64 checks passed
Copy link

Hey @NicolasHug!

You merged this PR, but no labels were added. The list of valid labels is available at https://github.com/pytorch/vision/blob/main/.github/process_commit.py

facebook-github-bot pushed a commit that referenced this pull request Jan 16, 2024
Summary: Co-authored-by: Nicolas Hug <[email protected]>

Reviewed By: vmoens

Differential Revision: D52539003

fbshipit-source-id: e7b9412a496e88749dc6e9c5afdd1b5cf85b4aa0
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

draw_bounding_boxes and save_image require different input dtype
3 participants