Skip to content

Commit

Permalink
Add PerSAM [bis] (#23659)
Browse files Browse the repository at this point in the history
* Add PerSAM args

* Make attn_sim optional

* Rename to attention_similarity

* Add docstrigns

* Improve docstrings
  • Loading branch information
NielsRogge authored May 23, 2023
1 parent aa30cd4 commit 527ab89
Showing 1 changed file with 35 additions and 4 deletions.
39 changes: 35 additions & 4 deletions src/transformers/models/sam/modeling_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ def _recombine_heads(self, hidden_states: Tensor, point_batch_size: int) -> Tens
hidden_states = hidden_states.transpose(1, 2)
return hidden_states.reshape(batch // point_batch_size, point_batch_size, n_tokens, n_heads * c_per_head)

def forward(self, query: Tensor, key: Tensor, value: Tensor) -> Tensor:
def forward(self, query: Tensor, key: Tensor, value: Tensor, attention_similarity: Tensor = None) -> Tensor:
# Input projections
query = self.q_proj(query)
key = self.k_proj(key)
Expand All @@ -242,6 +242,10 @@ def forward(self, query: Tensor, key: Tensor, value: Tensor) -> Tensor:
attn = attn / math.sqrt(c_per_head)
attn = torch.softmax(attn, dim=-1)

if attention_similarity is not None:
attn = attn + attention_similarity
attn = torch.softmax(attn, dim=-1)

# Get output
out = attn @ value
out = self._recombine_heads(out, point_batch_size)
Expand Down Expand Up @@ -290,6 +294,7 @@ def forward(
keys: Tensor,
query_point_embedding: Tensor,
key_point_embedding: Tensor,
attention_similarity: Tensor,
output_attentions: bool = False,
):
# Self attention block
Expand All @@ -305,7 +310,9 @@ def forward(
query = queries + query_point_embedding
key = keys + key_point_embedding

attn_out = self.cross_attn_token_to_image(query=query, key=key, value=keys)
attn_out = self.cross_attn_token_to_image(
query=query, key=key, value=keys, attention_similarity=attention_similarity
)
queries = queries + attn_out

queries = self.layer_norm2(queries)
Expand Down Expand Up @@ -353,6 +360,8 @@ def forward(
point_embeddings: Tensor,
image_embeddings: Tensor,
image_positional_embeddings: Tensor,
attention_similarity: Tensor,
target_embedding=None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
Expand All @@ -377,11 +386,15 @@ def forward(

# Apply transformer blocks and final layernorm
for layer in self.layers:
if target_embedding is not None:
queries += target_embedding

queries, keys, attention_outputs = layer(
queries=queries,
keys=keys,
query_point_embedding=point_embeddings,
key_point_embedding=image_positional_embeddings,
attention_similarity=attention_similarity,
output_attentions=output_attentions,
)

Expand Down Expand Up @@ -460,6 +473,8 @@ def forward(
dense_prompt_embeddings: torch.Tensor,
multimask_output: bool,
output_attentions: Optional[bool] = None,
attention_similarity: torch.Tensor = None,
target_embedding: torch.Tensor = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Predict masks given image and prompt embeddings.
Expand Down Expand Up @@ -500,6 +515,8 @@ def forward(
point_embeddings=point_embeddings,
image_embeddings=image_embeddings,
image_positional_embeddings=image_positional_embeddings,
attention_similarity=attention_similarity,
target_embedding=target_embedding,
output_attentions=output_attentions,
)
iou_token_out = point_embedding[:, :, 0, :]
Expand Down Expand Up @@ -576,8 +593,12 @@ def __init__(self, config: SamPromptEncoderConfig):
self.conv1 = nn.Conv2d(1, self.mask_input_channels, kernel_size=2, stride=2)
self.conv2 = nn.Conv2d(self.mask_input_channels, config.mask_input_channels, kernel_size=2, stride=2)
self.conv3 = nn.Conv2d(config.mask_input_channels, config.hidden_size, kernel_size=1)
self.layer_norm1 = SamLayerNorm(self.mask_input_channels, config.layer_norm_eps)
self.layer_norm2 = SamLayerNorm(self.mask_input_channels * 4, config.layer_norm_eps)
self.layer_norm1 = SamLayerNorm(
self.mask_input_channels, eps=config.layer_norm_eps, data_format="channels_first"
)
self.layer_norm2 = SamLayerNorm(
self.mask_input_channels * 4, eps=config.layer_norm_eps, data_format="channels_first"
)

def forward(self, masks):
hidden_states = self.conv1(masks)
Expand Down Expand Up @@ -1146,6 +1167,12 @@ def _init_weights(self, module):
In the original implementation and paper, the model always outputs 3 masks per image (or per point / per
bounding box if relevant). However, it is possible to just output a single mask, that corresponds to the
"best" mask, by specifying `multimask_output=False`.
attention_similarity (`torch.FloatTensor`, *optional*):
Attention similarity tensor, to be provided to the mask decoder for target-guided attention in case the
model is used for personalization as introduced in [PerSAM](https://arxiv.org/abs/2305.03048).
target_embedding (`torch.FloatTensor`, *optional*):
Embedding of the target concept, to be provided to the mask decoder for target-semantic prompting in case
the model is used for personalization as introduced in [PerSAM](https://arxiv.org/abs/2305.03048).
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
Expand Down Expand Up @@ -1265,6 +1292,8 @@ def forward(
input_masks: Optional[torch.LongTensor] = None,
image_embeddings: Optional[torch.FloatTensor] = None,
multimask_output: bool = True,
attention_similarity: Optional[torch.FloatTensor] = None,
target_embedding: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict=None,
Expand Down Expand Up @@ -1374,6 +1403,8 @@ def forward(
sparse_prompt_embeddings=sparse_embeddings,
dense_prompt_embeddings=dense_embeddings,
multimask_output=multimask_output,
attention_similarity=attention_similarity,
target_embedding=target_embedding,
output_attentions=output_attentions,
)

Expand Down

0 comments on commit 527ab89

Please sign in to comment.