Skip to content

Commit

Permalink
Fix: Change tensors to integers for torch.dynamo and torch.compile co…
Browse files Browse the repository at this point in the history
…mpatibility (#23475)

* Fix: Change tensors to integers in torch.split() for torch.dynamo and torch.compile compatibility

* Applied the suggested fix to the utils/check_copies.py test

* Applied the suggested fix by changing the original function that gets copied
  • Loading branch information
loevlie authored May 19, 2023
1 parent 389bdba commit 847e569
Show file tree
Hide file tree
Showing 4 changed files with 5 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -550,7 +550,7 @@ def multi_scale_deformable_attention(
) -> Tensor:
batch_size, _, num_heads, hidden_dim = value.shape
_, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
value_list = value.split([height * width for height, width in value_spatial_shapes], dim=1)
value_list = value.split([height.item() * width.item() for height, width in value_spatial_shapes], dim=1)
sampling_grids = 2 * sampling_locations - 1
sampling_value_list = []
for level_id, (height, width) in enumerate(value_spatial_shapes):
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/deta/modeling_deta.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,7 +453,7 @@ def multi_scale_deformable_attention(
) -> Tensor:
batch_size, _, num_heads, hidden_dim = value.shape
_, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
value_list = value.split([height * width for height, width in value_spatial_shapes], dim=1)
value_list = value.split([height.item() * width.item() for height, width in value_spatial_shapes], dim=1)
sampling_grids = 2 * sampling_locations - 1
sampling_value_list = []
for level_id, (height, width) in enumerate(value_spatial_shapes):
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/mask2former/modeling_mask2former.py
Original file line number Diff line number Diff line change
Expand Up @@ -810,7 +810,7 @@ def multi_scale_deformable_attention(
) -> Tensor:
batch_size, _, num_heads, hidden_dim = value.shape
_, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
value_list = value.split([height * width for height, width in value_spatial_shapes], dim=1)
value_list = value.split([height.item() * width.item() for height, width in value_spatial_shapes], dim=1)
sampling_grids = 2 * sampling_locations - 1
sampling_value_list = []
for level_id, (height, width) in enumerate(value_spatial_shapes):
Expand Down Expand Up @@ -1340,7 +1340,7 @@ def forward(
else:
split_sizes[i] = last_hidden_state.shape[1] - level_start_index[i]

encoder_output = torch.split(last_hidden_state, split_sizes, dim=1)
encoder_output = torch.split(last_hidden_state, [size.item() for size in split_sizes], dim=1)

# Compute final features
outputs = [
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/oneformer/modeling_oneformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def multi_scale_deformable_attention(
) -> Tensor:
batch_size, _, num_heads, hidden_dim = value.shape
_, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
value_list = value.split([height * width for height, width in value_spatial_shapes], dim=1)
value_list = value.split([height.item() * width.item() for height, width in value_spatial_shapes], dim=1)
sampling_grids = 2 * sampling_locations - 1
sampling_value_list = []
for level_id, (height, width) in enumerate(value_spatial_shapes):
Expand Down

0 comments on commit 847e569

Please sign in to comment.