Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add PerSAM [bis] #23659

Merged
merged 5 commits into from
May 23, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 35 additions & 4 deletions src/transformers/models/sam/modeling_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ def _recombine_heads(self, hidden_states: Tensor, point_batch_size: int) -> Tens
hidden_states = hidden_states.transpose(1, 2)
return hidden_states.reshape(batch // point_batch_size, point_batch_size, n_tokens, n_heads * c_per_head)

def forward(self, query: Tensor, key: Tensor, value: Tensor) -> Tensor:
def forward(self, query: Tensor, key: Tensor, value: Tensor, attention_similarity: Tensor = None) -> Tensor:
# Input projections
query = self.q_proj(query)
key = self.k_proj(key)
Expand All @@ -242,6 +242,10 @@ def forward(self, query: Tensor, key: Tensor, value: Tensor) -> Tensor:
attn = attn / math.sqrt(c_per_head)
attn = torch.softmax(attn, dim=-1)

if attention_similarity is not None:
attn = attn + attention_similarity
attn = torch.softmax(attn, dim=-1)

# Get output
out = attn @ value
out = self._recombine_heads(out, point_batch_size)
Expand Down Expand Up @@ -290,6 +294,7 @@ def forward(
keys: Tensor,
query_point_embedding: Tensor,
key_point_embedding: Tensor,
attention_similarity: Tensor,
output_attentions: bool = False,
):
# Self attention block
Expand All @@ -305,7 +310,9 @@ def forward(
query = queries + query_point_embedding
key = keys + key_point_embedding

attn_out = self.cross_attn_token_to_image(query=query, key=key, value=keys)
attn_out = self.cross_attn_token_to_image(
query=query, key=key, value=keys, attention_similarity=attention_similarity
)
queries = queries + attn_out

queries = self.layer_norm2(queries)
Expand Down Expand Up @@ -353,6 +360,8 @@ def forward(
point_embeddings: Tensor,
image_embeddings: Tensor,
image_positional_embeddings: Tensor,
attention_similarity: Tensor,
target_embedding=None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
Expand All @@ -377,11 +386,15 @@ def forward(

# Apply transformer blocks and final layernorm
for layer in self.layers:
if target_embedding is not None:
queries += target_embedding

queries, keys, attention_outputs = layer(
queries=queries,
keys=keys,
query_point_embedding=point_embeddings,
key_point_embedding=image_positional_embeddings,
attention_similarity=attention_similarity,
output_attentions=output_attentions,
)

Expand Down Expand Up @@ -460,6 +473,8 @@ def forward(
dense_prompt_embeddings: torch.Tensor,
multimask_output: bool,
output_attentions: Optional[bool] = None,
attention_similarity: torch.Tensor = None,
target_embedding: torch.Tensor = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Predict masks given image and prompt embeddings.
Expand Down Expand Up @@ -500,6 +515,8 @@ def forward(
point_embeddings=point_embeddings,
image_embeddings=image_embeddings,
image_positional_embeddings=image_positional_embeddings,
attention_similarity=attention_similarity,
target_embedding=target_embedding,
output_attentions=output_attentions,
)
iou_token_out = point_embedding[:, :, 0, :]
Expand Down Expand Up @@ -576,8 +593,12 @@ def __init__(self, config: SamPromptEncoderConfig):
self.conv1 = nn.Conv2d(1, self.mask_input_channels, kernel_size=2, stride=2)
self.conv2 = nn.Conv2d(self.mask_input_channels, config.mask_input_channels, kernel_size=2, stride=2)
self.conv3 = nn.Conv2d(config.mask_input_channels, config.hidden_size, kernel_size=1)
self.layer_norm1 = SamLayerNorm(self.mask_input_channels, config.layer_norm_eps)
self.layer_norm2 = SamLayerNorm(self.mask_input_channels * 4, config.layer_norm_eps)
self.layer_norm1 = SamLayerNorm(
self.mask_input_channels, eps=config.layer_norm_eps, data_format="channels_first"
)
self.layer_norm2 = SamLayerNorm(
self.mask_input_channels * 4, eps=config.layer_norm_eps, data_format="channels_first"
)

def forward(self, masks):
hidden_states = self.conv1(masks)
Expand Down Expand Up @@ -1146,6 +1167,12 @@ def _init_weights(self, module):
In the original implementation and paper, the model always outputs 3 masks per image (or per point / per
bounding box if relevant). However, it is possible to just output a single mask, that corresponds to the
"best" mask, by specifying `multimask_output=False`.
attention_similarity (`torch.FloatTensor`, *optional*):
Attention similarity tensor, to be provided to the mask decoder for target-guided attention in case the
model is used for personalization as introduced in [PerSAM](https://arxiv.org/abs/2305.03048).
target_embedding (`torch.FloatTensor`, *optional*):
Embedding of the target concept, to be provided to the mask decoder for target-semantic prompting in case
the model is used for personalization as introduced in [PerSAM](https://arxiv.org/abs/2305.03048).
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
Expand Down Expand Up @@ -1265,6 +1292,8 @@ def forward(
input_masks: Optional[torch.LongTensor] = None,
image_embeddings: Optional[torch.FloatTensor] = None,
multimask_output: bool = True,
attention_similarity: Optional[torch.FloatTensor] = None,
target_embedding: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict=None,
Expand Down Expand Up @@ -1374,6 +1403,8 @@ def forward(
sparse_prompt_embeddings=sparse_embeddings,
dense_prompt_embeddings=dense_embeddings,
multimask_output=multimask_output,
attention_similarity=attention_similarity,
target_embedding=target_embedding,
output_attentions=output_attentions,
)

Expand Down