-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
DeformableDetrModel support fp16 #29013
Conversation
@@ -64,7 +64,7 @@ at::Tensor ms_deform_attn_cuda_forward( | |||
for (int n = 0; n < batch/im2col_step_; ++n) | |||
{ | |||
auto columns = output_n.select(0, n); | |||
AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_forward_cuda", ([&] { | |||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(value.type(), "ms_deform_attn_forward_cuda", ([&] { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def get_valid_ratio(self, mask, dtype=torch.float32): | ||
"""Get the valid ratio of all feature maps.""" | ||
|
||
_, height, width = mask.shape | ||
valid_height = torch.sum(mask[:, :, 0], 1) | ||
valid_width = torch.sum(mask[:, 0, :], 1) | ||
valid_ratio_heigth = valid_height.float() / height | ||
valid_ratio_width = valid_width.float() / width | ||
valid_ratio_heigth = valid_height.to(dtype) / height | ||
valid_ratio_width = valid_width.to(dtype) / width |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reference:
transformers/src/transformers/models/oneformer/modeling_oneformer.py
Lines 1355 to 1364 in de6029a
def get_valid_ratio(self, mask, dtype=torch.float32): | |
"""Get the valid ratio of all feature maps.""" | |
_, height, width = mask.shape | |
valid_height = torch.sum(~mask[:, :, 0], 1) | |
valid_width = torch.sum(~mask[:, 0, :], 1) | |
valid_ratio_heigth = valid_height.to(dtype) / height | |
valid_ratio_width = valid_width.to(dtype) / width | |
valid_ratio = torch.stack([valid_ratio_width, valid_ratio_heigth], -1) | |
return valid_ratio |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice! 🔥
Thanks for taking the time to link to all the relevant references, it really helps for a quick review.
Only comment is a small spelling nit. I realise it was there before, but let's take the opportunity to fix it.
src/transformers/models/deformable_detr/modeling_deformable_detr.py
Outdated
Show resolved
Hide resolved
Actually, one thing we'll need to add is a test e.g. like here for MT5. For the quality checks, running |
…tr.py Co-authored-by: amyeroberts <[email protected]>
@amyeroberts |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great - thanks for adding this great contribution to the library!
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
* Update ms_deform_attn_cuda.cu * Update ms_deform_attn_cuda.cuh * Update modeling_deformable_detr.py * Update src/transformers/models/deformable_detr/modeling_deformable_detr.py Co-authored-by: amyeroberts <[email protected]> * Update modeling_deformable_detr.py * python utils/check_copies.py --fix_and_overwrite * Fix dtype missmatch error * Update test_modeling_deformable_detr.py * Update test_modeling_deformable_detr.py * Update modeling_deformable_detr.py * Update modeling_deformable_detr.py --------- Co-authored-by: amyeroberts <[email protected]>
* Update ms_deform_attn_cuda.cu * Update ms_deform_attn_cuda.cuh * Update modeling_deformable_detr.py * Update src/transformers/models/deformable_detr/modeling_deformable_detr.py Co-authored-by: amyeroberts <[email protected]> * Update modeling_deformable_detr.py * python utils/check_copies.py --fix_and_overwrite * Fix dtype missmatch error * Update test_modeling_deformable_detr.py * Update test_modeling_deformable_detr.py * Update modeling_deformable_detr.py * Update modeling_deformable_detr.py --------- Co-authored-by: amyeroberts <[email protected]>
What does this PR do?
This PR for DeformableDetrModel support fp16.
#29011
Who can review?
@amyeroberts