From 85831b8ae0a869041aca898d29b04018edc8db26 Mon Sep 17 00:00:00 2001
From: Xin Li <7219519+xin-li-67@users.noreply.github.com>
Date: Mon, 24 Jul 2023 19:11:30 +0800
Subject: [PATCH] [Feature][MMSIG] Add UniFormer Pose Estimation to Projects
folder (#2501)
---
projects/uniformer/README.md | 138 ++++
...hm_uniformer-b-8xb128-210e_coco-256x192.py | 135 ++++
...-hm_uniformer-b-8xb32-210e_coco-384x288.py | 134 ++++
...-hm_uniformer-b-8xb32-210e_coco-448x320.py | 134 ++++
...hm_uniformer-s-8xb128-210e_coco-256x192.py | 17 +
...hm_uniformer-s-8xb128-210e_coco-384x288.py | 23 +
...-hm_uniformer-s-8xb64-210e_coco-448x320.py | 22 +
projects/uniformer/models/__init__.py | 1 +
projects/uniformer/models/uniformer.py | 709 ++++++++++++++++++
9 files changed, 1313 insertions(+)
create mode 100644 projects/uniformer/README.md
create mode 100644 projects/uniformer/configs/td-hm_uniformer-b-8xb128-210e_coco-256x192.py
create mode 100644 projects/uniformer/configs/td-hm_uniformer-b-8xb32-210e_coco-384x288.py
create mode 100644 projects/uniformer/configs/td-hm_uniformer-b-8xb32-210e_coco-448x320.py
create mode 100644 projects/uniformer/configs/td-hm_uniformer-s-8xb128-210e_coco-256x192.py
create mode 100644 projects/uniformer/configs/td-hm_uniformer-s-8xb128-210e_coco-384x288.py
create mode 100644 projects/uniformer/configs/td-hm_uniformer-s-8xb64-210e_coco-448x320.py
create mode 100644 projects/uniformer/models/__init__.py
create mode 100644 projects/uniformer/models/uniformer.py
diff --git a/projects/uniformer/README.md b/projects/uniformer/README.md
new file mode 100644
index 0000000000..6f166f975e
--- /dev/null
+++ b/projects/uniformer/README.md
@@ -0,0 +1,138 @@
+# Pose Estion with UniFormer
+
+This project implements a topdown heatmap based human pose estimator, utilizing the approach outlined in **UniFormer: Unifying Convolution and Self-attention for Visual Recognition** (TPAMI 2023) and **UniFormer: Unified Transformer for Efficient Spatiotemporal Representation Learning** (ICLR 2022).
+
+
+
+
+
+## Usage
+
+### Preparation
+
+1. Setup Development Environment
+
+- Python 3.7 or higher
+- PyTorch 1.6 or higher
+- [MMEngine](https://github.com/open-mmlab/mmengine) v0.6.0 or higher
+- [MMCV](https://github.com/open-mmlab/mmcv) v2.0.0rc4 or higher
+- [MMDetection](https://github.com/open-mmlab/mmdetection) v3.0.0rc6 or higher
+- [MMPose](https://github.com/open-mmlab/mmpose) v1.0.0rc1 or higher
+
+All the commands below rely on the correct configuration of `PYTHONPATH`, which should point to the project's directory so that Python can locate the module files. **In `uniformer/` root directory**, run the following line to add the current directory to `PYTHONPATH`:
+
+```shell
+export PYTHONPATH=`pwd`:$PYTHONPATH
+```
+
+2. Download Pretrained Weights
+
+To either run inferences or train on the `uniformer pose estimation` project, you have to download the original Uniformer pretrained weights on the ImageNet1k dataset and the weights trained for the downstream pose estimation task. The original ImageNet1k weights are hosted on SenseTime's [huggingface repository](https://huggingface.co/Sense-X/uniformer_image), and the downstream pose estimation task weights are hosted either on Google Drive or Baiduyun. We have uploaded them to the OpenMMLab download URLs, allowing users to use them without burden. For example, you can take a look at [`td-hm_uniformer-b-8xb128-210e_coco-256x192.py`](./configs/td-hm_uniformer-b-8xb128-210e_coco-256x192.py#62), the corresponding pretrained weight URL is already here and when the training or testing process starts, the weight will be automatically downloaded to your device. For the downstream task weights, you can get their URLs from the [benchmark result table](#results).
+
+### Inference
+
+We have provided a [inferencer_demo.py](../../demo/inferencer_demo.py) with which developers can utilize to run quick inference demos. Here is a basic demonstration:
+
+```shell
+python demo/inferencer_demo.py $INPUTS \
+ --pose2d $CONFIG --pose2d-weights $CHECKPOINT \
+ [--show] [--vis-out-dir $VIS_OUT_DIR] [--pred-out-dir $PRED_OUT_DIR]
+```
+
+For more information on using the inferencer, please see [this document](https://mmpose.readthedocs.io/en/latest/user_guides/inference.html#out-of-the-box-inferencer).
+
+Here's an example code:
+
+```shell
+python demo/inferencer_demo.py tests/data/coco/000000000785.jpg \
+ --pose2d projects/uniformer/configs/td-hm_uniformer-s-8xb128-210e_coco-256x192.py \
+ --pose2d-weights https://download.openmmlab.com/mmpose/v1/projects/uniformer/top_down_256x192_global_small-d4a7fdac_20230724.pth \
+ --vis-out-dir vis_results
+```
+
+Then you will find the demo result in `vis_results` folder, and it may be similar to this:
+
+
+
+### Training and Testing
+
+1. Data Preparation
+
+Prepare the COCO dataset according to the [instruction](https://mmpose.readthedocs.io/en/latest/dataset_zoo/2d_body_keypoint.html#coco).
+
+2. To Train and Test with Single GPU:
+
+```shell
+python tools/test.py $CONFIG --auto-scale-lr
+```
+
+```shell
+python tools/test.py $CONFIG $CHECKPOINT
+```
+
+3. To Train and Test with Multiple GPUs:
+
+```shell
+bash tools/dist_train.sh $CONFIG $NUM_GPUs --amp
+```
+
+```shell
+bash tools/dist_test.sh $CONFIG $CHECKPOINT $NUM_GPUs --amp
+```
+
+## Results
+
+Here is the testing results on COCO val2017:
+
+| Model | Input Size | AP | AP50 | AP75 | AR | AR50 | Download |
+| :-----------------------------------------------------------------: | :--------: | :--: | :-------------: | :-------------: | :--: | :-------------: | :--------------------------------------------------------------------: |
+| [UniFormer-S](./configs/td-hm_uniformer-s-8xb128-210e_coco-256x192.py) | 256x192 | 74.0 | 90.2 | 82.1 | 79.5 | 94.1 | [model](https://download.openmmlab.com/mmpose/v1/projects/uniformer/top_down_256x192_global_small-d4a7fdac_20230724.pth) \| [log](https://download.openmmlab.com/mmpose/v1/projects/uniformer/top_down_256x192_global_small-d4a7fdac_20230724.log.json) |
+| [UniFormer-S](./configs/td-hm_uniformer-s-8xb128-210e_coco-384x288.py) | 384x288 | 75.9 | 90.6 | 83.0 | 81.0 | 94.3 | [model](https://download.openmmlab.com/mmpose/v1/projects/uniformer/top_down_384x288_global_small-7a613f78_20230724.pth) \| [log](https://download.openmmlab.com/mmpose/v1/projects/uniformer/top_down_384x288_global_small-7a613f78_20230724.log.json) |
+| [UniFormer-S](./configs/td-hm_uniformer-s-8xb64-210e_coco-448x320.py) | 448x320 | 76.2 | 90.6 | 83.2 | 81.4 | 94.4 | [model](https://download.openmmlab.com/mmpose/v1/projects/uniformer/top_down_448x320_global_small-18b760de_20230724.pth) \| [log](https://download.openmmlab.com/mmpose/v1/projects/uniformer/top_down_448x320_global_small-18b760de_20230724.log.json) |
+| [UniFormer-B](./configs/td-hm_uniformer-b-8xb128-210e_coco-256x192.py) | 256x192 | 75.0 | 90.5 | 83.0 | 80.4 | 94.2 | [model](https://download.openmmlab.com/mmpose/v1/projects/uniformer/top_down_256x192_global_base-1713bcd4_20230724.pth) \| [log](https://download.openmmlab.com/mmpose/v1/projects/uniformer/top_down_256x192_global_base-1713bcd4_20230724.log.json) |
+| [UniFormer-B](./configs/td-hm_uniformer-b-8xb32-210e_coco-384x288.py) | 384x288 | 76.7 | 90.8 | 84.1 | 81.9 | 94.6 | [model](https://download.openmmlab.com/mmpose/v1/projects/uniformer/top_down_384x288_global_base-c650da38_20230724.pth) \| [log](https://download.openmmlab.com/mmpose/v1/projects/uniformer/top_down_384x288_global_base-c650da38_20230724.log.json) |
+| [UniFormer-B](./configs/td-hm_uniformer-b-8xb32-210e_coco-448x320.py) | 448x320 | 77.4 | 91.0 | 84.4 | 82.5 | 94.9 | [model](https://download.openmmlab.com/mmpose/v1/projects/uniformer/top_down_448x320_global_base-a05c185f_20230724.pth) \| [log](https://download.openmmlab.com/mmpose/v1/projects/uniformer/top_down_448x320_global_base-a05c185f_20230724.log.json) |
+
+Here is the testing results on COCO val 2017 from the official UniFormer Pose Estimation repository for comparison:
+
+| Backbone | Input Size | AP | AP50 | AP75 | ARM | ARL | AR | Model | Log |
+| :---------- | :--------- | :--- | :-------------- | :-------------- | :------------- | :------------- | :--- | :-------------------------------------------------------- | :------------------------------------------------------- |
+| UniFormer-S | 256x192 | 74.0 | 90.3 | 82.2 | 66.8 | 76.7 | 79.5 | [google](https://drive.google.com/file/d/162R0JuTpf3gpLe1IK6oxRoQK7JSj4ylx/view?usp=sharing) | [google](https://drive.google.com/file/d/15j40u97Db6TA2gMHdn0yFEsDFb5SMBy4/view?usp=sharing) |
+| UniFormer-S | 384x288 | 75.9 | 90.6 | 83.4 | 68.6 | 79.0 | 81.4 | [google](https://drive.google.com/file/d/163vuFkpcgVOthC05jCwjGzo78Nr0eikW/view?usp=sharing) | [google](https://drive.google.com/file/d/15X9M_5cq9RQMgs64Yn9YvV5k5f0zOBHo/view?usp=sharing) |
+| UniFormer-S | 448x320 | 76.2 | 90.6 | 83.2 | 68.6 | 79.4 | 81.4 | [google](https://drive.google.com/file/d/165nQRsT58SXJegcttksHwDn46Fme5dGX/view?usp=sharing) | [google](https://drive.google.com/file/d/15IJjSWp4R5OybMdV2CZEUx_TwXdTMOee/view?usp=sharing) |
+| UniFormer-B | 256x192 | 75.0 | 90.6 | 83.0 | 67.8 | 77.7 | 80.4 | [google](https://drive.google.com/file/d/15tzJaRyEzyWp2mQhpjDbBzuGoyCaJJ-2/view?usp=sharing) | [google](https://drive.google.com/file/d/15jJyTPcJKj_id0PNdytloqt7yjH2M8UR/view?usp=sharing) |
+| UniFormer-B | 384x288 | 76.7 | 90.8 | 84.0 | 69.3 | 79.7 | 81.4 | [google](https://drive.google.com/file/d/15qtUaOR_C7-vooheJE75mhA9oJQt3gSx/view?usp=sharing) | [google](https://drive.google.com/file/d/15L1Uxo_uRSMlGnOvWzAzkJLKX6Qh_xNw/view?usp=sharing) |
+| UniFormer-B | 448x320 | 77.4 | 91.1 | 84.4 | 70.2 | 80.6 | 82.5 | [google](https://drive.google.com/file/d/156iNxetiCk8JJz41aFDmFh9cQbCaMk3D/view?usp=sharing) | [google](https://drive.google.com/file/d/15aRpZc2Tie5gsn3_l-aXto1MrC9wyzMC/view?usp=sharing) |
+
+Note:
+
+1. All the original models are pretrained on ImageNet-1K without Token Labeling and Layer Scale, as mentioned in the [official README](https://github.com/Sense-X/UniFormer/tree/main/pose_estimation) . The official team has confirmed that **Token labeling can largely improve the performance of the downstream tasks**. Developers can utilize the implementation by themselves.
+2. The original implementation did not include the **freeze BN in the backbone**. The official team has confirmed that this can improve the performance as well.
+3. To avoid running out of memory, developers can use `torch.utils.checkpoint` in the `config.py` by setting `use_checkpoint=True` and `checkpoint_num=[0, 0, 2, 0] # index for using checkpoint in every stage`
+4. We warmly welcome any contributions if you can successfully reproduce the results from the paper!
+
+## Citation
+
+If this project benefits your work, please kindly consider citing the original papers:
+
+```bibtex
+@misc{li2022uniformer,
+ title={UniFormer: Unifying Convolution and Self-attention for Visual Recognition},
+ author={Kunchang Li and Yali Wang and Junhao Zhang and Peng Gao and Guanglu Song and Yu Liu and Hongsheng Li and Yu Qiao},
+ year={2022},
+ eprint={2201.09450},
+ archivePrefix={arXiv},
+ primaryClass={cs.CV}
+}
+```
+
+```bibtex
+@misc{li2022uniformer,
+ title={UniFormer: Unified Transformer for Efficient Spatiotemporal Representation Learning},
+ author={Kunchang Li and Yali Wang and Peng Gao and Guanglu Song and Yu Liu and Hongsheng Li and Yu Qiao},
+ year={2022},
+ eprint={2201.04676},
+ archivePrefix={arXiv},
+ primaryClass={cs.CV}
+}
+```
diff --git a/projects/uniformer/configs/td-hm_uniformer-b-8xb128-210e_coco-256x192.py b/projects/uniformer/configs/td-hm_uniformer-b-8xb128-210e_coco-256x192.py
new file mode 100644
index 0000000000..07f1377842
--- /dev/null
+++ b/projects/uniformer/configs/td-hm_uniformer-b-8xb128-210e_coco-256x192.py
@@ -0,0 +1,135 @@
+_base_ = ['mmpose::_base_/default_runtime.py']
+
+custom_imports = dict(imports='projects.uniformer.models')
+
+# runtime
+train_cfg = dict(max_epochs=210, val_interval=10)
+
+# enable DDP training when pretrained model is used
+find_unused_parameters = True
+
+# optimizer
+optim_wrapper = dict(optimizer=dict(
+ type='Adam',
+ lr=2e-3,
+))
+
+# learning policy
+param_scheduler = [
+ dict(
+ type='LinearLR', begin=0, end=500, start_factor=0.001,
+ by_epoch=False), # warm-up
+ dict(
+ type='MultiStepLR',
+ begin=0,
+ end=210,
+ milestones=[170, 200],
+ gamma=0.1,
+ by_epoch=True)
+]
+
+# automatically scaling LR based on the actual training batch size
+auto_scale_lr = dict(base_batch_size=1024)
+
+# hooks
+default_hooks = dict(
+ checkpoint=dict(save_best='coco/AP', rule='greater', interval=5))
+
+# codec settings
+codec = dict(
+ type='MSRAHeatmap', input_size=(192, 256), heatmap_size=(48, 64), sigma=2)
+
+# model settings
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+model = dict(
+ type='TopdownPoseEstimator',
+ data_preprocessor=dict(
+ type='PoseDataPreprocessor',
+ mean=[123.675, 116.28, 103.53],
+ std=[58.395, 57.12, 57.375],
+ bgr_to_rgb=True),
+ backbone=dict(
+ type='UniFormer',
+ embed_dims=[64, 128, 320, 512],
+ depths=[5, 8, 20, 7],
+ head_dim=64,
+ drop_path_rate=0.4,
+ use_checkpoint=False, # whether use torch.utils.checkpoint
+ use_window=False, # whether use window MHRA
+ use_hybrid=False, # whether use hybrid MHRA
+ init_cfg=dict(
+ # Set the path to pretrained backbone here
+ type='Pretrained',
+ checkpoint='https://download.openmmlab.com/mmpose/v1/projects/'
+ 'uniformer/uniformer_base_in1k.pth' # noqa
+ )),
+ head=dict(
+ type='HeatmapHead',
+ in_channels=512,
+ out_channels=17,
+ final_layer=dict(kernel_size=1),
+ loss=dict(type='KeypointMSELoss', use_target_weight=True),
+ decoder=codec),
+ test_cfg=dict(flip_test=True, flip_mode='heatmap', shift_heatmap=True))
+
+# base dataset settings
+dataset_type = 'CocoDataset'
+data_mode = 'topdown'
+data_root = 'data/coco/'
+
+# pipelines
+train_pipeline = [
+ dict(type='LoadImage'),
+ dict(type='GetBBoxCenterScale'),
+ dict(type='RandomFlip', direction='horizontal'),
+ dict(type='RandomHalfBody'),
+ dict(type='RandomBBoxTransform'),
+ dict(type='TopdownAffine', input_size=codec['input_size']),
+ dict(type='GenerateTarget', encoder=codec),
+ dict(type='PackPoseInputs')
+]
+val_pipeline = [
+ dict(type='LoadImage'),
+ dict(type='GetBBoxCenterScale'),
+ dict(type='TopdownAffine', input_size=codec['input_size']),
+ dict(type='PackPoseInputs')
+]
+
+# data loaders
+train_dataloader = dict(
+ batch_size=128,
+ num_workers=2,
+ persistent_workers=True,
+ sampler=dict(type='DefaultSampler', shuffle=True),
+ dataset=dict(
+ type=dataset_type,
+ data_root=data_root,
+ data_mode=data_mode,
+ ann_file='annotations/person_keypoints_train2017.json',
+ data_prefix=dict(img='train2017/'),
+ pipeline=train_pipeline,
+ ))
+val_dataloader = dict(
+ batch_size=256,
+ num_workers=2,
+ persistent_workers=True,
+ drop_last=False,
+ sampler=dict(type='DefaultSampler', shuffle=False, round_up=False),
+ dataset=dict(
+ type=dataset_type,
+ data_root=data_root,
+ data_mode=data_mode,
+ ann_file='annotations/person_keypoints_val2017.json',
+ bbox_file='data/coco/person_detection_results/'
+ 'COCO_val2017_detections_AP_H_56_person.json',
+ data_prefix=dict(img='val2017/'),
+ test_mode=True,
+ pipeline=val_pipeline,
+ ))
+test_dataloader = val_dataloader
+
+# evaluators
+val_evaluator = dict(
+ type='CocoMetric',
+ ann_file=data_root + 'annotations/person_keypoints_val2017.json')
+test_evaluator = val_evaluator
diff --git a/projects/uniformer/configs/td-hm_uniformer-b-8xb32-210e_coco-384x288.py b/projects/uniformer/configs/td-hm_uniformer-b-8xb32-210e_coco-384x288.py
new file mode 100644
index 0000000000..d43061d0cd
--- /dev/null
+++ b/projects/uniformer/configs/td-hm_uniformer-b-8xb32-210e_coco-384x288.py
@@ -0,0 +1,134 @@
+_base_ = ['mmpose::_base_/default_runtime.py']
+
+custom_imports = dict(imports='projects.uniformer.models')
+
+# runtime
+train_cfg = dict(max_epochs=210, val_interval=10)
+
+# enable DDP training when pretrained model is used
+find_unused_parameters = True
+
+# optimizer
+optim_wrapper = dict(optimizer=dict(
+ type='Adam',
+ lr=5e-4,
+))
+
+# learning policy
+param_scheduler = [
+ dict(
+ type='LinearLR', begin=0, end=500, start_factor=0.001,
+ by_epoch=False), # warm-up
+ dict(
+ type='MultiStepLR',
+ begin=0,
+ end=210,
+ milestones=[170, 200],
+ gamma=0.1,
+ by_epoch=True)
+]
+
+# automatically scaling LR based on the actual training batch size
+auto_scale_lr = dict(base_batch_size=512)
+
+# hooks
+default_hooks = dict(checkpoint=dict(save_best='coco/AP', rule='greater'))
+
+# codec settings
+codec = dict(
+ type='MSRAHeatmap', input_size=(288, 384), heatmap_size=(72, 96), sigma=3)
+
+# model settings
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+model = dict(
+ type='TopdownPoseEstimator',
+ data_preprocessor=dict(
+ type='PoseDataPreprocessor',
+ mean=[123.675, 116.28, 103.53],
+ std=[58.395, 57.12, 57.375],
+ bgr_to_rgb=True),
+ backbone=dict(
+ type='UniFormer',
+ embed_dims=[64, 128, 320, 512],
+ depths=[5, 8, 20, 7],
+ head_dim=64,
+ drop_path_rate=0.4,
+ use_checkpoint=False, # whether use torch.utils.checkpoint
+ use_window=False, # whether use window MHRA
+ use_hybrid=False, # whether use hybrid MHRA
+ init_cfg=dict(
+ # Set the path to pretrained backbone here
+ type='Pretrained',
+ checkpoint='https://download.openmmlab.com/mmpose/v1/projects/'
+ 'uniformer/uniformer_base_in1k.pth' # noqa
+ )),
+ head=dict(
+ type='HeatmapHead',
+ in_channels=512,
+ out_channels=17,
+ final_layer=dict(kernel_size=1),
+ loss=dict(type='KeypointMSELoss', use_target_weight=True),
+ decoder=codec),
+ test_cfg=dict(flip_test=True, flip_mode='heatmap', shift_heatmap=True))
+
+# base dataset settings
+dataset_type = 'CocoDataset'
+data_mode = 'topdown'
+data_root = 'data/coco/'
+
+# pipelines
+train_pipeline = [
+ dict(type='LoadImage'),
+ dict(type='GetBBoxCenterScale'),
+ dict(type='RandomFlip', direction='horizontal'),
+ dict(type='RandomHalfBody'),
+ dict(type='RandomBBoxTransform'),
+ dict(type='TopdownAffine', input_size=codec['input_size']),
+ dict(type='GenerateTarget', encoder=codec),
+ dict(type='PackPoseInputs')
+]
+val_pipeline = [
+ dict(type='LoadImage'),
+ dict(type='GetBBoxCenterScale'),
+ dict(type='TopdownAffine', input_size=codec['input_size']),
+ dict(type='PackPoseInputs')
+]
+
+# data loaders
+train_dataloader = dict(
+ batch_size=128,
+ num_workers=2,
+ persistent_workers=True,
+ sampler=dict(type='DefaultSampler', shuffle=True),
+ dataset=dict(
+ type=dataset_type,
+ data_root=data_root,
+ data_mode=data_mode,
+ ann_file='annotations/person_keypoints_train2017.json',
+ data_prefix=dict(img='train2017/'),
+ pipeline=train_pipeline,
+ ))
+val_dataloader = dict(
+ batch_size=256,
+ num_workers=2,
+ persistent_workers=True,
+ drop_last=False,
+ sampler=dict(type='DefaultSampler', shuffle=False, round_up=False),
+ dataset=dict(
+ type=dataset_type,
+ data_root=data_root,
+ data_mode=data_mode,
+ ann_file='annotations/person_keypoints_val2017.json',
+ bbox_file='data/coco/person_detection_results/'
+ 'COCO_val2017_detections_AP_H_56_person.json',
+ data_prefix=dict(img='val2017/'),
+ test_mode=True,
+ pipeline=val_pipeline,
+ ))
+test_dataloader = val_dataloader
+
+# evaluators
+val_evaluator = dict(
+ type='CocoMetric',
+ ann_file=data_root + 'annotations/person_keypoints_val2017.json')
+test_evaluator = val_evaluator
diff --git a/projects/uniformer/configs/td-hm_uniformer-b-8xb32-210e_coco-448x320.py b/projects/uniformer/configs/td-hm_uniformer-b-8xb32-210e_coco-448x320.py
new file mode 100644
index 0000000000..81554ad27e
--- /dev/null
+++ b/projects/uniformer/configs/td-hm_uniformer-b-8xb32-210e_coco-448x320.py
@@ -0,0 +1,134 @@
+_base_ = ['mmpose::_base_/default_runtime.py']
+
+custom_imports = dict(imports='projects.uniformer.models')
+
+# runtime
+train_cfg = dict(max_epochs=210, val_interval=10)
+
+# enable DDP training when pretrained model is used
+find_unused_parameters = True
+
+# optimizer
+optim_wrapper = dict(optimizer=dict(
+ type='Adam',
+ lr=5e-4,
+))
+
+# learning policy
+param_scheduler = [
+ dict(
+ type='LinearLR', begin=0, end=500, start_factor=0.001,
+ by_epoch=False), # warm-up
+ dict(
+ type='MultiStepLR',
+ begin=0,
+ end=210,
+ milestones=[170, 200],
+ gamma=0.1,
+ by_epoch=True)
+]
+
+# automatically scaling LR based on the actual training batch size
+auto_scale_lr = dict(base_batch_size=256)
+
+# hooks
+default_hooks = dict(checkpoint=dict(save_best='coco/AP', rule='greater'))
+
+# codec settings
+codec = dict(
+ type='MSRAHeatmap', input_size=(320, 448), heatmap_size=(80, 112), sigma=3)
+
+# model settings
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+model = dict(
+ type='TopdownPoseEstimator',
+ data_preprocessor=dict(
+ type='PoseDataPreprocessor',
+ mean=[123.675, 116.28, 103.53],
+ std=[58.395, 57.12, 57.375],
+ bgr_to_rgb=True),
+ backbone=dict(
+ type='UniFormer',
+ embed_dims=[64, 128, 320, 512],
+ depths=[5, 8, 20, 7],
+ head_dim=64,
+ drop_path_rate=0.55,
+ use_checkpoint=False, # whether use torch.utils.checkpoint
+ use_window=False, # whether use window MHRA
+ use_hybrid=False, # whether use hybrid MHRA
+ init_cfg=dict(
+ # Set the path to pretrained backbone here
+ type='Pretrained',
+ checkpoint='https://download.openmmlab.com/mmpose/v1/projects/'
+ 'uniformer/uniformer_base_in1k.pth' # noqa
+ )),
+ head=dict(
+ type='HeatmapHead',
+ in_channels=512,
+ out_channels=17,
+ final_layer=dict(kernel_size=1),
+ loss=dict(type='KeypointMSELoss', use_target_weight=True),
+ decoder=codec),
+ test_cfg=dict(flip_test=True, flip_mode='heatmap', shift_heatmap=True))
+
+# base dataset settings
+dataset_type = 'CocoDataset'
+data_mode = 'topdown'
+data_root = 'data/coco/'
+
+# pipelines
+train_pipeline = [
+ dict(type='LoadImage'),
+ dict(type='GetBBoxCenterScale'),
+ dict(type='RandomFlip', direction='horizontal'),
+ dict(type='RandomHalfBody'),
+ dict(type='RandomBBoxTransform'),
+ dict(type='TopdownAffine', input_size=codec['input_size']),
+ dict(type='GenerateTarget', encoder=codec),
+ dict(type='PackPoseInputs')
+]
+val_pipeline = [
+ dict(type='LoadImage'),
+ dict(type='GetBBoxCenterScale'),
+ dict(type='TopdownAffine', input_size=codec['input_size']),
+ dict(type='PackPoseInputs')
+]
+
+# data loaders
+train_dataloader = dict(
+ batch_size=32,
+ num_workers=2,
+ persistent_workers=True,
+ sampler=dict(type='DefaultSampler', shuffle=True),
+ dataset=dict(
+ type=dataset_type,
+ data_root=data_root,
+ data_mode=data_mode,
+ ann_file='annotations/person_keypoints_train2017.json',
+ data_prefix=dict(img='train2017/'),
+ pipeline=train_pipeline,
+ ))
+val_dataloader = dict(
+ batch_size=256,
+ num_workers=2,
+ persistent_workers=True,
+ drop_last=False,
+ sampler=dict(type='DefaultSampler', shuffle=False, round_up=False),
+ dataset=dict(
+ type=dataset_type,
+ data_root=data_root,
+ data_mode=data_mode,
+ ann_file='annotations/person_keypoints_val2017.json',
+ bbox_file='data/coco/person_detection_results/'
+ 'COCO_val2017_detections_AP_H_56_person.json',
+ data_prefix=dict(img='val2017/'),
+ test_mode=True,
+ pipeline=val_pipeline,
+ ))
+test_dataloader = val_dataloader
+
+# evaluators
+val_evaluator = dict(
+ type='CocoMetric',
+ ann_file=data_root + 'annotations/person_keypoints_val2017.json')
+test_evaluator = val_evaluator
diff --git a/projects/uniformer/configs/td-hm_uniformer-s-8xb128-210e_coco-256x192.py b/projects/uniformer/configs/td-hm_uniformer-s-8xb128-210e_coco-256x192.py
new file mode 100644
index 0000000000..54994893dd
--- /dev/null
+++ b/projects/uniformer/configs/td-hm_uniformer-s-8xb128-210e_coco-256x192.py
@@ -0,0 +1,17 @@
+_base_ = ['./td-hm_uniformer-b-8xb128-210e_coco-256x192.py']
+
+# automatically scaling LR based on the actual training batch size
+auto_scale_lr = dict(base_batch_size=1024)
+
+model = dict(
+ backbone=dict(
+ depths=[3, 4, 8, 3],
+ drop_path_rate=0.2,
+ init_cfg=dict(
+ type='Pretrained',
+ checkpoint='https://download.openmmlab.com/mmpose/v1/projects/'
+ 'uniformer/uniformer_small_in1k.pth' # noqa
+ )))
+
+train_dataloader = dict(batch_size=32)
+val_dataloader = dict(batch_size=256)
diff --git a/projects/uniformer/configs/td-hm_uniformer-s-8xb128-210e_coco-384x288.py b/projects/uniformer/configs/td-hm_uniformer-s-8xb128-210e_coco-384x288.py
new file mode 100644
index 0000000000..59f68946ef
--- /dev/null
+++ b/projects/uniformer/configs/td-hm_uniformer-s-8xb128-210e_coco-384x288.py
@@ -0,0 +1,23 @@
+_base_ = ['./td-hm_uniformer-b-8xb32-210e_coco-384x288.py']
+
+# optimizer
+optim_wrapper = dict(optimizer=dict(
+ type='Adam',
+ lr=2e-3,
+))
+
+# automatically scaling LR based on the actual training batch size
+auto_scale_lr = dict(base_batch_size=1024)
+
+model = dict(
+ backbone=dict(
+ depths=[3, 4, 8, 3],
+ drop_path_rate=0.2,
+ init_cfg=dict(
+ type='Pretrained',
+ checkpoint='https://download.openmmlab.com/mmpose/v1/projects/'
+ 'uniformer/uniformer_small_in1k.pth' # noqa
+ )))
+
+train_dataloader = dict(batch_size=128)
+val_dataloader = dict(batch_size=256)
diff --git a/projects/uniformer/configs/td-hm_uniformer-s-8xb64-210e_coco-448x320.py b/projects/uniformer/configs/td-hm_uniformer-s-8xb64-210e_coco-448x320.py
new file mode 100644
index 0000000000..0359ac6d63
--- /dev/null
+++ b/projects/uniformer/configs/td-hm_uniformer-s-8xb64-210e_coco-448x320.py
@@ -0,0 +1,22 @@
+_base_ = ['./td-hm_uniformer-b-8xb32-210e_coco-448x320.py']
+
+# optimizer
+optim_wrapper = dict(optimizer=dict(
+ type='Adam',
+ lr=1.0e-3,
+))
+
+# automatically scaling LR based on the actual training batch size
+auto_scale_lr = dict(base_batch_size=512)
+
+model = dict(
+ backbone=dict(
+ depths=[3, 4, 8, 3],
+ drop_path_rate=0.2,
+ init_cfg=dict(
+ type='Pretrained',
+ checkpoint='https://download.openmmlab.com/mmpose/v1/projects/'
+ 'uniformer/uniformer_small_in1k.pth')))
+
+train_dataloader = dict(batch_size=64)
+val_dataloader = dict(batch_size=256)
diff --git a/projects/uniformer/models/__init__.py b/projects/uniformer/models/__init__.py
new file mode 100644
index 0000000000..6256db6f45
--- /dev/null
+++ b/projects/uniformer/models/__init__.py
@@ -0,0 +1 @@
+from .uniformer import * # noqa
diff --git a/projects/uniformer/models/uniformer.py b/projects/uniformer/models/uniformer.py
new file mode 100644
index 0000000000..cea36f061b
--- /dev/null
+++ b/projects/uniformer/models/uniformer.py
@@ -0,0 +1,709 @@
+from collections import OrderedDict
+from functools import partial
+from typing import Dict, List, Optional, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from mmcv.cnn.bricks.transformer import build_dropout
+from mmengine.model import BaseModule
+from mmengine.model.weight_init import trunc_normal_
+from mmengine.runner import checkpoint, load_checkpoint
+from mmengine.utils import to_2tuple
+
+from mmpose.models.backbones.base_backbone import BaseBackbone
+from mmpose.registry import MODELS
+from mmpose.utils import get_root_logger
+
+
+class Mlp(BaseModule):
+ """Multilayer perceptron.
+
+ Args:
+ in_features (int): Number of input features.
+ hidden_features (int): Number of hidden features.
+ Defaults to None.
+ out_features (int): Number of output features.
+ Defaults to None.
+ drop (float): Dropout rate. Defaults to 0.0.
+ init_cfg (dict, optional): Config dict for initialization.
+ Defaults to None.
+ """
+
+ def __init__(self,
+ in_features: int,
+ hidden_features: int = None,
+ out_features: int = None,
+ drop_rate: float = 0.,
+ init_cfg: Optional[dict] = None) -> None:
+ super().__init__(init_cfg=init_cfg)
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features)
+ self.act = nn.GELU()
+ self.fc2 = nn.Linear(hidden_features, out_features)
+ self.drop = nn.Dropout(drop_rate)
+
+ def forward(self, x):
+ x = self.act(self.fc1(x))
+ x = self.fc2(self.drop(x))
+ x = self.drop(x)
+ return x
+
+
+class CMlp(BaseModule):
+ """Multilayer perceptron via convolution.
+
+ Args:
+ in_features (int): Number of input features.
+ hidden_features (int): Number of hidden features.
+ Defaults to None.
+ out_features (int): Number of output features.
+ Defaults to None.
+ drop (float): Dropout rate. Defaults to 0.0.
+ init_cfg (dict, optional): Config dict for initialization.
+ Defaults to None.
+ """
+
+ def __init__(self,
+ in_features: int,
+ hidden_features: int = None,
+ out_features: int = None,
+ drop_rate: float = 0.,
+ init_cfg: Optional[dict] = None) -> None:
+ super().__init__(init_cfg=init_cfg)
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Conv2d(in_features, hidden_features, kernel_size=1)
+ self.act = nn.GELU()
+ self.fc2 = nn.Conv2d(hidden_features, out_features, kernel_size=1)
+ self.drop = nn.Dropout(drop_rate)
+
+ def forward(self, x):
+ x = self.act(self.fc1(x))
+ x = self.fc2(self.drop(x))
+ x = self.drop(x)
+ return x
+
+
+class CBlock(BaseModule):
+ """Convolution Block.
+
+ Args:
+ embed_dim (int): Number of input features.
+ mlp_ratio (float): Ratio of mlp hidden dimension
+ to embedding dimension. Defaults to 4.
+ drop (float): Dropout rate.
+ Defaults to 0.0.
+ drop_paths (float): Stochastic depth rates.
+ Defaults to 0.0.
+ init_cfg (dict, optional): Config dict for initialization.
+ Defaults to None.
+ """
+
+ def __init__(self,
+ embed_dim: int,
+ mlp_ratio: float = 4.,
+ drop_rate: float = 0.,
+ drop_path_rate: float = 0.,
+ init_cfg: Optional[dict] = None) -> None:
+ super().__init__(init_cfg=init_cfg)
+ self.pos_embed = nn.Conv2d(
+ embed_dim, embed_dim, 3, padding=1, groups=embed_dim)
+ self.norm1 = nn.BatchNorm2d(embed_dim)
+ self.conv1 = nn.Conv2d(embed_dim, embed_dim, 1)
+ self.conv2 = nn.Conv2d(embed_dim, embed_dim, 1)
+ self.attn = nn.Conv2d(
+ embed_dim, embed_dim, 5, padding=2, groups=embed_dim)
+ # NOTE: drop path for stochastic depth, we shall see if this is
+ # better than dropout here
+ self.drop_path = build_dropout(
+ dict(type='DropPath', drop_prob=drop_path_rate)
+ ) if drop_path_rate > 0. else nn.Identity()
+ self.norm2 = nn.BatchNorm2d(embed_dim)
+ mlp_hidden_dim = int(embed_dim * mlp_ratio)
+ self.mlp = CMlp(
+ in_features=embed_dim,
+ hidden_features=mlp_hidden_dim,
+ drop_rate=drop_rate)
+
+ def forward(self, x):
+ x = x + self.pos_embed(x)
+ x = x + self.drop_path(
+ self.conv2(self.attn(self.conv1(self.norm1(x)))))
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
+ return x
+
+
+class Attention(BaseModule):
+ """Self-Attention.
+
+ Args:
+ embed_dim (int): Number of input features.
+ num_heads (int): Number of attention heads.
+ Defaults to 8.
+ qkv_bias (bool): If True, add a learnable bias to query, key, value.
+ Defaults to True.
+ qk_scale (float, optional): Override default qk scale of
+ ``head_dim ** -0.5`` if set. Defaults to None.
+ attn_drop_rate (float): Attention dropout rate.
+ Defaults to 0.0.
+ proj_drop_rate (float): Dropout rate.
+ Defaults to 0.0.
+ init_cfg (dict, optional): Config dict for initialization.
+ Defaults to None.
+ init_cfg (dict, optional): The config of weight initialization.
+ Defaults to None.
+ """
+
+ def __init__(self,
+ embed_dim: int,
+ num_heads: int = 8,
+ qkv_bias: bool = True,
+ qk_scale: float = None,
+ attn_drop_rate: float = 0.,
+ proj_drop_rate: float = 0.,
+ init_cfg: Optional[dict] = None) -> None:
+ super().__init__(init_cfg=init_cfg)
+ self.num_heads = num_heads
+ head_dim = embed_dim // num_heads
+ # NOTE scale factor was wrong in my original version, can set manually
+ # to be compat with prev weights
+ self.scale = qk_scale or head_dim**-0.5
+
+ self.qkv = nn.Linear(embed_dim, embed_dim * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop_rate)
+ self.proj = nn.Linear(embed_dim, embed_dim)
+ self.proj_drop = nn.Dropout(proj_drop_rate)
+
+ def forward(self, x):
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads,
+ C // self.num_heads).permute(2, 0, 3, 1, 4)
+ q, k, v = qkv[0], qkv[1], qkv[
+ 2] # make torchscript happy (cannot use tensor as tuple)
+
+ attn = (q @ k.transpose(-2, -1)) * self.scale
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class PatchEmbed(BaseModule):
+ """Image to Patch Embedding.
+
+ Args:
+ img_size (int): Number of input size.
+ Defaults to 224.
+ patch_size (int): Number of patch size.
+ Defaults to 16.
+ in_channels (int): Number of input features.
+ Defaults to 3.
+ embed_dims (int): Number of output features.
+ Defaults to 768.
+ init_cfg (dict, optional): Config dict for initialization.
+ Defaults to None.
+ """
+
+ def __init__(self,
+ img_size: int = 224,
+ patch_size: int = 16,
+ in_channels: int = 3,
+ embed_dim: int = 768,
+ init_cfg: Optional[dict] = None) -> None:
+ super().__init__(init_cfg=init_cfg)
+ img_size = to_2tuple(img_size)
+ patch_size = to_2tuple(patch_size)
+ num_patches = (img_size[1] // patch_size[1]) * (
+ img_size[0] // patch_size[0])
+ self.img_size = img_size
+ self.patch_size = patch_size
+ self.num_patches = num_patches
+ self.norm = nn.LayerNorm(embed_dim)
+ self.proj = nn.Conv2d(
+ in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
+
+ def forward(self, x):
+ B, _, H, W = x.shape
+ x = self.proj(x)
+ B, _, H, W = x.shape
+ x = x.flatten(2).transpose(1, 2)
+ x = self.norm(x)
+ x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
+ return x
+
+
+class SABlock(BaseModule):
+ """Self-Attention Block.
+
+ Args:
+ embed_dim (int): Number of input features.
+ num_heads (int): Number of attention heads.
+ mlp_ratio (float): Ratio of mlp hidden dimension
+ to embedding dimension. Defaults to 4.
+ qkv_bias (bool): If True, add a learnable bias to query, key, value.
+ Defaults to True.
+ qk_scale (float, optional): Override default qk scale of
+ ``head_dim ** -0.5`` if set. Defaults to None.
+ drop (float): Dropout rate. Defaults to 0.0.
+ attn_drop (float): Attention dropout rate. Defaults to 0.0.
+ drop_paths (float): Stochastic depth rates.
+ Defaults to 0.0.
+ init_cfg (dict, optional): Config dict for initialization.
+ Defaults to None.
+ """
+
+ def __init__(self,
+ embed_dim: int,
+ num_heads: int,
+ mlp_ratio: float = 4.,
+ qkv_bias: bool = False,
+ qk_scale: float = None,
+ drop_rate: float = 0.,
+ attn_drop_rate: float = 0.,
+ drop_path_rate: float = 0.,
+ init_cfg: Optional[dict] = None) -> None:
+ super().__init__(init_cfg=init_cfg)
+
+ self.pos_embed = nn.Conv2d(
+ embed_dim, embed_dim, 3, padding=1, groups=embed_dim)
+ self.norm1 = nn.LayerNorm(embed_dim)
+ self.attn = Attention(
+ embed_dim,
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ attn_drop_rate=attn_drop_rate,
+ proj_drop_rate=drop_rate)
+ # NOTE: drop path for stochastic depth,
+ # we shall see if this is better than dropout here
+ self.drop_path = build_dropout(
+ dict(type='DropPath', drop_prob=drop_path_rate)
+ ) if drop_path_rate > 0. else nn.Identity()
+ self.norm2 = nn.LayerNorm(embed_dim)
+ mlp_hidden_dim = int(embed_dim * mlp_ratio)
+ self.mlp = Mlp(
+ in_features=embed_dim,
+ hidden_features=mlp_hidden_dim,
+ drop_rate=drop_rate)
+
+ def forward(self, x):
+ x = x + self.pos_embed(x)
+ B, N, H, W = x.shape
+ x = x.flatten(2).transpose(1, 2)
+ x = x + self.drop_path(self.attn(self.norm1(x)))
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
+ x = x.transpose(1, 2).reshape(B, N, H, W)
+ return x
+
+
+class WindowSABlock(BaseModule):
+ """Self-Attention Block.
+
+ Args:
+ embed_dim (int): Number of input features.
+ num_heads (int): Number of attention heads.
+ window_size (int): Size of the partition window. Defaults to 14.
+ mlp_ratio (float): Ratio of mlp hidden dimension
+ to embedding dimension. Defaults to 4.
+ qkv_bias (bool): If True, add a learnable bias to query, key, value.
+ Defaults to True.
+ qk_scale (float, optional): Override default qk scale of
+ ``head_dim ** -0.5`` if set. Defaults to None.
+ drop (float): Dropout rate. Defaults to 0.0.
+ attn_drop (float): Attention dropout rate. Defaults to 0.0.
+ drop_paths (float): Stochastic depth rates.
+ Defaults to 0.0.
+ init_cfg (dict, optional): Config dict for initialization.
+ Defaults to None.
+ """
+
+ def __init__(self,
+ embed_dim: int,
+ num_heads: int,
+ window_size: int = 14,
+ mlp_ratio: float = 4.,
+ qkv_bias: bool = False,
+ qk_scale: float = None,
+ drop_rate: float = 0.,
+ attn_drop_rate: float = 0.,
+ drop_path_rate: float = 0.,
+ init_cfg: Optional[dict] = None) -> None:
+ super().__init__(init_cfg=init_cfg)
+ self.windows_size = window_size
+ self.pos_embed = nn.Conv2d(
+ embed_dim, embed_dim, 3, padding=1, groups=embed_dim)
+ self.norm1 = nn.LayerNorm(embed_dim)
+ self.attn = Attention(
+ embed_dim,
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ attn_drop_rate=attn_drop_rate,
+ proj_drop_rate=drop_rate)
+ # NOTE: drop path for stochastic depth,
+ # we shall see if this is better than dropout here
+ self.drop_path = build_dropout(
+ dict(type='DropPath', drop_prob=drop_path_rate)
+ ) if drop_path_rate > 0. else nn.Identity()
+ # self.norm2 = build_dropout(norm_cfg, embed_dims)[1]
+ self.norm2 = nn.LayerNorm(embed_dim)
+ mlp_hidden_dim = int(embed_dim * mlp_ratio)
+ self.mlp = Mlp(
+ in_features=embed_dim,
+ hidden_features=mlp_hidden_dim,
+ drop_rate=drop_rate)
+
+ def window_reverse(self, windows, H, W):
+ """
+ Args:
+ windows: (num_windows*B, window_size, window_size, C)
+ H (int): Height of image
+ W (int): Width of image
+ Returns:
+ x: (B, H, W, C)
+ """
+ window_size = self.window_size
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
+ x = windows.view(B, H // window_size, W // window_size, window_size,
+ window_size, -1)
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
+ return x
+
+ def window_partition(self, x):
+ """
+ Args:
+ x: (B, H, W, C)
+ Returns:
+ windows: (num_windows*B, window_size, window_size, C)
+ """
+ B, H, W, C = x.shape
+ window_size = self.window_size
+ x = x.view(B, H // window_size, window_size, W // window_size,
+ window_size, C)
+ windows = x.permute(0, 1, 3, 2, 4,
+ 5).contiguous().view(-1, window_size, window_size,
+ C)
+ return windows
+
+ def forward(self, x):
+ x = x + self.pos_embed(x)
+ x = x.permute(0, 2, 3, 1)
+ B, H, W, C = x.shape
+ shortcut = x
+ x = self.norm1(x)
+
+ pad_l = pad_t = 0
+ pad_r = (self.window_size - W % self.window_size) % self.window_size
+ pad_b = (self.window_size - H % self.window_size) % self.window_size
+ x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
+ _, H_pad, W_pad, _ = x.shape
+
+ x_windows = self.window_partition(
+ x) # nW*B, window_size, window_size, C
+ x_windows = x_windows.view(-1, self.window_size * self.window_size,
+ C) # nW*B, window_size*window_size, C
+
+ # W-MSA/SW-MSA
+ attn_windows = self.attn(x_windows) # nW*B, window_size*window_size, C
+
+ # merge windows
+ attn_windows = attn_windows.view(-1, self.window_size,
+ self.window_size, C)
+ x = self.window_reverse(attn_windows, H_pad, W_pad) # B H' W' C
+
+ # reverse cyclic shift
+ if pad_r > 0 or pad_b > 0:
+ x = x[:, :H, :W, :].contiguous()
+
+ x = shortcut + self.drop_path(x)
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
+ x = x.permute(0, 3, 1, 2).reshape(B, C, H, W)
+ return x
+
+
+@MODELS.register_module()
+class UniFormer(BaseBackbone):
+ """The implementation of Uniformer with downstream pose estimation task.
+
+ UniFormer: Unifying Convolution and Self-attention for Visual Recognition
+ https://arxiv.org/abs/2201.09450
+ UniFormer: Unified Transformer for Efficient Spatiotemporal Representation
+ Learning https://arxiv.org/abs/2201.04676
+
+ Args:
+ depths (List[int]): number of block in each layer.
+ Default to [3, 4, 8, 3].
+ img_size (int, tuple): input image size. Default: 224.
+ in_channels (int): number of input channels. Default: 3.
+ num_classes (int): number of classes for classification head. Default
+ to 80.
+ embed_dims (List[int]): embedding dimensions.
+ Default to [64, 128, 320, 512].
+ head_dim (int): dimension of attention heads
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
+ qkv_bias (bool, optional): if True, add a learnable bias to query, key,
+ value. Default: True
+ qk_scale (float | None, optional): override default qk scale of
+ head_dim ** -0.5 if set. Default: None.
+ representation_size (Optional[int]): enable and set representation
+ layer (pre-logits) to this value if set
+ drop_rate (float): dropout rate. Default: 0.
+ attn_drop_rate (float): attention dropout rate. Default: 0.
+ drop_path_rate (float): stochastic depth rate. Default: 0.
+ norm_layer (nn.Module): normalization layer
+ use_checkpoint (bool): whether use torch.utils.checkpoint
+ checkpoint_num (list): index for using checkpoint in every stage
+ use_windows (bool): whether use window MHRA
+ use_hybrid (bool): whether use hybrid MHRA
+ window_size (int): size of window (>14). Default: 14.
+ init_cfg (dict, optional): Config dict for initialization.
+ Defaults to None.
+ """
+
+ def __init__(
+ self,
+ depths: List[int] = [3, 4, 8, 3],
+ img_size: int = 224,
+ in_channels: int = 3,
+ num_classes: int = 80,
+ embed_dims: List[int] = [64, 128, 320, 512],
+ head_dim: int = 64,
+ mlp_ratio: int = 4.,
+ qkv_bias: bool = True,
+ qk_scale: float = None,
+ representation_size: Optional[int] = None,
+ drop_rate: float = 0.,
+ attn_drop_rate: float = 0.,
+ drop_path_rate: float = 0.,
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
+ use_checkpoint: bool = False,
+ checkpoint_num=(0, 0, 0, 0),
+ use_window: bool = False,
+ use_hybrid: bool = False,
+ window_size: int = 14,
+ init_cfg: Optional[Union[Dict, List[Dict]]] = [
+ dict(type='TruncNormal', layer='Linear', std=0.02, bias=0.),
+ dict(type='Constant', layer='LayerNorm', val=1., bias=0.)
+ ]
+ ) -> None:
+ super(UniFormer, self).__init__(init_cfg=init_cfg)
+
+ self.num_classes = num_classes
+ self.use_checkpoint = use_checkpoint
+ self.checkpoint_num = checkpoint_num
+ self.use_window = use_window
+ self.logger = get_root_logger()
+ self.logger.info(f'Use torch.utils.checkpoint: {self.use_checkpoint}')
+ self.logger.info(
+ f'torch.utils.checkpoint number: {self.checkpoint_num}')
+ self.num_features = self.embed_dims = embed_dims
+ norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
+
+ self.patch_embed1 = PatchEmbed(
+ img_size=img_size,
+ patch_size=4,
+ in_channels=in_channels,
+ embed_dim=embed_dims[0])
+ self.patch_embed2 = PatchEmbed(
+ img_size=img_size // 4,
+ patch_size=2,
+ in_channels=embed_dims[0],
+ embed_dim=embed_dims[1])
+ self.patch_embed3 = PatchEmbed(
+ img_size=img_size // 8,
+ patch_size=2,
+ in_channels=embed_dims[1],
+ embed_dim=embed_dims[2])
+ self.patch_embed4 = PatchEmbed(
+ img_size=img_size // 16,
+ patch_size=2,
+ in_channels=embed_dims[2],
+ embed_dim=embed_dims[3])
+
+ self.drop_after_pos = nn.Dropout(drop_rate)
+ dpr = [
+ x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
+ ] # stochastic depth decay rule
+ num_heads = [dim // head_dim for dim in embed_dims]
+ self.blocks1 = nn.ModuleList([
+ CBlock(
+ embed_dim=embed_dims[0],
+ mlp_ratio=mlp_ratio,
+ drop_rate=drop_rate,
+ drop_path_rate=dpr[i]) for i in range(depths[0])
+ ])
+ self.norm1 = norm_layer(embed_dims[0])
+ self.blocks2 = nn.ModuleList([
+ CBlock(
+ embed_dim=embed_dims[1],
+ mlp_ratio=mlp_ratio,
+ drop_rate=drop_rate,
+ drop_path_rate=dpr[i + depths[0]]) for i in range(depths[1])
+ ])
+ self.norm2 = norm_layer(embed_dims[1])
+ if self.use_window:
+ self.logger.info('Use local window for all blocks in stage3')
+ self.blocks3 = nn.ModuleList([
+ WindowSABlock(
+ embed_dim=embed_dims[2],
+ num_heads=num_heads[2],
+ window_size=window_size,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ drop_rate=drop_rate,
+ attn_drop_rate=attn_drop_rate,
+ drop_path_rate=dpr[i + depths[0] + depths[1]])
+ for i in range(depths[2])
+ ])
+ elif use_hybrid:
+ self.logger.info('Use hybrid window for blocks in stage3')
+ block3 = []
+ for i in range(depths[2]):
+ if (i + 1) % 4 == 0:
+ block3.append(
+ SABlock(
+ embed_dim=embed_dims[2],
+ num_heads=num_heads[2],
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ drop_rate=drop_rate,
+ attn_drop_rate=attn_drop_rate,
+ drop_path_rate=dpr[i + depths[0] + depths[1]]))
+ else:
+ block3.append(
+ WindowSABlock(
+ embed_dim=embed_dims[2],
+ num_heads=num_heads[2],
+ window_size=window_size,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ drop_rate=drop_rate,
+ attn_drop_rate=attn_drop_rate,
+ drop_path_rate=dpr[i + depths[0] + depths[1]]))
+ self.blocks3 = nn.ModuleList(block3)
+ else:
+ self.logger.info('Use global window for all blocks in stage3')
+ self.blocks3 = nn.ModuleList([
+ SABlock(
+ embed_dim=embed_dims[2],
+ num_heads=num_heads[2],
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ drop_rate=drop_rate,
+ attn_drop_rate=attn_drop_rate,
+ drop_path_rate=dpr[i + depths[0] + depths[1]])
+ for i in range(depths[2])
+ ])
+ self.norm3 = norm_layer(embed_dims[2])
+ self.blocks4 = nn.ModuleList([
+ SABlock(
+ embed_dim=embed_dims[3],
+ num_heads=num_heads[3],
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ drop_rate=drop_rate,
+ attn_drop_rate=attn_drop_rate,
+ drop_path_rate=dpr[i + depths[0] + depths[1] + depths[2]])
+ for i in range(depths[3])
+ ])
+ self.norm4 = norm_layer(embed_dims[3])
+
+ # Representation layer
+ if representation_size:
+ self.num_features = representation_size
+ self.pre_logits = nn.Sequential(
+ OrderedDict([('fc', nn.Linear(embed_dims,
+ representation_size)),
+ ('act', nn.Tanh())]))
+ else:
+ self.pre_logits = nn.Identity()
+
+ self.apply(self._init_weights)
+ self.init_weights()
+
+ def init_weights(self):
+ """Initialize the weights in backbone.
+
+ Args:
+ pretrained (str, optional): Path to pre-trained weights.
+ Defaults to None.
+ """
+ if (isinstance(self.init_cfg, dict)
+ and self.init_cfg['type'] == 'Pretrained'):
+ pretrained = self.init_cfg['checkpoint']
+ load_checkpoint(
+ self,
+ pretrained,
+ map_location='cpu',
+ strict=False,
+ logger=self.logger)
+ self.logger.info(f'Load pretrained model from {pretrained}')
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+
+ @torch.jit.ignore
+ def no_weight_decay(self):
+ return {'pos_embed', 'cls_token'}
+
+ def get_classifier(self):
+ return self.head
+
+ def reset_classifier(self, num_classes, global_pool=''):
+ self.num_classes = num_classes
+ self.head = nn.Linear(
+ self.embed_dims,
+ num_classes) if num_classes > 0 else nn.Identity()
+
+ def forward(self, x):
+ out = []
+ x = self.patch_embed1(x)
+ x = self.drop_after_pos(x)
+ for i, blk in enumerate(self.blocks1):
+ if self.use_checkpoint and i < self.checkpoint_num[0]:
+ x = checkpoint.checkpoint(blk, x)
+ else:
+ x = blk(x)
+ x_out = self.norm1(x.permute(0, 2, 3, 1))
+ out.append(x_out.permute(0, 3, 1, 2).contiguous())
+ x = self.patch_embed2(x)
+ for i, blk in enumerate(self.blocks2):
+ if self.use_checkpoint and i < self.checkpoint_num[1]:
+ x = checkpoint.checkpoint(blk, x)
+ else:
+ x = blk(x)
+ x_out = self.norm2(x.permute(0, 2, 3, 1))
+ out.append(x_out.permute(0, 3, 1, 2).contiguous())
+ x = self.patch_embed3(x)
+ for i, blk in enumerate(self.blocks3):
+ if self.use_checkpoint and i < self.checkpoint_num[2]:
+ x = checkpoint.checkpoint(blk, x)
+ else:
+ x = blk(x)
+ x_out = self.norm3(x.permute(0, 2, 3, 1))
+ out.append(x_out.permute(0, 3, 1, 2).contiguous())
+ x = self.patch_embed4(x)
+ for i, blk in enumerate(self.blocks4):
+ if self.use_checkpoint and i < self.checkpoint_num[3]:
+ x = checkpoint.checkpoint(blk, x)
+ else:
+ x = blk(x)
+ x_out = self.norm4(x.permute(0, 2, 3, 1))
+ out.append(x_out.permute(0, 3, 1, 2).contiguous())
+ return tuple(out)