Skip to content

Commit

Permalink
DeformableDETR two stage support bfloat16 (huggingface#30907)
Browse files Browse the repository at this point in the history
Update modeling_deformable_detr.py
  • Loading branch information
DonggeunYu authored and zucchini-nlp committed Jun 11, 2024
1 parent e7af110 commit e6a72b3
Showing 1 changed file with 2 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1616,8 +1616,8 @@ def gen_encoder_output_proposals(self, enc_output, padding_mask, spatial_shapes)
valid_width = torch.sum(~mask_flatten_[:, 0, :, 0], 1)

grid_y, grid_x = meshgrid(
torch.linspace(0, height - 1, height, dtype=torch.float32, device=enc_output.device),
torch.linspace(0, width - 1, width, dtype=torch.float32, device=enc_output.device),
torch.linspace(0, height - 1, height, dtype=enc_output.dtype, device=enc_output.device),
torch.linspace(0, width - 1, width, dtype=enc_output.dtype, device=enc_output.device),
indexing="ij",
)
grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1)
Expand Down

0 comments on commit e6a72b3

Please sign in to comment.