Skip to content

Commit

Permalink
[feat] reversing changes on undesired models
Browse files Browse the repository at this point in the history
  • Loading branch information
Lorenzobattistela committed Jul 13, 2023
1 parent 10342e9 commit df63b45
Show file tree
Hide file tree
Showing 3 changed files with 481 additions and 1,900 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -555,16 +555,16 @@ def __init__(
def _shape(self, tensor: torch.Tensor, seq_len: int, batch_size: int):
return tensor.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()

def with_pos_embed(self, tensor: torch.Tensor, object_queries: Optional[Tensor]):
return tensor if object_queries is None else tensor + object_queries
def with_pos_embed(self, tensor: torch.Tensor, position_embeddings: Optional[Tensor]):
return tensor if position_embeddings is None else tensor + position_embeddings

def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
object_queries: Optional[torch.Tensor] = None,
position_embeddings: Optional[torch.Tensor] = None,
key_value_states: Optional[torch.Tensor] = None,
spatial_position_embeddings: Optional[torch.Tensor] = None,
key_value_position_embeddings: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
Expand All @@ -575,14 +575,14 @@ def forward(
batch_size, target_len, embed_dim = hidden_states.size()

# add position embeddings to the hidden states before projecting to queries and keys
if object_queries is not None:
if position_embeddings is not None:
hidden_states_original = hidden_states
hidden_states = self.with_pos_embed(hidden_states, object_queries)
hidden_states = self.with_pos_embed(hidden_states, position_embeddings)

# add key-value position embeddings to the key value states
if spatial_position_embeddings is not None:
if key_value_position_embeddings is not None:
key_value_states_original = key_value_states
key_value_states = self.with_pos_embed(key_value_states, spatial_position_embeddings)
key_value_states = self.with_pos_embed(key_value_states, key_value_position_embeddings)

# get query proj
query_states = self.q_proj(hidden_states) * self.scaling
Expand Down Expand Up @@ -790,7 +790,7 @@ def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
object_queries: torch.Tensor = None,
position_embeddings: torch.Tensor = None,
output_attentions: bool = False,
):
"""
Expand All @@ -799,8 +799,7 @@ def forward(
attention_mask (`torch.FloatTensor`): attention mask of size
`(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative
values.
object_queries (`torch.FloatTensor`, *optional*):
Object queries (also called content embeddings), to be added to the hidden states.
position_embeddings (`torch.FloatTensor`, *optional*): position embeddings, to be added to hidden_states.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
Expand All @@ -809,7 +808,7 @@ def forward(
hidden_states, attn_weights = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
object_queries=object_queries,
position_embeddings=position_embeddings,
output_attentions=output_attentions,
)

Expand Down Expand Up @@ -1151,7 +1150,7 @@ class ConditionalDetrEncoder(ConditionalDetrPreTrainedModel):
Small tweak for ConditionalDETR:
- object_queries are added to the forward pass.
- position_embeddings are added to the forward pass.
Args:
config: ConditionalDetrConfig
Expand All @@ -1174,7 +1173,7 @@ def forward(
self,
inputs_embeds=None,
attention_mask=None,
object_queries=None,
position_embeddings=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
Expand All @@ -1192,8 +1191,8 @@ def forward(
[What are attention masks?](../glossary#attention-mask)
object_queries (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Object queries that are added to the queries in each self-attention layer.
position_embeddings (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Position embeddings that are added to the queries and keys in each self-attention layer.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
Expand Down Expand Up @@ -1233,11 +1232,11 @@ def forward(
if to_drop:
layer_outputs = (None, None)
else:
# we add object_queries as extra input to the encoder_layer
# we add position_embeddings as extra input to the encoder_layer
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
object_queries=object_queries,
position_embeddings=position_embeddings,
output_attentions=output_attentions,
)

Expand Down Expand Up @@ -2661,4 +2660,4 @@ def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
m[: img.shape[1], : img.shape[2]] = False
else:
raise ValueError("Only 3-dimensional tensors are supported")
return NestedTensor(tensor, mask)
return NestedTensor(tensor, mask)
30 changes: 15 additions & 15 deletions src/transformers/models/maskformer/modeling_maskformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,16 +437,16 @@ def __init__(
def _shape(self, tensor: torch.Tensor, seq_len: int, batch_size: int):
return tensor.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()

def with_pos_embed(self, tensor: torch.Tensor, object_queries: Optional[Tensor]):
return tensor if object_queries is None else tensor + object_queries
def with_pos_embed(self, tensor: torch.Tensor, position_embeddings: Optional[Tensor]):
return tensor if position_embeddings is None else tensor + position_embeddings

def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
object_queries: Optional[torch.Tensor] = None,
position_embeddings: Optional[torch.Tensor] = None,
key_value_states: Optional[torch.Tensor] = None,
spatial_position_embeddings: Optional[torch.Tensor] = None,
key_value_position_embeddings: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
Expand All @@ -457,14 +457,14 @@ def forward(
batch_size, target_len, embed_dim = hidden_states.size()

# add position embeddings to the hidden states before projecting to queries and keys
if object_queries is not None:
if position_embeddings is not None:
hidden_states_original = hidden_states
hidden_states = self.with_pos_embed(hidden_states, object_queries)
hidden_states = self.with_pos_embed(hidden_states, position_embeddings)

# add key-value position embeddings to the key value states
if spatial_position_embeddings is not None:
if key_value_position_embeddings is not None:
key_value_states_original = key_value_states
key_value_states = self.with_pos_embed(key_value_states, spatial_position_embeddings)
key_value_states = self.with_pos_embed(key_value_states, key_value_position_embeddings)

# get query proj
query_states = self.q_proj(hidden_states) * self.scaling
Expand Down Expand Up @@ -563,7 +563,7 @@ def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
object_queries: Optional[torch.Tensor] = None,
position_embeddings: Optional[torch.Tensor] = None,
query_position_embeddings: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
Expand All @@ -575,8 +575,8 @@ def forward(
attention_mask (`torch.FloatTensor`): attention mask of size
`(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative
values.
object_queries (`torch.FloatTensor`, *optional*):
object_queries that are added to the hidden states
position_embeddings (`torch.FloatTensor`, *optional*):
position embeddings that are added to the queries and keys
in the cross-attention layer.
query_position_embeddings (`torch.FloatTensor`, *optional*):
position embeddings that are added to the queries and keys
Expand All @@ -595,7 +595,7 @@ def forward(
# Self Attention
hidden_states, self_attn_weights = self.self_attn(
hidden_states=hidden_states,
object_queries=query_position_embeddings,
position_embeddings=query_position_embeddings,
attention_mask=attention_mask,
output_attentions=output_attentions,
)
Expand All @@ -611,10 +611,10 @@ def forward(

hidden_states, cross_attn_weights = self.encoder_attn(
hidden_states=hidden_states,
object_queries=object_queries,
position_embeddings=query_position_embeddings,
key_value_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
spatial_position_embeddings=query_position_embeddings,
key_value_position_embeddings=position_embeddings,
output_attentions=output_attentions,
)

Expand Down Expand Up @@ -1839,4 +1839,4 @@ def forward(
output = tuple(v for v in output.values())
if loss is not None:
output = ((loss)) + output
return output
return output
Loading

0 comments on commit df63b45

Please sign in to comment.