Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expand the channels to 3 if the user requested as such #8229

Merged
merged 10 commits into from
Jan 26, 2024

Conversation

ahmadsharif1
Copy link
Contributor

@ahmadsharif1 ahmadsharif1 commented Jan 23, 2024

Before this patch, if we passed in an image with just one channel and requested num_output_channels=3, we would return an image with a single channel. After this patch, we will expand the channels to 3 if requested.

This issue is described in #8167 .

@NicolasHug could you give some guidance on how to test this more thoroughly?

cc @vfdev-5

Copy link

pytorch-bot bot commented Jan 23, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/vision/8229

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit e8979fc with merge base 7f55a1b (image):

BROKEN TRUNK - The following job failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

Copy link
Member

@NicolasHug NicolasHug left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for the PR @ahmadsharif1 ! I made some minor comments below, but I'll approve now to unblock.

The tests are properly covering the changes as far as I can tell. In fact, this PR is aligning the tensor backend with the PIL backend, which was already expanding 1D images to 3D. I'm gonna mark this as a bugfix.

Thank you!

Comment on lines 36 to 37
if num_output_channels not in (1, 3):
raise ValueError(f"num_output_channels must be 1 or 3, got {num_output_channels}.")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This check is already done in all of the call sites of this helper, so it's not needed here. In general I agree we should error at the lower-level functions like this one instead of earlier. Maybe it's be better to error here instead of the callsites, but since it would require more change than necessary for this feature, let's leave it out for now.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Replaced with a TODO. Let me know if something else is needed here.

if image.shape[-3] == 1 and num_output_channels == 3:
s = [-1] * len(image.shape)
s[-3] = 3
image = image.expand(s)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we simply do

Suggested change
image = image.expand(s)
return image.expand(s)

here? The output should be the same and it would save some extra computation from l_img = r.mul(0.2989).add_(g, alpha=0.587).add_(b, alpha=0.114). Not critical though.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right. This saves some ops. Done.

@ahmadsharif1 ahmadsharif1 marked this pull request as ready for review January 24, 2024 20:52
Copy link
Collaborator

@vfdev-5 vfdev-5 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @ahmadsharif1 !

@ahmadsharif1
Copy link
Contributor Author

Turns out .expand() just creates a view into the same tensor.

I am not sure that's the behavior that the user wants/expects. I changed to torch.repeat() instead of torch.expand(), but let me know if I am mistaken.

@NicolasHug
Copy link
Member

Good catch @ahmadsharif1 , thank you. repeat() makes sense as it's a bit safer in case there are per-channel inplace modifications on the image in the future. This PR still LGTM, feel free to merge! The current CI failures are unrelated and can be ignored.

@ahmadsharif1 ahmadsharif1 merged commit e0fd033 into pytorch:main Jan 26, 2024
80 of 81 checks passed
@ahmadsharif1 ahmadsharif1 deleted the i8167 branch January 26, 2024 14:59
facebook-github-bot pushed a commit that referenced this pull request Mar 19, 2024
Summary: Co-authored-by: Nicolas Hug <[email protected]>

Reviewed By: vmoens

Differential Revision: D55062762

fbshipit-source-id: 8a36b5ada2a5926b5280b0dc420efb41ecbb3fa1
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants