Skip to content

Commit

Permalink
[Feature] Support MotionBERT (#2482)
Browse files Browse the repository at this point in the history
  • Loading branch information
LareinaM authored Jul 14, 2023
1 parent 6a23e2c commit b5bb116
Show file tree
Hide file tree
Showing 37 changed files with 1,738 additions and 191 deletions.
36 changes: 16 additions & 20 deletions configs/body_3d_keypoint/pose_lift/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,23 +16,19 @@ For single-person 3D pose estimation from a monocular camera, existing works can

#### Human3.6m Dataset

| Arch | Receptive Field | MPJPE | P-MPJPE | N-MPJPE | ckpt | log |

| :------------------------------------------------------ | :-------------: | :---: | :-----: | :-----: | :------------------------------------------------------: | :-----------------------------------------------------: |

| [VideoPose3D-supervised](/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_videopose3d-27frm-supv_8xb128-80e_h36m.py) | 27 | 40.1 | 30.1 | / | [ckpt](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_27frames_fullconv_supervised-fe8fbba9_20210527.pth) | [log](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_27frames_fullconv_supervised_20210527.log.json) |

| [VideoPose3D-supervised](/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_videopose3d-81frm-supv_8xb128-80e_h36m.py) | 81 | 39.1 | 29.3 | / | [ckpt](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_81frames_fullconv_supervised-1f2d1104_20210527.pth) | [log](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_81frames_fullconv_supervised_20210527.log.json) |

| [VideoPose3D-supervised](/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_videopose3d-243frm-supv_8xb128-80e_h36m.py) | 243 | | | / | [ckpt](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_243frames_fullconv_supervised-880bea25_20210527.pth) | [log](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_243frames_fullconv_supervised_20210527.log.json) |

| [VideoPose3D-supervised-CPN](/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_videopose3d-1frm-supv-cpn-ft_8xb128-80e_h36m.py) | 1 | 53.0 | 41.3 | / | [ckpt](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_1frame_fullconv_supervised_cpn_ft-5c3afaed_20210527.pth) | [log](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_1frame_fullconv_supervised_cpn_ft_20210527.log.json) |

| [VideoPose3D-supervised-CPN](/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_videopose3d-243frm-supv-cpn-ft_8xb128-200e_h36m.py) | 243 | | | / | [ckpt](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_243frames_fullconv_supervised_cpn_ft-88f5abbb_20210527.pth) | [log](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_243frames_fullconv_supervised_cpn_ft_20210527.log.json) |

| [VideoPose3D-semi-supervised](/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_videopose3d-27frm-semi-supv_8xb64-200e_h36m.py) | 27 | 57.2 | 42.4 | 54.2 | [ckpt](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_27frames_fullconv_semi-supervised-54aef83b_20210527.pth) | [log](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_27frames_fullconv_semi-supervised_20210527.log.json) |

| [VideoPose3D-semi-supervised-CPN](/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_videopose3d-27frm-semi-supv-cpn-ft_8xb64-200e_h36m.py) | 27 | 67.3 | 50.4 | 63.6 | [ckpt](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_27frames_fullconv_semi-supervised_cpn_ft-71be9cde_20210527.pth) | [log](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_27frames_fullconv_semi-supervised_cpn_ft_20210527.log.json) |
| Arch | MPJPE | P-MPJPE | N-MPJPE | ckpt | log | Details and Download |
| :-------------------------------------------- | :---: | :-----: | :-----: | :-------------------------------------------: | :------------------------------------------: | :---------------------------------------------: |
| [VideoPose3D-supervised-27frm](/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_videopose3d-27frm-supv_8xb128-80e_h36m.py) | 40.1 | 30.1 | / | [ckpt](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_27frames_fullconv_supervised-fe8fbba9_20210527.pth) | [log](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_27frames_fullconv_supervised_20210527.log.json) | [videpose3d_h36m.md](./h36m/videpose3d_h36m.md) |
| [VideoPose3D-supervised-81frm](/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_videopose3d-81frm-supv_8xb128-80e_h36m.py) | 39.1 | 29.3 | / | [ckpt](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_81frames_fullconv_supervised-1f2d1104_20210527.pth) | [log](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_81frames_fullconv_supervised_20210527.log.json) | [videpose3d_h36m.md](./h36m/videpose3d_h36m.md) |
| [VideoPose3D-supervised-243frm](/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_videopose3d-243frm-supv_8xb128-80e_h36m.py) | 37.6 | 28.3 | / | [ckpt](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_243frames_fullconv_supervised-880bea25_20210527.pth) | [log](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_243frames_fullconv_supervised_20210527.log.json) | [videpose3d_h36m.md](./h36m/videpose3d_h36m.md) |
| [VideoPose3D-supervised-CPN-1frm](/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_videopose3d-1frm-supv-cpn-ft_8xb128-80e_h36m.py) | 53.0 | 41.3 | / | [ckpt](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_1frame_fullconv_supervised_cpn_ft-5c3afaed_20210527.pth) | [log](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_1frame_fullconv_supervised_cpn_ft_20210527.log.json) | [videpose3d_h36m.md](./h36m/videpose3d_h36m.md) |
| [VideoPose3D-supervised-CPN-243frm](/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_videopose3d-243frm-supv-cpn-ft_8xb128-200e_h36m.py) | 47.9 | 38.0 | / | [ckpt](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_243frames_fullconv_supervised_cpn_ft-88f5abbb_20210527.pth) | [log](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_243frames_fullconv_supervised_cpn_ft_20210527.log.json) | [videpose3d_h36m.md](./h36m/videpose3d_h36m.md) |
| [VideoPose3D-semi-supervised-27frm](/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_videopose3d-27frm-semi-supv_8xb64-200e_h36m.py) | 57.2 | 42.4 | 54.2 | [ckpt](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_27frames_fullconv_semi-supervised-54aef83b_20210527.pth) | [log](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_27frames_fullconv_semi-supervised_20210527.log.json) | [videpose3d_h36m.md](./h36m/videpose3d_h36m.md) |
| [VideoPose3D-semi-supervised-CPN-27frm](/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_videopose3d-27frm-semi-supv-cpn-ft_8xb64-200e_h36m.py) | 67.3 | 50.4 | 63.6 | [ckpt](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_27frames_fullconv_semi-supervised_cpn_ft-71be9cde_20210527.pth) | [log](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_27frames_fullconv_semi-supervised_cpn_ft_20210527.log.json) | [videpose3d_h36m.md](./h36m/videpose3d_h36m.md) |
| [MotionBERT\*](/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_motionbert-243frm_8xb32-120e_h36m.py) | 35.3 | 27.7 | / | [ckpt](https://download.openmmlab.com/mmpose/v1/body_3d_keypoint/pose_lift/h36m/motionbert_h36m-f554954f_20230531.pth) | / | [motionbert_h36m.md](./h36m/motionbert_h36m.md) |
| [MotionBERT-finetuned\*](/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_motionbert-243frm_8xb32-120e_h36m.py) | 27.5 | 21.6 | / | [ckpt](https://download.openmmlab.com/mmpose/v1/body_3d_keypoint/pose_lift/h36m/motionbert_ft_h36m-d80af323_20230531.pth) | / | [motionbert_h36m.md](./h36m/motionbert_h36m.md) |

*Models with * are converted from the official repo. The config files of these models are only for validation. We don't ensure these config files' training accuracy and welcome you to contribute your reproduction results.*

## Image-based Single-view 3D Human Body Pose Estimation

Expand All @@ -46,6 +42,6 @@ For single-person 3D pose estimation from a monocular camera, existing works can

#### Human3.6m Dataset

| Arch | MPJPE | P-MPJPE | N-MPJPE | ckpt | log |
| :------------------------------------------------------ | :-------------: | :---: | :-----: | :-----: | :------------------------------------------------------: | :-----------------------------------------------------: |
| [SimpleBaseline3D-tcn](/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_simplebaseline3d_8xb64-200e_h36m.py) | 43.4 | 34.3 | /|[ckpt](https://download.openmmlab.com/mmpose/body3d/simple_baseline/simple3Dbaseline_h36m-f0ad73a4_20210419.pth) | [log](https://download.openmmlab.com/mmpose/body3d/simple_baseline/20210415_065056.log.json) |
| Arch | MPJPE | P-MPJPE | N-MPJPE | ckpt | log | Details and Download |
| :---------------------------------------- | :---: | :-----: | :-----: | :---------------------------------------: | :---------------------------------------: | :--------------------------------------------------------: |
| [SimpleBaseline3D-tcn](/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_simplebaseline3d_8xb64-200e_h36m.py) | 43.4 | 34.3 | / | [ckpt](https://download.openmmlab.com/mmpose/body3d/simple_baseline/simple3Dbaseline_h36m-f0ad73a4_20210419.pth) | [log](https://download.openmmlab.com/mmpose/body3d/simple_baseline/20210415_065056.log.json) | [simplebaseline3d_h36m.md](./h36m/simplebaseline3d_h36m.md) |
53 changes: 53 additions & 0 deletions configs/body_3d_keypoint/pose_lift/h36m/motionbert_h36m.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
<!-- [BACKBONE] -->

<details>
<summary align="right"><a href="https://arxiv.org/abs/2210.06551">MotionBERT (2022)</a></summary>

```bibtex
@misc{Zhu_Ma_Liu_Liu_Wu_Wang_2022,
title={Learning Human Motion Representations: A Unified Perspective},
author={Zhu, Wentao and Ma, Xiaoxuan and Liu, Zhaoyang and Liu, Libin and Wu, Wayne and Wang, Yizhou},
year={2022},
month={Oct},
language={en-US}
}
```

</details>

<!-- [DATASET] -->

<details>
<summary align="right"><a href="https://ieeexplore.ieee.org/abstract/document/6682899/">Human3.6M (TPAMI'2014)</a></summary>

```bibtex
@article{h36m_pami,
author = {Ionescu, Catalin and Papava, Dragos and Olaru, Vlad and Sminchisescu, Cristian},
title = {Human3.6M: Large Scale Datasets and Predictive Methods for 3D Human Sensing in Natural Environments},
journal = {IEEE Transactions on Pattern Analysis and Machine Intelligence},
publisher = {IEEE Computer Society},
volume = {36},
number = {7},
pages = {1325-1339},
month = {jul},
year = {2014}
}
```

</details>

Testing results on Human3.6M dataset with ground truth 2D detections

| Arch | MPJPE | average MPJPE | P-MPJPE | ckpt |
| :-------------------------------------------------------------------------------------- | :---: | :-----------: | :-----: | :--------------------------------------------------------------------------------------: |
| [MotionBERT\*](/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_motionbert-243frm_8xb32-120e_h36m.py) | 35.3 | 35.3 | 27.7 | [ckpt](https://download.openmmlab.com/mmpose/v1/body_3d_keypoint/pose_lift/h36m/motionbert_h36m-f554954f_20230531.pth) |
| [MotionBERT-finetuned\*](/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_motionbert-243frm_8xb32-120e_h36m.py) | 27.5 | 27.4 | 21.6 | [ckpt](https://download.openmmlab.com/mmpose/v1/body_3d_keypoint/pose_lift/h36m/motionbert_ft_h36m-d80af323_20230531.pth) |

Testing results on Human3.6M dataset from the [official repo](https://github.com/Walter0807/MotionBERT) with ground truth 2D detections

| Arch | MPJPE | average MPJPE | P-MPJPE | ckpt |
| :-------------------------------------------------------------------------------------- | :---: | :-----------: | :-----: | :--------------------------------------------------------------------------------------: |
| [MotionBERT\*](/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_motionbert-243frm_8xb32-120e_h36m.py) | 40.5 | 39.9 | 34.1 | [ckpt](https://download.openmmlab.com/mmpose/v1/body_3d_keypoint/pose_lift/h36m/motionbert_h36m-f554954f_20230531.pth) |
| [MotionBERT-finetuned\*](/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_motionbert-243frm_8xb32-120e_h36m.py) | 38.2 | 37.7 | 32.6 | [ckpt](https://download.openmmlab.com/mmpose/v1/body_3d_keypoint/pose_lift/h36m/motionbert_ft_h36m-d80af323_20230531.pth) |

*Models with * are converted from the [official repo](https://github.com/Walter0807/MotionBERT). The config files of these models are only for validation. We don't ensure these config files' training accuracy and welcome you to contribute your reproduction results.*
34 changes: 34 additions & 0 deletions configs/body_3d_keypoint/pose_lift/h36m/motionbert_h36m.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
Collections:
- Name: MotionBERT
Paper:
Title: "Learning Human Motion Representations: A Unified Perspective"
URL: https://arxiv.org/abs/2210.06551
README: https://github.com/open-mmlab/mmpose/blob/main/docs/en/papers/algorithms/motionbert.md
Models:
- Config: configs/body_3d_keypoint/pose_lift/h36m/pose-lift_motionbert_8xb32-120e_h36m.py
In Collection: MotionBERT
Metadata:
Architecture: &id001
- MotionBERT
Training Data: Human3.6M
Name: vid_pl_motionbert_8xb32-120e_h36m
Results:
- Dataset: Human3.6M
Metrics:
MPJPE: 35.3
P-MPJPE: 27.7
Task: Body 3D Keypoint
Weights: https://download.openmmlab.com/mmpose/v1/body_3d_keypoint/pose_lift/h36m/motionbert_h36m-f554954f_20230531.pth
- Config: configs/body_3d_keypoint/pose_lift/h36m/pose-lift_motionbert_8xb32-120e_h36m.py
In Collection: MotionBERT
Metadata:
Architecture: *id001
Training Data: Human3.6M
Name: vid_pl_motionbert-finetuned_8xb32-120e_h36m
Results:
- Dataset: Human3.6M
Metrics:
MPJPE: 27.5
P-MPJPE: 21.6
Task: Body 3D Keypoint
Weights: https://download.openmmlab.com/mmpose/v1/body_3d_keypoint/pose_lift/h36m/motionbert_ft_h36m-d80af323_20230531.pth
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
_base_ = ['../../../_base_/default_runtime.py']

vis_backends = [
dict(type='LocalVisBackend'),
]
visualizer = dict(
type='Pose3dLocalVisualizer', vis_backends=vis_backends, name='visualizer')

# runtime
train_cfg = dict(max_epochs=120, val_interval=10)

# optimizer
optim_wrapper = dict(
optimizer=dict(type='AdamW', lr=0.0002, weight_decay=0.01))

# learning policy
param_scheduler = [
dict(type='ExponentialLR', gamma=0.99, end=120, by_epoch=True)
]

auto_scale_lr = dict(base_batch_size=512)

# hooks
default_hooks = dict(
checkpoint=dict(
type='CheckpointHook',
save_best='MPJPE',
rule='less',
max_keep_ckpts=1),
logger=dict(type='LoggerHook', interval=20),
)

# codec settings
train_codec = dict(
type='MotionBERTLabel',
num_keypoints=17,
concat_vis=True,
rootrel=True,
factor_label=False)
val_codec = dict(
type='MotionBERTLabel', num_keypoints=17, concat_vis=True, rootrel=True)

# model settings
model = dict(
type='PoseLifter',
backbone=dict(
type='DSTFormer',
in_channels=3,
feat_size=512,
depth=5,
num_heads=8,
mlp_ratio=2,
seq_len=243,
att_fuse=True,
),
head=dict(
type='MotionRegressionHead',
in_channels=512,
out_channels=3,
embedding_size=512,
loss=dict(type='MPJPEVelocityJointLoss'),
decoder=val_codec,
),
)

# base dataset settings
dataset_type = 'Human36mDataset'
data_root = 'data/h36m/'

# pipelines
train_pipeline = [
dict(
type='RandomFlipAroundRoot',
keypoints_flip_cfg={},
target_flip_cfg={},
flip_image=True),
dict(type='GenerateTarget', encoder=train_codec),
dict(
type='PackPoseInputs',
meta_keys=('id', 'category_id', 'target_img_path', 'flip_indices',
'factor', 'camera_param'))
]
val_pipeline = [
dict(type='GenerateTarget', encoder=val_codec),
dict(
type='PackPoseInputs',
meta_keys=('id', 'category_id', 'target_img_path', 'flip_indices',
'factor', 'camera_param'))
]

# data loaders
train_dataloader = dict(
batch_size=32,
prefetch_factor=4,
pin_memory=True,
num_workers=2,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=True),
dataset=dict(
type=dataset_type,
ann_file='annotation_body3d/fps50/h36m_train.npz',
seq_len=1,
multiple_target=243,
multiple_target_step=81,
camera_param_file='annotation_body3d/cameras.pkl',
data_root=data_root,
data_prefix=dict(img='images/'),
pipeline=train_pipeline,
))

val_dataloader = dict(
batch_size=32,
prefetch_factor=4,
pin_memory=True,
num_workers=2,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=False, round_up=False),
dataset=dict(
type=dataset_type,
ann_file='annotation_body3d/fps50/h36m_test.npz',
seq_len=1,
seq_step=1,
multiple_target=243,
camera_param_file='annotation_body3d/cameras.pkl',
data_root=data_root,
data_prefix=dict(img='images/'),
pipeline=val_pipeline,
test_mode=True,
))
test_dataloader = val_dataloader

# evaluators
skip_list = [
'S9_Greet', 'S9_SittingDown', 'S9_Wait_1', 'S9_Greeting', 'S9_Waiting_1'
]
val_evaluator = [
dict(type='MPJPE', mode='mpjpe', skip_list=skip_list),
dict(type='MPJPE', mode='p-mpjpe', skip_list=skip_list)
]
test_evaluator = val_evaluator
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
type='Pose3dLocalVisualizer', vis_backends=vis_backends, name='visualizer')

# runtime
train_cfg = dict(max_epochs=80, val_interval=10)
train_cfg = dict(max_epochs=160, val_interval=10)

# optimizer
optim_wrapper = dict(optimizer=dict(type='Adam', lr=1e-4))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
type='Pose3dLocalVisualizer', vis_backends=vis_backends, name='visualizer')

# runtime
train_cfg = dict(max_epochs=80, val_interval=10)
train_cfg = dict(max_epochs=160, val_interval=10)

# optimizer
optim_wrapper = dict(optimizer=dict(type='Adam', lr=1e-3))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
type='Pose3dLocalVisualizer', vis_backends=vis_backends, name='visualizer')

# runtime
train_cfg = dict(max_epochs=80, val_interval=10)
train_cfg = dict(max_epochs=160, val_interval=10)

# optimizer
optim_wrapper = dict(optimizer=dict(type='Adam', lr=1e-3))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
type='Pose3dLocalVisualizer', vis_backends=vis_backends, name='visualizer')

# runtime
train_cfg = dict(max_epochs=80, val_interval=10)
train_cfg = dict(max_epochs=160, val_interval=10)

# optimizer
optim_wrapper = dict(optimizer=dict(type='Adam', lr=1e-3))
Expand Down
Loading

0 comments on commit b5bb116

Please sign in to comment.