Skip to content

Commit

Permalink
🚨🚨🚨 fix(Mask2Former): torch export 🚨🚨🚨 (#34393)
Browse files Browse the repository at this point in the history
* fix(Mask2Former): torch export

Signed-off-by: Phillip Kuznetsov <[email protected]>

* revert level_start_index and create a level_start_index_list

Signed-off-by: Phillip Kuznetsov <[email protected]>

* Add a comment to explain the level_start_index_list

Signed-off-by: Phillip Kuznetsov <[email protected]>

* Address comment

Signed-off-by: Phillip Kuznetsov <[email protected]>

* add torch.export.export test

Signed-off-by: Phillip Kuznetsov <[email protected]>

* rename arg

Signed-off-by: Phillip Kuznetsov <[email protected]>

* remove spatial_shapes

Signed-off-by: Phillip Kuznetsov <[email protected]>

* Use the version check from pytorch_utils

Signed-off-by: Phillip Kuznetsov <[email protected]>

* [run_slow] mask2former

Signed-off-by: Phillip Kuznetsov <[email protected]>

---------

Signed-off-by: Phillip Kuznetsov <[email protected]>
  • Loading branch information
philkuz authored Nov 19, 2024
1 parent 5815243 commit 5fa4f64
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 25 deletions.
62 changes: 37 additions & 25 deletions src/transformers/models/mask2former/modeling_mask2former.py
Original file line number Diff line number Diff line change
Expand Up @@ -926,7 +926,7 @@ def forward(
encoder_attention_mask=None,
position_embeddings: Optional[torch.Tensor] = None,
reference_points=None,
spatial_shapes=None,
spatial_shapes_list=None,
level_start_index=None,
output_attentions: bool = False,
):
Expand All @@ -936,7 +936,8 @@ def forward(

batch_size, num_queries, _ = hidden_states.shape
batch_size, sequence_length, _ = encoder_hidden_states.shape
if (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() != sequence_length:
total_elements = sum(height * width for height, width in spatial_shapes_list)
if total_elements != sequence_length:
raise ValueError(
"Make sure to align the spatial shapes with the sequence length of the encoder hidden states"
)
Expand All @@ -957,7 +958,11 @@ def forward(
)
# batch_size, num_queries, n_heads, n_levels, n_points, 2
if reference_points.shape[-1] == 2:
offset_normalizer = torch.stack([spatial_shapes[..., 1], spatial_shapes[..., 0]], -1)
offset_normalizer = torch.tensor(
[[shape[1], shape[0]] for shape in spatial_shapes_list],
dtype=torch.long,
device=reference_points.device,
)
sampling_locations = (
reference_points[:, :, None, :, None, :]
+ sampling_offsets / offset_normalizer[None, None, None, :, None, :]
Expand All @@ -970,7 +975,7 @@ def forward(
else:
raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]}")

output = multi_scale_deformable_attention(value, spatial_shapes, sampling_locations, attention_weights)
output = multi_scale_deformable_attention(value, spatial_shapes_list, sampling_locations, attention_weights)
output = self.output_proj(output)

return output, attention_weights
Expand Down Expand Up @@ -1001,7 +1006,7 @@ def forward(
attention_mask: torch.Tensor,
position_embeddings: torch.Tensor = None,
reference_points=None,
spatial_shapes=None,
spatial_shapes_list=None,
level_start_index=None,
output_attentions: bool = False,
):
Expand All @@ -1015,8 +1020,8 @@ def forward(
Position embeddings, to be added to `hidden_states`.
reference_points (`torch.FloatTensor`, *optional*):
Reference points.
spatial_shapes (`torch.LongTensor`, *optional*):
Spatial shapes of the backbone feature maps.
spatial_shapes_list (`list` of `tuple`):
Spatial shapes of the backbone feature maps as a list of tuples.
level_start_index (`torch.LongTensor`, *optional*):
Level start index.
output_attentions (`bool`, *optional*):
Expand All @@ -1033,7 +1038,7 @@ def forward(
encoder_attention_mask=attention_mask,
position_embeddings=position_embeddings,
reference_points=reference_points,
spatial_shapes=spatial_shapes,
spatial_shapes_list=spatial_shapes_list,
level_start_index=level_start_index,
output_attentions=output_attentions,
)
Expand Down Expand Up @@ -1086,13 +1091,13 @@ def __init__(self, config: Mask2FormerConfig):
)

@staticmethod
def get_reference_points(spatial_shapes, valid_ratios, device):
def get_reference_points(spatial_shapes_list, valid_ratios, device):
"""
Get reference points for each feature map. Used in decoder.
Args:
spatial_shapes (`torch.LongTensor`):
Spatial shapes of each feature map, has shape of `(num_feature_levels, 2)`.
spatial_shapes_list (`list` of `tuple`):
Spatial shapes of the backbone feature maps as a list of tuples.
valid_ratios (`torch.FloatTensor`):
Valid ratios of each feature map, has shape of `(batch_size, num_feature_levels, 2)`.
device (`torch.device`):
Expand All @@ -1101,7 +1106,7 @@ def get_reference_points(spatial_shapes, valid_ratios, device):
`torch.FloatTensor` of shape `(batch_size, num_queries, num_feature_levels, 2)`
"""
reference_points_list = []
for lvl, (height, width) in enumerate(spatial_shapes):
for lvl, (height, width) in enumerate(spatial_shapes_list):
ref_y, ref_x = torch.meshgrid(
torch.linspace(0.5, height - 0.5, height, dtype=valid_ratios.dtype, device=device),
torch.linspace(0.5, width - 0.5, width, dtype=valid_ratios.dtype, device=device),
Expand All @@ -1122,7 +1127,7 @@ def forward(
inputs_embeds=None,
attention_mask=None,
position_embeddings=None,
spatial_shapes=None,
spatial_shapes_list=None,
level_start_index=None,
valid_ratios=None,
output_attentions=None,
Expand All @@ -1140,8 +1145,8 @@ def forward(
[What are attention masks?](../glossary#attention-mask)
position_embeddings (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Position embeddings that are added to the queries and keys in each self-attention layer.
spatial_shapes (`torch.LongTensor` of shape `(num_feature_levels, 2)`):
Spatial shapes of each feature map.
spatial_shapes_list (`list` of `tuple`):
Spatial shapes of each feature map as a list of tuples.
level_start_index (`torch.LongTensor` of shape `(num_feature_levels)`):
Starting index of each feature map.
valid_ratios (`torch.FloatTensor` of shape `(batch_size, num_feature_levels, 2)`):
Expand All @@ -1162,7 +1167,7 @@ def forward(
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

hidden_states = inputs_embeds
reference_points = self.get_reference_points(spatial_shapes, valid_ratios, device=inputs_embeds.device)
reference_points = self.get_reference_points(spatial_shapes_list, valid_ratios, device=inputs_embeds.device)

all_hidden_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
Expand All @@ -1176,7 +1181,7 @@ def forward(
attention_mask,
position_embeddings=position_embeddings,
reference_points=reference_points,
spatial_shapes=spatial_shapes,
spatial_shapes_list=spatial_shapes_list,
level_start_index=level_start_index,
output_attentions=output_attentions,
)
Expand Down Expand Up @@ -1302,9 +1307,9 @@ def forward(
]

# Prepare encoder inputs (by flattening)
spatial_shapes = [(embed.shape[2], embed.shape[3]) for embed in input_embeds]
spatial_shapes_list = [(embed.shape[2], embed.shape[3]) for embed in input_embeds]
input_embeds_flat = torch.cat([embed.flatten(2).transpose(1, 2) for embed in input_embeds], 1)
spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=input_embeds_flat.device)
spatial_shapes = torch.as_tensor(spatial_shapes_list, dtype=torch.long, device=input_embeds_flat.device)
masks_flat = torch.cat([mask.flatten(1) for mask in masks], 1)

position_embeddings = [embed.flatten(2).transpose(1, 2) for embed in position_embeddings]
Expand All @@ -1320,7 +1325,7 @@ def forward(
inputs_embeds=input_embeds_flat,
attention_mask=masks_flat,
position_embeddings=level_pos_embed_flat,
spatial_shapes=spatial_shapes,
spatial_shapes_list=spatial_shapes_list,
level_start_index=level_start_index,
valid_ratios=valid_ratios,
output_attentions=output_attentions,
Expand All @@ -1331,18 +1336,23 @@ def forward(
last_hidden_state = encoder_outputs.last_hidden_state
batch_size = last_hidden_state.shape[0]

# We compute level_start_index_list separately from the tensor version level_start_index
# to avoid iterating over a tensor which breaks torch.compile/export.
level_start_index_list = [0]
for height, width in spatial_shapes_list[:-1]:
level_start_index_list.append(level_start_index_list[-1] + height * width)
split_sizes = [None] * self.num_feature_levels
for i in range(self.num_feature_levels):
if i < self.num_feature_levels - 1:
split_sizes[i] = level_start_index[i + 1] - level_start_index[i]
split_sizes[i] = level_start_index_list[i + 1] - level_start_index_list[i]
else:
split_sizes[i] = last_hidden_state.shape[1] - level_start_index[i]
split_sizes[i] = last_hidden_state.shape[1] - level_start_index_list[i]

encoder_output = torch.split(last_hidden_state, [size.item() for size in split_sizes], dim=1)
encoder_output = torch.split(last_hidden_state, split_sizes, dim=1)

# Compute final features
outputs = [
x.transpose(1, 2).view(batch_size, -1, spatial_shapes[i][0], spatial_shapes[i][1])
x.transpose(1, 2).view(batch_size, -1, spatial_shapes_list[i][0], spatial_shapes_list[i][1])
for i, x in enumerate(encoder_output)
]

Expand Down Expand Up @@ -1876,7 +1886,9 @@ def forward(
else:
level_index = idx % self.num_feature_levels

attention_mask[torch.where(attention_mask.sum(-1) == attention_mask.shape[-1])] = False
where = (attention_mask.sum(-1) != attention_mask.shape[-1]).to(attention_mask.dtype)
# Multiply the attention mask instead of indexing to avoid issue in torch.export.
attention_mask = attention_mask * where.unsqueeze(-1)

layer_outputs = decoder_layer(
hidden_states,
Expand Down
26 changes: 26 additions & 0 deletions tests/models/mask2former/test_modeling_mask2former.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from tests.test_modeling_common import floats_tensor
from transformers import Mask2FormerConfig, is_torch_available, is_vision_available
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_4
from transformers.testing_utils import (
require_timm,
require_torch,
Expand Down Expand Up @@ -481,3 +482,28 @@ def test_with_segmentation_maps_and_loss(self):
outputs = model(**inputs)

self.assertTrue(outputs.loss is not None)

def test_export(self):
if not is_torch_greater_or_equal_than_2_4:
self.skipTest(reason="This test requires torch >= 2.4 to run.")
model = Mask2FormerForUniversalSegmentation.from_pretrained(self.model_checkpoints).to(torch_device).eval()
image_processor = self.default_image_processor
image = prepare_img()
inputs = image_processor(image, return_tensors="pt").to(torch_device)

exported_program = torch.export.export(
model,
args=(inputs["pixel_values"], inputs["pixel_mask"]),
strict=True,
)
with torch.no_grad():
eager_outputs = model(**inputs)
exported_outputs = exported_program.module().forward(inputs["pixel_values"], inputs["pixel_mask"])
self.assertEqual(eager_outputs.masks_queries_logits.shape, exported_outputs.masks_queries_logits.shape)
self.assertTrue(
torch.allclose(eager_outputs.masks_queries_logits, exported_outputs.masks_queries_logits, atol=TOLERANCE)
)
self.assertEqual(eager_outputs.class_queries_logits.shape, exported_outputs.class_queries_logits.shape)
self.assertTrue(
torch.allclose(eager_outputs.class_queries_logits, exported_outputs.class_queries_logits, atol=TOLERANCE)
)

0 comments on commit 5fa4f64

Please sign in to comment.