From 537f084df14b845eced8452c8d313fde52f795b1 Mon Sep 17 00:00:00 2001
From: LiuYi <1150854440@qq.com>
Date: Wed, 13 Sep 2023 13:22:43 +0800
Subject: [PATCH] support edpose
---
.../edpose/coco/edpose_coco.md | 59 +
.../edpose/coco/edpose_coco.yml | 25 +
.../edpose/coco/edpose_res50_coco.py | 217 +++
docs/src/papers/algorithms/edpose.md | 31 +
mmpose/codecs/__init__.py | 3 +-
mmpose/codecs/edpose_label.py | 170 ++
mmpose/datasets/transforms/__init__.py | 6 +-
.../transforms/bottomup_transforms.py | 409 ++++-
mmpose/models/data_preprocessors/__init__.py | 4 +-
.../data_preprocessors/data_preprocessor.py | 184 +++
mmpose/models/heads/__init__.py | 4 +-
.../heads/transformer_heads/__init__.py | 19 +
.../base_transformer_head.py | 112 ++
.../heads/transformer_heads/edpose_head.py | 1439 +++++++++++++++++
.../transformers/__init__.py | 17 +
.../transformers/deformable_detr_layers.py | 251 +++
.../transformers/detr_layers.py | 333 ++++
.../transformer_heads/transformers/utils.py | 114 ++
mmpose/models/necks/__init__.py | 4 +-
mmpose/models/necks/chanel_mapper_neck.py | 109 ++
.../test_bottomup_transforms.py | 167 +-
21 files changed, 3668 insertions(+), 9 deletions(-)
create mode 100644 configs/body_2d_keypoint/edpose/coco/edpose_coco.md
create mode 100644 configs/body_2d_keypoint/edpose/coco/edpose_coco.yml
create mode 100644 configs/body_2d_keypoint/edpose/coco/edpose_res50_coco.py
create mode 100644 docs/src/papers/algorithms/edpose.md
create mode 100644 mmpose/codecs/edpose_label.py
create mode 100644 mmpose/models/heads/transformer_heads/__init__.py
create mode 100644 mmpose/models/heads/transformer_heads/base_transformer_head.py
create mode 100644 mmpose/models/heads/transformer_heads/edpose_head.py
create mode 100644 mmpose/models/heads/transformer_heads/transformers/__init__.py
create mode 100644 mmpose/models/heads/transformer_heads/transformers/deformable_detr_layers.py
create mode 100644 mmpose/models/heads/transformer_heads/transformers/detr_layers.py
create mode 100644 mmpose/models/heads/transformer_heads/transformers/utils.py
create mode 100644 mmpose/models/necks/chanel_mapper_neck.py
diff --git a/configs/body_2d_keypoint/edpose/coco/edpose_coco.md b/configs/body_2d_keypoint/edpose/coco/edpose_coco.md
new file mode 100644
index 0000000000..76dca297f7
--- /dev/null
+++ b/configs/body_2d_keypoint/edpose/coco/edpose_coco.md
@@ -0,0 +1,59 @@
+
+
+
+ED-Pose (ICLR'2023)
+
+```bibtex
+@inproceedings{
+yang2023explicit,
+title={Explicit Box Detection Unifies End-to-End Multi-Person Pose Estimation},
+author={Jie Yang and Ailing Zeng and Shilong Liu and Feng Li and Ruimao Zhang and Lei Zhang},
+booktitle={International Conference on Learning Representations},
+year={2023},
+url={https://openreview.net/forum?id=s4WVupnJjmX}
+}
+```
+
+
+
+
+
+
+ResNet (CVPR'2016)
+
+```bibtex
+@inproceedings{he2016deep,
+ title={Deep residual learning for image recognition},
+ author={He, Kaiming and Zhang, Xiangyu and Ren, Shaoqing and Sun, Jian},
+ booktitle={Proceedings of the IEEE conference on computer vision and pattern recognition},
+ pages={770--778},
+ year={2016}
+}
+```
+
+
+
+
+
+
+COCO (ECCV'2014)
+
+```bibtex
+@inproceedings{lin2014microsoft,
+ title={Microsoft coco: Common objects in context},
+ author={Lin, Tsung-Yi and Maire, Michael and Belongie, Serge and Hays, James and Perona, Pietro and Ramanan, Deva and Doll{\'a}r, Piotr and Zitnick, C Lawrence},
+ booktitle={European conference on computer vision},
+ pages={740--755},
+ year={2014},
+ organization={Springer}
+}
+```
+
+
+
+Results on COCO val2017
+
+| Arch | BackBone | AP | AP50 | AP75 | AR | AR50 | ckpt | log |
+| :-------------------------------------------- | :-------: | :---: | :-------------: | :-------------: | :---: | :-------------: | :--------------------------------------------: | :-------------------------------------------: |
+| [edpose_res50_coco](/configs/body_2d_keypoint/edpose/coco/edpose_res50_coco.py) | ResNet-50 | 0.716 | 0.897 | 0.783 | 0.793 | 0.943 | [ckpt](https://download.openmmlab.com/mmpose/v1/body_2d_keypoint/edpose/coco/edpose_res50_coco_3rdparty.pth) | [log](https://download.openmmlab.com/mmpose/v1/body_2d_keypoint/edpose/coco/edpose_res50_coco_3rdparty.json) |
+| | | | | | | | | |
diff --git a/configs/body_2d_keypoint/edpose/coco/edpose_coco.yml b/configs/body_2d_keypoint/edpose/coco/edpose_coco.yml
new file mode 100644
index 0000000000..4b2901d625
--- /dev/null
+++ b/configs/body_2d_keypoint/edpose/coco/edpose_coco.yml
@@ -0,0 +1,25 @@
+Collections:
+- Name: ED-Pose
+ Paper:
+ Title: Explicit Box Detection Unifies End-to-End Multi-Person Pose Estimation
+ URL: https://arxiv.org/pdf/2302.01593.pdf
+ README: https://github.com/open-mmlab/mmpose/blob/main/docs/src/papers/algorithms/edpose.md
+Models:
+- Config: configs/body_2d_keypoint/edpose/coco/edpose_resnet50_coco.py
+ In Collection: ED-Pose
+ Metadata:
+ Architecture: &id001
+ - ED-Pose
+ - ResNet
+ Training Data: COCO
+ Name: edpose_resnet50_coco
+ Results:
+ - Dataset: COCO
+ Metrics:
+ AP: 0.716
+ AP@0.5: 0.897
+ AP@0.75: 0.783
+ AR: 0.793
+ AR@0.5: 0.943
+ Task: Body 2D Keypoint
+ Weights: https://download.openmmlab.com/mmpose/v1/body_2d_keypoint/edpose/coco/edpose_res50_coco_3rdparty.pth
diff --git a/configs/body_2d_keypoint/edpose/coco/edpose_res50_coco.py b/configs/body_2d_keypoint/edpose/coco/edpose_res50_coco.py
new file mode 100644
index 0000000000..1c78377b8b
--- /dev/null
+++ b/configs/body_2d_keypoint/edpose/coco/edpose_res50_coco.py
@@ -0,0 +1,217 @@
+_base_ = ['../../../_base_/default_runtime.py']
+
+# runtime
+train_cfg = dict(max_epochs=140, val_interval=10)
+
+# optimizer
+optim_wrapper = dict(optimizer=dict(
+ type='Adam',
+ lr=1e-3,
+))
+
+# learning policy
+param_scheduler = [
+ dict(
+ type='LinearLR', begin=0, end=500, start_factor=0.001,
+ by_epoch=False), # warm-up
+ dict(
+ type='MultiStepLR',
+ begin=0,
+ end=140,
+ milestones=[90, 120],
+ gamma=0.1,
+ by_epoch=True)
+]
+
+# automatically scaling LR based on the actual training batch size
+auto_scale_lr = dict(base_batch_size=80)
+
+# hooks
+default_hooks = dict(checkpoint=dict(save_best='coco/AP', rule='greater'))
+
+# codec settings
+codec = dict(
+ type='EDPoseLabel', num_select=50, num_body_points=17, not_to_xyxy=False)
+
+# model settings
+model = dict(
+ type='BottomupPoseEstimator',
+ data_preprocessor=dict(
+ type='BatchShapeDataPreprocessor',
+ mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225],
+ bgr_to_rgb=True,
+ pad_size_divisor=1,
+ normalize_bakend='pillow'),
+ backbone=dict(
+ type='ResNet',
+ depth=50,
+ num_stages=4,
+ out_indices=(1, 2, 3),
+ frozen_stages=1,
+ norm_cfg=dict(type='FrozenBatchNorm2d', requires_grad=False),
+ norm_eval=True,
+ style='pytorch',
+ init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')),
+ neck=dict(
+ type='ChannelMapper',
+ in_channels=[512, 1024, 2048],
+ kernel_size=1,
+ out_channels=256,
+ act_cfg=None,
+ norm_cfg=dict(type='GN', num_groups=32),
+ num_outs=4),
+ head=dict(
+ type='EDPoseHead',
+ num_queries=900,
+ num_feature_levels=4,
+ num_body_points=17,
+ as_two_stage=True,
+ encoder=dict(
+ num_layers=6,
+ layer_cfg=dict( # DeformableDetrTransformerEncoderLayer
+ self_attn_cfg=dict( # MultiScaleDeformableAttention
+ embed_dims=256,
+ num_heads=8,
+ num_levels=4,
+ num_points=4,
+ batch_first=True),
+ ffn_cfg=dict(
+ embed_dims=256,
+ feedforward_channels=2048,
+ num_fcs=2,
+ ffn_drop=0.0))),
+ decoder=dict(
+ num_layers=6,
+ embed_dims=256,
+ layer_cfg=dict( # DeformableDetrTransformerDecoderLayer
+ self_attn_cfg=dict( # MultiheadAttention
+ embed_dims=256,
+ num_heads=8,
+ batch_first=True),
+ cross_attn_cfg=dict( # MultiScaleDeformableAttention
+ embed_dims=256,
+ batch_first=True),
+ ffn_cfg=dict(
+ embed_dims=256, feedforward_channels=2048, ffn_drop=0.1)),
+ query_dim=4,
+ num_feature_levels=4,
+ num_group=100,
+ num_dn=100,
+ num_box_decoder_layers=2,
+ return_intermediate=True),
+ out_head=dict(num_classes=2),
+ positional_encoding=dict(
+ num_pos_feats=128,
+ temperatureH=20,
+ temperatureW=20,
+ normalize=True),
+ denosing_cfg=dict(
+ dn_box_noise_scale=0.4,
+ dn_label_noise_ratio=0.5,
+ dn_labelbook_size=100,
+ dn_attn_mask_type_list=['match2dn', 'dn2dn', 'group2group']),
+ data_decoder=codec),
+ test_cfg=dict(Pmultiscale_test=False, flip_test=False, num_select=50),
+ train_cfg=dict())
+
+# enable DDP training when rescore net is used
+find_unused_parameters = True
+
+# base dataset settings
+dataset_type = 'CocoDataset'
+data_mode = 'bottomup'
+data_root = 'data/coco/'
+
+# pipelines
+train_pipeline = [
+ dict(type='LoadImage'),
+ dict(type='RandomFlip', direction='horizontal'),
+ dict(
+ type='RandomChoice',
+ transforms=[
+ [
+ dict(
+ type='RandomChoiceResize',
+ scales=[(480, 1333), (512, 1333), (544, 1333), (576, 1333),
+ (608, 1333), (640, 1333), (672, 1333), (704, 1333),
+ (736, 1333), (768, 1333), (800, 1333)],
+ keep_ratio=True)
+ ],
+ [
+ dict(
+ type='BottomupRandomChoiceResize',
+ # The radio of all image in train dataset < 7
+ # follow the original implement
+ scales=[(400, 4200), (500, 4200), (600, 4200)],
+ keep_ratio=True),
+ dict(
+ type='BottomupRandomCrop',
+ crop_type='absolute_range',
+ crop_size=(384, 600),
+ allow_negative_crop=True),
+ dict(
+ type='BottomupRandomChoiceResize',
+ scales=[(480, 1333), (512, 1333), (544, 1333), (576, 1333),
+ (608, 1333), (640, 1333), (672, 1333), (704, 1333),
+ (736, 1333), (768, 1333), (800, 1333)],
+ keep_ratio=True)
+ ]
+ ]),
+ dict(type='PackPoseInputs'),
+]
+
+val_pipeline = [
+ dict(type='LoadImage', imdecode_backend='pillow'),
+ dict(
+ type='BottomupRandomChoiceResize',
+ scales=[(800, 1333)],
+ keep_ratio=True,
+ backend='pillow'),
+ dict(
+ type='PackPoseInputs',
+ meta_keys=('id', 'img_id', 'img_path', 'crowd_index', 'ori_shape',
+ 'img_shape', 'input_size', 'input_center', 'input_scale',
+ 'flip', 'flip_direction', 'flip_indices', 'raw_ann_info',
+ 'skeleton_links'))
+]
+
+# data loaders
+train_dataloader = dict(
+ batch_size=1,
+ num_workers=1,
+ persistent_workers=True,
+ sampler=dict(type='DefaultSampler', shuffle=False),
+ dataset=dict(
+ type=dataset_type,
+ data_root=data_root,
+ data_mode=data_mode,
+ ann_file='annotations/person_keypoints_train2017.json',
+ data_prefix=dict(img='train2017/'),
+ pipeline=train_pipeline,
+ ))
+val_dataloader = dict(
+ batch_size=4,
+ num_workers=8,
+ persistent_workers=True,
+ drop_last=False,
+ sampler=dict(type='DefaultSampler', shuffle=False, round_up=False),
+ dataset=dict(
+ type=dataset_type,
+ data_root=data_root,
+ data_mode=data_mode,
+ ann_file='annotations/person_keypoints_val2017.json',
+ data_prefix=dict(img='val2017/'),
+ test_mode=True,
+ pipeline=val_pipeline,
+ ))
+test_dataloader = val_dataloader
+
+# evaluators
+val_evaluator = dict(
+ type='CocoMetric',
+ ann_file=data_root + 'annotations/person_keypoints_val2017.json',
+ nms_mode='none',
+ score_mode='keypoint',
+)
+test_evaluator = val_evaluator
diff --git a/docs/src/papers/algorithms/edpose.md b/docs/src/papers/algorithms/edpose.md
new file mode 100644
index 0000000000..07acf2edb5
--- /dev/null
+++ b/docs/src/papers/algorithms/edpose.md
@@ -0,0 +1,31 @@
+# Explicit Box Detection Unifies End-to-End Multi-Person Pose Estimation
+
+
+
+
+ED-Pose (ICLR'2023)
+
+```bibtex
+@inproceedings{
+yang2023explicit,
+title={Explicit Box Detection Unifies End-to-End Multi-Person Pose Estimation},
+author={Jie Yang and Ailing Zeng and Shilong Liu and Feng Li and Ruimao Zhang and Lei Zhang},
+booktitle={International Conference on Learning Representations},
+year={2023},
+url={https://openreview.net/forum?id=s4WVupnJjmX}
+}
+```
+
+
+
+## Abstract
+
+
+
+This paper presents a novel end-to-end framework with Explicit box Detection for multi-person Pose estimation, called ED-Pose, where it unifies the contextual learning between human-level (global) and keypoint-level (local) information. Different from previous one-stage methods, ED-Pose re-considers this task as two explicit box detection processes with a unified representation and regression supervision. First, we introduce a human detection decoder from encoded tokens to extract global features. It can provide a good initialization for the latter keypoint detection, making the training process converge fast. Second, to bring in contextual information near keypoints, we regard pose estimation as a keypoint box detection problem to learn both box positions and contents for each keypoint. A human-to-keypoint detection decoder adopts an interactive learning strategy between human and keypoint features to further enhance global and local feature aggregation. In general, ED-Pose is conceptually simple without post-processing and dense heatmap supervision. It demonstrates its effectiveness and efficiency compared with both two-stage and one-stage methods. Notably, explicit box detection boosts the pose estimation performance by 4.5 AP on COCO and 9.9 AP on CrowdPose. For the first time, as a fully end-to-end framework with a L1 regression loss, ED-Pose surpasses heatmap-based Top-down methods under the same backbone by 1.2 AP on COCO and achieves the state-of-the-art with 76.6 AP on CrowdPose without bells and whistles. Code is available at https://github.com/IDEA-Research/ED-Pose.
+
+
+
+
+
+
diff --git a/mmpose/codecs/__init__.py b/mmpose/codecs/__init__.py
index 1a48b7f851..224d860457 100644
--- a/mmpose/codecs/__init__.py
+++ b/mmpose/codecs/__init__.py
@@ -1,6 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .associative_embedding import AssociativeEmbedding
from .decoupled_heatmap import DecoupledHeatmap
+from .edpose_label import EDPoseLabel
from .image_pose_lifting import ImagePoseLifting
from .integral_regression_label import IntegralRegressionLabel
from .megvii_heatmap import MegviiHeatmap
@@ -16,5 +17,5 @@
'MSRAHeatmap', 'MegviiHeatmap', 'UDPHeatmap', 'RegressionLabel',
'SimCCLabel', 'IntegralRegressionLabel', 'AssociativeEmbedding', 'SPR',
'DecoupledHeatmap', 'VideoPoseLifting', 'ImagePoseLifting',
- 'MotionBERTLabel'
+ 'MotionBERTLabel', 'EDPoseLabel'
]
diff --git a/mmpose/codecs/edpose_label.py b/mmpose/codecs/edpose_label.py
new file mode 100644
index 0000000000..bc0bef61e8
--- /dev/null
+++ b/mmpose/codecs/edpose_label.py
@@ -0,0 +1,170 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Optional
+
+import numpy as np
+
+from mmpose.registry import KEYPOINT_CODECS
+from .base import BaseKeypointCodec
+
+
+@KEYPOINT_CODECS.register_module()
+class EDPoseLabel(BaseKeypointCodec):
+ r"""Generate keypoint and label coordinates for `ED-Pose`_ by
+ Yang J. et al (2023).
+
+ Note:
+
+ - instance number: N
+ - keypoint number: K
+ - keypoint dimension: D
+ - image size: [w, h]
+
+ Encoded:
+
+ - keypoints (np.ndarray): Keypoint coordinates in shape (N, K, D)
+ - keypoints_visible (np.ndarray): Keypoint visibility in shape
+ (N, K, D)
+ - area (np.ndarray): Area in shape (N)
+ - bbox (np.ndarray): Bbox in shape (N, 4)
+
+ Args:
+ num_select (int): The number of candidate keypoints
+ num_body_points (int): The Number of keypoints
+ not_to_xyxy (bool): Whether convert bbox from cxcy to
+ xyxy.
+ """
+
+ auxiliary_encode_keys = {'area', 'bboxes', 'img_shape'}
+
+ def __init__(self,
+ num_select: int = 100,
+ num_body_points: int = 17,
+ not_to_xyxy: bool = False):
+ super().__init__()
+
+ self.num_select = num_select
+ self.num_body_points = num_body_points
+ self.not_to_xyxy = not_to_xyxy
+
+ def encode(
+ self,
+ img_shape,
+ keypoints: np.ndarray,
+ keypoints_visible: Optional[np.ndarray] = None,
+ area: Optional[np.ndarray] = None,
+ bboxes: Optional[np.ndarray] = None,
+ ) -> dict:
+ """Encoding keypoints 、area、bbox from input image space to normalized
+ space.
+
+ Args:
+ - keypoints (np.ndarray): Keypoint coordinates in
+ shape (N, K, D).
+ - keypoints_visible (np.ndarray): Keypoint visibility in shape
+ (N, K, D)
+ - area (np.ndarray):
+ - bboxes (np.ndarray):
+
+ Returns:
+ encoded (dict): Contains the following items:
+
+ - keypoint_labels (np.ndarray): The processed keypoints in
+ shape like (N, K, D).
+ - keypoints_visible (np.ndarray): Keypoint visibility in shape
+ (N, K, D)
+ - area_labels (np.ndarray): The processed target
+ area in shape (N).
+ - bboxes_labels: The processed target bbox in
+ shape (N, 4).
+ """
+ w, h = img_shape
+
+ if keypoints_visible is None:
+ keypoints_visible = np.ones(keypoints.shape[:2], dtype=np.float32)
+
+ if bboxes is not None:
+ bboxes = self.box_xyxy_to_cxcywh(bboxes)
+ bboxes_labels = bboxes / np.array([w, h, w, h], dtype=np.float32)
+
+ if area is not None:
+ area_labels = area / (
+ np.array(w, dtype=np.float32) * np.array(h, dtype=np.float32))
+
+ if keypoints is not None:
+ keypoint_labels = keypoints / np.array([w, h], dtype=np.float32)
+
+ encoded = dict(
+ keypoint_labels=keypoint_labels,
+ area_labels=area_labels,
+ bboxes_labels=bboxes_labels,
+ keypoints_visible=keypoints_visible)
+
+ return encoded
+
+ def decode(self, input_shapes: np.ndarray, pred_logits: np.ndarray,
+ pred_boxes: np.ndarray, pred_keypoints: np.ndarray):
+ """Select the final top-k keypoints, and decode the results from
+ normalize size to origin input size.
+
+ Args:
+ input_shapes (Tensor): The size of input image resize.
+ test_cfg (ConfigType): Config of testing.
+ pred_logits (Tensor): The result of score.
+ pred_boxes (Tensor): The result of bbox.
+ pred_keypoints (Tensor): The result of keypoints.
+
+ Returns:
+ """
+
+ num_body_points = self.num_body_points
+
+ prob = pred_logits
+
+ prob_reshaped = prob.reshape(-1)
+ topk_indexes = np.argsort(-prob_reshaped)[:self.num_select]
+ topk_values = np.take_along_axis(prob_reshaped, topk_indexes, axis=0)
+
+ scores = np.tile(topk_values[:, np.newaxis], [1, num_body_points])
+
+ # bbox
+ topk_boxes = topk_indexes // pred_logits.shape[1]
+ if self.not_to_xyxy:
+ boxes = pred_boxes
+ else:
+ x_c, y_c, w, h = np.split(pred_boxes, 4, axis=-1)
+ b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w),
+ (y_c + 0.5 * h)]
+ boxes = np.concatenate(b, axis=1)
+
+ boxes = np.take_along_axis(
+ boxes, np.tile(topk_boxes[:, np.newaxis], [1, 4]), axis=0)
+
+ # from relative [0, 1] to absolute [0, height] coordinates
+ img_h, img_w = np.split(input_shapes, 2, axis=0)
+ scale_fct = np.hstack([img_w, img_h, img_w, img_h])
+ boxes = boxes * scale_fct[np.newaxis, :]
+
+ # keypoints
+ topk_keypoints = topk_indexes // pred_logits.shape[1]
+ keypoints = np.take_along_axis(
+ pred_keypoints,
+ np.tile(topk_keypoints[:, np.newaxis], [1, num_body_points * 3]),
+ axis=0)
+
+ Z_pred = keypoints[:, :(num_body_points * 2)]
+ V_pred = keypoints[:, (num_body_points * 2):]
+ Z_pred = Z_pred * np.tile(
+ np.hstack([img_w, img_h]), [num_body_points])[np.newaxis, :]
+ keypoints_res = np.zeros_like(keypoints)
+ keypoints_res[..., 0::3] = Z_pred[..., 0::2]
+ keypoints_res[..., 1::3] = Z_pred[..., 1::2]
+ keypoints_res[..., 2::3] = V_pred[..., 0::1]
+
+ keypoint = keypoints_res.reshape(-1, num_body_points, 3)[:, :, :2]
+
+ return keypoint, scores, boxes
+
+ def box_xyxy_to_cxcywh(self, x):
+ x0, y0, x1, y1 = x.unbind(-1)
+ b = [(x0 + x1) / 2, (y0 + y1) / 2, (x1 - x0), (y1 - y0)]
+ return np.stack(b, dim=-1)
diff --git a/mmpose/datasets/transforms/__init__.py b/mmpose/datasets/transforms/__init__.py
index 7ccbf7dac2..a1240a1666 100644
--- a/mmpose/datasets/transforms/__init__.py
+++ b/mmpose/datasets/transforms/__init__.py
@@ -1,6 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .bottomup_transforms import (BottomupGetHeatmapMask, BottomupRandomAffine,
- BottomupResize)
+ BottomupRandomChoiceResize,
+ BottomupRandomCrop, BottomupResize)
from .common_transforms import (Albumentation, GenerateTarget,
GetBBoxCenterScale, PhotometricDistortion,
RandomBBoxTransform, RandomFlip,
@@ -16,5 +17,6 @@
'RandomHalfBody', 'TopdownAffine', 'Albumentation',
'PhotometricDistortion', 'PackPoseInputs', 'LoadImage',
'BottomupGetHeatmapMask', 'BottomupRandomAffine', 'BottomupResize',
- 'GenerateTarget', 'KeypointConverter', 'RandomFlipAroundRoot'
+ 'GenerateTarget', 'KeypointConverter', 'RandomFlipAroundRoot',
+ 'BottomupRandomCrop', 'BottomupRandomChoiceResize'
]
diff --git a/mmpose/datasets/transforms/bottomup_transforms.py b/mmpose/datasets/transforms/bottomup_transforms.py
index c31e0ae17d..c6ba5e5bc1 100644
--- a/mmpose/datasets/transforms/bottomup_transforms.py
+++ b/mmpose/datasets/transforms/bottomup_transforms.py
@@ -1,10 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved.
-from typing import Dict, List, Optional, Tuple
+from typing import Dict, List, Optional, Sequence, Tuple, Union
import cv2
import numpy as np
import xtcocotools.mask as cocomask
from mmcv.image import imflip_, imresize
+from mmcv.image.geometric import imrescale
from mmcv.transforms import BaseTransform
from mmcv.transforms.utils import cache_randomness
from scipy.stats import truncnorm
@@ -515,3 +516,409 @@ def transform(self, results: Dict) -> Optional[dict]:
results['aug_scale'] = None
return results
+
+
+@TRANSFORMS.register_module()
+class BottomupRandomCrop(BaseTransform):
+ """Random crop the image & bboxes & masks.
+
+ The absolute ``crop_size`` is sampled based on ``crop_type`` and
+ ``image_size``, then the cropped results are generated.
+
+ Required Keys:
+
+ - img
+ - keypoints
+ - bbox (optional)
+ - masks (BitmapMasks | PolygonMasks) (optional)
+
+ Modified Keys:
+
+ - img
+ - img_shape
+ - keypoints
+ - keypoints_visible
+ - num_keypoints
+ - bbox (optional)
+ - bbox_score (optional)
+ - id (optional)
+ - category_id (optional)
+ - raw_ann_info (optional)
+ - iscrowd (optional)
+ - segmentation (optional)
+ - masks (optional)
+
+ Added Keys:
+
+ - warp_mat
+
+ Args:
+ crop_size (tuple): The relative ratio or absolute pixels of
+ (width, height).
+ crop_type (str, optional): One of "relative_range", "relative",
+ "absolute", "absolute_range". "relative" randomly crops
+ (h * crop_size[0], w * crop_size[1]) part from an input of size
+ (h, w). "relative_range" uniformly samples relative crop size from
+ range [crop_size[0], 1] and [crop_size[1], 1] for height and width
+ respectively. "absolute" crops from an input with absolute size
+ (crop_size[0], crop_size[1]). "absolute_range" uniformly samples
+ crop_h in range [crop_size[0], min(h, crop_size[1])] and crop_w
+ in range [crop_size[0], min(w, crop_size[1])].
+ Defaults to "absolute".
+ allow_negative_crop (bool, optional): Whether to allow a crop that does
+ not contain any bbox area. Defaults to False.
+ recompute_bbox (bool, optional): Whether to re-compute the boxes based
+ on cropped instance masks. Defaults to False.
+ bbox_clip_border (bool, optional): Whether clip the objects outside
+ the border of the image. Defaults to True.
+
+ Note:
+ - If the image is smaller than the absolute crop size, return the
+ original image.
+ - If the crop does not contain any gt-bbox region and
+ ``allow_negative_crop`` is set to False, skip this image.
+ """
+
+ def __init__(self,
+ crop_size: tuple,
+ crop_type: str = 'absolute',
+ allow_negative_crop: bool = False,
+ recompute_bbox: bool = False,
+ bbox_clip_border: bool = True) -> None:
+ if crop_type not in [
+ 'relative_range', 'relative', 'absolute', 'absolute_range'
+ ]:
+ raise ValueError(f'Invalid crop_type {crop_type}.')
+ if crop_type in ['absolute', 'absolute_range']:
+ assert crop_size[0] > 0 and crop_size[1] > 0
+ assert isinstance(crop_size[0], int) and isinstance(
+ crop_size[1], int)
+ if crop_type == 'absolute_range':
+ assert crop_size[0] <= crop_size[1]
+ else:
+ assert 0 < crop_size[0] <= 1 and 0 < crop_size[1] <= 1
+ self.crop_size = crop_size
+ self.crop_type = crop_type
+ self.allow_negative_crop = allow_negative_crop
+ self.bbox_clip_border = bbox_clip_border
+ self.recompute_bbox = recompute_bbox
+
+ def _crop_data(self, results: dict, crop_size: Tuple[int, int],
+ allow_negative_crop: bool) -> Union[dict, None]:
+ """Function to randomly crop images, bounding boxes, masks, semantic
+ segmentation maps.
+
+ Args:
+ results (dict): Result dict from loading pipeline.
+ crop_size (Tuple[int, int]): Expected absolute size after
+ cropping, (h, w).
+ allow_negative_crop (bool): Whether to allow a crop that does not
+ contain any bbox area.
+
+ Returns:
+ results (Union[dict, None]): Randomly cropped results, 'img_shape'
+ key in result dict is updated according to crop size. None will
+ be returned when there is no valid bbox after cropping.
+ """
+ assert crop_size[0] > 0 and crop_size[1] > 0
+ img = results['img']
+ margin_h = max(img.shape[0] - crop_size[0], 0)
+ margin_w = max(img.shape[1] - crop_size[1], 0)
+ offset_h, offset_w = self._rand_offset((margin_h, margin_w))
+ crop_y1, crop_y2 = offset_h, offset_h + crop_size[0]
+ crop_x1, crop_x2 = offset_w, offset_w + crop_size[1]
+
+ # Record the warp matrix for the RandomCrop
+ warp_mat = np.array([[1, 0, -offset_w], [0, 1, -offset_h], [0, 0, 1]],
+ dtype=np.float32)
+ if results.get('warp_mat', None) is None:
+ results['warp_mat'] = warp_mat
+ else:
+ results['warp_mat'] = warp_mat @ results['warp_mat']
+
+ # crop the image
+ img = img[crop_y1:crop_y2, crop_x1:crop_x2, ...]
+ img_shape = img.shape
+ results['img'] = img
+ results['img_shape'] = img_shape[:2]
+
+ # crop bboxes accordingly and clip to the image boundary
+ if results.get('bbox', None) is not None:
+ distances = (-offset_w, -offset_h)
+ bboxes = results['bbox']
+ bboxes = bboxes + np.tile(np.asarray(distances), 2)
+
+ if self.bbox_clip_border:
+ bboxes[..., 0::2] = bboxes[..., 0::2].clip(0, img_shape[1])
+ bboxes[..., 1::2] = bboxes[..., 1::2].clip(0, img_shape[0])
+
+ valid_inds = (bboxes[..., 0] < img_shape[1]) & \
+ (bboxes[..., 1] < img_shape[0]) & \
+ (bboxes[..., 2] > 0) & \
+ (bboxes[..., 3] > 0)
+
+ # If the crop does not contain any gt-bbox area and
+ # allow_negative_crop is False, skip this image.
+ if (not valid_inds.any() and not allow_negative_crop):
+ return None
+
+ results['bbox'] = bboxes[valid_inds]
+ meta_keys = [
+ 'bbox_score', 'id', 'category_id', 'raw_ann_info', 'iscrowd'
+ ]
+ for key in meta_keys:
+ if results.get(key):
+ if isinstance(results[key], list):
+ results[key] = np.asarray(
+ results[key])[valid_inds].tolist()
+ else:
+ results[key] = results[key][valid_inds]
+
+ if results.get('keypoints', None) is not None:
+ keypoints = results['keypoints']
+ distances = np.asarray(distances).reshape(1, 1, 2)
+ keypoints = keypoints + distances
+ if self.bbox_clip_border:
+ keypoints_outside_x = keypoints[:, :, 0] < 0
+ keypoints_outside_y = keypoints[:, :, 1] < 0
+ keypoints_outside_width = keypoints[:, :, 0] > img_shape[1]
+ keypoints_outside_height = keypoints[:, :,
+ 1] > img_shape[0]
+
+ kpt_outside = np.logical_or.reduce(
+ (keypoints_outside_x, keypoints_outside_y,
+ keypoints_outside_width, keypoints_outside_height))
+
+ results['keypoints_visible'][kpt_outside] *= 0
+ keypoints[:, :, 0] = keypoints[:, :, 0].clip(0, img_shape[1])
+ keypoints[:, :, 1] = keypoints[:, :, 1].clip(0, img_shape[0])
+ results['keypoints'] = keypoints[valid_inds]
+ results['keypoints_visible'] = results['keypoints_visible'][
+ valid_inds]
+
+ if results.get('segmentation', None) is not None:
+ results['segmentation'] = results['segmentation'][
+ crop_y1:crop_y2, crop_x1:crop_x2]
+
+ if results.get('masks', None) is not None:
+ results['masks'] = results['masks'][valid_inds.nonzero(
+ )[0]].crop(np.asarray([crop_x1, crop_y1, crop_x2, crop_y2]))
+ if self.recompute_bbox:
+ results['bbox'] = results['masks'].get_bboxes(
+ type(results['bbox']))
+
+ return results
+
+ @cache_randomness
+ def _rand_offset(self, margin: Tuple[int, int]) -> Tuple[int, int]:
+ """Randomly generate crop offset.
+
+ Args:
+ margin (Tuple[int, int]): The upper bound for the offset generated
+ randomly.
+
+ Returns:
+ Tuple[int, int]: The random offset for the crop.
+ """
+ margin_h, margin_w = margin
+ offset_h = np.random.randint(0, margin_h + 1)
+ offset_w = np.random.randint(0, margin_w + 1)
+
+ return offset_h, offset_w
+
+ @cache_randomness
+ def _get_crop_size(self, image_size: Tuple[int, int]) -> Tuple[int, int]:
+ """Randomly generates the absolute crop size based on `crop_type` and
+ `image_size`.
+
+ Args:
+ image_size (Tuple[int, int]): (h, w).
+
+ Returns:
+ crop_size (Tuple[int, int]): (crop_h, crop_w) in absolute pixels.
+ """
+ h, w = image_size
+ if self.crop_type == 'absolute':
+ return min(self.crop_size[1], h), min(self.crop_size[0], w)
+ elif self.crop_type == 'absolute_range':
+ crop_h = np.random.randint(
+ min(h, self.crop_size[0]),
+ min(h, self.crop_size[1]) + 1)
+ crop_w = np.random.randint(
+ min(w, self.crop_size[0]),
+ min(w, self.crop_size[1]) + 1)
+ return crop_h, crop_w
+ elif self.crop_type == 'relative':
+ crop_w, crop_h = self.crop_size
+ return int(h * crop_h + 0.5), int(w * crop_w + 0.5)
+ else:
+ # 'relative_range'
+ crop_size = np.asarray(self.crop_size, dtype=np.float32)
+ crop_h, crop_w = crop_size + np.random.rand(2) * (1 - crop_size)
+ return int(h * crop_h + 0.5), int(w * crop_w + 0.5)
+
+ def transform(self, results: dict) -> Union[dict, None]:
+ """Transform function to randomly crop images, bounding boxes, masks,
+ semantic segmentation maps.
+
+ Args:
+ results (dict): Result dict from loading pipeline.
+
+ Returns:
+ results (Union[dict, None]): Randomly cropped results, 'img_shape'
+ key in result dict is updated according to crop size. None will
+ be returned when there is no valid bbox after cropping.
+ """
+ image_size = results['img'].shape[:2]
+ crop_size = self._get_crop_size(image_size)
+ results = self._crop_data(results, crop_size, self.allow_negative_crop)
+ return results
+
+
+@TRANSFORMS.register_module()
+class BottomupRandomChoiceResize(BaseTransform):
+ """Resize images & bbox & mask from a list of multiple scales.
+
+ This transform resizes the input image to some scale. Bboxes and masks are
+ then resized with the same scale factor. Resize scale will be randomly
+ selected from ``scales``.
+
+ How to choose the target scale to resize the image will follow the rules
+ below:
+
+ - if `scale` is a list of tuple, the target scale is sampled from the list
+ uniformally.
+ - if `scale` is a tuple, the target scale will be set to the tuple.
+
+ Required Keys:
+
+ - img
+ - bbox
+ - keypoints
+
+ Modified Keys:
+
+ - img
+ - img_shape
+ - bbox
+ - keypoints
+
+ Added Keys:
+
+ - scale
+ - scale_factor
+ - scale_idx
+
+ Args:
+ scales (Union[list, Tuple]): Images scales for resizing.
+
+ **resize_kwargs: Other keyword arguments for the ``resize_type``.
+ """
+
+ def __init__(
+ self,
+ scales: Sequence[Union[int, Tuple]],
+ keep_ratio: bool = False,
+ clip_object_border: bool = True,
+ backend: str = 'cv2',
+ **resize_kwargs,
+ ) -> None:
+ super().__init__()
+ if isinstance(scales, list):
+ self.scales = scales
+ else:
+ self.scales = [scales]
+
+ self.keep_ratio = keep_ratio
+ self.clip_object_border = clip_object_border
+ self.backend = backend
+
+ @cache_randomness
+ def _random_select(self) -> Tuple[int, int]:
+ """Randomly select an scale from given candidates.
+
+ Returns:
+ (tuple, int): Returns a tuple ``(scale, scale_dix)``,
+ where ``scale`` is the selected image scale and
+ ``scale_idx`` is the selected index in the given candidates.
+ """
+
+ scale_idx = np.random.randint(len(self.scales))
+ scale = self.scales[scale_idx]
+ return scale, scale_idx
+
+ def _resize_img(self, results: dict) -> None:
+ """Resize images with ``self.scale``."""
+
+ if self.keep_ratio:
+
+ img, scale_factor = imrescale(
+ results['img'],
+ self.scale,
+ interpolation='bilinear',
+ return_scale=True,
+ backend=self.backend)
+ # the w_scale and h_scale has minor difference
+ # a real fix should be done in the mmcv.imrescale in the future
+ new_h, new_w = img.shape[:2]
+ h, w = results['img'].shape[:2]
+ w_scale = new_w / w
+ h_scale = new_h / h
+ else:
+ img, w_scale, h_scale = imresize(
+ results['img'],
+ self.scale,
+ interpolation='bilinear',
+ return_scale=True,
+ backend=self.backend)
+ results['img'] = img
+ results['img_shape'] = img.shape[:2]
+ results['scale_factor'] = (w_scale, h_scale)
+
+ def _resize_bboxes(self, results: dict) -> None:
+ """Resize bounding boxes with ``self.scale``."""
+ if results.get('bbox', None) is not None:
+ bboxes = results['bbox'] * np.tile(
+ np.array(results['scale_factor']), 2)
+ if self.clip_object_border:
+ bboxes[:, 0::2] = np.clip(bboxes[:, 0::2], 0,
+ results['img_shape'][1])
+ bboxes[:, 1::2] = np.clip(bboxes[:, 1::2], 0,
+ results['img_shape'][0])
+ results['bbox'] = bboxes
+
+ def _resize_keypoints(self, results: dict) -> None:
+ """Resize keypoints with ``self.scale``."""
+ if results.get('keypoints', None) is not None:
+ keypoints = results['keypoints']
+
+ keypoints[:, :, :2] = keypoints[:, :, :2] * np.array(
+ results['scale_factor'])
+ if self.clip_object_border:
+ keypoints[:, :, 0] = np.clip(keypoints[:, :, 0], 0,
+ results['img_shape'][1])
+ keypoints[:, :, 1] = np.clip(keypoints[:, :, 1], 0,
+ results['img_shape'][0])
+ results['keypoints'] = keypoints
+
+ def transform(self, results: dict) -> dict:
+ """Apply resize transforms on results from a list of scales.
+
+ Args:
+ results (dict): Result dict contains the data to transform.
+
+ Returns:
+ dict: Resized results, 'img', 'bbox',
+ 'keypoints', 'scale', 'scale_factor', 'img_shape',
+ and 'keep_ratio' keys are updated in result dict.
+ """
+
+ target_scale, scale_idx = self._random_select()
+
+ self.scale = target_scale
+ self._resize_img(results)
+ self._resize_bboxes(results)
+ self._resize_keypoints(results)
+
+ results['scale_idx'] = scale_idx
+ return results
diff --git a/mmpose/models/data_preprocessors/__init__.py b/mmpose/models/data_preprocessors/__init__.py
index 7c9bd22e2b..03fd83a30c 100644
--- a/mmpose/models/data_preprocessors/__init__.py
+++ b/mmpose/models/data_preprocessors/__init__.py
@@ -1,4 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
-from .data_preprocessor import PoseDataPreprocessor
+from .data_preprocessor import BatchShapeDataPreprocessor, PoseDataPreprocessor
-__all__ = ['PoseDataPreprocessor']
+__all__ = ['PoseDataPreprocessor', 'BatchShapeDataPreprocessor']
diff --git a/mmpose/models/data_preprocessors/data_preprocessor.py b/mmpose/models/data_preprocessors/data_preprocessor.py
index bcfe54ab59..7ee3a42c04 100644
--- a/mmpose/models/data_preprocessors/data_preprocessor.py
+++ b/mmpose/models/data_preprocessors/data_preprocessor.py
@@ -1,5 +1,16 @@
# Copyright (c) OpenMMLab. All rights reserved.
+import math
+from numbers import Number
+from typing import Optional, Sequence, Union
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torchvision.transforms.functional
from mmengine.model import ImgDataPreprocessor
+from mmengine.model.utils import stack_batch
+from mmengine.utils import is_seq_of
+from PIL import Image
from mmpose.registry import MODELS
@@ -7,3 +18,176 @@
@MODELS.register_module()
class PoseDataPreprocessor(ImgDataPreprocessor):
"""Image pre-processor for pose estimation tasks."""
+
+
+@MODELS.register_module()
+class BatchShapeDataPreprocessor(ImgDataPreprocessor):
+ """Image pre-processor for pose estimation tasks.
+
+ Comparing with the :class:`PoseDataPreprocessor`,
+
+ 1. It will additionally append batch_input_shape
+ to data_samples considering the DETR-based pose estimation tasks.
+
+ 2. Add a 'pillow backend' pipeline based normalize operation, convert
+ np.array to PIL.Image, and normalize it through torchvision.
+
+ It provides the data pre-processing as follows
+
+ - Collate and move data to the target device.
+ - Pad inputs to the maximum size of current batch with defined
+ ``pad_value``. The padding size can be divisible by a defined
+ ``pad_size_divisor``
+ - Stack inputs to batch_inputs.
+ - Convert inputs from bgr to rgb if the shape of input is (3, H, W).
+ - Normalize image with defined std and mean.
+
+ Args:
+ - mean (Sequence[Number], optional): The pixel mean of R, G, B
+ channels. Defaults to None.
+ - std (Sequence[Number], optional): The pixel standard deviation
+ of R, G, B channels. Defaults to None.
+ - pad_size_divisor (int): The size of padded image should be
+ divisible by ``pad_size_divisor``. Defaults to 1.
+ - pad_value (Number): The padded pixel value. Defaults to 0.
+ - bgr_to_rgb (bool): whether to convert image from BGR to RGB.
+ Defaults to False.
+ - rgb_to_bgr (bool): whether to convert image from RGB to RGB.
+ Defaults to False.
+ - non_blocking (bool): Whether block current process
+ when transferring data to device. Defaults to False.
+ - normalize_bakend (str): choose the normalize backend
+ in ['cv2', 'pillow']
+ """
+
+ def __init__(self,
+ mean: Sequence[Number] = None,
+ std: Sequence[Number] = None,
+ pad_size_divisor: int = 1,
+ pad_value: Union[float, int] = 0,
+ bgr_to_rgb: bool = False,
+ rgb_to_bgr: bool = False,
+ non_blocking: Optional[bool] = False,
+ normalize_bakend: str = 'cv2'):
+ super().__init__(
+ mean=mean,
+ std=std,
+ pad_size_divisor=pad_size_divisor,
+ pad_value=pad_value,
+ bgr_to_rgb=bgr_to_rgb,
+ rgb_to_bgr=rgb_to_bgr,
+ non_blocking=non_blocking)
+ self.normalize_bakend = normalize_bakend
+
+ def forward(self, data: dict, training: bool = False) -> dict:
+ """Perform normalization、padding and bgr2rgb conversion based on
+ ``BaseDataPreprocessor``.
+
+ Args:
+ data (dict): Data sampled from dataloader.
+ training (bool): Whether to enable training time augmentation.
+
+ Returns:
+ dict: Data in the same format as the model input.
+ """
+ if self.normalize_bakend == 'cv2':
+ data = super().forward(data=data, training=training)
+ else:
+ data = self.normalize_pillow(data=data, training=training)
+
+ inputs, data_samples = data['inputs'], data['data_samples']
+
+ if data_samples is not None:
+ # NOTE the batched image size information may be useful, e.g.
+ # in DETR, this is needed for the construction of masks, which is
+ # then used for the transformer_head.
+ batch_input_shape = tuple(inputs[0].size()[-2:])
+ for data_sample in data_samples:
+
+ w, h = data_sample.ori_shape
+ center = np.array([w / 2, h / 2], dtype=np.float32)
+ scale = np.array([w, h], dtype=np.float32)
+ data_sample.set_metainfo({
+ 'batch_input_shape': batch_input_shape,
+ 'input_size': data_sample.img_shape,
+ 'input_center': center,
+ 'input_scale': scale
+ })
+ return {'inputs': inputs, 'data_samples': data_samples}
+
+ def normalize_pillow(self,
+ data: dict,
+ training: bool = False) -> Union[dict, list]:
+
+ data = self.cast_data(data) # type: ignore
+ _batch_inputs = data['inputs']
+ # Process data with `pseudo_collate`.
+ if is_seq_of(_batch_inputs, torch.Tensor):
+ batch_inputs = []
+ for _batch_input in _batch_inputs:
+ # channel transform
+ if self._channel_conversion:
+ _batch_input = _batch_input[[2, 1, 0], ...]
+
+ _batch_input_array = _batch_input.detach().cpu().numpy(
+ ).transpose(1, 2, 0)
+ assert _batch_input_array.dtype == np.uint8, \
+ 'Pillow backend only support uint8 type'
+ pil_image = Image.fromarray(_batch_input_array)
+ _batch_input = torchvision.transforms.functional.to_tensor(
+ pil_image).to(_batch_input.device)
+
+ # Normalization.
+ if self._enable_normalize:
+ if self.mean.shape[0] == 3:
+ assert _batch_input.dim(
+ ) == 3 and _batch_input.shape[0] == 3, (
+ 'If the mean has 3 values, the input tensor '
+ 'should in shape of (3, H, W), but got the tensor '
+ f'with shape {_batch_input.shape}')
+ _batch_input = torchvision.transforms.functional.normalize(
+ _batch_input, mean=self.mean, std=self.std)
+ batch_inputs.append(_batch_input)
+ # Pad and stack Tensor.
+ batch_inputs = stack_batch(batch_inputs, self.pad_size_divisor,
+ self.pad_value)
+ # Process data with `default_collate`.
+ elif isinstance(_batch_inputs, torch.Tensor):
+ assert _batch_inputs.dim() == 4, (
+ 'The input of `ImgDataPreprocessor` should be a NCHW tensor '
+ 'or a list of tensor, but got a tensor with shape: '
+ f'{_batch_inputs.shape}')
+ if self._channel_conversion:
+ _batch_inputs = _batch_inputs[:, [2, 1, 0], ...]
+ # Convert to float after channel conversion to ensure
+ # efficiency
+ _batch_inputs_array = _batch_inputs.detach().cpu().numpy(
+ ).transpose(0, 2, 3, 1)
+ assert _batch_inputs.dtype == np.uint8, \
+ 'Pillow backend only support uint8 type'
+ pil_image = Image.fromarray(_batch_inputs_array)
+ _batch_inputs = torchvision.transforms.functional.to_tensor(
+ pil_image).to(_batch_inputs.device)
+
+ if self._enable_normalize:
+ _batch_inputs = torchvision.transforms.functional.normalize(
+ _batch_inputs,
+ mean=(self.mean / 255).tolist(),
+ std=(self.std / 255).tolist())
+
+ h, w = _batch_inputs.shape[2:]
+ target_h = math.ceil(
+ h / self.pad_size_divisor) * self.pad_size_divisor
+ target_w = math.ceil(
+ w / self.pad_size_divisor) * self.pad_size_divisor
+ pad_h = target_h - h
+ pad_w = target_w - w
+ batch_inputs = F.pad(_batch_inputs, (0, pad_w, 0, pad_h),
+ 'constant', self.pad_value)
+ else:
+ raise TypeError('Output of `cast_data` should be a dict of '
+ 'list/tuple with inputs and data_samples, '
+ f'but got {type(data)}: {data}')
+ data['inputs'] = batch_inputs
+ data.setdefault('data_samples', None)
+ return data
diff --git a/mmpose/models/heads/__init__.py b/mmpose/models/heads/__init__.py
index ef0e17d98e..2a2a49b0dc 100644
--- a/mmpose/models/heads/__init__.py
+++ b/mmpose/models/heads/__init__.py
@@ -8,11 +8,13 @@
MotionRegressionHead, RegressionHead, RLEHead,
TemporalRegressionHead,
TrajectoryRegressionHead)
+from .transformer_heads import EDPoseHead, FrozenBatchNorm2d
__all__ = [
'BaseHead', 'HeatmapHead', 'CPMHead', 'MSPNHead', 'ViPNASHead',
'RegressionHead', 'IntegralRegressionHead', 'SimCCHead', 'RLEHead',
'DSNTHead', 'AssociativeEmbeddingHead', 'DEKRHead', 'VisPredictHead',
'CIDHead', 'RTMCCHead', 'TemporalRegressionHead',
- 'TrajectoryRegressionHead', 'MotionRegressionHead'
+ 'TrajectoryRegressionHead', 'MotionRegressionHead', 'EDPoseHead',
+ 'FrozenBatchNorm2d'
]
diff --git a/mmpose/models/heads/transformer_heads/__init__.py b/mmpose/models/heads/transformer_heads/__init__.py
new file mode 100644
index 0000000000..f6d58126ef
--- /dev/null
+++ b/mmpose/models/heads/transformer_heads/__init__.py
@@ -0,0 +1,19 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .edpose_head import EDPoseHead, FrozenBatchNorm2d
+from .transformers import (MLP, DeformableDetrTransformerDecoder,
+ DeformableDetrTransformerDecoderLayer,
+ DeformableDetrTransformerEncoder,
+ DeformableDetrTransformerEncoderLayer,
+ DetrTransformerDecoder, DetrTransformerDecoderLayer,
+ DetrTransformerEncoder, DetrTransformerEncoderLayer,
+ PositionEmbeddingSineHW, inverse_sigmoid)
+
+__all__ = [
+ 'EDPoseHead', 'FrozenBatchNorm2d', 'DetrTransformerEncoder',
+ 'DetrTransformerDecoder', 'DetrTransformerEncoderLayer',
+ 'DetrTransformerDecoderLayer', 'DeformableDetrTransformerEncoder',
+ 'DeformableDetrTransformerDecoder',
+ 'DeformableDetrTransformerEncoderLayer',
+ 'DeformableDetrTransformerDecoderLayer', 'inverse_sigmoid',
+ 'PositionEmbeddingSineHW', 'MLP'
+]
diff --git a/mmpose/models/heads/transformer_heads/base_transformer_head.py b/mmpose/models/heads/transformer_heads/base_transformer_head.py
new file mode 100644
index 0000000000..bf94cc03b6
--- /dev/null
+++ b/mmpose/models/heads/transformer_heads/base_transformer_head.py
@@ -0,0 +1,112 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from abc import abstractmethod
+from typing import Dict, Tuple
+
+from torch import Tensor
+
+from mmpose.models.utils.tta import flip_coordinates
+from mmpose.registry import MODELS
+from mmpose.utils.typing import (Features, OptConfigType, OptMultiConfig,
+ OptSampleList, Predictions)
+from ..base_head import BaseHead
+
+
+@MODELS.register_module()
+class TransformerHead(BaseHead):
+ r"""Implementation of `Deformable DETR: Deformable Transformers for
+ End-to-End Object Detection `_
+
+ Code is modified from the `official github repo
+ `_.
+
+ Args:
+ encoder (:obj:`ConfigDict` or dict, optional): Config of the
+ Transformer encoder. Defaults to None.
+ decoder (:obj:`ConfigDict` or dict, optional): Config of the
+ Transformer decoder. Defaults to None.
+ out_head (:obj:`ConfigDict` or dict, optional): Config for the
+ bounding final out head module. Defaults to None.
+ positional_encoding (:obj:`ConfigDict` or dict): Config for
+ transformer position encoding. Defaults None
+ num_queries (int): Number of query in Transformer.
+ """
+ _version = 2
+
+ def __init__(self,
+ encoder: OptConfigType = None,
+ decoder: OptConfigType = None,
+ out_head: OptConfigType = None,
+ positional_encoding: OptConfigType = None,
+ num_queries: int = 100,
+ loss: OptConfigType = None,
+ init_cfg: OptMultiConfig = None):
+
+ if init_cfg is None:
+ init_cfg = self.default_init_cfg
+
+ super().__init__(init_cfg)
+
+ self.encoder_cfg = encoder
+ self.decoder_cfg = decoder
+ self.out_head_cfg = out_head
+ self.positional_encoding_cfg = positional_encoding
+ self.num_queries = num_queries
+
+ def forward(self,
+ feats: Tuple[Tensor],
+ batch_data_samples: OptSampleList = None) -> Dict:
+ """Forward the network."""
+ encoder_outputs_dict = self.forward_encoder(feats, batch_data_samples)
+
+ decoder_outputs_dict = self.forward_decoder(**encoder_outputs_dict)
+
+ head_outputs_dict = self.forward_out_head(batch_data_samples,
+ **decoder_outputs_dict)
+ return head_outputs_dict
+
+ def predict(self,
+ feats: Features,
+ batch_data_samples: OptSampleList,
+ test_cfg: OptConfigType = {}) -> Predictions:
+ """Predict results from features."""
+
+ if test_cfg.get('flip_test', False):
+ # TTA: flip test -> feats = [orig, flipped]
+ assert isinstance(feats, list) and len(feats) == 2
+ flip_indices = batch_data_samples[0].metainfo['flip_indices']
+ input_size = batch_data_samples[0].metainfo['input_size']
+ _feats, _feats_flip = feats
+
+ _batch_coords = self.forward(_feats)
+ _batch_coords_flip = flip_coordinates(
+ self.forward(_feats_flip),
+ flip_indices=flip_indices,
+ shift_coords=test_cfg.get('shift_coords', True),
+ input_size=input_size)
+ batch_coords = (_batch_coords + _batch_coords_flip) * 0.5
+ else:
+ batch_coords = self.forward(feats, batch_data_samples) # (B, K, D)
+
+ return batch_coords
+
+ def loss(self,
+ feats: Tuple[Tensor],
+ batch_data_samples: OptSampleList,
+ train_cfg: OptConfigType = {}) -> dict:
+ """Calculate losses from a batch of inputs and data samples."""
+ pass
+
+ @abstractmethod
+ def forward_encoder(self, feat: Tensor, feat_mask: Tensor,
+ feat_pos: Tensor, **kwargs) -> Dict:
+ pass
+
+ @abstractmethod
+ def forward_decoder(self, query: Tensor, query_pos: Tensor, memory: Tensor,
+ **kwargs) -> Dict:
+ pass
+
+ @abstractmethod
+ def forward_out_head(self, query: Tensor, query_pos: Tensor,
+ memory: Tensor, **kwargs) -> Dict:
+ pass
diff --git a/mmpose/models/heads/transformer_heads/edpose_head.py b/mmpose/models/heads/transformer_heads/edpose_head.py
new file mode 100644
index 0000000000..f2dd154563
--- /dev/null
+++ b/mmpose/models/heads/transformer_heads/edpose_head.py
@@ -0,0 +1,1439 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import copy
+import math
+from typing import Dict, List, Tuple
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from mmcv.ops import MultiScaleDeformableAttention
+from mmengine.model import BaseModule, ModuleList, constant_init
+from mmengine.structures import InstanceData
+from torch import Tensor, nn
+
+from mmpose.registry import KEYPOINT_CODECS, MODELS
+from mmpose.utils.tensor_utils import to_numpy
+from mmpose.utils.typing import (ConfigType, Features, OptConfigType,
+ OptSampleList, Predictions)
+from .base_transformer_head import TransformerHead
+from .transformers.deformable_detr_layers import (
+ DeformableDetrTransformerDecoderLayer, DeformableDetrTransformerEncoder)
+from .transformers.utils import MLP, PositionEmbeddingSineHW, inverse_sigmoid
+
+
+@MODELS.register_module()
+class FrozenBatchNorm2d(BaseModule):
+ """BatchNorm2d where the batch statistics and the affine parameters are
+ fixed.
+
+ Copy-paste from torchvision.misc.ops with added eps before rqsrt, without
+ which any other models than torchvision.models.resnet[18,34,50,101] produce
+ nans.
+ """
+
+ def __init__(self, n, eps: int = 1e-5):
+ super(FrozenBatchNorm2d, self).__init__()
+ self.register_buffer('weight', torch.ones(n))
+ self.register_buffer('bias', torch.zeros(n))
+ self.register_buffer('running_mean', torch.zeros(n))
+ self.register_buffer('running_var', torch.ones(n))
+ self.eps = eps
+
+ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
+ missing_keys, unexpected_keys, error_msgs):
+ num_batches_tracked_key = prefix + 'num_batches_tracked'
+ if num_batches_tracked_key in state_dict:
+ del state_dict[num_batches_tracked_key]
+
+ super(FrozenBatchNorm2d,
+ self)._load_from_state_dict(state_dict, prefix, local_metadata,
+ strict, missing_keys,
+ unexpected_keys, error_msgs)
+
+ def forward(self, x):
+ w = self.weight.reshape(1, -1, 1, 1)
+ b = self.bias.reshape(1, -1, 1, 1)
+ rv = self.running_var.reshape(1, -1, 1, 1)
+ rm = self.running_mean.reshape(1, -1, 1, 1)
+ scale = w * (rv + self.eps).rsqrt()
+ bias = b - rm * scale
+ return x * scale + bias
+
+
+class EDPoseDecoder(BaseModule):
+ """Transformer decoder of EDPose: `Explicit Box Detection Unifies End-to-
+ End Multi-Person Pose Estimation.
+
+ Args:
+ - layer_cfg (:obj:`ConfigDict` or dict): the config of each encoder
+ layer. All the layers will share the same config.
+ - num_layers (int): Number of decoder layers.
+ - return_intermediate (bool, optional): Whether to return outputs of
+ intermediate layers. Defaults to `True`.
+ - embed_dims (int): Dims of embed.
+ - query_dim (int): Dims of queries.
+ - num_feature_levels (int): Number of feature levels.
+ - num_box_decoder_layers (int): Number of box decoder layers.
+ - num_body_points (int): Number of datasets' body keypoints.
+ - num_dn (int): Number of denosing points.
+ - num_group (int): Number of decoder layers.
+ """
+
+ def __init__(self,
+ layer_cfg,
+ num_layers,
+ return_intermediate,
+ embed_dims: int = 256,
+ query_dim=4,
+ num_feature_levels=1,
+ num_box_decoder_layers=2,
+ num_body_points=17,
+ num_dn=100,
+ num_group=100):
+ super().__init__()
+
+ self.layer_cfg = layer_cfg
+ self.num_layers = num_layers
+ self.embed_dims = embed_dims
+
+ assert return_intermediate, 'support return_intermediate only'
+ self.return_intermediate = return_intermediate
+
+ assert query_dim in [
+ 2, 4
+ ], 'query_dim should be 2/4 but {}'.format(query_dim)
+ self.query_dim = query_dim
+
+ self.num_feature_levels = num_feature_levels
+
+ self.layers = ModuleList([
+ DeformableDetrTransformerDecoderLayer(**self.layer_cfg)
+ for _ in range(self.num_layers)
+ ])
+ self.norm = nn.LayerNorm(self.embed_dims)
+
+ self.ref_point_head = MLP(self.query_dim // 2 * self.embed_dims,
+ self.embed_dims, self.embed_dims, 2)
+
+ self.num_body_points = num_body_points
+ self.query_scale = None
+ self.bbox_embed = None
+ self.class_embed = None
+ self.pose_embed = None
+ self.pose_hw_embed = None
+ self.num_box_decoder_layers = num_box_decoder_layers
+ self.box_pred_damping = None
+ self.num_group = num_group
+ self.rm_detach = None
+ self.num_dn = num_dn
+ self.hw = nn.Embedding(self.num_body_points, 2)
+ self.keypoint_embed = nn.Embedding(self.num_body_points, embed_dims)
+ self.kpt_index = [
+ x for x in range(self.num_group * (self.num_body_points + 1))
+ if x % (self.num_body_points + 1) != 0
+ ]
+
+ def forward(self, query: Tensor, value: Tensor, key_padding_mask: Tensor,
+ reference_points: Tensor, spatial_shapes: Tensor,
+ level_start_index: Tensor, valid_ratios: Tensor,
+ humandet_attn_mask: Tensor, human2pose_attn_mask: Tensor,
+ **kwargs) -> Tuple[Tensor]:
+ """Forward function of decoder
+ Args:
+ query (Tensor): The input queries, has shape (bs, num_queries,
+ dim).
+ value (Tensor): The input values, has shape (bs, num_value, dim).
+ key_padding_mask (Tensor): The `key_padding_mask` of `cross_attn`
+ input. ByteTensor, has shape (bs, num_value).
+ reference_points (Tensor): The initial reference, has shape
+ (bs, num_queries, 4) with the last dimension arranged as
+ (cx, cy, w, h) when `as_two_stage` is `True`, otherwise has
+ shape (bs, num_queries, 2) with the last dimension arranged
+ as (cx, cy).
+ spatial_shapes (Tensor): Spatial shapes of features in all levels,
+ has shape (num_levels, 2), last dimension represents (h, w).
+ level_start_index (Tensor): The start index of each level.
+ A tensor has shape (num_levels, ) and can be represented
+ as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...].
+ valid_ratios (Tensor): The ratios of the valid width and the valid
+ height relative to the width and the height of features in all
+ levels, has shape (bs, num_levels, 2).
+ reg_branches: (obj:`nn.ModuleList`, optional): Used for refining
+ the regression results.
+
+ Returns:
+ tuple[Tensor]: Outputs of Deformable Transformer Decoder.
+
+ - output (Tensor): Output embeddings of the last decoder, has
+ shape (num_queries, bs, embed_dims) when `return_intermediate`
+ is `False`. Otherwise, Intermediate output embeddings of all
+ decoder layers, has shape (num_decoder_layers, num_queries, bs,
+ embed_dims).
+ - reference_points (Tensor): The reference of the last decoder
+ layer, has shape (bs, num_queries, 4) when `return_intermediate`
+ is `False`. Otherwise, Intermediate references of all decoder
+ layers, has shape (num_decoder_layers, bs, num_queries, 4). The
+ coordinates are arranged as (cx, cy, w, h)
+ """
+ output = query
+ attn_mask = humandet_attn_mask
+ intermediate = []
+ intermediate_reference_points = [reference_points]
+ effect_num_dn = self.num_dn if self.training else 0
+ inter_select_number = self.num_group
+ for layer_id, layer in enumerate(self.layers):
+ if reference_points.shape[-1] == 4:
+ reference_points_input = \
+ reference_points[:, :, None] * \
+ torch.cat([valid_ratios, valid_ratios], -1)[None, :]
+ else:
+ assert reference_points.shape[-1] == 2
+ reference_points_input = \
+ reference_points[:, :, None] * \
+ valid_ratios[None, :]
+
+ query_sine_embed = self.get_proposal_pos_embed(
+ reference_points_input[:, :, 0, :]) # nq, bs, 256*2
+ query_pos = self.ref_point_head(query_sine_embed) # nq, bs, 256
+
+ output = layer(
+ output.transpose(0, 1),
+ query_pos=query_pos.transpose(0, 1),
+ value=value.transpose(0, 1),
+ key_padding_mask=key_padding_mask,
+ spatial_shapes=spatial_shapes,
+ level_start_index=level_start_index,
+ valid_ratios=valid_ratios,
+ reference_points=reference_points_input.transpose(
+ 0, 1).contiguous(),
+ self_attn_mask=attn_mask,
+ **kwargs)
+ output = output.transpose(0, 1)
+ intermediate.append(self.norm(output))
+
+ # human update
+ if layer_id < self.num_box_decoder_layers:
+ delta_unsig = self.bbox_embed[layer_id](output)
+ new_reference_points = delta_unsig + inverse_sigmoid(
+ reference_points)
+ new_reference_points = new_reference_points.sigmoid()
+
+ # query expansion
+ if layer_id == self.num_box_decoder_layers - 1:
+ dn_output = output[:effect_num_dn]
+ dn_new_reference_points = new_reference_points[:effect_num_dn]
+ class_unselected = self.class_embed[layer_id](
+ output)[effect_num_dn:]
+ topk_proposals = torch.topk(
+ class_unselected.max(-1)[0], inter_select_number, dim=0)[1]
+ new_reference_points_for_box = torch.gather(
+ new_reference_points[effect_num_dn:], 0,
+ topk_proposals.unsqueeze(-1).repeat(1, 1, 4))
+ new_output_for_box = torch.gather(
+ output[effect_num_dn:], 0,
+ topk_proposals.unsqueeze(-1).repeat(1, 1, self.embed_dims))
+ bs = new_output_for_box.shape[1]
+ new_output_for_keypoint = new_output_for_box[:, None, :, :] \
+ + self.keypoint_embed.weight[None, :, None, :]
+ if self.num_body_points == 17:
+ delta_xy = self.pose_embed[-1](new_output_for_keypoint)[
+ ..., :2]
+ else:
+ delta_xy = self.pose_embed[0](new_output_for_keypoint)[
+ ..., :2]
+ keypoint_xy = (inverse_sigmoid(
+ new_reference_points_for_box[..., :2][:, None]) +
+ delta_xy).sigmoid()
+ num_queries, _, bs, _ = keypoint_xy.shape
+ keypoint_wh_weight = self.hw.weight.unsqueeze(0).unsqueeze(
+ -2).repeat(num_queries, 1, bs, 1).sigmoid()
+ keypoint_wh = keypoint_wh_weight * \
+ new_reference_points_for_box[..., 2:][:, None]
+ new_reference_points_for_keypoint = torch.cat(
+ (keypoint_xy, keypoint_wh), dim=-1)
+ new_reference_points = torch.cat(
+ (new_reference_points_for_box.unsqueeze(1),
+ new_reference_points_for_keypoint),
+ dim=1).flatten(0, 1)
+ output = torch.cat(
+ (new_output_for_box.unsqueeze(1), new_output_for_keypoint),
+ dim=1).flatten(0, 1)
+ new_reference_points = torch.cat(
+ (dn_new_reference_points, new_reference_points), dim=0)
+ output = torch.cat((dn_output, output), dim=0)
+ attn_mask = human2pose_attn_mask
+
+ # human-to-keypoints update
+ if layer_id >= self.num_box_decoder_layers:
+ effect_num_dn = self.num_dn if self.training else 0
+ inter_select_number = self.num_group
+ ref_before_sigmoid = inverse_sigmoid(reference_points)
+ output_bbox_dn = output[:effect_num_dn]
+ output_bbox_norm = output[effect_num_dn:][0::(
+ self.num_body_points + 1)]
+ ref_before_sigmoid_bbox_dn = \
+ ref_before_sigmoid[:effect_num_dn]
+ ref_before_sigmoid_bbox_norm = \
+ ref_before_sigmoid[effect_num_dn:][0::(
+ self.num_body_points + 1)]
+ delta_unsig_dn = self.bbox_embed[layer_id](output_bbox_dn)
+ delta_unsig_norm = self.bbox_embed[layer_id](output_bbox_norm)
+ outputs_unsig_dn = delta_unsig_dn + ref_before_sigmoid_bbox_dn
+ outputs_unsig_norm = delta_unsig_norm + \
+ ref_before_sigmoid_bbox_norm
+ new_reference_points_for_box_dn = outputs_unsig_dn.sigmoid()
+ new_reference_points_for_box_norm = outputs_unsig_norm.sigmoid(
+ )
+ output_kpt = output[effect_num_dn:].index_select(
+ 0, torch.tensor(self.kpt_index, device=output.device))
+ delta_xy_unsig = self.pose_embed[layer_id -
+ self.num_box_decoder_layers](
+ output_kpt)
+ outputs_unsig = ref_before_sigmoid[
+ effect_num_dn:].index_select(
+ 0, torch.tensor(self.kpt_index,
+ device=output.device)).clone()
+ delta_hw_unsig = self.pose_hw_embed[
+ layer_id - self.num_box_decoder_layers](
+ output_kpt)
+ outputs_unsig[..., :2] += delta_xy_unsig[..., :2]
+ outputs_unsig[..., 2:] += delta_hw_unsig
+ new_reference_points_for_keypoint = outputs_unsig.sigmoid()
+ bs = new_reference_points_for_box_norm.shape[1]
+ new_reference_points_norm = torch.cat(
+ (new_reference_points_for_box_norm.unsqueeze(1),
+ new_reference_points_for_keypoint.view(
+ -1, self.num_body_points, bs, 4)),
+ dim=1).flatten(0, 1)
+ new_reference_points = torch.cat(
+ (new_reference_points_for_box_dn,
+ new_reference_points_norm),
+ dim=0)
+
+ reference_points = new_reference_points.detach()
+ intermediate_reference_points.append(reference_points)
+
+ return [[itm_out.transpose(0, 1) for itm_out in intermediate],
+ [
+ itm_refpoint.transpose(0, 1)
+ for itm_refpoint in intermediate_reference_points
+ ]]
+
+ @staticmethod
+ def get_proposal_pos_embed(pos_tensor: Tensor,
+ temperature: int = 10000,
+ num_pos_feats: int = 128) -> Tensor:
+ """Get the position embedding of the proposal.
+
+ Args:
+ pos_tensor (Tensor): Not normalized proposals, has shape
+ (bs, num_queries, 4) with the last dimension arranged as
+ (cx, cy, w, h).
+ temperature (int, optional): The temperature used for scaling the
+ position embedding. Defaults to 10000.
+ num_pos_feats (int, optional): The feature dimension for each
+ position along x, y, w, and h-axis. Note the final returned
+ dimension for each position is 4 times of num_pos_feats.
+ Default to 128.
+
+ Returns:
+ Tensor: The position embedding of proposal, has shape
+ (bs, num_queries, num_pos_feats * 4), with the last dimension
+ arranged as (cx, cy, w, h)
+ """
+
+ scale = 2 * math.pi
+ dim_t = torch.arange(
+ num_pos_feats, dtype=torch.float32, device=pos_tensor.device)
+ dim_t = temperature**(2 * (dim_t // 2) / num_pos_feats)
+ x_embed = pos_tensor[:, :, 0] * scale
+ y_embed = pos_tensor[:, :, 1] * scale
+ pos_x = x_embed[:, :, None] / dim_t
+ pos_y = y_embed[:, :, None] / dim_t
+ pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()),
+ dim=3).flatten(2)
+ pos_y = torch.stack((pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()),
+ dim=3).flatten(2)
+ if pos_tensor.size(-1) == 2:
+ pos = torch.cat((pos_y, pos_x), dim=2)
+ elif pos_tensor.size(-1) == 4:
+ w_embed = pos_tensor[:, :, 2] * scale
+ pos_w = w_embed[:, :, None] / dim_t
+ pos_w = torch.stack(
+ (pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()),
+ dim=3).flatten(2)
+
+ h_embed = pos_tensor[:, :, 3] * scale
+ pos_h = h_embed[:, :, None] / dim_t
+ pos_h = torch.stack(
+ (pos_h[:, :, 0::2].sin(), pos_h[:, :, 1::2].cos()),
+ dim=3).flatten(2)
+
+ pos = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=2)
+ else:
+ raise ValueError('Unknown pos_tensor shape(-1):{}'.format(
+ pos_tensor.size(-1)))
+ return pos
+
+
+class EDPoseOutHead(BaseModule):
+ """Final Head of EDPose: `Explicit Box Detection Unifies End-to-End Multi-
+ Person Pose Estimation.
+
+ Args:
+ - num_classes (int): The number of classes.
+ - num_body_points (int): The number of datasets' body keypoints.
+ - num_queries (int): The number of queries.
+ - cls_no_bias (bool): Weather add the bias to class embed.
+ - embed_dims (int): The dims of embed.
+ - as_two_stage (bool, optional): Whether to generate the proposal
+ from the outputs of encoder. Defaults to `False`.
+ - refine_queries_num (int): The number of refines queries after
+ decoders.
+ - num_box_decoder_layers (int): The number of bbox decoder layer.
+ - num_group (int): The number of groups.
+ - num_pred_layer (int): The number of the prediction layers.
+ Defaults to 6.
+ - dec_pred_class_embed_share (bool): Whether to share parameters
+ for all the class prediction layers. Defaults to `False`.
+ - dec_pred_bbox_embed_share (bool): Whether to share parameters
+ for all the bbox prediction layers. Defaults to `False`.
+ - dec_pred_pose_embed_share (bool): Whether to share parameters
+ for all the pose prediction layers. Defaults to `False`.
+ """
+
+ def __init__(self,
+ num_classes,
+ num_body_points: int = 17,
+ num_queries: int = 900,
+ cls_no_bias: bool = False,
+ embed_dims: int = 256,
+ as_two_stage: bool = False,
+ refine_queries_num: int = 100,
+ num_box_decoder_layers: int = 2,
+ num_group: int = 100,
+ num_pred_layer: int = 6,
+ dec_pred_class_embed_share: bool = False,
+ dec_pred_bbox_embed_share: bool = False,
+ dec_pred_pose_embed_share: bool = False,
+ **kwargs):
+ super().__init__()
+ self.embed_dims = embed_dims
+ self.as_two_stage = as_two_stage
+ self.num_classes = num_classes
+ self.refine_queries_num = refine_queries_num
+ self.num_box_decoder_layers = num_box_decoder_layers
+ self.num_body_points = num_body_points
+ self.num_queries = num_queries
+
+ # prepare pred layers
+ self.dec_pred_class_embed_share = dec_pred_class_embed_share
+ self.dec_pred_bbox_embed_share = dec_pred_bbox_embed_share
+ self.dec_pred_pose_embed_share = dec_pred_pose_embed_share
+ # prepare class & box embed
+ _class_embed = nn.Linear(
+ self.embed_dims, self.num_classes, bias=(not cls_no_bias))
+ if not cls_no_bias:
+ prior_prob = 0.01
+ bias_value = -math.log((1 - prior_prob) / prior_prob)
+ _class_embed.bias.data = torch.ones(self.num_classes) * bias_value
+
+ _bbox_embed = MLP(self.embed_dims, self.embed_dims, 4, 3)
+ _pose_embed = MLP(self.embed_dims, self.embed_dims, 2, 3)
+ _pose_hw_embed = MLP(self.embed_dims, self.embed_dims, 2, 3)
+
+ self.num_group = num_group
+ if dec_pred_bbox_embed_share:
+ box_embed_layerlist = [_bbox_embed for i in range(num_pred_layer)]
+ else:
+ box_embed_layerlist = [
+ copy.deepcopy(_bbox_embed) for i in range(num_pred_layer)
+ ]
+ if dec_pred_class_embed_share:
+ class_embed_layerlist = [
+ _class_embed for i in range(num_pred_layer)
+ ]
+ else:
+ class_embed_layerlist = [
+ copy.deepcopy(_class_embed) for i in range(num_pred_layer)
+ ]
+
+ if num_body_points == 17:
+ if dec_pred_pose_embed_share:
+ pose_embed_layerlist = [
+ _pose_embed
+ for i in range(num_pred_layer - num_box_decoder_layers + 1)
+ ]
+ else:
+ pose_embed_layerlist = [
+ copy.deepcopy(_pose_embed)
+ for i in range(num_pred_layer - num_box_decoder_layers + 1)
+ ]
+ else:
+ if dec_pred_pose_embed_share:
+ pose_embed_layerlist = [
+ _pose_embed
+ for i in range(num_pred_layer - num_box_decoder_layers)
+ ]
+ else:
+ pose_embed_layerlist = [
+ copy.deepcopy(_pose_embed)
+ for i in range(num_pred_layer - num_box_decoder_layers)
+ ]
+
+ pose_hw_embed_layerlist = [
+ _pose_hw_embed
+ for i in range(num_pred_layer - num_box_decoder_layers)
+ ]
+ self.bbox_embed = nn.ModuleList(box_embed_layerlist)
+ self.class_embed = nn.ModuleList(class_embed_layerlist)
+ self.pose_embed = nn.ModuleList(pose_embed_layerlist)
+ self.pose_hw_embed = nn.ModuleList(pose_hw_embed_layerlist)
+
+ def init_weights(self) -> None:
+ """Initialize weights of the Deformable DETR head."""
+
+ for m in self.bbox_embed:
+ constant_init(m[-1], 0, bias=0)
+ for m in self.pose_embed:
+ constant_init(m[-1], 0, bias=0)
+
+ def forward(self, hidden_states: List[Tensor], references: List[Tensor],
+ mask_dict: Dict, hidden_states_enc: Tensor,
+ referens_enc: Tensor, batch_data_samples) -> Dict:
+ """Forward function.
+
+ Args:
+ hidden_states (Tensor): Hidden states output from each decoder
+ layer, has shape (num_decoder_layers, bs, num_queries, dim).
+ references (list[Tensor]): List of the reference from the decoder.
+
+ Returns:
+ tuple[Tensor]: results of head containing the following tensor.
+
+ - pred_logits (Tensor): Outputs from the
+ classification head, the socres of every bboxes.
+ - pred_boxes (Tensor): The output boxes.
+ - pred_keypoints (Tensor): The output keypoints.
+ """
+ # update human boxes
+ effec_dn_num = self.refine_queries_num if self.training else 0
+ outputs_coord_list = []
+ outputs_class = []
+ for dec_lid, (layer_ref_sig, layer_bbox_embed, layer_cls_embed,
+ layer_hs) in enumerate(
+ zip(references[:-1], self.bbox_embed,
+ self.class_embed, hidden_states)):
+ if dec_lid < self.num_box_decoder_layers:
+ layer_delta_unsig = layer_bbox_embed(layer_hs)
+ layer_outputs_unsig = layer_delta_unsig + inverse_sigmoid(
+ layer_ref_sig)
+ layer_outputs_unsig = layer_outputs_unsig.sigmoid()
+ layer_cls = layer_cls_embed(layer_hs)
+ outputs_coord_list.append(layer_outputs_unsig)
+ outputs_class.append(layer_cls)
+ else:
+ layer_hs_bbox_dn = layer_hs[:, :effec_dn_num, :]
+ layer_hs_bbox_norm = \
+ layer_hs[:, effec_dn_num:, :][:, 0::(
+ self.num_body_points + 1), :]
+ bs = layer_ref_sig.shape[0]
+ ref_before_sigmoid_bbox_dn = \
+ layer_ref_sig[:, : effec_dn_num, :]
+ ref_before_sigmoid_bbox_norm = \
+ layer_ref_sig[:, effec_dn_num:, :][:, 0::(
+ self.num_body_points + 1), :]
+ layer_delta_unsig_dn = layer_bbox_embed(layer_hs_bbox_dn)
+ layer_delta_unsig_norm = layer_bbox_embed(layer_hs_bbox_norm)
+ layer_outputs_unsig_dn = layer_delta_unsig_dn + \
+ inverse_sigmoid(ref_before_sigmoid_bbox_dn)
+ layer_outputs_unsig_dn = layer_outputs_unsig_dn.sigmoid()
+ layer_outputs_unsig_norm = layer_delta_unsig_norm + \
+ inverse_sigmoid(ref_before_sigmoid_bbox_norm)
+ layer_outputs_unsig_norm = layer_outputs_unsig_norm.sigmoid()
+ layer_outputs_unsig = torch.cat(
+ (layer_outputs_unsig_dn, layer_outputs_unsig_norm), dim=1)
+ layer_cls_dn = layer_cls_embed(layer_hs_bbox_dn)
+ layer_cls_norm = layer_cls_embed(layer_hs_bbox_norm)
+ layer_cls = torch.cat((layer_cls_dn, layer_cls_norm), dim=1)
+ outputs_class.append(layer_cls)
+ outputs_coord_list.append(layer_outputs_unsig)
+
+ # update keypoints boxes
+ outputs_keypoints_list = []
+ kpt_index = [
+ x for x in range(self.num_group * (self.num_body_points + 1))
+ if x % (self.num_body_points + 1) != 0
+ ]
+ for dec_lid, (layer_ref_sig, layer_hs) in enumerate(
+ zip(references[:-1], hidden_states)):
+ if dec_lid < self.num_box_decoder_layers:
+ assert isinstance(layer_hs, torch.Tensor)
+ bs = layer_hs.shape[0]
+ layer_res = layer_hs.new_zeros(
+ (bs, self.num_queries, self.num_body_points * 3))
+ outputs_keypoints_list.append(layer_res)
+ else:
+ bs = layer_ref_sig.shape[0]
+ layer_hs_kpt = \
+ layer_hs[:, effec_dn_num:, :].index_select(
+ 1, torch.tensor(kpt_index, device=layer_hs.device))
+ delta_xy_unsig = self.pose_embed[dec_lid -
+ self.num_box_decoder_layers](
+ layer_hs_kpt)
+ layer_ref_sig_kpt = \
+ layer_ref_sig[:, effec_dn_num:, :].index_select(
+ 1, torch.tensor(kpt_index, device=layer_hs.device))
+ layer_outputs_unsig_keypoints = delta_xy_unsig + \
+ inverse_sigmoid(layer_ref_sig_kpt[..., :2])
+ vis_xy_unsig = torch.ones_like(
+ layer_outputs_unsig_keypoints,
+ device=layer_outputs_unsig_keypoints.device)
+ xyv = torch.cat((layer_outputs_unsig_keypoints,
+ vis_xy_unsig[:, :, 0].unsqueeze(-1)),
+ dim=-1)
+ xyv = xyv.sigmoid()
+ layer_res = xyv.reshape(
+ (bs, self.num_group, self.num_body_points,
+ 3)).flatten(2, 3)
+ layer_res = self.keypoint_xyzxyz_to_xyxyzz(layer_res)
+ outputs_keypoints_list.append(layer_res)
+
+ dn_mask_dict = mask_dict
+ if self.refine_queries_num > 0 and dn_mask_dict is not None:
+ outputs_class, outputs_coord_list, outputs_keypoints_list = \
+ self.dn_post_process2(
+ outputs_class, outputs_coord_list,
+ outputs_keypoints_list, dn_mask_dict
+ )
+ # TODO:the denosing strateges are used in training stage
+
+ for idx, (_out_class, _out_bbox, _out_keypoint) in enumerate(
+ zip(outputs_class, outputs_coord_list,
+ outputs_keypoints_list)):
+ assert _out_class.shape[1] == \
+ _out_bbox.shape[1] == _out_keypoint.shape[1]
+
+ out = {
+ 'pred_logits': outputs_class[-1],
+ 'pred_boxes': outputs_coord_list[-1],
+ 'pred_keypoints': outputs_keypoints_list[-1]
+ }
+
+ # TODO: if refine_queries_num and dn_mask_dict are used:
+
+ return out
+
+ def keypoint_xyzxyz_to_xyxyzz(self, keypoints: torch.Tensor):
+ """
+ Args:
+ keypoints (torch.Tensor): ..., 51
+ """
+ res = torch.zeros_like(keypoints)
+ num_points = keypoints.shape[-1] // 3
+ res[..., 0:2 * num_points:2] = keypoints[..., 0::3]
+ res[..., 1:2 * num_points:2] = keypoints[..., 1::3]
+ res[..., 2 * num_points:] = keypoints[..., 2::3]
+ return res
+
+
+@MODELS.register_module()
+class EDPoseHead(TransformerHead):
+ """Head introduced in `Explicit Box Detection Unifies End-to-End Multi-
+ Person Pose Estimation`_ by J Yang1 et al (2023). The head is composed of
+ Encoder、Decoder、Out_head.
+
+ Code is modified from the `official github repo
+ `_.
+
+ More details can be found in the `paper
+ `_ .
+
+ Args:
+ num_queries (int): Number of query in Transformer.
+ num_feature_levels (int, optional): Number of feature levels.
+ Defaults to 4.
+ as_two_stage (bool, optional): Whether to generate the proposal
+ from the outputs of encoder. Defaults to `False`.
+ encoder (:obj:`ConfigDict` or dict, optional): Config of the
+ Transformer encoder. Defaults to None.
+ decoder (:obj:`ConfigDict` or dict, optional): Config of the
+ Transformer decoder. Defaults to None.
+ out_head (:obj:`ConfigDict` or dict, optional): Config for the
+ bounding final out head module. Defaults to None.
+ positional_encoding (:obj:`ConfigDict` or dict): Config for
+ transformer position encoding. Defaults None.
+ denosing_cfg (:obj:`ConfigDict` or dict, optional): Config of the
+ human query denoising training strategy.
+ data_decoder (:obj:`ConfigDict` or dict, optional): Config of the
+ data decoder which transform the results from output space to
+ input space.
+ dec_pred_class_embed_share (bool): Whether to share the class embed
+ layer. Default False.
+ dec_pred_bbox_embed_share (bool): Whether to share the bbox embed
+ layer. Default False.
+ refine_queries_num (int): Number of refined human content queries
+ and their position queries .
+ two_stage_keep_all_tokens (bool): Whether to keep all tokens.
+ """
+
+ def __init__(self,
+ num_queries: int = 100,
+ num_feature_levels: int = 4,
+ num_body_points: int = 17,
+ as_two_stage: bool = False,
+ encoder: OptConfigType = None,
+ decoder: OptConfigType = None,
+ out_head: OptConfigType = None,
+ positional_encoding: OptConfigType = None,
+ data_decoder: OptConfigType = None,
+ denosing_cfg: OptConfigType = None,
+ dec_pred_class_embed_share: bool = False,
+ dec_pred_bbox_embed_share: bool = False,
+ refine_queries_num: int = 100,
+ two_stage_keep_all_tokens: bool = False) -> None:
+
+ self.as_two_stage = as_two_stage
+ self.num_feature_levels = num_feature_levels
+ self.refine_queries_num = refine_queries_num
+ self.dec_pred_class_embed_share = dec_pred_class_embed_share
+ self.dec_pred_bbox_embed_share = dec_pred_bbox_embed_share
+ self.two_stage_keep_all_tokens = two_stage_keep_all_tokens
+ self.num_heads = decoder['layer_cfg']['self_attn_cfg']['num_heads']
+ self.num_group = decoder['num_group']
+ self.num_body_points = num_body_points
+ self.denosing_cfg = denosing_cfg
+ if data_decoder is not None:
+ self.data_decoder = KEYPOINT_CODECS.build(data_decoder)
+ else:
+ self.data_decoder = None
+
+ super().__init__(
+ encoder=encoder,
+ decoder=decoder,
+ out_head=out_head,
+ positional_encoding=positional_encoding,
+ num_queries=num_queries)
+
+ self.positional_encoding = PositionEmbeddingSineHW(
+ **self.positional_encoding_cfg)
+ self.encoder = DeformableDetrTransformerEncoder(**self.encoder_cfg)
+ self.decoder = EDPoseDecoder(
+ num_body_points=num_body_points, **self.decoder_cfg)
+ self.out_head = EDPoseOutHead(
+ num_body_points=num_body_points,
+ as_two_stage=as_two_stage,
+ refine_queries_num=refine_queries_num,
+ **self.out_head_cfg,
+ **self.decoder_cfg)
+
+ self.embed_dims = self.encoder.embed_dims
+ self.label_enc = nn.Embedding(
+ self.denosing_cfg['dn_labelbook_size'] + 1, self.embed_dims)
+
+ if not self.as_two_stage:
+ self.query_embedding = nn.Embedding(self.num_queries,
+ self.embed_dims)
+ self.refpoint_embedding = nn.Embedding(self.num_queries, 4)
+
+ self.level_embed = nn.Parameter(
+ torch.Tensor(self.num_feature_levels, self.embed_dims))
+
+ self.decoder.bbox_embed = self.out_head.bbox_embed
+ self.decoder.pose_embed = self.out_head.pose_embed
+ self.decoder.pose_hw_embed = self.out_head.pose_hw_embed
+ self.decoder.class_embed = self.out_head.class_embed
+
+ if self.as_two_stage:
+ self.memory_trans_fc = nn.Linear(self.embed_dims, self.embed_dims)
+ self.memory_trans_norm = nn.LayerNorm(self.embed_dims)
+ if dec_pred_class_embed_share and dec_pred_bbox_embed_share:
+ self.enc_out_bbox_embed = self.out_head.bbox_embed[0]
+ else:
+ self.enc_out_bbox_embed = copy.deepcopy(
+ self.out_head.bbox_embed[0])
+
+ if dec_pred_class_embed_share and dec_pred_bbox_embed_share:
+ self.enc_out_class_embed = self.out_head.class_embed[0]
+ else:
+ self.enc_out_class_embed = copy.deepcopy(
+ self.out_head.class_embed[0])
+
+ def init_weights(self) -> None:
+ """Initialize weights for Transformer and other components."""
+ super().init_weights()
+ for coder in self.encoder, self.decoder:
+ for p in coder.parameters():
+ if p.dim() > 1:
+ nn.init.xavier_uniform_(p)
+ for m in self.modules():
+ if isinstance(m, MultiScaleDeformableAttention):
+ m.init_weights()
+ if self.as_two_stage:
+ nn.init.xavier_uniform_(self.memory_trans_fc.weight)
+
+ nn.init.normal_(self.level_embed)
+
+ def pre_transformer(self,
+ img_feats: Tuple[Tensor],
+ batch_data_samples: OptSampleList = None
+ ) -> Tuple[Dict]:
+ """Process image features before feeding them to the transformer.
+
+ Args:
+ img_feats (tuple[Tensor]): Multi-level features that may have
+ different resolutions, output from neck. Each feature has
+ shape (bs, dim, h_lvl, w_lvl), where 'lvl' means 'layer'.
+ batch_data_samples (list[:obj:`DetDataSample`], optional): The
+ batch data samples. It usually includes information such
+ as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`.
+ Defaults to None.
+
+ Returns:
+ tuple[dict]: The first dict contains the inputs of encoder and the
+ second dict contains the inputs of decoder.
+
+ - encoder_inputs_dict (dict): The keyword args dictionary of
+ `self.encoder()`.
+ - decoder_inputs_dict (dict): The keyword args dictionary of
+ `self.forward_decoder()`, which includes 'memory_mask'.
+ """
+ batch_size = img_feats[0].size(0)
+ # construct binary masks for the transformer.
+ assert batch_data_samples is not None
+ batch_input_shape = batch_data_samples[0].batch_input_shape
+ img_shape_list = [sample.img_shape for sample in batch_data_samples]
+ input_img_h, input_img_w = batch_input_shape
+ masks = img_feats[0].new_ones((batch_size, input_img_h, input_img_w))
+ for img_id in range(batch_size):
+ img_h, img_w = img_shape_list[img_id]
+ masks[img_id, :img_h, :img_w] = 0
+ # NOTE following the official DETR repo, non-zero values representing
+ # ignored positions, while zero values means valid positions.
+
+ mlvl_masks = []
+ mlvl_pos_embeds = []
+ for feat in img_feats:
+ mlvl_masks.append(
+ F.interpolate(masks[None],
+ size=feat.shape[-2:]).to(torch.bool).squeeze(0))
+ mlvl_pos_embeds.append(self.positional_encoding(mlvl_masks[-1]))
+
+ feat_flatten = []
+ lvl_pos_embed_flatten = []
+ mask_flatten = []
+ spatial_shapes = []
+ for lvl, (feat, mask, pos_embed) in enumerate(
+ zip(img_feats, mlvl_masks, mlvl_pos_embeds)):
+ batch_size, c, h, w = feat.shape
+ # [bs, c, h_lvl, w_lvl] -> [bs, h_lvl*w_lvl, c]
+ feat = feat.view(batch_size, c, -1).permute(0, 2, 1)
+ pos_embed = pos_embed.view(batch_size, c, -1).permute(0, 2, 1)
+ lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1)
+ # [bs, h_lvl, w_lvl] -> [bs, h_lvl*w_lvl]
+ mask = mask.flatten(1)
+ spatial_shape = (h, w)
+
+ feat_flatten.append(feat)
+ lvl_pos_embed_flatten.append(lvl_pos_embed)
+ mask_flatten.append(mask)
+ spatial_shapes.append(spatial_shape)
+
+ # (bs, num_feat_points, dim)
+ feat_flatten = torch.cat(feat_flatten, 1)
+ lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1)
+ # (bs, num_feat_points), where num_feat_points = sum_lvl(h_lvl*w_lvl)
+ mask_flatten = torch.cat(mask_flatten, 1)
+
+ spatial_shapes = torch.as_tensor( # (num_level, 2)
+ spatial_shapes,
+ dtype=torch.long,
+ device=feat_flatten.device)
+ level_start_index = torch.cat((
+ spatial_shapes.new_zeros((1, )), # (num_level)
+ spatial_shapes.prod(1).cumsum(0)[:-1]))
+ valid_ratios = torch.stack( # (bs, num_level, 2)
+ [self.get_valid_ratio(m) for m in mlvl_masks], 1)
+
+ if self.refine_queries_num > 0 or batch_data_samples is not None:
+ input_query_label, input_query_bbox, humandet_attn_mask, \
+ human2pose_attn_mask, mask_dict =\
+ self.prepare_for_denosing(
+ batch_data_samples,
+ device=img_feats[0].device)
+ else:
+ assert batch_data_samples is None
+ input_query_bbox = input_query_label = \
+ humandet_attn_mask = human2pose_attn_mask = mask_dict = None
+
+ encoder_inputs_dict = dict(
+ query=feat_flatten,
+ query_pos=lvl_pos_embed_flatten,
+ key_padding_mask=mask_flatten,
+ spatial_shapes=spatial_shapes,
+ level_start_index=level_start_index,
+ valid_ratios=valid_ratios)
+ decoder_inputs_dict = dict(
+ memory_mask=mask_flatten,
+ spatial_shapes=spatial_shapes,
+ level_start_index=level_start_index,
+ valid_ratios=valid_ratios,
+ humandet_attn_mask=humandet_attn_mask,
+ human2pose_attn_mask=human2pose_attn_mask,
+ input_query_bbox=input_query_bbox,
+ input_query_label=input_query_label,
+ mask_dict=mask_dict)
+ return encoder_inputs_dict, decoder_inputs_dict
+
+ def forward_encoder(self,
+ img_feats: Tuple[Tensor],
+ batch_data_samples: OptSampleList = None) -> Dict:
+ """Forward with Transformer encoder.
+
+ The forward procedure is defined as:
+ 'pre_transformer' -> 'encoder'
+
+ Args:
+ img_feats (tuple[Tensor]): Multi-level features that may have
+ different resolutions, output from neck. Each feature has
+ shape (bs, dim, h_lvl, w_lvl), where 'lvl' means 'layer'.
+ batch_data_samples (list[:obj:`DetDataSample`], optional): The
+ batch data samples. It usually includes information such
+ as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`.
+ Defaults to None.
+
+ Returns:
+ dict: The dictionary of encoder outputs, which includes the
+ `memory` of the encoder output.
+ """
+ encoder_inputs_dict, decoder_inputs_dict = self.pre_transformer(
+ img_feats, batch_data_samples)
+
+ memory = self.encoder(**encoder_inputs_dict)
+ encoder_outputs_dict = dict(memory=memory, **decoder_inputs_dict)
+ return encoder_outputs_dict
+
+ def pre_decoder(self, memory: Tensor, memory_mask: Tensor,
+ spatial_shapes: Tensor, input_query_bbox: Tensor,
+ input_query_label: Tensor) -> Tuple[Dict, Dict]:
+ """Prepare intermediate variables before entering Transformer decoder,
+ such as `query` and `reference_points`.
+
+ Args:
+ memory (Tensor): The output embeddings of the Transformer encoder,
+ has shape (bs, num_feat_points, dim).
+ memory_mask (Tensor): ByteTensor, the padding mask of the memory,
+ has shape (bs, num_feat_points). It will only be used when
+ `as_two_stage` is `True`.
+ spatial_shapes (Tensor): Spatial shapes of features in all levels,
+ has shape (num_levels, 2), last dimension represents (h, w).
+ It will only be used when `as_two_stage` is `True`.
+ input_query_bbox (Tensor): Denosing bbox query for training.
+ input_query_label (Tensor): Denosing label query for training.
+
+ Returns:
+ tuple[dict, dict]: The decoder_inputs_dict and head_inputs_dict.
+
+ - decoder_inputs_dict (dict): The keyword dictionary args of
+ `self.decoder()`.
+ - head_inputs_dict (dict): The keyword dictionary args of the
+ bbox_head functions.
+ """
+ bs, _, c = memory.shape
+ if self.as_two_stage:
+ output_memory, output_proposals = \
+ self.gen_encoder_output_proposals(
+ memory, memory_mask, spatial_shapes)
+ enc_outputs_class = self.enc_out_class_embed(output_memory)
+ enc_outputs_coord_unact = self.enc_out_bbox_embed(
+ output_memory) + output_proposals
+
+ topk_proposals = torch.topk(
+ enc_outputs_class.max(-1)[0], self.num_queries, dim=1)[1]
+ topk_coords_undetach = torch.gather(
+ enc_outputs_coord_unact, 1,
+ topk_proposals.unsqueeze(-1).repeat(1, 1, 4))
+ topk_coords_unact = topk_coords_undetach.detach()
+ reference_points = topk_coords_unact.sigmoid()
+
+ query_undetach = torch.gather(
+ output_memory, 1,
+ topk_proposals.unsqueeze(-1).repeat(1, 1, self.embed_dims))
+ query = query_undetach.detach()
+
+ if input_query_bbox is not None:
+ reference_points = torch.cat(
+ [input_query_bbox, topk_coords_unact], dim=1).sigmoid()
+ query = torch.cat([input_query_label, query], dim=1)
+ if self.two_stage_keep_all_tokens:
+ hidden_states_enc = output_memory.unsqueeze(0)
+ referens_enc = enc_outputs_coord_unact.unsqueeze(0)
+ else:
+ hidden_states_enc = query_undetach.unsqueeze(0)
+ referens_enc = topk_coords_undetach.sigmoid().unsqueeze(0)
+ else:
+ hidden_states_enc, referens_enc = None, None
+ query = self.query_embedding.weight[:, None, :].repeat(
+ 1, bs, 1).transpose(0, 1)
+ reference_points = \
+ self.refpoint_embedding.weight[:, None, :].repeat(1, bs, 1)
+
+ if input_query_bbox is not None:
+ reference_points = torch.cat(
+ [input_query_bbox, reference_points], dim=1)
+ query = torch.cat([input_query_label, query], dim=1)
+ reference_points = reference_points.sigmoid()
+
+ decoder_inputs_dict = dict(
+ query=query, reference_points=reference_points)
+ head_inputs_dict = dict(
+ hidden_states_enc=hidden_states_enc, referens_enc=referens_enc)
+ return decoder_inputs_dict, head_inputs_dict
+
+ def forward_decoder(self, memory: Tensor, memory_mask: Tensor,
+ spatial_shapes: Tensor, level_start_index: Tensor,
+ valid_ratios: Tensor, humandet_attn_mask: Tensor,
+ human2pose_attn_mask: Tensor, input_query_bbox: Tensor,
+ input_query_label: Tensor, mask_dict: Dict) -> Dict:
+ """Forward with Transformer decoder.
+
+ The forward procedure is defined as:
+ 'pre_decoder' -> 'decoder'
+
+ Args:
+ memory (Tensor): The output embeddings of the Transformer encoder,
+ has shape (bs, num_feat_points, dim).
+ memory_mask (Tensor): ByteTensor, the padding mask of the memory,
+ has shape (bs, num_feat_points).
+ spatial_shapes (Tensor): Spatial shapes of features in all levels,
+ has shape (num_levels, 2), last dimension represents (h, w).
+ level_start_index (Tensor): The start index of each level.
+ A tensor has shape (num_levels, ) and can be represented
+ as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...].
+ valid_ratios (Tensor): The ratios of the valid width and the valid
+ height relative to the width and the height of features in all
+ levels, has shape (bs, num_levels, 2).
+ humandet_attn_mask (Tensor): Human attention mask.
+ human2pose_attn_mask (Tensor): Human to pose attention mask.
+ input_query_bbox (Tensor): Denosing bbox query for training.
+ input_query_label (Tensor): Denosing label query for training.
+
+ Returns:
+ dict: The dictionary of decoder outputs, which includes the
+ `hidden_states` of the decoder output and `references` including
+ the initial and intermediate reference_points.
+ """
+ decoder_in, head_in = self.pre_decoder(memory, memory_mask,
+ spatial_shapes,
+ input_query_bbox,
+ input_query_label)
+
+ inter_states, inter_references = self.decoder(
+ query=decoder_in['query'].transpose(0, 1),
+ value=memory.transpose(0, 1),
+ key_padding_mask=memory_mask, # for cross_attn
+ reference_points=decoder_in['reference_points'].transpose(0, 1),
+ spatial_shapes=spatial_shapes,
+ level_start_index=level_start_index,
+ valid_ratios=valid_ratios,
+ humandet_attn_mask=humandet_attn_mask,
+ human2pose_attn_mask=human2pose_attn_mask)
+ references = inter_references
+ decoder_outputs_dict = dict(
+ hidden_states=inter_states,
+ references=references,
+ mask_dict=mask_dict)
+ decoder_outputs_dict.update(head_in)
+ return decoder_outputs_dict
+
+ def forward_out_head(self, batch_data_samples: OptSampleList,
+ hidden_states: List[Tensor], references: List[Tensor],
+ mask_dict: Dict, hidden_states_enc: Tensor,
+ referens_enc: Tensor) -> Tuple[Tensor]:
+ """Forward function."""
+ out = self.out_head(hidden_states, references, mask_dict,
+ hidden_states_enc, referens_enc,
+ batch_data_samples)
+ return out
+
+ def forward(self,
+ feats: Tuple[Tensor],
+ batch_data_samples: OptSampleList = None) -> Dict:
+ """Forward the network."""
+ encoder_outputs_dict = self.forward_encoder(feats, batch_data_samples)
+
+ decoder_outputs_dict = self.forward_decoder(**encoder_outputs_dict)
+
+ head_outputs_dict = self.forward_out_head(batch_data_samples,
+ **decoder_outputs_dict)
+ return head_outputs_dict
+
+ def predict(self,
+ feats: Features,
+ batch_data_samples: OptSampleList,
+ test_cfg: ConfigType = {}) -> Predictions:
+ """Predict results from features."""
+ input_shapes = np.array(
+ [d.metainfo['input_size'] for d in batch_data_samples], )
+
+ if test_cfg.get('flip_test', False):
+ assert NotImplementedError
+ else:
+ batch_out = self.forward(feats, batch_data_samples) # (B, K, D)
+
+ pred = self.decode(input_shapes, **batch_out)
+ return pred
+
+ def decode(self, input_shapes: np.ndarray, pred_logits: Tensor,
+ pred_boxes: Tensor, pred_keypoints: Tensor):
+ """Select the final top-k keypoints, and decode the results from
+ normalize size to origin input size.
+
+ Args:
+ input_shapes (Tensor): The size of input image.
+ pred_logits (Tensor): The result of score.
+ pred_boxes (Tensor): The result of bbox.
+ pred_keypoints (Tensor): The result of keypoints.
+
+ Returns:
+ """
+
+ if self.data_decoder is None:
+ raise RuntimeError(f'The data decoder has not been set in \
+ {self.__class__.__name__}. '
+ 'Please set the data decoder configs in \
+ the init parameters to '
+ 'enable head methods `head.predict()` and \
+ `head.decode()`')
+
+ preds = []
+ batch_size = input_shapes.shape[0]
+
+ pred_logits = pred_logits.sigmoid()
+ pred_logits, pred_boxes, pred_keypoints = to_numpy(
+ [pred_logits, pred_boxes, pred_keypoints])
+
+ for b in range(batch_size):
+ keypoint, keypoint_score, bbox = self.data_decoder.decode(
+ input_shapes[b], pred_logits[b], pred_boxes[b],
+ pred_keypoints[b])
+
+ # pack outputs
+ preds.append(
+ InstanceData(
+ keypoints=keypoint,
+ keypoint_scores=keypoint_score,
+ boxes=bbox))
+
+ return preds
+
+ @staticmethod
+ def get_valid_ratio(mask: Tensor) -> Tensor:
+ """Get the valid radios of feature map in a level.
+
+ .. code:: text
+
+ |---> valid_W <---|
+ ---+-----------------+-----+---
+ A | | | A
+ | | | | |
+ | | | | |
+ valid_H | | | |
+ | | | | H
+ | | | | |
+ V | | | |
+ ---+-----------------+ | |
+ | | V
+ +-----------------------+---
+ |---------> W <---------|
+
+ The valid_ratios are defined as:
+ r_h = valid_H / H, r_w = valid_W / W
+ They are the factors to re-normalize the relative coordinates of the
+ image to the relative coordinates of the current level feature map.
+
+ Args:
+ mask (Tensor): Binary mask of a feature map, has shape (bs, H, W).
+
+ Returns:
+ Tensor: valid ratios [r_w, r_h] of a feature map, has shape (1, 2).
+ """
+ _, H, W = mask.shape
+ valid_H = torch.sum(~mask[:, :, 0], 1)
+ valid_W = torch.sum(~mask[:, 0, :], 1)
+ valid_ratio_h = valid_H.float() / H
+ valid_ratio_w = valid_W.float() / W
+ valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1)
+ return valid_ratio
+
+ def gen_encoder_output_proposals(self, memory: Tensor, memory_mask: Tensor,
+ spatial_shapes: Tensor
+ ) -> Tuple[Tensor, Tensor]:
+ """Generate proposals from encoded memory. The function will only be
+ used when `as_two_stage` is `True`.
+
+ Args:
+ memory (Tensor): The output embeddings of the Transformer encoder,
+ has shape (bs, num_feat_points, dim).
+ memory_mask (Tensor): ByteTensor, the padding mask of the memory,
+ has shape (bs, num_feat_points).
+ spatial_shapes (Tensor): Spatial shapes of features in all levels,
+ has shape (num_levels, 2), last dimension represents (h, w).
+
+ Returns:
+ tuple: A tuple of transformed memory and proposals.
+
+ - output_memory (Tensor): The transformed memory for obtaining
+ top-k proposals, has shape (bs, num_feat_points, dim).
+ - output_proposals (Tensor): The inverse-normalized proposal, has
+ shape (batch_size, num_keys, 4) with the last dimension arranged
+ as (cx, cy, w, h).
+ """
+ bs = memory.size(0)
+ proposals = []
+ _cur = 0 # start index in the sequence of the current level
+ for lvl, (H, W) in enumerate(spatial_shapes):
+ mask_flatten_ = memory_mask[:,
+ _cur:(_cur + H * W)].view(bs, H, W, 1)
+ valid_H = torch.sum(~mask_flatten_[:, :, 0, 0], 1).unsqueeze(-1)
+ valid_W = torch.sum(~mask_flatten_[:, 0, :, 0], 1).unsqueeze(-1)
+
+ grid_y, grid_x = torch.meshgrid(
+ torch.linspace(
+ 0, H - 1, H, dtype=torch.float32, device=memory.device),
+ torch.linspace(
+ 0, W - 1, W, dtype=torch.float32, device=memory.device))
+ grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1)
+
+ scale = torch.cat([valid_W, valid_H], 1).view(bs, 1, 1, 2)
+ grid = (grid.unsqueeze(0).expand(bs, -1, -1, -1) + 0.5) / scale
+ wh = torch.ones_like(grid) * 0.05 * (2.0**lvl)
+ proposal = torch.cat((grid, wh), -1).view(bs, -1, 4)
+ proposals.append(proposal)
+ _cur += (H * W)
+ output_proposals = torch.cat(proposals, 1)
+ output_proposals_valid = ((output_proposals > 0.01) &
+ (output_proposals < 0.99)).all(
+ -1, keepdim=True)
+ # inverse_sigmoid
+ output_proposals = torch.log(output_proposals / (1 - output_proposals))
+ output_proposals = output_proposals.masked_fill(
+ memory_mask.unsqueeze(-1), float('inf'))
+ output_proposals = output_proposals.masked_fill(
+ ~output_proposals_valid, float('inf'))
+
+ output_memory = memory
+ output_memory = output_memory.masked_fill(
+ memory_mask.unsqueeze(-1), float(0))
+ output_memory = output_memory.masked_fill(~output_proposals_valid,
+ float(0))
+ output_memory = self.memory_trans_fc(output_memory)
+ output_memory = self.memory_trans_norm(output_memory)
+ # [bs, sum(hw), 2]
+ return output_memory, output_proposals
+
+ @property
+ def default_init_cfg(self):
+ init_cfg = [dict(type='Normal', layer=['Linear'], std=0.01, bias=0)]
+ return init_cfg
+
+ def prepare_for_denosing(self, targets: OptSampleList, device):
+ """prepare for dn components in forward function."""
+ if not self.training:
+ bs = len(targets)
+ attn_mask_infere = torch.zeros(
+ bs,
+ self.num_heads,
+ self.num_group * (self.num_body_points + 1),
+ self.num_group * (self.num_body_points + 1),
+ device=device,
+ dtype=torch.bool)
+ group_bbox_kpt = (self.num_body_points + 1)
+ kpt_index = [
+ x for x in range(self.num_group * (self.num_body_points + 1))
+ if x % (self.num_body_points + 1) == 0
+ ]
+ for matchj in range(self.num_group * (self.num_body_points + 1)):
+ sj = (matchj // group_bbox_kpt) * group_bbox_kpt
+ ej = (matchj // group_bbox_kpt + 1) * group_bbox_kpt
+ if sj > 0:
+ attn_mask_infere[:, :, matchj, :sj] = True
+ if ej < self.num_group * (self.num_body_points + 1):
+ attn_mask_infere[:, :, matchj, ej:] = True
+ for match_x in range(self.num_group * (self.num_body_points + 1)):
+ if match_x % group_bbox_kpt == 0:
+ attn_mask_infere[:, :, match_x, kpt_index] = False
+
+ attn_mask_infere = attn_mask_infere.flatten(0, 1)
+ return None, None, None, attn_mask_infere, None
+
+ # targets, dn_scalar, noise_scale = dn_args
+ device = targets[0]['boxes'].device
+ bs = len(targets)
+ refine_queries_num = self.refine_queries_num
+
+ # gather gt boxes and labels
+ gt_boxes = [t['boxes'] for t in targets]
+ gt_labels = [t['labels'] for t in targets]
+ gt_keypoints = [t['keypoints'] for t in targets]
+
+ # repeat them
+ def get_indices_for_repeat(now_num, target_num, device='cuda'):
+ """
+ Input:
+ - now_num: int
+ - target_num: int
+ Output:
+ - indices: tensor[target_num]
+ """
+ out_indice = []
+ base_indice = torch.arange(now_num).to(device)
+ multiplier = target_num // now_num
+ out_indice.append(base_indice.repeat(multiplier))
+ residue = target_num % now_num
+ out_indice.append(base_indice[torch.randint(
+ 0, now_num, (residue, ), device=device)])
+ return torch.cat(out_indice)
+
+ gt_boxes_expand = []
+ gt_labels_expand = []
+ gt_keypoints_expand = []
+ for idx, (gt_boxes_i, gt_labels_i, gt_keypoint_i) in enumerate(
+ zip(gt_boxes, gt_labels, gt_keypoints)):
+ num_gt_i = gt_boxes_i.shape[0]
+ if num_gt_i > 0:
+ indices = get_indices_for_repeat(num_gt_i, refine_queries_num,
+ device)
+ gt_boxes_expand_i = gt_boxes_i[indices] # num_dn, 4
+ gt_labels_expand_i = gt_labels_i[indices]
+ gt_keypoints_expand_i = gt_keypoint_i[indices]
+ else:
+ # all negative samples when no gt boxes
+ gt_boxes_expand_i = torch.rand(
+ refine_queries_num, 4, device=device)
+ gt_labels_expand_i = torch.ones(
+ refine_queries_num, dtype=torch.int64,
+ device=device) * int(self.num_classes)
+ gt_keypoints_expand_i = torch.rand(
+ refine_queries_num,
+ self.num_body_points * 3,
+ device=device)
+ gt_boxes_expand.append(gt_boxes_expand_i)
+ gt_labels_expand.append(gt_labels_expand_i)
+ gt_keypoints_expand.append(gt_keypoints_expand_i)
+ gt_boxes_expand = torch.stack(gt_boxes_expand)
+ gt_labels_expand = torch.stack(gt_labels_expand)
+ gt_keypoints_expand = torch.stack(gt_keypoints_expand)
+ knwon_boxes_expand = gt_boxes_expand.clone()
+ knwon_labels_expand = gt_labels_expand.clone()
+
+ # add noise
+ if self.denosing_cfg['dn_label_noise_ratio'] > 0:
+ prob = torch.rand_like(knwon_labels_expand.float())
+ chosen_indice = prob < self.denosing_cfg['dn_label_noise_ratio']
+ new_label = torch.randint_like(
+ knwon_labels_expand[chosen_indice], 0,
+ self.dn_labelbook_size) # randomly put a new one here
+ knwon_labels_expand[chosen_indice] = new_label
+
+ if self.denosing_cfg['dn_box_noise_scale'] > 0:
+ diff = torch.zeros_like(knwon_boxes_expand)
+ diff[..., :2] = knwon_boxes_expand[..., 2:] / 2
+ diff[..., 2:] = knwon_boxes_expand[..., 2:]
+ knwon_boxes_expand += torch.mul(
+ (torch.rand_like(knwon_boxes_expand) * 2 - 1.0),
+ diff) * self.denosing_cfg['dn_box_noise_scale']
+ knwon_boxes_expand = knwon_boxes_expand.clamp(min=0.0, max=1.0)
+
+ input_query_label = self.label_enc(knwon_labels_expand)
+ input_query_bbox = inverse_sigmoid(knwon_boxes_expand)
+
+ # prepare mask
+
+ if 'group2group' in self.denosing_cfg['dn_attn_mask_type_list']:
+ attn_mask = torch.zeros(
+ bs,
+ self.num_heads,
+ refine_queries_num + self.num_queries,
+ refine_queries_num + self.num_queries,
+ device=device,
+ dtype=torch.bool)
+ attn_mask[:, :, refine_queries_num:, :refine_queries_num] = True
+ for idx, (gt_boxes_i,
+ gt_labels_i) in enumerate(zip(gt_boxes, gt_labels)):
+ num_gt_i = gt_boxes_i.shape[0]
+ if num_gt_i == 0:
+ continue
+ for matchi in range(refine_queries_num):
+ si = (matchi // num_gt_i) * num_gt_i
+ ei = (matchi // num_gt_i + 1) * num_gt_i
+ if si > 0:
+ attn_mask[idx, :, matchi, :si] = True
+ if ei < refine_queries_num:
+ attn_mask[idx, :, matchi, ei:refine_queries_num] = True
+ attn_mask = attn_mask.flatten(0, 1)
+
+ if 'group2group' in self.denosing_cfg['dn_attn_mask_type_list']:
+ attn_mask2 = torch.zeros(
+ bs,
+ self.num_heads,
+ refine_queries_num + self.num_group *
+ (self.num_body_points + 1),
+ refine_queries_num + self.num_group *
+ (self.num_body_points + 1),
+ device=device,
+ dtype=torch.bool)
+ attn_mask2[:, :, refine_queries_num:, :refine_queries_num] = True
+ group_bbox_kpt = (self.num_body_points + 1)
+ kpt_index = [
+ x for x in range(self.num_group * (self.num_body_points + 1))
+ if x % (self.num_body_points + 1) == 0
+ ]
+ for matchj in range(self.num_group * (self.num_body_points + 1)):
+ sj = (matchj // group_bbox_kpt) * group_bbox_kpt
+ ej = (matchj // group_bbox_kpt + 1) * group_bbox_kpt
+ if sj > 0:
+ attn_mask2[:, :, refine_queries_num:,
+ refine_queries_num:][:, :, matchj, :sj] = True
+ if ej < self.num_group * (self.num_body_points + 1):
+ attn_mask2[:, :, refine_queries_num:,
+ refine_queries_num:][:, :, matchj, ej:] = True
+
+ for match_x in range(self.num_group * (self.num_body_points + 1)):
+ if match_x % group_bbox_kpt == 0:
+ attn_mask2[:, :, refine_queries_num:,
+ refine_queries_num:][:, :, match_x,
+ kpt_index] = False
+
+ for idx, (gt_boxes_i,
+ gt_labels_i) in enumerate(zip(gt_boxes, gt_labels)):
+ num_gt_i = gt_boxes_i.shape[0]
+ if num_gt_i == 0:
+ continue
+ for matchi in range(refine_queries_num):
+ si = (matchi // num_gt_i) * num_gt_i
+ ei = (matchi // num_gt_i + 1) * num_gt_i
+ if si > 0:
+ attn_mask2[idx, :, matchi, :si] = True
+ if ei < refine_queries_num:
+ attn_mask2[idx, :, matchi,
+ ei:refine_queries_num] = True
+ attn_mask2 = attn_mask2.flatten(0, 1)
+
+ mask_dict = {
+ 'pad_size': refine_queries_num,
+ 'known_bboxs': gt_boxes_expand,
+ 'known_labels': gt_labels_expand,
+ 'known_keypoints': gt_keypoints_expand
+ }
+
+ return input_query_label, input_query_bbox, \
+ attn_mask, attn_mask2, mask_dict
+
+ def loss(self,
+ feats: Tuple[Tensor],
+ batch_data_samples: OptSampleList,
+ train_cfg: OptConfigType = {}) -> dict:
+ """Calculate losses from a batch of inputs and data samples."""
+
+ assert NotImplementedError
diff --git a/mmpose/models/heads/transformer_heads/transformers/__init__.py b/mmpose/models/heads/transformer_heads/transformers/__init__.py
new file mode 100644
index 0000000000..d678b2522c
--- /dev/null
+++ b/mmpose/models/heads/transformer_heads/transformers/__init__.py
@@ -0,0 +1,17 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .deformable_detr_layers import (DeformableDetrTransformerDecoder,
+ DeformableDetrTransformerDecoderLayer,
+ DeformableDetrTransformerEncoder,
+ DeformableDetrTransformerEncoderLayer)
+from .detr_layers import (DetrTransformerDecoder, DetrTransformerDecoderLayer,
+ DetrTransformerEncoder, DetrTransformerEncoderLayer)
+from .utils import MLP, PositionEmbeddingSineHW, inverse_sigmoid
+
+__all__ = [
+ 'DetrTransformerEncoder', 'DetrTransformerDecoder',
+ 'DetrTransformerEncoderLayer', 'DetrTransformerDecoderLayer',
+ 'DeformableDetrTransformerEncoder', 'DeformableDetrTransformerDecoder',
+ 'DeformableDetrTransformerEncoderLayer',
+ 'DeformableDetrTransformerDecoderLayer', 'inverse_sigmoid',
+ 'PositionEmbeddingSineHW', 'MLP'
+]
diff --git a/mmpose/models/heads/transformer_heads/transformers/deformable_detr_layers.py b/mmpose/models/heads/transformer_heads/transformers/deformable_detr_layers.py
new file mode 100644
index 0000000000..49dd526092
--- /dev/null
+++ b/mmpose/models/heads/transformer_heads/transformers/deformable_detr_layers.py
@@ -0,0 +1,251 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Optional, Tuple, Union
+
+import torch
+from mmcv.cnn import build_norm_layer
+from mmcv.cnn.bricks.transformer import FFN, MultiheadAttention
+from mmcv.ops import MultiScaleDeformableAttention
+from mmengine.model import ModuleList
+from torch import Tensor, nn
+
+from .detr_layers import (DetrTransformerDecoder, DetrTransformerDecoderLayer,
+ DetrTransformerEncoder, DetrTransformerEncoderLayer)
+from .utils import inverse_sigmoid
+
+
+class DeformableDetrTransformerEncoder(DetrTransformerEncoder):
+ """Transformer encoder of Deformable DETR."""
+
+ def _init_layers(self) -> None:
+ """Initialize encoder layers."""
+ self.layers = ModuleList([
+ DeformableDetrTransformerEncoderLayer(**self.layer_cfg)
+ for _ in range(self.num_layers)
+ ])
+ self.embed_dims = self.layers[0].embed_dims
+
+ def forward(self, query: Tensor, query_pos: Tensor,
+ key_padding_mask: Tensor, spatial_shapes: Tensor,
+ level_start_index: Tensor, valid_ratios: Tensor,
+ **kwargs) -> Tensor:
+ """Forward function of Transformer encoder.
+
+ Args:
+ query (Tensor): The input query, has shape (bs, num_queries, dim).
+ query_pos (Tensor): The positional encoding for query, has shape
+ (bs, num_queries, dim).
+ key_padding_mask (Tensor): The `key_padding_mask` of `self_attn`
+ input. ByteTensor, has shape (bs, num_queries).
+ spatial_shapes (Tensor): Spatial shapes of features in all levels,
+ has shape (num_levels, 2), last dimension represents (h, w).
+ level_start_index (Tensor): The start index of each level.
+ A tensor has shape (num_levels, ) and can be represented
+ as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...].
+ valid_ratios (Tensor): The ratios of the valid width and the valid
+ height relative to the width and the height of features in all
+ levels, has shape (bs, num_levels, 2).
+
+ Returns:
+ Tensor: Output queries of Transformer encoder, which is also
+ called 'encoder output embeddings' or 'memory', has shape
+ (bs, num_queries, dim)
+ """
+ reference_points = self.get_encoder_reference_points(
+ spatial_shapes, valid_ratios, device=query.device)
+ for layer in self.layers:
+ query = layer(
+ query=query,
+ query_pos=query_pos,
+ key_padding_mask=key_padding_mask,
+ spatial_shapes=spatial_shapes,
+ level_start_index=level_start_index,
+ valid_ratios=valid_ratios,
+ reference_points=reference_points,
+ **kwargs)
+ return query
+
+ @staticmethod
+ def get_encoder_reference_points(spatial_shapes: Tensor,
+ valid_ratios: Tensor,
+ device: Union[torch.device,
+ str]) -> Tensor:
+ """Get the reference points used in encoder.
+
+ Args:
+ spatial_shapes (Tensor): Spatial shapes of features in all levels,
+ has shape (num_levels, 2), last dimension represents (h, w).
+ valid_ratios (Tensor): The ratios of the valid width and the valid
+ height relative to the width and the height of features in all
+ levels, has shape (bs, num_levels, 2).
+ device (obj:`device` or str): The device acquired by the
+ `reference_points`.
+
+ Returns:
+ Tensor: Reference points used in decoder, has shape (bs, length,
+ num_levels, 2).
+ """
+
+ reference_points_list = []
+ for lvl, (H, W) in enumerate(spatial_shapes):
+ ref_y, ref_x = torch.meshgrid(
+ torch.linspace(
+ 0.5, H - 0.5, H, dtype=torch.float32, device=device),
+ torch.linspace(
+ 0.5, W - 0.5, W, dtype=torch.float32, device=device))
+ ref_y = ref_y.reshape(-1)[None] / (
+ valid_ratios[:, None, lvl, 1] * H)
+ ref_x = ref_x.reshape(-1)[None] / (
+ valid_ratios[:, None, lvl, 0] * W)
+ ref = torch.stack((ref_x, ref_y), -1)
+ reference_points_list.append(ref)
+ reference_points = torch.cat(reference_points_list, 1)
+ # [bs, sum(hw), num_level, 2]
+ reference_points = reference_points[:, :, None] * valid_ratios[:, None]
+ return reference_points
+
+
+class DeformableDetrTransformerDecoder(DetrTransformerDecoder):
+ """Transformer Decoder of Deformable DETR."""
+
+ def _init_layers(self) -> None:
+ """Initialize decoder layers."""
+ self.layers = ModuleList([
+ DeformableDetrTransformerDecoderLayer(**self.layer_cfg)
+ for _ in range(self.num_layers)
+ ])
+ self.embed_dims = self.layers[0].embed_dims
+ if self.post_norm_cfg is not None:
+ raise ValueError('There is not post_norm in '
+ f'{self._get_name()}')
+
+ def forward(self,
+ query: Tensor,
+ query_pos: Tensor,
+ value: Tensor,
+ key_padding_mask: Tensor,
+ reference_points: Tensor,
+ spatial_shapes: Tensor,
+ level_start_index: Tensor,
+ valid_ratios: Tensor,
+ reg_branches: Optional[nn.Module] = None,
+ **kwargs) -> Tuple[Tensor]:
+ """Forward function of Transformer decoder.
+
+ Args:
+ query (Tensor): The input queries, has shape (bs, num_queries,
+ dim).
+ query_pos (Tensor): The input positional query, has shape
+ (bs, num_queries, dim). It will be added to `query` before
+ forward function.
+ value (Tensor): The input values, has shape (bs, num_value, dim).
+ key_padding_mask (Tensor): The `key_padding_mask` of `cross_attn`
+ input. ByteTensor, has shape (bs, num_value).
+ reference_points (Tensor): The initial reference, has shape
+ (bs, num_queries, 4) with the last dimension arranged as
+ (cx, cy, w, h) when `as_two_stage` is `True`, otherwise has
+ shape (bs, num_queries, 2) with the last dimension arranged
+ as (cx, cy).
+ spatial_shapes (Tensor): Spatial shapes of features in all levels,
+ has shape (num_levels, 2), last dimension represents (h, w).
+ level_start_index (Tensor): The start index of each level.
+ A tensor has shape (num_levels, ) and can be represented
+ as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...].
+ valid_ratios (Tensor): The ratios of the valid width and the valid
+ height relative to the width and the height of features in all
+ levels, has shape (bs, num_levels, 2).
+ reg_branches: (obj:`nn.ModuleList`, optional): Used for refining
+ the regression results. Only would be passed when
+ `with_box_refine` is `True`, otherwise would be `None`.
+
+ Returns:
+ tuple[Tensor]: Outputs of Deformable Transformer Decoder.
+
+ - output (Tensor): Output embeddings of the last decoder, has
+ shape (num_queries, bs, embed_dims) when `return_intermediate`
+ is `False`. Otherwise, Intermediate output embeddings of all
+ decoder layers, has shape (num_decoder_layers, num_queries, bs,
+ embed_dims).
+ - reference_points (Tensor): The reference of the last decoder
+ layer, has shape (bs, num_queries, 4) when `return_intermediate`
+ is `False`. Otherwise, Intermediate references of all decoder
+ layers, has shape (num_decoder_layers, bs, num_queries, 4). The
+ coordinates are arranged as (cx, cy, w, h)
+ """
+ output = query
+ intermediate = []
+ intermediate_reference_points = []
+ for layer_id, layer in enumerate(self.layers):
+ if reference_points.shape[-1] == 4:
+ reference_points_input = \
+ reference_points[:, :, None] * \
+ torch.cat([valid_ratios, valid_ratios], -1)[:, None]
+ else:
+ assert reference_points.shape[-1] == 2
+ reference_points_input = \
+ reference_points[:, :, None] * \
+ valid_ratios[:, None]
+ output = layer(
+ output,
+ query_pos=query_pos,
+ value=value,
+ key_padding_mask=key_padding_mask,
+ spatial_shapes=spatial_shapes,
+ level_start_index=level_start_index,
+ valid_ratios=valid_ratios,
+ reference_points=reference_points_input,
+ **kwargs)
+
+ if reg_branches is not None:
+ tmp_reg_preds = reg_branches[layer_id](output)
+ if reference_points.shape[-1] == 4:
+ new_reference_points = tmp_reg_preds + inverse_sigmoid(
+ reference_points)
+ new_reference_points = new_reference_points.sigmoid()
+ else:
+ assert reference_points.shape[-1] == 2
+ new_reference_points = tmp_reg_preds
+ new_reference_points[..., :2] = tmp_reg_preds[
+ ..., :2] + inverse_sigmoid(reference_points)
+ new_reference_points = new_reference_points.sigmoid()
+ reference_points = new_reference_points.detach()
+
+ if self.return_intermediate:
+ intermediate.append(output)
+ intermediate_reference_points.append(reference_points)
+
+ if self.return_intermediate:
+ return torch.stack(intermediate), torch.stack(
+ intermediate_reference_points)
+
+ return output, reference_points
+
+
+class DeformableDetrTransformerEncoderLayer(DetrTransformerEncoderLayer):
+ """Encoder layer of Deformable DETR."""
+
+ def _init_layers(self) -> None:
+ """Initialize self_attn, ffn, and norms."""
+ self.self_attn = MultiScaleDeformableAttention(**self.self_attn_cfg)
+ self.embed_dims = self.self_attn.embed_dims
+ self.ffn = FFN(**self.ffn_cfg)
+ norms_list = [
+ build_norm_layer(self.norm_cfg, self.embed_dims)[1]
+ for _ in range(2)
+ ]
+ self.norms = ModuleList(norms_list)
+
+
+class DeformableDetrTransformerDecoderLayer(DetrTransformerDecoderLayer):
+ """Decoder layer of Deformable DETR."""
+
+ def _init_layers(self) -> None:
+ """Initialize self_attn, cross-attn, ffn, and norms."""
+ self.self_attn = MultiheadAttention(**self.self_attn_cfg)
+ self.cross_attn = MultiScaleDeformableAttention(**self.cross_attn_cfg)
+ self.embed_dims = self.self_attn.embed_dims
+ self.ffn = FFN(**self.ffn_cfg)
+ norms_list = [
+ build_norm_layer(self.norm_cfg, self.embed_dims)[1]
+ for _ in range(3)
+ ]
+ self.norms = ModuleList(norms_list)
diff --git a/mmpose/models/heads/transformer_heads/transformers/detr_layers.py b/mmpose/models/heads/transformer_heads/transformers/detr_layers.py
new file mode 100644
index 0000000000..d8e6772d7e
--- /dev/null
+++ b/mmpose/models/heads/transformer_heads/transformers/detr_layers.py
@@ -0,0 +1,333 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Union
+
+import torch
+from mmcv.cnn import build_norm_layer
+from mmcv.cnn.bricks.transformer import FFN, MultiheadAttention
+from mmengine import ConfigDict
+from mmengine.model import BaseModule, ModuleList
+from torch import Tensor
+
+from mmpose.utils.typing import ConfigType, OptConfigType
+
+
+class DetrTransformerEncoder(BaseModule):
+ """Encoder of DETR.
+
+ Args:
+ num_layers (int): Number of encoder layers.
+ layer_cfg (:obj:`ConfigDict` or dict): the config of each encoder
+ layer. All the layers will share the same config.
+ init_cfg (:obj:`ConfigDict` or dict, optional): the config to control
+ the initialization. Defaults to None.
+ """
+
+ def __init__(self,
+ num_layers: int,
+ layer_cfg: ConfigType,
+ init_cfg: OptConfigType = None) -> None:
+
+ super().__init__(init_cfg=init_cfg)
+ self.num_layers = num_layers
+ self.layer_cfg = layer_cfg
+ self._init_layers()
+
+ def _init_layers(self) -> None:
+ """Initialize encoder layers."""
+ self.layers = ModuleList([
+ DetrTransformerEncoderLayer(**self.layer_cfg)
+ for _ in range(self.num_layers)
+ ])
+ self.embed_dims = self.layers[0].embed_dims
+
+ def forward(self, query: Tensor, query_pos: Tensor,
+ key_padding_mask: Tensor, **kwargs) -> Tensor:
+ """Forward function of encoder.
+
+ Args:
+ query (Tensor): Input queries of encoder, has shape
+ (bs, num_queries, dim).
+ query_pos (Tensor): The positional embeddings of the queries, has
+ shape (bs, num_queries, dim).
+ key_padding_mask (Tensor): The `key_padding_mask` of `self_attn`
+ input. ByteTensor, has shape (bs, num_queries).
+
+ Returns:
+ Tensor: Has shape (bs, num_queries, dim) if `batch_first` is
+ `True`, otherwise (num_queries, bs, dim).
+ """
+ for layer in self.layers:
+ query = layer(query, query_pos, key_padding_mask, **kwargs)
+ return query
+
+
+class DetrTransformerDecoder(BaseModule):
+ """Decoder of DETR.
+
+ Args:
+ num_layers (int): Number of decoder layers.
+ layer_cfg (:obj:`ConfigDict` or dict): the config of each encoder
+ layer. All the layers will share the same config.
+ post_norm_cfg (:obj:`ConfigDict` or dict, optional): Config of the
+ post normalization layer. Defaults to `LN`.
+ return_intermediate (bool, optional): Whether to return outputs of
+ intermediate layers. Defaults to `True`,
+ init_cfg (:obj:`ConfigDict` or dict, optional): the config to control
+ the initialization. Defaults to None.
+ """
+
+ def __init__(self,
+ num_layers: int,
+ layer_cfg: ConfigType,
+ post_norm_cfg: OptConfigType = dict(type='LN'),
+ return_intermediate: bool = True,
+ init_cfg: Union[dict, ConfigDict] = None) -> None:
+ super().__init__(init_cfg=init_cfg)
+ self.layer_cfg = layer_cfg
+ self.num_layers = num_layers
+ self.post_norm_cfg = post_norm_cfg
+ self.return_intermediate = return_intermediate
+ self._init_layers()
+
+ def _init_layers(self) -> None:
+ """Initialize decoder layers."""
+ self.layers = ModuleList([
+ DetrTransformerDecoderLayer(**self.layer_cfg)
+ for _ in range(self.num_layers)
+ ])
+ self.embed_dims = self.layers[0].embed_dims
+ self.post_norm = build_norm_layer(self.post_norm_cfg,
+ self.embed_dims)[1]
+
+ def forward(self, query: Tensor, key: Tensor, value: Tensor,
+ query_pos: Tensor, key_pos: Tensor, key_padding_mask: Tensor,
+ **kwargs) -> Tensor:
+ """Forward function of decoder
+ Args:
+ query (Tensor): The input query, has shape (bs, num_queries, dim).
+ key (Tensor): The input key, has shape (bs, num_keys, dim).
+ value (Tensor): The input value with the same shape as `key`.
+ query_pos (Tensor): The positional encoding for `query`, with the
+ same shape as `query`.
+ key_pos (Tensor): The positional encoding for `key`, with the
+ same shape as `key`.
+ key_padding_mask (Tensor): The `key_padding_mask` of `cross_attn`
+ input. ByteTensor, has shape (bs, num_value).
+
+ Returns:
+ Tensor: The forwarded results will have shape
+ (num_decoder_layers, bs, num_queries, dim) if
+ `return_intermediate` is `True` else (1, bs, num_queries, dim).
+ """
+ intermediate = []
+ for layer in self.layers:
+ query = layer(
+ query,
+ key=key,
+ value=value,
+ query_pos=query_pos,
+ key_pos=key_pos,
+ key_padding_mask=key_padding_mask,
+ **kwargs)
+ if self.return_intermediate:
+ intermediate.append(self.post_norm(query))
+ query = self.post_norm(query)
+
+ if self.return_intermediate:
+ return torch.stack(intermediate)
+
+ return query.unsqueeze(0)
+
+
+class DetrTransformerEncoderLayer(BaseModule):
+ """Implements encoder layer in DETR transformer.
+
+ Args:
+ self_attn_cfg (:obj:`ConfigDict` or dict, optional): Config for self
+ attention.
+ ffn_cfg (:obj:`ConfigDict` or dict, optional): Config for FFN.
+ norm_cfg (:obj:`ConfigDict` or dict, optional): Config for
+ normalization layers. All the layers will share the same
+ config. Defaults to `LN`.
+ init_cfg (:obj:`ConfigDict` or dict, optional): Config to control
+ the initialization. Defaults to None.
+ """
+
+ def __init__(self,
+ self_attn_cfg: OptConfigType = dict(
+ embed_dims=256, num_heads=8, dropout=0.0),
+ ffn_cfg: OptConfigType = dict(
+ embed_dims=256,
+ feedforward_channels=1024,
+ num_fcs=2,
+ ffn_drop=0.,
+ act_cfg=dict(type='ReLU', inplace=True)),
+ norm_cfg: OptConfigType = dict(type='LN'),
+ init_cfg: OptConfigType = None) -> None:
+
+ super().__init__(init_cfg=init_cfg)
+
+ self.self_attn_cfg = self_attn_cfg
+ self.ffn_cfg = ffn_cfg
+ self.norm_cfg = norm_cfg
+ self._init_layers()
+
+ def _init_layers(self) -> None:
+ """Initialize self-attention, FFN, and normalization."""
+ self.self_attn = MultiheadAttention(**self.self_attn_cfg)
+ self.embed_dims = self.self_attn.embed_dims
+ self.ffn = FFN(**self.ffn_cfg)
+ norms_list = [
+ build_norm_layer(self.norm_cfg, self.embed_dims)[1]
+ for _ in range(2)
+ ]
+ self.norms = ModuleList(norms_list)
+
+ def forward(self, query: Tensor, query_pos: Tensor,
+ key_padding_mask: Tensor, **kwargs) -> Tensor:
+ """Forward function of an encoder layer.
+
+ Args:
+ query (Tensor): The input query, has shape (bs, num_queries, dim).
+ query_pos (Tensor): The positional encoding for query, with
+ the same shape as `query`.
+ key_padding_mask (Tensor): The `key_padding_mask` of `self_attn`
+ input. ByteTensor. has shape (bs, num_queries).
+ Returns:
+ Tensor: forwarded results, has shape (bs, num_queries, dim).
+ """
+ query = self.self_attn(
+ query=query,
+ key=query,
+ value=query,
+ query_pos=query_pos,
+ key_pos=query_pos,
+ key_padding_mask=key_padding_mask,
+ **kwargs)
+ query = self.norms[0](query)
+ query = self.ffn(query)
+ query = self.norms[1](query)
+
+ return query
+
+
+class DetrTransformerDecoderLayer(BaseModule):
+ """Implements decoder layer in DETR transformer.
+
+ Args:
+ self_attn_cfg (:obj:`ConfigDict` or dict, optional): Config for self
+ attention.
+ cross_attn_cfg (:obj:`ConfigDict` or dict, optional): Config for cross
+ attention.
+ ffn_cfg (:obj:`ConfigDict` or dict, optional): Config for FFN.
+ norm_cfg (:obj:`ConfigDict` or dict, optional): Config for
+ normalization layers. All the layers will share the same
+ config. Defaults to `LN`.
+ init_cfg (:obj:`ConfigDict` or dict, optional): Config to control
+ the initialization. Defaults to None.
+ """
+
+ def __init__(self,
+ self_attn_cfg: OptConfigType = dict(
+ embed_dims=256,
+ num_heads=8,
+ dropout=0.0,
+ batch_first=True),
+ cross_attn_cfg: OptConfigType = dict(
+ embed_dims=256,
+ num_heads=8,
+ dropout=0.0,
+ batch_first=True),
+ ffn_cfg: OptConfigType = dict(
+ embed_dims=256,
+ feedforward_channels=1024,
+ num_fcs=2,
+ ffn_drop=0.,
+ act_cfg=dict(type='ReLU', inplace=True),
+ ),
+ norm_cfg: OptConfigType = dict(type='LN'),
+ init_cfg: OptConfigType = None) -> None:
+
+ super().__init__(init_cfg=init_cfg)
+
+ self.self_attn_cfg = self_attn_cfg
+ self.cross_attn_cfg = cross_attn_cfg
+ self.ffn_cfg = ffn_cfg
+ self.norm_cfg = norm_cfg
+ self._init_layers()
+
+ def _init_layers(self) -> None:
+ """Initialize self-attention, FFN, and normalization."""
+ self.self_attn = MultiheadAttention(**self.self_attn_cfg)
+ self.cross_attn = MultiheadAttention(**self.cross_attn_cfg)
+ self.embed_dims = self.self_attn.embed_dims
+ self.ffn = FFN(**self.ffn_cfg)
+ norms_list = [
+ build_norm_layer(self.norm_cfg, self.embed_dims)[1]
+ for _ in range(3)
+ ]
+ self.norms = ModuleList(norms_list)
+
+ def forward(self,
+ query: Tensor,
+ key: Tensor = None,
+ value: Tensor = None,
+ query_pos: Tensor = None,
+ key_pos: Tensor = None,
+ self_attn_mask: Tensor = None,
+ cross_attn_mask: Tensor = None,
+ key_padding_mask: Tensor = None,
+ **kwargs) -> Tensor:
+ """
+ Args:
+ query (Tensor): The input query, has shape (bs, num_queries, dim).
+ key (Tensor, optional): The input key, has shape (bs, num_keys,
+ dim). If `None`, the `query` will be used. Defaults to `None`.
+ value (Tensor, optional): The input value, has the same shape as
+ `key`, as in `nn.MultiheadAttention.forward`. If `None`, the
+ `key` will be used. Defaults to `None`.
+ query_pos (Tensor, optional): The positional encoding for `query`,
+ has the same shape as `query`. If not `None`, it will be added
+ to `query` before forward function. Defaults to `None`.
+ key_pos (Tensor, optional): The positional encoding for `key`, has
+ the same shape as `key`. If not `None`, it will be added to
+ `key` before forward function. If None, and `query_pos` has the
+ same shape as `key`, then `query_pos` will be used for
+ `key_pos`. Defaults to None.
+ self_attn_mask (Tensor, optional): ByteTensor mask, has shape
+ (num_queries, num_keys), as in `nn.MultiheadAttention.forward`.
+ Defaults to None.
+ cross_attn_mask (Tensor, optional): ByteTensor mask, has shape
+ (num_queries, num_keys), as in `nn.MultiheadAttention.forward`.
+ Defaults to None.
+ key_padding_mask (Tensor, optional): The `key_padding_mask` of
+ `self_attn` input. ByteTensor, has shape (bs, num_value).
+ Defaults to None.
+
+ Returns:
+ Tensor: forwarded results, has shape (bs, num_queries, dim).
+ """
+
+ query = self.self_attn(
+ query=query,
+ key=query,
+ value=query,
+ query_pos=query_pos,
+ key_pos=query_pos,
+ attn_mask=self_attn_mask,
+ **kwargs)
+ query = self.norms[0](query)
+ query = self.cross_attn(
+ query=query,
+ key=key,
+ value=value,
+ query_pos=query_pos,
+ key_pos=key_pos,
+ attn_mask=cross_attn_mask,
+ key_padding_mask=key_padding_mask,
+ **kwargs)
+ query = self.norms[1](query)
+ query = self.ffn(query)
+ query = self.norms[2](query)
+
+ return query
diff --git a/mmpose/models/heads/transformer_heads/transformers/utils.py b/mmpose/models/heads/transformer_heads/transformers/utils.py
new file mode 100644
index 0000000000..e9a1d2abaf
--- /dev/null
+++ b/mmpose/models/heads/transformer_heads/transformers/utils.py
@@ -0,0 +1,114 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import math
+
+import torch
+import torch.nn.functional as F
+from mmcv.cnn import Linear
+from mmengine.model import BaseModule, ModuleList
+from torch import Tensor
+
+
+def inverse_sigmoid(x: Tensor, eps: float = 1e-3) -> Tensor:
+ """Inverse function of sigmoid.
+
+ Args:
+ x (Tensor): The tensor to do the inverse.
+ eps (float): EPS avoid numerical overflow. Defaults 1e-5.
+ Returns:
+ Tensor: The x has passed the inverse function of sigmoid, has the same
+ shape with input.
+ """
+ x = x.clamp(min=0, max=1)
+ x1 = x.clamp(min=eps)
+ x2 = (1 - x).clamp(min=eps)
+ return torch.log(x1 / x2)
+
+
+class MLP(BaseModule):
+ """Very simple multi-layer perceptron (also called FFN) with relu. Mostly
+ used in DETR series detectors.
+
+ Args:
+ input_dim (int): Feature dim of the input tensor.
+ hidden_dim (int): Feature dim of the hidden layer.
+ output_dim (int): Feature dim of the output tensor.
+ num_layers (int): Number of FFN layers. As the last
+ layer of MLP only contains FFN (Linear).
+ """
+
+ def __init__(self, input_dim: int, hidden_dim: int, output_dim: int,
+ num_layers: int) -> None:
+ super().__init__()
+ self.num_layers = num_layers
+ h = [hidden_dim] * (num_layers - 1)
+ self.layers = ModuleList(
+ Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Forward function of MLP.
+
+ Args:
+ x (Tensor): The input feature, has shape
+ (num_queries, bs, input_dim).
+ Returns:
+ Tensor: The output feature, has shape
+ (num_queries, bs, output_dim).
+ """
+ for i, layer in enumerate(self.layers):
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
+ return x
+
+
+class PositionEmbeddingSineHW(BaseModule):
+ """This is a more standard version of the position embedding, very similar
+ to the one used by the Attention is all you need paper, generalized to work
+ on images."""
+
+ def __init__(self,
+ num_pos_feats=64,
+ temperatureH=10000,
+ temperatureW=10000,
+ normalize=False,
+ scale=None):
+ super().__init__()
+ self.num_pos_feats = num_pos_feats
+ self.temperatureH = temperatureH
+ self.temperatureW = temperatureW
+ self.normalize = normalize
+ if scale is not None and normalize is False:
+ raise ValueError('normalize should be True if scale is passed')
+ if scale is None:
+ scale = 2 * math.pi
+ self.scale = scale
+
+ def forward(self, mask: Tensor):
+
+ assert mask is not None
+ not_mask = ~mask
+ y_embed = not_mask.cumsum(1, dtype=torch.float32)
+ x_embed = not_mask.cumsum(2, dtype=torch.float32)
+
+ if self.normalize:
+ eps = 1e-6
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
+
+ dim_tx = torch.arange(
+ self.num_pos_feats, dtype=torch.float32, device=mask.device)
+ dim_tx = self.temperatureW**(2 * (dim_tx // 2) / self.num_pos_feats)
+ pos_x = x_embed[:, :, :, None] / dim_tx
+
+ dim_ty = torch.arange(
+ self.num_pos_feats, dtype=torch.float32, device=mask.device)
+ dim_ty = self.temperatureH**(2 * (dim_ty // 2) / self.num_pos_feats)
+ pos_y = y_embed[:, :, :, None] / dim_ty
+
+ pos_x = torch.stack(
+ (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()),
+ dim=4).flatten(3)
+ pos_y = torch.stack(
+ (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()),
+ dim=4).flatten(3)
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
+
+ return pos
diff --git a/mmpose/models/necks/__init__.py b/mmpose/models/necks/__init__.py
index b4f9105cb3..ff9cc9c2a7 100644
--- a/mmpose/models/necks/__init__.py
+++ b/mmpose/models/necks/__init__.py
@@ -1,9 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved.
+from .chanel_mapper_neck import ChannelMapper
from .fmap_proc_neck import FeatureMapProcessor
from .fpn import FPN
from .gap_neck import GlobalAveragePooling
from .posewarper_neck import PoseWarperNeck
__all__ = [
- 'GlobalAveragePooling', 'PoseWarperNeck', 'FPN', 'FeatureMapProcessor'
+ 'GlobalAveragePooling', 'PoseWarperNeck', 'FPN', 'FeatureMapProcessor',
+ 'ChannelMapper'
]
diff --git a/mmpose/models/necks/chanel_mapper_neck.py b/mmpose/models/necks/chanel_mapper_neck.py
new file mode 100644
index 0000000000..f424faf729
--- /dev/null
+++ b/mmpose/models/necks/chanel_mapper_neck.py
@@ -0,0 +1,109 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import List, Tuple, Union
+
+import torch.nn as nn
+from mmcv.cnn import ConvModule
+from mmengine.model import BaseModule
+from torch import Tensor
+
+from mmpose.registry import MODELS
+from mmpose.utils.typing import OptConfigType, OptMultiConfig
+
+
+@MODELS.register_module()
+class ChannelMapper(BaseModule):
+ """Channel Mapper to reduce/increase channels of backbone features.
+
+ This is used to reduce/increase channels of backbone features.
+
+ Args:
+ in_channels (List[int]): Number of input channels per scale.
+ out_channels (int): Number of output channels (used at each scale).
+ kernel_size (int, optional): kernel_size for reducing channels (used
+ at each scale). Default: 3.
+ conv_cfg (:obj:`ConfigDict` or dict, optional): Config dict for
+ convolution layer. Default: None.
+ norm_cfg (:obj:`ConfigDict` or dict, optional): Config dict for
+ normalization layer. Default: None.
+ act_cfg (:obj:`ConfigDict` or dict, optional): Config dict for
+ activation layer in ConvModule. Default: dict(type='ReLU').
+ num_outs (int, optional): Number of output feature maps. There would
+ be extra_convs when num_outs larger than the length of in_channels.
+ init_cfg (:obj:`ConfigDict` or dict or list[:obj:`ConfigDict` or dict],
+ optional): Initialization config dict.
+ Example:
+ >>> import torch
+ >>> in_channels = [2, 3, 5, 7]
+ >>> scales = [340, 170, 84, 43]
+ >>> inputs = [torch.rand(1, c, s, s)
+ ... for c, s in zip(in_channels, scales)]
+ >>> self = ChannelMapper(in_channels, 11, 3).eval()
+ >>> outputs = self.forward(inputs)
+ >>> for i in range(len(outputs)):
+ ... print(f'outputs[{i}].shape = {outputs[i].shape}')
+ outputs[0].shape = torch.Size([1, 11, 340, 340])
+ outputs[1].shape = torch.Size([1, 11, 170, 170])
+ outputs[2].shape = torch.Size([1, 11, 84, 84])
+ outputs[3].shape = torch.Size([1, 11, 43, 43])
+ """
+
+ def __init__(
+ self,
+ in_channels: List[int],
+ out_channels: int,
+ kernel_size: int = 3,
+ conv_cfg: OptConfigType = None,
+ norm_cfg: OptConfigType = None,
+ act_cfg: OptConfigType = dict(type='ReLU'),
+ num_outs: int = None,
+ bias: Union[bool, str] = True,
+ init_cfg: OptMultiConfig = dict(
+ type='Xavier', layer='Conv2d', distribution='uniform')
+ ) -> None:
+ super().__init__(init_cfg=init_cfg)
+ assert isinstance(in_channels, list)
+ self.extra_convs = None
+ if num_outs is None:
+ num_outs = len(in_channels)
+ self.convs = nn.ModuleList()
+ for in_channel in in_channels:
+ self.convs.append(
+ ConvModule(
+ in_channel,
+ out_channels,
+ kernel_size,
+ bias=bias,
+ padding=(kernel_size - 1) // 2,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg))
+ if num_outs > len(in_channels):
+ self.extra_convs = nn.ModuleList()
+ for i in range(len(in_channels), num_outs):
+ if i == len(in_channels):
+ in_channel = in_channels[-1]
+ else:
+ in_channel = out_channels
+ self.extra_convs.append(
+ ConvModule(
+ in_channel,
+ out_channels,
+ 3,
+ stride=2,
+ padding=1,
+ bias=bias,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg))
+
+ def forward(self, inputs: Tuple[Tensor]) -> Tuple[Tensor]:
+ """Forward function."""
+ assert len(inputs) == len(self.convs)
+ outs = [self.convs[i](inputs[i]) for i in range(len(inputs))]
+ if self.extra_convs:
+ for i in range(len(self.extra_convs)):
+ if i == 0:
+ outs.append(self.extra_convs[0](inputs[-1]))
+ else:
+ outs.append(self.extra_convs[i](outs[-1]))
+ return tuple(outs)
diff --git a/tests/test_datasets/test_transforms/test_bottomup_transforms.py b/tests/test_datasets/test_transforms/test_bottomup_transforms.py
index cded7a6efb..8d9213c729 100644
--- a/tests/test_datasets/test_transforms/test_bottomup_transforms.py
+++ b/tests/test_datasets/test_transforms/test_bottomup_transforms.py
@@ -6,7 +6,9 @@
from mmcv.transforms import Compose
from mmpose.datasets.transforms import (BottomupGetHeatmapMask,
- BottomupRandomAffine, BottomupResize,
+ BottomupRandomAffine,
+ BottomupRandomChoiceResize,
+ BottomupRandomCrop, BottomupResize,
RandomFlip)
from mmpose.testing import get_coco_sample
@@ -145,3 +147,166 @@ def test_transform(self):
self.assertIsInstance(results['input_scale'], np.ndarray)
self.assertEqual(results['img'][0].shape, (256, 256, 3))
self.assertEqual(results['img'][1].shape, (384, 384, 3))
+
+
+class TestBottomupRandomCrop(TestCase):
+
+ def setUp(self):
+ # test invalid crop_type
+ with self.assertRaisesRegex(ValueError, 'Invalid crop_type'):
+ BottomupRandomCrop(crop_size=(10, 10), crop_type='unknown')
+
+ crop_type_list = ['absolute', 'absolute_range']
+ for crop_type in crop_type_list:
+ # test h > 0 and w > 0
+ for crop_size in [(0, 0), (0, 1), (1, 0)]:
+ with self.assertRaises(AssertionError):
+ BottomupRandomCrop(
+ crop_size=crop_size, crop_type=crop_type)
+ # test type(h) = int and type(w) = int
+ for crop_size in [(1.0, 1), (1, 1.0), (1.0, 1.0)]:
+ with self.assertRaises(AssertionError):
+ BottomupRandomCrop(
+ crop_size=crop_size, crop_type=crop_type)
+
+ # test crop_size[0] <= crop_size[1]
+ with self.assertRaises(AssertionError):
+ BottomupRandomCrop(crop_size=(10, 5), crop_type='absolute_range')
+
+ # test h in (0, 1] and w in (0, 1]
+ crop_type_list = ['relative_range', 'relative']
+ for crop_type in crop_type_list:
+ for crop_size in [(0, 1), (1, 0), (1.1, 0.5), (0.5, 1.1)]:
+ with self.assertRaises(AssertionError):
+ BottomupRandomCrop(
+ crop_size=crop_size, crop_type=crop_type)
+
+ self.data_info = get_coco_sample(img_shape=(24, 32))
+
+ def test_transform(self):
+ # test relative and absolute crop
+ src_results = self.data_info
+ target_shape = (12, 16)
+ for crop_type, crop_size in zip(['relative', 'absolute'], [(0.5, 0.5),
+ (16, 12)]):
+ transform = BottomupRandomCrop(
+ crop_size=crop_size, crop_type=crop_type)
+ results = transform(deepcopy(src_results))
+ self.assertEqual(results['img'].shape[:2], target_shape)
+
+ # test absolute_range crop
+ transform = BottomupRandomCrop(
+ crop_size=(10, 20), crop_type='absolute_range')
+ results = transform(deepcopy(src_results))
+ h, w = results['img'].shape[:2]
+ self.assertTrue(10 <= w <= 20)
+ self.assertTrue(10 <= h <= 20)
+ self.assertEqual(results['img_shape'], results['img'].shape[:2])
+ # test relative_range crop
+ transform = BottomupRandomCrop(
+ crop_size=(0.5, 0.5), crop_type='relative_range')
+ results = transform(deepcopy(src_results))
+ h, w = results['img'].shape[:2]
+ self.assertTrue(16 <= w <= 32)
+ self.assertTrue(12 <= h <= 24)
+ self.assertEqual(results['img_shape'], results['img'].shape[:2])
+
+ # test with keypoints, bbox, segmentation
+ src_results = get_coco_sample(img_shape=(10, 10), num_instances=2)
+ segmentation = np.random.randint(0, 255, size=(10, 10), dtype=np.uint8)
+ keypoints = np.ones_like(src_results['keypoints']) * 5
+ src_results['segmentation'] = segmentation
+ src_results['keypoints'] = keypoints
+ transform = BottomupRandomCrop(
+ crop_size=(7, 5),
+ allow_negative_crop=False,
+ recompute_bbox=False,
+ bbox_clip_border=True)
+ results = transform(deepcopy(src_results))
+ h, w = results['img'].shape[:2]
+ self.assertEqual(h, 5)
+ self.assertEqual(w, 7)
+ self.assertEqual(results['bbox'].shape[0], 2)
+ self.assertTrue(results['keypoints_visible'].all())
+ self.assertTupleEqual(results['segmentation'].shape[:2], (5, 7))
+ self.assertEqual(results['img_shape'], results['img'].shape[:2])
+
+ # test bbox_clip_border = False
+ transform = BottomupRandomCrop(
+ crop_size=(10, 11),
+ allow_negative_crop=False,
+ recompute_bbox=True,
+ bbox_clip_border=False)
+ results = transform(deepcopy(src_results))
+ self.assertTrue((results['bbox'] == src_results['bbox']).all())
+
+ # test the crop does not contain any gt-bbox
+ # allow_negative_crop = False
+ img = np.random.randint(0, 255, size=(10, 10), dtype=np.uint8)
+ bbox = np.zeros((0, 4), dtype=np.float32)
+ src_results = {'img': img, 'bbox': bbox}
+ transform = BottomupRandomCrop(
+ crop_size=(5, 3), allow_negative_crop=False)
+ results = transform(deepcopy(src_results))
+ self.assertIsNone(results)
+
+ # allow_negative_crop = True
+ img = np.random.randint(0, 255, size=(10, 10), dtype=np.uint8)
+ bbox = np.zeros((0, 4), dtype=np.float32)
+ src_results = {'img': img, 'bbox': bbox}
+ transform = BottomupRandomCrop(
+ crop_size=(5, 3), allow_negative_crop=True)
+ results = transform(deepcopy(src_results))
+ self.assertTrue(isinstance(results, dict))
+
+
+class TestBottomupRandomChoiceResize(TestCase):
+
+ def setUp(self):
+ self.data_info = get_coco_sample(img_shape=(300, 400))
+
+ def test_transform(self):
+ results = dict()
+ # test with one scale
+ transform = BottomupRandomChoiceResize(scales=[(1333, 800)])
+ results = deepcopy(self.data_info)
+ results = transform(results)
+ self.assertEqual(results['img'].shape, (800, 1333, 3))
+
+ # test with multi scales
+ _scale_choice = [(1333, 800), (1333, 600)]
+ transform = BottomupRandomChoiceResize(scales=_scale_choice)
+ results = deepcopy(self.data_info)
+ results = transform(results)
+ self.assertIn((results['img'].shape[1], results['img'].shape[0]),
+ _scale_choice)
+
+ # test keep_ratio
+ transform = BottomupRandomChoiceResize(
+ scales=[(900, 600)], resize_type='Resize', keep_ratio=True)
+ results = deepcopy(self.data_info)
+ _input_ratio = results['img'].shape[0] / results['img'].shape[1]
+ results = transform(results)
+ _output_ratio = results['img'].shape[0] / results['img'].shape[1]
+ self.assertLess(abs(_input_ratio - _output_ratio), 1.5 * 1e-3)
+
+ # test clip_object_border
+ bbox = [[200, 150, 600, 450]]
+ transform = BottomupRandomChoiceResize(
+ scales=[(200, 150)], resize_type='Resize', clip_object_border=True)
+ results = deepcopy(self.data_info)
+ results['bbox'] = np.array(bbox)
+ results = transform(results)
+ self.assertEqual(results['img'].shape, (150, 200, 3))
+ self.assertTrue((results['bbox'] == np.array([[100, 75, 200,
+ 150]])).all())
+
+ transform = BottomupRandomChoiceResize(
+ scales=[(200, 150)],
+ resize_type='Resize',
+ clip_object_border=False)
+ results = self.data_info
+ results['bbox'] = np.array(bbox)
+ results = transform(results)
+ assert results['img'].shape == (150, 200, 3)
+ assert np.equal(results['bbox'], np.array([[100, 75, 300, 225]])).all()