Skip to content

Commit

Permalink
Refactor detr 3.x conditional detr (#9405)
Browse files Browse the repository at this point in the history
* Add conditional detr to 3.0
Co-authored-by: lym <[email protected]>
  • Loading branch information
LYMDLUT authored Dec 15, 2022
1 parent f874d5c commit 899d4b8
Show file tree
Hide file tree
Showing 14 changed files with 969 additions and 15 deletions.
39 changes: 39 additions & 0 deletions configs/conditional_detr/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Conditional DETR

> [Conditional DETR for Fast Training Convergence](https://arxiv.org/abs/2108.06152)
<!-- [ALGORITHM] -->

## Abstract

The DETR approach applies the transformer encoder and decoder architecture to object detection and achieves promising performance. In this paper, we handle the critical issue, slow training convergence, and present a conditional cross-attention mechanism for fast DETR training. Our approach is motivated by that the cross-attention in DETR relies highly on the content embeddings and that the spatial embeddings make minor contributions, increasing the need for high-quality content embeddings and thus increasing the training difficulty.

<div align=center>
<img src="https://github.com/Atten4Vis/ConditionalDETR/blob/main/.github/attention-maps.png?raw=true"/>
</div>

Our conditional DETR learns a conditional spatial query from the decoder embedding for decoder multi-head cross-attention. The benefit is that through the conditional spatial query, each cross-attention head is able to attend to a band containing a distinct region, e.g., one object extremity or a region inside the object box (Figure 1). This narrows down the spatial range for localizing the distinct regions for object classification and box regression, thus relaxing the dependence on the content embeddings and easing the training. Empirical results show that conditional DETR converges 6.7x faster for the backbones R50 and R101 and 10x faster for stronger backbones DC5-R50 and DC5-R101.

<div align=center>
<img src="https://github.com/Atten4Vis/ConditionalDETR/raw/main/.github/conditional-detr.png" width="48%"/>
<img src="https://github.com/Atten4Vis/ConditionalDETR/raw/main/.github/convergence-curve.png" width="48%"/>
</div>

## Results and Models

We provide the config files and models for Conditional DETR: [Conditional DETR for Fast Training Convergence](https://arxiv.org/abs/2108.06152).

| Backbone | Model | Lr schd | Mem (GB) | Inf time (fps) | box AP | Config | Download |
| :------: | :--------------: | :-----: | :------: | :------------: | :----: | :-----------------------------------------------: | :----------------------------------: |
| R-50 | Conditional DETR | 50e | 7.9 | | 40.9 | [config](./conditional_detr_r50_8xb2-50e_coco.py) | \[model\](# TODO) \| \[log\](# TODO) |

## Citation

```latex
@inproceedings{meng2021-CondDETR,
title = {Conditional DETR for Fast Training Convergence},
author = {Meng, Depu and Chen, Xiaokang and Fan, Zejia and Zeng, Gang and Li, Houqiang and Yuan, Yuhui and Sun, Lei and Wang, Jingdong},
booktitle = {Proceedings of the IEEE International Conference on Computer Vision (ICCV)},
year = {2021}
}
```
42 changes: 42 additions & 0 deletions configs/conditional_detr/conditional_detr_r50_8xb2-50e_coco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
_base_ = ['../detr/detr_r50_8xb2-150e_coco.py']
model = dict(
type='ConditionalDETR',
num_queries=300,
decoder=dict(
num_layers=6,
layer_cfg=dict(
self_attn_cfg=dict(
_delete_=True,
embed_dims=256,
num_heads=8,
attn_drop=0.1,
cross_attn=False),
cross_attn_cfg=dict(
_delete_=True,
embed_dims=256,
num_heads=8,
attn_drop=0.1,
cross_attn=True))),
bbox_head=dict(
type='ConditionalDETRHead',
loss_cls=dict(
_delete_=True,
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=2.0)),
# training and testing settings
train_cfg=dict(
assigner=dict(
type='HungarianAssigner',
match_costs=[
dict(type='FocalLossCost', weight=2.0),
dict(type='BBoxL1Cost', weight=5.0, box_format='xywh'),
dict(type='IoUCost', iou_mode='giou', weight=2.0)
])))

# learning policy
train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=50, val_interval=1)

param_scheduler = [dict(type='MultiStepLR', end=50, milestones=[40])]
4 changes: 3 additions & 1 deletion mmdet/models/dense_heads/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .centernet_head import CenterNetHead
from .centernet_update_head import CenterNetUpdateHead
from .centripetal_head import CentripetalHead
from .conditional_detr_head import ConditionalDETRHead
from .corner_head import CornerHead
from .ddod_head import DDODHead
from .deformable_detr_head import DeformableDETRHead
Expand Down Expand Up @@ -56,5 +57,6 @@
'DeformableDETRHead', 'CenterNetHead', 'YOLOXHead', 'SOLOHead',
'DecoupledSOLOHead', 'DecoupledSOLOLightHead', 'SOLOV2Head', 'LADHead',
'TOODHead', 'MaskFormerHead', 'Mask2FormerHead', 'DDODHead',
'CenterNetUpdateHead', 'RTMDetHead', 'RTMDetSepBNHead'
'CenterNetUpdateHead', 'RTMDetHead', 'RTMDetSepBNHead',
'ConditionalDETRHead'
]
168 changes: 168 additions & 0 deletions mmdet/models/dense_heads/conditional_detr_head.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Tuple

import torch
import torch.nn as nn
from mmengine.model import bias_init_with_prob
from torch import Tensor

from mmdet.models.layers.transformer import inverse_sigmoid
from mmdet.registry import MODELS
from mmdet.structures import SampleList
from mmdet.utils import InstanceList
from .detr_head import DETRHead


@MODELS.register_module()
class ConditionalDETRHead(DETRHead):
"""Head of Conditional DETR. Conditional DETR: Conditional DETR for Fast
Training Convergence. More details can be found in the `paper.
<https://arxiv.org/abs/2108.06152>`_ .
"""

def init_weights(self):
"""Initialize weights of the transformer head."""
super().init_weights()
# The initialization below for transformer head is very
# important as we use Focal_loss for loss_cls
if self.loss_cls.use_sigmoid:
bias_init = bias_init_with_prob(0.01)
nn.init.constant_(self.fc_cls.bias, bias_init)

def forward(self, hidden_states: Tensor,
references: Tensor) -> Tuple[Tensor, Tensor]:
""""Forward function.
Args:
hidden_states (Tensor): Features from transformer decoder. If
`return_intermediate_dec` in detr.py is True output has shape
(num_decoder_layers, bs, num_queries, dim), else has shape (1,
bs, num_queries, dim) which only contains the last layer
outputs.
references (Tensor): References from transformer decoder,has
shape (bs, num_query, 2).
Returns:
tuple[Tensor]: results of head containing the following tensor.
- layers_cls_scores (Tensor): Outputs from the classification head,
shape (num_decoder_layers, bs, num_queries, cls_out_channels).
Note cls_out_channels should include background.
- layers_bbox_preds (Tensor): Sigmoid outputs from the regression
head with normalized coordinate format (cx, cy, w, h), has shape
(num_decoder_layers, bs, num_queries, 4).
"""

references_unsigmoid = inverse_sigmoid(references)
layers_bbox_preds = []
for layer_id in range(hidden_states.shape[0]):
tmp_reg_preds = self.fc_reg(
self.activate(self.reg_ffn(hidden_states[layer_id])))
tmp_reg_preds[..., :2] += references_unsigmoid
outputs_coord = tmp_reg_preds.sigmoid()
layers_bbox_preds.append(outputs_coord)
layers_bbox_preds = torch.stack(layers_bbox_preds)

layers_cls_scores = self.fc_cls(hidden_states)
return layers_cls_scores, layers_bbox_preds

def loss(self, hidden_states: Tensor, references: Tensor,
batch_data_samples: SampleList) -> dict:
"""Perform forward propagation and loss calculation of the detection
head on the features of the upstream network.
Args:
hidden_states (Tensor): Feature from the transformer decoder, has
shape (num_decoder_layers, bs, num_queries, dim).
references (Tensor): references from the transformer decoder, has
shape (num_decoder_layers, bs, num_queries, 2).
batch_data_samples (List[:obj:`DetDataSample`]): The Data
Samples. It usually includes information such as
`gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.
Returns:
dict: A dictionary of loss components.
"""
batch_gt_instances = []
batch_img_metas = []
for data_sample in batch_data_samples:
batch_img_metas.append(data_sample.metainfo)
batch_gt_instances.append(data_sample.gt_instances)

outs = self(hidden_states, references)
loss_inputs = outs + (batch_gt_instances, batch_img_metas)
losses = self.loss_by_feat(*loss_inputs)
return losses

def loss_and_predict(
self, hidden_states: Tensor, references: Tensor,
batch_data_samples: SampleList) -> Tuple[dict, InstanceList]:
"""Perform forward propagation of the head, then calculate loss and
predictions from the features and data samples. Over-write because
img_metas are needed as inputs for bbox_head.
Args:
hidden_states (Tensor): Feature from the transformer decoder, has
shape (num_decoder_layers, bs, num_queries, dim).
references (Tensor): references from the transformer decoder, has
shape (num_decoder_layers, bs, num_queries, 2).
batch_data_samples (list[:obj:`DetDataSample`]): Each item contains
the meta information of each image and corresponding
annotations.
Returns:
tuple: the return value is a tuple contains:
- losses: (dict[str, Tensor]): A dictionary of loss components.
- predictions (list[:obj:`InstanceData`]): Detection
results of each image after the post process.
"""
batch_gt_instances = []
batch_img_metas = []
for data_sample in batch_data_samples:
batch_img_metas.append(data_sample.metainfo)
batch_gt_instances.append(data_sample.gt_instances)

outs = self(hidden_states, references)
loss_inputs = outs + (batch_gt_instances, batch_img_metas)
losses = self.loss_by_feat(*loss_inputs)

predictions = self.predict_by_feat(
*outs, batch_img_metas=batch_img_metas)
return losses, predictions

def predict(self,
hidden_states: Tensor,
references: Tensor,
batch_data_samples: SampleList,
rescale: bool = True) -> InstanceList:
"""Perform forward propagation of the detection head and predict
detection results on the features of the upstream network. Over-write
because img_metas are needed as inputs for bbox_head.
Args:
hidden_states (Tensor): Feature from the transformer decoder, has
shape (num_decoder_layers, bs, num_queries, dim).
references (Tensor): references from the transformer decoder, has
shape (num_decoder_layers, bs, num_queries, 2).
batch_data_samples (List[:obj:`DetDataSample`]): The Data
Samples. It usually includes information such as
`gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.
rescale (bool, optional): Whether to rescale the results.
Defaults to True.
Returns:
list[obj:`InstanceData`]: Detection results of each image
after the post process.
"""
batch_img_metas = [
data_samples.metainfo for data_samples in batch_data_samples
]

last_layer_hidden_state = hidden_states[-1].unsqueeze(0)
outs = self(last_layer_hidden_state, references)

predictions = self.predict_by_feat(
*outs, batch_img_metas=batch_img_metas, rescale=rescale)

return predictions
9 changes: 5 additions & 4 deletions mmdet/models/dense_heads/detr_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,17 +446,18 @@ def loss_and_predict(
img_metas are needed as inputs for bbox_head.
Args:
hidden_states (tuple[Tensor]): Features from FPN.
hidden_states (tuple[Tensor]): Feature from the transformer
decoder, has shape (num_decoder_layers, bs, num_queries, dim).
batch_data_samples (list[:obj:`DetDataSample`]): Each item contains
the meta information of each image and corresponding
annotations.
Returns:
tuple: the return value is a tuple contains:
- losses: (dict[str, Tensor]): A dictionary of loss components.
- predictions (list[:obj:`InstanceData`]): Detection
results of each image after the post process.
- losses: (dict[str, Tensor]): A dictionary of loss components.
- predictions (list[:obj:`InstanceData`]): Detection
results of each image after the post process.
"""
batch_gt_instances = []
batch_img_metas = []
Expand Down
3 changes: 2 additions & 1 deletion mmdet/models/detectors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .base_detr import DetectionTransformer
from .cascade_rcnn import CascadeRCNN
from .centernet import CenterNet
from .conditional_detr import ConditionalDETR
from .cornernet import CornerNet
from .ddod import DDOD
from .deformable_detr import DeformableDETR
Expand Down Expand Up @@ -59,5 +60,5 @@
'SOLOv2', 'DeformableDETR', 'AutoAssign', 'YOLOF', 'CenterNet', 'YOLOX',
'TwoStagePanopticSegmentor', 'PanopticFPN', 'QueryInst', 'LAD', 'TOOD',
'MaskFormer', 'DDOD', 'Mask2Former', 'SemiBaseDetector', 'SoftTeacher',
'DetectionTransformer', 'RTMDet'
'DetectionTransformer', 'RTMDet', 'ConditionalDETR'
]
75 changes: 75 additions & 0 deletions mmdet/models/detectors/conditional_detr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict

import torch.nn as nn
from torch import Tensor

from mmdet.registry import MODELS
from ..layers import (ConditionalDetrTransformerDecoder,
DetrTransformerEncoder, SinePositionalEncoding)
from .detr import DETR


@MODELS.register_module()
class ConditionalDETR(DETR):
r"""Implementation of `Conditional DETR for Fast Training Convergence.
<https://arxiv.org/abs/2108.06152>`_.
Code is modified from the `official github repo
<https://github.com/Atten4Vis/ConditionalDETR>`_.
"""

def _init_layers(self) -> None:
"""Initialize layers except for backbone, neck and bbox_head."""
self.positional_encoding = SinePositionalEncoding(
**self.positional_encoding_cfg)
self.encoder = DetrTransformerEncoder(**self.encoder)
self.decoder = ConditionalDetrTransformerDecoder(**self.decoder)
self.embed_dims = self.encoder.embed_dims
# NOTE The embed_dims is typically passed from the inside out.
# For example in DETR, The embed_dims is passed as
# self_attn -> the first encoder layer -> encoder -> detector.
self.query_embedding = nn.Embedding(self.num_queries, self.embed_dims)

num_feats = self.positional_encoding.num_feats
assert num_feats * 2 == self.embed_dims, \
f'embed_dims should be exactly 2 times of num_feats. ' \
f'Found {self.embed_dims} and {num_feats}.'

def forward_decoder(self, query: Tensor, query_pos: Tensor, memory: Tensor,
memory_mask: Tensor, memory_pos: Tensor) -> Dict:
"""Forward with Transformer decoder.
Args:
query (Tensor): The queries of decoder inputs, has shape
(bs, num_queries, dim).
query_pos (Tensor): The positional queries of decoder inputs,
has shape (bs, num_queries, dim).
memory (Tensor): The output embeddings of the Transformer encoder,
has shape (bs, num_feat_points, dim).
memory_mask (Tensor): ByteTensor, the padding mask of the memory,
has shape (bs, num_feat_points).
memory_pos (Tensor): The positional embeddings of memory, has
shape (bs, num_feat_points, dim).
Returns:
dict: The dictionary of decoder outputs, which includes the
`hidden_states` and `references` of the decoder output.
- hidden_states (Tensor): Has shape
(num_decoder_layers, bs, num_queries, dim)
- references (Tensor): Has shape
(bs, num_queries, 2)
"""

hidden_states, references = self.decoder(
query=query,
key=memory,
value=memory,
query_pos=query_pos,
key_pos=memory_pos,
key_padding_mask=memory_mask)
head_inputs_dict = dict(
hidden_states=hidden_states, references=references)
return head_inputs_dict
5 changes: 4 additions & 1 deletion mmdet/models/detectors/detr.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,8 +199,11 @@ def forward_decoder(self, query: Tensor, query_pos: Tensor, memory: Tensor,
Returns:
dict: The dictionary of decoder outputs, which includes the
`hidden_states` of the decoder output.
- hidden_states (Tensor): Has shape
(num_decoder_layers, bs, num_queries, dim)
"""
# (num_decoder_layers, bs, num_queries, dim)

hidden_states = self.decoder(
query=query,
key=memory,
Expand Down
Loading

0 comments on commit 899d4b8

Please sign in to comment.