diff --git a/src/transformers/models/sam/modeling_sam.py b/src/transformers/models/sam/modeling_sam.py index 7df46117509700..29111c14436216 100644 --- a/src/transformers/models/sam/modeling_sam.py +++ b/src/transformers/models/sam/modeling_sam.py @@ -224,7 +224,7 @@ def _recombine_heads(self, hidden_states: Tensor, point_batch_size: int) -> Tens hidden_states = hidden_states.transpose(1, 2) return hidden_states.reshape(batch // point_batch_size, point_batch_size, n_tokens, n_heads * c_per_head) - def forward(self, query: Tensor, key: Tensor, value: Tensor) -> Tensor: + def forward(self, query: Tensor, key: Tensor, value: Tensor, attention_similarity: Tensor = None) -> Tensor: # Input projections query = self.q_proj(query) key = self.k_proj(key) @@ -242,6 +242,10 @@ def forward(self, query: Tensor, key: Tensor, value: Tensor) -> Tensor: attn = attn / math.sqrt(c_per_head) attn = torch.softmax(attn, dim=-1) + if attention_similarity is not None: + attn = attn + attention_similarity + attn = torch.softmax(attn, dim=-1) + # Get output out = attn @ value out = self._recombine_heads(out, point_batch_size) @@ -290,6 +294,7 @@ def forward( keys: Tensor, query_point_embedding: Tensor, key_point_embedding: Tensor, + attention_similarity: Tensor, output_attentions: bool = False, ): # Self attention block @@ -305,7 +310,9 @@ def forward( query = queries + query_point_embedding key = keys + key_point_embedding - attn_out = self.cross_attn_token_to_image(query=query, key=key, value=keys) + attn_out = self.cross_attn_token_to_image( + query=query, key=key, value=keys, attention_similarity=attention_similarity + ) queries = queries + attn_out queries = self.layer_norm2(queries) @@ -353,6 +360,8 @@ def forward( point_embeddings: Tensor, image_embeddings: Tensor, image_positional_embeddings: Tensor, + attention_similarity: Tensor, + target_embedding=None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, @@ -377,11 +386,15 @@ def forward( # Apply transformer blocks and final layernorm for layer in self.layers: + if target_embedding is not None: + queries += target_embedding + queries, keys, attention_outputs = layer( queries=queries, keys=keys, query_point_embedding=point_embeddings, key_point_embedding=image_positional_embeddings, + attention_similarity=attention_similarity, output_attentions=output_attentions, ) @@ -460,6 +473,8 @@ def forward( dense_prompt_embeddings: torch.Tensor, multimask_output: bool, output_attentions: Optional[bool] = None, + attention_similarity: torch.Tensor = None, + target_embedding: torch.Tensor = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Predict masks given image and prompt embeddings. @@ -500,6 +515,8 @@ def forward( point_embeddings=point_embeddings, image_embeddings=image_embeddings, image_positional_embeddings=image_positional_embeddings, + attention_similarity=attention_similarity, + target_embedding=target_embedding, output_attentions=output_attentions, ) iou_token_out = point_embedding[:, :, 0, :] @@ -576,8 +593,12 @@ def __init__(self, config: SamPromptEncoderConfig): self.conv1 = nn.Conv2d(1, self.mask_input_channels, kernel_size=2, stride=2) self.conv2 = nn.Conv2d(self.mask_input_channels, config.mask_input_channels, kernel_size=2, stride=2) self.conv3 = nn.Conv2d(config.mask_input_channels, config.hidden_size, kernel_size=1) - self.layer_norm1 = SamLayerNorm(self.mask_input_channels, config.layer_norm_eps) - self.layer_norm2 = SamLayerNorm(self.mask_input_channels * 4, config.layer_norm_eps) + self.layer_norm1 = SamLayerNorm( + self.mask_input_channels, eps=config.layer_norm_eps, data_format="channels_first" + ) + self.layer_norm2 = SamLayerNorm( + self.mask_input_channels * 4, eps=config.layer_norm_eps, data_format="channels_first" + ) def forward(self, masks): hidden_states = self.conv1(masks) @@ -1146,6 +1167,12 @@ def _init_weights(self, module): In the original implementation and paper, the model always outputs 3 masks per image (or per point / per bounding box if relevant). However, it is possible to just output a single mask, that corresponds to the "best" mask, by specifying `multimask_output=False`. + attention_similarity (`torch.FloatTensor`, *optional*): + Attention similarity tensor, to be provided to the mask decoder for target-guided attention in case the + model is used for personalization as introduced in [PerSAM](https://arxiv.org/abs/2305.03048). + target_embedding (`torch.FloatTensor`, *optional*): + Embedding of the target concept, to be provided to the mask decoder for target-semantic prompting in case + the model is used for personalization as introduced in [PerSAM](https://arxiv.org/abs/2305.03048). output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. @@ -1265,6 +1292,8 @@ def forward( input_masks: Optional[torch.LongTensor] = None, image_embeddings: Optional[torch.FloatTensor] = None, multimask_output: bool = True, + attention_similarity: Optional[torch.FloatTensor] = None, + target_embedding: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict=None, @@ -1374,6 +1403,8 @@ def forward( sparse_prompt_embeddings=sparse_embeddings, dense_prompt_embeddings=dense_embeddings, multimask_output=multimask_output, + attention_similarity=attention_similarity, + target_embedding=target_embedding, output_attentions=output_attentions, )