Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fixing name position_embeddings to object_queries #24652

Merged
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
cde3c93
fixing name position_embeddings to object_queries
Lorenzobattistela Jul 4, 2023
7af5048
[fix] renaming variable and docstring do object queries
Lorenzobattistela Jul 11, 2023
234e2be
[fix] comment position_embedding to object queries
Lorenzobattistela Jul 13, 2023
56e3e9e
[feat] changes from make-fix-copies to keep consistency
Lorenzobattistela Jul 20, 2023
0928a48
Revert "[feat] changes from make-fix-copies to keep consistency"
Lorenzobattistela Jul 20, 2023
74549a4
[tests] fix wrong expected score
Lorenzobattistela Jul 20, 2023
6b4b43b
[fix] wrong assignment causing wrong tensor shapes
Lorenzobattistela Jul 20, 2023
d1457d8
[fix] fixing position_embeddings to object queries to keep consistenc…
Lorenzobattistela Jul 20, 2023
f235267
[fix] make fix copies, renaming position_embeddings to object_queries
Lorenzobattistela Jul 20, 2023
3ceb749
[fix] positional_embeddingss to object queries, fixes from make fix c…
Lorenzobattistela Jul 20, 2023
cae8807
[fix] comments frmo make fix copies
Lorenzobattistela Jul 20, 2023
5edbec2
[fix] adding args validation to keep version support
Lorenzobattistela Jul 25, 2023
402de69
[fix] adding args validation to keep version support -conditional detr
Lorenzobattistela Jul 25, 2023
614ab95
[fix] adding args validation to keep version support - maskformer
Lorenzobattistela Jul 25, 2023
94bfb75
[style] make fixup style fixes
Lorenzobattistela Jul 25, 2023
fbd42a0
[feat] adding args checking
Lorenzobattistela Jul 25, 2023
83adc38
[feat] fixcopies and args checking
Lorenzobattistela Jul 25, 2023
acf09ad
make fixup
Lorenzobattistela Jul 25, 2023
e031207
make fixup
Lorenzobattistela Jul 25, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -555,16 +555,16 @@ def __init__(
def _shape(self, tensor: torch.Tensor, seq_len: int, batch_size: int):
return tensor.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()

def with_pos_embed(self, tensor: torch.Tensor, position_embeddings: Optional[Tensor]):
return tensor if position_embeddings is None else tensor + position_embeddings
def with_pos_embed(self, tensor: torch.Tensor, object_queries: Optional[Tensor]):
return tensor if object_queries is None else tensor + object_queries

def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_embeddings: Optional[torch.Tensor] = None,
object_queries: Optional[torch.Tensor] = None,
key_value_states: Optional[torch.Tensor] = None,
key_value_position_embeddings: Optional[torch.Tensor] = None,
spatial_position_embeddings: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
Expand All @@ -575,14 +575,14 @@ def forward(
batch_size, target_len, embed_dim = hidden_states.size()

# add position embeddings to the hidden states before projecting to queries and keys
if position_embeddings is not None:
if object_queries is not None:
hidden_states_original = hidden_states
hidden_states = self.with_pos_embed(hidden_states, position_embeddings)
hidden_states = self.with_pos_embed(hidden_states, object_queries)

# add key-value position embeddings to the key value states
if key_value_position_embeddings is not None:
if spatial_position_embeddings is not None:
key_value_states_original = key_value_states
key_value_states = self.with_pos_embed(key_value_states, key_value_position_embeddings)
key_value_states = self.with_pos_embed(key_value_states, spatial_position_embeddings)

# get query proj
query_states = self.q_proj(hidden_states) * self.scaling
Expand Down Expand Up @@ -790,7 +790,7 @@ def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
position_embeddings: torch.Tensor = None,
object_queries: torch.Tensor = None,
output_attentions: bool = False,
):
"""
Expand All @@ -799,7 +799,8 @@ def forward(
attention_mask (`torch.FloatTensor`): attention mask of size
`(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative
values.
position_embeddings (`torch.FloatTensor`, *optional*): position embeddings, to be added to hidden_states.
object_queries (`torch.FloatTensor`, *optional*):
Object queries (also called content embeddings), to be added to the hidden states.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
Expand All @@ -808,7 +809,7 @@ def forward(
hidden_states, attn_weights = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_embeddings=position_embeddings,
object_queries=object_queries,
output_attentions=output_attentions,
)

Expand Down Expand Up @@ -885,7 +886,7 @@ def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_embeddings: Optional[torch.Tensor] = None,
object_queries: Optional[torch.Tensor] = None,
query_position_embeddings: Optional[torch.Tensor] = None,
query_sine_embed: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
Expand All @@ -899,11 +900,11 @@ def forward(
attention_mask (`torch.FloatTensor`): attention mask of size
`(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative
values.
position_embeddings (`torch.FloatTensor`, *optional*):
position embeddings that are added to the queries and keys
object_queries (`torch.FloatTensor`, *optional*):
object_queries that are added to the queries and keys
in the cross-attention layer.
query_position_embeddings (`torch.FloatTensor`, *optional*):
position embeddings that are added to the queries and keys
object_queries that are added to the queries and keys
in the self-attention layer.
encoder_hidden_states (`torch.FloatTensor`):
cross attention input to the layer of shape `(seq_len, batch, embed_dim)`
Expand Down Expand Up @@ -954,7 +955,7 @@ def forward(
batch_size, num_queries, n_model = q_content.shape
_, source_len, _ = k_content.shape

k_pos = self.ca_kpos_proj(position_embeddings)
k_pos = self.ca_kpos_proj(object_queries)

# For the first decoder layer, we concatenate the positional embedding predicted from
# the object query (the positional embedding) into the original query (key) in DETR.
Expand Down Expand Up @@ -1150,7 +1151,7 @@ class ConditionalDetrEncoder(ConditionalDetrPreTrainedModel):

Small tweak for ConditionalDETR:

- position_embeddings are added to the forward pass.
- object_queries are added to the forward pass.

Args:
config: ConditionalDetrConfig
Expand All @@ -1173,7 +1174,7 @@ def forward(
self,
inputs_embeds=None,
attention_mask=None,
position_embeddings=None,
object_queries=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
Expand All @@ -1191,8 +1192,8 @@ def forward(

[What are attention masks?](../glossary#attention-mask)

position_embeddings (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Position embeddings that are added to the queries and keys in each self-attention layer.
object_queries (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Object queries that are added to the queries in each self-attention layer.

output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
Expand Down Expand Up @@ -1232,11 +1233,11 @@ def forward(
if to_drop:
layer_outputs = (None, None)
else:
# we add position_embeddings as extra input to the encoder_layer
# we add object_queries as extra input to the encoder_layer
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
position_embeddings=position_embeddings,
object_queries=object_queries,
output_attentions=output_attentions,
)

Expand All @@ -1263,7 +1264,7 @@ class ConditionalDetrDecoder(ConditionalDetrPreTrainedModel):

Some small tweaks for Conditional DETR:

- position_embeddings and query_position_embeddings are added to the forward pass.
- object_queries and query_position_embeddings are added to the forward pass.
- if self.config.auxiliary_loss is set to True, also returns a stack of activations from all decoding layers.

Args:
Expand Down Expand Up @@ -1296,7 +1297,7 @@ def forward(
attention_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
position_embeddings=None,
object_queries=None,
query_position_embeddings=None,
output_attentions=None,
output_hidden_states=None,
Expand Down Expand Up @@ -1324,7 +1325,7 @@ def forward(
- 1 for pixels that are real (i.e. **not masked**),
- 0 for pixels that are padding (i.e. **masked**).

position_embeddings (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
object_queries (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Position embeddings that are added to the queries and keys in each cross-attention layer.
query_position_embeddings (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`):
, *optional*): Position embeddings that are added to the queries and keys in each self-attention layer.
Expand Down Expand Up @@ -1404,7 +1405,7 @@ def custom_forward(*inputs):
create_custom_forward(decoder_layer),
hidden_states,
combined_attention_mask,
position_embeddings,
object_queries,
query_position_embeddings,
query_sine_embed,
encoder_hidden_states,
Expand All @@ -1416,7 +1417,7 @@ def custom_forward(*inputs):
layer_outputs = decoder_layer(
hidden_states,
attention_mask=combined_attention_mask,
position_embeddings=position_embeddings,
object_queries=object_queries,
query_position_embeddings=query_position_embeddings,
query_sine_embed=query_sine_embed,
encoder_hidden_states=encoder_hidden_states,
Expand Down Expand Up @@ -1484,8 +1485,8 @@ def __init__(self, config: ConditionalDetrConfig):

# Create backbone + positional encoding
backbone = ConditionalDetrConvEncoder(config)
position_embeddings = build_position_encoding(config)
self.backbone = ConditionalDetrConvModel(backbone, position_embeddings)
object_queries = build_position_encoding(config)
self.backbone = ConditionalDetrConvModel(backbone, object_queries)

# Create projection layer
self.input_projection = nn.Conv2d(backbone.intermediate_channel_sizes[-1], config.d_model, kernel_size=1)
Expand Down Expand Up @@ -1569,7 +1570,7 @@ def forward(
# First, sent pixel_values + pixel_mask through Backbone to obtain the features
# pixel_values should be of shape (batch_size, num_channels, height, width)
# pixel_mask should be of shape (batch_size, height, width)
features, position_embeddings_list = self.backbone(pixel_values, pixel_mask)
features, object_queries_list = self.backbone(pixel_values, pixel_mask)

# get final feature map and downsampled mask
feature_map, mask = features[-1]
Expand All @@ -1580,21 +1581,21 @@ def forward(
# Second, apply 1x1 convolution to reduce the channel dimension to d_model (256 by default)
projected_feature_map = self.input_projection(feature_map)

# Third, flatten the feature map + position embeddings of shape NxCxHxW to NxCxHW, and permute it to NxHWxC
# Third, flatten the feature map + object_queries of shape NxCxHxW to NxCxHW, and permute it to NxHWxC
# In other words, turn their shape into (batch_size, sequence_length, hidden_size)
flattened_features = projected_feature_map.flatten(2).permute(0, 2, 1)
position_embeddings = position_embeddings_list[-1].flatten(2).permute(0, 2, 1)
object_queries = object_queries_list[-1].flatten(2).permute(0, 2, 1)

flattened_mask = mask.flatten(1)

# Fourth, sent flattened_features + flattened_mask + position embeddings through encoder
# Fourth, sent flattened_features + flattened_mask + object_queries through encoder
# flattened_features is a Tensor of shape (batch_size, heigth*width, hidden_size)
# flattened_mask is a Tensor of shape (batch_size, heigth*width)
if encoder_outputs is None:
encoder_outputs = self.encoder(
inputs_embeds=flattened_features,
attention_mask=flattened_mask,
position_embeddings=position_embeddings,
object_queries=object_queries,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
Expand All @@ -1607,15 +1608,15 @@ def forward(
attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
)

# Fifth, sent query embeddings + position embeddings through the decoder (which is conditioned on the encoder output)
# Fifth, sent query embeddings + object_queries through the decoder (which is conditioned on the encoder output)
query_position_embeddings = self.query_position_embeddings.weight.unsqueeze(0).repeat(batch_size, 1, 1)
queries = torch.zeros_like(query_position_embeddings)

# decoder outputs consists of (dec_features, dec_hidden, dec_attn)
decoder_outputs = self.decoder(
inputs_embeds=queries,
attention_mask=None,
position_embeddings=position_embeddings,
object_queries=object_queries,
query_position_embeddings=query_position_embeddings,
encoder_hidden_states=encoder_outputs[0],
encoder_attention_mask=flattened_mask,
Expand Down Expand Up @@ -1931,29 +1932,29 @@ def forward(
if pixel_mask is None:
pixel_mask = torch.ones((batch_size, height, width), device=device)

# First, get list of feature maps and position embeddings
features, position_embeddings_list = self.conditional_detr.model.backbone(pixel_values, pixel_mask=pixel_mask)
# First, get list of feature maps and object_queries
features, object_queries_list = self.conditional_detr.model.backbone(pixel_values, pixel_mask=pixel_mask)

# Second, apply 1x1 convolution to reduce the channel dimension to d_model (256 by default)
feature_map, mask = features[-1]
batch_size, num_channels, height, width = feature_map.shape
projected_feature_map = self.conditional_detr.model.input_projection(feature_map)

# Third, flatten the feature map + position embeddings of shape NxCxHxW to NxCxHW, and permute it to NxHWxC
# Third, flatten the feature map + object_queries of shape NxCxHxW to NxCxHW, and permute it to NxHWxC
# In other words, turn their shape into (batch_size, sequence_length, hidden_size)
flattened_features = projected_feature_map.flatten(2).permute(0, 2, 1)
position_embeddings = position_embeddings_list[-1].flatten(2).permute(0, 2, 1)
object_queries = object_queries_list[-1].flatten(2).permute(0, 2, 1)

flattened_mask = mask.flatten(1)

# Fourth, sent flattened_features + flattened_mask + position embeddings through encoder
# Fourth, sent flattened_features + flattened_mask + object_queries through encoder
# flattened_features is a Tensor of shape (batch_size, heigth*width, hidden_size)
# flattened_mask is a Tensor of shape (batch_size, heigth*width)
if encoder_outputs is None:
encoder_outputs = self.conditional_detr.model.encoder(
inputs_embeds=flattened_features,
attention_mask=flattened_mask,
position_embeddings=position_embeddings,
object_queries=object_queries,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
Expand All @@ -1966,7 +1967,7 @@ def forward(
attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
)

# Fifth, sent query embeddings + position embeddings through the decoder (which is conditioned on the encoder output)
# Fifth, sent query embeddings + object_queries through the decoder (which is conditioned on the encoder output)
query_position_embeddings = self.conditional_detr.model.query_position_embeddings.weight.unsqueeze(0).repeat(
batch_size, 1, 1
)
Expand All @@ -1976,7 +1977,7 @@ def forward(
decoder_outputs = self.conditional_detr.model.decoder(
inputs_embeds=queries,
attention_mask=None,
position_embeddings=position_embeddings,
object_queries=object_queries,
query_position_embeddings=query_position_embeddings,
encoder_hidden_states=encoder_outputs[0],
encoder_attention_mask=flattened_mask,
Expand Down
Loading