diff --git a/src/transformers/models/mask2former/modeling_mask2former.py b/src/transformers/models/mask2former/modeling_mask2former.py index f4aea415adf..4cc96b1652d 100644 --- a/src/transformers/models/mask2former/modeling_mask2former.py +++ b/src/transformers/models/mask2former/modeling_mask2former.py @@ -926,7 +926,7 @@ def forward( encoder_attention_mask=None, position_embeddings: Optional[torch.Tensor] = None, reference_points=None, - spatial_shapes=None, + spatial_shapes_list=None, level_start_index=None, output_attentions: bool = False, ): @@ -936,7 +936,8 @@ def forward( batch_size, num_queries, _ = hidden_states.shape batch_size, sequence_length, _ = encoder_hidden_states.shape - if (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() != sequence_length: + total_elements = sum(height * width for height, width in spatial_shapes_list) + if total_elements != sequence_length: raise ValueError( "Make sure to align the spatial shapes with the sequence length of the encoder hidden states" ) @@ -957,7 +958,11 @@ def forward( ) # batch_size, num_queries, n_heads, n_levels, n_points, 2 if reference_points.shape[-1] == 2: - offset_normalizer = torch.stack([spatial_shapes[..., 1], spatial_shapes[..., 0]], -1) + offset_normalizer = torch.tensor( + [[shape[1], shape[0]] for shape in spatial_shapes_list], + dtype=torch.long, + device=reference_points.device, + ) sampling_locations = ( reference_points[:, :, None, :, None, :] + sampling_offsets / offset_normalizer[None, None, None, :, None, :] @@ -970,7 +975,7 @@ def forward( else: raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]}") - output = multi_scale_deformable_attention(value, spatial_shapes, sampling_locations, attention_weights) + output = multi_scale_deformable_attention(value, spatial_shapes_list, sampling_locations, attention_weights) output = self.output_proj(output) return output, attention_weights @@ -1001,7 +1006,7 @@ def forward( attention_mask: torch.Tensor, position_embeddings: torch.Tensor = None, reference_points=None, - spatial_shapes=None, + spatial_shapes_list=None, level_start_index=None, output_attentions: bool = False, ): @@ -1015,8 +1020,8 @@ def forward( Position embeddings, to be added to `hidden_states`. reference_points (`torch.FloatTensor`, *optional*): Reference points. - spatial_shapes (`torch.LongTensor`, *optional*): - Spatial shapes of the backbone feature maps. + spatial_shapes_list (`list` of `tuple`): + Spatial shapes of the backbone feature maps as a list of tuples. level_start_index (`torch.LongTensor`, *optional*): Level start index. output_attentions (`bool`, *optional*): @@ -1033,7 +1038,7 @@ def forward( encoder_attention_mask=attention_mask, position_embeddings=position_embeddings, reference_points=reference_points, - spatial_shapes=spatial_shapes, + spatial_shapes_list=spatial_shapes_list, level_start_index=level_start_index, output_attentions=output_attentions, ) @@ -1086,13 +1091,13 @@ def __init__(self, config: Mask2FormerConfig): ) @staticmethod - def get_reference_points(spatial_shapes, valid_ratios, device): + def get_reference_points(spatial_shapes_list, valid_ratios, device): """ Get reference points for each feature map. Used in decoder. Args: - spatial_shapes (`torch.LongTensor`): - Spatial shapes of each feature map, has shape of `(num_feature_levels, 2)`. + spatial_shapes_list (`list` of `tuple`): + Spatial shapes of the backbone feature maps as a list of tuples. valid_ratios (`torch.FloatTensor`): Valid ratios of each feature map, has shape of `(batch_size, num_feature_levels, 2)`. device (`torch.device`): @@ -1101,7 +1106,7 @@ def get_reference_points(spatial_shapes, valid_ratios, device): `torch.FloatTensor` of shape `(batch_size, num_queries, num_feature_levels, 2)` """ reference_points_list = [] - for lvl, (height, width) in enumerate(spatial_shapes): + for lvl, (height, width) in enumerate(spatial_shapes_list): ref_y, ref_x = torch.meshgrid( torch.linspace(0.5, height - 0.5, height, dtype=valid_ratios.dtype, device=device), torch.linspace(0.5, width - 0.5, width, dtype=valid_ratios.dtype, device=device), @@ -1122,7 +1127,7 @@ def forward( inputs_embeds=None, attention_mask=None, position_embeddings=None, - spatial_shapes=None, + spatial_shapes_list=None, level_start_index=None, valid_ratios=None, output_attentions=None, @@ -1140,8 +1145,8 @@ def forward( [What are attention masks?](../glossary#attention-mask) position_embeddings (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): Position embeddings that are added to the queries and keys in each self-attention layer. - spatial_shapes (`torch.LongTensor` of shape `(num_feature_levels, 2)`): - Spatial shapes of each feature map. + spatial_shapes_list (`list` of `tuple`): + Spatial shapes of each feature map as a list of tuples. level_start_index (`torch.LongTensor` of shape `(num_feature_levels)`): Starting index of each feature map. valid_ratios (`torch.FloatTensor` of shape `(batch_size, num_feature_levels, 2)`): @@ -1162,7 +1167,7 @@ def forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict hidden_states = inputs_embeds - reference_points = self.get_reference_points(spatial_shapes, valid_ratios, device=inputs_embeds.device) + reference_points = self.get_reference_points(spatial_shapes_list, valid_ratios, device=inputs_embeds.device) all_hidden_states = () if output_hidden_states else None all_attentions = () if output_attentions else None @@ -1176,7 +1181,7 @@ def forward( attention_mask, position_embeddings=position_embeddings, reference_points=reference_points, - spatial_shapes=spatial_shapes, + spatial_shapes_list=spatial_shapes_list, level_start_index=level_start_index, output_attentions=output_attentions, ) @@ -1302,9 +1307,9 @@ def forward( ] # Prepare encoder inputs (by flattening) - spatial_shapes = [(embed.shape[2], embed.shape[3]) for embed in input_embeds] + spatial_shapes_list = [(embed.shape[2], embed.shape[3]) for embed in input_embeds] input_embeds_flat = torch.cat([embed.flatten(2).transpose(1, 2) for embed in input_embeds], 1) - spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=input_embeds_flat.device) + spatial_shapes = torch.as_tensor(spatial_shapes_list, dtype=torch.long, device=input_embeds_flat.device) masks_flat = torch.cat([mask.flatten(1) for mask in masks], 1) position_embeddings = [embed.flatten(2).transpose(1, 2) for embed in position_embeddings] @@ -1320,7 +1325,7 @@ def forward( inputs_embeds=input_embeds_flat, attention_mask=masks_flat, position_embeddings=level_pos_embed_flat, - spatial_shapes=spatial_shapes, + spatial_shapes_list=spatial_shapes_list, level_start_index=level_start_index, valid_ratios=valid_ratios, output_attentions=output_attentions, @@ -1331,18 +1336,23 @@ def forward( last_hidden_state = encoder_outputs.last_hidden_state batch_size = last_hidden_state.shape[0] + # We compute level_start_index_list separately from the tensor version level_start_index + # to avoid iterating over a tensor which breaks torch.compile/export. + level_start_index_list = [0] + for height, width in spatial_shapes_list[:-1]: + level_start_index_list.append(level_start_index_list[-1] + height * width) split_sizes = [None] * self.num_feature_levels for i in range(self.num_feature_levels): if i < self.num_feature_levels - 1: - split_sizes[i] = level_start_index[i + 1] - level_start_index[i] + split_sizes[i] = level_start_index_list[i + 1] - level_start_index_list[i] else: - split_sizes[i] = last_hidden_state.shape[1] - level_start_index[i] + split_sizes[i] = last_hidden_state.shape[1] - level_start_index_list[i] - encoder_output = torch.split(last_hidden_state, [size.item() for size in split_sizes], dim=1) + encoder_output = torch.split(last_hidden_state, split_sizes, dim=1) # Compute final features outputs = [ - x.transpose(1, 2).view(batch_size, -1, spatial_shapes[i][0], spatial_shapes[i][1]) + x.transpose(1, 2).view(batch_size, -1, spatial_shapes_list[i][0], spatial_shapes_list[i][1]) for i, x in enumerate(encoder_output) ] @@ -1876,7 +1886,9 @@ def forward( else: level_index = idx % self.num_feature_levels - attention_mask[torch.where(attention_mask.sum(-1) == attention_mask.shape[-1])] = False + where = (attention_mask.sum(-1) != attention_mask.shape[-1]).to(attention_mask.dtype) + # Multiply the attention mask instead of indexing to avoid issue in torch.export. + attention_mask = attention_mask * where.unsqueeze(-1) layer_outputs = decoder_layer( hidden_states, diff --git a/tests/models/mask2former/test_modeling_mask2former.py b/tests/models/mask2former/test_modeling_mask2former.py index ba78cf9ce3f..a3caefe14ab 100644 --- a/tests/models/mask2former/test_modeling_mask2former.py +++ b/tests/models/mask2former/test_modeling_mask2former.py @@ -20,6 +20,7 @@ from tests.test_modeling_common import floats_tensor from transformers import Mask2FormerConfig, is_torch_available, is_vision_available +from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_4 from transformers.testing_utils import ( require_timm, require_torch, @@ -481,3 +482,28 @@ def test_with_segmentation_maps_and_loss(self): outputs = model(**inputs) self.assertTrue(outputs.loss is not None) + + def test_export(self): + if not is_torch_greater_or_equal_than_2_4: + self.skipTest(reason="This test requires torch >= 2.4 to run.") + model = Mask2FormerForUniversalSegmentation.from_pretrained(self.model_checkpoints).to(torch_device).eval() + image_processor = self.default_image_processor + image = prepare_img() + inputs = image_processor(image, return_tensors="pt").to(torch_device) + + exported_program = torch.export.export( + model, + args=(inputs["pixel_values"], inputs["pixel_mask"]), + strict=True, + ) + with torch.no_grad(): + eager_outputs = model(**inputs) + exported_outputs = exported_program.module().forward(inputs["pixel_values"], inputs["pixel_mask"]) + self.assertEqual(eager_outputs.masks_queries_logits.shape, exported_outputs.masks_queries_logits.shape) + self.assertTrue( + torch.allclose(eager_outputs.masks_queries_logits, exported_outputs.masks_queries_logits, atol=TOLERANCE) + ) + self.assertEqual(eager_outputs.class_queries_logits.shape, exported_outputs.class_queries_logits.shape) + self.assertTrue( + torch.allclose(eager_outputs.class_queries_logits, exported_outputs.class_queries_logits, atol=TOLERANCE) + )