Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DeformableDetrModel support fp16 #29013

Merged
merged 11 commits into from
Feb 15, 2024
Merged

DeformableDetrModel support fp16 #29013

merged 11 commits into from
Feb 15, 2024

Conversation

DonggeunYu
Copy link
Contributor

@DonggeunYu DonggeunYu commented Feb 14, 2024

What does this PR do?

This PR for DeformableDetrModel support fp16.

#29011

Who can review?

@amyeroberts

@@ -64,7 +64,7 @@ at::Tensor ms_deform_attn_cuda_forward(
for (int n = 0; n < batch/im2col_step_; ++n)
{
auto columns = output_n.select(0, n);
AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_forward_cuda", ([&] {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(value.type(), "ms_deform_attn_forward_cuda", ([&] {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment on lines 1543 to 1550
def get_valid_ratio(self, mask, dtype=torch.float32):
"""Get the valid ratio of all feature maps."""

_, height, width = mask.shape
valid_height = torch.sum(mask[:, :, 0], 1)
valid_width = torch.sum(mask[:, 0, :], 1)
valid_ratio_heigth = valid_height.float() / height
valid_ratio_width = valid_width.float() / width
valid_ratio_heigth = valid_height.to(dtype) / height
valid_ratio_width = valid_width.to(dtype) / width
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reference:

def get_valid_ratio(self, mask, dtype=torch.float32):
"""Get the valid ratio of all feature maps."""
_, height, width = mask.shape
valid_height = torch.sum(~mask[:, :, 0], 1)
valid_width = torch.sum(~mask[:, 0, :], 1)
valid_ratio_heigth = valid_height.to(dtype) / height
valid_ratio_width = valid_width.to(dtype) / width
valid_ratio = torch.stack([valid_ratio_width, valid_ratio_heigth], -1)
return valid_ratio

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice! 🔥

Thanks for taking the time to link to all the relevant references, it really helps for a quick review.

Only comment is a small spelling nit. I realise it was there before, but let's take the opportunity to fix it.

@amyeroberts
Copy link
Collaborator

Actually, one thing we'll need to add is a test e.g. like here for MT5.

For the quality checks, running make fix-copies and pushing the changes should resolve the issues. You make need to make some additional adjustments to other modeling files to properly reflect the changes.

@amyeroberts amyeroberts self-requested a review February 14, 2024 10:17
@DonggeunYu
Copy link
Contributor Author

@amyeroberts
I took all of your feedback on board.

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great - thanks for adding this great contribution to the library!

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@amyeroberts amyeroberts merged commit 5b6fa23 into huggingface:main Feb 15, 2024
18 checks passed
hackyon pushed a commit to hackyon/transformers that referenced this pull request Feb 15, 2024
* Update ms_deform_attn_cuda.cu

* Update ms_deform_attn_cuda.cuh

* Update modeling_deformable_detr.py

* Update src/transformers/models/deformable_detr/modeling_deformable_detr.py

Co-authored-by: amyeroberts <[email protected]>

* Update modeling_deformable_detr.py

* python utils/check_copies.py --fix_and_overwrite

* Fix dtype missmatch error

* Update test_modeling_deformable_detr.py

* Update test_modeling_deformable_detr.py

* Update modeling_deformable_detr.py

* Update modeling_deformable_detr.py

---------

Co-authored-by: amyeroberts <[email protected]>
itazap pushed a commit that referenced this pull request May 14, 2024
* Update ms_deform_attn_cuda.cu

* Update ms_deform_attn_cuda.cuh

* Update modeling_deformable_detr.py

* Update src/transformers/models/deformable_detr/modeling_deformable_detr.py

Co-authored-by: amyeroberts <[email protected]>

* Update modeling_deformable_detr.py

* python utils/check_copies.py --fix_and_overwrite

* Fix dtype missmatch error

* Update test_modeling_deformable_detr.py

* Update test_modeling_deformable_detr.py

* Update modeling_deformable_detr.py

* Update modeling_deformable_detr.py

---------

Co-authored-by: amyeroberts <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants