-
Notifications
You must be signed in to change notification settings - Fork 9.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Refactor detr 3.x conditional detr (#9405)
* Add conditional detr to 3.0 Co-authored-by: lym <[email protected]>
- Loading branch information
Showing
14 changed files
with
969 additions
and
15 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
# Conditional DETR | ||
|
||
> [Conditional DETR for Fast Training Convergence](https://arxiv.org/abs/2108.06152) | ||
<!-- [ALGORITHM] --> | ||
|
||
## Abstract | ||
|
||
The DETR approach applies the transformer encoder and decoder architecture to object detection and achieves promising performance. In this paper, we handle the critical issue, slow training convergence, and present a conditional cross-attention mechanism for fast DETR training. Our approach is motivated by that the cross-attention in DETR relies highly on the content embeddings and that the spatial embeddings make minor contributions, increasing the need for high-quality content embeddings and thus increasing the training difficulty. | ||
|
||
<div align=center> | ||
<img src="https://github.com/Atten4Vis/ConditionalDETR/blob/main/.github/attention-maps.png?raw=true"/> | ||
</div> | ||
|
||
Our conditional DETR learns a conditional spatial query from the decoder embedding for decoder multi-head cross-attention. The benefit is that through the conditional spatial query, each cross-attention head is able to attend to a band containing a distinct region, e.g., one object extremity or a region inside the object box (Figure 1). This narrows down the spatial range for localizing the distinct regions for object classification and box regression, thus relaxing the dependence on the content embeddings and easing the training. Empirical results show that conditional DETR converges 6.7x faster for the backbones R50 and R101 and 10x faster for stronger backbones DC5-R50 and DC5-R101. | ||
|
||
<div align=center> | ||
<img src="https://github.com/Atten4Vis/ConditionalDETR/raw/main/.github/conditional-detr.png" width="48%"/> | ||
<img src="https://github.com/Atten4Vis/ConditionalDETR/raw/main/.github/convergence-curve.png" width="48%"/> | ||
</div> | ||
|
||
## Results and Models | ||
|
||
We provide the config files and models for Conditional DETR: [Conditional DETR for Fast Training Convergence](https://arxiv.org/abs/2108.06152). | ||
|
||
| Backbone | Model | Lr schd | Mem (GB) | Inf time (fps) | box AP | Config | Download | | ||
| :------: | :--------------: | :-----: | :------: | :------------: | :----: | :-----------------------------------------------: | :----------------------------------: | | ||
| R-50 | Conditional DETR | 50e | 7.9 | | 40.9 | [config](./conditional_detr_r50_8xb2-50e_coco.py) | \[model\](# TODO) \| \[log\](# TODO) | | ||
|
||
## Citation | ||
|
||
```latex | ||
@inproceedings{meng2021-CondDETR, | ||
title = {Conditional DETR for Fast Training Convergence}, | ||
author = {Meng, Depu and Chen, Xiaokang and Fan, Zejia and Zeng, Gang and Li, Houqiang and Yuan, Yuhui and Sun, Lei and Wang, Jingdong}, | ||
booktitle = {Proceedings of the IEEE International Conference on Computer Vision (ICCV)}, | ||
year = {2021} | ||
} | ||
``` |
42 changes: 42 additions & 0 deletions
42
configs/conditional_detr/conditional_detr_r50_8xb2-50e_coco.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
_base_ = ['../detr/detr_r50_8xb2-150e_coco.py'] | ||
model = dict( | ||
type='ConditionalDETR', | ||
num_queries=300, | ||
decoder=dict( | ||
num_layers=6, | ||
layer_cfg=dict( | ||
self_attn_cfg=dict( | ||
_delete_=True, | ||
embed_dims=256, | ||
num_heads=8, | ||
attn_drop=0.1, | ||
cross_attn=False), | ||
cross_attn_cfg=dict( | ||
_delete_=True, | ||
embed_dims=256, | ||
num_heads=8, | ||
attn_drop=0.1, | ||
cross_attn=True))), | ||
bbox_head=dict( | ||
type='ConditionalDETRHead', | ||
loss_cls=dict( | ||
_delete_=True, | ||
type='FocalLoss', | ||
use_sigmoid=True, | ||
gamma=2.0, | ||
alpha=0.25, | ||
loss_weight=2.0)), | ||
# training and testing settings | ||
train_cfg=dict( | ||
assigner=dict( | ||
type='HungarianAssigner', | ||
match_costs=[ | ||
dict(type='FocalLossCost', weight=2.0), | ||
dict(type='BBoxL1Cost', weight=5.0, box_format='xywh'), | ||
dict(type='IoUCost', iou_mode='giou', weight=2.0) | ||
]))) | ||
|
||
# learning policy | ||
train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=50, val_interval=1) | ||
|
||
param_scheduler = [dict(type='MultiStepLR', end=50, milestones=[40])] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,168 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from typing import Tuple | ||
|
||
import torch | ||
import torch.nn as nn | ||
from mmengine.model import bias_init_with_prob | ||
from torch import Tensor | ||
|
||
from mmdet.models.layers.transformer import inverse_sigmoid | ||
from mmdet.registry import MODELS | ||
from mmdet.structures import SampleList | ||
from mmdet.utils import InstanceList | ||
from .detr_head import DETRHead | ||
|
||
|
||
@MODELS.register_module() | ||
class ConditionalDETRHead(DETRHead): | ||
"""Head of Conditional DETR. Conditional DETR: Conditional DETR for Fast | ||
Training Convergence. More details can be found in the `paper. | ||
<https://arxiv.org/abs/2108.06152>`_ . | ||
""" | ||
|
||
def init_weights(self): | ||
"""Initialize weights of the transformer head.""" | ||
super().init_weights() | ||
# The initialization below for transformer head is very | ||
# important as we use Focal_loss for loss_cls | ||
if self.loss_cls.use_sigmoid: | ||
bias_init = bias_init_with_prob(0.01) | ||
nn.init.constant_(self.fc_cls.bias, bias_init) | ||
|
||
def forward(self, hidden_states: Tensor, | ||
references: Tensor) -> Tuple[Tensor, Tensor]: | ||
""""Forward function. | ||
Args: | ||
hidden_states (Tensor): Features from transformer decoder. If | ||
`return_intermediate_dec` in detr.py is True output has shape | ||
(num_decoder_layers, bs, num_queries, dim), else has shape (1, | ||
bs, num_queries, dim) which only contains the last layer | ||
outputs. | ||
references (Tensor): References from transformer decoder,has | ||
shape (bs, num_query, 2). | ||
Returns: | ||
tuple[Tensor]: results of head containing the following tensor. | ||
- layers_cls_scores (Tensor): Outputs from the classification head, | ||
shape (num_decoder_layers, bs, num_queries, cls_out_channels). | ||
Note cls_out_channels should include background. | ||
- layers_bbox_preds (Tensor): Sigmoid outputs from the regression | ||
head with normalized coordinate format (cx, cy, w, h), has shape | ||
(num_decoder_layers, bs, num_queries, 4). | ||
""" | ||
|
||
references_unsigmoid = inverse_sigmoid(references) | ||
layers_bbox_preds = [] | ||
for layer_id in range(hidden_states.shape[0]): | ||
tmp_reg_preds = self.fc_reg( | ||
self.activate(self.reg_ffn(hidden_states[layer_id]))) | ||
tmp_reg_preds[..., :2] += references_unsigmoid | ||
outputs_coord = tmp_reg_preds.sigmoid() | ||
layers_bbox_preds.append(outputs_coord) | ||
layers_bbox_preds = torch.stack(layers_bbox_preds) | ||
|
||
layers_cls_scores = self.fc_cls(hidden_states) | ||
return layers_cls_scores, layers_bbox_preds | ||
|
||
def loss(self, hidden_states: Tensor, references: Tensor, | ||
batch_data_samples: SampleList) -> dict: | ||
"""Perform forward propagation and loss calculation of the detection | ||
head on the features of the upstream network. | ||
Args: | ||
hidden_states (Tensor): Feature from the transformer decoder, has | ||
shape (num_decoder_layers, bs, num_queries, dim). | ||
references (Tensor): references from the transformer decoder, has | ||
shape (num_decoder_layers, bs, num_queries, 2). | ||
batch_data_samples (List[:obj:`DetDataSample`]): The Data | ||
Samples. It usually includes information such as | ||
`gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. | ||
Returns: | ||
dict: A dictionary of loss components. | ||
""" | ||
batch_gt_instances = [] | ||
batch_img_metas = [] | ||
for data_sample in batch_data_samples: | ||
batch_img_metas.append(data_sample.metainfo) | ||
batch_gt_instances.append(data_sample.gt_instances) | ||
|
||
outs = self(hidden_states, references) | ||
loss_inputs = outs + (batch_gt_instances, batch_img_metas) | ||
losses = self.loss_by_feat(*loss_inputs) | ||
return losses | ||
|
||
def loss_and_predict( | ||
self, hidden_states: Tensor, references: Tensor, | ||
batch_data_samples: SampleList) -> Tuple[dict, InstanceList]: | ||
"""Perform forward propagation of the head, then calculate loss and | ||
predictions from the features and data samples. Over-write because | ||
img_metas are needed as inputs for bbox_head. | ||
Args: | ||
hidden_states (Tensor): Feature from the transformer decoder, has | ||
shape (num_decoder_layers, bs, num_queries, dim). | ||
references (Tensor): references from the transformer decoder, has | ||
shape (num_decoder_layers, bs, num_queries, 2). | ||
batch_data_samples (list[:obj:`DetDataSample`]): Each item contains | ||
the meta information of each image and corresponding | ||
annotations. | ||
Returns: | ||
tuple: the return value is a tuple contains: | ||
- losses: (dict[str, Tensor]): A dictionary of loss components. | ||
- predictions (list[:obj:`InstanceData`]): Detection | ||
results of each image after the post process. | ||
""" | ||
batch_gt_instances = [] | ||
batch_img_metas = [] | ||
for data_sample in batch_data_samples: | ||
batch_img_metas.append(data_sample.metainfo) | ||
batch_gt_instances.append(data_sample.gt_instances) | ||
|
||
outs = self(hidden_states, references) | ||
loss_inputs = outs + (batch_gt_instances, batch_img_metas) | ||
losses = self.loss_by_feat(*loss_inputs) | ||
|
||
predictions = self.predict_by_feat( | ||
*outs, batch_img_metas=batch_img_metas) | ||
return losses, predictions | ||
|
||
def predict(self, | ||
hidden_states: Tensor, | ||
references: Tensor, | ||
batch_data_samples: SampleList, | ||
rescale: bool = True) -> InstanceList: | ||
"""Perform forward propagation of the detection head and predict | ||
detection results on the features of the upstream network. Over-write | ||
because img_metas are needed as inputs for bbox_head. | ||
Args: | ||
hidden_states (Tensor): Feature from the transformer decoder, has | ||
shape (num_decoder_layers, bs, num_queries, dim). | ||
references (Tensor): references from the transformer decoder, has | ||
shape (num_decoder_layers, bs, num_queries, 2). | ||
batch_data_samples (List[:obj:`DetDataSample`]): The Data | ||
Samples. It usually includes information such as | ||
`gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. | ||
rescale (bool, optional): Whether to rescale the results. | ||
Defaults to True. | ||
Returns: | ||
list[obj:`InstanceData`]: Detection results of each image | ||
after the post process. | ||
""" | ||
batch_img_metas = [ | ||
data_samples.metainfo for data_samples in batch_data_samples | ||
] | ||
|
||
last_layer_hidden_state = hidden_states[-1].unsqueeze(0) | ||
outs = self(last_layer_hidden_state, references) | ||
|
||
predictions = self.predict_by_feat( | ||
*outs, batch_img_metas=batch_img_metas, rescale=rescale) | ||
|
||
return predictions |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from typing import Dict | ||
|
||
import torch.nn as nn | ||
from torch import Tensor | ||
|
||
from mmdet.registry import MODELS | ||
from ..layers import (ConditionalDetrTransformerDecoder, | ||
DetrTransformerEncoder, SinePositionalEncoding) | ||
from .detr import DETR | ||
|
||
|
||
@MODELS.register_module() | ||
class ConditionalDETR(DETR): | ||
r"""Implementation of `Conditional DETR for Fast Training Convergence. | ||
<https://arxiv.org/abs/2108.06152>`_. | ||
Code is modified from the `official github repo | ||
<https://github.com/Atten4Vis/ConditionalDETR>`_. | ||
""" | ||
|
||
def _init_layers(self) -> None: | ||
"""Initialize layers except for backbone, neck and bbox_head.""" | ||
self.positional_encoding = SinePositionalEncoding( | ||
**self.positional_encoding_cfg) | ||
self.encoder = DetrTransformerEncoder(**self.encoder) | ||
self.decoder = ConditionalDetrTransformerDecoder(**self.decoder) | ||
self.embed_dims = self.encoder.embed_dims | ||
# NOTE The embed_dims is typically passed from the inside out. | ||
# For example in DETR, The embed_dims is passed as | ||
# self_attn -> the first encoder layer -> encoder -> detector. | ||
self.query_embedding = nn.Embedding(self.num_queries, self.embed_dims) | ||
|
||
num_feats = self.positional_encoding.num_feats | ||
assert num_feats * 2 == self.embed_dims, \ | ||
f'embed_dims should be exactly 2 times of num_feats. ' \ | ||
f'Found {self.embed_dims} and {num_feats}.' | ||
|
||
def forward_decoder(self, query: Tensor, query_pos: Tensor, memory: Tensor, | ||
memory_mask: Tensor, memory_pos: Tensor) -> Dict: | ||
"""Forward with Transformer decoder. | ||
Args: | ||
query (Tensor): The queries of decoder inputs, has shape | ||
(bs, num_queries, dim). | ||
query_pos (Tensor): The positional queries of decoder inputs, | ||
has shape (bs, num_queries, dim). | ||
memory (Tensor): The output embeddings of the Transformer encoder, | ||
has shape (bs, num_feat_points, dim). | ||
memory_mask (Tensor): ByteTensor, the padding mask of the memory, | ||
has shape (bs, num_feat_points). | ||
memory_pos (Tensor): The positional embeddings of memory, has | ||
shape (bs, num_feat_points, dim). | ||
Returns: | ||
dict: The dictionary of decoder outputs, which includes the | ||
`hidden_states` and `references` of the decoder output. | ||
- hidden_states (Tensor): Has shape | ||
(num_decoder_layers, bs, num_queries, dim) | ||
- references (Tensor): Has shape | ||
(bs, num_queries, 2) | ||
""" | ||
|
||
hidden_states, references = self.decoder( | ||
query=query, | ||
key=memory, | ||
value=memory, | ||
query_pos=query_pos, | ||
key_pos=memory_pos, | ||
key_padding_mask=memory_mask) | ||
head_inputs_dict = dict( | ||
hidden_states=hidden_states, references=references) | ||
return head_inputs_dict |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.