From 537f084df14b845eced8452c8d313fde52f795b1 Mon Sep 17 00:00:00 2001 From: LiuYi <1150854440@qq.com> Date: Wed, 13 Sep 2023 13:22:43 +0800 Subject: [PATCH] support edpose --- .../edpose/coco/edpose_coco.md | 59 + .../edpose/coco/edpose_coco.yml | 25 + .../edpose/coco/edpose_res50_coco.py | 217 +++ docs/src/papers/algorithms/edpose.md | 31 + mmpose/codecs/__init__.py | 3 +- mmpose/codecs/edpose_label.py | 170 ++ mmpose/datasets/transforms/__init__.py | 6 +- .../transforms/bottomup_transforms.py | 409 ++++- mmpose/models/data_preprocessors/__init__.py | 4 +- .../data_preprocessors/data_preprocessor.py | 184 +++ mmpose/models/heads/__init__.py | 4 +- .../heads/transformer_heads/__init__.py | 19 + .../base_transformer_head.py | 112 ++ .../heads/transformer_heads/edpose_head.py | 1439 +++++++++++++++++ .../transformers/__init__.py | 17 + .../transformers/deformable_detr_layers.py | 251 +++ .../transformers/detr_layers.py | 333 ++++ .../transformer_heads/transformers/utils.py | 114 ++ mmpose/models/necks/__init__.py | 4 +- mmpose/models/necks/chanel_mapper_neck.py | 109 ++ .../test_bottomup_transforms.py | 167 +- 21 files changed, 3668 insertions(+), 9 deletions(-) create mode 100644 configs/body_2d_keypoint/edpose/coco/edpose_coco.md create mode 100644 configs/body_2d_keypoint/edpose/coco/edpose_coco.yml create mode 100644 configs/body_2d_keypoint/edpose/coco/edpose_res50_coco.py create mode 100644 docs/src/papers/algorithms/edpose.md create mode 100644 mmpose/codecs/edpose_label.py create mode 100644 mmpose/models/heads/transformer_heads/__init__.py create mode 100644 mmpose/models/heads/transformer_heads/base_transformer_head.py create mode 100644 mmpose/models/heads/transformer_heads/edpose_head.py create mode 100644 mmpose/models/heads/transformer_heads/transformers/__init__.py create mode 100644 mmpose/models/heads/transformer_heads/transformers/deformable_detr_layers.py create mode 100644 mmpose/models/heads/transformer_heads/transformers/detr_layers.py create mode 100644 mmpose/models/heads/transformer_heads/transformers/utils.py create mode 100644 mmpose/models/necks/chanel_mapper_neck.py diff --git a/configs/body_2d_keypoint/edpose/coco/edpose_coco.md b/configs/body_2d_keypoint/edpose/coco/edpose_coco.md new file mode 100644 index 0000000000..76dca297f7 --- /dev/null +++ b/configs/body_2d_keypoint/edpose/coco/edpose_coco.md @@ -0,0 +1,59 @@ + + +
+ED-Pose (ICLR'2023) + +```bibtex +@inproceedings{ +yang2023explicit, +title={Explicit Box Detection Unifies End-to-End Multi-Person Pose Estimation}, +author={Jie Yang and Ailing Zeng and Shilong Liu and Feng Li and Ruimao Zhang and Lei Zhang}, +booktitle={International Conference on Learning Representations}, +year={2023}, +url={https://openreview.net/forum?id=s4WVupnJjmX} +} +``` + +
+ + + +
+ResNet (CVPR'2016) + +```bibtex +@inproceedings{he2016deep, + title={Deep residual learning for image recognition}, + author={He, Kaiming and Zhang, Xiangyu and Ren, Shaoqing and Sun, Jian}, + booktitle={Proceedings of the IEEE conference on computer vision and pattern recognition}, + pages={770--778}, + year={2016} +} +``` + +
+ + + +
+COCO (ECCV'2014) + +```bibtex +@inproceedings{lin2014microsoft, + title={Microsoft coco: Common objects in context}, + author={Lin, Tsung-Yi and Maire, Michael and Belongie, Serge and Hays, James and Perona, Pietro and Ramanan, Deva and Doll{\'a}r, Piotr and Zitnick, C Lawrence}, + booktitle={European conference on computer vision}, + pages={740--755}, + year={2014}, + organization={Springer} +} +``` + +
+ +Results on COCO val2017 + +| Arch | BackBone | AP | AP50 | AP75 | AR | AR50 | ckpt | log | +| :-------------------------------------------- | :-------: | :---: | :-------------: | :-------------: | :---: | :-------------: | :--------------------------------------------: | :-------------------------------------------: | +| [edpose_res50_coco](/configs/body_2d_keypoint/edpose/coco/edpose_res50_coco.py) | ResNet-50 | 0.716 | 0.897 | 0.783 | 0.793 | 0.943 | [ckpt](https://download.openmmlab.com/mmpose/v1/body_2d_keypoint/edpose/coco/edpose_res50_coco_3rdparty.pth) | [log](https://download.openmmlab.com/mmpose/v1/body_2d_keypoint/edpose/coco/edpose_res50_coco_3rdparty.json) | +| | | | | | | | | | diff --git a/configs/body_2d_keypoint/edpose/coco/edpose_coco.yml b/configs/body_2d_keypoint/edpose/coco/edpose_coco.yml new file mode 100644 index 0000000000..4b2901d625 --- /dev/null +++ b/configs/body_2d_keypoint/edpose/coco/edpose_coco.yml @@ -0,0 +1,25 @@ +Collections: +- Name: ED-Pose + Paper: + Title: Explicit Box Detection Unifies End-to-End Multi-Person Pose Estimation + URL: https://arxiv.org/pdf/2302.01593.pdf + README: https://github.com/open-mmlab/mmpose/blob/main/docs/src/papers/algorithms/edpose.md +Models: +- Config: configs/body_2d_keypoint/edpose/coco/edpose_resnet50_coco.py + In Collection: ED-Pose + Metadata: + Architecture: &id001 + - ED-Pose + - ResNet + Training Data: COCO + Name: edpose_resnet50_coco + Results: + - Dataset: COCO + Metrics: + AP: 0.716 + AP@0.5: 0.897 + AP@0.75: 0.783 + AR: 0.793 + AR@0.5: 0.943 + Task: Body 2D Keypoint + Weights: https://download.openmmlab.com/mmpose/v1/body_2d_keypoint/edpose/coco/edpose_res50_coco_3rdparty.pth diff --git a/configs/body_2d_keypoint/edpose/coco/edpose_res50_coco.py b/configs/body_2d_keypoint/edpose/coco/edpose_res50_coco.py new file mode 100644 index 0000000000..1c78377b8b --- /dev/null +++ b/configs/body_2d_keypoint/edpose/coco/edpose_res50_coco.py @@ -0,0 +1,217 @@ +_base_ = ['../../../_base_/default_runtime.py'] + +# runtime +train_cfg = dict(max_epochs=140, val_interval=10) + +# optimizer +optim_wrapper = dict(optimizer=dict( + type='Adam', + lr=1e-3, +)) + +# learning policy +param_scheduler = [ + dict( + type='LinearLR', begin=0, end=500, start_factor=0.001, + by_epoch=False), # warm-up + dict( + type='MultiStepLR', + begin=0, + end=140, + milestones=[90, 120], + gamma=0.1, + by_epoch=True) +] + +# automatically scaling LR based on the actual training batch size +auto_scale_lr = dict(base_batch_size=80) + +# hooks +default_hooks = dict(checkpoint=dict(save_best='coco/AP', rule='greater')) + +# codec settings +codec = dict( + type='EDPoseLabel', num_select=50, num_body_points=17, not_to_xyxy=False) + +# model settings +model = dict( + type='BottomupPoseEstimator', + data_preprocessor=dict( + type='BatchShapeDataPreprocessor', + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225], + bgr_to_rgb=True, + pad_size_divisor=1, + normalize_bakend='pillow'), + backbone=dict( + type='ResNet', + depth=50, + num_stages=4, + out_indices=(1, 2, 3), + frozen_stages=1, + norm_cfg=dict(type='FrozenBatchNorm2d', requires_grad=False), + norm_eval=True, + style='pytorch', + init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')), + neck=dict( + type='ChannelMapper', + in_channels=[512, 1024, 2048], + kernel_size=1, + out_channels=256, + act_cfg=None, + norm_cfg=dict(type='GN', num_groups=32), + num_outs=4), + head=dict( + type='EDPoseHead', + num_queries=900, + num_feature_levels=4, + num_body_points=17, + as_two_stage=True, + encoder=dict( + num_layers=6, + layer_cfg=dict( # DeformableDetrTransformerEncoderLayer + self_attn_cfg=dict( # MultiScaleDeformableAttention + embed_dims=256, + num_heads=8, + num_levels=4, + num_points=4, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=2048, + num_fcs=2, + ffn_drop=0.0))), + decoder=dict( + num_layers=6, + embed_dims=256, + layer_cfg=dict( # DeformableDetrTransformerDecoderLayer + self_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + batch_first=True), + cross_attn_cfg=dict( # MultiScaleDeformableAttention + embed_dims=256, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, feedforward_channels=2048, ffn_drop=0.1)), + query_dim=4, + num_feature_levels=4, + num_group=100, + num_dn=100, + num_box_decoder_layers=2, + return_intermediate=True), + out_head=dict(num_classes=2), + positional_encoding=dict( + num_pos_feats=128, + temperatureH=20, + temperatureW=20, + normalize=True), + denosing_cfg=dict( + dn_box_noise_scale=0.4, + dn_label_noise_ratio=0.5, + dn_labelbook_size=100, + dn_attn_mask_type_list=['match2dn', 'dn2dn', 'group2group']), + data_decoder=codec), + test_cfg=dict(Pmultiscale_test=False, flip_test=False, num_select=50), + train_cfg=dict()) + +# enable DDP training when rescore net is used +find_unused_parameters = True + +# base dataset settings +dataset_type = 'CocoDataset' +data_mode = 'bottomup' +data_root = 'data/coco/' + +# pipelines +train_pipeline = [ + dict(type='LoadImage'), + dict(type='RandomFlip', direction='horizontal'), + dict( + type='RandomChoice', + transforms=[ + [ + dict( + type='RandomChoiceResize', + scales=[(480, 1333), (512, 1333), (544, 1333), (576, 1333), + (608, 1333), (640, 1333), (672, 1333), (704, 1333), + (736, 1333), (768, 1333), (800, 1333)], + keep_ratio=True) + ], + [ + dict( + type='BottomupRandomChoiceResize', + # The radio of all image in train dataset < 7 + # follow the original implement + scales=[(400, 4200), (500, 4200), (600, 4200)], + keep_ratio=True), + dict( + type='BottomupRandomCrop', + crop_type='absolute_range', + crop_size=(384, 600), + allow_negative_crop=True), + dict( + type='BottomupRandomChoiceResize', + scales=[(480, 1333), (512, 1333), (544, 1333), (576, 1333), + (608, 1333), (640, 1333), (672, 1333), (704, 1333), + (736, 1333), (768, 1333), (800, 1333)], + keep_ratio=True) + ] + ]), + dict(type='PackPoseInputs'), +] + +val_pipeline = [ + dict(type='LoadImage', imdecode_backend='pillow'), + dict( + type='BottomupRandomChoiceResize', + scales=[(800, 1333)], + keep_ratio=True, + backend='pillow'), + dict( + type='PackPoseInputs', + meta_keys=('id', 'img_id', 'img_path', 'crowd_index', 'ori_shape', + 'img_shape', 'input_size', 'input_center', 'input_scale', + 'flip', 'flip_direction', 'flip_indices', 'raw_ann_info', + 'skeleton_links')) +] + +# data loaders +train_dataloader = dict( + batch_size=1, + num_workers=1, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict( + type=dataset_type, + data_root=data_root, + data_mode=data_mode, + ann_file='annotations/person_keypoints_train2017.json', + data_prefix=dict(img='train2017/'), + pipeline=train_pipeline, + )) +val_dataloader = dict( + batch_size=4, + num_workers=8, + persistent_workers=True, + drop_last=False, + sampler=dict(type='DefaultSampler', shuffle=False, round_up=False), + dataset=dict( + type=dataset_type, + data_root=data_root, + data_mode=data_mode, + ann_file='annotations/person_keypoints_val2017.json', + data_prefix=dict(img='val2017/'), + test_mode=True, + pipeline=val_pipeline, + )) +test_dataloader = val_dataloader + +# evaluators +val_evaluator = dict( + type='CocoMetric', + ann_file=data_root + 'annotations/person_keypoints_val2017.json', + nms_mode='none', + score_mode='keypoint', +) +test_evaluator = val_evaluator diff --git a/docs/src/papers/algorithms/edpose.md b/docs/src/papers/algorithms/edpose.md new file mode 100644 index 0000000000..07acf2edb5 --- /dev/null +++ b/docs/src/papers/algorithms/edpose.md @@ -0,0 +1,31 @@ +# Explicit Box Detection Unifies End-to-End Multi-Person Pose Estimation + + + +
+ED-Pose (ICLR'2023) + +```bibtex +@inproceedings{ +yang2023explicit, +title={Explicit Box Detection Unifies End-to-End Multi-Person Pose Estimation}, +author={Jie Yang and Ailing Zeng and Shilong Liu and Feng Li and Ruimao Zhang and Lei Zhang}, +booktitle={International Conference on Learning Representations}, +year={2023}, +url={https://openreview.net/forum?id=s4WVupnJjmX} +} +``` + +
+ +## Abstract + + + +This paper presents a novel end-to-end framework with Explicit box Detection for multi-person Pose estimation, called ED-Pose, where it unifies the contextual learning between human-level (global) and keypoint-level (local) information. Different from previous one-stage methods, ED-Pose re-considers this task as two explicit box detection processes with a unified representation and regression supervision. First, we introduce a human detection decoder from encoded tokens to extract global features. It can provide a good initialization for the latter keypoint detection, making the training process converge fast. Second, to bring in contextual information near keypoints, we regard pose estimation as a keypoint box detection problem to learn both box positions and contents for each keypoint. A human-to-keypoint detection decoder adopts an interactive learning strategy between human and keypoint features to further enhance global and local feature aggregation. In general, ED-Pose is conceptually simple without post-processing and dense heatmap supervision. It demonstrates its effectiveness and efficiency compared with both two-stage and one-stage methods. Notably, explicit box detection boosts the pose estimation performance by 4.5 AP on COCO and 9.9 AP on CrowdPose. For the first time, as a fully end-to-end framework with a L1 regression loss, ED-Pose surpasses heatmap-based Top-down methods under the same backbone by 1.2 AP on COCO and achieves the state-of-the-art with 76.6 AP on CrowdPose without bells and whistles. Code is available at https://github.com/IDEA-Research/ED-Pose. + + + +
+ +
diff --git a/mmpose/codecs/__init__.py b/mmpose/codecs/__init__.py index 1a48b7f851..224d860457 100644 --- a/mmpose/codecs/__init__.py +++ b/mmpose/codecs/__init__.py @@ -1,6 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. from .associative_embedding import AssociativeEmbedding from .decoupled_heatmap import DecoupledHeatmap +from .edpose_label import EDPoseLabel from .image_pose_lifting import ImagePoseLifting from .integral_regression_label import IntegralRegressionLabel from .megvii_heatmap import MegviiHeatmap @@ -16,5 +17,5 @@ 'MSRAHeatmap', 'MegviiHeatmap', 'UDPHeatmap', 'RegressionLabel', 'SimCCLabel', 'IntegralRegressionLabel', 'AssociativeEmbedding', 'SPR', 'DecoupledHeatmap', 'VideoPoseLifting', 'ImagePoseLifting', - 'MotionBERTLabel' + 'MotionBERTLabel', 'EDPoseLabel' ] diff --git a/mmpose/codecs/edpose_label.py b/mmpose/codecs/edpose_label.py new file mode 100644 index 0000000000..bc0bef61e8 --- /dev/null +++ b/mmpose/codecs/edpose_label.py @@ -0,0 +1,170 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional + +import numpy as np + +from mmpose.registry import KEYPOINT_CODECS +from .base import BaseKeypointCodec + + +@KEYPOINT_CODECS.register_module() +class EDPoseLabel(BaseKeypointCodec): + r"""Generate keypoint and label coordinates for `ED-Pose`_ by + Yang J. et al (2023). + + Note: + + - instance number: N + - keypoint number: K + - keypoint dimension: D + - image size: [w, h] + + Encoded: + + - keypoints (np.ndarray): Keypoint coordinates in shape (N, K, D) + - keypoints_visible (np.ndarray): Keypoint visibility in shape + (N, K, D) + - area (np.ndarray): Area in shape (N) + - bbox (np.ndarray): Bbox in shape (N, 4) + + Args: + num_select (int): The number of candidate keypoints + num_body_points (int): The Number of keypoints + not_to_xyxy (bool): Whether convert bbox from cxcy to + xyxy. + """ + + auxiliary_encode_keys = {'area', 'bboxes', 'img_shape'} + + def __init__(self, + num_select: int = 100, + num_body_points: int = 17, + not_to_xyxy: bool = False): + super().__init__() + + self.num_select = num_select + self.num_body_points = num_body_points + self.not_to_xyxy = not_to_xyxy + + def encode( + self, + img_shape, + keypoints: np.ndarray, + keypoints_visible: Optional[np.ndarray] = None, + area: Optional[np.ndarray] = None, + bboxes: Optional[np.ndarray] = None, + ) -> dict: + """Encoding keypoints 、area、bbox from input image space to normalized + space. + + Args: + - keypoints (np.ndarray): Keypoint coordinates in + shape (N, K, D). + - keypoints_visible (np.ndarray): Keypoint visibility in shape + (N, K, D) + - area (np.ndarray): + - bboxes (np.ndarray): + + Returns: + encoded (dict): Contains the following items: + + - keypoint_labels (np.ndarray): The processed keypoints in + shape like (N, K, D). + - keypoints_visible (np.ndarray): Keypoint visibility in shape + (N, K, D) + - area_labels (np.ndarray): The processed target + area in shape (N). + - bboxes_labels: The processed target bbox in + shape (N, 4). + """ + w, h = img_shape + + if keypoints_visible is None: + keypoints_visible = np.ones(keypoints.shape[:2], dtype=np.float32) + + if bboxes is not None: + bboxes = self.box_xyxy_to_cxcywh(bboxes) + bboxes_labels = bboxes / np.array([w, h, w, h], dtype=np.float32) + + if area is not None: + area_labels = area / ( + np.array(w, dtype=np.float32) * np.array(h, dtype=np.float32)) + + if keypoints is not None: + keypoint_labels = keypoints / np.array([w, h], dtype=np.float32) + + encoded = dict( + keypoint_labels=keypoint_labels, + area_labels=area_labels, + bboxes_labels=bboxes_labels, + keypoints_visible=keypoints_visible) + + return encoded + + def decode(self, input_shapes: np.ndarray, pred_logits: np.ndarray, + pred_boxes: np.ndarray, pred_keypoints: np.ndarray): + """Select the final top-k keypoints, and decode the results from + normalize size to origin input size. + + Args: + input_shapes (Tensor): The size of input image resize. + test_cfg (ConfigType): Config of testing. + pred_logits (Tensor): The result of score. + pred_boxes (Tensor): The result of bbox. + pred_keypoints (Tensor): The result of keypoints. + + Returns: + """ + + num_body_points = self.num_body_points + + prob = pred_logits + + prob_reshaped = prob.reshape(-1) + topk_indexes = np.argsort(-prob_reshaped)[:self.num_select] + topk_values = np.take_along_axis(prob_reshaped, topk_indexes, axis=0) + + scores = np.tile(topk_values[:, np.newaxis], [1, num_body_points]) + + # bbox + topk_boxes = topk_indexes // pred_logits.shape[1] + if self.not_to_xyxy: + boxes = pred_boxes + else: + x_c, y_c, w, h = np.split(pred_boxes, 4, axis=-1) + b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), + (y_c + 0.5 * h)] + boxes = np.concatenate(b, axis=1) + + boxes = np.take_along_axis( + boxes, np.tile(topk_boxes[:, np.newaxis], [1, 4]), axis=0) + + # from relative [0, 1] to absolute [0, height] coordinates + img_h, img_w = np.split(input_shapes, 2, axis=0) + scale_fct = np.hstack([img_w, img_h, img_w, img_h]) + boxes = boxes * scale_fct[np.newaxis, :] + + # keypoints + topk_keypoints = topk_indexes // pred_logits.shape[1] + keypoints = np.take_along_axis( + pred_keypoints, + np.tile(topk_keypoints[:, np.newaxis], [1, num_body_points * 3]), + axis=0) + + Z_pred = keypoints[:, :(num_body_points * 2)] + V_pred = keypoints[:, (num_body_points * 2):] + Z_pred = Z_pred * np.tile( + np.hstack([img_w, img_h]), [num_body_points])[np.newaxis, :] + keypoints_res = np.zeros_like(keypoints) + keypoints_res[..., 0::3] = Z_pred[..., 0::2] + keypoints_res[..., 1::3] = Z_pred[..., 1::2] + keypoints_res[..., 2::3] = V_pred[..., 0::1] + + keypoint = keypoints_res.reshape(-1, num_body_points, 3)[:, :, :2] + + return keypoint, scores, boxes + + def box_xyxy_to_cxcywh(self, x): + x0, y0, x1, y1 = x.unbind(-1) + b = [(x0 + x1) / 2, (y0 + y1) / 2, (x1 - x0), (y1 - y0)] + return np.stack(b, dim=-1) diff --git a/mmpose/datasets/transforms/__init__.py b/mmpose/datasets/transforms/__init__.py index 7ccbf7dac2..a1240a1666 100644 --- a/mmpose/datasets/transforms/__init__.py +++ b/mmpose/datasets/transforms/__init__.py @@ -1,6 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. from .bottomup_transforms import (BottomupGetHeatmapMask, BottomupRandomAffine, - BottomupResize) + BottomupRandomChoiceResize, + BottomupRandomCrop, BottomupResize) from .common_transforms import (Albumentation, GenerateTarget, GetBBoxCenterScale, PhotometricDistortion, RandomBBoxTransform, RandomFlip, @@ -16,5 +17,6 @@ 'RandomHalfBody', 'TopdownAffine', 'Albumentation', 'PhotometricDistortion', 'PackPoseInputs', 'LoadImage', 'BottomupGetHeatmapMask', 'BottomupRandomAffine', 'BottomupResize', - 'GenerateTarget', 'KeypointConverter', 'RandomFlipAroundRoot' + 'GenerateTarget', 'KeypointConverter', 'RandomFlipAroundRoot', + 'BottomupRandomCrop', 'BottomupRandomChoiceResize' ] diff --git a/mmpose/datasets/transforms/bottomup_transforms.py b/mmpose/datasets/transforms/bottomup_transforms.py index c31e0ae17d..c6ba5e5bc1 100644 --- a/mmpose/datasets/transforms/bottomup_transforms.py +++ b/mmpose/datasets/transforms/bottomup_transforms.py @@ -1,10 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Sequence, Tuple, Union import cv2 import numpy as np import xtcocotools.mask as cocomask from mmcv.image import imflip_, imresize +from mmcv.image.geometric import imrescale from mmcv.transforms import BaseTransform from mmcv.transforms.utils import cache_randomness from scipy.stats import truncnorm @@ -515,3 +516,409 @@ def transform(self, results: Dict) -> Optional[dict]: results['aug_scale'] = None return results + + +@TRANSFORMS.register_module() +class BottomupRandomCrop(BaseTransform): + """Random crop the image & bboxes & masks. + + The absolute ``crop_size`` is sampled based on ``crop_type`` and + ``image_size``, then the cropped results are generated. + + Required Keys: + + - img + - keypoints + - bbox (optional) + - masks (BitmapMasks | PolygonMasks) (optional) + + Modified Keys: + + - img + - img_shape + - keypoints + - keypoints_visible + - num_keypoints + - bbox (optional) + - bbox_score (optional) + - id (optional) + - category_id (optional) + - raw_ann_info (optional) + - iscrowd (optional) + - segmentation (optional) + - masks (optional) + + Added Keys: + + - warp_mat + + Args: + crop_size (tuple): The relative ratio or absolute pixels of + (width, height). + crop_type (str, optional): One of "relative_range", "relative", + "absolute", "absolute_range". "relative" randomly crops + (h * crop_size[0], w * crop_size[1]) part from an input of size + (h, w). "relative_range" uniformly samples relative crop size from + range [crop_size[0], 1] and [crop_size[1], 1] for height and width + respectively. "absolute" crops from an input with absolute size + (crop_size[0], crop_size[1]). "absolute_range" uniformly samples + crop_h in range [crop_size[0], min(h, crop_size[1])] and crop_w + in range [crop_size[0], min(w, crop_size[1])]. + Defaults to "absolute". + allow_negative_crop (bool, optional): Whether to allow a crop that does + not contain any bbox area. Defaults to False. + recompute_bbox (bool, optional): Whether to re-compute the boxes based + on cropped instance masks. Defaults to False. + bbox_clip_border (bool, optional): Whether clip the objects outside + the border of the image. Defaults to True. + + Note: + - If the image is smaller than the absolute crop size, return the + original image. + - If the crop does not contain any gt-bbox region and + ``allow_negative_crop`` is set to False, skip this image. + """ + + def __init__(self, + crop_size: tuple, + crop_type: str = 'absolute', + allow_negative_crop: bool = False, + recompute_bbox: bool = False, + bbox_clip_border: bool = True) -> None: + if crop_type not in [ + 'relative_range', 'relative', 'absolute', 'absolute_range' + ]: + raise ValueError(f'Invalid crop_type {crop_type}.') + if crop_type in ['absolute', 'absolute_range']: + assert crop_size[0] > 0 and crop_size[1] > 0 + assert isinstance(crop_size[0], int) and isinstance( + crop_size[1], int) + if crop_type == 'absolute_range': + assert crop_size[0] <= crop_size[1] + else: + assert 0 < crop_size[0] <= 1 and 0 < crop_size[1] <= 1 + self.crop_size = crop_size + self.crop_type = crop_type + self.allow_negative_crop = allow_negative_crop + self.bbox_clip_border = bbox_clip_border + self.recompute_bbox = recompute_bbox + + def _crop_data(self, results: dict, crop_size: Tuple[int, int], + allow_negative_crop: bool) -> Union[dict, None]: + """Function to randomly crop images, bounding boxes, masks, semantic + segmentation maps. + + Args: + results (dict): Result dict from loading pipeline. + crop_size (Tuple[int, int]): Expected absolute size after + cropping, (h, w). + allow_negative_crop (bool): Whether to allow a crop that does not + contain any bbox area. + + Returns: + results (Union[dict, None]): Randomly cropped results, 'img_shape' + key in result dict is updated according to crop size. None will + be returned when there is no valid bbox after cropping. + """ + assert crop_size[0] > 0 and crop_size[1] > 0 + img = results['img'] + margin_h = max(img.shape[0] - crop_size[0], 0) + margin_w = max(img.shape[1] - crop_size[1], 0) + offset_h, offset_w = self._rand_offset((margin_h, margin_w)) + crop_y1, crop_y2 = offset_h, offset_h + crop_size[0] + crop_x1, crop_x2 = offset_w, offset_w + crop_size[1] + + # Record the warp matrix for the RandomCrop + warp_mat = np.array([[1, 0, -offset_w], [0, 1, -offset_h], [0, 0, 1]], + dtype=np.float32) + if results.get('warp_mat', None) is None: + results['warp_mat'] = warp_mat + else: + results['warp_mat'] = warp_mat @ results['warp_mat'] + + # crop the image + img = img[crop_y1:crop_y2, crop_x1:crop_x2, ...] + img_shape = img.shape + results['img'] = img + results['img_shape'] = img_shape[:2] + + # crop bboxes accordingly and clip to the image boundary + if results.get('bbox', None) is not None: + distances = (-offset_w, -offset_h) + bboxes = results['bbox'] + bboxes = bboxes + np.tile(np.asarray(distances), 2) + + if self.bbox_clip_border: + bboxes[..., 0::2] = bboxes[..., 0::2].clip(0, img_shape[1]) + bboxes[..., 1::2] = bboxes[..., 1::2].clip(0, img_shape[0]) + + valid_inds = (bboxes[..., 0] < img_shape[1]) & \ + (bboxes[..., 1] < img_shape[0]) & \ + (bboxes[..., 2] > 0) & \ + (bboxes[..., 3] > 0) + + # If the crop does not contain any gt-bbox area and + # allow_negative_crop is False, skip this image. + if (not valid_inds.any() and not allow_negative_crop): + return None + + results['bbox'] = bboxes[valid_inds] + meta_keys = [ + 'bbox_score', 'id', 'category_id', 'raw_ann_info', 'iscrowd' + ] + for key in meta_keys: + if results.get(key): + if isinstance(results[key], list): + results[key] = np.asarray( + results[key])[valid_inds].tolist() + else: + results[key] = results[key][valid_inds] + + if results.get('keypoints', None) is not None: + keypoints = results['keypoints'] + distances = np.asarray(distances).reshape(1, 1, 2) + keypoints = keypoints + distances + if self.bbox_clip_border: + keypoints_outside_x = keypoints[:, :, 0] < 0 + keypoints_outside_y = keypoints[:, :, 1] < 0 + keypoints_outside_width = keypoints[:, :, 0] > img_shape[1] + keypoints_outside_height = keypoints[:, :, + 1] > img_shape[0] + + kpt_outside = np.logical_or.reduce( + (keypoints_outside_x, keypoints_outside_y, + keypoints_outside_width, keypoints_outside_height)) + + results['keypoints_visible'][kpt_outside] *= 0 + keypoints[:, :, 0] = keypoints[:, :, 0].clip(0, img_shape[1]) + keypoints[:, :, 1] = keypoints[:, :, 1].clip(0, img_shape[0]) + results['keypoints'] = keypoints[valid_inds] + results['keypoints_visible'] = results['keypoints_visible'][ + valid_inds] + + if results.get('segmentation', None) is not None: + results['segmentation'] = results['segmentation'][ + crop_y1:crop_y2, crop_x1:crop_x2] + + if results.get('masks', None) is not None: + results['masks'] = results['masks'][valid_inds.nonzero( + )[0]].crop(np.asarray([crop_x1, crop_y1, crop_x2, crop_y2])) + if self.recompute_bbox: + results['bbox'] = results['masks'].get_bboxes( + type(results['bbox'])) + + return results + + @cache_randomness + def _rand_offset(self, margin: Tuple[int, int]) -> Tuple[int, int]: + """Randomly generate crop offset. + + Args: + margin (Tuple[int, int]): The upper bound for the offset generated + randomly. + + Returns: + Tuple[int, int]: The random offset for the crop. + """ + margin_h, margin_w = margin + offset_h = np.random.randint(0, margin_h + 1) + offset_w = np.random.randint(0, margin_w + 1) + + return offset_h, offset_w + + @cache_randomness + def _get_crop_size(self, image_size: Tuple[int, int]) -> Tuple[int, int]: + """Randomly generates the absolute crop size based on `crop_type` and + `image_size`. + + Args: + image_size (Tuple[int, int]): (h, w). + + Returns: + crop_size (Tuple[int, int]): (crop_h, crop_w) in absolute pixels. + """ + h, w = image_size + if self.crop_type == 'absolute': + return min(self.crop_size[1], h), min(self.crop_size[0], w) + elif self.crop_type == 'absolute_range': + crop_h = np.random.randint( + min(h, self.crop_size[0]), + min(h, self.crop_size[1]) + 1) + crop_w = np.random.randint( + min(w, self.crop_size[0]), + min(w, self.crop_size[1]) + 1) + return crop_h, crop_w + elif self.crop_type == 'relative': + crop_w, crop_h = self.crop_size + return int(h * crop_h + 0.5), int(w * crop_w + 0.5) + else: + # 'relative_range' + crop_size = np.asarray(self.crop_size, dtype=np.float32) + crop_h, crop_w = crop_size + np.random.rand(2) * (1 - crop_size) + return int(h * crop_h + 0.5), int(w * crop_w + 0.5) + + def transform(self, results: dict) -> Union[dict, None]: + """Transform function to randomly crop images, bounding boxes, masks, + semantic segmentation maps. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + results (Union[dict, None]): Randomly cropped results, 'img_shape' + key in result dict is updated according to crop size. None will + be returned when there is no valid bbox after cropping. + """ + image_size = results['img'].shape[:2] + crop_size = self._get_crop_size(image_size) + results = self._crop_data(results, crop_size, self.allow_negative_crop) + return results + + +@TRANSFORMS.register_module() +class BottomupRandomChoiceResize(BaseTransform): + """Resize images & bbox & mask from a list of multiple scales. + + This transform resizes the input image to some scale. Bboxes and masks are + then resized with the same scale factor. Resize scale will be randomly + selected from ``scales``. + + How to choose the target scale to resize the image will follow the rules + below: + + - if `scale` is a list of tuple, the target scale is sampled from the list + uniformally. + - if `scale` is a tuple, the target scale will be set to the tuple. + + Required Keys: + + - img + - bbox + - keypoints + + Modified Keys: + + - img + - img_shape + - bbox + - keypoints + + Added Keys: + + - scale + - scale_factor + - scale_idx + + Args: + scales (Union[list, Tuple]): Images scales for resizing. + + **resize_kwargs: Other keyword arguments for the ``resize_type``. + """ + + def __init__( + self, + scales: Sequence[Union[int, Tuple]], + keep_ratio: bool = False, + clip_object_border: bool = True, + backend: str = 'cv2', + **resize_kwargs, + ) -> None: + super().__init__() + if isinstance(scales, list): + self.scales = scales + else: + self.scales = [scales] + + self.keep_ratio = keep_ratio + self.clip_object_border = clip_object_border + self.backend = backend + + @cache_randomness + def _random_select(self) -> Tuple[int, int]: + """Randomly select an scale from given candidates. + + Returns: + (tuple, int): Returns a tuple ``(scale, scale_dix)``, + where ``scale`` is the selected image scale and + ``scale_idx`` is the selected index in the given candidates. + """ + + scale_idx = np.random.randint(len(self.scales)) + scale = self.scales[scale_idx] + return scale, scale_idx + + def _resize_img(self, results: dict) -> None: + """Resize images with ``self.scale``.""" + + if self.keep_ratio: + + img, scale_factor = imrescale( + results['img'], + self.scale, + interpolation='bilinear', + return_scale=True, + backend=self.backend) + # the w_scale and h_scale has minor difference + # a real fix should be done in the mmcv.imrescale in the future + new_h, new_w = img.shape[:2] + h, w = results['img'].shape[:2] + w_scale = new_w / w + h_scale = new_h / h + else: + img, w_scale, h_scale = imresize( + results['img'], + self.scale, + interpolation='bilinear', + return_scale=True, + backend=self.backend) + results['img'] = img + results['img_shape'] = img.shape[:2] + results['scale_factor'] = (w_scale, h_scale) + + def _resize_bboxes(self, results: dict) -> None: + """Resize bounding boxes with ``self.scale``.""" + if results.get('bbox', None) is not None: + bboxes = results['bbox'] * np.tile( + np.array(results['scale_factor']), 2) + if self.clip_object_border: + bboxes[:, 0::2] = np.clip(bboxes[:, 0::2], 0, + results['img_shape'][1]) + bboxes[:, 1::2] = np.clip(bboxes[:, 1::2], 0, + results['img_shape'][0]) + results['bbox'] = bboxes + + def _resize_keypoints(self, results: dict) -> None: + """Resize keypoints with ``self.scale``.""" + if results.get('keypoints', None) is not None: + keypoints = results['keypoints'] + + keypoints[:, :, :2] = keypoints[:, :, :2] * np.array( + results['scale_factor']) + if self.clip_object_border: + keypoints[:, :, 0] = np.clip(keypoints[:, :, 0], 0, + results['img_shape'][1]) + keypoints[:, :, 1] = np.clip(keypoints[:, :, 1], 0, + results['img_shape'][0]) + results['keypoints'] = keypoints + + def transform(self, results: dict) -> dict: + """Apply resize transforms on results from a list of scales. + + Args: + results (dict): Result dict contains the data to transform. + + Returns: + dict: Resized results, 'img', 'bbox', + 'keypoints', 'scale', 'scale_factor', 'img_shape', + and 'keep_ratio' keys are updated in result dict. + """ + + target_scale, scale_idx = self._random_select() + + self.scale = target_scale + self._resize_img(results) + self._resize_bboxes(results) + self._resize_keypoints(results) + + results['scale_idx'] = scale_idx + return results diff --git a/mmpose/models/data_preprocessors/__init__.py b/mmpose/models/data_preprocessors/__init__.py index 7c9bd22e2b..03fd83a30c 100644 --- a/mmpose/models/data_preprocessors/__init__.py +++ b/mmpose/models/data_preprocessors/__init__.py @@ -1,4 +1,4 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .data_preprocessor import PoseDataPreprocessor +from .data_preprocessor import BatchShapeDataPreprocessor, PoseDataPreprocessor -__all__ = ['PoseDataPreprocessor'] +__all__ = ['PoseDataPreprocessor', 'BatchShapeDataPreprocessor'] diff --git a/mmpose/models/data_preprocessors/data_preprocessor.py b/mmpose/models/data_preprocessors/data_preprocessor.py index bcfe54ab59..7ee3a42c04 100644 --- a/mmpose/models/data_preprocessors/data_preprocessor.py +++ b/mmpose/models/data_preprocessors/data_preprocessor.py @@ -1,5 +1,16 @@ # Copyright (c) OpenMMLab. All rights reserved. +import math +from numbers import Number +from typing import Optional, Sequence, Union + +import numpy as np +import torch +import torch.nn.functional as F +import torchvision.transforms.functional from mmengine.model import ImgDataPreprocessor +from mmengine.model.utils import stack_batch +from mmengine.utils import is_seq_of +from PIL import Image from mmpose.registry import MODELS @@ -7,3 +18,176 @@ @MODELS.register_module() class PoseDataPreprocessor(ImgDataPreprocessor): """Image pre-processor for pose estimation tasks.""" + + +@MODELS.register_module() +class BatchShapeDataPreprocessor(ImgDataPreprocessor): + """Image pre-processor for pose estimation tasks. + + Comparing with the :class:`PoseDataPreprocessor`, + + 1. It will additionally append batch_input_shape + to data_samples considering the DETR-based pose estimation tasks. + + 2. Add a 'pillow backend' pipeline based normalize operation, convert + np.array to PIL.Image, and normalize it through torchvision. + + It provides the data pre-processing as follows + + - Collate and move data to the target device. + - Pad inputs to the maximum size of current batch with defined + ``pad_value``. The padding size can be divisible by a defined + ``pad_size_divisor`` + - Stack inputs to batch_inputs. + - Convert inputs from bgr to rgb if the shape of input is (3, H, W). + - Normalize image with defined std and mean. + + Args: + - mean (Sequence[Number], optional): The pixel mean of R, G, B + channels. Defaults to None. + - std (Sequence[Number], optional): The pixel standard deviation + of R, G, B channels. Defaults to None. + - pad_size_divisor (int): The size of padded image should be + divisible by ``pad_size_divisor``. Defaults to 1. + - pad_value (Number): The padded pixel value. Defaults to 0. + - bgr_to_rgb (bool): whether to convert image from BGR to RGB. + Defaults to False. + - rgb_to_bgr (bool): whether to convert image from RGB to RGB. + Defaults to False. + - non_blocking (bool): Whether block current process + when transferring data to device. Defaults to False. + - normalize_bakend (str): choose the normalize backend + in ['cv2', 'pillow'] + """ + + def __init__(self, + mean: Sequence[Number] = None, + std: Sequence[Number] = None, + pad_size_divisor: int = 1, + pad_value: Union[float, int] = 0, + bgr_to_rgb: bool = False, + rgb_to_bgr: bool = False, + non_blocking: Optional[bool] = False, + normalize_bakend: str = 'cv2'): + super().__init__( + mean=mean, + std=std, + pad_size_divisor=pad_size_divisor, + pad_value=pad_value, + bgr_to_rgb=bgr_to_rgb, + rgb_to_bgr=rgb_to_bgr, + non_blocking=non_blocking) + self.normalize_bakend = normalize_bakend + + def forward(self, data: dict, training: bool = False) -> dict: + """Perform normalization、padding and bgr2rgb conversion based on + ``BaseDataPreprocessor``. + + Args: + data (dict): Data sampled from dataloader. + training (bool): Whether to enable training time augmentation. + + Returns: + dict: Data in the same format as the model input. + """ + if self.normalize_bakend == 'cv2': + data = super().forward(data=data, training=training) + else: + data = self.normalize_pillow(data=data, training=training) + + inputs, data_samples = data['inputs'], data['data_samples'] + + if data_samples is not None: + # NOTE the batched image size information may be useful, e.g. + # in DETR, this is needed for the construction of masks, which is + # then used for the transformer_head. + batch_input_shape = tuple(inputs[0].size()[-2:]) + for data_sample in data_samples: + + w, h = data_sample.ori_shape + center = np.array([w / 2, h / 2], dtype=np.float32) + scale = np.array([w, h], dtype=np.float32) + data_sample.set_metainfo({ + 'batch_input_shape': batch_input_shape, + 'input_size': data_sample.img_shape, + 'input_center': center, + 'input_scale': scale + }) + return {'inputs': inputs, 'data_samples': data_samples} + + def normalize_pillow(self, + data: dict, + training: bool = False) -> Union[dict, list]: + + data = self.cast_data(data) # type: ignore + _batch_inputs = data['inputs'] + # Process data with `pseudo_collate`. + if is_seq_of(_batch_inputs, torch.Tensor): + batch_inputs = [] + for _batch_input in _batch_inputs: + # channel transform + if self._channel_conversion: + _batch_input = _batch_input[[2, 1, 0], ...] + + _batch_input_array = _batch_input.detach().cpu().numpy( + ).transpose(1, 2, 0) + assert _batch_input_array.dtype == np.uint8, \ + 'Pillow backend only support uint8 type' + pil_image = Image.fromarray(_batch_input_array) + _batch_input = torchvision.transforms.functional.to_tensor( + pil_image).to(_batch_input.device) + + # Normalization. + if self._enable_normalize: + if self.mean.shape[0] == 3: + assert _batch_input.dim( + ) == 3 and _batch_input.shape[0] == 3, ( + 'If the mean has 3 values, the input tensor ' + 'should in shape of (3, H, W), but got the tensor ' + f'with shape {_batch_input.shape}') + _batch_input = torchvision.transforms.functional.normalize( + _batch_input, mean=self.mean, std=self.std) + batch_inputs.append(_batch_input) + # Pad and stack Tensor. + batch_inputs = stack_batch(batch_inputs, self.pad_size_divisor, + self.pad_value) + # Process data with `default_collate`. + elif isinstance(_batch_inputs, torch.Tensor): + assert _batch_inputs.dim() == 4, ( + 'The input of `ImgDataPreprocessor` should be a NCHW tensor ' + 'or a list of tensor, but got a tensor with shape: ' + f'{_batch_inputs.shape}') + if self._channel_conversion: + _batch_inputs = _batch_inputs[:, [2, 1, 0], ...] + # Convert to float after channel conversion to ensure + # efficiency + _batch_inputs_array = _batch_inputs.detach().cpu().numpy( + ).transpose(0, 2, 3, 1) + assert _batch_inputs.dtype == np.uint8, \ + 'Pillow backend only support uint8 type' + pil_image = Image.fromarray(_batch_inputs_array) + _batch_inputs = torchvision.transforms.functional.to_tensor( + pil_image).to(_batch_inputs.device) + + if self._enable_normalize: + _batch_inputs = torchvision.transforms.functional.normalize( + _batch_inputs, + mean=(self.mean / 255).tolist(), + std=(self.std / 255).tolist()) + + h, w = _batch_inputs.shape[2:] + target_h = math.ceil( + h / self.pad_size_divisor) * self.pad_size_divisor + target_w = math.ceil( + w / self.pad_size_divisor) * self.pad_size_divisor + pad_h = target_h - h + pad_w = target_w - w + batch_inputs = F.pad(_batch_inputs, (0, pad_w, 0, pad_h), + 'constant', self.pad_value) + else: + raise TypeError('Output of `cast_data` should be a dict of ' + 'list/tuple with inputs and data_samples, ' + f'but got {type(data)}: {data}') + data['inputs'] = batch_inputs + data.setdefault('data_samples', None) + return data diff --git a/mmpose/models/heads/__init__.py b/mmpose/models/heads/__init__.py index ef0e17d98e..2a2a49b0dc 100644 --- a/mmpose/models/heads/__init__.py +++ b/mmpose/models/heads/__init__.py @@ -8,11 +8,13 @@ MotionRegressionHead, RegressionHead, RLEHead, TemporalRegressionHead, TrajectoryRegressionHead) +from .transformer_heads import EDPoseHead, FrozenBatchNorm2d __all__ = [ 'BaseHead', 'HeatmapHead', 'CPMHead', 'MSPNHead', 'ViPNASHead', 'RegressionHead', 'IntegralRegressionHead', 'SimCCHead', 'RLEHead', 'DSNTHead', 'AssociativeEmbeddingHead', 'DEKRHead', 'VisPredictHead', 'CIDHead', 'RTMCCHead', 'TemporalRegressionHead', - 'TrajectoryRegressionHead', 'MotionRegressionHead' + 'TrajectoryRegressionHead', 'MotionRegressionHead', 'EDPoseHead', + 'FrozenBatchNorm2d' ] diff --git a/mmpose/models/heads/transformer_heads/__init__.py b/mmpose/models/heads/transformer_heads/__init__.py new file mode 100644 index 0000000000..f6d58126ef --- /dev/null +++ b/mmpose/models/heads/transformer_heads/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .edpose_head import EDPoseHead, FrozenBatchNorm2d +from .transformers import (MLP, DeformableDetrTransformerDecoder, + DeformableDetrTransformerDecoderLayer, + DeformableDetrTransformerEncoder, + DeformableDetrTransformerEncoderLayer, + DetrTransformerDecoder, DetrTransformerDecoderLayer, + DetrTransformerEncoder, DetrTransformerEncoderLayer, + PositionEmbeddingSineHW, inverse_sigmoid) + +__all__ = [ + 'EDPoseHead', 'FrozenBatchNorm2d', 'DetrTransformerEncoder', + 'DetrTransformerDecoder', 'DetrTransformerEncoderLayer', + 'DetrTransformerDecoderLayer', 'DeformableDetrTransformerEncoder', + 'DeformableDetrTransformerDecoder', + 'DeformableDetrTransformerEncoderLayer', + 'DeformableDetrTransformerDecoderLayer', 'inverse_sigmoid', + 'PositionEmbeddingSineHW', 'MLP' +] diff --git a/mmpose/models/heads/transformer_heads/base_transformer_head.py b/mmpose/models/heads/transformer_heads/base_transformer_head.py new file mode 100644 index 0000000000..bf94cc03b6 --- /dev/null +++ b/mmpose/models/heads/transformer_heads/base_transformer_head.py @@ -0,0 +1,112 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import abstractmethod +from typing import Dict, Tuple + +from torch import Tensor + +from mmpose.models.utils.tta import flip_coordinates +from mmpose.registry import MODELS +from mmpose.utils.typing import (Features, OptConfigType, OptMultiConfig, + OptSampleList, Predictions) +from ..base_head import BaseHead + + +@MODELS.register_module() +class TransformerHead(BaseHead): + r"""Implementation of `Deformable DETR: Deformable Transformers for + End-to-End Object Detection `_ + + Code is modified from the `official github repo + `_. + + Args: + encoder (:obj:`ConfigDict` or dict, optional): Config of the + Transformer encoder. Defaults to None. + decoder (:obj:`ConfigDict` or dict, optional): Config of the + Transformer decoder. Defaults to None. + out_head (:obj:`ConfigDict` or dict, optional): Config for the + bounding final out head module. Defaults to None. + positional_encoding (:obj:`ConfigDict` or dict): Config for + transformer position encoding. Defaults None + num_queries (int): Number of query in Transformer. + """ + _version = 2 + + def __init__(self, + encoder: OptConfigType = None, + decoder: OptConfigType = None, + out_head: OptConfigType = None, + positional_encoding: OptConfigType = None, + num_queries: int = 100, + loss: OptConfigType = None, + init_cfg: OptMultiConfig = None): + + if init_cfg is None: + init_cfg = self.default_init_cfg + + super().__init__(init_cfg) + + self.encoder_cfg = encoder + self.decoder_cfg = decoder + self.out_head_cfg = out_head + self.positional_encoding_cfg = positional_encoding + self.num_queries = num_queries + + def forward(self, + feats: Tuple[Tensor], + batch_data_samples: OptSampleList = None) -> Dict: + """Forward the network.""" + encoder_outputs_dict = self.forward_encoder(feats, batch_data_samples) + + decoder_outputs_dict = self.forward_decoder(**encoder_outputs_dict) + + head_outputs_dict = self.forward_out_head(batch_data_samples, + **decoder_outputs_dict) + return head_outputs_dict + + def predict(self, + feats: Features, + batch_data_samples: OptSampleList, + test_cfg: OptConfigType = {}) -> Predictions: + """Predict results from features.""" + + if test_cfg.get('flip_test', False): + # TTA: flip test -> feats = [orig, flipped] + assert isinstance(feats, list) and len(feats) == 2 + flip_indices = batch_data_samples[0].metainfo['flip_indices'] + input_size = batch_data_samples[0].metainfo['input_size'] + _feats, _feats_flip = feats + + _batch_coords = self.forward(_feats) + _batch_coords_flip = flip_coordinates( + self.forward(_feats_flip), + flip_indices=flip_indices, + shift_coords=test_cfg.get('shift_coords', True), + input_size=input_size) + batch_coords = (_batch_coords + _batch_coords_flip) * 0.5 + else: + batch_coords = self.forward(feats, batch_data_samples) # (B, K, D) + + return batch_coords + + def loss(self, + feats: Tuple[Tensor], + batch_data_samples: OptSampleList, + train_cfg: OptConfigType = {}) -> dict: + """Calculate losses from a batch of inputs and data samples.""" + pass + + @abstractmethod + def forward_encoder(self, feat: Tensor, feat_mask: Tensor, + feat_pos: Tensor, **kwargs) -> Dict: + pass + + @abstractmethod + def forward_decoder(self, query: Tensor, query_pos: Tensor, memory: Tensor, + **kwargs) -> Dict: + pass + + @abstractmethod + def forward_out_head(self, query: Tensor, query_pos: Tensor, + memory: Tensor, **kwargs) -> Dict: + pass diff --git a/mmpose/models/heads/transformer_heads/edpose_head.py b/mmpose/models/heads/transformer_heads/edpose_head.py new file mode 100644 index 0000000000..f2dd154563 --- /dev/null +++ b/mmpose/models/heads/transformer_heads/edpose_head.py @@ -0,0 +1,1439 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import math +from typing import Dict, List, Tuple + +import numpy as np +import torch +import torch.nn.functional as F +from mmcv.ops import MultiScaleDeformableAttention +from mmengine.model import BaseModule, ModuleList, constant_init +from mmengine.structures import InstanceData +from torch import Tensor, nn + +from mmpose.registry import KEYPOINT_CODECS, MODELS +from mmpose.utils.tensor_utils import to_numpy +from mmpose.utils.typing import (ConfigType, Features, OptConfigType, + OptSampleList, Predictions) +from .base_transformer_head import TransformerHead +from .transformers.deformable_detr_layers import ( + DeformableDetrTransformerDecoderLayer, DeformableDetrTransformerEncoder) +from .transformers.utils import MLP, PositionEmbeddingSineHW, inverse_sigmoid + + +@MODELS.register_module() +class FrozenBatchNorm2d(BaseModule): + """BatchNorm2d where the batch statistics and the affine parameters are + fixed. + + Copy-paste from torchvision.misc.ops with added eps before rqsrt, without + which any other models than torchvision.models.resnet[18,34,50,101] produce + nans. + """ + + def __init__(self, n, eps: int = 1e-5): + super(FrozenBatchNorm2d, self).__init__() + self.register_buffer('weight', torch.ones(n)) + self.register_buffer('bias', torch.zeros(n)) + self.register_buffer('running_mean', torch.zeros(n)) + self.register_buffer('running_var', torch.ones(n)) + self.eps = eps + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs): + num_batches_tracked_key = prefix + 'num_batches_tracked' + if num_batches_tracked_key in state_dict: + del state_dict[num_batches_tracked_key] + + super(FrozenBatchNorm2d, + self)._load_from_state_dict(state_dict, prefix, local_metadata, + strict, missing_keys, + unexpected_keys, error_msgs) + + def forward(self, x): + w = self.weight.reshape(1, -1, 1, 1) + b = self.bias.reshape(1, -1, 1, 1) + rv = self.running_var.reshape(1, -1, 1, 1) + rm = self.running_mean.reshape(1, -1, 1, 1) + scale = w * (rv + self.eps).rsqrt() + bias = b - rm * scale + return x * scale + bias + + +class EDPoseDecoder(BaseModule): + """Transformer decoder of EDPose: `Explicit Box Detection Unifies End-to- + End Multi-Person Pose Estimation. + + Args: + - layer_cfg (:obj:`ConfigDict` or dict): the config of each encoder + layer. All the layers will share the same config. + - num_layers (int): Number of decoder layers. + - return_intermediate (bool, optional): Whether to return outputs of + intermediate layers. Defaults to `True`. + - embed_dims (int): Dims of embed. + - query_dim (int): Dims of queries. + - num_feature_levels (int): Number of feature levels. + - num_box_decoder_layers (int): Number of box decoder layers. + - num_body_points (int): Number of datasets' body keypoints. + - num_dn (int): Number of denosing points. + - num_group (int): Number of decoder layers. + """ + + def __init__(self, + layer_cfg, + num_layers, + return_intermediate, + embed_dims: int = 256, + query_dim=4, + num_feature_levels=1, + num_box_decoder_layers=2, + num_body_points=17, + num_dn=100, + num_group=100): + super().__init__() + + self.layer_cfg = layer_cfg + self.num_layers = num_layers + self.embed_dims = embed_dims + + assert return_intermediate, 'support return_intermediate only' + self.return_intermediate = return_intermediate + + assert query_dim in [ + 2, 4 + ], 'query_dim should be 2/4 but {}'.format(query_dim) + self.query_dim = query_dim + + self.num_feature_levels = num_feature_levels + + self.layers = ModuleList([ + DeformableDetrTransformerDecoderLayer(**self.layer_cfg) + for _ in range(self.num_layers) + ]) + self.norm = nn.LayerNorm(self.embed_dims) + + self.ref_point_head = MLP(self.query_dim // 2 * self.embed_dims, + self.embed_dims, self.embed_dims, 2) + + self.num_body_points = num_body_points + self.query_scale = None + self.bbox_embed = None + self.class_embed = None + self.pose_embed = None + self.pose_hw_embed = None + self.num_box_decoder_layers = num_box_decoder_layers + self.box_pred_damping = None + self.num_group = num_group + self.rm_detach = None + self.num_dn = num_dn + self.hw = nn.Embedding(self.num_body_points, 2) + self.keypoint_embed = nn.Embedding(self.num_body_points, embed_dims) + self.kpt_index = [ + x for x in range(self.num_group * (self.num_body_points + 1)) + if x % (self.num_body_points + 1) != 0 + ] + + def forward(self, query: Tensor, value: Tensor, key_padding_mask: Tensor, + reference_points: Tensor, spatial_shapes: Tensor, + level_start_index: Tensor, valid_ratios: Tensor, + humandet_attn_mask: Tensor, human2pose_attn_mask: Tensor, + **kwargs) -> Tuple[Tensor]: + """Forward function of decoder + Args: + query (Tensor): The input queries, has shape (bs, num_queries, + dim). + value (Tensor): The input values, has shape (bs, num_value, dim). + key_padding_mask (Tensor): The `key_padding_mask` of `cross_attn` + input. ByteTensor, has shape (bs, num_value). + reference_points (Tensor): The initial reference, has shape + (bs, num_queries, 4) with the last dimension arranged as + (cx, cy, w, h) when `as_two_stage` is `True`, otherwise has + shape (bs, num_queries, 2) with the last dimension arranged + as (cx, cy). + spatial_shapes (Tensor): Spatial shapes of features in all levels, + has shape (num_levels, 2), last dimension represents (h, w). + level_start_index (Tensor): The start index of each level. + A tensor has shape (num_levels, ) and can be represented + as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...]. + valid_ratios (Tensor): The ratios of the valid width and the valid + height relative to the width and the height of features in all + levels, has shape (bs, num_levels, 2). + reg_branches: (obj:`nn.ModuleList`, optional): Used for refining + the regression results. + + Returns: + tuple[Tensor]: Outputs of Deformable Transformer Decoder. + + - output (Tensor): Output embeddings of the last decoder, has + shape (num_queries, bs, embed_dims) when `return_intermediate` + is `False`. Otherwise, Intermediate output embeddings of all + decoder layers, has shape (num_decoder_layers, num_queries, bs, + embed_dims). + - reference_points (Tensor): The reference of the last decoder + layer, has shape (bs, num_queries, 4) when `return_intermediate` + is `False`. Otherwise, Intermediate references of all decoder + layers, has shape (num_decoder_layers, bs, num_queries, 4). The + coordinates are arranged as (cx, cy, w, h) + """ + output = query + attn_mask = humandet_attn_mask + intermediate = [] + intermediate_reference_points = [reference_points] + effect_num_dn = self.num_dn if self.training else 0 + inter_select_number = self.num_group + for layer_id, layer in enumerate(self.layers): + if reference_points.shape[-1] == 4: + reference_points_input = \ + reference_points[:, :, None] * \ + torch.cat([valid_ratios, valid_ratios], -1)[None, :] + else: + assert reference_points.shape[-1] == 2 + reference_points_input = \ + reference_points[:, :, None] * \ + valid_ratios[None, :] + + query_sine_embed = self.get_proposal_pos_embed( + reference_points_input[:, :, 0, :]) # nq, bs, 256*2 + query_pos = self.ref_point_head(query_sine_embed) # nq, bs, 256 + + output = layer( + output.transpose(0, 1), + query_pos=query_pos.transpose(0, 1), + value=value.transpose(0, 1), + key_padding_mask=key_padding_mask, + spatial_shapes=spatial_shapes, + level_start_index=level_start_index, + valid_ratios=valid_ratios, + reference_points=reference_points_input.transpose( + 0, 1).contiguous(), + self_attn_mask=attn_mask, + **kwargs) + output = output.transpose(0, 1) + intermediate.append(self.norm(output)) + + # human update + if layer_id < self.num_box_decoder_layers: + delta_unsig = self.bbox_embed[layer_id](output) + new_reference_points = delta_unsig + inverse_sigmoid( + reference_points) + new_reference_points = new_reference_points.sigmoid() + + # query expansion + if layer_id == self.num_box_decoder_layers - 1: + dn_output = output[:effect_num_dn] + dn_new_reference_points = new_reference_points[:effect_num_dn] + class_unselected = self.class_embed[layer_id]( + output)[effect_num_dn:] + topk_proposals = torch.topk( + class_unselected.max(-1)[0], inter_select_number, dim=0)[1] + new_reference_points_for_box = torch.gather( + new_reference_points[effect_num_dn:], 0, + topk_proposals.unsqueeze(-1).repeat(1, 1, 4)) + new_output_for_box = torch.gather( + output[effect_num_dn:], 0, + topk_proposals.unsqueeze(-1).repeat(1, 1, self.embed_dims)) + bs = new_output_for_box.shape[1] + new_output_for_keypoint = new_output_for_box[:, None, :, :] \ + + self.keypoint_embed.weight[None, :, None, :] + if self.num_body_points == 17: + delta_xy = self.pose_embed[-1](new_output_for_keypoint)[ + ..., :2] + else: + delta_xy = self.pose_embed[0](new_output_for_keypoint)[ + ..., :2] + keypoint_xy = (inverse_sigmoid( + new_reference_points_for_box[..., :2][:, None]) + + delta_xy).sigmoid() + num_queries, _, bs, _ = keypoint_xy.shape + keypoint_wh_weight = self.hw.weight.unsqueeze(0).unsqueeze( + -2).repeat(num_queries, 1, bs, 1).sigmoid() + keypoint_wh = keypoint_wh_weight * \ + new_reference_points_for_box[..., 2:][:, None] + new_reference_points_for_keypoint = torch.cat( + (keypoint_xy, keypoint_wh), dim=-1) + new_reference_points = torch.cat( + (new_reference_points_for_box.unsqueeze(1), + new_reference_points_for_keypoint), + dim=1).flatten(0, 1) + output = torch.cat( + (new_output_for_box.unsqueeze(1), new_output_for_keypoint), + dim=1).flatten(0, 1) + new_reference_points = torch.cat( + (dn_new_reference_points, new_reference_points), dim=0) + output = torch.cat((dn_output, output), dim=0) + attn_mask = human2pose_attn_mask + + # human-to-keypoints update + if layer_id >= self.num_box_decoder_layers: + effect_num_dn = self.num_dn if self.training else 0 + inter_select_number = self.num_group + ref_before_sigmoid = inverse_sigmoid(reference_points) + output_bbox_dn = output[:effect_num_dn] + output_bbox_norm = output[effect_num_dn:][0::( + self.num_body_points + 1)] + ref_before_sigmoid_bbox_dn = \ + ref_before_sigmoid[:effect_num_dn] + ref_before_sigmoid_bbox_norm = \ + ref_before_sigmoid[effect_num_dn:][0::( + self.num_body_points + 1)] + delta_unsig_dn = self.bbox_embed[layer_id](output_bbox_dn) + delta_unsig_norm = self.bbox_embed[layer_id](output_bbox_norm) + outputs_unsig_dn = delta_unsig_dn + ref_before_sigmoid_bbox_dn + outputs_unsig_norm = delta_unsig_norm + \ + ref_before_sigmoid_bbox_norm + new_reference_points_for_box_dn = outputs_unsig_dn.sigmoid() + new_reference_points_for_box_norm = outputs_unsig_norm.sigmoid( + ) + output_kpt = output[effect_num_dn:].index_select( + 0, torch.tensor(self.kpt_index, device=output.device)) + delta_xy_unsig = self.pose_embed[layer_id - + self.num_box_decoder_layers]( + output_kpt) + outputs_unsig = ref_before_sigmoid[ + effect_num_dn:].index_select( + 0, torch.tensor(self.kpt_index, + device=output.device)).clone() + delta_hw_unsig = self.pose_hw_embed[ + layer_id - self.num_box_decoder_layers]( + output_kpt) + outputs_unsig[..., :2] += delta_xy_unsig[..., :2] + outputs_unsig[..., 2:] += delta_hw_unsig + new_reference_points_for_keypoint = outputs_unsig.sigmoid() + bs = new_reference_points_for_box_norm.shape[1] + new_reference_points_norm = torch.cat( + (new_reference_points_for_box_norm.unsqueeze(1), + new_reference_points_for_keypoint.view( + -1, self.num_body_points, bs, 4)), + dim=1).flatten(0, 1) + new_reference_points = torch.cat( + (new_reference_points_for_box_dn, + new_reference_points_norm), + dim=0) + + reference_points = new_reference_points.detach() + intermediate_reference_points.append(reference_points) + + return [[itm_out.transpose(0, 1) for itm_out in intermediate], + [ + itm_refpoint.transpose(0, 1) + for itm_refpoint in intermediate_reference_points + ]] + + @staticmethod + def get_proposal_pos_embed(pos_tensor: Tensor, + temperature: int = 10000, + num_pos_feats: int = 128) -> Tensor: + """Get the position embedding of the proposal. + + Args: + pos_tensor (Tensor): Not normalized proposals, has shape + (bs, num_queries, 4) with the last dimension arranged as + (cx, cy, w, h). + temperature (int, optional): The temperature used for scaling the + position embedding. Defaults to 10000. + num_pos_feats (int, optional): The feature dimension for each + position along x, y, w, and h-axis. Note the final returned + dimension for each position is 4 times of num_pos_feats. + Default to 128. + + Returns: + Tensor: The position embedding of proposal, has shape + (bs, num_queries, num_pos_feats * 4), with the last dimension + arranged as (cx, cy, w, h) + """ + + scale = 2 * math.pi + dim_t = torch.arange( + num_pos_feats, dtype=torch.float32, device=pos_tensor.device) + dim_t = temperature**(2 * (dim_t // 2) / num_pos_feats) + x_embed = pos_tensor[:, :, 0] * scale + y_embed = pos_tensor[:, :, 1] * scale + pos_x = x_embed[:, :, None] / dim_t + pos_y = y_embed[:, :, None] / dim_t + pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), + dim=3).flatten(2) + pos_y = torch.stack((pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), + dim=3).flatten(2) + if pos_tensor.size(-1) == 2: + pos = torch.cat((pos_y, pos_x), dim=2) + elif pos_tensor.size(-1) == 4: + w_embed = pos_tensor[:, :, 2] * scale + pos_w = w_embed[:, :, None] / dim_t + pos_w = torch.stack( + (pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()), + dim=3).flatten(2) + + h_embed = pos_tensor[:, :, 3] * scale + pos_h = h_embed[:, :, None] / dim_t + pos_h = torch.stack( + (pos_h[:, :, 0::2].sin(), pos_h[:, :, 1::2].cos()), + dim=3).flatten(2) + + pos = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=2) + else: + raise ValueError('Unknown pos_tensor shape(-1):{}'.format( + pos_tensor.size(-1))) + return pos + + +class EDPoseOutHead(BaseModule): + """Final Head of EDPose: `Explicit Box Detection Unifies End-to-End Multi- + Person Pose Estimation. + + Args: + - num_classes (int): The number of classes. + - num_body_points (int): The number of datasets' body keypoints. + - num_queries (int): The number of queries. + - cls_no_bias (bool): Weather add the bias to class embed. + - embed_dims (int): The dims of embed. + - as_two_stage (bool, optional): Whether to generate the proposal + from the outputs of encoder. Defaults to `False`. + - refine_queries_num (int): The number of refines queries after + decoders. + - num_box_decoder_layers (int): The number of bbox decoder layer. + - num_group (int): The number of groups. + - num_pred_layer (int): The number of the prediction layers. + Defaults to 6. + - dec_pred_class_embed_share (bool): Whether to share parameters + for all the class prediction layers. Defaults to `False`. + - dec_pred_bbox_embed_share (bool): Whether to share parameters + for all the bbox prediction layers. Defaults to `False`. + - dec_pred_pose_embed_share (bool): Whether to share parameters + for all the pose prediction layers. Defaults to `False`. + """ + + def __init__(self, + num_classes, + num_body_points: int = 17, + num_queries: int = 900, + cls_no_bias: bool = False, + embed_dims: int = 256, + as_two_stage: bool = False, + refine_queries_num: int = 100, + num_box_decoder_layers: int = 2, + num_group: int = 100, + num_pred_layer: int = 6, + dec_pred_class_embed_share: bool = False, + dec_pred_bbox_embed_share: bool = False, + dec_pred_pose_embed_share: bool = False, + **kwargs): + super().__init__() + self.embed_dims = embed_dims + self.as_two_stage = as_two_stage + self.num_classes = num_classes + self.refine_queries_num = refine_queries_num + self.num_box_decoder_layers = num_box_decoder_layers + self.num_body_points = num_body_points + self.num_queries = num_queries + + # prepare pred layers + self.dec_pred_class_embed_share = dec_pred_class_embed_share + self.dec_pred_bbox_embed_share = dec_pred_bbox_embed_share + self.dec_pred_pose_embed_share = dec_pred_pose_embed_share + # prepare class & box embed + _class_embed = nn.Linear( + self.embed_dims, self.num_classes, bias=(not cls_no_bias)) + if not cls_no_bias: + prior_prob = 0.01 + bias_value = -math.log((1 - prior_prob) / prior_prob) + _class_embed.bias.data = torch.ones(self.num_classes) * bias_value + + _bbox_embed = MLP(self.embed_dims, self.embed_dims, 4, 3) + _pose_embed = MLP(self.embed_dims, self.embed_dims, 2, 3) + _pose_hw_embed = MLP(self.embed_dims, self.embed_dims, 2, 3) + + self.num_group = num_group + if dec_pred_bbox_embed_share: + box_embed_layerlist = [_bbox_embed for i in range(num_pred_layer)] + else: + box_embed_layerlist = [ + copy.deepcopy(_bbox_embed) for i in range(num_pred_layer) + ] + if dec_pred_class_embed_share: + class_embed_layerlist = [ + _class_embed for i in range(num_pred_layer) + ] + else: + class_embed_layerlist = [ + copy.deepcopy(_class_embed) for i in range(num_pred_layer) + ] + + if num_body_points == 17: + if dec_pred_pose_embed_share: + pose_embed_layerlist = [ + _pose_embed + for i in range(num_pred_layer - num_box_decoder_layers + 1) + ] + else: + pose_embed_layerlist = [ + copy.deepcopy(_pose_embed) + for i in range(num_pred_layer - num_box_decoder_layers + 1) + ] + else: + if dec_pred_pose_embed_share: + pose_embed_layerlist = [ + _pose_embed + for i in range(num_pred_layer - num_box_decoder_layers) + ] + else: + pose_embed_layerlist = [ + copy.deepcopy(_pose_embed) + for i in range(num_pred_layer - num_box_decoder_layers) + ] + + pose_hw_embed_layerlist = [ + _pose_hw_embed + for i in range(num_pred_layer - num_box_decoder_layers) + ] + self.bbox_embed = nn.ModuleList(box_embed_layerlist) + self.class_embed = nn.ModuleList(class_embed_layerlist) + self.pose_embed = nn.ModuleList(pose_embed_layerlist) + self.pose_hw_embed = nn.ModuleList(pose_hw_embed_layerlist) + + def init_weights(self) -> None: + """Initialize weights of the Deformable DETR head.""" + + for m in self.bbox_embed: + constant_init(m[-1], 0, bias=0) + for m in self.pose_embed: + constant_init(m[-1], 0, bias=0) + + def forward(self, hidden_states: List[Tensor], references: List[Tensor], + mask_dict: Dict, hidden_states_enc: Tensor, + referens_enc: Tensor, batch_data_samples) -> Dict: + """Forward function. + + Args: + hidden_states (Tensor): Hidden states output from each decoder + layer, has shape (num_decoder_layers, bs, num_queries, dim). + references (list[Tensor]): List of the reference from the decoder. + + Returns: + tuple[Tensor]: results of head containing the following tensor. + + - pred_logits (Tensor): Outputs from the + classification head, the socres of every bboxes. + - pred_boxes (Tensor): The output boxes. + - pred_keypoints (Tensor): The output keypoints. + """ + # update human boxes + effec_dn_num = self.refine_queries_num if self.training else 0 + outputs_coord_list = [] + outputs_class = [] + for dec_lid, (layer_ref_sig, layer_bbox_embed, layer_cls_embed, + layer_hs) in enumerate( + zip(references[:-1], self.bbox_embed, + self.class_embed, hidden_states)): + if dec_lid < self.num_box_decoder_layers: + layer_delta_unsig = layer_bbox_embed(layer_hs) + layer_outputs_unsig = layer_delta_unsig + inverse_sigmoid( + layer_ref_sig) + layer_outputs_unsig = layer_outputs_unsig.sigmoid() + layer_cls = layer_cls_embed(layer_hs) + outputs_coord_list.append(layer_outputs_unsig) + outputs_class.append(layer_cls) + else: + layer_hs_bbox_dn = layer_hs[:, :effec_dn_num, :] + layer_hs_bbox_norm = \ + layer_hs[:, effec_dn_num:, :][:, 0::( + self.num_body_points + 1), :] + bs = layer_ref_sig.shape[0] + ref_before_sigmoid_bbox_dn = \ + layer_ref_sig[:, : effec_dn_num, :] + ref_before_sigmoid_bbox_norm = \ + layer_ref_sig[:, effec_dn_num:, :][:, 0::( + self.num_body_points + 1), :] + layer_delta_unsig_dn = layer_bbox_embed(layer_hs_bbox_dn) + layer_delta_unsig_norm = layer_bbox_embed(layer_hs_bbox_norm) + layer_outputs_unsig_dn = layer_delta_unsig_dn + \ + inverse_sigmoid(ref_before_sigmoid_bbox_dn) + layer_outputs_unsig_dn = layer_outputs_unsig_dn.sigmoid() + layer_outputs_unsig_norm = layer_delta_unsig_norm + \ + inverse_sigmoid(ref_before_sigmoid_bbox_norm) + layer_outputs_unsig_norm = layer_outputs_unsig_norm.sigmoid() + layer_outputs_unsig = torch.cat( + (layer_outputs_unsig_dn, layer_outputs_unsig_norm), dim=1) + layer_cls_dn = layer_cls_embed(layer_hs_bbox_dn) + layer_cls_norm = layer_cls_embed(layer_hs_bbox_norm) + layer_cls = torch.cat((layer_cls_dn, layer_cls_norm), dim=1) + outputs_class.append(layer_cls) + outputs_coord_list.append(layer_outputs_unsig) + + # update keypoints boxes + outputs_keypoints_list = [] + kpt_index = [ + x for x in range(self.num_group * (self.num_body_points + 1)) + if x % (self.num_body_points + 1) != 0 + ] + for dec_lid, (layer_ref_sig, layer_hs) in enumerate( + zip(references[:-1], hidden_states)): + if dec_lid < self.num_box_decoder_layers: + assert isinstance(layer_hs, torch.Tensor) + bs = layer_hs.shape[0] + layer_res = layer_hs.new_zeros( + (bs, self.num_queries, self.num_body_points * 3)) + outputs_keypoints_list.append(layer_res) + else: + bs = layer_ref_sig.shape[0] + layer_hs_kpt = \ + layer_hs[:, effec_dn_num:, :].index_select( + 1, torch.tensor(kpt_index, device=layer_hs.device)) + delta_xy_unsig = self.pose_embed[dec_lid - + self.num_box_decoder_layers]( + layer_hs_kpt) + layer_ref_sig_kpt = \ + layer_ref_sig[:, effec_dn_num:, :].index_select( + 1, torch.tensor(kpt_index, device=layer_hs.device)) + layer_outputs_unsig_keypoints = delta_xy_unsig + \ + inverse_sigmoid(layer_ref_sig_kpt[..., :2]) + vis_xy_unsig = torch.ones_like( + layer_outputs_unsig_keypoints, + device=layer_outputs_unsig_keypoints.device) + xyv = torch.cat((layer_outputs_unsig_keypoints, + vis_xy_unsig[:, :, 0].unsqueeze(-1)), + dim=-1) + xyv = xyv.sigmoid() + layer_res = xyv.reshape( + (bs, self.num_group, self.num_body_points, + 3)).flatten(2, 3) + layer_res = self.keypoint_xyzxyz_to_xyxyzz(layer_res) + outputs_keypoints_list.append(layer_res) + + dn_mask_dict = mask_dict + if self.refine_queries_num > 0 and dn_mask_dict is not None: + outputs_class, outputs_coord_list, outputs_keypoints_list = \ + self.dn_post_process2( + outputs_class, outputs_coord_list, + outputs_keypoints_list, dn_mask_dict + ) + # TODO:the denosing strateges are used in training stage + + for idx, (_out_class, _out_bbox, _out_keypoint) in enumerate( + zip(outputs_class, outputs_coord_list, + outputs_keypoints_list)): + assert _out_class.shape[1] == \ + _out_bbox.shape[1] == _out_keypoint.shape[1] + + out = { + 'pred_logits': outputs_class[-1], + 'pred_boxes': outputs_coord_list[-1], + 'pred_keypoints': outputs_keypoints_list[-1] + } + + # TODO: if refine_queries_num and dn_mask_dict are used: + + return out + + def keypoint_xyzxyz_to_xyxyzz(self, keypoints: torch.Tensor): + """ + Args: + keypoints (torch.Tensor): ..., 51 + """ + res = torch.zeros_like(keypoints) + num_points = keypoints.shape[-1] // 3 + res[..., 0:2 * num_points:2] = keypoints[..., 0::3] + res[..., 1:2 * num_points:2] = keypoints[..., 1::3] + res[..., 2 * num_points:] = keypoints[..., 2::3] + return res + + +@MODELS.register_module() +class EDPoseHead(TransformerHead): + """Head introduced in `Explicit Box Detection Unifies End-to-End Multi- + Person Pose Estimation`_ by J Yang1 et al (2023). The head is composed of + Encoder、Decoder、Out_head. + + Code is modified from the `official github repo + `_. + + More details can be found in the `paper + `_ . + + Args: + num_queries (int): Number of query in Transformer. + num_feature_levels (int, optional): Number of feature levels. + Defaults to 4. + as_two_stage (bool, optional): Whether to generate the proposal + from the outputs of encoder. Defaults to `False`. + encoder (:obj:`ConfigDict` or dict, optional): Config of the + Transformer encoder. Defaults to None. + decoder (:obj:`ConfigDict` or dict, optional): Config of the + Transformer decoder. Defaults to None. + out_head (:obj:`ConfigDict` or dict, optional): Config for the + bounding final out head module. Defaults to None. + positional_encoding (:obj:`ConfigDict` or dict): Config for + transformer position encoding. Defaults None. + denosing_cfg (:obj:`ConfigDict` or dict, optional): Config of the + human query denoising training strategy. + data_decoder (:obj:`ConfigDict` or dict, optional): Config of the + data decoder which transform the results from output space to + input space. + dec_pred_class_embed_share (bool): Whether to share the class embed + layer. Default False. + dec_pred_bbox_embed_share (bool): Whether to share the bbox embed + layer. Default False. + refine_queries_num (int): Number of refined human content queries + and their position queries . + two_stage_keep_all_tokens (bool): Whether to keep all tokens. + """ + + def __init__(self, + num_queries: int = 100, + num_feature_levels: int = 4, + num_body_points: int = 17, + as_two_stage: bool = False, + encoder: OptConfigType = None, + decoder: OptConfigType = None, + out_head: OptConfigType = None, + positional_encoding: OptConfigType = None, + data_decoder: OptConfigType = None, + denosing_cfg: OptConfigType = None, + dec_pred_class_embed_share: bool = False, + dec_pred_bbox_embed_share: bool = False, + refine_queries_num: int = 100, + two_stage_keep_all_tokens: bool = False) -> None: + + self.as_two_stage = as_two_stage + self.num_feature_levels = num_feature_levels + self.refine_queries_num = refine_queries_num + self.dec_pred_class_embed_share = dec_pred_class_embed_share + self.dec_pred_bbox_embed_share = dec_pred_bbox_embed_share + self.two_stage_keep_all_tokens = two_stage_keep_all_tokens + self.num_heads = decoder['layer_cfg']['self_attn_cfg']['num_heads'] + self.num_group = decoder['num_group'] + self.num_body_points = num_body_points + self.denosing_cfg = denosing_cfg + if data_decoder is not None: + self.data_decoder = KEYPOINT_CODECS.build(data_decoder) + else: + self.data_decoder = None + + super().__init__( + encoder=encoder, + decoder=decoder, + out_head=out_head, + positional_encoding=positional_encoding, + num_queries=num_queries) + + self.positional_encoding = PositionEmbeddingSineHW( + **self.positional_encoding_cfg) + self.encoder = DeformableDetrTransformerEncoder(**self.encoder_cfg) + self.decoder = EDPoseDecoder( + num_body_points=num_body_points, **self.decoder_cfg) + self.out_head = EDPoseOutHead( + num_body_points=num_body_points, + as_two_stage=as_two_stage, + refine_queries_num=refine_queries_num, + **self.out_head_cfg, + **self.decoder_cfg) + + self.embed_dims = self.encoder.embed_dims + self.label_enc = nn.Embedding( + self.denosing_cfg['dn_labelbook_size'] + 1, self.embed_dims) + + if not self.as_two_stage: + self.query_embedding = nn.Embedding(self.num_queries, + self.embed_dims) + self.refpoint_embedding = nn.Embedding(self.num_queries, 4) + + self.level_embed = nn.Parameter( + torch.Tensor(self.num_feature_levels, self.embed_dims)) + + self.decoder.bbox_embed = self.out_head.bbox_embed + self.decoder.pose_embed = self.out_head.pose_embed + self.decoder.pose_hw_embed = self.out_head.pose_hw_embed + self.decoder.class_embed = self.out_head.class_embed + + if self.as_two_stage: + self.memory_trans_fc = nn.Linear(self.embed_dims, self.embed_dims) + self.memory_trans_norm = nn.LayerNorm(self.embed_dims) + if dec_pred_class_embed_share and dec_pred_bbox_embed_share: + self.enc_out_bbox_embed = self.out_head.bbox_embed[0] + else: + self.enc_out_bbox_embed = copy.deepcopy( + self.out_head.bbox_embed[0]) + + if dec_pred_class_embed_share and dec_pred_bbox_embed_share: + self.enc_out_class_embed = self.out_head.class_embed[0] + else: + self.enc_out_class_embed = copy.deepcopy( + self.out_head.class_embed[0]) + + def init_weights(self) -> None: + """Initialize weights for Transformer and other components.""" + super().init_weights() + for coder in self.encoder, self.decoder: + for p in coder.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + for m in self.modules(): + if isinstance(m, MultiScaleDeformableAttention): + m.init_weights() + if self.as_two_stage: + nn.init.xavier_uniform_(self.memory_trans_fc.weight) + + nn.init.normal_(self.level_embed) + + def pre_transformer(self, + img_feats: Tuple[Tensor], + batch_data_samples: OptSampleList = None + ) -> Tuple[Dict]: + """Process image features before feeding them to the transformer. + + Args: + img_feats (tuple[Tensor]): Multi-level features that may have + different resolutions, output from neck. Each feature has + shape (bs, dim, h_lvl, w_lvl), where 'lvl' means 'layer'. + batch_data_samples (list[:obj:`DetDataSample`], optional): The + batch data samples. It usually includes information such + as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`. + Defaults to None. + + Returns: + tuple[dict]: The first dict contains the inputs of encoder and the + second dict contains the inputs of decoder. + + - encoder_inputs_dict (dict): The keyword args dictionary of + `self.encoder()`. + - decoder_inputs_dict (dict): The keyword args dictionary of + `self.forward_decoder()`, which includes 'memory_mask'. + """ + batch_size = img_feats[0].size(0) + # construct binary masks for the transformer. + assert batch_data_samples is not None + batch_input_shape = batch_data_samples[0].batch_input_shape + img_shape_list = [sample.img_shape for sample in batch_data_samples] + input_img_h, input_img_w = batch_input_shape + masks = img_feats[0].new_ones((batch_size, input_img_h, input_img_w)) + for img_id in range(batch_size): + img_h, img_w = img_shape_list[img_id] + masks[img_id, :img_h, :img_w] = 0 + # NOTE following the official DETR repo, non-zero values representing + # ignored positions, while zero values means valid positions. + + mlvl_masks = [] + mlvl_pos_embeds = [] + for feat in img_feats: + mlvl_masks.append( + F.interpolate(masks[None], + size=feat.shape[-2:]).to(torch.bool).squeeze(0)) + mlvl_pos_embeds.append(self.positional_encoding(mlvl_masks[-1])) + + feat_flatten = [] + lvl_pos_embed_flatten = [] + mask_flatten = [] + spatial_shapes = [] + for lvl, (feat, mask, pos_embed) in enumerate( + zip(img_feats, mlvl_masks, mlvl_pos_embeds)): + batch_size, c, h, w = feat.shape + # [bs, c, h_lvl, w_lvl] -> [bs, h_lvl*w_lvl, c] + feat = feat.view(batch_size, c, -1).permute(0, 2, 1) + pos_embed = pos_embed.view(batch_size, c, -1).permute(0, 2, 1) + lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1) + # [bs, h_lvl, w_lvl] -> [bs, h_lvl*w_lvl] + mask = mask.flatten(1) + spatial_shape = (h, w) + + feat_flatten.append(feat) + lvl_pos_embed_flatten.append(lvl_pos_embed) + mask_flatten.append(mask) + spatial_shapes.append(spatial_shape) + + # (bs, num_feat_points, dim) + feat_flatten = torch.cat(feat_flatten, 1) + lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1) + # (bs, num_feat_points), where num_feat_points = sum_lvl(h_lvl*w_lvl) + mask_flatten = torch.cat(mask_flatten, 1) + + spatial_shapes = torch.as_tensor( # (num_level, 2) + spatial_shapes, + dtype=torch.long, + device=feat_flatten.device) + level_start_index = torch.cat(( + spatial_shapes.new_zeros((1, )), # (num_level) + spatial_shapes.prod(1).cumsum(0)[:-1])) + valid_ratios = torch.stack( # (bs, num_level, 2) + [self.get_valid_ratio(m) for m in mlvl_masks], 1) + + if self.refine_queries_num > 0 or batch_data_samples is not None: + input_query_label, input_query_bbox, humandet_attn_mask, \ + human2pose_attn_mask, mask_dict =\ + self.prepare_for_denosing( + batch_data_samples, + device=img_feats[0].device) + else: + assert batch_data_samples is None + input_query_bbox = input_query_label = \ + humandet_attn_mask = human2pose_attn_mask = mask_dict = None + + encoder_inputs_dict = dict( + query=feat_flatten, + query_pos=lvl_pos_embed_flatten, + key_padding_mask=mask_flatten, + spatial_shapes=spatial_shapes, + level_start_index=level_start_index, + valid_ratios=valid_ratios) + decoder_inputs_dict = dict( + memory_mask=mask_flatten, + spatial_shapes=spatial_shapes, + level_start_index=level_start_index, + valid_ratios=valid_ratios, + humandet_attn_mask=humandet_attn_mask, + human2pose_attn_mask=human2pose_attn_mask, + input_query_bbox=input_query_bbox, + input_query_label=input_query_label, + mask_dict=mask_dict) + return encoder_inputs_dict, decoder_inputs_dict + + def forward_encoder(self, + img_feats: Tuple[Tensor], + batch_data_samples: OptSampleList = None) -> Dict: + """Forward with Transformer encoder. + + The forward procedure is defined as: + 'pre_transformer' -> 'encoder' + + Args: + img_feats (tuple[Tensor]): Multi-level features that may have + different resolutions, output from neck. Each feature has + shape (bs, dim, h_lvl, w_lvl), where 'lvl' means 'layer'. + batch_data_samples (list[:obj:`DetDataSample`], optional): The + batch data samples. It usually includes information such + as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`. + Defaults to None. + + Returns: + dict: The dictionary of encoder outputs, which includes the + `memory` of the encoder output. + """ + encoder_inputs_dict, decoder_inputs_dict = self.pre_transformer( + img_feats, batch_data_samples) + + memory = self.encoder(**encoder_inputs_dict) + encoder_outputs_dict = dict(memory=memory, **decoder_inputs_dict) + return encoder_outputs_dict + + def pre_decoder(self, memory: Tensor, memory_mask: Tensor, + spatial_shapes: Tensor, input_query_bbox: Tensor, + input_query_label: Tensor) -> Tuple[Dict, Dict]: + """Prepare intermediate variables before entering Transformer decoder, + such as `query` and `reference_points`. + + Args: + memory (Tensor): The output embeddings of the Transformer encoder, + has shape (bs, num_feat_points, dim). + memory_mask (Tensor): ByteTensor, the padding mask of the memory, + has shape (bs, num_feat_points). It will only be used when + `as_two_stage` is `True`. + spatial_shapes (Tensor): Spatial shapes of features in all levels, + has shape (num_levels, 2), last dimension represents (h, w). + It will only be used when `as_two_stage` is `True`. + input_query_bbox (Tensor): Denosing bbox query for training. + input_query_label (Tensor): Denosing label query for training. + + Returns: + tuple[dict, dict]: The decoder_inputs_dict and head_inputs_dict. + + - decoder_inputs_dict (dict): The keyword dictionary args of + `self.decoder()`. + - head_inputs_dict (dict): The keyword dictionary args of the + bbox_head functions. + """ + bs, _, c = memory.shape + if self.as_two_stage: + output_memory, output_proposals = \ + self.gen_encoder_output_proposals( + memory, memory_mask, spatial_shapes) + enc_outputs_class = self.enc_out_class_embed(output_memory) + enc_outputs_coord_unact = self.enc_out_bbox_embed( + output_memory) + output_proposals + + topk_proposals = torch.topk( + enc_outputs_class.max(-1)[0], self.num_queries, dim=1)[1] + topk_coords_undetach = torch.gather( + enc_outputs_coord_unact, 1, + topk_proposals.unsqueeze(-1).repeat(1, 1, 4)) + topk_coords_unact = topk_coords_undetach.detach() + reference_points = topk_coords_unact.sigmoid() + + query_undetach = torch.gather( + output_memory, 1, + topk_proposals.unsqueeze(-1).repeat(1, 1, self.embed_dims)) + query = query_undetach.detach() + + if input_query_bbox is not None: + reference_points = torch.cat( + [input_query_bbox, topk_coords_unact], dim=1).sigmoid() + query = torch.cat([input_query_label, query], dim=1) + if self.two_stage_keep_all_tokens: + hidden_states_enc = output_memory.unsqueeze(0) + referens_enc = enc_outputs_coord_unact.unsqueeze(0) + else: + hidden_states_enc = query_undetach.unsqueeze(0) + referens_enc = topk_coords_undetach.sigmoid().unsqueeze(0) + else: + hidden_states_enc, referens_enc = None, None + query = self.query_embedding.weight[:, None, :].repeat( + 1, bs, 1).transpose(0, 1) + reference_points = \ + self.refpoint_embedding.weight[:, None, :].repeat(1, bs, 1) + + if input_query_bbox is not None: + reference_points = torch.cat( + [input_query_bbox, reference_points], dim=1) + query = torch.cat([input_query_label, query], dim=1) + reference_points = reference_points.sigmoid() + + decoder_inputs_dict = dict( + query=query, reference_points=reference_points) + head_inputs_dict = dict( + hidden_states_enc=hidden_states_enc, referens_enc=referens_enc) + return decoder_inputs_dict, head_inputs_dict + + def forward_decoder(self, memory: Tensor, memory_mask: Tensor, + spatial_shapes: Tensor, level_start_index: Tensor, + valid_ratios: Tensor, humandet_attn_mask: Tensor, + human2pose_attn_mask: Tensor, input_query_bbox: Tensor, + input_query_label: Tensor, mask_dict: Dict) -> Dict: + """Forward with Transformer decoder. + + The forward procedure is defined as: + 'pre_decoder' -> 'decoder' + + Args: + memory (Tensor): The output embeddings of the Transformer encoder, + has shape (bs, num_feat_points, dim). + memory_mask (Tensor): ByteTensor, the padding mask of the memory, + has shape (bs, num_feat_points). + spatial_shapes (Tensor): Spatial shapes of features in all levels, + has shape (num_levels, 2), last dimension represents (h, w). + level_start_index (Tensor): The start index of each level. + A tensor has shape (num_levels, ) and can be represented + as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...]. + valid_ratios (Tensor): The ratios of the valid width and the valid + height relative to the width and the height of features in all + levels, has shape (bs, num_levels, 2). + humandet_attn_mask (Tensor): Human attention mask. + human2pose_attn_mask (Tensor): Human to pose attention mask. + input_query_bbox (Tensor): Denosing bbox query for training. + input_query_label (Tensor): Denosing label query for training. + + Returns: + dict: The dictionary of decoder outputs, which includes the + `hidden_states` of the decoder output and `references` including + the initial and intermediate reference_points. + """ + decoder_in, head_in = self.pre_decoder(memory, memory_mask, + spatial_shapes, + input_query_bbox, + input_query_label) + + inter_states, inter_references = self.decoder( + query=decoder_in['query'].transpose(0, 1), + value=memory.transpose(0, 1), + key_padding_mask=memory_mask, # for cross_attn + reference_points=decoder_in['reference_points'].transpose(0, 1), + spatial_shapes=spatial_shapes, + level_start_index=level_start_index, + valid_ratios=valid_ratios, + humandet_attn_mask=humandet_attn_mask, + human2pose_attn_mask=human2pose_attn_mask) + references = inter_references + decoder_outputs_dict = dict( + hidden_states=inter_states, + references=references, + mask_dict=mask_dict) + decoder_outputs_dict.update(head_in) + return decoder_outputs_dict + + def forward_out_head(self, batch_data_samples: OptSampleList, + hidden_states: List[Tensor], references: List[Tensor], + mask_dict: Dict, hidden_states_enc: Tensor, + referens_enc: Tensor) -> Tuple[Tensor]: + """Forward function.""" + out = self.out_head(hidden_states, references, mask_dict, + hidden_states_enc, referens_enc, + batch_data_samples) + return out + + def forward(self, + feats: Tuple[Tensor], + batch_data_samples: OptSampleList = None) -> Dict: + """Forward the network.""" + encoder_outputs_dict = self.forward_encoder(feats, batch_data_samples) + + decoder_outputs_dict = self.forward_decoder(**encoder_outputs_dict) + + head_outputs_dict = self.forward_out_head(batch_data_samples, + **decoder_outputs_dict) + return head_outputs_dict + + def predict(self, + feats: Features, + batch_data_samples: OptSampleList, + test_cfg: ConfigType = {}) -> Predictions: + """Predict results from features.""" + input_shapes = np.array( + [d.metainfo['input_size'] for d in batch_data_samples], ) + + if test_cfg.get('flip_test', False): + assert NotImplementedError + else: + batch_out = self.forward(feats, batch_data_samples) # (B, K, D) + + pred = self.decode(input_shapes, **batch_out) + return pred + + def decode(self, input_shapes: np.ndarray, pred_logits: Tensor, + pred_boxes: Tensor, pred_keypoints: Tensor): + """Select the final top-k keypoints, and decode the results from + normalize size to origin input size. + + Args: + input_shapes (Tensor): The size of input image. + pred_logits (Tensor): The result of score. + pred_boxes (Tensor): The result of bbox. + pred_keypoints (Tensor): The result of keypoints. + + Returns: + """ + + if self.data_decoder is None: + raise RuntimeError(f'The data decoder has not been set in \ + {self.__class__.__name__}. ' + 'Please set the data decoder configs in \ + the init parameters to ' + 'enable head methods `head.predict()` and \ + `head.decode()`') + + preds = [] + batch_size = input_shapes.shape[0] + + pred_logits = pred_logits.sigmoid() + pred_logits, pred_boxes, pred_keypoints = to_numpy( + [pred_logits, pred_boxes, pred_keypoints]) + + for b in range(batch_size): + keypoint, keypoint_score, bbox = self.data_decoder.decode( + input_shapes[b], pred_logits[b], pred_boxes[b], + pred_keypoints[b]) + + # pack outputs + preds.append( + InstanceData( + keypoints=keypoint, + keypoint_scores=keypoint_score, + boxes=bbox)) + + return preds + + @staticmethod + def get_valid_ratio(mask: Tensor) -> Tensor: + """Get the valid radios of feature map in a level. + + .. code:: text + + |---> valid_W <---| + ---+-----------------+-----+--- + A | | | A + | | | | | + | | | | | + valid_H | | | | + | | | | H + | | | | | + V | | | | + ---+-----------------+ | | + | | V + +-----------------------+--- + |---------> W <---------| + + The valid_ratios are defined as: + r_h = valid_H / H, r_w = valid_W / W + They are the factors to re-normalize the relative coordinates of the + image to the relative coordinates of the current level feature map. + + Args: + mask (Tensor): Binary mask of a feature map, has shape (bs, H, W). + + Returns: + Tensor: valid ratios [r_w, r_h] of a feature map, has shape (1, 2). + """ + _, H, W = mask.shape + valid_H = torch.sum(~mask[:, :, 0], 1) + valid_W = torch.sum(~mask[:, 0, :], 1) + valid_ratio_h = valid_H.float() / H + valid_ratio_w = valid_W.float() / W + valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1) + return valid_ratio + + def gen_encoder_output_proposals(self, memory: Tensor, memory_mask: Tensor, + spatial_shapes: Tensor + ) -> Tuple[Tensor, Tensor]: + """Generate proposals from encoded memory. The function will only be + used when `as_two_stage` is `True`. + + Args: + memory (Tensor): The output embeddings of the Transformer encoder, + has shape (bs, num_feat_points, dim). + memory_mask (Tensor): ByteTensor, the padding mask of the memory, + has shape (bs, num_feat_points). + spatial_shapes (Tensor): Spatial shapes of features in all levels, + has shape (num_levels, 2), last dimension represents (h, w). + + Returns: + tuple: A tuple of transformed memory and proposals. + + - output_memory (Tensor): The transformed memory for obtaining + top-k proposals, has shape (bs, num_feat_points, dim). + - output_proposals (Tensor): The inverse-normalized proposal, has + shape (batch_size, num_keys, 4) with the last dimension arranged + as (cx, cy, w, h). + """ + bs = memory.size(0) + proposals = [] + _cur = 0 # start index in the sequence of the current level + for lvl, (H, W) in enumerate(spatial_shapes): + mask_flatten_ = memory_mask[:, + _cur:(_cur + H * W)].view(bs, H, W, 1) + valid_H = torch.sum(~mask_flatten_[:, :, 0, 0], 1).unsqueeze(-1) + valid_W = torch.sum(~mask_flatten_[:, 0, :, 0], 1).unsqueeze(-1) + + grid_y, grid_x = torch.meshgrid( + torch.linspace( + 0, H - 1, H, dtype=torch.float32, device=memory.device), + torch.linspace( + 0, W - 1, W, dtype=torch.float32, device=memory.device)) + grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1) + + scale = torch.cat([valid_W, valid_H], 1).view(bs, 1, 1, 2) + grid = (grid.unsqueeze(0).expand(bs, -1, -1, -1) + 0.5) / scale + wh = torch.ones_like(grid) * 0.05 * (2.0**lvl) + proposal = torch.cat((grid, wh), -1).view(bs, -1, 4) + proposals.append(proposal) + _cur += (H * W) + output_proposals = torch.cat(proposals, 1) + output_proposals_valid = ((output_proposals > 0.01) & + (output_proposals < 0.99)).all( + -1, keepdim=True) + # inverse_sigmoid + output_proposals = torch.log(output_proposals / (1 - output_proposals)) + output_proposals = output_proposals.masked_fill( + memory_mask.unsqueeze(-1), float('inf')) + output_proposals = output_proposals.masked_fill( + ~output_proposals_valid, float('inf')) + + output_memory = memory + output_memory = output_memory.masked_fill( + memory_mask.unsqueeze(-1), float(0)) + output_memory = output_memory.masked_fill(~output_proposals_valid, + float(0)) + output_memory = self.memory_trans_fc(output_memory) + output_memory = self.memory_trans_norm(output_memory) + # [bs, sum(hw), 2] + return output_memory, output_proposals + + @property + def default_init_cfg(self): + init_cfg = [dict(type='Normal', layer=['Linear'], std=0.01, bias=0)] + return init_cfg + + def prepare_for_denosing(self, targets: OptSampleList, device): + """prepare for dn components in forward function.""" + if not self.training: + bs = len(targets) + attn_mask_infere = torch.zeros( + bs, + self.num_heads, + self.num_group * (self.num_body_points + 1), + self.num_group * (self.num_body_points + 1), + device=device, + dtype=torch.bool) + group_bbox_kpt = (self.num_body_points + 1) + kpt_index = [ + x for x in range(self.num_group * (self.num_body_points + 1)) + if x % (self.num_body_points + 1) == 0 + ] + for matchj in range(self.num_group * (self.num_body_points + 1)): + sj = (matchj // group_bbox_kpt) * group_bbox_kpt + ej = (matchj // group_bbox_kpt + 1) * group_bbox_kpt + if sj > 0: + attn_mask_infere[:, :, matchj, :sj] = True + if ej < self.num_group * (self.num_body_points + 1): + attn_mask_infere[:, :, matchj, ej:] = True + for match_x in range(self.num_group * (self.num_body_points + 1)): + if match_x % group_bbox_kpt == 0: + attn_mask_infere[:, :, match_x, kpt_index] = False + + attn_mask_infere = attn_mask_infere.flatten(0, 1) + return None, None, None, attn_mask_infere, None + + # targets, dn_scalar, noise_scale = dn_args + device = targets[0]['boxes'].device + bs = len(targets) + refine_queries_num = self.refine_queries_num + + # gather gt boxes and labels + gt_boxes = [t['boxes'] for t in targets] + gt_labels = [t['labels'] for t in targets] + gt_keypoints = [t['keypoints'] for t in targets] + + # repeat them + def get_indices_for_repeat(now_num, target_num, device='cuda'): + """ + Input: + - now_num: int + - target_num: int + Output: + - indices: tensor[target_num] + """ + out_indice = [] + base_indice = torch.arange(now_num).to(device) + multiplier = target_num // now_num + out_indice.append(base_indice.repeat(multiplier)) + residue = target_num % now_num + out_indice.append(base_indice[torch.randint( + 0, now_num, (residue, ), device=device)]) + return torch.cat(out_indice) + + gt_boxes_expand = [] + gt_labels_expand = [] + gt_keypoints_expand = [] + for idx, (gt_boxes_i, gt_labels_i, gt_keypoint_i) in enumerate( + zip(gt_boxes, gt_labels, gt_keypoints)): + num_gt_i = gt_boxes_i.shape[0] + if num_gt_i > 0: + indices = get_indices_for_repeat(num_gt_i, refine_queries_num, + device) + gt_boxes_expand_i = gt_boxes_i[indices] # num_dn, 4 + gt_labels_expand_i = gt_labels_i[indices] + gt_keypoints_expand_i = gt_keypoint_i[indices] + else: + # all negative samples when no gt boxes + gt_boxes_expand_i = torch.rand( + refine_queries_num, 4, device=device) + gt_labels_expand_i = torch.ones( + refine_queries_num, dtype=torch.int64, + device=device) * int(self.num_classes) + gt_keypoints_expand_i = torch.rand( + refine_queries_num, + self.num_body_points * 3, + device=device) + gt_boxes_expand.append(gt_boxes_expand_i) + gt_labels_expand.append(gt_labels_expand_i) + gt_keypoints_expand.append(gt_keypoints_expand_i) + gt_boxes_expand = torch.stack(gt_boxes_expand) + gt_labels_expand = torch.stack(gt_labels_expand) + gt_keypoints_expand = torch.stack(gt_keypoints_expand) + knwon_boxes_expand = gt_boxes_expand.clone() + knwon_labels_expand = gt_labels_expand.clone() + + # add noise + if self.denosing_cfg['dn_label_noise_ratio'] > 0: + prob = torch.rand_like(knwon_labels_expand.float()) + chosen_indice = prob < self.denosing_cfg['dn_label_noise_ratio'] + new_label = torch.randint_like( + knwon_labels_expand[chosen_indice], 0, + self.dn_labelbook_size) # randomly put a new one here + knwon_labels_expand[chosen_indice] = new_label + + if self.denosing_cfg['dn_box_noise_scale'] > 0: + diff = torch.zeros_like(knwon_boxes_expand) + diff[..., :2] = knwon_boxes_expand[..., 2:] / 2 + diff[..., 2:] = knwon_boxes_expand[..., 2:] + knwon_boxes_expand += torch.mul( + (torch.rand_like(knwon_boxes_expand) * 2 - 1.0), + diff) * self.denosing_cfg['dn_box_noise_scale'] + knwon_boxes_expand = knwon_boxes_expand.clamp(min=0.0, max=1.0) + + input_query_label = self.label_enc(knwon_labels_expand) + input_query_bbox = inverse_sigmoid(knwon_boxes_expand) + + # prepare mask + + if 'group2group' in self.denosing_cfg['dn_attn_mask_type_list']: + attn_mask = torch.zeros( + bs, + self.num_heads, + refine_queries_num + self.num_queries, + refine_queries_num + self.num_queries, + device=device, + dtype=torch.bool) + attn_mask[:, :, refine_queries_num:, :refine_queries_num] = True + for idx, (gt_boxes_i, + gt_labels_i) in enumerate(zip(gt_boxes, gt_labels)): + num_gt_i = gt_boxes_i.shape[0] + if num_gt_i == 0: + continue + for matchi in range(refine_queries_num): + si = (matchi // num_gt_i) * num_gt_i + ei = (matchi // num_gt_i + 1) * num_gt_i + if si > 0: + attn_mask[idx, :, matchi, :si] = True + if ei < refine_queries_num: + attn_mask[idx, :, matchi, ei:refine_queries_num] = True + attn_mask = attn_mask.flatten(0, 1) + + if 'group2group' in self.denosing_cfg['dn_attn_mask_type_list']: + attn_mask2 = torch.zeros( + bs, + self.num_heads, + refine_queries_num + self.num_group * + (self.num_body_points + 1), + refine_queries_num + self.num_group * + (self.num_body_points + 1), + device=device, + dtype=torch.bool) + attn_mask2[:, :, refine_queries_num:, :refine_queries_num] = True + group_bbox_kpt = (self.num_body_points + 1) + kpt_index = [ + x for x in range(self.num_group * (self.num_body_points + 1)) + if x % (self.num_body_points + 1) == 0 + ] + for matchj in range(self.num_group * (self.num_body_points + 1)): + sj = (matchj // group_bbox_kpt) * group_bbox_kpt + ej = (matchj // group_bbox_kpt + 1) * group_bbox_kpt + if sj > 0: + attn_mask2[:, :, refine_queries_num:, + refine_queries_num:][:, :, matchj, :sj] = True + if ej < self.num_group * (self.num_body_points + 1): + attn_mask2[:, :, refine_queries_num:, + refine_queries_num:][:, :, matchj, ej:] = True + + for match_x in range(self.num_group * (self.num_body_points + 1)): + if match_x % group_bbox_kpt == 0: + attn_mask2[:, :, refine_queries_num:, + refine_queries_num:][:, :, match_x, + kpt_index] = False + + for idx, (gt_boxes_i, + gt_labels_i) in enumerate(zip(gt_boxes, gt_labels)): + num_gt_i = gt_boxes_i.shape[0] + if num_gt_i == 0: + continue + for matchi in range(refine_queries_num): + si = (matchi // num_gt_i) * num_gt_i + ei = (matchi // num_gt_i + 1) * num_gt_i + if si > 0: + attn_mask2[idx, :, matchi, :si] = True + if ei < refine_queries_num: + attn_mask2[idx, :, matchi, + ei:refine_queries_num] = True + attn_mask2 = attn_mask2.flatten(0, 1) + + mask_dict = { + 'pad_size': refine_queries_num, + 'known_bboxs': gt_boxes_expand, + 'known_labels': gt_labels_expand, + 'known_keypoints': gt_keypoints_expand + } + + return input_query_label, input_query_bbox, \ + attn_mask, attn_mask2, mask_dict + + def loss(self, + feats: Tuple[Tensor], + batch_data_samples: OptSampleList, + train_cfg: OptConfigType = {}) -> dict: + """Calculate losses from a batch of inputs and data samples.""" + + assert NotImplementedError diff --git a/mmpose/models/heads/transformer_heads/transformers/__init__.py b/mmpose/models/heads/transformer_heads/transformers/__init__.py new file mode 100644 index 0000000000..d678b2522c --- /dev/null +++ b/mmpose/models/heads/transformer_heads/transformers/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .deformable_detr_layers import (DeformableDetrTransformerDecoder, + DeformableDetrTransformerDecoderLayer, + DeformableDetrTransformerEncoder, + DeformableDetrTransformerEncoderLayer) +from .detr_layers import (DetrTransformerDecoder, DetrTransformerDecoderLayer, + DetrTransformerEncoder, DetrTransformerEncoderLayer) +from .utils import MLP, PositionEmbeddingSineHW, inverse_sigmoid + +__all__ = [ + 'DetrTransformerEncoder', 'DetrTransformerDecoder', + 'DetrTransformerEncoderLayer', 'DetrTransformerDecoderLayer', + 'DeformableDetrTransformerEncoder', 'DeformableDetrTransformerDecoder', + 'DeformableDetrTransformerEncoderLayer', + 'DeformableDetrTransformerDecoderLayer', 'inverse_sigmoid', + 'PositionEmbeddingSineHW', 'MLP' +] diff --git a/mmpose/models/heads/transformer_heads/transformers/deformable_detr_layers.py b/mmpose/models/heads/transformer_heads/transformers/deformable_detr_layers.py new file mode 100644 index 0000000000..49dd526092 --- /dev/null +++ b/mmpose/models/heads/transformer_heads/transformers/deformable_detr_layers.py @@ -0,0 +1,251 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional, Tuple, Union + +import torch +from mmcv.cnn import build_norm_layer +from mmcv.cnn.bricks.transformer import FFN, MultiheadAttention +from mmcv.ops import MultiScaleDeformableAttention +from mmengine.model import ModuleList +from torch import Tensor, nn + +from .detr_layers import (DetrTransformerDecoder, DetrTransformerDecoderLayer, + DetrTransformerEncoder, DetrTransformerEncoderLayer) +from .utils import inverse_sigmoid + + +class DeformableDetrTransformerEncoder(DetrTransformerEncoder): + """Transformer encoder of Deformable DETR.""" + + def _init_layers(self) -> None: + """Initialize encoder layers.""" + self.layers = ModuleList([ + DeformableDetrTransformerEncoderLayer(**self.layer_cfg) + for _ in range(self.num_layers) + ]) + self.embed_dims = self.layers[0].embed_dims + + def forward(self, query: Tensor, query_pos: Tensor, + key_padding_mask: Tensor, spatial_shapes: Tensor, + level_start_index: Tensor, valid_ratios: Tensor, + **kwargs) -> Tensor: + """Forward function of Transformer encoder. + + Args: + query (Tensor): The input query, has shape (bs, num_queries, dim). + query_pos (Tensor): The positional encoding for query, has shape + (bs, num_queries, dim). + key_padding_mask (Tensor): The `key_padding_mask` of `self_attn` + input. ByteTensor, has shape (bs, num_queries). + spatial_shapes (Tensor): Spatial shapes of features in all levels, + has shape (num_levels, 2), last dimension represents (h, w). + level_start_index (Tensor): The start index of each level. + A tensor has shape (num_levels, ) and can be represented + as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...]. + valid_ratios (Tensor): The ratios of the valid width and the valid + height relative to the width and the height of features in all + levels, has shape (bs, num_levels, 2). + + Returns: + Tensor: Output queries of Transformer encoder, which is also + called 'encoder output embeddings' or 'memory', has shape + (bs, num_queries, dim) + """ + reference_points = self.get_encoder_reference_points( + spatial_shapes, valid_ratios, device=query.device) + for layer in self.layers: + query = layer( + query=query, + query_pos=query_pos, + key_padding_mask=key_padding_mask, + spatial_shapes=spatial_shapes, + level_start_index=level_start_index, + valid_ratios=valid_ratios, + reference_points=reference_points, + **kwargs) + return query + + @staticmethod + def get_encoder_reference_points(spatial_shapes: Tensor, + valid_ratios: Tensor, + device: Union[torch.device, + str]) -> Tensor: + """Get the reference points used in encoder. + + Args: + spatial_shapes (Tensor): Spatial shapes of features in all levels, + has shape (num_levels, 2), last dimension represents (h, w). + valid_ratios (Tensor): The ratios of the valid width and the valid + height relative to the width and the height of features in all + levels, has shape (bs, num_levels, 2). + device (obj:`device` or str): The device acquired by the + `reference_points`. + + Returns: + Tensor: Reference points used in decoder, has shape (bs, length, + num_levels, 2). + """ + + reference_points_list = [] + for lvl, (H, W) in enumerate(spatial_shapes): + ref_y, ref_x = torch.meshgrid( + torch.linspace( + 0.5, H - 0.5, H, dtype=torch.float32, device=device), + torch.linspace( + 0.5, W - 0.5, W, dtype=torch.float32, device=device)) + ref_y = ref_y.reshape(-1)[None] / ( + valid_ratios[:, None, lvl, 1] * H) + ref_x = ref_x.reshape(-1)[None] / ( + valid_ratios[:, None, lvl, 0] * W) + ref = torch.stack((ref_x, ref_y), -1) + reference_points_list.append(ref) + reference_points = torch.cat(reference_points_list, 1) + # [bs, sum(hw), num_level, 2] + reference_points = reference_points[:, :, None] * valid_ratios[:, None] + return reference_points + + +class DeformableDetrTransformerDecoder(DetrTransformerDecoder): + """Transformer Decoder of Deformable DETR.""" + + def _init_layers(self) -> None: + """Initialize decoder layers.""" + self.layers = ModuleList([ + DeformableDetrTransformerDecoderLayer(**self.layer_cfg) + for _ in range(self.num_layers) + ]) + self.embed_dims = self.layers[0].embed_dims + if self.post_norm_cfg is not None: + raise ValueError('There is not post_norm in ' + f'{self._get_name()}') + + def forward(self, + query: Tensor, + query_pos: Tensor, + value: Tensor, + key_padding_mask: Tensor, + reference_points: Tensor, + spatial_shapes: Tensor, + level_start_index: Tensor, + valid_ratios: Tensor, + reg_branches: Optional[nn.Module] = None, + **kwargs) -> Tuple[Tensor]: + """Forward function of Transformer decoder. + + Args: + query (Tensor): The input queries, has shape (bs, num_queries, + dim). + query_pos (Tensor): The input positional query, has shape + (bs, num_queries, dim). It will be added to `query` before + forward function. + value (Tensor): The input values, has shape (bs, num_value, dim). + key_padding_mask (Tensor): The `key_padding_mask` of `cross_attn` + input. ByteTensor, has shape (bs, num_value). + reference_points (Tensor): The initial reference, has shape + (bs, num_queries, 4) with the last dimension arranged as + (cx, cy, w, h) when `as_two_stage` is `True`, otherwise has + shape (bs, num_queries, 2) with the last dimension arranged + as (cx, cy). + spatial_shapes (Tensor): Spatial shapes of features in all levels, + has shape (num_levels, 2), last dimension represents (h, w). + level_start_index (Tensor): The start index of each level. + A tensor has shape (num_levels, ) and can be represented + as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...]. + valid_ratios (Tensor): The ratios of the valid width and the valid + height relative to the width and the height of features in all + levels, has shape (bs, num_levels, 2). + reg_branches: (obj:`nn.ModuleList`, optional): Used for refining + the regression results. Only would be passed when + `with_box_refine` is `True`, otherwise would be `None`. + + Returns: + tuple[Tensor]: Outputs of Deformable Transformer Decoder. + + - output (Tensor): Output embeddings of the last decoder, has + shape (num_queries, bs, embed_dims) when `return_intermediate` + is `False`. Otherwise, Intermediate output embeddings of all + decoder layers, has shape (num_decoder_layers, num_queries, bs, + embed_dims). + - reference_points (Tensor): The reference of the last decoder + layer, has shape (bs, num_queries, 4) when `return_intermediate` + is `False`. Otherwise, Intermediate references of all decoder + layers, has shape (num_decoder_layers, bs, num_queries, 4). The + coordinates are arranged as (cx, cy, w, h) + """ + output = query + intermediate = [] + intermediate_reference_points = [] + for layer_id, layer in enumerate(self.layers): + if reference_points.shape[-1] == 4: + reference_points_input = \ + reference_points[:, :, None] * \ + torch.cat([valid_ratios, valid_ratios], -1)[:, None] + else: + assert reference_points.shape[-1] == 2 + reference_points_input = \ + reference_points[:, :, None] * \ + valid_ratios[:, None] + output = layer( + output, + query_pos=query_pos, + value=value, + key_padding_mask=key_padding_mask, + spatial_shapes=spatial_shapes, + level_start_index=level_start_index, + valid_ratios=valid_ratios, + reference_points=reference_points_input, + **kwargs) + + if reg_branches is not None: + tmp_reg_preds = reg_branches[layer_id](output) + if reference_points.shape[-1] == 4: + new_reference_points = tmp_reg_preds + inverse_sigmoid( + reference_points) + new_reference_points = new_reference_points.sigmoid() + else: + assert reference_points.shape[-1] == 2 + new_reference_points = tmp_reg_preds + new_reference_points[..., :2] = tmp_reg_preds[ + ..., :2] + inverse_sigmoid(reference_points) + new_reference_points = new_reference_points.sigmoid() + reference_points = new_reference_points.detach() + + if self.return_intermediate: + intermediate.append(output) + intermediate_reference_points.append(reference_points) + + if self.return_intermediate: + return torch.stack(intermediate), torch.stack( + intermediate_reference_points) + + return output, reference_points + + +class DeformableDetrTransformerEncoderLayer(DetrTransformerEncoderLayer): + """Encoder layer of Deformable DETR.""" + + def _init_layers(self) -> None: + """Initialize self_attn, ffn, and norms.""" + self.self_attn = MultiScaleDeformableAttention(**self.self_attn_cfg) + self.embed_dims = self.self_attn.embed_dims + self.ffn = FFN(**self.ffn_cfg) + norms_list = [ + build_norm_layer(self.norm_cfg, self.embed_dims)[1] + for _ in range(2) + ] + self.norms = ModuleList(norms_list) + + +class DeformableDetrTransformerDecoderLayer(DetrTransformerDecoderLayer): + """Decoder layer of Deformable DETR.""" + + def _init_layers(self) -> None: + """Initialize self_attn, cross-attn, ffn, and norms.""" + self.self_attn = MultiheadAttention(**self.self_attn_cfg) + self.cross_attn = MultiScaleDeformableAttention(**self.cross_attn_cfg) + self.embed_dims = self.self_attn.embed_dims + self.ffn = FFN(**self.ffn_cfg) + norms_list = [ + build_norm_layer(self.norm_cfg, self.embed_dims)[1] + for _ in range(3) + ] + self.norms = ModuleList(norms_list) diff --git a/mmpose/models/heads/transformer_heads/transformers/detr_layers.py b/mmpose/models/heads/transformer_heads/transformers/detr_layers.py new file mode 100644 index 0000000000..d8e6772d7e --- /dev/null +++ b/mmpose/models/heads/transformer_heads/transformers/detr_layers.py @@ -0,0 +1,333 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Union + +import torch +from mmcv.cnn import build_norm_layer +from mmcv.cnn.bricks.transformer import FFN, MultiheadAttention +from mmengine import ConfigDict +from mmengine.model import BaseModule, ModuleList +from torch import Tensor + +from mmpose.utils.typing import ConfigType, OptConfigType + + +class DetrTransformerEncoder(BaseModule): + """Encoder of DETR. + + Args: + num_layers (int): Number of encoder layers. + layer_cfg (:obj:`ConfigDict` or dict): the config of each encoder + layer. All the layers will share the same config. + init_cfg (:obj:`ConfigDict` or dict, optional): the config to control + the initialization. Defaults to None. + """ + + def __init__(self, + num_layers: int, + layer_cfg: ConfigType, + init_cfg: OptConfigType = None) -> None: + + super().__init__(init_cfg=init_cfg) + self.num_layers = num_layers + self.layer_cfg = layer_cfg + self._init_layers() + + def _init_layers(self) -> None: + """Initialize encoder layers.""" + self.layers = ModuleList([ + DetrTransformerEncoderLayer(**self.layer_cfg) + for _ in range(self.num_layers) + ]) + self.embed_dims = self.layers[0].embed_dims + + def forward(self, query: Tensor, query_pos: Tensor, + key_padding_mask: Tensor, **kwargs) -> Tensor: + """Forward function of encoder. + + Args: + query (Tensor): Input queries of encoder, has shape + (bs, num_queries, dim). + query_pos (Tensor): The positional embeddings of the queries, has + shape (bs, num_queries, dim). + key_padding_mask (Tensor): The `key_padding_mask` of `self_attn` + input. ByteTensor, has shape (bs, num_queries). + + Returns: + Tensor: Has shape (bs, num_queries, dim) if `batch_first` is + `True`, otherwise (num_queries, bs, dim). + """ + for layer in self.layers: + query = layer(query, query_pos, key_padding_mask, **kwargs) + return query + + +class DetrTransformerDecoder(BaseModule): + """Decoder of DETR. + + Args: + num_layers (int): Number of decoder layers. + layer_cfg (:obj:`ConfigDict` or dict): the config of each encoder + layer. All the layers will share the same config. + post_norm_cfg (:obj:`ConfigDict` or dict, optional): Config of the + post normalization layer. Defaults to `LN`. + return_intermediate (bool, optional): Whether to return outputs of + intermediate layers. Defaults to `True`, + init_cfg (:obj:`ConfigDict` or dict, optional): the config to control + the initialization. Defaults to None. + """ + + def __init__(self, + num_layers: int, + layer_cfg: ConfigType, + post_norm_cfg: OptConfigType = dict(type='LN'), + return_intermediate: bool = True, + init_cfg: Union[dict, ConfigDict] = None) -> None: + super().__init__(init_cfg=init_cfg) + self.layer_cfg = layer_cfg + self.num_layers = num_layers + self.post_norm_cfg = post_norm_cfg + self.return_intermediate = return_intermediate + self._init_layers() + + def _init_layers(self) -> None: + """Initialize decoder layers.""" + self.layers = ModuleList([ + DetrTransformerDecoderLayer(**self.layer_cfg) + for _ in range(self.num_layers) + ]) + self.embed_dims = self.layers[0].embed_dims + self.post_norm = build_norm_layer(self.post_norm_cfg, + self.embed_dims)[1] + + def forward(self, query: Tensor, key: Tensor, value: Tensor, + query_pos: Tensor, key_pos: Tensor, key_padding_mask: Tensor, + **kwargs) -> Tensor: + """Forward function of decoder + Args: + query (Tensor): The input query, has shape (bs, num_queries, dim). + key (Tensor): The input key, has shape (bs, num_keys, dim). + value (Tensor): The input value with the same shape as `key`. + query_pos (Tensor): The positional encoding for `query`, with the + same shape as `query`. + key_pos (Tensor): The positional encoding for `key`, with the + same shape as `key`. + key_padding_mask (Tensor): The `key_padding_mask` of `cross_attn` + input. ByteTensor, has shape (bs, num_value). + + Returns: + Tensor: The forwarded results will have shape + (num_decoder_layers, bs, num_queries, dim) if + `return_intermediate` is `True` else (1, bs, num_queries, dim). + """ + intermediate = [] + for layer in self.layers: + query = layer( + query, + key=key, + value=value, + query_pos=query_pos, + key_pos=key_pos, + key_padding_mask=key_padding_mask, + **kwargs) + if self.return_intermediate: + intermediate.append(self.post_norm(query)) + query = self.post_norm(query) + + if self.return_intermediate: + return torch.stack(intermediate) + + return query.unsqueeze(0) + + +class DetrTransformerEncoderLayer(BaseModule): + """Implements encoder layer in DETR transformer. + + Args: + self_attn_cfg (:obj:`ConfigDict` or dict, optional): Config for self + attention. + ffn_cfg (:obj:`ConfigDict` or dict, optional): Config for FFN. + norm_cfg (:obj:`ConfigDict` or dict, optional): Config for + normalization layers. All the layers will share the same + config. Defaults to `LN`. + init_cfg (:obj:`ConfigDict` or dict, optional): Config to control + the initialization. Defaults to None. + """ + + def __init__(self, + self_attn_cfg: OptConfigType = dict( + embed_dims=256, num_heads=8, dropout=0.0), + ffn_cfg: OptConfigType = dict( + embed_dims=256, + feedforward_channels=1024, + num_fcs=2, + ffn_drop=0., + act_cfg=dict(type='ReLU', inplace=True)), + norm_cfg: OptConfigType = dict(type='LN'), + init_cfg: OptConfigType = None) -> None: + + super().__init__(init_cfg=init_cfg) + + self.self_attn_cfg = self_attn_cfg + self.ffn_cfg = ffn_cfg + self.norm_cfg = norm_cfg + self._init_layers() + + def _init_layers(self) -> None: + """Initialize self-attention, FFN, and normalization.""" + self.self_attn = MultiheadAttention(**self.self_attn_cfg) + self.embed_dims = self.self_attn.embed_dims + self.ffn = FFN(**self.ffn_cfg) + norms_list = [ + build_norm_layer(self.norm_cfg, self.embed_dims)[1] + for _ in range(2) + ] + self.norms = ModuleList(norms_list) + + def forward(self, query: Tensor, query_pos: Tensor, + key_padding_mask: Tensor, **kwargs) -> Tensor: + """Forward function of an encoder layer. + + Args: + query (Tensor): The input query, has shape (bs, num_queries, dim). + query_pos (Tensor): The positional encoding for query, with + the same shape as `query`. + key_padding_mask (Tensor): The `key_padding_mask` of `self_attn` + input. ByteTensor. has shape (bs, num_queries). + Returns: + Tensor: forwarded results, has shape (bs, num_queries, dim). + """ + query = self.self_attn( + query=query, + key=query, + value=query, + query_pos=query_pos, + key_pos=query_pos, + key_padding_mask=key_padding_mask, + **kwargs) + query = self.norms[0](query) + query = self.ffn(query) + query = self.norms[1](query) + + return query + + +class DetrTransformerDecoderLayer(BaseModule): + """Implements decoder layer in DETR transformer. + + Args: + self_attn_cfg (:obj:`ConfigDict` or dict, optional): Config for self + attention. + cross_attn_cfg (:obj:`ConfigDict` or dict, optional): Config for cross + attention. + ffn_cfg (:obj:`ConfigDict` or dict, optional): Config for FFN. + norm_cfg (:obj:`ConfigDict` or dict, optional): Config for + normalization layers. All the layers will share the same + config. Defaults to `LN`. + init_cfg (:obj:`ConfigDict` or dict, optional): Config to control + the initialization. Defaults to None. + """ + + def __init__(self, + self_attn_cfg: OptConfigType = dict( + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + cross_attn_cfg: OptConfigType = dict( + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + ffn_cfg: OptConfigType = dict( + embed_dims=256, + feedforward_channels=1024, + num_fcs=2, + ffn_drop=0., + act_cfg=dict(type='ReLU', inplace=True), + ), + norm_cfg: OptConfigType = dict(type='LN'), + init_cfg: OptConfigType = None) -> None: + + super().__init__(init_cfg=init_cfg) + + self.self_attn_cfg = self_attn_cfg + self.cross_attn_cfg = cross_attn_cfg + self.ffn_cfg = ffn_cfg + self.norm_cfg = norm_cfg + self._init_layers() + + def _init_layers(self) -> None: + """Initialize self-attention, FFN, and normalization.""" + self.self_attn = MultiheadAttention(**self.self_attn_cfg) + self.cross_attn = MultiheadAttention(**self.cross_attn_cfg) + self.embed_dims = self.self_attn.embed_dims + self.ffn = FFN(**self.ffn_cfg) + norms_list = [ + build_norm_layer(self.norm_cfg, self.embed_dims)[1] + for _ in range(3) + ] + self.norms = ModuleList(norms_list) + + def forward(self, + query: Tensor, + key: Tensor = None, + value: Tensor = None, + query_pos: Tensor = None, + key_pos: Tensor = None, + self_attn_mask: Tensor = None, + cross_attn_mask: Tensor = None, + key_padding_mask: Tensor = None, + **kwargs) -> Tensor: + """ + Args: + query (Tensor): The input query, has shape (bs, num_queries, dim). + key (Tensor, optional): The input key, has shape (bs, num_keys, + dim). If `None`, the `query` will be used. Defaults to `None`. + value (Tensor, optional): The input value, has the same shape as + `key`, as in `nn.MultiheadAttention.forward`. If `None`, the + `key` will be used. Defaults to `None`. + query_pos (Tensor, optional): The positional encoding for `query`, + has the same shape as `query`. If not `None`, it will be added + to `query` before forward function. Defaults to `None`. + key_pos (Tensor, optional): The positional encoding for `key`, has + the same shape as `key`. If not `None`, it will be added to + `key` before forward function. If None, and `query_pos` has the + same shape as `key`, then `query_pos` will be used for + `key_pos`. Defaults to None. + self_attn_mask (Tensor, optional): ByteTensor mask, has shape + (num_queries, num_keys), as in `nn.MultiheadAttention.forward`. + Defaults to None. + cross_attn_mask (Tensor, optional): ByteTensor mask, has shape + (num_queries, num_keys), as in `nn.MultiheadAttention.forward`. + Defaults to None. + key_padding_mask (Tensor, optional): The `key_padding_mask` of + `self_attn` input. ByteTensor, has shape (bs, num_value). + Defaults to None. + + Returns: + Tensor: forwarded results, has shape (bs, num_queries, dim). + """ + + query = self.self_attn( + query=query, + key=query, + value=query, + query_pos=query_pos, + key_pos=query_pos, + attn_mask=self_attn_mask, + **kwargs) + query = self.norms[0](query) + query = self.cross_attn( + query=query, + key=key, + value=value, + query_pos=query_pos, + key_pos=key_pos, + attn_mask=cross_attn_mask, + key_padding_mask=key_padding_mask, + **kwargs) + query = self.norms[1](query) + query = self.ffn(query) + query = self.norms[2](query) + + return query diff --git a/mmpose/models/heads/transformer_heads/transformers/utils.py b/mmpose/models/heads/transformer_heads/transformers/utils.py new file mode 100644 index 0000000000..e9a1d2abaf --- /dev/null +++ b/mmpose/models/heads/transformer_heads/transformers/utils.py @@ -0,0 +1,114 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math + +import torch +import torch.nn.functional as F +from mmcv.cnn import Linear +from mmengine.model import BaseModule, ModuleList +from torch import Tensor + + +def inverse_sigmoid(x: Tensor, eps: float = 1e-3) -> Tensor: + """Inverse function of sigmoid. + + Args: + x (Tensor): The tensor to do the inverse. + eps (float): EPS avoid numerical overflow. Defaults 1e-5. + Returns: + Tensor: The x has passed the inverse function of sigmoid, has the same + shape with input. + """ + x = x.clamp(min=0, max=1) + x1 = x.clamp(min=eps) + x2 = (1 - x).clamp(min=eps) + return torch.log(x1 / x2) + + +class MLP(BaseModule): + """Very simple multi-layer perceptron (also called FFN) with relu. Mostly + used in DETR series detectors. + + Args: + input_dim (int): Feature dim of the input tensor. + hidden_dim (int): Feature dim of the hidden layer. + output_dim (int): Feature dim of the output tensor. + num_layers (int): Number of FFN layers. As the last + layer of MLP only contains FFN (Linear). + """ + + def __init__(self, input_dim: int, hidden_dim: int, output_dim: int, + num_layers: int) -> None: + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = ModuleList( + Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) + + def forward(self, x: Tensor) -> Tensor: + """Forward function of MLP. + + Args: + x (Tensor): The input feature, has shape + (num_queries, bs, input_dim). + Returns: + Tensor: The output feature, has shape + (num_queries, bs, output_dim). + """ + for i, layer in enumerate(self.layers): + x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + return x + + +class PositionEmbeddingSineHW(BaseModule): + """This is a more standard version of the position embedding, very similar + to the one used by the Attention is all you need paper, generalized to work + on images.""" + + def __init__(self, + num_pos_feats=64, + temperatureH=10000, + temperatureW=10000, + normalize=False, + scale=None): + super().__init__() + self.num_pos_feats = num_pos_feats + self.temperatureH = temperatureH + self.temperatureW = temperatureW + self.normalize = normalize + if scale is not None and normalize is False: + raise ValueError('normalize should be True if scale is passed') + if scale is None: + scale = 2 * math.pi + self.scale = scale + + def forward(self, mask: Tensor): + + assert mask is not None + not_mask = ~mask + y_embed = not_mask.cumsum(1, dtype=torch.float32) + x_embed = not_mask.cumsum(2, dtype=torch.float32) + + if self.normalize: + eps = 1e-6 + y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale + + dim_tx = torch.arange( + self.num_pos_feats, dtype=torch.float32, device=mask.device) + dim_tx = self.temperatureW**(2 * (dim_tx // 2) / self.num_pos_feats) + pos_x = x_embed[:, :, :, None] / dim_tx + + dim_ty = torch.arange( + self.num_pos_feats, dtype=torch.float32, device=mask.device) + dim_ty = self.temperatureH**(2 * (dim_ty // 2) / self.num_pos_feats) + pos_y = y_embed[:, :, :, None] / dim_ty + + pos_x = torch.stack( + (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), + dim=4).flatten(3) + pos_y = torch.stack( + (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), + dim=4).flatten(3) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + + return pos diff --git a/mmpose/models/necks/__init__.py b/mmpose/models/necks/__init__.py index b4f9105cb3..ff9cc9c2a7 100644 --- a/mmpose/models/necks/__init__.py +++ b/mmpose/models/necks/__init__.py @@ -1,9 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. +from .chanel_mapper_neck import ChannelMapper from .fmap_proc_neck import FeatureMapProcessor from .fpn import FPN from .gap_neck import GlobalAveragePooling from .posewarper_neck import PoseWarperNeck __all__ = [ - 'GlobalAveragePooling', 'PoseWarperNeck', 'FPN', 'FeatureMapProcessor' + 'GlobalAveragePooling', 'PoseWarperNeck', 'FPN', 'FeatureMapProcessor', + 'ChannelMapper' ] diff --git a/mmpose/models/necks/chanel_mapper_neck.py b/mmpose/models/necks/chanel_mapper_neck.py new file mode 100644 index 0000000000..f424faf729 --- /dev/null +++ b/mmpose/models/necks/chanel_mapper_neck.py @@ -0,0 +1,109 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Tuple, Union + +import torch.nn as nn +from mmcv.cnn import ConvModule +from mmengine.model import BaseModule +from torch import Tensor + +from mmpose.registry import MODELS +from mmpose.utils.typing import OptConfigType, OptMultiConfig + + +@MODELS.register_module() +class ChannelMapper(BaseModule): + """Channel Mapper to reduce/increase channels of backbone features. + + This is used to reduce/increase channels of backbone features. + + Args: + in_channels (List[int]): Number of input channels per scale. + out_channels (int): Number of output channels (used at each scale). + kernel_size (int, optional): kernel_size for reducing channels (used + at each scale). Default: 3. + conv_cfg (:obj:`ConfigDict` or dict, optional): Config dict for + convolution layer. Default: None. + norm_cfg (:obj:`ConfigDict` or dict, optional): Config dict for + normalization layer. Default: None. + act_cfg (:obj:`ConfigDict` or dict, optional): Config dict for + activation layer in ConvModule. Default: dict(type='ReLU'). + num_outs (int, optional): Number of output feature maps. There would + be extra_convs when num_outs larger than the length of in_channels. + init_cfg (:obj:`ConfigDict` or dict or list[:obj:`ConfigDict` or dict], + optional): Initialization config dict. + Example: + >>> import torch + >>> in_channels = [2, 3, 5, 7] + >>> scales = [340, 170, 84, 43] + >>> inputs = [torch.rand(1, c, s, s) + ... for c, s in zip(in_channels, scales)] + >>> self = ChannelMapper(in_channels, 11, 3).eval() + >>> outputs = self.forward(inputs) + >>> for i in range(len(outputs)): + ... print(f'outputs[{i}].shape = {outputs[i].shape}') + outputs[0].shape = torch.Size([1, 11, 340, 340]) + outputs[1].shape = torch.Size([1, 11, 170, 170]) + outputs[2].shape = torch.Size([1, 11, 84, 84]) + outputs[3].shape = torch.Size([1, 11, 43, 43]) + """ + + def __init__( + self, + in_channels: List[int], + out_channels: int, + kernel_size: int = 3, + conv_cfg: OptConfigType = None, + norm_cfg: OptConfigType = None, + act_cfg: OptConfigType = dict(type='ReLU'), + num_outs: int = None, + bias: Union[bool, str] = True, + init_cfg: OptMultiConfig = dict( + type='Xavier', layer='Conv2d', distribution='uniform') + ) -> None: + super().__init__(init_cfg=init_cfg) + assert isinstance(in_channels, list) + self.extra_convs = None + if num_outs is None: + num_outs = len(in_channels) + self.convs = nn.ModuleList() + for in_channel in in_channels: + self.convs.append( + ConvModule( + in_channel, + out_channels, + kernel_size, + bias=bias, + padding=(kernel_size - 1) // 2, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + if num_outs > len(in_channels): + self.extra_convs = nn.ModuleList() + for i in range(len(in_channels), num_outs): + if i == len(in_channels): + in_channel = in_channels[-1] + else: + in_channel = out_channels + self.extra_convs.append( + ConvModule( + in_channel, + out_channels, + 3, + stride=2, + padding=1, + bias=bias, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + + def forward(self, inputs: Tuple[Tensor]) -> Tuple[Tensor]: + """Forward function.""" + assert len(inputs) == len(self.convs) + outs = [self.convs[i](inputs[i]) for i in range(len(inputs))] + if self.extra_convs: + for i in range(len(self.extra_convs)): + if i == 0: + outs.append(self.extra_convs[0](inputs[-1])) + else: + outs.append(self.extra_convs[i](outs[-1])) + return tuple(outs) diff --git a/tests/test_datasets/test_transforms/test_bottomup_transforms.py b/tests/test_datasets/test_transforms/test_bottomup_transforms.py index cded7a6efb..8d9213c729 100644 --- a/tests/test_datasets/test_transforms/test_bottomup_transforms.py +++ b/tests/test_datasets/test_transforms/test_bottomup_transforms.py @@ -6,7 +6,9 @@ from mmcv.transforms import Compose from mmpose.datasets.transforms import (BottomupGetHeatmapMask, - BottomupRandomAffine, BottomupResize, + BottomupRandomAffine, + BottomupRandomChoiceResize, + BottomupRandomCrop, BottomupResize, RandomFlip) from mmpose.testing import get_coco_sample @@ -145,3 +147,166 @@ def test_transform(self): self.assertIsInstance(results['input_scale'], np.ndarray) self.assertEqual(results['img'][0].shape, (256, 256, 3)) self.assertEqual(results['img'][1].shape, (384, 384, 3)) + + +class TestBottomupRandomCrop(TestCase): + + def setUp(self): + # test invalid crop_type + with self.assertRaisesRegex(ValueError, 'Invalid crop_type'): + BottomupRandomCrop(crop_size=(10, 10), crop_type='unknown') + + crop_type_list = ['absolute', 'absolute_range'] + for crop_type in crop_type_list: + # test h > 0 and w > 0 + for crop_size in [(0, 0), (0, 1), (1, 0)]: + with self.assertRaises(AssertionError): + BottomupRandomCrop( + crop_size=crop_size, crop_type=crop_type) + # test type(h) = int and type(w) = int + for crop_size in [(1.0, 1), (1, 1.0), (1.0, 1.0)]: + with self.assertRaises(AssertionError): + BottomupRandomCrop( + crop_size=crop_size, crop_type=crop_type) + + # test crop_size[0] <= crop_size[1] + with self.assertRaises(AssertionError): + BottomupRandomCrop(crop_size=(10, 5), crop_type='absolute_range') + + # test h in (0, 1] and w in (0, 1] + crop_type_list = ['relative_range', 'relative'] + for crop_type in crop_type_list: + for crop_size in [(0, 1), (1, 0), (1.1, 0.5), (0.5, 1.1)]: + with self.assertRaises(AssertionError): + BottomupRandomCrop( + crop_size=crop_size, crop_type=crop_type) + + self.data_info = get_coco_sample(img_shape=(24, 32)) + + def test_transform(self): + # test relative and absolute crop + src_results = self.data_info + target_shape = (12, 16) + for crop_type, crop_size in zip(['relative', 'absolute'], [(0.5, 0.5), + (16, 12)]): + transform = BottomupRandomCrop( + crop_size=crop_size, crop_type=crop_type) + results = transform(deepcopy(src_results)) + self.assertEqual(results['img'].shape[:2], target_shape) + + # test absolute_range crop + transform = BottomupRandomCrop( + crop_size=(10, 20), crop_type='absolute_range') + results = transform(deepcopy(src_results)) + h, w = results['img'].shape[:2] + self.assertTrue(10 <= w <= 20) + self.assertTrue(10 <= h <= 20) + self.assertEqual(results['img_shape'], results['img'].shape[:2]) + # test relative_range crop + transform = BottomupRandomCrop( + crop_size=(0.5, 0.5), crop_type='relative_range') + results = transform(deepcopy(src_results)) + h, w = results['img'].shape[:2] + self.assertTrue(16 <= w <= 32) + self.assertTrue(12 <= h <= 24) + self.assertEqual(results['img_shape'], results['img'].shape[:2]) + + # test with keypoints, bbox, segmentation + src_results = get_coco_sample(img_shape=(10, 10), num_instances=2) + segmentation = np.random.randint(0, 255, size=(10, 10), dtype=np.uint8) + keypoints = np.ones_like(src_results['keypoints']) * 5 + src_results['segmentation'] = segmentation + src_results['keypoints'] = keypoints + transform = BottomupRandomCrop( + crop_size=(7, 5), + allow_negative_crop=False, + recompute_bbox=False, + bbox_clip_border=True) + results = transform(deepcopy(src_results)) + h, w = results['img'].shape[:2] + self.assertEqual(h, 5) + self.assertEqual(w, 7) + self.assertEqual(results['bbox'].shape[0], 2) + self.assertTrue(results['keypoints_visible'].all()) + self.assertTupleEqual(results['segmentation'].shape[:2], (5, 7)) + self.assertEqual(results['img_shape'], results['img'].shape[:2]) + + # test bbox_clip_border = False + transform = BottomupRandomCrop( + crop_size=(10, 11), + allow_negative_crop=False, + recompute_bbox=True, + bbox_clip_border=False) + results = transform(deepcopy(src_results)) + self.assertTrue((results['bbox'] == src_results['bbox']).all()) + + # test the crop does not contain any gt-bbox + # allow_negative_crop = False + img = np.random.randint(0, 255, size=(10, 10), dtype=np.uint8) + bbox = np.zeros((0, 4), dtype=np.float32) + src_results = {'img': img, 'bbox': bbox} + transform = BottomupRandomCrop( + crop_size=(5, 3), allow_negative_crop=False) + results = transform(deepcopy(src_results)) + self.assertIsNone(results) + + # allow_negative_crop = True + img = np.random.randint(0, 255, size=(10, 10), dtype=np.uint8) + bbox = np.zeros((0, 4), dtype=np.float32) + src_results = {'img': img, 'bbox': bbox} + transform = BottomupRandomCrop( + crop_size=(5, 3), allow_negative_crop=True) + results = transform(deepcopy(src_results)) + self.assertTrue(isinstance(results, dict)) + + +class TestBottomupRandomChoiceResize(TestCase): + + def setUp(self): + self.data_info = get_coco_sample(img_shape=(300, 400)) + + def test_transform(self): + results = dict() + # test with one scale + transform = BottomupRandomChoiceResize(scales=[(1333, 800)]) + results = deepcopy(self.data_info) + results = transform(results) + self.assertEqual(results['img'].shape, (800, 1333, 3)) + + # test with multi scales + _scale_choice = [(1333, 800), (1333, 600)] + transform = BottomupRandomChoiceResize(scales=_scale_choice) + results = deepcopy(self.data_info) + results = transform(results) + self.assertIn((results['img'].shape[1], results['img'].shape[0]), + _scale_choice) + + # test keep_ratio + transform = BottomupRandomChoiceResize( + scales=[(900, 600)], resize_type='Resize', keep_ratio=True) + results = deepcopy(self.data_info) + _input_ratio = results['img'].shape[0] / results['img'].shape[1] + results = transform(results) + _output_ratio = results['img'].shape[0] / results['img'].shape[1] + self.assertLess(abs(_input_ratio - _output_ratio), 1.5 * 1e-3) + + # test clip_object_border + bbox = [[200, 150, 600, 450]] + transform = BottomupRandomChoiceResize( + scales=[(200, 150)], resize_type='Resize', clip_object_border=True) + results = deepcopy(self.data_info) + results['bbox'] = np.array(bbox) + results = transform(results) + self.assertEqual(results['img'].shape, (150, 200, 3)) + self.assertTrue((results['bbox'] == np.array([[100, 75, 200, + 150]])).all()) + + transform = BottomupRandomChoiceResize( + scales=[(200, 150)], + resize_type='Resize', + clip_object_border=False) + results = self.data_info + results['bbox'] = np.array(bbox) + results = transform(results) + assert results['img'].shape == (150, 200, 3) + assert np.equal(results['bbox'], np.array([[100, 75, 300, 225]])).all()