From 11802069dc807f885c087b5377a53779999d6b86 Mon Sep 17 00:00:00 2001 From: LareinaM Date: Mon, 29 May 2023 10:24:19 +0800 Subject: [PATCH 01/30] add head and backbone --- mmpose/models/backbones/__init__.py | 3 +- mmpose/models/backbones/dstformer.py | 309 ++++++++++++++++++ mmpose/models/heads/__init__.py | 5 +- .../models/heads/regression_heads/__init__.py | 10 +- .../motion_regression_head.py | 157 +++++++++ .../test_backbones/test_dstformer.py | 36 ++ .../test_motion_regression_head.py | 54 +++ 7 files changed, 565 insertions(+), 9 deletions(-) create mode 100644 mmpose/models/backbones/dstformer.py create mode 100644 mmpose/models/heads/regression_heads/motion_regression_head.py create mode 100644 tests/test_models/test_backbones/test_dstformer.py create mode 100644 tests/test_models/test_heads/test_regression_heads/test_motion_regression_head.py diff --git a/mmpose/models/backbones/__init__.py b/mmpose/models/backbones/__init__.py index cb2498560a..563264eecf 100644 --- a/mmpose/models/backbones/__init__.py +++ b/mmpose/models/backbones/__init__.py @@ -1,6 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. from .alexnet import AlexNet from .cpm import CPM +from .dstformer import DSTFormer from .hourglass import HourglassNet from .hourglass_ae import HourglassAENet from .hrformer import HRFormer @@ -33,5 +34,5 @@ 'SEResNet', 'SEResNeXt', 'ShuffleNetV1', 'ShuffleNetV2', 'CPM', 'RSN', 'MSPN', 'ResNeSt', 'VGG', 'TCN', 'ViPNAS_ResNet', 'ViPNAS_MobileNetV3', 'LiteHRNet', 'V2VNet', 'HRFormer', 'PyramidVisionTransformer', - 'PyramidVisionTransformerV2', 'SwinTransformer' + 'PyramidVisionTransformerV2', 'SwinTransformer', 'DSTFormer' ] diff --git a/mmpose/models/backbones/dstformer.py b/mmpose/models/backbones/dstformer.py new file mode 100644 index 0000000000..d1f2833124 --- /dev/null +++ b/mmpose/models/backbones/dstformer.py @@ -0,0 +1,309 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +from mmcv.cnn.bricks import DropPath +from mmengine.model import BaseModule, constant_init +from timm.models.layers import trunc_normal_ + +from mmpose.registry import MODELS +from .base_backbone import BaseBackbone + + +class Attention(BaseModule): + + def __init__(self, + dim, + num_heads=8, + qkv_bias=False, + qk_scale=None, + attn_drop=0., + proj_drop=0., + mode='spatial'): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim**-0.5 + + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.mode = mode + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.proj_drop = nn.Dropout(proj_drop) + + self.attn_count_s = None + self.attn_count_t = None + + def forward(self, x, seq_len=1): + B, N, C = x.shape + + if self.mode == 'temporal': + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // + self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[ + 2] # make torchscript happy (cannot use tensor as tuple) + x = self.forward_temporal(q, k, v, seq_len=seq_len) + elif self.mode == 'spatial': + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // + self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[ + 2] # make torchscript happy (cannot use tensor as tuple) + x = self.forward_spatial(q, k, v) + else: + raise NotImplementedError(self.mode) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def forward_spatial(self, q, k, v): + B, _, N, C = q.shape + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = attn @ v + x = x.transpose(1, 2).reshape(B, N, C * self.num_heads) + return x + + def forward_temporal(self, q, k, v, seq_len=8): + B, _, N, C = q.shape + qt = q.reshape(-1, seq_len, self.num_heads, N, + C).permute(0, 2, 3, 1, 4) # (B, H, N, T, C) + kt = k.reshape(-1, seq_len, self.num_heads, N, + C).permute(0, 2, 3, 1, 4) # (B, H, N, T, C) + vt = v.reshape(-1, seq_len, self.num_heads, N, + C).permute(0, 2, 3, 1, 4) # (B, H, N, T, C) + + attn = (qt @ kt.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = attn @ vt # (B, H, N, T, C) + x = x.permute(0, 3, 2, 1, 4).reshape(B, N, C * self.num_heads) + return x + + +class AttentionBlock(BaseModule): + + def __init__(self, + dim, + num_heads, + mlp_ratio=4., + mlp_out_ratio=1., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + st_mode='st', + att_fuse=False): + super().__init__() + + self.st_mode = st_mode + self.norm1_s = nn.LayerNorm(dim) + self.norm1_t = nn.LayerNorm(dim) + + self.attn_s = Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop, + mode='spatial') + self.attn_t = Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop, + mode='temporal') + + self.drop_path = DropPath( + drop_path) if drop_path > 0. else nn.Identity() + self.norm2_s = nn.LayerNorm(dim) + self.norm2_t = nn.LayerNorm(dim) + + mlp_hidden_dim = int(dim * mlp_ratio) + mlp_out_dim = int(dim * mlp_out_ratio) + self.mlp_s = nn.Sequential( + nn.Linear(dim, mlp_hidden_dim), nn.GELU(), + nn.Linear(mlp_hidden_dim, mlp_out_dim), nn.Dropout(0)) + self.mlp_t = nn.Sequential( + nn.Linear(dim, mlp_hidden_dim), nn.GELU(), + nn.Linear(mlp_hidden_dim, mlp_out_dim), nn.Dropout(0)) + + self.att_fuse = att_fuse + if self.att_fuse: + self.attn_regress = nn.Linear(dim * 2, dim * 2) + + def forward(self, x, seq_len=1): + if self.st_mode == 'st': + x = x + self.drop_path(self.attn_s(self.norm1_s(x), seq_len)) + x = x + self.drop_path(self.mlp_s(self.norm2_s(x))) + x = x + self.drop_path(self.attn_t(self.norm1_t(x), seq_len)) + x = x + self.drop_path(self.mlp_t(self.norm2_t(x))) + elif self.st_mode == 'ts': + x = x + self.drop_path(self.attn_t(self.norm1_t(x), seq_len)) + x = x + self.drop_path(self.mlp_t(self.norm2_t(x))) + x = x + self.drop_path(self.attn_s(self.norm1_s(x), seq_len)) + x = x + self.drop_path(self.mlp_s(self.norm2_s(x))) + else: + raise NotImplementedError(self.st_mode) + return x + + +@MODELS.register_module() +class DSTFormer(BaseBackbone): + """Dual-stream Spatio-temporal Transformer Module. + + Args: + in_channels (int): Number of input channels. + feat_size: Number of feature channels. Default: 256. + depth: The network depth. Default: 5. + num_heads: Number of heads in multi-Head self-attention blocks. + Default: 8. + mlp_ratio (int, optional): The expansion ratio of FFN. Default: 4. + num_keypoints: num_keypoints (int): Number of keypoints. Default: 17. + seq_len: The sequence length. Default: 243. + qkv_bias (bool, optional): If True, add a learnable bias to q, k, v. + Default: True. + qk_scale (float | None, optional): Override default qk scale of + head_dim ** -0.5 if set. Default: None. + drop_rate (float, optional): Dropout ratio of input. Default: 0. + attn_drop_rate (float, optional): Dropout ratio of attention weight. + Default: 0. + drop_path_rate (float, optional): Stochastic depth rate. Default: 0. + att_fuse: Whether to fuse the results of attention blocks. + Default: True. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None + + Example: + >>> from mmpose.models import DSTFormer + >>> import torch + >>> self = DSTFormer(in_channels=3) + >>> self.eval() + >>> inputs = torch.rand(1, 2, 17, 3) + >>> level_outputs = self.forward(inputs) + >>> print(tuple(level_outputs.shape)) + (1, 2, 17, 512) + """ + + def __init__(self, + in_channels, + feat_size=256, + depth=5, + num_heads=8, + mlp_ratio=4, + num_keypoints=17, + seq_len=243, + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + att_fuse=True, + init_cfg=None): + super().__init__(init_cfg=init_cfg) + + self.in_channels = in_channels + self.feat_size = feat_size + + self.joints_embed = nn.Linear(in_channels, feat_size) + self.pos_drop = nn.Dropout(p=drop_rate) + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth) + ] # stochastic depth decay rule + + self.blocks_st = nn.ModuleList([ + AttentionBlock( + dim=feat_size, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[i], + st_mode='st') for i in range(depth) + ]) + self.blocks_ts = nn.ModuleList([ + AttentionBlock( + dim=feat_size, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[i], + st_mode='ts') for i in range(depth) + ]) + + self.norm = nn.LayerNorm(feat_size) + + self.temp_embed = nn.Parameter(torch.zeros(1, seq_len, 1, feat_size)) + self.spat_embed = nn.Parameter( + torch.zeros(1, num_keypoints, feat_size)) + + trunc_normal_(self.temp_embed, std=.02) + trunc_normal_(self.spat_embed, std=.02) + + self.att_fuse = att_fuse + if self.att_fuse: + self.attn_regress = nn.ModuleList( + [nn.Linear(feat_size * 2, 2) for i in range(depth)]) + for i in range(depth): + self.attn_regress[i].weight.data.fill_(0) + self.attn_regress[i].bias.data.fill_(0.5) + + def forward(self, x): + if len(x.shape) == 3: + x = x[None, :] + assert len(x.shape) == 4 + + B, F, K, C = x.shape + x = x.reshape(-1, K, C) + BF = x.shape[0] + x = self.joints_embed(x) # (BF, K, feat_size) + x = x + self.spat_embed + _, K, C = x.shape + x = x.reshape(-1, F, K, C) + self.temp_embed[:, :F, :, :] + x = x.reshape(BF, K, C) # (BF, K, feat_size) + x = self.pos_drop(x) + + for idx, (blk_st, + blk_ts) in enumerate(zip(self.blocks_st, self.blocks_ts)): + x_st = blk_st(x, F) + x_ts = blk_ts(x, F) + if self.att_fuse: + att = self.attn_regress[idx] + alpha = torch.cat([x_st, x_ts], dim=-1) + BF, K = alpha.shape[:2] + alpha = att(alpha) + alpha = alpha.softmax(dim=-1) + x = x_st * alpha[:, :, 0:1] + x_ts * alpha[:, :, 1:2] + else: + x = (x_st + x_ts) * 0.5 + x = self.norm(x) + x = x.reshape(B, F, K, -1) # (B, F, K, feat_size) + return tuple(x) + + def init_weights(self): + """Initialize the weights in backbone.""" + super(DSTFormer, self).init_weights() + + if (isinstance(self.init_cfg, dict) + and self.init_cfg['type'] == 'Pretrained'): + return + + for m in self.modules(): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + constant_init(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + constant_init(m.bias, 0) + constant_init(m.weight, 1.0) diff --git a/mmpose/models/heads/__init__.py b/mmpose/models/heads/__init__.py index e01f2269e3..ef0e17d98e 100644 --- a/mmpose/models/heads/__init__.py +++ b/mmpose/models/heads/__init__.py @@ -5,7 +5,8 @@ HeatmapHead, MSPNHead, ViPNASHead) from .hybrid_heads import DEKRHead, VisPredictHead from .regression_heads import (DSNTHead, IntegralRegressionHead, - RegressionHead, RLEHead, TemporalRegressionHead, + MotionRegressionHead, RegressionHead, RLEHead, + TemporalRegressionHead, TrajectoryRegressionHead) __all__ = [ @@ -13,5 +14,5 @@ 'RegressionHead', 'IntegralRegressionHead', 'SimCCHead', 'RLEHead', 'DSNTHead', 'AssociativeEmbeddingHead', 'DEKRHead', 'VisPredictHead', 'CIDHead', 'RTMCCHead', 'TemporalRegressionHead', - 'TrajectoryRegressionHead' + 'TrajectoryRegressionHead', 'MotionRegressionHead' ] diff --git a/mmpose/models/heads/regression_heads/__init__.py b/mmpose/models/heads/regression_heads/__init__.py index ce9cd5e1b0..729d193b51 100644 --- a/mmpose/models/heads/regression_heads/__init__.py +++ b/mmpose/models/heads/regression_heads/__init__.py @@ -1,16 +1,14 @@ # Copyright (c) OpenMMLab. All rights reserved. from .dsnt_head import DSNTHead from .integral_regression_head import IntegralRegressionHead +from .motion_regression_head import MotionRegressionHead from .regression_head import RegressionHead from .rle_head import RLEHead from .temporal_regression_head import TemporalRegressionHead from .trajectory_regression_head import TrajectoryRegressionHead __all__ = [ - 'RegressionHead', - 'IntegralRegressionHead', - 'DSNTHead', - 'RLEHead', - 'TemporalRegressionHead', - 'TrajectoryRegressionHead', + 'RegressionHead', 'IntegralRegressionHead', 'DSNTHead', 'RLEHead', + 'TemporalRegressionHead', 'TrajectoryRegressionHead', + 'MotionRegressionHead' ] diff --git a/mmpose/models/heads/regression_heads/motion_regression_head.py b/mmpose/models/heads/regression_heads/motion_regression_head.py new file mode 100644 index 0000000000..57699c6bdd --- /dev/null +++ b/mmpose/models/heads/regression_heads/motion_regression_head.py @@ -0,0 +1,157 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from collections import OrderedDict +from typing import Tuple + +import numpy as np +import torch +from torch import Tensor, nn + +from mmpose.evaluation.functional import keypoint_pck_accuracy +from mmpose.registry import KEYPOINT_CODECS, MODELS +from mmpose.utils.tensor_utils import to_numpy +from mmpose.utils.typing import (ConfigType, OptConfigType, OptSampleList, + Predictions) +from ..base_head import BaseHead + + +@MODELS.register_module() +class MotionRegressionHead(BaseHead): + """Regression head of `MotionBERT`_ by Zhu et al (2022). + + Args: + in_channels (int): Number of input channels. Default: 256. + out_channels (int): Number of output channels. Default: 3. + embedding_size (int): Number of embedding channels. Default: 512. + loss (Config): Config for keypoint loss. Defaults to use + :class:`MSELoss` + decoder (Config, optional): The decoder config that controls decoding + keypoint coordinates from the network output. Defaults to ``None`` + init_cfg (Config, optional): Config to control the initialization. See + :attr:`default_init_cfg` for default settings + + .. _`MotionBERT`: https://arxiv.org/abs/2210.06551 + """ + + _version = 2 + + def __init__(self, + in_channels: int = 256, + out_channels: int = 3, + embedding_size: int = 512, + loss: ConfigType = dict( + type='MSELoss', use_target_weight=True), + decoder: OptConfigType = None, + init_cfg: OptConfigType = None): + + if init_cfg is None: + init_cfg = self.default_init_cfg + + super().__init__(init_cfg) + + self.in_channels = in_channels + self.out_channels = out_channels + self.loss_module = MODELS.build(loss) + if decoder is not None: + self.decoder = KEYPOINT_CODECS.build(decoder) + else: + self.decoder = None + + # Define fully-connected layers + self.pre_logits = nn.Sequential( + OrderedDict([('fc', nn.Linear(in_channels, embedding_size)), + ('act', nn.Tanh())])) + self.fc = nn.Linear( + embedding_size, + out_channels) if embedding_size > 0 else nn.Identity() + + def forward(self, feats: Tuple[Tensor]) -> Tensor: + """Forward the network. The input is multi scale feature maps and the + output is the coordinates. + + Args: + feats (Tuple[Tensor]): Multi scale feature maps. + + Returns: + Tensor: Output coordinates (and sigmas[optional]). + """ + x = feats[-1] # (B, F, K, in_channels) + x = self.pre_logits(x) # (B, F, K, embedding_size) + x = self.fc(x) # (B, F, K, out_channels) + + return x + + def predict(self, + feats: Tuple[Tensor], + batch_data_samples: OptSampleList, + test_cfg: ConfigType = {}) -> Predictions: + """Predict results from outputs. + + Returns: + preds (sequence[InstanceData]): Prediction results. + Each contains the following fields: + + - keypoints: Predicted keypoints of shape (B, N, K, D). + - keypoint_scores: Scores of predicted keypoints of shape + (B, N, K). + """ + + batch_coords = self.forward(feats) # (B, K, D) + + # Restore global position with target_root + target_root = batch_data_samples[0].metainfo.get('target_root', None) + if target_root is not None: + target_root = torch.stack([ + torch.from_numpy(b.metainfo['target_root']) + for b in batch_data_samples + ]) + else: + target_root = torch.stack([ + torch.empty((0), dtype=torch.float32) + for _ in batch_data_samples[0].metainfo + ]) + + preds = self.decode((batch_coords, target_root)) + + return preds + + def loss(self, + inputs: Tuple[Tensor], + batch_data_samples: OptSampleList, + train_cfg: ConfigType = {}) -> dict: + """Calculate losses from a batch of inputs and data samples.""" + + pred_outputs = self.forward(inputs) + + lifting_target_label = torch.cat([ + d.gt_instance_labels.lifting_target_label + for d in batch_data_samples + ]) + lifting_target_weights = torch.cat([ + d.gt_instance_labels.lifting_target_weights + for d in batch_data_samples + ]) + + # calculate losses + losses = dict() + loss = self.loss_module(pred_outputs, lifting_target_label, + lifting_target_weights.unsqueeze(-1)) + + losses.update(loss_pose3d=loss) + + # calculate accuracy + _, avg_acc, _ = keypoint_pck_accuracy( + pred=to_numpy(pred_outputs), + gt=to_numpy(lifting_target_label), + mask=to_numpy(lifting_target_weights) > 0, + thr=0.05, + norm_factor=np.ones((pred_outputs.size(0), 3), dtype=np.float32)) + + mpjpe_pose = torch.tensor(avg_acc, device=lifting_target_label.device) + losses.update(mpjpe=mpjpe_pose) + + return losses + + @property + def default_init_cfg(self): + init_cfg = [dict(type='Normal', layer=['Linear'], std=0.01, bias=0)] + return init_cfg diff --git a/tests/test_models/test_backbones/test_dstformer.py b/tests/test_models/test_backbones/test_dstformer.py new file mode 100644 index 0000000000..966ed6f49b --- /dev/null +++ b/tests/test_models/test_backbones/test_dstformer.py @@ -0,0 +1,36 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest import TestCase + +import torch + +from mmpose.models.backbones import DSTFormer +from mmpose.models.backbones.dstformer import AttentionBlock + + +class TestDSTFormer(TestCase): + + def test_attention_block(self): + # BasicTemporalBlock with causal == False + block = AttentionBlock(dim=256, num_heads=2) + x = torch.rand(2, 17, 256) + x_out = block(x) + self.assertEqual(x_out.shape, torch.Size([2, 17, 256])) + + def test_DSTFormer(self): + # Test DSTFormer with depth=2 + model = DSTFormer(in_channels=3, depth=2, seq_len=2) + pose3d = torch.rand((1, 2, 17, 3)) + feat = model(pose3d) + self.assertEqual(feat[0].shape, (2, 17, 256)) + + # Test DSTFormer with depth=4 and qkv_bias=False + model = DSTFormer(in_channels=3, depth=4, seq_len=2, qkv_bias=False) + pose3d = torch.rand((1, 2, 17, 3)) + feat = model(pose3d) + self.assertEqual(feat[0].shape, (2, 17, 256)) + + # Test DSTFormer with depth=4 and att_fuse=False + model = DSTFormer(in_channels=3, depth=4, seq_len=2, att_fuse=False) + pose3d = torch.rand((1, 2, 17, 3)) + feat = model(pose3d) + self.assertEqual(feat[0].shape, (2, 17, 256)) diff --git a/tests/test_models/test_heads/test_regression_heads/test_motion_regression_head.py b/tests/test_models/test_heads/test_regression_heads/test_motion_regression_head.py new file mode 100644 index 0000000000..1234ca8ef2 --- /dev/null +++ b/tests/test_models/test_heads/test_regression_heads/test_motion_regression_head.py @@ -0,0 +1,54 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Tuple +from unittest import TestCase + +import torch +from mmengine.structures import InstanceData + +from mmpose.models.heads import MotionRegressionHead +from mmpose.testing import get_packed_inputs + + +class TestMotionRegressionHead(TestCase): + + def _get_feats( + self, + batch_size: int = 2, + feat_shapes: List[Tuple[int, int, int]] = [(32, 1, 1)], + ): + + feats = [ + torch.rand((batch_size, ) + shape, dtype=torch.float32) + for shape in feat_shapes + ] + return feats + + def test_init(self): + + head = MotionRegressionHead(in_channels=1024) + self.assertEqual(head.fc.weight.shape, (3, 512)) + self.assertIsNone(head.decoder) + + # w/ decoder + head = MotionRegressionHead( + in_channels=1024, + decoder=dict(type='VideoPoseLifting', num_keypoints=17), + ) + self.assertIsNotNone(head.decoder) + + def test_predict(self): + decoder_cfg = dict(type='VideoPoseLifting', num_keypoints=17) + + head = MotionRegressionHead( + in_channels=1024, + decoder=decoder_cfg, + ) + + feats = self._get_feats(batch_size=4, feat_shapes=[(2, 17, 1024)]) + batch_data_samples = get_packed_inputs( + batch_size=2, with_heatmap=False)['data_samples'] + preds = head.predict(feats, batch_data_samples) + + self.assertTrue(len(preds), 2) + self.assertIsInstance(preds[0], InstanceData) + self.assertEqual(preds[0].keypoints.shape, (1, 2, 17, 3)) From 94cb13f599316eaeb9230fbe8c365c5641f1a73f Mon Sep 17 00:00:00 2001 From: LareinaM Date: Mon, 29 May 2023 15:16:03 +0800 Subject: [PATCH 02/30] modify existing modules, add config --- .../h36m/vid_pl_motionbert_8xb32-120e_h36m.py | 102 ++++++++++++++++++ mmpose/codecs/image_pose_lifting.py | 20 ++++ mmpose/codecs/video_pose_lifting.py | 20 ++++ .../datasets/base/base_mocap_dataset.py | 21 +++- .../datasets/datasets/body3d/h36m_dataset.py | 20 +++- mmpose/datasets/transforms/formatting.py | 2 - mmpose/models/backbones/dstformer.py | 11 +- tests/test_codecs/test_image_pose_lifting.py | 6 +- tests/test_codecs/test_video_pose_lifting.py | 13 ++- 9 files changed, 201 insertions(+), 14 deletions(-) create mode 100644 configs/body_3d_keypoint/video_pose_lift/h36m/vid_pl_motionbert_8xb32-120e_h36m.py diff --git a/configs/body_3d_keypoint/video_pose_lift/h36m/vid_pl_motionbert_8xb32-120e_h36m.py b/configs/body_3d_keypoint/video_pose_lift/h36m/vid_pl_motionbert_8xb32-120e_h36m.py new file mode 100644 index 0000000000..500b6c9ea7 --- /dev/null +++ b/configs/body_3d_keypoint/video_pose_lift/h36m/vid_pl_motionbert_8xb32-120e_h36m.py @@ -0,0 +1,102 @@ +_base_ = ['../../../_base_/default_runtime.py'] + +vis_backends = [ + dict(type='LocalVisBackend'), +] +visualizer = dict( + type='Pose3dLocalVisualizer', vis_backends=vis_backends, name='visualizer') + +# runtime +train_cfg = None + +# optimizer + +# learning policy + +auto_scale_lr = dict(base_batch_size=512) + +# hooks +default_hooks = dict( + checkpoint=dict( + type='CheckpointHook', + save_best='MPJPE', + rule='less', + max_keep_ckpts=1), + logger=dict(type='LoggerHook', interval=20), +) + +# codec settings +codec = dict( + type='VideoPoseLifting', + num_keypoints=17, + zero_center=True, + root_index=0, + remove_root=False, + reshape_keypoints=False, + concat_vis=True) + +# model settings +model = dict( + type='PoseLifter', + backbone=dict( + type='DSTFormer', + in_channels=3, + feat_size=512, + depth=5, + num_heads=8, + mlp_ratio=2, + seq_len=243, + att_fuse=True, + ), + head=dict( + type='MotionRegressionHead', + in_channels=512, + out_channels=3, + embedding_size=512, + loss=dict(type='MPJPELoss'), + decoder=codec, + ), +) + +# base dataset settings +dataset_type = 'Human36mDataset' +data_root = 'data/h36m/' + +# pipelines +val_pipeline = [ + dict(type='GenerateTarget', encoder=codec), + dict( + type='PackPoseInputs', + meta_keys=('id', 'category_id', 'target_img_path', 'flip_indices', + 'target_root')) +] + +# data loaders +val_dataloader = dict( + batch_size=32, + num_workers=2, + prefetch_factor=4, + pin_memory=True, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=False, round_up=False), + dataset=dict( + type=dataset_type, + ann_file='annotation_body3d/fps50/h36m_test.npz', + seq_len=1, + seq_step=1, + merge_seq=243, + pad_video_seq=True, + camera_param_file='annotation_body3d/cameras.pkl', + data_root=data_root, + data_prefix=dict(img='images/'), + pipeline=val_pipeline, + test_mode=True, + )) +test_dataloader = val_dataloader + +# evaluators +val_evaluator = [ + dict(type='MPJPE', mode='mpjpe'), + dict(type='MPJPE', mode='p-mpjpe') +] +test_evaluator = val_evaluator diff --git a/mmpose/codecs/image_pose_lifting.py b/mmpose/codecs/image_pose_lifting.py index 64bf925997..91d1db6f52 100644 --- a/mmpose/codecs/image_pose_lifting.py +++ b/mmpose/codecs/image_pose_lifting.py @@ -25,6 +25,10 @@ class ImagePoseLifting(BaseKeypointCodec): Default: ``False``. save_index (bool): If true, store the root position separated from the original pose. Default: ``False``. + reshape_keypoints (bool): If true, reshape the keypoints into shape + (-1, N). Default: ``True``. + concat_vis (bool): If true, concat the visibility item of keypoints. + Default: ``False``. keypoints_mean (np.ndarray, optional): Mean values of keypoints coordinates in shape (K, D). keypoints_std (np.ndarray, optional): Std values of keypoints @@ -42,6 +46,8 @@ def __init__(self, root_index: int, remove_root: bool = False, save_index: bool = False, + reshape_keypoints: bool = True, + concat_vis: bool = False, keypoints_mean: Optional[np.ndarray] = None, keypoints_std: Optional[np.ndarray] = None, target_mean: Optional[np.ndarray] = None, @@ -52,6 +58,8 @@ def __init__(self, self.root_index = root_index self.remove_root = remove_root self.save_index = save_index + self.reshape_keypoints = reshape_keypoints + self.concat_vis = concat_vis if keypoints_mean is not None and keypoints_std is not None: assert keypoints_mean.shape == keypoints_std.shape if target_mean is not None and target_std is not None: @@ -163,7 +171,19 @@ def encode(self, if keypoint_labels.ndim == 2: keypoint_labels = keypoint_labels[None, ...] + if self.concat_vis: + keypoints_visible_ = keypoints_visible + if keypoints_visible.ndim == 2: + keypoints_visible_ = keypoints_visible[..., None] + keypoint_labels = np.concatenate( + (keypoint_labels, keypoints_visible_), axis=2) + + if self.reshape_keypoints: + N = keypoint_labels.shape[0] + keypoint_labels = keypoint_labels.transpose(1, 2, 0).reshape(-1, N) + encoded['keypoint_labels'] = keypoint_labels + encoded['keypoints_visible'] = keypoints_visible encoded['lifting_target_label'] = lifting_target_label encoded['lifting_target_weights'] = lifting_target_weights encoded['trajectory_weights'] = trajectory_weights diff --git a/mmpose/codecs/video_pose_lifting.py b/mmpose/codecs/video_pose_lifting.py index 56cf35fa2d..0092e77db2 100644 --- a/mmpose/codecs/video_pose_lifting.py +++ b/mmpose/codecs/video_pose_lifting.py @@ -30,6 +30,10 @@ class VideoPoseLifting(BaseKeypointCodec): save_index (bool): If true, store the root position separated from the original pose, only takes effect if ``remove_root`` is ``True``. Default: ``False``. + reshape_keypoints (bool): If true, reshape the keypoints into shape + (-1, N). Default: ``True``. + concat_vis (bool): If true, concat the visibility item of keypoints. + Default: ``False``. normalize_camera (bool): Whether to normalize camera intrinsics. Default: ``False``. """ @@ -44,6 +48,8 @@ def __init__(self, root_index: int = 0, remove_root: bool = False, save_index: bool = False, + reshape_keypoints: bool = True, + concat_vis: bool = False, normalize_camera: bool = False): super().__init__() @@ -52,6 +58,8 @@ def __init__(self, self.root_index = root_index self.remove_root = remove_root self.save_index = save_index + self.reshape_keypoints = reshape_keypoints + self.concat_vis = concat_vis self.normalize_camera = normalize_camera def encode(self, @@ -167,7 +175,19 @@ def encode(self, _camera_param['c'] = (_camera_param['c'] - center[:, None]) / scale encoded['camera_param'] = _camera_param + if self.concat_vis: + keypoints_visible_ = keypoints_visible + if keypoints_visible.ndim == 2: + keypoints_visible_ = keypoints_visible[..., None] + keypoint_labels = np.concatenate( + (keypoint_labels, keypoints_visible_), axis=2) + + if self.reshape_keypoints: + N = keypoint_labels.shape[0] + keypoint_labels = keypoint_labels.transpose(1, 2, 0).reshape(-1, N) + encoded['keypoint_labels'] = keypoint_labels + encoded['keypoints_visible'] = keypoints_visible encoded['lifting_target_label'] = lifting_target_label encoded['lifting_target_weights'] = lifting_target_weights encoded['trajectory_weights'] = trajectory_weights diff --git a/mmpose/datasets/datasets/base/base_mocap_dataset.py b/mmpose/datasets/datasets/base/base_mocap_dataset.py index d671a6ae94..81ec986682 100644 --- a/mmpose/datasets/datasets/base/base_mocap_dataset.py +++ b/mmpose/datasets/datasets/base/base_mocap_dataset.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. +import itertools import os.path as osp from copy import deepcopy from itertools import filterfalse, groupby @@ -21,6 +22,8 @@ class BaseMocapDataset(BaseDataset): Args: ann_file (str): Annotation file path. Default: ''. seq_len (int): Number of frames in a sequence. Default: 1. + merge_seq (int): If larger than 0, merge every ``merge_seq`` sequence + together. Default: 0. causal (bool): If set to ``True``, the rightmost input frame will be the target frame. Otherwise, the middle input frame will be the target frame. Default: ``True``. @@ -63,6 +66,7 @@ class BaseMocapDataset(BaseDataset): def __init__(self, ann_file: str = '', seq_len: int = 1, + merge_seq: int = 0, causal: bool = True, subset_frac: float = 1.0, camera_param_file: Optional[str] = None, @@ -102,6 +106,8 @@ def __init__(self, self.seq_len = seq_len self.causal = causal + self.merge_seq = merge_seq + assert 0 < subset_frac <= 1, ( f'Unsupported `subset_frac` {subset_frac}. Supported range ' 'is (0, 1].') @@ -241,6 +247,17 @@ def get_sequence_indices(self) -> List[List[int]]: sequence_indices = [[idx] for idx in range(num_imgs)] else: raise NotImplementedError('Multi-frame data sample unsupported!') + + if self.merge_seq > 0: + sequence_indices_merged = [] + for i in range(0, len(sequence_indices), self.merge_seq): + if i + self.merge_seq > len(sequence_indices): + break + sequence_indices_merged.append( + list( + itertools.chain.from_iterable( + sequence_indices[i:i + self.merge_seq]))) + sequence_indices = sequence_indices_merged return sequence_indices def _load_annotations(self) -> Tuple[List[dict], List[dict]]: @@ -274,7 +291,9 @@ def _load_annotations(self) -> Tuple[List[dict], List[dict]]: image_list = [] for idx, frame_ids in enumerate(self.sequence_indices): - assert len(frame_ids) == self.seq_len + assert len(frame_ids) == ( + self.merge_seq if self.merge_seq * + self.seq_len else self.seq_len) _img_names = img_names[frame_ids] diff --git a/mmpose/datasets/datasets/body3d/h36m_dataset.py b/mmpose/datasets/datasets/body3d/h36m_dataset.py index 60094aa254..9150416239 100644 --- a/mmpose/datasets/datasets/body3d/h36m_dataset.py +++ b/mmpose/datasets/datasets/body3d/h36m_dataset.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. +import itertools import os.path as osp from collections import defaultdict from typing import Callable, List, Optional, Sequence, Tuple, Union @@ -45,6 +46,8 @@ class Human36mDataset(BaseMocapDataset): seq_len (int): Number of frames in a sequence. Default: 1. seq_step (int): The interval for extracting frames from the video. Default: 1. + merge_seq (int): If larger than 0, merge every ``merge_seq`` sequence + together. Default: 0. pad_video_seq (bool): Whether to pad the video so that poses will be predicted for every frame in the video. Default: ``False``. causal (bool): If set to ``True``, the rightmost input frame will be @@ -104,6 +107,7 @@ def __init__(self, ann_file: str = '', seq_len: int = 1, seq_step: int = 1, + merge_seq: int = 0, pad_video_seq: bool = False, causal: bool = True, subset_frac: float = 1.0, @@ -141,6 +145,7 @@ def __init__(self, super().__init__( ann_file=ann_file, seq_len=seq_len, + merge_seq=merge_seq, causal=causal, subset_frac=subset_frac, camera_param_file=camera_param_file, @@ -205,7 +210,20 @@ def get_sequence_indices(self) -> List[List[int]]: start = np.random.randint(0, len(sequence_indices) - subset_size + 1) end = start + subset_size - return sequence_indices[start:end] + sequence_indices = sequence_indices[start:end] + + if self.merge_seq > 0: + sequence_indices_merged = [] + for i in range(0, len(sequence_indices), self.merge_seq): + if i + self.merge_seq > len(sequence_indices): + break + sequence_indices_merged.append( + list( + itertools.chain.from_iterable( + sequence_indices[i:i + self.merge_seq]))) + sequence_indices = sequence_indices_merged + + return sequence_indices def _load_annotations(self) -> Tuple[List[dict], List[dict]]: instance_list, image_list = super()._load_annotations() diff --git a/mmpose/datasets/transforms/formatting.py b/mmpose/datasets/transforms/formatting.py index 05aeef179f..38e7fbc3fb 100644 --- a/mmpose/datasets/transforms/formatting.py +++ b/mmpose/datasets/transforms/formatting.py @@ -51,8 +51,6 @@ def keypoints_to_tensor(keypoints: Union[np.ndarray, Sequence[np.ndarray]] """ if isinstance(keypoints, np.ndarray): keypoints = np.ascontiguousarray(keypoints) - N = keypoints.shape[0] - keypoints = keypoints.transpose(1, 2, 0).reshape(-1, N) tensor = torch.from_numpy(keypoints).contiguous() else: assert is_seq_of(keypoints, np.ndarray) diff --git a/mmpose/models/backbones/dstformer.py b/mmpose/models/backbones/dstformer.py index d1f2833124..317fa55a99 100644 --- a/mmpose/models/backbones/dstformer.py +++ b/mmpose/models/backbones/dstformer.py @@ -95,8 +95,7 @@ def __init__(self, drop=0., attn_drop=0., drop_path=0., - st_mode='st', - att_fuse=False): + st_mode='st'): super().__init__() self.st_mode = st_mode @@ -129,14 +128,10 @@ def __init__(self, mlp_out_dim = int(dim * mlp_out_ratio) self.mlp_s = nn.Sequential( nn.Linear(dim, mlp_hidden_dim), nn.GELU(), - nn.Linear(mlp_hidden_dim, mlp_out_dim), nn.Dropout(0)) + nn.Linear(mlp_hidden_dim, mlp_out_dim), nn.Dropout(drop)) self.mlp_t = nn.Sequential( nn.Linear(dim, mlp_hidden_dim), nn.GELU(), - nn.Linear(mlp_hidden_dim, mlp_out_dim), nn.Dropout(0)) - - self.att_fuse = att_fuse - if self.att_fuse: - self.attn_regress = nn.Linear(dim * 2, dim * 2) + nn.Linear(mlp_hidden_dim, mlp_out_dim), nn.Dropout(drop)) def forward(self, x, seq_len=1): if self.st_mode == 'st': diff --git a/tests/test_codecs/test_image_pose_lifting.py b/tests/test_codecs/test_image_pose_lifting.py index bb94786c32..e4663cc1b4 100644 --- a/tests/test_codecs/test_image_pose_lifting.py +++ b/tests/test_codecs/test_image_pose_lifting.py @@ -30,7 +30,11 @@ def setUp(self) -> None: encoded_wo_sigma=encoded_wo_sigma) def build_pose_lifting_label(self, **kwargs): - cfg = dict(type='ImagePoseLifting', num_keypoints=17, root_index=0) + cfg = dict( + type='ImagePoseLifting', + num_keypoints=17, + root_index=0, + reshape_keypoints=False) cfg.update(kwargs) return KEYPOINT_CODECS.build(cfg) diff --git a/tests/test_codecs/test_video_pose_lifting.py b/tests/test_codecs/test_video_pose_lifting.py index cc58292d0c..d1cca17bbd 100644 --- a/tests/test_codecs/test_video_pose_lifting.py +++ b/tests/test_codecs/test_video_pose_lifting.py @@ -19,7 +19,8 @@ def get_camera_param(self, imgname, camera_param) -> dict: return camera_param[(subj, camera)] def build_pose_lifting_label(self, **kwargs): - cfg = dict(type='VideoPoseLifting', num_keypoints=17) + cfg = dict( + type='VideoPoseLifting', num_keypoints=17, reshape_keypoints=False) cfg.update(kwargs) return KEYPOINT_CODECS.build(cfg) @@ -76,6 +77,16 @@ def test_encode(self): self.assertEqual(encoded['lifting_target_weights'].shape, (17, )) self.assertEqual(encoded['trajectory_weights'].shape, (17, )) + # test reshape_keypoints + codec = self.build_pose_lifting_label(reshape_keypoints=True) + encoded = codec.encode(keypoints, keypoints_visible, lifting_target, + lifting_target_visible, camera_param) + + self.assertEqual(encoded['keypoint_labels'].shape, (34, 1)) + self.assertEqual(encoded['lifting_target_label'].shape, (17, 3)) + self.assertEqual(encoded['lifting_target_weights'].shape, (17, )) + self.assertEqual(encoded['trajectory_weights'].shape, (17, )) + # test removing root codec = self.build_pose_lifting_label( remove_root=True, save_index=True) From f782c58d0d82561e8c57c463b7ff3469ef23ef1e Mon Sep 17 00:00:00 2001 From: LareinaM Date: Mon, 29 May 2023 15:20:50 +0800 Subject: [PATCH 03/30] add docs (pending model and results) --- .../video_pose_lift/h36m/motionbert_h36m.md | 43 +++++++++++++++++++ .../video_pose_lift/h36m/motionbert_h36m.yml | 21 +++++++++ 2 files changed, 64 insertions(+) create mode 100644 configs/body_3d_keypoint/video_pose_lift/h36m/motionbert_h36m.md create mode 100644 configs/body_3d_keypoint/video_pose_lift/h36m/motionbert_h36m.yml diff --git a/configs/body_3d_keypoint/video_pose_lift/h36m/motionbert_h36m.md b/configs/body_3d_keypoint/video_pose_lift/h36m/motionbert_h36m.md new file mode 100644 index 0000000000..698464076c --- /dev/null +++ b/configs/body_3d_keypoint/video_pose_lift/h36m/motionbert_h36m.md @@ -0,0 +1,43 @@ + + +
+MotionBERT (2022) + +```bibtex + @misc{Zhu_Ma_Liu_Liu_Wu_Wang_2022, + title={Learning Human Motion Representations: A Unified Perspective}, + author={Zhu, Wentao and Ma, Xiaoxuan and Liu, Zhaoyang and Liu, Libin and Wu, Wayne and Wang, Yizhou}, + year={2022}, + month={Oct}, + language={en-US} + } +``` + +
+ + + +
+Human3.6M (TPAMI'2014) + +```bibtex +@article{h36m_pami, +author = {Ionescu, Catalin and Papava, Dragos and Olaru, Vlad and Sminchisescu, Cristian}, +title = {Human3.6M: Large Scale Datasets and Predictive Methods for 3D Human Sensing in Natural Environments}, +journal = {IEEE Transactions on Pattern Analysis and Machine Intelligence}, +publisher = {IEEE Computer Society}, +volume = {36}, +number = {7}, +pages = {1325-1339}, +month = {jul}, +year = {2014} +} +``` + +
+ +Testing results on Human3.6M dataset with ground truth 2D detections + +| Arch | MPJPE | P-MPJPE | ckpt | log | +| :------------------------------------------------------------------------------------------------ | :---: | :-----: | :--------: | :-------: | +| [MotionBERT](/configs/body_3d_keypoint/video_pose_lift/h36m/vid_pl_motionbert_8xb32-120e_h36m.py) | | | [ckpt](<>) | [log](<>) | diff --git a/configs/body_3d_keypoint/video_pose_lift/h36m/motionbert_h36m.yml b/configs/body_3d_keypoint/video_pose_lift/h36m/motionbert_h36m.yml new file mode 100644 index 0000000000..dbd1b22b90 --- /dev/null +++ b/configs/body_3d_keypoint/video_pose_lift/h36m/motionbert_h36m.yml @@ -0,0 +1,21 @@ +Collections: +- Name: MotionBERT + Paper: + Title: "Learning Human Motion Representations: A Unified Perspective" + URL: https://arxiv.org/abs/2210.06551 + README: https://github.com/open-mmlab/mmpose/blob/main/docs/en/papers/algorithms/motionbert.md +Models: +- Config: configs/body_3d_keypoint/video_pose_lift/h36m/vid_pl_motionbert_8xb32-120e_h36m.py + In Collection: MotionBERT + Metadata: + Architecture: &id001 + - MotionBERT + Training Data: Human3.6M + Name: vid_pl_motionbert_8xb32-120e_h36m + Results: + - Dataset: Human3.6M + Metrics: + MPJPE: + P-MPJPE: + Task: Body 3D Keypoint + Weights: From 50b4bba9b930e068a4480fa3e532956647e59da2 Mon Sep 17 00:00:00 2001 From: LareinaM Date: Wed, 31 May 2023 10:28:43 +0800 Subject: [PATCH 04/30] add results --- .../video_pose_lift/h36m/motionbert_h36m.md | 6 +++--- .../video_pose_lift/h36m/motionbert_h36m.yml | 5 +++-- .../h36m/vid_pl_motionbert_8xb32-120e_h36m.py | 3 ++- mmpose/evaluation/metrics/keypoint_3d_metrics.py | 3 ++- 4 files changed, 10 insertions(+), 7 deletions(-) diff --git a/configs/body_3d_keypoint/video_pose_lift/h36m/motionbert_h36m.md b/configs/body_3d_keypoint/video_pose_lift/h36m/motionbert_h36m.md index 698464076c..882b74b756 100644 --- a/configs/body_3d_keypoint/video_pose_lift/h36m/motionbert_h36m.md +++ b/configs/body_3d_keypoint/video_pose_lift/h36m/motionbert_h36m.md @@ -38,6 +38,6 @@ year = {2014} Testing results on Human3.6M dataset with ground truth 2D detections -| Arch | MPJPE | P-MPJPE | ckpt | log | -| :------------------------------------------------------------------------------------------------ | :---: | :-----: | :--------: | :-------: | -| [MotionBERT](/configs/body_3d_keypoint/video_pose_lift/h36m/vid_pl_motionbert_8xb32-120e_h36m.py) | | | [ckpt](<>) | [log](<>) | +| Arch | MPJPE | P-MPJPE | N-MPJPE | ckpt | log | +| :------------------------------------------------------------------------------------------------ | :---: | :-----: | :-----: | :--------: | :-------: | +| [MotionBERT](/configs/body_3d_keypoint/video_pose_lift/h36m/vid_pl_motionbert_8xb32-120e_h36m.py) | 34.87 | 14.95 | 34.02 | [ckpt](<>) | [log](<>) | diff --git a/configs/body_3d_keypoint/video_pose_lift/h36m/motionbert_h36m.yml b/configs/body_3d_keypoint/video_pose_lift/h36m/motionbert_h36m.yml index dbd1b22b90..c48beb3973 100644 --- a/configs/body_3d_keypoint/video_pose_lift/h36m/motionbert_h36m.yml +++ b/configs/body_3d_keypoint/video_pose_lift/h36m/motionbert_h36m.yml @@ -15,7 +15,8 @@ Models: Results: - Dataset: Human3.6M Metrics: - MPJPE: - P-MPJPE: + MPJPE: 34.87 + P-MPJPE: 14.95 + N-MPJPE: 34.02 Task: Body 3D Keypoint Weights: diff --git a/configs/body_3d_keypoint/video_pose_lift/h36m/vid_pl_motionbert_8xb32-120e_h36m.py b/configs/body_3d_keypoint/video_pose_lift/h36m/vid_pl_motionbert_8xb32-120e_h36m.py index 500b6c9ea7..78245c4db0 100644 --- a/configs/body_3d_keypoint/video_pose_lift/h36m/vid_pl_motionbert_8xb32-120e_h36m.py +++ b/configs/body_3d_keypoint/video_pose_lift/h36m/vid_pl_motionbert_8xb32-120e_h36m.py @@ -97,6 +97,7 @@ # evaluators val_evaluator = [ dict(type='MPJPE', mode='mpjpe'), - dict(type='MPJPE', mode='p-mpjpe') + dict(type='MPJPE', mode='p-mpjpe'), + dict(type='MPJPE', mode='n-mpjpe') ] test_evaluator = val_evaluator diff --git a/mmpose/evaluation/metrics/keypoint_3d_metrics.py b/mmpose/evaluation/metrics/keypoint_3d_metrics.py index e945650c30..e4645e3c4c 100644 --- a/mmpose/evaluation/metrics/keypoint_3d_metrics.py +++ b/mmpose/evaluation/metrics/keypoint_3d_metrics.py @@ -126,6 +126,7 @@ def compute_metrics(self, results: list) -> Dict[str, float]: for action_category, indices in action_category_indices.items(): metrics[f'{error_name}_{action_category}'] = keypoint_mpjpe( - pred_coords[indices], gt_coords[indices], mask[indices]) + pred_coords[indices], gt_coords[indices], mask[indices], + self.ALIGNMENT[self.mode]) return metrics From 1412ee7c6b1bb4338e76e841d66ab590df7ef9fd Mon Sep 17 00:00:00 2001 From: LareinaM Date: Wed, 31 May 2023 14:56:30 +0800 Subject: [PATCH 05/30] add finetune results, modify visualizer --- .../video_pose_lift/h36m/motionbert_h36m.md | 7 ++++--- .../video_pose_lift/h36m/motionbert_h36m.yml | 14 ++++++++++++++ 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/configs/body_3d_keypoint/video_pose_lift/h36m/motionbert_h36m.md b/configs/body_3d_keypoint/video_pose_lift/h36m/motionbert_h36m.md index 882b74b756..fcce90e80a 100644 --- a/configs/body_3d_keypoint/video_pose_lift/h36m/motionbert_h36m.md +++ b/configs/body_3d_keypoint/video_pose_lift/h36m/motionbert_h36m.md @@ -38,6 +38,7 @@ year = {2014} Testing results on Human3.6M dataset with ground truth 2D detections -| Arch | MPJPE | P-MPJPE | N-MPJPE | ckpt | log | -| :------------------------------------------------------------------------------------------------ | :---: | :-----: | :-----: | :--------: | :-------: | -| [MotionBERT](/configs/body_3d_keypoint/video_pose_lift/h36m/vid_pl_motionbert_8xb32-120e_h36m.py) | 34.87 | 14.95 | 34.02 | [ckpt](<>) | [log](<>) | +| Arch | MPJPE | P-MPJPE | N-MPJPE | ckpt | log | +| :---------------------------------------------------------------------------------------------------------- | :---: | :-----: | :-----: | :--------: | :-------: | +| [MotionBERT](/configs/body_3d_keypoint/video_pose_lift/h36m/vid_pl_motionbert_8xb32-120e_h36m.py) | 34.87 | 14.95 | 34.02 | [ckpt](<>) | [log](<>) | +| [MotionBERT-finetuned](/configs/body_3d_keypoint/video_pose_lift/h36m/vid_pl_motionbert_8xb32-120e_h36m.py) | 34.97 | 14.94 | 34.11 | [ckpt](<>) | [log](<>) | diff --git a/configs/body_3d_keypoint/video_pose_lift/h36m/motionbert_h36m.yml b/configs/body_3d_keypoint/video_pose_lift/h36m/motionbert_h36m.yml index c48beb3973..2a386ff01f 100644 --- a/configs/body_3d_keypoint/video_pose_lift/h36m/motionbert_h36m.yml +++ b/configs/body_3d_keypoint/video_pose_lift/h36m/motionbert_h36m.yml @@ -20,3 +20,17 @@ Models: N-MPJPE: 34.02 Task: Body 3D Keypoint Weights: +- Config: configs/body_3d_keypoint/video_pose_lift/h36m/vid_pl_motionbert_8xb32-120e_h36m.py + In Collection: MotionBERT + Metadata: + Architecture: *id001 + Training Data: Human3.6M + Name: vid_pl_motionbert-finetuned_8xb32-120e_h36m + Results: + - Dataset: Human3.6M + Metrics: + MPJPE: 34.97 + P-MPJPE: 14.94 + N-MPJPE: 34.11 + Task: Body 3D Keypoint + Weights: From 5a694c5f62df799f2d00d127b2cac089b8acd0a8 Mon Sep 17 00:00:00 2001 From: LareinaM Date: Mon, 5 Jun 2023 14:54:54 +0800 Subject: [PATCH 06/30] refactor target and fix bugs --- .../h36m/vid_pl_motionbert_8xb32-120e_h36m.py | 5 +- mmpose/apis/inference_3d.py | 5 +- mmpose/codecs/image_pose_lifting.py | 10 ++-- mmpose/codecs/video_pose_lifting.py | 10 ++-- .../datasets/base/base_mocap_dataset.py | 10 ++-- .../evaluation/metrics/keypoint_3d_metrics.py | 14 ++--- mmpose/models/backbones/dstformer.py | 6 +-- .../motion_regression_head.py | 5 +- .../test_metrics/test_keypoint_3d_metrics.py | 4 +- .../test_motion_regression_head.py | 54 ------------------- 10 files changed, 36 insertions(+), 87 deletions(-) delete mode 100644 tests/test_models/test_heads/test_regression_heads/test_motion_regression_head.py diff --git a/configs/body_3d_keypoint/video_pose_lift/h36m/vid_pl_motionbert_8xb32-120e_h36m.py b/configs/body_3d_keypoint/video_pose_lift/h36m/vid_pl_motionbert_8xb32-120e_h36m.py index 78245c4db0..d800f23601 100644 --- a/configs/body_3d_keypoint/video_pose_lift/h36m/vid_pl_motionbert_8xb32-120e_h36m.py +++ b/configs/body_3d_keypoint/video_pose_lift/h36m/vid_pl_motionbert_8xb32-120e_h36m.py @@ -75,8 +75,6 @@ val_dataloader = dict( batch_size=32, num_workers=2, - prefetch_factor=4, - pin_memory=True, persistent_workers=True, sampler=dict(type='DefaultSampler', shuffle=False, round_up=False), dataset=dict( @@ -97,7 +95,6 @@ # evaluators val_evaluator = [ dict(type='MPJPE', mode='mpjpe'), - dict(type='MPJPE', mode='p-mpjpe'), - dict(type='MPJPE', mode='n-mpjpe') + # dict(type='MPJPE', mode='p-mpjpe') ] test_evaluator = val_evaluator diff --git a/mmpose/apis/inference_3d.py b/mmpose/apis/inference_3d.py index d5bb753945..8725b27caa 100644 --- a/mmpose/apis/inference_3d.py +++ b/mmpose/apis/inference_3d.py @@ -316,8 +316,9 @@ def inference_pose_lifter_model(model, T, K, ), dtype=np.float32) - data_info['lifting_target'] = np.zeros((K, 3), dtype=np.float32) - data_info['lifting_target_visible'] = np.ones((K, 1), dtype=np.float32) + data_info['lifting_target'] = np.zeros((1, K, 3), dtype=np.float32) + data_info['lifting_target_visible'] = np.ones((1, K, 1), + dtype=np.float32) if image_size is not None: assert len(image_size) == 2 diff --git a/mmpose/codecs/image_pose_lifting.py b/mmpose/codecs/image_pose_lifting.py index 91d1db6f52..58d4eb7c23 100644 --- a/mmpose/codecs/image_pose_lifting.py +++ b/mmpose/codecs/image_pose_lifting.py @@ -81,9 +81,9 @@ def encode(self, keypoints_visible (np.ndarray, optional): Keypoint visibilities in shape (N, K). lifting_target (np.ndarray, optional): 3d target coordinate in - shape (K, C). + shape (T, K, C). lifting_target_visible (np.ndarray, optional): Target coordinate in - shape (K, ). + shape (T, K, ). Returns: encoded (dict): Contains the following items: @@ -112,7 +112,7 @@ def encode(self, keypoints_visible = np.ones(keypoints.shape[:2], dtype=np.float32) if lifting_target is None: - lifting_target = keypoints[0] + lifting_target = [keypoints[0]] # set initial value for `lifting_target_weights` # and `trajectory_weights` @@ -134,7 +134,7 @@ def encode(self, f'Got invalid joint shape {lifting_target.shape}' root = lifting_target[..., self.root_index, :] - lifting_target_label = lifting_target - root + lifting_target_label = lifting_target - root[:, None] if self.remove_root: lifting_target_label = np.delete( @@ -214,7 +214,7 @@ def decode(self, keypoints = keypoints * self.target_std + self.target_mean if target_root.size > 0: - keypoints = keypoints + np.expand_dims(target_root, axis=0) + keypoints = keypoints + target_root if self.remove_root: keypoints = np.insert( keypoints, self.root_index, target_root, axis=1) diff --git a/mmpose/codecs/video_pose_lifting.py b/mmpose/codecs/video_pose_lifting.py index 0092e77db2..ed34b7bf06 100644 --- a/mmpose/codecs/video_pose_lifting.py +++ b/mmpose/codecs/video_pose_lifting.py @@ -75,9 +75,9 @@ def encode(self, keypoints_visible (np.ndarray, optional): Keypoint visibilities in shape (N, K). lifting_target (np.ndarray, optional): 3d target coordinate in - shape (K, C). + shape (T, K, C). lifting_target_visible (np.ndarray, optional): Target coordinate in - shape (K, ). + shape (T, K, ). camera_param (dict, optional): The camera parameter dictionary. Returns: @@ -109,7 +109,7 @@ def encode(self, keypoints_visible = np.ones(keypoints.shape[:2], dtype=np.float32) if lifting_target is None: - lifting_target = keypoints[0] + lifting_target = [keypoints[0]] # set initial value for `lifting_target_weights` # and `trajectory_weights` @@ -136,7 +136,7 @@ def encode(self, f'Got invalid joint shape {lifting_target.shape}' root = lifting_target[..., self.root_index, :] - lifting_target_label = lifting_target_label - root + lifting_target_label = lifting_target_label - root[:, None] encoded['target_root'] = root if self.remove_root: @@ -213,7 +213,7 @@ def decode(self, keypoints = encoded.copy() if target_root.size > 0: - keypoints = keypoints + np.expand_dims(target_root, axis=0) + keypoints = keypoints + target_root if self.remove_root: keypoints = np.insert( keypoints, self.root_index, target_root, axis=1) diff --git a/mmpose/datasets/datasets/base/base_mocap_dataset.py b/mmpose/datasets/datasets/base/base_mocap_dataset.py index 81ec986682..7724c5e436 100644 --- a/mmpose/datasets/datasets/base/base_mocap_dataset.py +++ b/mmpose/datasets/datasets/base/base_mocap_dataset.py @@ -107,6 +107,8 @@ def __init__(self, self.causal = causal self.merge_seq = merge_seq + if self.merge_seq: + assert (self.seq_len == 1) assert 0 < subset_frac <= 1, ( f'Unsupported `subset_frac` {subset_frac}. Supported range ' @@ -292,8 +294,8 @@ def _load_annotations(self) -> Tuple[List[dict], List[dict]]: for idx, frame_ids in enumerate(self.sequence_indices): assert len(frame_ids) == ( - self.merge_seq if self.merge_seq * - self.seq_len else self.seq_len) + self.merge_seq * + self.seq_len if self.merge_seq else self.seq_len) _img_names = img_names[frame_ids] @@ -305,7 +307,9 @@ def _load_annotations(self) -> Tuple[List[dict], List[dict]]: keypoints_3d = _keypoints_3d[..., :3] keypoints_3d_visible = _keypoints_3d[..., 3] - target_idx = -1 if self.causal else int(self.seq_len) // 2 + target_idx = [-1] if self.causal else [int(self.seq_len) // 2] + if self.merge_seq: + target_idx = list(range(self.merge_seq)) instance_info = { 'num_keypoints': num_keypoints, diff --git a/mmpose/evaluation/metrics/keypoint_3d_metrics.py b/mmpose/evaluation/metrics/keypoint_3d_metrics.py index e4645e3c4c..9e8ab537a3 100644 --- a/mmpose/evaluation/metrics/keypoint_3d_metrics.py +++ b/mmpose/evaluation/metrics/keypoint_3d_metrics.py @@ -67,16 +67,18 @@ def process(self, data_batch: Sequence[dict], the model. """ for data_sample in data_samples: - # predicted keypoints coordinates, [1, K, D] + # predicted keypoints coordinates, [1, T, K, D] pred_coords = data_sample['pred_instances']['keypoints'] + if pred_coords.ndim == 4: + pred_coords = np.squeeze(pred_coords, axis=0) # ground truth data_info gt = data_sample['gt_instances'] - # ground truth keypoints coordinates, [1, K, D] + # ground truth keypoints coordinates, [T, K, D] gt_coords = gt['lifting_target'] - # ground truth keypoints_visible, [1, K, 1] + # ground truth keypoints_visible, [T, K, 1] mask = gt['lifting_target_visible'].astype(bool).reshape(1, -1) # instance action - img_path = data_sample['target_img_path'] + img_path = data_sample['target_img_path'][0] _, rest = osp.basename(img_path).split('_', 1) action, _ = rest.split('.', 1) @@ -104,10 +106,8 @@ def compute_metrics(self, results: list) -> Dict[str, float]: # pred_coords: [N, K, D] pred_coords = np.concatenate( [result['pred_coords'] for result in results]) - if pred_coords.ndim == 4 and pred_coords.shape[1] == 1: - pred_coords = np.squeeze(pred_coords, axis=1) # gt_coords: [N, K, D] - gt_coords = np.stack([result['gt_coords'] for result in results]) + gt_coords = np.concatenate([result['gt_coords'] for result in results]) # mask: [N, K] mask = np.concatenate([result['mask'] for result in results]) # action_category_indices: Dict[List[int]] diff --git a/mmpose/models/backbones/dstformer.py b/mmpose/models/backbones/dstformer.py index 317fa55a99..4175b34053 100644 --- a/mmpose/models/backbones/dstformer.py +++ b/mmpose/models/backbones/dstformer.py @@ -282,9 +282,9 @@ def forward(self, x): x = x_st * alpha[:, :, 0:1] + x_ts * alpha[:, :, 1:2] else: x = (x_st + x_ts) * 0.5 - x = self.norm(x) - x = x.reshape(B, F, K, -1) # (B, F, K, feat_size) - return tuple(x) + x = self.norm(x) # (BF, K, feat_size) + x = x.reshape(B, F, K, -1) + return x def init_weights(self): """Initialize the weights in backbone.""" diff --git a/mmpose/models/heads/regression_heads/motion_regression_head.py b/mmpose/models/heads/regression_heads/motion_regression_head.py index 57699c6bdd..4a5e51f472 100644 --- a/mmpose/models/heads/regression_heads/motion_regression_head.py +++ b/mmpose/models/heads/regression_heads/motion_regression_head.py @@ -74,7 +74,7 @@ def forward(self, feats: Tuple[Tensor]) -> Tensor: Returns: Tensor: Output coordinates (and sigmas[optional]). """ - x = feats[-1] # (B, F, K, in_channels) + x = feats # (B, F, K, in_channels) x = self.pre_logits(x) # (B, F, K, embedding_size) x = self.fc(x) # (B, F, K, out_channels) @@ -107,8 +107,9 @@ def predict(self, else: target_root = torch.stack([ torch.empty((0), dtype=torch.float32) - for _ in batch_data_samples[0].metainfo + for _ in batch_data_samples ]) + target_root = target_root[..., None, :] preds = self.decode((batch_coords, target_root)) diff --git a/tests/test_evaluation/test_metrics/test_keypoint_3d_metrics.py b/tests/test_evaluation/test_metrics/test_keypoint_3d_metrics.py index 8289b09d0f..d51d493cbc 100644 --- a/tests/test_evaluation/test_metrics/test_keypoint_3d_metrics.py +++ b/tests/test_evaluation/test_metrics/test_keypoint_3d_metrics.py @@ -20,9 +20,9 @@ def setUp(self): for i in range(self.batch_size): gt_instances = InstanceData() keypoints = np.random.random((1, num_keypoints, 3)) - gt_instances.lifting_target = np.random.random((num_keypoints, 3)) + gt_instances.lifting_target = keypoints gt_instances.lifting_target_visible = np.ones( - (num_keypoints, 1)).astype(bool) + (1, num_keypoints, 1)).astype(bool) pred_instances = InstanceData() pred_instances.keypoints = keypoints + np.random.normal( diff --git a/tests/test_models/test_heads/test_regression_heads/test_motion_regression_head.py b/tests/test_models/test_heads/test_regression_heads/test_motion_regression_head.py deleted file mode 100644 index 1234ca8ef2..0000000000 --- a/tests/test_models/test_heads/test_regression_heads/test_motion_regression_head.py +++ /dev/null @@ -1,54 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from typing import List, Tuple -from unittest import TestCase - -import torch -from mmengine.structures import InstanceData - -from mmpose.models.heads import MotionRegressionHead -from mmpose.testing import get_packed_inputs - - -class TestMotionRegressionHead(TestCase): - - def _get_feats( - self, - batch_size: int = 2, - feat_shapes: List[Tuple[int, int, int]] = [(32, 1, 1)], - ): - - feats = [ - torch.rand((batch_size, ) + shape, dtype=torch.float32) - for shape in feat_shapes - ] - return feats - - def test_init(self): - - head = MotionRegressionHead(in_channels=1024) - self.assertEqual(head.fc.weight.shape, (3, 512)) - self.assertIsNone(head.decoder) - - # w/ decoder - head = MotionRegressionHead( - in_channels=1024, - decoder=dict(type='VideoPoseLifting', num_keypoints=17), - ) - self.assertIsNotNone(head.decoder) - - def test_predict(self): - decoder_cfg = dict(type='VideoPoseLifting', num_keypoints=17) - - head = MotionRegressionHead( - in_channels=1024, - decoder=decoder_cfg, - ) - - feats = self._get_feats(batch_size=4, feat_shapes=[(2, 17, 1024)]) - batch_data_samples = get_packed_inputs( - batch_size=2, with_heatmap=False)['data_samples'] - preds = head.predict(feats, batch_data_samples) - - self.assertTrue(len(preds), 2) - self.assertIsInstance(preds[0], InstanceData) - self.assertEqual(preds[0].keypoints.shape, (1, 2, 17, 3)) From 4b722434f3a5c45b8418f55baf6945e96eb73e5b Mon Sep 17 00:00:00 2001 From: LareinaM Date: Wed, 7 Jun 2023 15:08:13 +0800 Subject: [PATCH 07/30] add codec, fix bugs --- mmpose/codecs/__init__.py | 4 +- mmpose/codecs/image_pose_lifting.py | 8 +- mmpose/codecs/mono_pose_lifting.py | 223 ++++++++++++++++++ mmpose/codecs/video_pose_lifting.py | 6 +- .../evaluation/metrics/keypoint_3d_metrics.py | 3 +- mmpose/models/backbones/dstformer.py | 10 +- tests/test_codecs/test_image_pose_lifting.py | 61 +++-- tests/test_codecs/test_mono_pose_lifting.py | 200 ++++++++++++++++ tests/test_codecs/test_video_pose_lifting.py | 73 ++++-- 9 files changed, 527 insertions(+), 61 deletions(-) create mode 100644 mmpose/codecs/mono_pose_lifting.py create mode 100644 tests/test_codecs/test_mono_pose_lifting.py diff --git a/mmpose/codecs/__init__.py b/mmpose/codecs/__init__.py index cdbd8feb0c..9c17513afe 100644 --- a/mmpose/codecs/__init__.py +++ b/mmpose/codecs/__init__.py @@ -4,6 +4,7 @@ from .image_pose_lifting import ImagePoseLifting from .integral_regression_label import IntegralRegressionLabel from .megvii_heatmap import MegviiHeatmap +from .mono_pose_lifting import MonoPoseLifting from .msra_heatmap import MSRAHeatmap from .regression_label import RegressionLabel from .simcc_label import SimCCLabel @@ -14,5 +15,6 @@ __all__ = [ 'MSRAHeatmap', 'MegviiHeatmap', 'UDPHeatmap', 'RegressionLabel', 'SimCCLabel', 'IntegralRegressionLabel', 'AssociativeEmbedding', 'SPR', - 'DecoupledHeatmap', 'VideoPoseLifting', 'ImagePoseLifting' + 'DecoupledHeatmap', 'VideoPoseLifting', 'ImagePoseLifting', + 'MonoPoseLifting' ] diff --git a/mmpose/codecs/image_pose_lifting.py b/mmpose/codecs/image_pose_lifting.py index 58d4eb7c23..852d583c39 100644 --- a/mmpose/codecs/image_pose_lifting.py +++ b/mmpose/codecs/image_pose_lifting.py @@ -139,8 +139,8 @@ def encode(self, if self.remove_root: lifting_target_label = np.delete( lifting_target_label, self.root_index, axis=-2) - assert lifting_target_weights.ndim in {1, 2} - axis_to_remove = -2 if lifting_target_weights.ndim == 2 else -1 + assert lifting_target_weights.ndim in {2, 3} + axis_to_remove = -2 if lifting_target_weights.ndim == 3 else -1 lifting_target_weights = np.delete( lifting_target_weights, self.root_index, axis=axis_to_remove) # Add a flag to avoid latter transforms that rely on the root @@ -210,10 +210,10 @@ def decode(self, keypoints = encoded.copy() if self.target_mean is not None and self.target_std is not None: - assert self.target_mean.shape == keypoints.shape[1:] + assert self.target_mean.shape == keypoints.shape keypoints = keypoints * self.target_std + self.target_mean - if target_root.size > 0: + if target_root is not None and target_root.size > 0: keypoints = keypoints + target_root if self.remove_root: keypoints = np.insert( diff --git a/mmpose/codecs/mono_pose_lifting.py b/mmpose/codecs/mono_pose_lifting.py new file mode 100644 index 0000000000..db686a848b --- /dev/null +++ b/mmpose/codecs/mono_pose_lifting.py @@ -0,0 +1,223 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +from copy import deepcopy +from typing import Optional, Tuple + +import numpy as np + +from mmpose.registry import KEYPOINT_CODECS +from .base import BaseKeypointCodec + + +@KEYPOINT_CODECS.register_module() +class MonoPoseLifting(BaseKeypointCodec): + r"""Generate keypoint coordinates for pose lifter. + + Note: + + - instance number: N + - keypoint number: K + - keypoint dimension: D + - pose-lifitng target dimension: C + + Args: + num_keypoints (int): The number of keypoints in the dataset. + zero_center: Whether to zero-center the target around root. Default: + ``True``. + root_index (int): Root keypoint index in the pose. Default: 0. + remove_root (bool): If true, remove the root keypoint from the pose. + Default: ``False``. + save_index (bool): If true, store the root position separated from the + original pose, only takes effect if ``remove_root`` is ``True``. + Default: ``False``. + concat_vis (bool): If true, concat the visibility item of keypoints. + Default: ``False``. + """ + + auxiliary_encode_keys = { + 'lifting_target', 'lifting_target_visible', 'camera_param' + } + + def __init__(self, + num_keypoints: int, + zero_center: bool = True, + root_index: int = 0, + remove_root: bool = False, + save_index: bool = False, + concat_vis: bool = False): + super().__init__() + + self.num_keypoints = num_keypoints + self.zero_center = zero_center + self.root_index = root_index + self.remove_root = remove_root + self.save_index = save_index + self.concat_vis = concat_vis + + def encode(self, + keypoints: np.ndarray, + keypoints_visible: Optional[np.ndarray] = None, + lifting_target: Optional[np.ndarray] = None, + lifting_target_visible: Optional[np.ndarray] = None, + camera_param: Optional[dict] = None) -> dict: + """Encoding keypoints from input image space to normalized space. + + Args: + keypoints (np.ndarray): Keypoint coordinates in shape (B, T, K, D). + keypoints_visible (np.ndarray, optional): Keypoint visibilities in + shape (B, T, K). + lifting_target (np.ndarray, optional): 3d target coordinate in + shape (T, K, C). + lifting_target_visible (np.ndarray, optional): Target coordinate in + shape (T, K, ). + camera_param (dict, optional): The camera parameter dictionary. + + Returns: + encoded (dict): Contains the following items: + + - keypoint_labels (np.ndarray): The processed keypoints in + shape (K * D, N) where D is 2 for 2d coordinates. + - lifting_target_label: The processed target coordinate in + shape (K, C) or (K-1, C). + - lifting_target_weights (np.ndarray): The target weights in + shape (K, ) or (K-1, ). + - trajectory_weights (np.ndarray): The trajectory weights in + shape (K, ). + + In addition, there are some optional items it may contain: + + - target_root (np.ndarray): The root coordinate of target in + shape (C, ). Exists if ``self.zero_center`` is ``True``. + - target_root_removed (bool): Indicate whether the root of + pose-lifitng target is removed. Exists if + ``self.remove_root`` is ``True``. + - target_root_index (int): An integer indicating the index of + root. Exists if ``self.remove_root`` and ``self.save_index`` + are ``True``. + - camera_param (dict): The updated camera parameter dictionary. + Exists if ``self.normalize_camera`` is ``True``. + """ + if keypoints_visible is None: + keypoints_visible = np.ones(keypoints.shape[:2], dtype=np.float32) + + if lifting_target is None: + lifting_target = [keypoints[0]] + + # set initial value for `lifting_target_weights` + # and `trajectory_weights` + if lifting_target_visible is None: + lifting_target_visible = np.ones( + lifting_target.shape[:-1], dtype=np.float32) + lifting_target_weights = lifting_target_visible + trajectory_weights = (1 / lifting_target[:, 2]) + else: + valid = lifting_target_visible > 0.5 + lifting_target_weights = np.where(valid, 1., 0.).astype(np.float32) + trajectory_weights = lifting_target_weights + + if camera_param is None: + camera_param = dict() + + encoded = dict() + + lifting_target_label = lifting_target.copy() + # Zero-center the target pose around a given root keypoint + if self.zero_center: + assert (lifting_target.ndim >= 2 and + lifting_target.shape[-2] > self.root_index), \ + f'Got invalid joint shape {lifting_target.shape}' + + root = lifting_target[..., self.root_index, :] + lifting_target_label = lifting_target_label - root[:, None] + encoded['target_root'] = root + + if self.remove_root: + lifting_target_label = np.delete( + lifting_target_label, self.root_index, axis=-2) + assert lifting_target_weights.ndim == 2 + lifting_target_weights = np.delete( + lifting_target_weights, self.root_index, axis=-1) + # Add a flag to avoid latter transforms that rely on the root + # joint or the original joint index + encoded['target_root_removed'] = True + + # Save the root index for restoring the global pose + if self.save_index: + encoded['target_root_index'] = self.root_index + + # Normalize the 2D keypoint coordinate with image width and height + _camera_param = deepcopy(camera_param) + assert 'w' in _camera_param and 'h' in _camera_param + w, h = _camera_param['w'], _camera_param['h'] + keypoint_labels = keypoints.copy() + keypoint_labels[:, :, :2] = keypoint_labels[:, :, :2] / w * 2 - [ + 1, h / w + ] + keypoint_labels[:, :, 2:] = keypoint_labels[:, :, 2:] / w * 2 + + assert keypoint_labels.ndim in {2, 3} + if keypoint_labels.ndim == 2: + keypoint_labels = keypoint_labels[None, ...] + + if self.concat_vis: + keypoints_visible_ = keypoints_visible + if keypoints_visible.ndim == 2: + keypoints_visible_ = keypoints_visible[..., None] + keypoint_labels = np.concatenate( + (keypoint_labels, keypoints_visible_), axis=2) + + encoded['keypoint_labels'] = keypoint_labels + encoded['keypoints_visible'] = keypoints_visible + encoded['lifting_target_label'] = lifting_target_label + encoded['lifting_target_weights'] = lifting_target_weights + encoded['trajectory_weights'] = trajectory_weights + + return encoded + + def decode( + self, + encoded: np.ndarray, + target_root: Optional[np.ndarray] = None, + w: Optional[np.ndarray] = None, + h: Optional[np.ndarray] = None, + ) -> Tuple[np.ndarray, np.ndarray]: + """Decode keypoint coordinates from normalized space to input image + space. + + Args: + encoded (np.ndarray): Coordinates in shape (N, K, C). + target_root (np.ndarray, optional): The pose-lifitng target root + coordinate. Default: ``None``. + target_root (np.ndarray, optional): The pose-lifitng target root + coordinate. Default: ``None``. + w (np.ndarray, optional): The image widths in shape (N, ). + Default: ``None``. + h (np.ndarray, optional): The image heights in shape (N, ). + Default: ``None``. + + Returns: + keypoints (np.ndarray): Decoded coordinates in shape (N, K, C). + scores (np.ndarray): The keypoint scores in shape (N, K). + """ + keypoints = encoded.copy() + + if target_root is not None and target_root.size > 0: + if self.zero_center: + keypoints = keypoints + target_root + if self.remove_root: + keypoints = np.insert( + keypoints, self.root_index, target_root, axis=1) + scores = np.ones(keypoints.shape[:-1], dtype=np.float32) + + if w is not None and w.size > 0: + assert w.shape == h.shape + assert w.shape[0] == keypoints.shape[0] + assert w.ndim in {1, 2} + if w.ndim == 1: + w = w[:, None] + h = h[:, None] + trans = np.append( + np.ones((w.shape[0], 1)), h / w, axis=1)[:, None, :] + keypoints[..., :2] = (keypoints[..., :2] + trans) * w[:, None] / 2 + keypoints[..., 2:] = keypoints[..., 2:] * w[:, None] / 2 + return keypoints, scores diff --git a/mmpose/codecs/video_pose_lifting.py b/mmpose/codecs/video_pose_lifting.py index ed34b7bf06..dc51f1f92f 100644 --- a/mmpose/codecs/video_pose_lifting.py +++ b/mmpose/codecs/video_pose_lifting.py @@ -142,8 +142,8 @@ def encode(self, if self.remove_root: lifting_target_label = np.delete( lifting_target_label, self.root_index, axis=-2) - assert lifting_target_weights.ndim in {1, 2} - axis_to_remove = -2 if lifting_target_weights.ndim == 2 else -1 + assert lifting_target_weights.ndim in {2, 3} + axis_to_remove = -2 if lifting_target_weights.ndim == 3 else -1 lifting_target_weights = np.delete( lifting_target_weights, self.root_index, @@ -212,7 +212,7 @@ def decode(self, """ keypoints = encoded.copy() - if target_root.size > 0: + if target_root is not None and target_root.size > 0: keypoints = keypoints + target_root if self.remove_root: keypoints = np.insert( diff --git a/mmpose/evaluation/metrics/keypoint_3d_metrics.py b/mmpose/evaluation/metrics/keypoint_3d_metrics.py index 9e8ab537a3..7c8f73d26c 100644 --- a/mmpose/evaluation/metrics/keypoint_3d_metrics.py +++ b/mmpose/evaluation/metrics/keypoint_3d_metrics.py @@ -76,7 +76,8 @@ def process(self, data_batch: Sequence[dict], # ground truth keypoints coordinates, [T, K, D] gt_coords = gt['lifting_target'] # ground truth keypoints_visible, [T, K, 1] - mask = gt['lifting_target_visible'].astype(bool).reshape(1, -1) + mask = gt['lifting_target_visible'].astype(bool).reshape( + gt_coords.shape[0], -1) # instance action img_path = data_sample['target_img_path'][0] _, rest = osp.basename(img_path).split('_', 1) diff --git a/mmpose/models/backbones/dstformer.py b/mmpose/models/backbones/dstformer.py index 4175b34053..76bf6dd83c 100644 --- a/mmpose/models/backbones/dstformer.py +++ b/mmpose/models/backbones/dstformer.py @@ -99,8 +99,8 @@ def __init__(self, super().__init__() self.st_mode = st_mode - self.norm1_s = nn.LayerNorm(dim) - self.norm1_t = nn.LayerNorm(dim) + self.norm1_s = nn.LayerNorm(dim, eps=1e-06) + self.norm1_t = nn.LayerNorm(dim, eps=1e-06) self.attn_s = Attention( dim, @@ -121,8 +121,8 @@ def __init__(self, self.drop_path = DropPath( drop_path) if drop_path > 0. else nn.Identity() - self.norm2_s = nn.LayerNorm(dim) - self.norm2_t = nn.LayerNorm(dim) + self.norm2_s = nn.LayerNorm(dim, eps=1e-06) + self.norm2_t = nn.LayerNorm(dim, eps=1e-06) mlp_hidden_dim = int(dim * mlp_ratio) mlp_out_dim = int(dim * mlp_out_ratio) @@ -237,7 +237,7 @@ def __init__(self, st_mode='ts') for i in range(depth) ]) - self.norm = nn.LayerNorm(feat_size) + self.norm = nn.LayerNorm(feat_size, eps=1e-06) self.temp_embed = nn.Parameter(torch.zeros(1, seq_len, 1, feat_size)) self.spat_embed = nn.Parameter( diff --git a/tests/test_codecs/test_image_pose_lifting.py b/tests/test_codecs/test_image_pose_lifting.py index e4663cc1b4..7a07f87786 100644 --- a/tests/test_codecs/test_image_pose_lifting.py +++ b/tests/test_codecs/test_image_pose_lifting.py @@ -13,14 +13,18 @@ def setUp(self) -> None: keypoints = (0.1 + 0.8 * np.random.rand(1, 17, 2)) * [192, 256] keypoints = np.round(keypoints).astype(np.float32) keypoints_visible = np.random.randint(2, size=(1, 17)) - lifting_target = (0.1 + 0.8 * np.random.rand(17, 3)) - lifting_target_visible = np.random.randint(2, size=(17, )) + lifting_target = (0.1 + 0.8 * np.random.rand(1, 17, 3)) + lifting_target_visible = np.random.randint( + 2, size=( + 1, + 17, + )) encoded_wo_sigma = np.random.rand(1, 17, 3) self.keypoints_mean = np.random.rand(17, 2).astype(np.float32) self.keypoints_std = np.random.rand(17, 2).astype(np.float32) + 1e-6 - self.target_mean = np.random.rand(17, 3).astype(np.float32) - self.target_std = np.random.rand(17, 3).astype(np.float32) + 1e-6 + self.target_mean = np.random.rand(1, 17, 3).astype(np.float32) + self.target_std = np.random.rand(1, 17, 3).astype(np.float32) + 1e-6 self.data = dict( keypoints=keypoints, @@ -54,10 +58,19 @@ def test_encode(self): lifting_target_visible) self.assertEqual(encoded['keypoint_labels'].shape, (1, 17, 2)) - self.assertEqual(encoded['lifting_target_label'].shape, (17, 3)) - self.assertEqual(encoded['lifting_target_weights'].shape, (17, )) - self.assertEqual(encoded['trajectory_weights'].shape, (17, )) - self.assertEqual(encoded['target_root'].shape, (3, )) + self.assertEqual(encoded['lifting_target_label'].shape, (1, 17, 3)) + self.assertEqual(encoded['lifting_target_weights'].shape, ( + 1, + 17, + )) + self.assertEqual(encoded['trajectory_weights'].shape, ( + 1, + 17, + )) + self.assertEqual(encoded['target_root'].shape, ( + 1, + 3, + )) # test removing root codec = self.build_pose_lifting_label( @@ -67,10 +80,16 @@ def test_encode(self): self.assertTrue('target_root_removed' in encoded and 'target_root_index' in encoded) - self.assertEqual(encoded['lifting_target_weights'].shape, (16, )) self.assertEqual(encoded['keypoint_labels'].shape, (1, 17, 2)) - self.assertEqual(encoded['lifting_target_label'].shape, (16, 3)) - self.assertEqual(encoded['target_root'].shape, (3, )) + self.assertEqual(encoded['lifting_target_label'].shape, (1, 16, 3)) + self.assertEqual(encoded['lifting_target_weights'].shape, ( + 1, + 16, + )) + self.assertEqual(encoded['target_root'].shape, ( + 1, + 3, + )) # test normalization codec = self.build_pose_lifting_label( @@ -82,7 +101,7 @@ def test_encode(self): lifting_target_visible) self.assertEqual(encoded['keypoint_labels'].shape, (1, 17, 2)) - self.assertEqual(encoded['lifting_target_label'].shape, (17, 3)) + self.assertEqual(encoded['lifting_target_label'].shape, (1, 17, 3)) def test_decode(self): lifting_target = self.data['lifting_target'] @@ -116,12 +135,10 @@ def test_cicular_verification(self): lifting_target_visible) _keypoints, _ = codec.decode( - np.expand_dims(encoded['lifting_target_label'], axis=0), + encoded['lifting_target_label'], target_root=lifting_target[..., 0, :]) - self.assertTrue( - np.allclose( - np.expand_dims(lifting_target, axis=0), _keypoints, atol=5.)) + self.assertTrue(np.allclose(lifting_target, _keypoints, atol=5.)) # test removing root codec = self.build_pose_lifting_label(remove_root=True) @@ -129,12 +146,10 @@ def test_cicular_verification(self): lifting_target_visible) _keypoints, _ = codec.decode( - np.expand_dims(encoded['lifting_target_label'], axis=0), + encoded['lifting_target_label'], target_root=lifting_target[..., 0, :]) - self.assertTrue( - np.allclose( - np.expand_dims(lifting_target, axis=0), _keypoints, atol=5.)) + self.assertTrue(np.allclose(lifting_target, _keypoints, atol=5.)) # test normalization codec = self.build_pose_lifting_label( @@ -146,9 +161,7 @@ def test_cicular_verification(self): lifting_target_visible) _keypoints, _ = codec.decode( - np.expand_dims(encoded['lifting_target_label'], axis=0), + encoded['lifting_target_label'], target_root=lifting_target[..., 0, :]) - self.assertTrue( - np.allclose( - np.expand_dims(lifting_target, axis=0), _keypoints, atol=5.)) + self.assertTrue(np.allclose(lifting_target, _keypoints, atol=5.)) diff --git a/tests/test_codecs/test_mono_pose_lifting.py b/tests/test_codecs/test_mono_pose_lifting.py new file mode 100644 index 0000000000..bae1be6d50 --- /dev/null +++ b/tests/test_codecs/test_mono_pose_lifting.py @@ -0,0 +1,200 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +from unittest import TestCase + +import numpy as np +from mmengine.fileio import load + +from mmpose.codecs import MonoPoseLifting +from mmpose.registry import KEYPOINT_CODECS + + +class TestMonoPoseLifting(TestCase): + + def get_camera_param(self, imgname, camera_param) -> dict: + """Get camera parameters of a frame by its image name.""" + subj, rest = osp.basename(imgname).split('_', 1) + action, rest = rest.split('.', 1) + camera, rest = rest.split('_', 1) + return camera_param[(subj, camera)] + + def build_pose_lifting_label(self, **kwargs): + cfg = dict(type='MonoPoseLifting', num_keypoints=17) + cfg.update(kwargs) + return KEYPOINT_CODECS.build(cfg) + + def setUp(self) -> None: + keypoints = (0.1 + 0.8 * np.random.rand(1, 17, 2)) * [1000, 1002] + keypoints = np.round(keypoints).astype(np.float32) + keypoints_visible = np.random.randint(2, size=(1, 17)) + lifting_target = (0.1 + 0.8 * np.random.rand(1, 17, 3)) + lifting_target_visible = np.random.randint( + 2, size=( + 1, + 17, + )) + encoded_wo_sigma = np.random.rand(1, 17, 3) + + camera_param = load('tests/data/h36m/cameras.pkl') + camera_param = self.get_camera_param( + 'S1/S1_Directions_1.54138969/S1_Directions_1.54138969_000001.jpg', + camera_param) + + self.data = dict( + keypoints=keypoints, + keypoints_visible=keypoints_visible, + lifting_target=lifting_target, + lifting_target_visible=lifting_target_visible, + camera_param=camera_param, + encoded_wo_sigma=encoded_wo_sigma) + + def test_build(self): + codec = self.build_pose_lifting_label() + self.assertIsInstance(codec, MonoPoseLifting) + + def test_encode(self): + keypoints = self.data['keypoints'] + keypoints_visible = self.data['keypoints_visible'] + lifting_target = self.data['lifting_target'] + lifting_target_visible = self.data['lifting_target_visible'] + camera_param = self.data['camera_param'] + + # test default settings + codec = self.build_pose_lifting_label() + encoded = codec.encode(keypoints, keypoints_visible, lifting_target, + lifting_target_visible, camera_param) + + self.assertEqual(encoded['keypoint_labels'].shape, (1, 17, 2)) + self.assertEqual(encoded['lifting_target_label'].shape, (1, 17, 3)) + self.assertEqual(encoded['lifting_target_weights'].shape, ( + 1, + 17, + )) + self.assertEqual(encoded['trajectory_weights'].shape, ( + 1, + 17, + )) + self.assertEqual(encoded['target_root'].shape, ( + 1, + 3, + )) + + # test not zero-centering + codec = self.build_pose_lifting_label(zero_center=False) + encoded = codec.encode(keypoints, keypoints_visible, lifting_target, + lifting_target_visible, camera_param) + + self.assertEqual(encoded['keypoint_labels'].shape, (1, 17, 2)) + self.assertEqual(encoded['lifting_target_label'].shape, (1, 17, 3)) + self.assertEqual(encoded['lifting_target_weights'].shape, ( + 1, + 17, + )) + self.assertEqual(encoded['trajectory_weights'].shape, ( + 1, + 17, + )) + + # test removing root + codec = self.build_pose_lifting_label( + remove_root=True, save_index=True) + encoded = codec.encode(keypoints, keypoints_visible, lifting_target, + lifting_target_visible, camera_param) + + self.assertTrue('target_root_removed' in encoded + and 'target_root_index' in encoded) + self.assertEqual(encoded['keypoint_labels'].shape, (1, 17, 2)) + self.assertEqual(encoded['lifting_target_label'].shape, (1, 16, 3)) + self.assertEqual(encoded['target_root'].shape, ( + 1, + 3, + )) + + # test concatenating visibility + codec = self.build_pose_lifting_label(concat_vis=True) + encoded = codec.encode(keypoints, keypoints_visible, lifting_target, + lifting_target_visible, camera_param) + + self.assertEqual(encoded['keypoint_labels'].shape, (1, 17, 3)) + self.assertEqual(encoded['lifting_target_label'].shape, (1, 17, 3)) + self.assertEqual(encoded['target_root'].shape, ( + 1, + 3, + )) + + def test_decode(self): + lifting_target = self.data['lifting_target'] + encoded_wo_sigma = self.data['encoded_wo_sigma'] + camera_param = self.data['camera_param'] + + # test default settings + codec = self.build_pose_lifting_label() + + decoded, scores = codec.decode( + encoded_wo_sigma, target_root=lifting_target[..., 0, :]) + + self.assertEqual(decoded.shape, (1, 17, 3)) + self.assertEqual(scores.shape, (1, 17)) + + # test `remove_root=True` + codec = self.build_pose_lifting_label(remove_root=True) + + decoded, scores = codec.decode( + encoded_wo_sigma, target_root=lifting_target[..., 0, :]) + + self.assertEqual(decoded.shape, (1, 18, 3)) + self.assertEqual(scores.shape, (1, 18)) + + # test denormalize according to image shape + codec = self.build_pose_lifting_label(zero_center=False) + + decoded, scores = codec.decode( + encoded_wo_sigma, + w=np.array([camera_param['w']]), + h=np.array([camera_param['h']])) + + self.assertEqual(decoded.shape, (1, 17, 3)) + self.assertEqual(scores.shape, (1, 17)) + + def test_cicular_verification(self): + keypoints = self.data['keypoints'] + keypoints_visible = self.data['keypoints_visible'] + lifting_target = self.data['lifting_target'] + lifting_target_visible = self.data['lifting_target_visible'] + camera_param = self.data['camera_param'] + + # test default settings + codec = self.build_pose_lifting_label() + encoded = codec.encode(keypoints, keypoints_visible, lifting_target, + lifting_target_visible, camera_param) + + _keypoints, _ = codec.decode( + encoded['lifting_target_label'], + target_root=lifting_target[..., 0, :]) + + self.assertTrue(np.allclose(lifting_target, _keypoints, atol=5.)) + + # test removing root + codec = self.build_pose_lifting_label(remove_root=True) + encoded = codec.encode(keypoints, keypoints_visible, lifting_target, + lifting_target_visible, camera_param) + + _keypoints, _ = codec.decode( + encoded['lifting_target_label'], + target_root=lifting_target[..., 0, :]) + + self.assertTrue(np.allclose(lifting_target, _keypoints, atol=5.)) + + # test denormalize according to image shape + keypoints = (0.1 + 0.8 * np.random.rand(1, 17, 3)) * [1000, 1002, 1] + keypoints = np.round(keypoints).astype(np.float32) + codec = self.build_pose_lifting_label(zero_center=False) + encoded = codec.encode(keypoints, keypoints_visible, lifting_target, + lifting_target_visible, camera_param) + + _keypoints, _ = codec.decode( + encoded['keypoint_labels'], + w=np.array([camera_param['w']]), + h=np.array([camera_param['h']])) + + self.assertTrue(np.allclose(keypoints, _keypoints, atol=5.)) diff --git a/tests/test_codecs/test_video_pose_lifting.py b/tests/test_codecs/test_video_pose_lifting.py index d1cca17bbd..4ed11f7f7a 100644 --- a/tests/test_codecs/test_video_pose_lifting.py +++ b/tests/test_codecs/test_video_pose_lifting.py @@ -28,8 +28,12 @@ def setUp(self) -> None: keypoints = (0.1 + 0.8 * np.random.rand(1, 17, 2)) * [192, 256] keypoints = np.round(keypoints).astype(np.float32) keypoints_visible = np.random.randint(2, size=(1, 17)) - lifting_target = (0.1 + 0.8 * np.random.rand(17, 3)) - lifting_target_visible = np.random.randint(2, size=(17, )) + lifting_target = (0.1 + 0.8 * np.random.rand(1, 17, 3)) + lifting_target_visible = np.random.randint( + 2, size=( + 1, + 17, + )) encoded_wo_sigma = np.random.rand(1, 17, 3) camera_param = load('tests/data/h36m/cameras.pkl') @@ -62,10 +66,19 @@ def test_encode(self): lifting_target_visible, camera_param) self.assertEqual(encoded['keypoint_labels'].shape, (1, 17, 2)) - self.assertEqual(encoded['lifting_target_label'].shape, (17, 3)) - self.assertEqual(encoded['lifting_target_weights'].shape, (17, )) - self.assertEqual(encoded['trajectory_weights'].shape, (17, )) - self.assertEqual(encoded['target_root'].shape, (3, )) + self.assertEqual(encoded['lifting_target_label'].shape, (1, 17, 3)) + self.assertEqual(encoded['lifting_target_weights'].shape, ( + 1, + 17, + )) + self.assertEqual(encoded['trajectory_weights'].shape, ( + 1, + 17, + )) + self.assertEqual(encoded['target_root'].shape, ( + 1, + 3, + )) # test not zero-centering codec = self.build_pose_lifting_label(zero_center=False) @@ -73,9 +86,15 @@ def test_encode(self): lifting_target_visible, camera_param) self.assertEqual(encoded['keypoint_labels'].shape, (1, 17, 2)) - self.assertEqual(encoded['lifting_target_label'].shape, (17, 3)) - self.assertEqual(encoded['lifting_target_weights'].shape, (17, )) - self.assertEqual(encoded['trajectory_weights'].shape, (17, )) + self.assertEqual(encoded['lifting_target_label'].shape, (1, 17, 3)) + self.assertEqual(encoded['lifting_target_weights'].shape, ( + 1, + 17, + )) + self.assertEqual(encoded['trajectory_weights'].shape, ( + 1, + 17, + )) # test reshape_keypoints codec = self.build_pose_lifting_label(reshape_keypoints=True) @@ -83,9 +102,15 @@ def test_encode(self): lifting_target_visible, camera_param) self.assertEqual(encoded['keypoint_labels'].shape, (34, 1)) - self.assertEqual(encoded['lifting_target_label'].shape, (17, 3)) - self.assertEqual(encoded['lifting_target_weights'].shape, (17, )) - self.assertEqual(encoded['trajectory_weights'].shape, (17, )) + self.assertEqual(encoded['lifting_target_label'].shape, (1, 17, 3)) + self.assertEqual(encoded['lifting_target_weights'].shape, ( + 1, + 17, + )) + self.assertEqual(encoded['trajectory_weights'].shape, ( + 1, + 17, + )) # test removing root codec = self.build_pose_lifting_label( @@ -95,10 +120,16 @@ def test_encode(self): self.assertTrue('target_root_removed' in encoded and 'target_root_index' in encoded) - self.assertEqual(encoded['lifting_target_weights'].shape, (16, )) self.assertEqual(encoded['keypoint_labels'].shape, (1, 17, 2)) - self.assertEqual(encoded['lifting_target_label'].shape, (16, 3)) - self.assertEqual(encoded['target_root'].shape, (3, )) + self.assertEqual(encoded['lifting_target_label'].shape, (1, 16, 3)) + self.assertEqual(encoded['lifting_target_weights'].shape, ( + 1, + 16, + )) + self.assertEqual(encoded['target_root'].shape, ( + 1, + 3, + )) # test normalizing camera codec = self.build_pose_lifting_label(normalize_camera=True) @@ -146,12 +177,10 @@ def test_cicular_verification(self): lifting_target_visible, camera_param) _keypoints, _ = codec.decode( - np.expand_dims(encoded['lifting_target_label'], axis=0), + encoded['lifting_target_label'], target_root=lifting_target[..., 0, :]) - self.assertTrue( - np.allclose( - np.expand_dims(lifting_target, axis=0), _keypoints, atol=5.)) + self.assertTrue(np.allclose(lifting_target, _keypoints, atol=5.)) # test removing root codec = self.build_pose_lifting_label(remove_root=True) @@ -159,9 +188,7 @@ def test_cicular_verification(self): lifting_target_visible, camera_param) _keypoints, _ = codec.decode( - np.expand_dims(encoded['lifting_target_label'], axis=0), + encoded['lifting_target_label'], target_root=lifting_target[..., 0, :]) - self.assertTrue( - np.allclose( - np.expand_dims(lifting_target, axis=0), _keypoints, atol=5.)) + self.assertTrue(np.allclose(lifting_target, _keypoints, atol=5.)) From ffec7403b95afd6fb7842229e1f2303a8a731625 Mon Sep 17 00:00:00 2001 From: LareinaM Date: Fri, 9 Jun 2023 17:33:55 +0800 Subject: [PATCH 08/30] reform lifting target, complete new codec --- .../h36m/vid_pl_motionbert_8xb32-120e_h36m.py | 10 ++-- mmpose/codecs/image_pose_lifting.py | 19 ++++++-- mmpose/codecs/mono_pose_lifting.py | 48 ++++++++++++------- mmpose/codecs/utils/__init__.py | 4 +- .../codecs/utils/camera_image_projection.py | 33 +++++++++++++ mmpose/codecs/video_pose_lifting.py | 14 ++++-- mmpose/datasets/transforms/formatting.py | 5 +- .../evaluation/metrics/keypoint_3d_metrics.py | 2 +- .../temporal_regression_head.py | 2 +- .../trajectory_regression_head.py | 2 +- tests/test_codecs/test_image_pose_lifting.py | 4 +- tests/test_codecs/test_mono_pose_lifting.py | 7 +-- tests/test_codecs/test_video_pose_lifting.py | 4 +- .../test_body_datasets/test_h36m_dataset.py | 11 +++++ .../test_transforms/test_pose3d_transforms.py | 19 +++++--- .../test_metrics/test_keypoint_3d_metrics.py | 9 ++-- 16 files changed, 135 insertions(+), 58 deletions(-) create mode 100644 mmpose/codecs/utils/camera_image_projection.py diff --git a/configs/body_3d_keypoint/video_pose_lift/h36m/vid_pl_motionbert_8xb32-120e_h36m.py b/configs/body_3d_keypoint/video_pose_lift/h36m/vid_pl_motionbert_8xb32-120e_h36m.py index d800f23601..e542784dec 100644 --- a/configs/body_3d_keypoint/video_pose_lift/h36m/vid_pl_motionbert_8xb32-120e_h36m.py +++ b/configs/body_3d_keypoint/video_pose_lift/h36m/vid_pl_motionbert_8xb32-120e_h36m.py @@ -27,12 +27,9 @@ # codec settings codec = dict( - type='VideoPoseLifting', + type='MonoPoseLifting', num_keypoints=17, - zero_center=True, - root_index=0, - remove_root=False, - reshape_keypoints=False, + zero_center=False, concat_vis=True) # model settings @@ -68,7 +65,7 @@ dict( type='PackPoseInputs', meta_keys=('id', 'category_id', 'target_img_path', 'flip_indices', - 'target_root')) + 'target_root', 'camera_param')) ] # data loaders @@ -82,7 +79,6 @@ ann_file='annotation_body3d/fps50/h36m_test.npz', seq_len=1, seq_step=1, - merge_seq=243, pad_video_seq=True, camera_param_file='annotation_body3d/cameras.pkl', data_root=data_root, diff --git a/mmpose/codecs/image_pose_lifting.py b/mmpose/codecs/image_pose_lifting.py index 852d583c39..d954eb7fb7 100644 --- a/mmpose/codecs/image_pose_lifting.py +++ b/mmpose/codecs/image_pose_lifting.py @@ -89,7 +89,9 @@ def encode(self, encoded (dict): Contains the following items: - keypoint_labels (np.ndarray): The processed keypoints in - shape (K * D, N) where D is 2 for 2d coordinates. + shape like (N, K, D) or (K * D, N). + - keypoint_labels_visible (np.ndarray): The processed + keypoints' weights in shape (N, K, ) or (N-1, K, ). - lifting_target_label: The processed target coordinate in shape (K, C) or (K-1, C). - lifting_target_weights (np.ndarray): The target weights in @@ -101,11 +103,13 @@ def encode(self, In addition, there are some optional items it may contain: + - target_root (np.ndarray): The root coordinate of target in + shape (C, ). Exists if ``zero_center`` is ``True``. - target_root_removed (bool): Indicate whether the root of - pose lifting target is removed. Added if ``self.remove_root`` - is ``True``. + pose-lifitng target is removed. Exists if + ``remove_root`` is ``True``. - target_root_index (int): An integer indicating the index of - root. Added if ``self.remove_root`` and ``self.save_index`` + root. Exists if ``remove_root`` and ``save_index`` are ``True``. """ if keypoints_visible is None: @@ -139,6 +143,8 @@ def encode(self, if self.remove_root: lifting_target_label = np.delete( lifting_target_label, self.root_index, axis=-2) + lifting_target_visible = np.delete( + lifting_target_visible, self.root_index, axis=-2) assert lifting_target_weights.ndim in {2, 3} axis_to_remove = -2 if lifting_target_weights.ndim == 3 else -1 lifting_target_weights = np.delete( @@ -160,6 +166,9 @@ def encode(self, keypoint_labels = (keypoint_labels - self.keypoints_mean) / self.keypoints_std if self.target_mean is not None and self.target_std is not None: + assert self.target_mean.ndim in {2, 3} + if self.target_mean.ndim == 2: + self.target_mean = self.target_mean[None, :] target_shape = lifting_target_label.shape assert self.target_mean.shape == target_shape @@ -183,7 +192,7 @@ def encode(self, keypoint_labels = keypoint_labels.transpose(1, 2, 0).reshape(-1, N) encoded['keypoint_labels'] = keypoint_labels - encoded['keypoints_visible'] = keypoints_visible + encoded['keypoint_labels_visible'] = keypoints_visible encoded['lifting_target_label'] = lifting_target_label encoded['lifting_target_weights'] = lifting_target_weights encoded['trajectory_weights'] = trajectory_weights diff --git a/mmpose/codecs/mono_pose_lifting.py b/mmpose/codecs/mono_pose_lifting.py index db686a848b..bf477a9b60 100644 --- a/mmpose/codecs/mono_pose_lifting.py +++ b/mmpose/codecs/mono_pose_lifting.py @@ -7,6 +7,7 @@ from mmpose.registry import KEYPOINT_CODECS from .base import BaseKeypointCodec +from .utils import camera_to_image_coord @KEYPOINT_CODECS.register_module() @@ -76,26 +77,28 @@ def encode(self, encoded (dict): Contains the following items: - keypoint_labels (np.ndarray): The processed keypoints in - shape (K * D, N) where D is 2 for 2d coordinates. + shape like (N, K, D). + - keypoint_labels_visible (np.ndarray): The processed + keypoints' weights in shape (N, K, ) or (N-1, K, ). - lifting_target_label: The processed target coordinate in shape (K, C) or (K-1, C). - lifting_target_weights (np.ndarray): The target weights in shape (K, ) or (K-1, ). - trajectory_weights (np.ndarray): The trajectory weights in shape (K, ). + - factor (np.ndarray): The factor mapping camera and image + coordinate. In addition, there are some optional items it may contain: - target_root (np.ndarray): The root coordinate of target in - shape (C, ). Exists if ``self.zero_center`` is ``True``. + shape (C, ). Exists if ``zero_center`` is ``True``. - target_root_removed (bool): Indicate whether the root of - pose-lifitng target is removed. Exists if - ``self.remove_root`` is ``True``. + pose-lifitng target is removed. Exists if ``remove_root`` is + ``True``. - target_root_index (int): An integer indicating the index of - root. Exists if ``self.remove_root`` and ``self.save_index`` - are ``True``. - - camera_param (dict): The updated camera parameter dictionary. - Exists if ``self.normalize_camera`` is ``True``. + root. Exists if ``remove_root`` and ``save_index`` are + ``True``. """ if keypoints_visible is None: keypoints_visible = np.ones(keypoints.shape[:2], dtype=np.float32) @@ -121,6 +124,12 @@ def encode(self, encoded = dict() lifting_target_label = lifting_target.copy() + keypoint_labels = keypoints.copy() + + assert keypoint_labels.ndim in {2, 3} + if keypoint_labels.ndim == 2: + keypoint_labels = keypoint_labels[None, ...] + # Zero-center the target pose around a given root keypoint if self.zero_center: assert (lifting_target.ndim >= 2 and @@ -134,6 +143,8 @@ def encode(self, if self.remove_root: lifting_target_label = np.delete( lifting_target_label, self.root_index, axis=-2) + lifting_target_visible = np.delete( + lifting_target_visible, self.root_index, axis=-2) assert lifting_target_weights.ndim == 2 lifting_target_weights = np.delete( lifting_target_weights, self.root_index, axis=-1) @@ -149,15 +160,15 @@ def encode(self, _camera_param = deepcopy(camera_param) assert 'w' in _camera_param and 'h' in _camera_param w, h = _camera_param['w'], _camera_param['h'] - keypoint_labels = keypoints.copy() - keypoint_labels[:, :, :2] = keypoint_labels[:, :, :2] / w * 2 - [ - 1, h / w - ] - keypoint_labels[:, :, 2:] = keypoint_labels[:, :, 2:] / w * 2 + keypoint_labels[ + ..., :2] = keypoint_labels[..., :2] / w * 2 - [1, h / w] - assert keypoint_labels.ndim in {2, 3} - if keypoint_labels.ndim == 2: - keypoint_labels = keypoint_labels[None, ...] + # convert target to image coordinate + lifting_target_label, factor = camera_to_image_coord( + self.root_index, lifting_target_label, _camera_param) + lifting_target_label[ + ..., :2] = lifting_target_label[..., :2] / w * 2 - [1, h / w] + lifting_target_label[..., 2:] = lifting_target_label[..., 2:] / w * 2 if self.concat_vis: keypoints_visible_ = keypoints_visible @@ -167,10 +178,13 @@ def encode(self, (keypoint_labels, keypoints_visible_), axis=2) encoded['keypoint_labels'] = keypoint_labels - encoded['keypoints_visible'] = keypoints_visible + encoded['keypoint_labels_visible'] = keypoints_visible encoded['lifting_target_label'] = lifting_target_label encoded['lifting_target_weights'] = lifting_target_weights + encoded['lifting_target'] = lifting_target_label + encoded['lifting_target_visible'] = lifting_target_visible encoded['trajectory_weights'] = trajectory_weights + encoded['factor'] = factor return encoded diff --git a/mmpose/codecs/utils/__init__.py b/mmpose/codecs/utils/__init__.py index eaa093f12b..38bbae5c39 100644 --- a/mmpose/codecs/utils/__init__.py +++ b/mmpose/codecs/utils/__init__.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. +from .camera_image_projection import camera_to_image_coord, camera_to_pixel from .gaussian_heatmap import (generate_gaussian_heatmaps, generate_udp_gaussian_heatmaps, generate_unbiased_gaussian_heatmaps) @@ -19,5 +20,6 @@ 'batch_heatmap_nms', 'refine_keypoints', 'refine_keypoints_dark', 'refine_keypoints_dark_udp', 'generate_displacement_heatmap', 'refine_simcc_dark', 'gaussian_blur1d', 'get_diagonal_lengths', - 'get_instance_root', 'get_instance_bbox', 'get_simcc_normalized' + 'get_instance_root', 'get_instance_bbox', 'get_simcc_normalized', + 'camera_to_image_coord', 'camera_to_pixel' ] diff --git a/mmpose/codecs/utils/camera_image_projection.py b/mmpose/codecs/utils/camera_image_projection.py new file mode 100644 index 0000000000..5f149cd060 --- /dev/null +++ b/mmpose/codecs/utils/camera_image_projection.py @@ -0,0 +1,33 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import numpy as np + + +def camera_to_image_coord(root_index, kpts_3d_cam, camera_param): + root = kpts_3d_cam[..., root_index, :] + tl_kpt = root.copy() + tl_kpt[:2] -= 1.0 + br_kpt = root.copy() + br_kpt[:2] += 1.0 + tl_kpt = np.reshape(tl_kpt, (1, 3)) + br_kpt = np.reshape(br_kpt, (1, 3)) + fx, fy = camera_param['f'] / 1000. + cx, cy = camera_param['c'] / 1000. + + tl2d = camera_to_pixel(tl_kpt, fx, fy, cx, cy).flatten() + br2d = camera_to_pixel(br_kpt, fx, fy, cx, cy).flatten() + rectangle_3d_size = 2.0 + kpts_3d_image = np.zeros_like(kpts_3d_cam) + kpts_3d_image[:, :2] = camera_to_pixel(kpts_3d_cam.copy(), fx, fy, cx, cy) + ratio = (br2d[0] - tl2d[0] + 0.001) / rectangle_3d_size + kpts_3d_depth = ratio * (kpts_3d_cam[:, 2] - kpts_3d_cam[root_index, 2]) + kpts_3d_image[:, 2] = kpts_3d_depth + return kpts_3d_image, ratio + + +def camera_to_pixel(kpts_3d, fx, fy, cx, cy): + pose_2d = kpts_3d[:, :2] / kpts_3d[:, 2:3] + pose_2d[:, 0] *= fx + pose_2d[:, 1] *= fy + pose_2d[:, 0] += cx + pose_2d[:, 1] += cy + return pose_2d diff --git a/mmpose/codecs/video_pose_lifting.py b/mmpose/codecs/video_pose_lifting.py index dc51f1f92f..55b06e079b 100644 --- a/mmpose/codecs/video_pose_lifting.py +++ b/mmpose/codecs/video_pose_lifting.py @@ -84,7 +84,9 @@ def encode(self, encoded (dict): Contains the following items: - keypoint_labels (np.ndarray): The processed keypoints in - shape (K * D, N) where D is 2 for 2d coordinates. + shape like (N, K, D) or (K * D, N). + - keypoint_labels_visible (np.ndarray): The processed + keypoints' weights in shape (N, K, ) or (N-1, K, ). - lifting_target_label: The processed target coordinate in shape (K, C) or (K-1, C). - lifting_target_weights (np.ndarray): The target weights in @@ -95,15 +97,15 @@ def encode(self, In addition, there are some optional items it may contain: - target_root (np.ndarray): The root coordinate of target in - shape (C, ). Exists if ``self.zero_center`` is ``True``. + shape (C, ). Exists if ``zero_center`` is ``True``. - target_root_removed (bool): Indicate whether the root of pose-lifitng target is removed. Exists if - ``self.remove_root`` is ``True``. + ``remove_root`` is ``True``. - target_root_index (int): An integer indicating the index of - root. Exists if ``self.remove_root`` and ``self.save_index`` + root. Exists if ``remove_root`` and ``save_index`` are ``True``. - camera_param (dict): The updated camera parameter dictionary. - Exists if ``self.normalize_camera`` is ``True``. + Exists if ``normalize_camera`` is ``True``. """ if keypoints_visible is None: keypoints_visible = np.ones(keypoints.shape[:2], dtype=np.float32) @@ -142,6 +144,8 @@ def encode(self, if self.remove_root: lifting_target_label = np.delete( lifting_target_label, self.root_index, axis=-2) + lifting_target_visible = np.delete( + lifting_target_visible, self.root_index, axis=-2) assert lifting_target_weights.ndim in {2, 3} axis_to_remove = -2 if lifting_target_weights.ndim == 3 else -1 lifting_target_weights = np.delete( diff --git a/mmpose/datasets/transforms/formatting.py b/mmpose/datasets/transforms/formatting.py index 38e7fbc3fb..132ecd8bb6 100644 --- a/mmpose/datasets/transforms/formatting.py +++ b/mmpose/datasets/transforms/formatting.py @@ -207,9 +207,8 @@ def transform(self, results: dict) -> dict: for key, packed_key in self.label_mapping_table.items(): if key in results: # For pose-lifting, store only target-related fields - if 'lifting_target_label' in results and key in { - 'keypoint_labels', 'keypoint_weights', - 'transformed_keypoints_visible' + if 'lifting_target' in results and key in { + 'keypoint_labels', 'keypoint_weights' }: continue if isinstance(results[key], list): diff --git a/mmpose/evaluation/metrics/keypoint_3d_metrics.py b/mmpose/evaluation/metrics/keypoint_3d_metrics.py index 7c8f73d26c..02c7f1aeab 100644 --- a/mmpose/evaluation/metrics/keypoint_3d_metrics.py +++ b/mmpose/evaluation/metrics/keypoint_3d_metrics.py @@ -67,7 +67,7 @@ def process(self, data_batch: Sequence[dict], the model. """ for data_sample in data_samples: - # predicted keypoints coordinates, [1, T, K, D] + # predicted keypoints coordinates, [T, K, D] pred_coords = data_sample['pred_instances']['keypoints'] if pred_coords.ndim == 4: pred_coords = np.squeeze(pred_coords, axis=0) diff --git a/mmpose/models/heads/regression_heads/temporal_regression_head.py b/mmpose/models/heads/regression_heads/temporal_regression_head.py index ac76316842..9ed2e9f4fa 100644 --- a/mmpose/models/heads/regression_heads/temporal_regression_head.py +++ b/mmpose/models/heads/regression_heads/temporal_regression_head.py @@ -101,7 +101,7 @@ def predict(self, else: target_root = torch.stack([ torch.empty((0), dtype=torch.float32) - for _ in batch_data_samples[0].metainfo + for _ in batch_data_samples ]) preds = self.decode((batch_coords, target_root)) diff --git a/mmpose/models/heads/regression_heads/trajectory_regression_head.py b/mmpose/models/heads/regression_heads/trajectory_regression_head.py index adfd7353d3..a1608aaae7 100644 --- a/mmpose/models/heads/regression_heads/trajectory_regression_head.py +++ b/mmpose/models/heads/regression_heads/trajectory_regression_head.py @@ -101,7 +101,7 @@ def predict(self, else: target_root = torch.stack([ torch.empty((0), dtype=torch.float32) - for _ in batch_data_samples[0].metainfo + for _ in batch_data_samples ]) preds = self.decode((batch_coords, target_root)) diff --git a/tests/test_codecs/test_image_pose_lifting.py b/tests/test_codecs/test_image_pose_lifting.py index 7a07f87786..78b19ec59b 100644 --- a/tests/test_codecs/test_image_pose_lifting.py +++ b/tests/test_codecs/test_image_pose_lifting.py @@ -80,12 +80,12 @@ def test_encode(self): self.assertTrue('target_root_removed' in encoded and 'target_root_index' in encoded) - self.assertEqual(encoded['keypoint_labels'].shape, (1, 17, 2)) - self.assertEqual(encoded['lifting_target_label'].shape, (1, 16, 3)) self.assertEqual(encoded['lifting_target_weights'].shape, ( 1, 16, )) + self.assertEqual(encoded['keypoint_labels'].shape, (1, 17, 2)) + self.assertEqual(encoded['lifting_target_label'].shape, (1, 16, 3)) self.assertEqual(encoded['target_root'].shape, ( 1, 3, diff --git a/tests/test_codecs/test_mono_pose_lifting.py b/tests/test_codecs/test_mono_pose_lifting.py index bae1be6d50..b5a1792197 100644 --- a/tests/test_codecs/test_mono_pose_lifting.py +++ b/tests/test_codecs/test_mono_pose_lifting.py @@ -186,8 +186,8 @@ def test_cicular_verification(self): self.assertTrue(np.allclose(lifting_target, _keypoints, atol=5.)) # test denormalize according to image shape - keypoints = (0.1 + 0.8 * np.random.rand(1, 17, 3)) * [1000, 1002, 1] - keypoints = np.round(keypoints).astype(np.float32) + keypoints = (0.1 + 0.8 * np.random.rand(1, 17, 3)) + keypoints[..., 2] = np.round(keypoints[..., 2]).astype(np.float32) codec = self.build_pose_lifting_label(zero_center=False) encoded = codec.encode(keypoints, keypoints_visible, lifting_target, lifting_target_visible, camera_param) @@ -197,4 +197,5 @@ def test_cicular_verification(self): w=np.array([camera_param['w']]), h=np.array([camera_param['h']])) - self.assertTrue(np.allclose(keypoints, _keypoints, atol=5.)) + self.assertTrue( + np.allclose(keypoints[..., :2], _keypoints[..., :2], atol=5.)) diff --git a/tests/test_codecs/test_video_pose_lifting.py b/tests/test_codecs/test_video_pose_lifting.py index 4ed11f7f7a..8366953d54 100644 --- a/tests/test_codecs/test_video_pose_lifting.py +++ b/tests/test_codecs/test_video_pose_lifting.py @@ -120,12 +120,12 @@ def test_encode(self): self.assertTrue('target_root_removed' in encoded and 'target_root_index' in encoded) - self.assertEqual(encoded['keypoint_labels'].shape, (1, 17, 2)) - self.assertEqual(encoded['lifting_target_label'].shape, (1, 16, 3)) self.assertEqual(encoded['lifting_target_weights'].shape, ( 1, 16, )) + self.assertEqual(encoded['keypoint_labels'].shape, (1, 17, 2)) + self.assertEqual(encoded['lifting_target_label'].shape, (1, 16, 3)) self.assertEqual(encoded['target_root'].shape, ( 1, 3, diff --git a/tests/test_datasets/test_datasets/test_body_datasets/test_h36m_dataset.py b/tests/test_datasets/test_datasets/test_body_datasets/test_h36m_dataset.py index 88944dc11f..314189fde6 100644 --- a/tests/test_datasets/test_datasets/test_body_datasets/test_h36m_dataset.py +++ b/tests/test_datasets/test_datasets/test_body_datasets/test_h36m_dataset.py @@ -116,6 +116,17 @@ def test_topdown(self): self.assertEqual(len(dataset), 4) self.check_data_info_keys(dataset[0]) + dataset = self.build_h36m_dataset( + data_mode='topdown', + seq_len=1, + seq_step=1, + merge_seq=2, + causal=False, + pad_video_seq=True, + camera_param_file='cameras.pkl') + self.assertEqual(len(dataset), 2) + self.check_data_info_keys(dataset[0]) + # test topdown testing with 2d keypoint detection file and # sequence config dataset = self.build_h36m_dataset( diff --git a/tests/test_datasets/test_transforms/test_pose3d_transforms.py b/tests/test_datasets/test_transforms/test_pose3d_transforms.py index 5f5d5aa096..db7a612dee 100644 --- a/tests/test_datasets/test_transforms/test_pose3d_transforms.py +++ b/tests/test_datasets/test_transforms/test_pose3d_transforms.py @@ -35,7 +35,7 @@ def _parse_h36m_imgname(imgname): scales = data['scale'].astype(np.float32) idx = 0 - target_idx = 0 + target_idx = [0] data_info = { 'keypoints': keypoints[idx, :, :2].reshape(1, -1, 2), @@ -52,7 +52,6 @@ def _parse_h36m_imgname(imgname): 'sample_idx': idx, 'lifting_target': keypoints_3d[target_idx, :, :3], 'lifting_target_visible': keypoints_3d[target_idx, :, 3], - 'target_img_path': osp.join('tests/data/h36m', imgnames[target_idx]), } # add camera parameters @@ -108,9 +107,12 @@ def test_transform(self): tar_vis2 = results['lifting_target_visible'] self.assertEqual(kpts_vis2.shape, (1, 17)) - self.assertEqual(tar_vis2.shape, (17, )) + self.assertEqual(tar_vis2.shape, ( + 1, + 17, + )) self.assertEqual(kpts2.shape, (1, 17, 2)) - self.assertEqual(tar2.shape, (17, 3)) + self.assertEqual(tar2.shape, (1, 17, 3)) flip_indices = [ 0, 4, 5, 6, 1, 2, 3, 7, 8, 9, 10, 14, 15, 16, 11, 12, 13 @@ -121,12 +123,15 @@ def test_transform(self): self.assertTrue( np.allclose(kpts1[0][left][1:], kpts2[0][right][1:], atol=4.)) self.assertTrue( - np.allclose(tar1[left][1:], tar2[right][1:], atol=4.)) + np.allclose( + tar1[..., left, 1:], tar2[..., right, 1:], atol=4.)) self.assertTrue( - np.allclose(kpts_vis1[0][left], kpts_vis2[0][right], atol=4.)) + np.allclose( + kpts_vis1[..., left], kpts_vis2[..., right], atol=4.)) self.assertTrue( - np.allclose(tar_vis1[left], tar_vis2[right], atol=4.)) + np.allclose( + tar_vis1[..., left], tar_vis2[..., right], atol=4.)) # test camera flipping transform = RandomFlipAroundRoot( diff --git a/tests/test_evaluation/test_metrics/test_keypoint_3d_metrics.py b/tests/test_evaluation/test_metrics/test_keypoint_3d_metrics.py index d51d493cbc..391b7b194a 100644 --- a/tests/test_evaluation/test_metrics/test_keypoint_3d_metrics.py +++ b/tests/test_evaluation/test_metrics/test_keypoint_3d_metrics.py @@ -20,7 +20,8 @@ def setUp(self): for i in range(self.batch_size): gt_instances = InstanceData() keypoints = np.random.random((1, num_keypoints, 3)) - gt_instances.lifting_target = keypoints + gt_instances.lifting_target = np.random.random( + (1, num_keypoints, 3)) gt_instances.lifting_target_visible = np.ones( (1, num_keypoints, 1)).astype(bool) @@ -32,8 +33,10 @@ def setUp(self): data_sample = PoseDataSample( gt_instances=gt_instances, pred_instances=pred_instances) data_sample.set_metainfo( - dict(target_img_path='tests/data/h36m/S7/' - 'S7_Greeting.55011271/S7_Greeting.55011271_000396.jpg')) + dict(target_img_path=[ + 'tests/data/h36m/S7/' + 'S7_Greeting.55011271/S7_Greeting.55011271_000396.jpg' + ])) self.data_batch.append(data) self.data_samples.append(data_sample.to_dict()) From 2bd6de5afa1a86d640de5642cf5128e02be0893f Mon Sep 17 00:00:00 2001 From: LareinaM Date: Mon, 12 Jun 2023 17:26:50 +0800 Subject: [PATCH 09/30] debug transforms --- .../h36m/vid_pl_motionbert_8xb32-120e_h36m.py | 6 +- mmpose/codecs/mono_pose_lifting.py | 56 ------------- .../codecs/utils/camera_image_projection.py | 23 +++--- .../motion_regression_head.py | 21 +++-- tests/test_codecs/test_mono_pose_lifting.py | 78 +------------------ 5 files changed, 31 insertions(+), 153 deletions(-) diff --git a/configs/body_3d_keypoint/video_pose_lift/h36m/vid_pl_motionbert_8xb32-120e_h36m.py b/configs/body_3d_keypoint/video_pose_lift/h36m/vid_pl_motionbert_8xb32-120e_h36m.py index e542784dec..be6886bccb 100644 --- a/configs/body_3d_keypoint/video_pose_lift/h36m/vid_pl_motionbert_8xb32-120e_h36m.py +++ b/configs/body_3d_keypoint/video_pose_lift/h36m/vid_pl_motionbert_8xb32-120e_h36m.py @@ -26,11 +26,7 @@ ) # codec settings -codec = dict( - type='MonoPoseLifting', - num_keypoints=17, - zero_center=False, - concat_vis=True) +codec = dict(type='MonoPoseLifting', num_keypoints=17, concat_vis=True) # model settings model = dict( diff --git a/mmpose/codecs/mono_pose_lifting.py b/mmpose/codecs/mono_pose_lifting.py index bf477a9b60..e7687d4d04 100644 --- a/mmpose/codecs/mono_pose_lifting.py +++ b/mmpose/codecs/mono_pose_lifting.py @@ -23,8 +23,6 @@ class MonoPoseLifting(BaseKeypointCodec): Args: num_keypoints (int): The number of keypoints in the dataset. - zero_center: Whether to zero-center the target around root. Default: - ``True``. root_index (int): Root keypoint index in the pose. Default: 0. remove_root (bool): If true, remove the root keypoint from the pose. Default: ``False``. @@ -41,7 +39,6 @@ class MonoPoseLifting(BaseKeypointCodec): def __init__(self, num_keypoints: int, - zero_center: bool = True, root_index: int = 0, remove_root: bool = False, save_index: bool = False, @@ -49,7 +46,6 @@ def __init__(self, super().__init__() self.num_keypoints = num_keypoints - self.zero_center = zero_center self.root_index = root_index self.remove_root = remove_root self.save_index = save_index @@ -88,17 +84,6 @@ def encode(self, shape (K, ). - factor (np.ndarray): The factor mapping camera and image coordinate. - - In addition, there are some optional items it may contain: - - - target_root (np.ndarray): The root coordinate of target in - shape (C, ). Exists if ``zero_center`` is ``True``. - - target_root_removed (bool): Indicate whether the root of - pose-lifitng target is removed. Exists if ``remove_root`` is - ``True``. - - target_root_index (int): An integer indicating the index of - root. Exists if ``remove_root`` and ``save_index`` are - ``True``. """ if keypoints_visible is None: keypoints_visible = np.ones(keypoints.shape[:2], dtype=np.float32) @@ -130,32 +115,6 @@ def encode(self, if keypoint_labels.ndim == 2: keypoint_labels = keypoint_labels[None, ...] - # Zero-center the target pose around a given root keypoint - if self.zero_center: - assert (lifting_target.ndim >= 2 and - lifting_target.shape[-2] > self.root_index), \ - f'Got invalid joint shape {lifting_target.shape}' - - root = lifting_target[..., self.root_index, :] - lifting_target_label = lifting_target_label - root[:, None] - encoded['target_root'] = root - - if self.remove_root: - lifting_target_label = np.delete( - lifting_target_label, self.root_index, axis=-2) - lifting_target_visible = np.delete( - lifting_target_visible, self.root_index, axis=-2) - assert lifting_target_weights.ndim == 2 - lifting_target_weights = np.delete( - lifting_target_weights, self.root_index, axis=-1) - # Add a flag to avoid latter transforms that rely on the root - # joint or the original joint index - encoded['target_root_removed'] = True - - # Save the root index for restoring the global pose - if self.save_index: - encoded['target_root_index'] = self.root_index - # Normalize the 2D keypoint coordinate with image width and height _camera_param = deepcopy(camera_param) assert 'w' in _camera_param and 'h' in _camera_param @@ -166,9 +125,6 @@ def encode(self, # convert target to image coordinate lifting_target_label, factor = camera_to_image_coord( self.root_index, lifting_target_label, _camera_param) - lifting_target_label[ - ..., :2] = lifting_target_label[..., :2] / w * 2 - [1, h / w] - lifting_target_label[..., 2:] = lifting_target_label[..., 2:] / w * 2 if self.concat_vis: keypoints_visible_ = keypoints_visible @@ -191,7 +147,6 @@ def encode(self, def decode( self, encoded: np.ndarray, - target_root: Optional[np.ndarray] = None, w: Optional[np.ndarray] = None, h: Optional[np.ndarray] = None, ) -> Tuple[np.ndarray, np.ndarray]: @@ -200,10 +155,6 @@ def decode( Args: encoded (np.ndarray): Coordinates in shape (N, K, C). - target_root (np.ndarray, optional): The pose-lifitng target root - coordinate. Default: ``None``. - target_root (np.ndarray, optional): The pose-lifitng target root - coordinate. Default: ``None``. w (np.ndarray, optional): The image widths in shape (N, ). Default: ``None``. h (np.ndarray, optional): The image heights in shape (N, ). @@ -214,13 +165,6 @@ def decode( scores (np.ndarray): The keypoint scores in shape (N, K). """ keypoints = encoded.copy() - - if target_root is not None and target_root.size > 0: - if self.zero_center: - keypoints = keypoints + target_root - if self.remove_root: - keypoints = np.insert( - keypoints, self.root_index, target_root, axis=1) scores = np.ones(keypoints.shape[:-1], dtype=np.float32) if w is not None and w.size > 0: diff --git a/mmpose/codecs/utils/camera_image_projection.py b/mmpose/codecs/utils/camera_image_projection.py index 5f149cd060..75a7efdefa 100644 --- a/mmpose/codecs/utils/camera_image_projection.py +++ b/mmpose/codecs/utils/camera_image_projection.py @@ -5,9 +5,9 @@ def camera_to_image_coord(root_index, kpts_3d_cam, camera_param): root = kpts_3d_cam[..., root_index, :] tl_kpt = root.copy() - tl_kpt[:2] -= 1.0 + tl_kpt[..., :2] -= 1.0 br_kpt = root.copy() - br_kpt[:2] += 1.0 + br_kpt[..., :2] += 1.0 tl_kpt = np.reshape(tl_kpt, (1, 3)) br_kpt = np.reshape(br_kpt, (1, 3)) fx, fy = camera_param['f'] / 1000. @@ -15,19 +15,22 @@ def camera_to_image_coord(root_index, kpts_3d_cam, camera_param): tl2d = camera_to_pixel(tl_kpt, fx, fy, cx, cy).flatten() br2d = camera_to_pixel(br_kpt, fx, fy, cx, cy).flatten() + rectangle_3d_size = 2.0 kpts_3d_image = np.zeros_like(kpts_3d_cam) - kpts_3d_image[:, :2] = camera_to_pixel(kpts_3d_cam.copy(), fx, fy, cx, cy) + kpts_3d_image[..., :2] = camera_to_pixel(kpts_3d_cam.copy(), fx, fy, cx, + cy) ratio = (br2d[0] - tl2d[0] + 0.001) / rectangle_3d_size - kpts_3d_depth = ratio * (kpts_3d_cam[:, 2] - kpts_3d_cam[root_index, 2]) - kpts_3d_image[:, 2] = kpts_3d_depth + kpts_3d_depth = ratio * ( + kpts_3d_cam[..., 2] - kpts_3d_cam[..., root_index, 2]) + kpts_3d_image[..., 2] = kpts_3d_depth return kpts_3d_image, ratio def camera_to_pixel(kpts_3d, fx, fy, cx, cy): - pose_2d = kpts_3d[:, :2] / kpts_3d[:, 2:3] - pose_2d[:, 0] *= fx - pose_2d[:, 1] *= fy - pose_2d[:, 0] += cx - pose_2d[:, 1] += cy + pose_2d = kpts_3d[..., :2] / kpts_3d[..., 2:3] + pose_2d[..., 0] *= fx + pose_2d[..., 1] *= fy + pose_2d[..., 0] += cx + pose_2d[..., 1] += cy return pose_2d diff --git a/mmpose/models/heads/regression_heads/motion_regression_head.py b/mmpose/models/heads/regression_heads/motion_regression_head.py index 4a5e51f472..3691c8469e 100644 --- a/mmpose/models/heads/regression_heads/motion_regression_head.py +++ b/mmpose/models/heads/regression_heads/motion_regression_head.py @@ -98,20 +98,27 @@ def predict(self, batch_coords = self.forward(feats) # (B, K, D) # Restore global position with target_root - target_root = batch_data_samples[0].metainfo.get('target_root', None) - if target_root is not None: - target_root = torch.stack([ - torch.from_numpy(b.metainfo['target_root']) + camera_param = batch_data_samples[0].metainfo.get('camera_param', None) + if camera_param is not None: + w = torch.stack([ + torch.from_numpy(np.array([b.metainfo['camera_param']['w']])) + for b in batch_data_samples + ]) + h = torch.stack([ + torch.from_numpy(np.array([b.metainfo['camera_param']['h']])) for b in batch_data_samples ]) else: - target_root = torch.stack([ + w = torch.stack([ + torch.empty((0), dtype=torch.float32) + for _ in batch_data_samples + ]) + h = torch.stack([ torch.empty((0), dtype=torch.float32) for _ in batch_data_samples ]) - target_root = target_root[..., None, :] - preds = self.decode((batch_coords, target_root)) + preds = self.decode((batch_coords, w, h)) return preds diff --git a/tests/test_codecs/test_mono_pose_lifting.py b/tests/test_codecs/test_mono_pose_lifting.py index b5a1792197..69b5c48a7f 100644 --- a/tests/test_codecs/test_mono_pose_lifting.py +++ b/tests/test_codecs/test_mono_pose_lifting.py @@ -74,41 +74,6 @@ def test_encode(self): 1, 17, )) - self.assertEqual(encoded['target_root'].shape, ( - 1, - 3, - )) - - # test not zero-centering - codec = self.build_pose_lifting_label(zero_center=False) - encoded = codec.encode(keypoints, keypoints_visible, lifting_target, - lifting_target_visible, camera_param) - - self.assertEqual(encoded['keypoint_labels'].shape, (1, 17, 2)) - self.assertEqual(encoded['lifting_target_label'].shape, (1, 17, 3)) - self.assertEqual(encoded['lifting_target_weights'].shape, ( - 1, - 17, - )) - self.assertEqual(encoded['trajectory_weights'].shape, ( - 1, - 17, - )) - - # test removing root - codec = self.build_pose_lifting_label( - remove_root=True, save_index=True) - encoded = codec.encode(keypoints, keypoints_visible, lifting_target, - lifting_target_visible, camera_param) - - self.assertTrue('target_root_removed' in encoded - and 'target_root_index' in encoded) - self.assertEqual(encoded['keypoint_labels'].shape, (1, 17, 2)) - self.assertEqual(encoded['lifting_target_label'].shape, (1, 16, 3)) - self.assertEqual(encoded['target_root'].shape, ( - 1, - 3, - )) # test concatenating visibility codec = self.build_pose_lifting_label(concat_vis=True) @@ -117,36 +82,21 @@ def test_encode(self): self.assertEqual(encoded['keypoint_labels'].shape, (1, 17, 3)) self.assertEqual(encoded['lifting_target_label'].shape, (1, 17, 3)) - self.assertEqual(encoded['target_root'].shape, ( - 1, - 3, - )) def test_decode(self): - lifting_target = self.data['lifting_target'] encoded_wo_sigma = self.data['encoded_wo_sigma'] camera_param = self.data['camera_param'] # test default settings codec = self.build_pose_lifting_label() - decoded, scores = codec.decode( - encoded_wo_sigma, target_root=lifting_target[..., 0, :]) + decoded, scores = codec.decode(encoded_wo_sigma) self.assertEqual(decoded.shape, (1, 17, 3)) self.assertEqual(scores.shape, (1, 17)) - # test `remove_root=True` - codec = self.build_pose_lifting_label(remove_root=True) - - decoded, scores = codec.decode( - encoded_wo_sigma, target_root=lifting_target[..., 0, :]) - - self.assertEqual(decoded.shape, (1, 18, 3)) - self.assertEqual(scores.shape, (1, 18)) - # test denormalize according to image shape - codec = self.build_pose_lifting_label(zero_center=False) + codec = self.build_pose_lifting_label() decoded, scores = codec.decode( encoded_wo_sigma, @@ -163,32 +113,10 @@ def test_cicular_verification(self): lifting_target_visible = self.data['lifting_target_visible'] camera_param = self.data['camera_param'] - # test default settings - codec = self.build_pose_lifting_label() - encoded = codec.encode(keypoints, keypoints_visible, lifting_target, - lifting_target_visible, camera_param) - - _keypoints, _ = codec.decode( - encoded['lifting_target_label'], - target_root=lifting_target[..., 0, :]) - - self.assertTrue(np.allclose(lifting_target, _keypoints, atol=5.)) - - # test removing root - codec = self.build_pose_lifting_label(remove_root=True) - encoded = codec.encode(keypoints, keypoints_visible, lifting_target, - lifting_target_visible, camera_param) - - _keypoints, _ = codec.decode( - encoded['lifting_target_label'], - target_root=lifting_target[..., 0, :]) - - self.assertTrue(np.allclose(lifting_target, _keypoints, atol=5.)) - # test denormalize according to image shape keypoints = (0.1 + 0.8 * np.random.rand(1, 17, 3)) keypoints[..., 2] = np.round(keypoints[..., 2]).astype(np.float32) - codec = self.build_pose_lifting_label(zero_center=False) + codec = self.build_pose_lifting_label() encoded = codec.encode(keypoints, keypoints_visible, lifting_target, lifting_target_visible, camera_param) From b35fe2c8448512a8dd2a1d358348753dacd99bb2 Mon Sep 17 00:00:00 2001 From: LareinaM Date: Tue, 13 Jun 2023 10:31:39 +0800 Subject: [PATCH 10/30] add zero-centering --- mmpose/codecs/mono_pose_lifting.py | 4 ++++ tests/test_codecs/test_mono_pose_lifting.py | 7 +++---- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/mmpose/codecs/mono_pose_lifting.py b/mmpose/codecs/mono_pose_lifting.py index e7687d4d04..8b4e3bc804 100644 --- a/mmpose/codecs/mono_pose_lifting.py +++ b/mmpose/codecs/mono_pose_lifting.py @@ -125,6 +125,8 @@ def encode(self, # convert target to image coordinate lifting_target_label, factor = camera_to_image_coord( self.root_index, lifting_target_label, _camera_param) + lifting_target_label[..., :, :] = lifting_target_label[ + ..., :, :] - lifting_target_label[..., self.root_index, :] if self.concat_vis: keypoints_visible_ = keypoints_visible @@ -178,4 +180,6 @@ def decode( np.ones((w.shape[0], 1)), h / w, axis=1)[:, None, :] keypoints[..., :2] = (keypoints[..., :2] + trans) * w[:, None] / 2 keypoints[..., 2:] = keypoints[..., 2:] * w[:, None] / 2 + keypoints[..., :, :] = keypoints[..., :, :] - keypoints[ + ..., self.root_index, :] return keypoints, scores diff --git a/tests/test_codecs/test_mono_pose_lifting.py b/tests/test_codecs/test_mono_pose_lifting.py index 69b5c48a7f..2c1f1cd6c2 100644 --- a/tests/test_codecs/test_mono_pose_lifting.py +++ b/tests/test_codecs/test_mono_pose_lifting.py @@ -107,7 +107,6 @@ def test_decode(self): self.assertEqual(scores.shape, (1, 17)) def test_cicular_verification(self): - keypoints = self.data['keypoints'] keypoints_visible = self.data['keypoints_visible'] lifting_target = self.data['lifting_target'] lifting_target_visible = self.data['lifting_target_visible'] @@ -115,7 +114,6 @@ def test_cicular_verification(self): # test denormalize according to image shape keypoints = (0.1 + 0.8 * np.random.rand(1, 17, 3)) - keypoints[..., 2] = np.round(keypoints[..., 2]).astype(np.float32) codec = self.build_pose_lifting_label() encoded = codec.encode(keypoints, keypoints_visible, lifting_target, lifting_target_visible, camera_param) @@ -125,5 +123,6 @@ def test_cicular_verification(self): w=np.array([camera_param['w']]), h=np.array([camera_param['h']])) - self.assertTrue( - np.allclose(keypoints[..., :2], _keypoints[..., :2], atol=5.)) + keypoints[..., :, :] = keypoints[..., :, :] - keypoints[..., 0, :] + + self.assertTrue(np.allclose(keypoints[..., :2], _keypoints[..., :2])) From 0187295deb40cb767e0920b4a39b3f36fc71f8e7 Mon Sep 17 00:00:00 2001 From: LareinaM Date: Tue, 13 Jun 2023 15:14:44 +0800 Subject: [PATCH 11/30] add factor --- .../h36m/vid_pl_motionbert_8xb32-120e_h36m.py | 2 +- mmpose/codecs/mono_pose_lifting.py | 11 +++++++- .../motion_regression_head.py | 16 ++++++++++-- tests/test_codecs/test_mono_pose_lifting.py | 26 +++++++++++++++++++ 4 files changed, 51 insertions(+), 4 deletions(-) diff --git a/configs/body_3d_keypoint/video_pose_lift/h36m/vid_pl_motionbert_8xb32-120e_h36m.py b/configs/body_3d_keypoint/video_pose_lift/h36m/vid_pl_motionbert_8xb32-120e_h36m.py index be6886bccb..ad0ea8c606 100644 --- a/configs/body_3d_keypoint/video_pose_lift/h36m/vid_pl_motionbert_8xb32-120e_h36m.py +++ b/configs/body_3d_keypoint/video_pose_lift/h36m/vid_pl_motionbert_8xb32-120e_h36m.py @@ -61,7 +61,7 @@ dict( type='PackPoseInputs', meta_keys=('id', 'category_id', 'target_img_path', 'flip_indices', - 'target_root', 'camera_param')) + 'factor', 'camera_param')) ] # data loaders diff --git a/mmpose/codecs/mono_pose_lifting.py b/mmpose/codecs/mono_pose_lifting.py index 8b4e3bc804..9e0d5f59bb 100644 --- a/mmpose/codecs/mono_pose_lifting.py +++ b/mmpose/codecs/mono_pose_lifting.py @@ -127,6 +127,7 @@ def encode(self, self.root_index, lifting_target_label, _camera_param) lifting_target_label[..., :, :] = lifting_target_label[ ..., :, :] - lifting_target_label[..., self.root_index, :] + lifting_target_label *= 1000 / factor if self.concat_vis: keypoints_visible_ = keypoints_visible @@ -142,7 +143,7 @@ def encode(self, encoded['lifting_target'] = lifting_target_label encoded['lifting_target_visible'] = lifting_target_visible encoded['trajectory_weights'] = trajectory_weights - encoded['factor'] = factor + encoded['factor'] = np.array([factor]) return encoded @@ -151,6 +152,7 @@ def decode( encoded: np.ndarray, w: Optional[np.ndarray] = None, h: Optional[np.ndarray] = None, + factor: Optional[np.ndarray] = None, ) -> Tuple[np.ndarray, np.ndarray]: """Decode keypoint coordinates from normalized space to input image space. @@ -161,6 +163,8 @@ def decode( Default: ``None``. h (np.ndarray, optional): The image heights in shape (N, ). Default: ``None``. + factor (np.ndarray, optional): The factor for projection in shape + (N, ). Default: ``None``. Returns: keypoints (np.ndarray): Decoded coordinates in shape (N, K, C). @@ -180,6 +184,11 @@ def decode( np.ones((w.shape[0], 1)), h / w, axis=1)[:, None, :] keypoints[..., :2] = (keypoints[..., :2] + trans) * w[:, None] / 2 keypoints[..., 2:] = keypoints[..., 2:] * w[:, None] / 2 + if factor is not None and factor.size > 0: + assert factor.shape[0] == keypoints.shape[0] + if factor.ndim == 1: + factor = factor[:, None] + keypoints[..., :, :] /= factor[..., :] keypoints[..., :, :] = keypoints[..., :, :] - keypoints[ ..., self.root_index, :] return keypoints, scores diff --git a/mmpose/models/heads/regression_heads/motion_regression_head.py b/mmpose/models/heads/regression_heads/motion_regression_head.py index 3691c8469e..558d44df21 100644 --- a/mmpose/models/heads/regression_heads/motion_regression_head.py +++ b/mmpose/models/heads/regression_heads/motion_regression_head.py @@ -97,7 +97,7 @@ def predict(self, batch_coords = self.forward(feats) # (B, K, D) - # Restore global position with target_root + # Restore global position with camera_param and factor camera_param = batch_data_samples[0].metainfo.get('camera_param', None) if camera_param is not None: w = torch.stack([ @@ -118,7 +118,19 @@ def predict(self, for _ in batch_data_samples ]) - preds = self.decode((batch_coords, w, h)) + factor = batch_data_samples[0].metainfo.get('factor', None) + if factor is not None: + factor = torch.stack([ + torch.from_numpy(np.array([b.metainfo['factor']])) + for b in batch_data_samples + ]) + else: + factor = torch.stack([ + torch.empty((0), dtype=torch.float32) + for _ in batch_data_samples + ]) + + preds = self.decode((batch_coords, w, h, factor)) return preds diff --git a/tests/test_codecs/test_mono_pose_lifting.py b/tests/test_codecs/test_mono_pose_lifting.py index 2c1f1cd6c2..ee08b0a225 100644 --- a/tests/test_codecs/test_mono_pose_lifting.py +++ b/tests/test_codecs/test_mono_pose_lifting.py @@ -106,6 +106,15 @@ def test_decode(self): self.assertEqual(decoded.shape, (1, 17, 3)) self.assertEqual(scores.shape, (1, 17)) + # test with factor + codec = self.build_pose_lifting_label() + + decoded, scores = codec.decode( + encoded_wo_sigma, factor=np.array([0.23])) + + self.assertEqual(decoded.shape, (1, 17, 3)) + self.assertEqual(scores.shape, (1, 17)) + def test_cicular_verification(self): keypoints_visible = self.data['keypoints_visible'] lifting_target = self.data['lifting_target'] @@ -126,3 +135,20 @@ def test_cicular_verification(self): keypoints[..., :, :] = keypoints[..., :, :] - keypoints[..., 0, :] self.assertTrue(np.allclose(keypoints[..., :2], _keypoints[..., :2])) + + # test with factor + keypoints = (0.1 + 0.8 * np.random.rand(1, 17, 3)) + codec = self.build_pose_lifting_label() + encoded = codec.encode(keypoints, keypoints_visible, lifting_target, + lifting_target_visible, camera_param) + + _keypoints, _ = codec.decode( + encoded['keypoint_labels'], + w=np.array([camera_param['w']]), + h=np.array([camera_param['h']]), + factor=encoded['factor']) + + keypoints /= encoded['factor'][0] + keypoints[..., :, :] = keypoints[..., :, :] - keypoints[..., 0, :] + + self.assertTrue(np.allclose(keypoints[..., :2], _keypoints[..., :2])) From 466e127aff47b2d3b65ca11641653062061db557 Mon Sep 17 00:00:00 2001 From: LareinaM Date: Mon, 19 Jun 2023 17:02:27 +0800 Subject: [PATCH 12/30] fix codec & dataset, enhance metric, add results --- .../video_pose_lift/h36m/motionbert_h36m.md | 15 ++- .../video_pose_lift/h36m/motionbert_h36m.yml | 10 +- .../h36m/vid_pl_motionbert_8xb32-120e_h36m.py | 15 ++- mmpose/codecs/mono_pose_lifting.py | 36 +++++--- .../codecs/utils/camera_image_projection.py | 17 ++-- .../datasets/base/base_mocap_dataset.py | 5 +- .../datasets/datasets/body3d/h36m_dataset.py | 92 +++++++++++-------- .../datasets/transforms/common_transforms.py | 4 +- .../evaluation/metrics/keypoint_3d_metrics.py | 20 +++- .../motion_regression_head.py | 2 +- tests/test_codecs/test_mono_pose_lifting.py | 2 +- 11 files changed, 133 insertions(+), 85 deletions(-) diff --git a/configs/body_3d_keypoint/video_pose_lift/h36m/motionbert_h36m.md b/configs/body_3d_keypoint/video_pose_lift/h36m/motionbert_h36m.md index fcce90e80a..7fa41b12d0 100644 --- a/configs/body_3d_keypoint/video_pose_lift/h36m/motionbert_h36m.md +++ b/configs/body_3d_keypoint/video_pose_lift/h36m/motionbert_h36m.md @@ -38,7 +38,14 @@ year = {2014} Testing results on Human3.6M dataset with ground truth 2D detections -| Arch | MPJPE | P-MPJPE | N-MPJPE | ckpt | log | -| :---------------------------------------------------------------------------------------------------------- | :---: | :-----: | :-----: | :--------: | :-------: | -| [MotionBERT](/configs/body_3d_keypoint/video_pose_lift/h36m/vid_pl_motionbert_8xb32-120e_h36m.py) | 34.87 | 14.95 | 34.02 | [ckpt](<>) | [log](<>) | -| [MotionBERT-finetuned](/configs/body_3d_keypoint/video_pose_lift/h36m/vid_pl_motionbert_8xb32-120e_h36m.py) | 34.97 | 14.94 | 34.11 | [ckpt](<>) | [log](<>) | +| Arch | MPJPE | P-MPJPE | ckpt | +| :---------------------------------------------------------------------------------------------------------- | :---: | :-----: | :--------: | +| [MotionBERT](/configs/body_3d_keypoint/video_pose_lift/h36m/vid_pl_motionbert_8xb32-120e_h36m.py) | 35.3 | 27.7 | [ckpt](<>) | +| [MotionBERT-finetuned](/configs/body_3d_keypoint/video_pose_lift/h36m/vid_pl_motionbert_8xb32-120e_h36m.py) | 27.5 | 21.6 | [ckpt](<>) | + +Testing results on Human3.6M dataset from the [official repo](https://github.com/Walter0807/MotionBERT) with ground truth 2D detections + +| Arch | MPJPE | P-MPJPE | ckpt | +| :---------------------------------------------------------------------------------------------------------- | :---: | :-----: | :--------: | +| [MotionBERT](/configs/body_3d_keypoint/video_pose_lift/h36m/vid_pl_motionbert_8xb32-120e_h36m.py) | 40.5 | 34.1 | [ckpt](<>) | +| [MotionBERT-finetuned](/configs/body_3d_keypoint/video_pose_lift/h36m/vid_pl_motionbert_8xb32-120e_h36m.py) | 38.2 | 32.6 | [ckpt](<>) | diff --git a/configs/body_3d_keypoint/video_pose_lift/h36m/motionbert_h36m.yml b/configs/body_3d_keypoint/video_pose_lift/h36m/motionbert_h36m.yml index 2a386ff01f..a4ed9970a3 100644 --- a/configs/body_3d_keypoint/video_pose_lift/h36m/motionbert_h36m.yml +++ b/configs/body_3d_keypoint/video_pose_lift/h36m/motionbert_h36m.yml @@ -15,9 +15,8 @@ Models: Results: - Dataset: Human3.6M Metrics: - MPJPE: 34.87 - P-MPJPE: 14.95 - N-MPJPE: 34.02 + MPJPE: 35.3 + P-MPJPE: 27.7 Task: Body 3D Keypoint Weights: - Config: configs/body_3d_keypoint/video_pose_lift/h36m/vid_pl_motionbert_8xb32-120e_h36m.py @@ -29,8 +28,7 @@ Models: Results: - Dataset: Human3.6M Metrics: - MPJPE: 34.97 - P-MPJPE: 14.94 - N-MPJPE: 34.11 + MPJPE: 27.5 + P-MPJPE: 21.6 Task: Body 3D Keypoint Weights: diff --git a/configs/body_3d_keypoint/video_pose_lift/h36m/vid_pl_motionbert_8xb32-120e_h36m.py b/configs/body_3d_keypoint/video_pose_lift/h36m/vid_pl_motionbert_8xb32-120e_h36m.py index ad0ea8c606..e2c346daf6 100644 --- a/configs/body_3d_keypoint/video_pose_lift/h36m/vid_pl_motionbert_8xb32-120e_h36m.py +++ b/configs/body_3d_keypoint/video_pose_lift/h36m/vid_pl_motionbert_8xb32-120e_h36m.py @@ -26,7 +26,8 @@ ) # codec settings -codec = dict(type='MonoPoseLifting', num_keypoints=17, concat_vis=True) +codec = dict( + type='MonoPoseLifting', num_keypoints=17, concat_vis=True, rootrel=True) # model settings model = dict( @@ -67,6 +68,9 @@ # data loaders val_dataloader = dict( batch_size=32, + shuffle=False, + prefetch_factor=4, + pin_memory=True, num_workers=2, persistent_workers=True, sampler=dict(type='DefaultSampler', shuffle=False, round_up=False), @@ -74,8 +78,8 @@ type=dataset_type, ann_file='annotation_body3d/fps50/h36m_test.npz', seq_len=1, + merge_seq=243, seq_step=1, - pad_video_seq=True, camera_param_file='annotation_body3d/cameras.pkl', data_root=data_root, data_prefix=dict(img='images/'), @@ -85,8 +89,11 @@ test_dataloader = val_dataloader # evaluators +skip_list = [ + 'S9_Greet', 'S9_SittingDown', 'S9_Wait_1', 'S9_Greeting', 'S9_Waiting_1' +] val_evaluator = [ - dict(type='MPJPE', mode='mpjpe'), - # dict(type='MPJPE', mode='p-mpjpe') + dict(type='MPJPE', mode='mpjpe', skip_list=skip_list), + dict(type='MPJPE', mode='p-mpjpe', skip_list=skip_list) ] test_evaluator = val_evaluator diff --git a/mmpose/codecs/mono_pose_lifting.py b/mmpose/codecs/mono_pose_lifting.py index 9e0d5f59bb..335003d3e1 100644 --- a/mmpose/codecs/mono_pose_lifting.py +++ b/mmpose/codecs/mono_pose_lifting.py @@ -34,7 +34,7 @@ class MonoPoseLifting(BaseKeypointCodec): """ auxiliary_encode_keys = { - 'lifting_target', 'lifting_target_visible', 'camera_param' + 'lifting_target', 'lifting_target_visible', 'camera_param', 'factor' } def __init__(self, @@ -42,7 +42,8 @@ def __init__(self, root_index: int = 0, remove_root: bool = False, save_index: bool = False, - concat_vis: bool = False): + concat_vis: bool = False, + rootrel: bool = False): super().__init__() self.num_keypoints = num_keypoints @@ -50,13 +51,15 @@ def __init__(self, self.remove_root = remove_root self.save_index = save_index self.concat_vis = concat_vis + self.rootrel = rootrel def encode(self, keypoints: np.ndarray, keypoints_visible: Optional[np.ndarray] = None, lifting_target: Optional[np.ndarray] = None, lifting_target_visible: Optional[np.ndarray] = None, - camera_param: Optional[dict] = None) -> dict: + camera_param: Optional[dict] = None, + factor: Optional[np.ndarray] = None) -> dict: """Encoding keypoints from input image space to normalized space. Args: @@ -83,13 +86,13 @@ def encode(self, - trajectory_weights (np.ndarray): The trajectory weights in shape (K, ). - factor (np.ndarray): The factor mapping camera and image - coordinate. + coordinate in shape (N, 1). """ if keypoints_visible is None: keypoints_visible = np.ones(keypoints.shape[:2], dtype=np.float32) if lifting_target is None: - lifting_target = [keypoints[0]] + lifting_target = [keypoints[..., 0, :, :]] # set initial value for `lifting_target_weights` # and `trajectory_weights` @@ -123,11 +126,17 @@ def encode(self, ..., :2] = keypoint_labels[..., :2] / w * 2 - [1, h / w] # convert target to image coordinate - lifting_target_label, factor = camera_to_image_coord( + lifting_target_label, factor_ = camera_to_image_coord( self.root_index, lifting_target_label, _camera_param) lifting_target_label[..., :, :] = lifting_target_label[ - ..., :, :] - lifting_target_label[..., self.root_index, :] - lifting_target_label *= 1000 / factor + ..., :, :] - lifting_target_label[..., + self.root_index:self.root_index + + 1, :] + if factor is None: + factor = factor_ + if factor.ndim == 1: + factor = factor[:, None] + lifting_target_label *= 1000 * factor[..., None] if self.concat_vis: keypoints_visible_ = keypoints_visible @@ -143,7 +152,7 @@ def encode(self, encoded['lifting_target'] = lifting_target_label encoded['lifting_target_visible'] = lifting_target_visible encoded['trajectory_weights'] = trajectory_weights - encoded['factor'] = np.array([factor]) + encoded['factor'] = factor return encoded @@ -173,6 +182,9 @@ def decode( keypoints = encoded.copy() scores = np.ones(keypoints.shape[:-1], dtype=np.float32) + if self.rootrel: + keypoints[..., 0, :] = 0 + if w is not None and w.size > 0: assert w.shape == h.shape assert w.shape[0] == keypoints.shape[0] @@ -186,9 +198,7 @@ def decode( keypoints[..., 2:] = keypoints[..., 2:] * w[:, None] / 2 if factor is not None and factor.size > 0: assert factor.shape[0] == keypoints.shape[0] - if factor.ndim == 1: - factor = factor[:, None] - keypoints[..., :, :] /= factor[..., :] + keypoints *= factor[..., None] keypoints[..., :, :] = keypoints[..., :, :] - keypoints[ - ..., self.root_index, :] + ..., self.root_index:self.root_index + 1, :] return keypoints, scores diff --git a/mmpose/codecs/utils/camera_image_projection.py b/mmpose/codecs/utils/camera_image_projection.py index 75a7efdefa..847062ce7e 100644 --- a/mmpose/codecs/utils/camera_image_projection.py +++ b/mmpose/codecs/utils/camera_image_projection.py @@ -8,23 +8,24 @@ def camera_to_image_coord(root_index, kpts_3d_cam, camera_param): tl_kpt[..., :2] -= 1.0 br_kpt = root.copy() br_kpt[..., :2] += 1.0 - tl_kpt = np.reshape(tl_kpt, (1, 3)) - br_kpt = np.reshape(br_kpt, (1, 3)) + tl_kpt = np.reshape(tl_kpt, (-1, 3)) + br_kpt = np.reshape(br_kpt, (-1, 3)) fx, fy = camera_param['f'] / 1000. cx, cy = camera_param['c'] / 1000. - tl2d = camera_to_pixel(tl_kpt, fx, fy, cx, cy).flatten() - br2d = camera_to_pixel(br_kpt, fx, fy, cx, cy).flatten() + tl2d = camera_to_pixel(tl_kpt, fx, fy, cx, cy) + br2d = camera_to_pixel(br_kpt, fx, fy, cx, cy) rectangle_3d_size = 2.0 kpts_3d_image = np.zeros_like(kpts_3d_cam) kpts_3d_image[..., :2] = camera_to_pixel(kpts_3d_cam.copy(), fx, fy, cx, cy) - ratio = (br2d[0] - tl2d[0] + 0.001) / rectangle_3d_size - kpts_3d_depth = ratio * ( - kpts_3d_cam[..., 2] - kpts_3d_cam[..., root_index, 2]) + ratio = (br2d[..., 0] - tl2d[..., 0] + 0.001) / rectangle_3d_size + factor = rectangle_3d_size / (br2d[..., 0] - tl2d[..., 0] + 0.001) + kpts_3d_depth = ratio[:, None] * ( + kpts_3d_cam[..., 2] - kpts_3d_cam[..., root_index:root_index + 1, 2]) kpts_3d_image[..., 2] = kpts_3d_depth - return kpts_3d_image, ratio + return kpts_3d_image, factor def camera_to_pixel(kpts_3d, fx, fy, cx, cy): diff --git a/mmpose/datasets/datasets/base/base_mocap_dataset.py b/mmpose/datasets/datasets/base/base_mocap_dataset.py index 7724c5e436..8fd81d01a7 100644 --- a/mmpose/datasets/datasets/base/base_mocap_dataset.py +++ b/mmpose/datasets/datasets/base/base_mocap_dataset.py @@ -293,9 +293,8 @@ def _load_annotations(self) -> Tuple[List[dict], List[dict]]: image_list = [] for idx, frame_ids in enumerate(self.sequence_indices): - assert len(frame_ids) == ( - self.merge_seq * - self.seq_len if self.merge_seq else self.seq_len) + assert len(frame_ids) == (self.merge_seq if self.merge_seq else + self.seq_len), f'{len(frame_ids)}' _img_names = img_names[frame_ids] diff --git a/mmpose/datasets/datasets/body3d/h36m_dataset.py b/mmpose/datasets/datasets/body3d/h36m_dataset.py index 9150416239..86dd4fb8d6 100644 --- a/mmpose/datasets/datasets/body3d/h36m_dataset.py +++ b/mmpose/datasets/datasets/body3d/h36m_dataset.py @@ -1,5 +1,4 @@ # Copyright (c) OpenMMLab. All rights reserved. -import itertools import os.path as osp from collections import defaultdict from typing import Callable, List, Optional, Sequence, Tuple, Union @@ -68,6 +67,9 @@ class Human36mDataset(BaseMocapDataset): If set, 2d keypoint loaded from this file will be used instead of ground-truth keypoints. This setting is only when ``keypoint_2d_src`` is ``'detection'``. Default: ``None``. + factor_file (str, optional): The projection factors' file. If set, + factor loaded from this file will be used instead of calculated + factors. Default: ``None``. camera_param_file (str): Cameras' parameters file. Default: ``None``. data_mode (str): Specifies the mode of data samples: ``'topdown'`` or ``'bottomup'``. In ``'topdown'`` mode, each data sample contains @@ -113,6 +115,7 @@ def __init__(self, subset_frac: float = 1.0, keypoint_2d_src: str = 'gt', keypoint_2d_det_file: Optional[str] = None, + factor_file: Optional[str] = None, camera_param_file: Optional[str] = None, data_mode: str = 'topdown', metainfo: Optional[dict] = None, @@ -142,6 +145,12 @@ def __init__(self, self.seq_step = seq_step self.pad_video_seq = pad_video_seq + if factor_file: + if not is_abs(factor_file): + factor_file = osp.join(data_root, factor_file) + assert exists(factor_file), 'Annotation file does not exist.' + self.factor_file = factor_file + super().__init__( ann_file=ann_file, seq_len=seq_len, @@ -176,35 +185,46 @@ def get_sequence_indices(self) -> List[List[int]]: sequence_indices = [] _len = (self.seq_len - 1) * self.seq_step + 1 _step = self.seq_step - for _, _indices in sorted(video_frames.items()): - n_frame = len(_indices) - - if self.pad_video_seq: - # Pad the sequence so that every frame in the sequence will be - # predicted. - if self.causal: - frames_left = self.seq_len - 1 - frames_right = 0 - else: - frames_left = (self.seq_len - 1) // 2 - frames_right = frames_left - for i in range(n_frame): - pad_left = max(0, frames_left - i // _step) - pad_right = max(0, - frames_right - (n_frame - 1 - i) // _step) - start = max(i % _step, i - frames_left * _step) - end = min(n_frame - (n_frame - 1 - i) % _step, - i + frames_right * _step + 1) - sequence_indices.append([_indices[0]] * pad_left + - _indices[start:end:_step] + - [_indices[-1]] * pad_right) - else: + + if self.merge_seq: + for _, _indices in sorted(video_frames.items()): + n_frame = len(_indices) seqs_from_video = [ - _indices[i:(i + _len):_step] - for i in range(0, n_frame - _len + 1) - ] + _indices[i:(i + self.merge_seq):_step] + for i in range(0, n_frame, self.merge_seq) + ][:n_frame // self.merge_seq] sequence_indices.extend(seqs_from_video) + else: + for _, _indices in sorted(video_frames.items()): + n_frame = len(_indices) + + if self.pad_video_seq: + # Pad the sequence so that every frame in the sequence will + # be predicted. + if self.causal: + frames_left = self.seq_len - 1 + frames_right = 0 + else: + frames_left = (self.seq_len - 1) // 2 + frames_right = frames_left + for i in range(n_frame): + pad_left = max(0, frames_left - i // _step) + pad_right = max( + 0, frames_right - (n_frame - 1 - i) // _step) + start = max(i % _step, i - frames_left * _step) + end = min(n_frame - (n_frame - 1 - i) % _step, + i + frames_right * _step + 1) + sequence_indices.append([_indices[0]] * pad_left + + _indices[start:end:_step] + + [_indices[-1]] * pad_right) + else: + seqs_from_video = [ + _indices[i:(i + _len):_step] + for i in range(0, n_frame - _len + 1) + ] + sequence_indices.extend(seqs_from_video) + # reduce dataset size if needed subset_size = int(len(sequence_indices) * self.subset_frac) start = np.random.randint(0, len(sequence_indices) - subset_size + 1) @@ -212,17 +232,6 @@ def get_sequence_indices(self) -> List[List[int]]: sequence_indices = sequence_indices[start:end] - if self.merge_seq > 0: - sequence_indices_merged = [] - for i in range(0, len(sequence_indices), self.merge_seq): - if i + self.merge_seq > len(sequence_indices): - break - sequence_indices_merged.append( - list( - itertools.chain.from_iterable( - sequence_indices[i:i + self.merge_seq]))) - sequence_indices = sequence_indices_merged - return sequence_indices def _load_annotations(self) -> Tuple[List[dict], List[dict]]: @@ -248,6 +257,13 @@ def _load_annotations(self) -> Tuple[List[dict], List[dict]]: 'keypoints_visible': keypoints_visible }) + if self.factor_file: + with get_local_path(self.factor_file) as local_path: + factors = np.load(local_path).astype(np.float32) + assert factors.shape[0] == kpts_3d.shape[0] + for idx, frame_ids in enumerate(self.sequence_indices): + factor = factors[frame_ids].astype(np.float32) + instance_list[idx].update({'factor': factor}) return instance_list, image_list diff --git a/mmpose/datasets/transforms/common_transforms.py b/mmpose/datasets/transforms/common_transforms.py index 87068246f8..ee800ebabe 100644 --- a/mmpose/datasets/transforms/common_transforms.py +++ b/mmpose/datasets/transforms/common_transforms.py @@ -961,7 +961,7 @@ def transform(self, results: Dict) -> Optional[dict]: # into results. auxiliary_encode_kwargs = { key: results[key] - for key in self.encoder.auxiliary_encode_keys + for key in self.encoder.auxiliary_encode_keys if key in results } encoded = self.encoder.encode( keypoints=keypoints, @@ -973,7 +973,7 @@ def transform(self, results: Dict) -> Optional[dict]: for _encoder in self.encoder: auxiliary_encode_kwargs = { key: results[key] - for key in _encoder.auxiliary_encode_keys + for key in _encoder.auxiliary_encode_keys if key in results } encoded_list.append( _encoder.encode( diff --git a/mmpose/evaluation/metrics/keypoint_3d_metrics.py b/mmpose/evaluation/metrics/keypoint_3d_metrics.py index 02c7f1aeab..fb3447bb3f 100644 --- a/mmpose/evaluation/metrics/keypoint_3d_metrics.py +++ b/mmpose/evaluation/metrics/keypoint_3d_metrics.py @@ -1,7 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. from collections import defaultdict from os import path as osp -from typing import Dict, Optional, Sequence +from typing import Dict, List, Optional, Sequence import numpy as np from mmengine.evaluator import BaseMetric @@ -38,6 +38,8 @@ class MPJPE(BaseMetric): names to disambiguate homonymous metrics of different evaluators. If prefix is not provided in the argument, ``self.default_prefix`` will be used instead. Default: ``None``. + skip_list (list, optional): The list of subject and action combinations + to be skipped. Default: []. """ ALIGNMENT = {'mpjpe': 'none', 'p-mpjpe': 'procrustes', 'n-mpjpe': 'scale'} @@ -45,7 +47,8 @@ class MPJPE(BaseMetric): def __init__(self, mode: str = 'mpjpe', collect_device: str = 'cpu', - prefix: Optional[str] = None) -> None: + prefix: Optional[str] = None, + skip_list: List[str] = []) -> None: super().__init__(collect_device=collect_device, prefix=prefix) allowed_modes = self.ALIGNMENT.keys() if mode not in allowed_modes: @@ -53,6 +56,7 @@ def __init__(self, f"'n-mpjpe', but got '{mode}'.") self.mode = mode + self.skip_list = skip_list def process(self, data_batch: Sequence[dict], data_samples: Sequence[dict]) -> None: @@ -82,12 +86,17 @@ def process(self, data_batch: Sequence[dict], img_path = data_sample['target_img_path'][0] _, rest = osp.basename(img_path).split('_', 1) action, _ = rest.split('.', 1) + actions = np.array([action] * gt_coords.shape[0]) + + subj_act = osp.basename(img_path).split('.')[0] + if subj_act in self.skip_list: + continue result = { 'pred_coords': pred_coords, 'gt_coords': gt_coords, 'mask': mask, - 'action': action + 'actions': actions } self.results.append(result) @@ -113,8 +122,9 @@ def compute_metrics(self, results: list) -> Dict[str, float]: mask = np.concatenate([result['mask'] for result in results]) # action_category_indices: Dict[List[int]] action_category_indices = defaultdict(list) - for idx, result in enumerate(results): - action_category = result['action'].split('_')[0] + actions = np.concatenate([result['actions'] for result in results]) + for idx, action in enumerate(actions): + action_category = action.split('_')[0] action_category_indices[action_category].append(idx) error_name = self.mode.upper() diff --git a/mmpose/models/heads/regression_heads/motion_regression_head.py b/mmpose/models/heads/regression_heads/motion_regression_head.py index 558d44df21..ae146a439b 100644 --- a/mmpose/models/heads/regression_heads/motion_regression_head.py +++ b/mmpose/models/heads/regression_heads/motion_regression_head.py @@ -121,7 +121,7 @@ def predict(self, factor = batch_data_samples[0].metainfo.get('factor', None) if factor is not None: factor = torch.stack([ - torch.from_numpy(np.array([b.metainfo['factor']])) + torch.from_numpy(b.metainfo['factor']) for b in batch_data_samples ]) else: diff --git a/tests/test_codecs/test_mono_pose_lifting.py b/tests/test_codecs/test_mono_pose_lifting.py index ee08b0a225..b6c2ed6ee7 100644 --- a/tests/test_codecs/test_mono_pose_lifting.py +++ b/tests/test_codecs/test_mono_pose_lifting.py @@ -148,7 +148,7 @@ def test_cicular_verification(self): h=np.array([camera_param['h']]), factor=encoded['factor']) - keypoints /= encoded['factor'][0] + keypoints *= encoded['factor'] keypoints[..., :, :] = keypoints[..., :, :] - keypoints[..., 0, :] self.assertTrue(np.allclose(keypoints[..., :2], _keypoints[..., :2])) From 014f9283f213ec30dbb2a396927ac7c03a750005 Mon Sep 17 00:00:00 2001 From: LareinaM Date: Wed, 21 Jun 2023 13:26:39 +0800 Subject: [PATCH 13/30] update results --- .../video_pose_lift/h36m/motionbert_h36m.md | 16 ++++++------ ...d_pl_motionbert-243frm_8xb32-120e_h36m.py} | 2 +- .../datasets/base/base_mocap_dataset.py | 25 ++++++++++--------- .../datasets/datasets/body3d/h36m_dataset.py | 16 ++++++------ .../test_body_datasets/test_h36m_dataset.py | 4 +-- 5 files changed, 32 insertions(+), 31 deletions(-) rename configs/body_3d_keypoint/video_pose_lift/h36m/{vid_pl_motionbert_8xb32-120e_h36m.py => vid_pl_motionbert-243frm_8xb32-120e_h36m.py} (98%) diff --git a/configs/body_3d_keypoint/video_pose_lift/h36m/motionbert_h36m.md b/configs/body_3d_keypoint/video_pose_lift/h36m/motionbert_h36m.md index 7fa41b12d0..e04a540b75 100644 --- a/configs/body_3d_keypoint/video_pose_lift/h36m/motionbert_h36m.md +++ b/configs/body_3d_keypoint/video_pose_lift/h36m/motionbert_h36m.md @@ -38,14 +38,14 @@ year = {2014} Testing results on Human3.6M dataset with ground truth 2D detections -| Arch | MPJPE | P-MPJPE | ckpt | -| :---------------------------------------------------------------------------------------------------------- | :---: | :-----: | :--------: | -| [MotionBERT](/configs/body_3d_keypoint/video_pose_lift/h36m/vid_pl_motionbert_8xb32-120e_h36m.py) | 35.3 | 27.7 | [ckpt](<>) | -| [MotionBERT-finetuned](/configs/body_3d_keypoint/video_pose_lift/h36m/vid_pl_motionbert_8xb32-120e_h36m.py) | 27.5 | 21.6 | [ckpt](<>) | +| Arch | MPJPE | average MPJPE | P-MPJPE | ckpt | +| :----------------------------------------------------------------------------------------------------------------- | :---: | :-----------: | :-----: | :--------: | +| [MotionBERT](/configs/body_3d_keypoint/video_pose_lift/h36m/vid_pl_motionbert-243frm_8xb32-120e_h36m.py) | 35.3 | 35.3 | 27.7 | [ckpt](<>) | +| [MotionBERT-finetuned](/configs/body_3d_keypoint/video_pose_lift/h36m/vid_pl_motionbert-243frm_8xb32-120e_h36m.py) | 27.5 | 27.4 | 21.6 | [ckpt](<>) | Testing results on Human3.6M dataset from the [official repo](https://github.com/Walter0807/MotionBERT) with ground truth 2D detections -| Arch | MPJPE | P-MPJPE | ckpt | -| :---------------------------------------------------------------------------------------------------------- | :---: | :-----: | :--------: | -| [MotionBERT](/configs/body_3d_keypoint/video_pose_lift/h36m/vid_pl_motionbert_8xb32-120e_h36m.py) | 40.5 | 34.1 | [ckpt](<>) | -| [MotionBERT-finetuned](/configs/body_3d_keypoint/video_pose_lift/h36m/vid_pl_motionbert_8xb32-120e_h36m.py) | 38.2 | 32.6 | [ckpt](<>) | +| Arch | MPJPE | average MPJPE | P-MPJPE | ckpt | +| :----------------------------------------------------------------------------------------------------------------- | :---: | :-----------: | :-----: | :--------: | +| [MotionBERT](/configs/body_3d_keypoint/video_pose_lift/h36m/vid_pl_motionbert-243frm_8xb32-120e_h36m.py) | 40.5 | 39.9 | 34.1 | [ckpt](<>) | +| [MotionBERT-finetuned](/configs/body_3d_keypoint/video_pose_lift/h36m/vid_pl_motionbert-243frm_8xb32-120e_h36m.py) | 38.2 | 37.7 | 32.6 | [ckpt](<>) | diff --git a/configs/body_3d_keypoint/video_pose_lift/h36m/vid_pl_motionbert_8xb32-120e_h36m.py b/configs/body_3d_keypoint/video_pose_lift/h36m/vid_pl_motionbert-243frm_8xb32-120e_h36m.py similarity index 98% rename from configs/body_3d_keypoint/video_pose_lift/h36m/vid_pl_motionbert_8xb32-120e_h36m.py rename to configs/body_3d_keypoint/video_pose_lift/h36m/vid_pl_motionbert-243frm_8xb32-120e_h36m.py index e2c346daf6..04c10b2835 100644 --- a/configs/body_3d_keypoint/video_pose_lift/h36m/vid_pl_motionbert_8xb32-120e_h36m.py +++ b/configs/body_3d_keypoint/video_pose_lift/h36m/vid_pl_motionbert-243frm_8xb32-120e_h36m.py @@ -78,7 +78,7 @@ type=dataset_type, ann_file='annotation_body3d/fps50/h36m_test.npz', seq_len=1, - merge_seq=243, + multiple_target=243, seq_step=1, camera_param_file='annotation_body3d/cameras.pkl', data_root=data_root, diff --git a/mmpose/datasets/datasets/base/base_mocap_dataset.py b/mmpose/datasets/datasets/base/base_mocap_dataset.py index 8fd81d01a7..e08ba6ea45 100644 --- a/mmpose/datasets/datasets/base/base_mocap_dataset.py +++ b/mmpose/datasets/datasets/base/base_mocap_dataset.py @@ -22,8 +22,8 @@ class BaseMocapDataset(BaseDataset): Args: ann_file (str): Annotation file path. Default: ''. seq_len (int): Number of frames in a sequence. Default: 1. - merge_seq (int): If larger than 0, merge every ``merge_seq`` sequence - together. Default: 0. + multiple_target (int): If larger than 0, merge every + ``multiple_target`` sequence together. Default: 0. causal (bool): If set to ``True``, the rightmost input frame will be the target frame. Otherwise, the middle input frame will be the target frame. Default: ``True``. @@ -66,7 +66,7 @@ class BaseMocapDataset(BaseDataset): def __init__(self, ann_file: str = '', seq_len: int = 1, - merge_seq: int = 0, + multiple_target: int = 0, causal: bool = True, subset_frac: float = 1.0, camera_param_file: Optional[str] = None, @@ -106,8 +106,8 @@ def __init__(self, self.seq_len = seq_len self.causal = causal - self.merge_seq = merge_seq - if self.merge_seq: + self.multiple_target = multiple_target + if self.multiple_target: assert (self.seq_len == 1) assert 0 < subset_frac <= 1, ( @@ -250,15 +250,15 @@ def get_sequence_indices(self) -> List[List[int]]: else: raise NotImplementedError('Multi-frame data sample unsupported!') - if self.merge_seq > 0: + if self.multiple_target > 0: sequence_indices_merged = [] - for i in range(0, len(sequence_indices), self.merge_seq): - if i + self.merge_seq > len(sequence_indices): + for i in range(0, len(sequence_indices), self.multiple_target): + if i + self.multiple_target > len(sequence_indices): break sequence_indices_merged.append( list( itertools.chain.from_iterable( - sequence_indices[i:i + self.merge_seq]))) + sequence_indices[i:i + self.multiple_target]))) sequence_indices = sequence_indices_merged return sequence_indices @@ -293,7 +293,8 @@ def _load_annotations(self) -> Tuple[List[dict], List[dict]]: image_list = [] for idx, frame_ids in enumerate(self.sequence_indices): - assert len(frame_ids) == (self.merge_seq if self.merge_seq else + assert len(frame_ids) == (self.multiple_target + if self.multiple_target else self.seq_len), f'{len(frame_ids)}' _img_names = img_names[frame_ids] @@ -307,8 +308,8 @@ def _load_annotations(self) -> Tuple[List[dict], List[dict]]: keypoints_3d_visible = _keypoints_3d[..., 3] target_idx = [-1] if self.causal else [int(self.seq_len) // 2] - if self.merge_seq: - target_idx = list(range(self.merge_seq)) + if self.multiple_target: + target_idx = list(range(self.multiple_target)) instance_info = { 'num_keypoints': num_keypoints, diff --git a/mmpose/datasets/datasets/body3d/h36m_dataset.py b/mmpose/datasets/datasets/body3d/h36m_dataset.py index 86dd4fb8d6..984665e062 100644 --- a/mmpose/datasets/datasets/body3d/h36m_dataset.py +++ b/mmpose/datasets/datasets/body3d/h36m_dataset.py @@ -45,8 +45,8 @@ class Human36mDataset(BaseMocapDataset): seq_len (int): Number of frames in a sequence. Default: 1. seq_step (int): The interval for extracting frames from the video. Default: 1. - merge_seq (int): If larger than 0, merge every ``merge_seq`` sequence - together. Default: 0. + multiple_target (int): If larger than 0, merge every + ``multiple_target`` sequence together. Default: 0. pad_video_seq (bool): Whether to pad the video so that poses will be predicted for every frame in the video. Default: ``False``. causal (bool): If set to ``True``, the rightmost input frame will be @@ -109,7 +109,7 @@ def __init__(self, ann_file: str = '', seq_len: int = 1, seq_step: int = 1, - merge_seq: int = 0, + multiple_target: int = 0, pad_video_seq: bool = False, causal: bool = True, subset_frac: float = 1.0, @@ -154,7 +154,7 @@ def __init__(self, super().__init__( ann_file=ann_file, seq_len=seq_len, - merge_seq=merge_seq, + multiple_target=multiple_target, causal=causal, subset_frac=subset_frac, camera_param_file=camera_param_file, @@ -186,13 +186,13 @@ def get_sequence_indices(self) -> List[List[int]]: _len = (self.seq_len - 1) * self.seq_step + 1 _step = self.seq_step - if self.merge_seq: + if self.multiple_target: for _, _indices in sorted(video_frames.items()): n_frame = len(_indices) seqs_from_video = [ - _indices[i:(i + self.merge_seq):_step] - for i in range(0, n_frame, self.merge_seq) - ][:n_frame // self.merge_seq] + _indices[i:(i + self.multiple_target):_step] + for i in range(0, n_frame, self.multiple_target) + ][:n_frame // self.multiple_target] sequence_indices.extend(seqs_from_video) else: diff --git a/tests/test_datasets/test_datasets/test_body_datasets/test_h36m_dataset.py b/tests/test_datasets/test_datasets/test_body_datasets/test_h36m_dataset.py index 314189fde6..fd6cdf5f17 100644 --- a/tests/test_datasets/test_datasets/test_body_datasets/test_h36m_dataset.py +++ b/tests/test_datasets/test_datasets/test_body_datasets/test_h36m_dataset.py @@ -120,11 +120,11 @@ def test_topdown(self): data_mode='topdown', seq_len=1, seq_step=1, - merge_seq=2, + multiple_target=1, causal=False, pad_video_seq=True, camera_param_file='cameras.pkl') - self.assertEqual(len(dataset), 2) + self.assertEqual(len(dataset), 4) self.check_data_info_keys(dataset[0]) # test topdown testing with 2d keypoint detection file and From 1015a53b35caec38b163018c61e605c6ee49138d Mon Sep 17 00:00:00 2001 From: LareinaM Date: Mon, 26 Jun 2023 11:44:43 +0800 Subject: [PATCH 14/30] fix --- mmpose/codecs/mono_pose_lifting.py | 8 +++++--- mmpose/datasets/datasets/body3d/h36m_dataset.py | 10 ++++++---- mmpose/datasets/transforms/common_transforms.py | 4 ++-- tests/test_codecs/test_mono_pose_lifting.py | 7 +++++-- 4 files changed, 18 insertions(+), 11 deletions(-) diff --git a/mmpose/codecs/mono_pose_lifting.py b/mmpose/codecs/mono_pose_lifting.py index 335003d3e1..ef84abffa9 100644 --- a/mmpose/codecs/mono_pose_lifting.py +++ b/mmpose/codecs/mono_pose_lifting.py @@ -71,6 +71,8 @@ def encode(self, lifting_target_visible (np.ndarray, optional): Target coordinate in shape (T, K, ). camera_param (dict, optional): The camera parameter dictionary. + factor (np.ndarray, optional): The factor mapping camera and image + coordinate in shape (T, ). Returns: encoded (dict): Contains the following items: @@ -78,7 +80,7 @@ def encode(self, - keypoint_labels (np.ndarray): The processed keypoints in shape like (N, K, D). - keypoint_labels_visible (np.ndarray): The processed - keypoints' weights in shape (N, K, ) or (N-1, K, ). + keypoints' weights in shape (N, K, ) or (N, K-1, ). - lifting_target_label: The processed target coordinate in shape (K, C) or (K-1, C). - lifting_target_weights (np.ndarray): The target weights in @@ -86,7 +88,7 @@ def encode(self, - trajectory_weights (np.ndarray): The trajectory weights in shape (K, ). - factor (np.ndarray): The factor mapping camera and image - coordinate in shape (N, 1). + coordinate in shape (T, 1). """ if keypoints_visible is None: keypoints_visible = np.ones(keypoints.shape[:2], dtype=np.float32) @@ -132,7 +134,7 @@ def encode(self, ..., :, :] - lifting_target_label[..., self.root_index:self.root_index + 1, :] - if factor is None: + if factor is None or factor[0] == 0: factor = factor_ if factor.ndim == 1: factor = factor[:, None] diff --git a/mmpose/datasets/datasets/body3d/h36m_dataset.py b/mmpose/datasets/datasets/body3d/h36m_dataset.py index 984665e062..59cd358aa9 100644 --- a/mmpose/datasets/datasets/body3d/h36m_dataset.py +++ b/mmpose/datasets/datasets/body3d/h36m_dataset.py @@ -260,10 +260,12 @@ def _load_annotations(self) -> Tuple[List[dict], List[dict]]: if self.factor_file: with get_local_path(self.factor_file) as local_path: factors = np.load(local_path).astype(np.float32) - assert factors.shape[0] == kpts_3d.shape[0] - for idx, frame_ids in enumerate(self.sequence_indices): - factor = factors[frame_ids].astype(np.float32) - instance_list[idx].update({'factor': factor}) + else: + factors = np.zeros((kpts_3d.shape[0], ), dtype=np.float32) + assert factors.shape[0] == kpts_3d.shape[0] + for idx, frame_ids in enumerate(self.sequence_indices): + factor = factors[frame_ids].astype(np.float32) + instance_list[idx].update({'factor': factor}) return instance_list, image_list diff --git a/mmpose/datasets/transforms/common_transforms.py b/mmpose/datasets/transforms/common_transforms.py index ee800ebabe..87068246f8 100644 --- a/mmpose/datasets/transforms/common_transforms.py +++ b/mmpose/datasets/transforms/common_transforms.py @@ -961,7 +961,7 @@ def transform(self, results: Dict) -> Optional[dict]: # into results. auxiliary_encode_kwargs = { key: results[key] - for key in self.encoder.auxiliary_encode_keys if key in results + for key in self.encoder.auxiliary_encode_keys } encoded = self.encoder.encode( keypoints=keypoints, @@ -973,7 +973,7 @@ def transform(self, results: Dict) -> Optional[dict]: for _encoder in self.encoder: auxiliary_encode_kwargs = { key: results[key] - for key in _encoder.auxiliary_encode_keys if key in results + for key in _encoder.auxiliary_encode_keys } encoded_list.append( _encoder.encode( diff --git a/tests/test_codecs/test_mono_pose_lifting.py b/tests/test_codecs/test_mono_pose_lifting.py index b6c2ed6ee7..69bdba77af 100644 --- a/tests/test_codecs/test_mono_pose_lifting.py +++ b/tests/test_codecs/test_mono_pose_lifting.py @@ -39,6 +39,7 @@ def setUp(self) -> None: camera_param = self.get_camera_param( 'S1/S1_Directions_1.54138969/S1_Directions_1.54138969_000001.jpg', camera_param) + factor = 0.1 + 5 * np.random.rand(1, ) self.data = dict( keypoints=keypoints, @@ -46,6 +47,7 @@ def setUp(self) -> None: lifting_target=lifting_target, lifting_target_visible=lifting_target_visible, camera_param=camera_param, + factor=factor, encoded_wo_sigma=encoded_wo_sigma) def test_build(self): @@ -58,11 +60,12 @@ def test_encode(self): lifting_target = self.data['lifting_target'] lifting_target_visible = self.data['lifting_target_visible'] camera_param = self.data['camera_param'] + factor = self.data['factor'] # test default settings codec = self.build_pose_lifting_label() encoded = codec.encode(keypoints, keypoints_visible, lifting_target, - lifting_target_visible, camera_param) + lifting_target_visible, camera_param, factor) self.assertEqual(encoded['keypoint_labels'].shape, (1, 17, 2)) self.assertEqual(encoded['lifting_target_label'].shape, (1, 17, 3)) @@ -78,7 +81,7 @@ def test_encode(self): # test concatenating visibility codec = self.build_pose_lifting_label(concat_vis=True) encoded = codec.encode(keypoints, keypoints_visible, lifting_target, - lifting_target_visible, camera_param) + lifting_target_visible, camera_param, factor) self.assertEqual(encoded['keypoint_labels'].shape, (1, 17, 3)) self.assertEqual(encoded['lifting_target_label'].shape, (1, 17, 3)) From a50ec62c4595e76f8ac66f4957b550eda3c6e680 Mon Sep 17 00:00:00 2001 From: LareinaM Date: Mon, 26 Jun 2023 15:28:56 +0800 Subject: [PATCH 15/30] change trunc_normal_ import --- mmpose/models/backbones/dstformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mmpose/models/backbones/dstformer.py b/mmpose/models/backbones/dstformer.py index 76bf6dd83c..2ef13bdb02 100644 --- a/mmpose/models/backbones/dstformer.py +++ b/mmpose/models/backbones/dstformer.py @@ -3,7 +3,7 @@ import torch.nn as nn from mmcv.cnn.bricks import DropPath from mmengine.model import BaseModule, constant_init -from timm.models.layers import trunc_normal_ +from mmengine.model.weight_init import trunc_normal_ from mmpose.registry import MODELS from .base_backbone import BaseBackbone From edbd6e2d9fdb2efbd07f5a359b41def408d00517 Mon Sep 17 00:00:00 2001 From: LareinaM Date: Tue, 27 Jun 2023 15:45:51 +0800 Subject: [PATCH 16/30] fix problems in codec --- mmpose/codecs/image_pose_lifting.py | 3 +- mmpose/codecs/video_pose_lifting.py | 3 +- tests/test_codecs/test_video_pose_lifting.py | 29 ++++++++++++++++++++ 3 files changed, 33 insertions(+), 2 deletions(-) diff --git a/mmpose/codecs/image_pose_lifting.py b/mmpose/codecs/image_pose_lifting.py index d954eb7fb7..31e8dcfef1 100644 --- a/mmpose/codecs/image_pose_lifting.py +++ b/mmpose/codecs/image_pose_lifting.py @@ -138,7 +138,8 @@ def encode(self, f'Got invalid joint shape {lifting_target.shape}' root = lifting_target[..., self.root_index, :] - lifting_target_label = lifting_target - root[:, None] + lifting_target_label = lifting_target - lifting_target[ + ..., self.root_index:self.root_index + 1, :] if self.remove_root: lifting_target_label = np.delete( diff --git a/mmpose/codecs/video_pose_lifting.py b/mmpose/codecs/video_pose_lifting.py index 55b06e079b..9e409a663c 100644 --- a/mmpose/codecs/video_pose_lifting.py +++ b/mmpose/codecs/video_pose_lifting.py @@ -138,7 +138,8 @@ def encode(self, f'Got invalid joint shape {lifting_target.shape}' root = lifting_target[..., self.root_index, :] - lifting_target_label = lifting_target_label - root[:, None] + lifting_target_label -= lifting_target_label[ + ..., self.root_index:self.root_index + 1, :] encoded['target_root'] = root if self.remove_root: diff --git a/tests/test_codecs/test_video_pose_lifting.py b/tests/test_codecs/test_video_pose_lifting.py index 8366953d54..31a095e927 100644 --- a/tests/test_codecs/test_video_pose_lifting.py +++ b/tests/test_codecs/test_video_pose_lifting.py @@ -144,6 +144,35 @@ def test_encode(self): encoded['camera_param']['f'], atol=4.)) + # test with multiple targets + keypoints = (0.1 + 0.8 * np.random.rand(2, 17, 2)) * [192, 256] + keypoints = np.round(keypoints).astype(np.float32) + keypoints_visible = np.random.randint(2, size=(2, 17)) + lifting_target = (0.1 + 0.8 * np.random.rand(2, 17, 3)) + lifting_target_visible = np.random.randint( + 2, size=( + 2, + 17, + )) + codec = self.build_pose_lifting_label() + encoded = codec.encode(keypoints, keypoints_visible, lifting_target, + lifting_target_visible, camera_param) + + self.assertEqual(encoded['keypoint_labels'].shape, (2, 17, 2)) + self.assertEqual(encoded['lifting_target_label'].shape, (2, 17, 3)) + self.assertEqual(encoded['lifting_target_weights'].shape, ( + 2, + 17, + )) + self.assertEqual(encoded['trajectory_weights'].shape, ( + 2, + 17, + )) + self.assertEqual(encoded['target_root'].shape, ( + 2, + 3, + )) + def test_decode(self): lifting_target = self.data['lifting_target'] encoded_wo_sigma = self.data['encoded_wo_sigma'] From b0b6fc18805c246fdc138cc3b4b72d7dccda246f Mon Sep 17 00:00:00 2001 From: LareinaM Date: Fri, 30 Jun 2023 11:14:31 +0800 Subject: [PATCH 17/30] rename codec, add docstring --- .../h36m/vid_pl_motionbert-243frm_8xb32-120e_h36m.py | 4 +++- .../codecs/{mono_pose_lifting.py => motionbert_label.py} | 7 +++++-- ...test_mono_pose_lifting.py => test_motionbert_label.py} | 8 ++++---- 3 files changed, 12 insertions(+), 7 deletions(-) rename mmpose/codecs/{mono_pose_lifting.py => motionbert_label.py} (97%) rename tests/test_codecs/{test_mono_pose_lifting.py => test_motionbert_label.py} (96%) diff --git a/configs/body_3d_keypoint/video_pose_lift/h36m/vid_pl_motionbert-243frm_8xb32-120e_h36m.py b/configs/body_3d_keypoint/video_pose_lift/h36m/vid_pl_motionbert-243frm_8xb32-120e_h36m.py index 04c10b2835..2803323ed2 100644 --- a/configs/body_3d_keypoint/video_pose_lift/h36m/vid_pl_motionbert-243frm_8xb32-120e_h36m.py +++ b/configs/body_3d_keypoint/video_pose_lift/h36m/vid_pl_motionbert-243frm_8xb32-120e_h36m.py @@ -27,7 +27,7 @@ # codec settings codec = dict( - type='MonoPoseLifting', num_keypoints=17, concat_vis=True, rootrel=True) + type='MotionBERTLabel', num_keypoints=17, concat_vis=True, rootrel=True) # model settings model = dict( @@ -77,6 +77,8 @@ dataset=dict( type=dataset_type, ann_file='annotation_body3d/fps50/h36m_test.npz', + # ann_file='annotation_body3d/fps50/h36m_test_original.npz', + # factor_file='annotation_body3d/fps50/h36m_factors.npy', seq_len=1, multiple_target=243, seq_step=1, diff --git a/mmpose/codecs/mono_pose_lifting.py b/mmpose/codecs/motionbert_label.py similarity index 97% rename from mmpose/codecs/mono_pose_lifting.py rename to mmpose/codecs/motionbert_label.py index ef84abffa9..1d036c49bb 100644 --- a/mmpose/codecs/mono_pose_lifting.py +++ b/mmpose/codecs/motionbert_label.py @@ -11,8 +11,9 @@ @KEYPOINT_CODECS.register_module() -class MonoPoseLifting(BaseKeypointCodec): - r"""Generate keypoint coordinates for pose lifter. +class MotionBERTLabel(BaseKeypointCodec): + r"""Generate keypoint and label coordinates for `MotionBERT`_ by Zhu et al + (2022). Note: @@ -31,6 +32,8 @@ class MonoPoseLifting(BaseKeypointCodec): Default: ``False``. concat_vis (bool): If true, concat the visibility item of keypoints. Default: ``False``. + rootrel (bool): If true, the root keypoint will be set to the + coordinate origin. Default: ``False``. """ auxiliary_encode_keys = { diff --git a/tests/test_codecs/test_mono_pose_lifting.py b/tests/test_codecs/test_motionbert_label.py similarity index 96% rename from tests/test_codecs/test_mono_pose_lifting.py rename to tests/test_codecs/test_motionbert_label.py index 69bdba77af..47ce3dfc68 100644 --- a/tests/test_codecs/test_mono_pose_lifting.py +++ b/tests/test_codecs/test_motionbert_label.py @@ -5,11 +5,11 @@ import numpy as np from mmengine.fileio import load -from mmpose.codecs import MonoPoseLifting +from mmpose.codecs import MotionBERTLabel from mmpose.registry import KEYPOINT_CODECS -class TestMonoPoseLifting(TestCase): +class TestMotionBERTLabel(TestCase): def get_camera_param(self, imgname, camera_param) -> dict: """Get camera parameters of a frame by its image name.""" @@ -19,7 +19,7 @@ def get_camera_param(self, imgname, camera_param) -> dict: return camera_param[(subj, camera)] def build_pose_lifting_label(self, **kwargs): - cfg = dict(type='MonoPoseLifting', num_keypoints=17) + cfg = dict(type='MotionBERTLabel', num_keypoints=17) cfg.update(kwargs) return KEYPOINT_CODECS.build(cfg) @@ -52,7 +52,7 @@ def setUp(self) -> None: def test_build(self): codec = self.build_pose_lifting_label() - self.assertIsInstance(codec, MonoPoseLifting) + self.assertIsInstance(codec, MotionBERTLabel) def test_encode(self): keypoints = self.data['keypoints'] From 032acd19ee8c18723cd9e4b724170cf87d1cf4db Mon Sep 17 00:00:00 2001 From: LareinaM Date: Fri, 30 Jun 2023 11:16:31 +0800 Subject: [PATCH 18/30] modify import --- mmpose/codecs/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mmpose/codecs/__init__.py b/mmpose/codecs/__init__.py index 9c17513afe..1a48b7f851 100644 --- a/mmpose/codecs/__init__.py +++ b/mmpose/codecs/__init__.py @@ -4,7 +4,7 @@ from .image_pose_lifting import ImagePoseLifting from .integral_regression_label import IntegralRegressionLabel from .megvii_heatmap import MegviiHeatmap -from .mono_pose_lifting import MonoPoseLifting +from .motionbert_label import MotionBERTLabel from .msra_heatmap import MSRAHeatmap from .regression_label import RegressionLabel from .simcc_label import SimCCLabel @@ -16,5 +16,5 @@ 'MSRAHeatmap', 'MegviiHeatmap', 'UDPHeatmap', 'RegressionLabel', 'SimCCLabel', 'IntegralRegressionLabel', 'AssociativeEmbedding', 'SPR', 'DecoupledHeatmap', 'VideoPoseLifting', 'ImagePoseLifting', - 'MonoPoseLifting' + 'MotionBERTLabel' ] From e4c46c894a04829df5fdcd814410252882ef8aaa Mon Sep 17 00:00:00 2001 From: LareinaM Date: Fri, 30 Jun 2023 13:29:15 +0800 Subject: [PATCH 19/30] fix codec --- mmpose/codecs/image_pose_lifting.py | 30 ++++++++++++++++++----------- 1 file changed, 19 insertions(+), 11 deletions(-) diff --git a/mmpose/codecs/image_pose_lifting.py b/mmpose/codecs/image_pose_lifting.py index 31e8dcfef1..0e17f99642 100644 --- a/mmpose/codecs/image_pose_lifting.py +++ b/mmpose/codecs/image_pose_lifting.py @@ -60,9 +60,18 @@ def __init__(self, self.save_index = save_index self.reshape_keypoints = reshape_keypoints self.concat_vis = concat_vis - if keypoints_mean is not None and keypoints_std is not None: + if keypoints_mean is not None: + keypoints_mean = np.array(keypoints_mean).reshape( + 1, num_keypoints, -1) + keypoints_std = np.array(keypoints_std).reshape( + 1, num_keypoints, -1) + assert keypoints_std is not None assert keypoints_mean.shape == keypoints_std.shape - if target_mean is not None and target_std is not None: + if target_mean is not None: + target_dim = num_keypoints - 1 if remove_root else num_keypoints + target_mean = np.array(target_mean).reshape(1, target_dim, -1) + target_std = np.array(target_std).reshape(1, target_dim, -1) + assert target_std is not None assert target_mean.shape == target_std.shape self.keypoints_mean = keypoints_mean self.keypoints_std = keypoints_std @@ -160,18 +169,17 @@ def encode(self, # Normalize the 2D keypoint coordinate with mean and std keypoint_labels = keypoints.copy() - if self.keypoints_mean is not None and self.keypoints_std is not None: - keypoints_shape = keypoints.shape - assert self.keypoints_mean.shape == keypoints_shape[1:] + if self.keypoints_mean is not None: + assert self.keypoints_mean.shape[1:] == keypoints.shape[1:] + encoded['keypoints_mean'] = self.keypoints_mean.copy() + encoded['keypoints_std'] = self.keypoints_std.copy() keypoint_labels = (keypoint_labels - self.keypoints_mean) / self.keypoints_std - if self.target_mean is not None and self.target_std is not None: - assert self.target_mean.ndim in {2, 3} - if self.target_mean.ndim == 2: - self.target_mean = self.target_mean[None, :] - target_shape = lifting_target_label.shape - assert self.target_mean.shape == target_shape + if self.target_mean is not None: + assert self.target_mean.shape == lifting_target_label.shape + encoded['target_mean'] = self.target_mean.copy() + encoded['target_std'] = self.target_std.copy() lifting_target_label = (lifting_target_label - self.target_mean) / self.target_std From 5ebba18576e7b096ad0bff148ca52213521bf644 Mon Sep 17 00:00:00 2001 From: LareinaM Date: Fri, 30 Jun 2023 16:23:28 +0800 Subject: [PATCH 20/30] fix typing --- mmpose/codecs/image_pose_lifting.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/mmpose/codecs/image_pose_lifting.py b/mmpose/codecs/image_pose_lifting.py index 0e17f99642..aae6c3b5be 100644 --- a/mmpose/codecs/image_pose_lifting.py +++ b/mmpose/codecs/image_pose_lifting.py @@ -61,16 +61,19 @@ def __init__(self, self.reshape_keypoints = reshape_keypoints self.concat_vis = concat_vis if keypoints_mean is not None: - keypoints_mean = np.array(keypoints_mean).reshape( - 1, num_keypoints, -1) - keypoints_std = np.array(keypoints_std).reshape( - 1, num_keypoints, -1) + keypoints_mean = np.array( + keypoints_mean, + dtype=np.float32).reshape(1, num_keypoints, -1) + keypoints_std = np.array( + keypoints_std, dtype=np.float32).reshape(1, num_keypoints, -1) assert keypoints_std is not None assert keypoints_mean.shape == keypoints_std.shape if target_mean is not None: target_dim = num_keypoints - 1 if remove_root else num_keypoints - target_mean = np.array(target_mean).reshape(1, target_dim, -1) - target_std = np.array(target_std).reshape(1, target_dim, -1) + target_mean = np.array( + target_mean, dtype=np.float32).reshape(1, target_dim, -1) + target_std = np.array( + target_std, dtype=np.float32).reshape(1, target_dim, -1) assert target_std is not None assert target_mean.shape == target_std.shape self.keypoints_mean = keypoints_mean From c9014dcf1049b935abec2d8a781e5aa1abf6837f Mon Sep 17 00:00:00 2001 From: LareinaM Date: Tue, 4 Jul 2023 21:24:03 +0800 Subject: [PATCH 21/30] move files, add weights' links --- configs/body_3d_keypoint/pose_lift/README.md | 28 ++++------ .../pose_lift/h36m/motionbert_h36m.md | 51 +++++++++++++++++++ .../h36m/motionbert_h36m.yml | 8 +-- ...lift_motionbert-243frm_8xb32-120e_h36m.py} | 0 .../video_pose_lift/h36m/motionbert_h36m.md | 51 ------------------- 5 files changed, 66 insertions(+), 72 deletions(-) create mode 100644 configs/body_3d_keypoint/pose_lift/h36m/motionbert_h36m.md rename configs/body_3d_keypoint/{video_pose_lift => pose_lift}/h36m/motionbert_h36m.yml (64%) rename configs/body_3d_keypoint/{video_pose_lift/h36m/vid_pl_motionbert-243frm_8xb32-120e_h36m.py => pose_lift/h36m/pose-lift_motionbert-243frm_8xb32-120e_h36m.py} (100%) delete mode 100644 configs/body_3d_keypoint/video_pose_lift/h36m/motionbert_h36m.md diff --git a/configs/body_3d_keypoint/pose_lift/README.md b/configs/body_3d_keypoint/pose_lift/README.md index 7e5f9f7e2a..f965c70cb2 100644 --- a/configs/body_3d_keypoint/pose_lift/README.md +++ b/configs/body_3d_keypoint/pose_lift/README.md @@ -16,23 +16,17 @@ For single-person 3D pose estimation from a monocular camera, existing works can #### Human3.6m Dataset -| Arch | Receptive Field | MPJPE | P-MPJPE | N-MPJPE | ckpt | log | - -| :------------------------------------------------------ | :-------------: | :---: | :-----: | :-----: | :------------------------------------------------------: | :-----------------------------------------------------: | - -| [VideoPose3D-supervised](/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_videopose3d-27frm-supv_8xb128-80e_h36m.py) | 27 | 40.1 | 30.1 | / | [ckpt](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_27frames_fullconv_supervised-fe8fbba9_20210527.pth) | [log](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_27frames_fullconv_supervised_20210527.log.json) | - -| [VideoPose3D-supervised](/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_videopose3d-81frm-supv_8xb128-80e_h36m.py) | 81 | 39.1 | 29.3 | / | [ckpt](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_81frames_fullconv_supervised-1f2d1104_20210527.pth) | [log](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_81frames_fullconv_supervised_20210527.log.json) | - -| [VideoPose3D-supervised](/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_videopose3d-243frm-supv_8xb128-80e_h36m.py) | 243 | | | / | [ckpt](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_243frames_fullconv_supervised-880bea25_20210527.pth) | [log](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_243frames_fullconv_supervised_20210527.log.json) | - -| [VideoPose3D-supervised-CPN](/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_videopose3d-1frm-supv-cpn-ft_8xb128-80e_h36m.py) | 1 | 53.0 | 41.3 | / | [ckpt](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_1frame_fullconv_supervised_cpn_ft-5c3afaed_20210527.pth) | [log](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_1frame_fullconv_supervised_cpn_ft_20210527.log.json) | - -| [VideoPose3D-supervised-CPN](/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_videopose3d-243frm-supv-cpn-ft_8xb128-200e_h36m.py) | 243 | | | / | [ckpt](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_243frames_fullconv_supervised_cpn_ft-88f5abbb_20210527.pth) | [log](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_243frames_fullconv_supervised_cpn_ft_20210527.log.json) | - -| [VideoPose3D-semi-supervised](/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_videopose3d-27frm-semi-supv_8xb64-200e_h36m.py) | 27 | 57.2 | 42.4 | 54.2 | [ckpt](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_27frames_fullconv_semi-supervised-54aef83b_20210527.pth) | [log](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_27frames_fullconv_semi-supervised_20210527.log.json) | - -| [VideoPose3D-semi-supervised-CPN](/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_videopose3d-27frm-semi-supv-cpn-ft_8xb64-200e_h36m.py) | 27 | 67.3 | 50.4 | 63.6 | [ckpt](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_27frames_fullconv_semi-supervised_cpn_ft-71be9cde_20210527.pth) | [log](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_27frames_fullconv_semi-supervised_cpn_ft_20210527.log.json) | +| Arch | MPJPE | P-MPJPE | N-MPJPE | ckpt | log | Details and Download | +| :-------------------------------------------- | :---: | :-----: | :-----: | :-------------------------------------------: | :------------------------------------------: | :---------------------------------------------: | +| [VideoPose3D-supervised-27frm](/configs/body_3d_keypoint/video_pose_lift/h36m/vid-pl_videopose3d-27frm-supv_8xb128-80e_h36m.py) | 40.1 | 30.1 | / | [ckpt](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_27frames_fullconv_supervised-fe8fbba9_20210527.pth) | [log](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_27frames_fullconv_supervised_20210527.log.json) | [videpose3d_h36m.md](./h36m/videpose3d_h36m.md) | +| [VideoPose3D-supervised-81frm](/configs/body_3d_keypoint/video_pose_lift/h36m/vid-pl_videopose3d-81frm-supv_8xb128-80e_h36m.py) | 39.1 | 29.3 | / | [ckpt](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_81frames_fullconv_supervised-1f2d1104_20210527.pth) | [log](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_81frames_fullconv_supervised_20210527.log.json) | [videpose3d_h36m.md](./h36m/videpose3d_h36m.md) | +| [VideoPose3D-supervised-243frm](/configs/body_3d_keypoint/video_pose_lift/h36m/vid-pl_videopose3d-243frm-supv_8xb128-80e_h36m.py) | | | / | [ckpt](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_243frames_fullconv_supervised-880bea25_20210527.pth) | [log](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_243frames_fullconv_supervised_20210527.log.json) | [videpose3d_h36m.md](./h36m/videpose3d_h36m.md) | +| [VideoPose3D-supervised-CPN-1frm](/configs/body_3d_keypoint/video_pose_lift/h36m/vid-pl_videopose3d-1frm-supv-cpn-ft_8xb128-80e_h36m.py) | 53.0 | 41.3 | / | [ckpt](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_1frame_fullconv_supervised_cpn_ft-5c3afaed_20210527.pth) | [log](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_1frame_fullconv_supervised_cpn_ft_20210527.log.json) | [videpose3d_h36m.md](./h36m/videpose3d_h36m.md) | +| [VideoPose3D-supervised-CPN-243frm](/configs/body_3d_keypoint/video_pose_lift/h36m/vid-pl_videopose3d-243frm-supv-cpn-ft_8xb128-200e_h36m.py) | | | / | [ckpt](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_243frames_fullconv_supervised_cpn_ft-88f5abbb_20210527.pth) | [log](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_243frames_fullconv_supervised_cpn_ft_20210527.log.json) | [videpose3d_h36m.md](./h36m/videpose3d_h36m.md) | +| [VideoPose3D-semi-supervised-27frm](/configs/body_3d_keypoint/video_pose_lift/h36m/vid-pl_videopose3d-27frm-semi-supv_8xb64-200e_h36m.py) | 57.2 | 42.4 | 54.2 | [ckpt](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_27frames_fullconv_semi-supervised-54aef83b_20210527.pth) | [log](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_27frames_fullconv_semi-supervised_20210527.log.json) | [videpose3d_h36m.md](./h36m/videpose3d_h36m.md) | +| [VideoPose3D-semi-supervised-CPN-27frm](/configs/body_3d_keypoint/video_pose_lift/h36m/vid-pl_videopose3d-27frm-semi-supv-cpn-ft_8xb64-200e_h36m.py) | 67.3 | 50.4 | 63.6 | [ckpt](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_27frames_fullconv_semi-supervised_cpn_ft-71be9cde_20210527.pth) | [log](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_27frames_fullconv_semi-supervised_cpn_ft_20210527.log.json) | [videpose3d_h36m.md](./h36m/videpose3d_h36m.md) | +| [MotionBERT](/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_motionbert-243frm_8xb32-120e_h36m.py) | 35.3 | 27.7 | / | [ckpt](https://download.openmmlab.com/mmpose/v1/body_3d_keypoint/pose_lift/h36m/motionbert_h36m-f554954f_20230531.pth) | / | [motionbert_h36m.md](./h36m/motionbert_h36m.md) | +| [MotionBERT-finetuned](/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_motionbert-243frm_8xb32-120e_h36m.py) | 27.5 | 21.6 | / | [ckpt](https://download.openmmlab.com/mmpose/v1/body_3d_keypoint/pose_lift/h36m/motionbert_ft_h36m-d80af323_20230531.pth) | / | [motionbert_h36m.md](./h36m/motionbert_h36m.md) | ## Image-based Single-view 3D Human Body Pose Estimation diff --git a/configs/body_3d_keypoint/pose_lift/h36m/motionbert_h36m.md b/configs/body_3d_keypoint/pose_lift/h36m/motionbert_h36m.md new file mode 100644 index 0000000000..f7b8faab1e --- /dev/null +++ b/configs/body_3d_keypoint/pose_lift/h36m/motionbert_h36m.md @@ -0,0 +1,51 @@ + + +
+MotionBERT (2022) + +```bibtex + @misc{Zhu_Ma_Liu_Liu_Wu_Wang_2022, + title={Learning Human Motion Representations: A Unified Perspective}, + author={Zhu, Wentao and Ma, Xiaoxuan and Liu, Zhaoyang and Liu, Libin and Wu, Wayne and Wang, Yizhou}, + year={2022}, + month={Oct}, + language={en-US} + } +``` + +
+ + + +
+Human3.6M (TPAMI'2014) + +```bibtex +@article{h36m_pami, +author = {Ionescu, Catalin and Papava, Dragos and Olaru, Vlad and Sminchisescu, Cristian}, +title = {Human3.6M: Large Scale Datasets and Predictive Methods for 3D Human Sensing in Natural Environments}, +journal = {IEEE Transactions on Pattern Analysis and Machine Intelligence}, +publisher = {IEEE Computer Society}, +volume = {36}, +number = {7}, +pages = {1325-1339}, +month = {jul}, +year = {2014} +} +``` + +
+ +Testing results on Human3.6M dataset with ground truth 2D detections + +| Arch | MPJPE | average MPJPE | P-MPJPE | ckpt | +| :-------------------------------------------------------------------------------------- | :---: | :-----------: | :-----: | :--------------------------------------------------------------------------------------: | +| [MotionBERT](/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_motionbert-243frm_8xb32-120e_h36m.py) | 35.3 | 35.3 | 27.7 | [ckpt](https://download.openmmlab.com/mmpose/v1/body_3d_keypoint/pose_lift/h36m/motionbert_h36m-f554954f_20230531.pth) | +| [MotionBERT-finetuned](/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_motionbert-243frm_8xb32-120e_h36m.py) | 27.5 | 27.4 | 21.6 | [ckpt](https://download.openmmlab.com/mmpose/v1/body_3d_keypoint/pose_lift/h36m/motionbert_ft_h36m-d80af323_20230531.pth) | + +Testing results on Human3.6M dataset from the [official repo](https://github.com/Walter0807/MotionBERT) with ground truth 2D detections + +| Arch | MPJPE | average MPJPE | P-MPJPE | ckpt | +| :-------------------------------------------------------------------------------------- | :---: | :-----------: | :-----: | :--------------------------------------------------------------------------------------: | +| [MotionBERT](/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_motionbert-243frm_8xb32-120e_h36m.py) | 40.5 | 39.9 | 34.1 | [ckpt](https://download.openmmlab.com/mmpose/v1/body_3d_keypoint/pose_lift/h36m/motionbert_h36m-f554954f_20230531.pth) | +| [MotionBERT-finetuned](/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_motionbert-243frm_8xb32-120e_h36m.py) | 38.2 | 37.7 | 32.6 | [ckpt](https://download.openmmlab.com/mmpose/v1/body_3d_keypoint/pose_lift/h36m/motionbert_ft_h36m-d80af323_20230531.pth) | diff --git a/configs/body_3d_keypoint/video_pose_lift/h36m/motionbert_h36m.yml b/configs/body_3d_keypoint/pose_lift/h36m/motionbert_h36m.yml similarity index 64% rename from configs/body_3d_keypoint/video_pose_lift/h36m/motionbert_h36m.yml rename to configs/body_3d_keypoint/pose_lift/h36m/motionbert_h36m.yml index a4ed9970a3..7257fea5a6 100644 --- a/configs/body_3d_keypoint/video_pose_lift/h36m/motionbert_h36m.yml +++ b/configs/body_3d_keypoint/pose_lift/h36m/motionbert_h36m.yml @@ -5,7 +5,7 @@ Collections: URL: https://arxiv.org/abs/2210.06551 README: https://github.com/open-mmlab/mmpose/blob/main/docs/en/papers/algorithms/motionbert.md Models: -- Config: configs/body_3d_keypoint/video_pose_lift/h36m/vid_pl_motionbert_8xb32-120e_h36m.py +- Config: configs/body_3d_keypoint/pose_lift/h36m/pose-lift_motionbert_8xb32-120e_h36m.py In Collection: MotionBERT Metadata: Architecture: &id001 @@ -18,8 +18,8 @@ Models: MPJPE: 35.3 P-MPJPE: 27.7 Task: Body 3D Keypoint - Weights: -- Config: configs/body_3d_keypoint/video_pose_lift/h36m/vid_pl_motionbert_8xb32-120e_h36m.py + Weights: https://download.openmmlab.com/mmpose/v1/body_3d_keypoint/pose_lift/h36m/motionbert_h36m-f554954f_20230531.pth +- Config: configs/body_3d_keypoint/pose_lift/h36m/pose-lift_motionbert_8xb32-120e_h36m.py In Collection: MotionBERT Metadata: Architecture: *id001 @@ -31,4 +31,4 @@ Models: MPJPE: 27.5 P-MPJPE: 21.6 Task: Body 3D Keypoint - Weights: + Weights: https://download.openmmlab.com/mmpose/v1/body_3d_keypoint/pose_lift/h36m/motionbert_ft_h36m-d80af323_20230531.pth diff --git a/configs/body_3d_keypoint/video_pose_lift/h36m/vid_pl_motionbert-243frm_8xb32-120e_h36m.py b/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_motionbert-243frm_8xb32-120e_h36m.py similarity index 100% rename from configs/body_3d_keypoint/video_pose_lift/h36m/vid_pl_motionbert-243frm_8xb32-120e_h36m.py rename to configs/body_3d_keypoint/pose_lift/h36m/pose-lift_motionbert-243frm_8xb32-120e_h36m.py diff --git a/configs/body_3d_keypoint/video_pose_lift/h36m/motionbert_h36m.md b/configs/body_3d_keypoint/video_pose_lift/h36m/motionbert_h36m.md deleted file mode 100644 index e04a540b75..0000000000 --- a/configs/body_3d_keypoint/video_pose_lift/h36m/motionbert_h36m.md +++ /dev/null @@ -1,51 +0,0 @@ - - -
-MotionBERT (2022) - -```bibtex - @misc{Zhu_Ma_Liu_Liu_Wu_Wang_2022, - title={Learning Human Motion Representations: A Unified Perspective}, - author={Zhu, Wentao and Ma, Xiaoxuan and Liu, Zhaoyang and Liu, Libin and Wu, Wayne and Wang, Yizhou}, - year={2022}, - month={Oct}, - language={en-US} - } -``` - -
- - - -
-Human3.6M (TPAMI'2014) - -```bibtex -@article{h36m_pami, -author = {Ionescu, Catalin and Papava, Dragos and Olaru, Vlad and Sminchisescu, Cristian}, -title = {Human3.6M: Large Scale Datasets and Predictive Methods for 3D Human Sensing in Natural Environments}, -journal = {IEEE Transactions on Pattern Analysis and Machine Intelligence}, -publisher = {IEEE Computer Society}, -volume = {36}, -number = {7}, -pages = {1325-1339}, -month = {jul}, -year = {2014} -} -``` - -
- -Testing results on Human3.6M dataset with ground truth 2D detections - -| Arch | MPJPE | average MPJPE | P-MPJPE | ckpt | -| :----------------------------------------------------------------------------------------------------------------- | :---: | :-----------: | :-----: | :--------: | -| [MotionBERT](/configs/body_3d_keypoint/video_pose_lift/h36m/vid_pl_motionbert-243frm_8xb32-120e_h36m.py) | 35.3 | 35.3 | 27.7 | [ckpt](<>) | -| [MotionBERT-finetuned](/configs/body_3d_keypoint/video_pose_lift/h36m/vid_pl_motionbert-243frm_8xb32-120e_h36m.py) | 27.5 | 27.4 | 21.6 | [ckpt](<>) | - -Testing results on Human3.6M dataset from the [official repo](https://github.com/Walter0807/MotionBERT) with ground truth 2D detections - -| Arch | MPJPE | average MPJPE | P-MPJPE | ckpt | -| :----------------------------------------------------------------------------------------------------------------- | :---: | :-----------: | :-----: | :--------: | -| [MotionBERT](/configs/body_3d_keypoint/video_pose_lift/h36m/vid_pl_motionbert-243frm_8xb32-120e_h36m.py) | 40.5 | 39.9 | 34.1 | [ckpt](<>) | -| [MotionBERT-finetuned](/configs/body_3d_keypoint/video_pose_lift/h36m/vid_pl_motionbert-243frm_8xb32-120e_h36m.py) | 38.2 | 37.7 | 32.6 | [ckpt](<>) | From e54f13cbdaf2855c65ee988a093436386dc03382 Mon Sep 17 00:00:00 2001 From: LareinaM Date: Wed, 5 Jul 2023 11:42:23 +0800 Subject: [PATCH 22/30] fix bug --- mmpose/apis/inferencers/pose3d_inferencer.py | 4 ++-- mmpose/datasets/transforms/formatting.py | 5 +++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/mmpose/apis/inferencers/pose3d_inferencer.py b/mmpose/apis/inferencers/pose3d_inferencer.py index 0fe66ac72b..819273af66 100644 --- a/mmpose/apis/inferencers/pose3d_inferencer.py +++ b/mmpose/apis/inferencers/pose3d_inferencer.py @@ -271,8 +271,8 @@ def preprocess_single(self, K, ), dtype=np.float32) - data_info['lifting_target'] = np.zeros((K, 3), dtype=np.float32) - data_info['lifting_target_visible'] = np.ones((K, 1), + data_info['lifting_target'] = np.zeros((1, K, 3), dtype=np.float32) + data_info['lifting_target_visible'] = np.ones((1, K, 1), dtype=np.float32) data_info['camera_param'] = dict(w=width, h=height) diff --git a/mmpose/datasets/transforms/formatting.py b/mmpose/datasets/transforms/formatting.py index 132ecd8bb6..d047cff3c3 100644 --- a/mmpose/datasets/transforms/formatting.py +++ b/mmpose/datasets/transforms/formatting.py @@ -207,8 +207,9 @@ def transform(self, results: dict) -> dict: for key, packed_key in self.label_mapping_table.items(): if key in results: # For pose-lifting, store only target-related fields - if 'lifting_target' in results and key in { - 'keypoint_labels', 'keypoint_weights' + if 'lifting_target' in results and packed_key in { + 'keypoint_labels', 'keypoint_weights', + 'keypoints_visible' }: continue if isinstance(results[key], list): From cb9d6b9b4ffa892b4be254aa274e9521301ed6bf Mon Sep 17 00:00:00 2001 From: LareinaM Date: Wed, 5 Jul 2023 14:20:18 +0800 Subject: [PATCH 23/30] update md --- configs/body_3d_keypoint/pose_lift/README.md | 16 +++++++++------- .../pose_lift/h36m/motionbert_h36m.md | 10 ++++++---- .../pose_lift/h36m/videopose3d_h36m.md | 18 +++++++++--------- 3 files changed, 24 insertions(+), 20 deletions(-) diff --git a/configs/body_3d_keypoint/pose_lift/README.md b/configs/body_3d_keypoint/pose_lift/README.md index f965c70cb2..b5453b4437 100644 --- a/configs/body_3d_keypoint/pose_lift/README.md +++ b/configs/body_3d_keypoint/pose_lift/README.md @@ -20,13 +20,15 @@ For single-person 3D pose estimation from a monocular camera, existing works can | :-------------------------------------------- | :---: | :-----: | :-----: | :-------------------------------------------: | :------------------------------------------: | :---------------------------------------------: | | [VideoPose3D-supervised-27frm](/configs/body_3d_keypoint/video_pose_lift/h36m/vid-pl_videopose3d-27frm-supv_8xb128-80e_h36m.py) | 40.1 | 30.1 | / | [ckpt](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_27frames_fullconv_supervised-fe8fbba9_20210527.pth) | [log](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_27frames_fullconv_supervised_20210527.log.json) | [videpose3d_h36m.md](./h36m/videpose3d_h36m.md) | | [VideoPose3D-supervised-81frm](/configs/body_3d_keypoint/video_pose_lift/h36m/vid-pl_videopose3d-81frm-supv_8xb128-80e_h36m.py) | 39.1 | 29.3 | / | [ckpt](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_81frames_fullconv_supervised-1f2d1104_20210527.pth) | [log](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_81frames_fullconv_supervised_20210527.log.json) | [videpose3d_h36m.md](./h36m/videpose3d_h36m.md) | -| [VideoPose3D-supervised-243frm](/configs/body_3d_keypoint/video_pose_lift/h36m/vid-pl_videopose3d-243frm-supv_8xb128-80e_h36m.py) | | | / | [ckpt](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_243frames_fullconv_supervised-880bea25_20210527.pth) | [log](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_243frames_fullconv_supervised_20210527.log.json) | [videpose3d_h36m.md](./h36m/videpose3d_h36m.md) | +| [VideoPose3D-supervised-243frm](/configs/body_3d_keypoint/video_pose_lift/h36m/vid-pl_videopose3d-243frm-supv_8xb128-80e_h36m.py) | 37.6 | 28.3 | / | [ckpt](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_243frames_fullconv_supervised-880bea25_20210527.pth) | [log](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_243frames_fullconv_supervised_20210527.log.json) | [videpose3d_h36m.md](./h36m/videpose3d_h36m.md) | | [VideoPose3D-supervised-CPN-1frm](/configs/body_3d_keypoint/video_pose_lift/h36m/vid-pl_videopose3d-1frm-supv-cpn-ft_8xb128-80e_h36m.py) | 53.0 | 41.3 | / | [ckpt](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_1frame_fullconv_supervised_cpn_ft-5c3afaed_20210527.pth) | [log](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_1frame_fullconv_supervised_cpn_ft_20210527.log.json) | [videpose3d_h36m.md](./h36m/videpose3d_h36m.md) | -| [VideoPose3D-supervised-CPN-243frm](/configs/body_3d_keypoint/video_pose_lift/h36m/vid-pl_videopose3d-243frm-supv-cpn-ft_8xb128-200e_h36m.py) | | | / | [ckpt](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_243frames_fullconv_supervised_cpn_ft-88f5abbb_20210527.pth) | [log](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_243frames_fullconv_supervised_cpn_ft_20210527.log.json) | [videpose3d_h36m.md](./h36m/videpose3d_h36m.md) | +| [VideoPose3D-supervised-CPN-243frm](/configs/body_3d_keypoint/video_pose_lift/h36m/vid-pl_videopose3d-243frm-supv-cpn-ft_8xb128-200e_h36m.py) | 47.9 | 38.0 | / | [ckpt](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_243frames_fullconv_supervised_cpn_ft-88f5abbb_20210527.pth) | [log](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_243frames_fullconv_supervised_cpn_ft_20210527.log.json) | [videpose3d_h36m.md](./h36m/videpose3d_h36m.md) | | [VideoPose3D-semi-supervised-27frm](/configs/body_3d_keypoint/video_pose_lift/h36m/vid-pl_videopose3d-27frm-semi-supv_8xb64-200e_h36m.py) | 57.2 | 42.4 | 54.2 | [ckpt](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_27frames_fullconv_semi-supervised-54aef83b_20210527.pth) | [log](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_27frames_fullconv_semi-supervised_20210527.log.json) | [videpose3d_h36m.md](./h36m/videpose3d_h36m.md) | | [VideoPose3D-semi-supervised-CPN-27frm](/configs/body_3d_keypoint/video_pose_lift/h36m/vid-pl_videopose3d-27frm-semi-supv-cpn-ft_8xb64-200e_h36m.py) | 67.3 | 50.4 | 63.6 | [ckpt](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_27frames_fullconv_semi-supervised_cpn_ft-71be9cde_20210527.pth) | [log](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_27frames_fullconv_semi-supervised_cpn_ft_20210527.log.json) | [videpose3d_h36m.md](./h36m/videpose3d_h36m.md) | -| [MotionBERT](/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_motionbert-243frm_8xb32-120e_h36m.py) | 35.3 | 27.7 | / | [ckpt](https://download.openmmlab.com/mmpose/v1/body_3d_keypoint/pose_lift/h36m/motionbert_h36m-f554954f_20230531.pth) | / | [motionbert_h36m.md](./h36m/motionbert_h36m.md) | -| [MotionBERT-finetuned](/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_motionbert-243frm_8xb32-120e_h36m.py) | 27.5 | 21.6 | / | [ckpt](https://download.openmmlab.com/mmpose/v1/body_3d_keypoint/pose_lift/h36m/motionbert_ft_h36m-d80af323_20230531.pth) | / | [motionbert_h36m.md](./h36m/motionbert_h36m.md) | +| [MotionBERT\*](/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_motionbert-243frm_8xb32-120e_h36m.py) | 35.3 | 27.7 | / | [ckpt](https://download.openmmlab.com/mmpose/v1/body_3d_keypoint/pose_lift/h36m/motionbert_h36m-f554954f_20230531.pth) | / | [motionbert_h36m.md](./h36m/motionbert_h36m.md) | +| [MotionBERT-finetuned\*](/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_motionbert-243frm_8xb32-120e_h36m.py) | 27.5 | 21.6 | / | [ckpt](https://download.openmmlab.com/mmpose/v1/body_3d_keypoint/pose_lift/h36m/motionbert_ft_h36m-d80af323_20230531.pth) | / | [motionbert_h36m.md](./h36m/motionbert_h36m.md) | + +*Models with * are converted from the official repo. The config files of these models are only for validation. We don't ensure these config files' training accuracy and welcome you to contribute your reproduction results.* ## Image-based Single-view 3D Human Body Pose Estimation @@ -40,6 +42,6 @@ For single-person 3D pose estimation from a monocular camera, existing works can #### Human3.6m Dataset -| Arch | MPJPE | P-MPJPE | N-MPJPE | ckpt | log | -| :------------------------------------------------------ | :-------------: | :---: | :-----: | :-----: | :------------------------------------------------------: | :-----------------------------------------------------: | -| [SimpleBaseline3D-tcn](/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_simplebaseline3d_8xb64-200e_h36m.py) | 43.4 | 34.3 | /|[ckpt](https://download.openmmlab.com/mmpose/body3d/simple_baseline/simple3Dbaseline_h36m-f0ad73a4_20210419.pth) | [log](https://download.openmmlab.com/mmpose/body3d/simple_baseline/20210415_065056.log.json) | +| Arch | MPJPE | P-MPJPE | N-MPJPE | ckpt | log | Details and Download | +| :---------------------------------------- | :---: | :-----: | :-----: | :---------------------------------------: | :---------------------------------------: | :--------------------------------------------------------: | +| [SimpleBaseline3D-tcn](/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_simplebaseline3d_8xb64-200e_h36m.py) | 43.4 | 34.3 | / | [ckpt](https://download.openmmlab.com/mmpose/body3d/simple_baseline/simple3Dbaseline_h36m-f0ad73a4_20210419.pth) | [log](https://download.openmmlab.com/mmpose/body3d/simple_baseline/20210415_065056.log.json) | [simplebaseline3d_h36m.md](./h36m/simplebaseline3d_h36m.md) | diff --git a/configs/body_3d_keypoint/pose_lift/h36m/motionbert_h36m.md b/configs/body_3d_keypoint/pose_lift/h36m/motionbert_h36m.md index f7b8faab1e..d830d65c18 100644 --- a/configs/body_3d_keypoint/pose_lift/h36m/motionbert_h36m.md +++ b/configs/body_3d_keypoint/pose_lift/h36m/motionbert_h36m.md @@ -40,12 +40,14 @@ Testing results on Human3.6M dataset with ground truth 2D detections | Arch | MPJPE | average MPJPE | P-MPJPE | ckpt | | :-------------------------------------------------------------------------------------- | :---: | :-----------: | :-----: | :--------------------------------------------------------------------------------------: | -| [MotionBERT](/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_motionbert-243frm_8xb32-120e_h36m.py) | 35.3 | 35.3 | 27.7 | [ckpt](https://download.openmmlab.com/mmpose/v1/body_3d_keypoint/pose_lift/h36m/motionbert_h36m-f554954f_20230531.pth) | -| [MotionBERT-finetuned](/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_motionbert-243frm_8xb32-120e_h36m.py) | 27.5 | 27.4 | 21.6 | [ckpt](https://download.openmmlab.com/mmpose/v1/body_3d_keypoint/pose_lift/h36m/motionbert_ft_h36m-d80af323_20230531.pth) | +| [MotionBERT\*](/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_motionbert-243frm_8xb32-120e_h36m.py) | 35.3 | 35.3 | 27.7 | [ckpt](https://download.openmmlab.com/mmpose/v1/body_3d_keypoint/pose_lift/h36m/motionbert_h36m-f554954f_20230531.pth) | +| [MotionBERT-finetuned\*](/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_motionbert-243frm_8xb32-120e_h36m.py) | 27.5 | 27.4 | 21.6 | [ckpt](https://download.openmmlab.com/mmpose/v1/body_3d_keypoint/pose_lift/h36m/motionbert_ft_h36m-d80af323_20230531.pth) | Testing results on Human3.6M dataset from the [official repo](https://github.com/Walter0807/MotionBERT) with ground truth 2D detections | Arch | MPJPE | average MPJPE | P-MPJPE | ckpt | | :-------------------------------------------------------------------------------------- | :---: | :-----------: | :-----: | :--------------------------------------------------------------------------------------: | -| [MotionBERT](/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_motionbert-243frm_8xb32-120e_h36m.py) | 40.5 | 39.9 | 34.1 | [ckpt](https://download.openmmlab.com/mmpose/v1/body_3d_keypoint/pose_lift/h36m/motionbert_h36m-f554954f_20230531.pth) | -| [MotionBERT-finetuned](/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_motionbert-243frm_8xb32-120e_h36m.py) | 38.2 | 37.7 | 32.6 | [ckpt](https://download.openmmlab.com/mmpose/v1/body_3d_keypoint/pose_lift/h36m/motionbert_ft_h36m-d80af323_20230531.pth) | +| [MotionBERT\*](/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_motionbert-243frm_8xb32-120e_h36m.py) | 40.5 | 39.9 | 34.1 | [ckpt](https://download.openmmlab.com/mmpose/v1/body_3d_keypoint/pose_lift/h36m/motionbert_h36m-f554954f_20230531.pth) | +| [MotionBERT-finetuned\*](/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_motionbert-243frm_8xb32-120e_h36m.py) | 38.2 | 37.7 | 32.6 | [ckpt](https://download.openmmlab.com/mmpose/v1/body_3d_keypoint/pose_lift/h36m/motionbert_ft_h36m-d80af323_20230531.pth) | + +*Models with * are converted from the [official repo](https://github.com/Walter0807/MotionBERT). The config files of these models are only for validation. We don't ensure these config files' training accuracy and welcome you to contribute your reproduction results.* diff --git a/configs/body_3d_keypoint/pose_lift/h36m/videopose3d_h36m.md b/configs/body_3d_keypoint/pose_lift/h36m/videopose3d_h36m.md index f1c75d786a..48502c7b09 100644 --- a/configs/body_3d_keypoint/pose_lift/h36m/videopose3d_h36m.md +++ b/configs/body_3d_keypoint/pose_lift/h36m/videopose3d_h36m.md @@ -41,27 +41,27 @@ Testing results on Human3.6M dataset with ground truth 2D detections, supervised | Arch | Receptive Field | MPJPE | P-MPJPE | ckpt | log | | :--------------------------------------------------------- | :-------------: | :---: | :-----: | :--------------------------------------------------------: | :-------------------------------------------------------: | -| [VideoPose3D](/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_videopose3d-27frm-supv_8xb128-80e_h36m.py) | 27 | 40.1 | 30.1 | [ckpt](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_27frames_fullconv_supervised-fe8fbba9_20210527.pth) | [log](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_27frames_fullconv_supervised_20210527.log.json) | -| [VideoPose3D](/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_videopose3d-81frm-supv_8xb128-80e_h36m.py) | 81 | 39.1 | 29.3 | [ckpt](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_81frames_fullconv_supervised-1f2d1104_20210527.pth) | [log](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_81frames_fullconv_supervised_20210527.log.json) | -| [VideoPose3D](/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_videopose3d-243frm-supv_8xb128-80e_h36m.py) | 243 | | | [ckpt](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_243frames_fullconv_supervised-880bea25_20210527.pth) | [log](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_243frames_fullconv_supervised_20210527.log.json) | +| [VideoPose3D-supervised-27frm](/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_videopose3d-27frm-supv_8xb128-80e_h36m.py) | 27 | 40.1 | 30.1 | [ckpt](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_27frames_fullconv_supervised-fe8fbba9_20210527.pth) | [log](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_27frames_fullconv_supervised_20210527.log.json) | +| [VideoPose3D-supervised-81frm](/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_videopose3d-81frm-supv_8xb128-80e_h36m.py) | 81 | 39.1 | 29.3 | [ckpt](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_81frames_fullconv_supervised-1f2d1104_20210527.pth) | [log](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_81frames_fullconv_supervised_20210527.log.json) | +| [VideoPose3D-supervised-243frm](/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_videopose3d-243frm-supv_8xb128-80e_h36m.py) | 243 | 37.6 | 28.3 | [ckpt](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_243frames_fullconv_supervised-880bea25_20210527.pth) | [log](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_243frames_fullconv_supervised_20210527.log.json) | Testing results on Human3.6M dataset with CPN 2D detections1, supervised training | Arch | Receptive Field | MPJPE | P-MPJPE | ckpt | log | | :--------------------------------------------------------- | :-------------: | :---: | :-----: | :--------------------------------------------------------: | :-------------------------------------------------------: | -| [VideoPose3D](/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_videopose3d-1frm-supv-cpn-ft_8xb128-80e_h36m.py) | 1 | 53.0 | 41.3 | [ckpt](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_1frame_fullconv_supervised_cpn_ft-5c3afaed_20210527.pth) | [log](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_1frame_fullconv_supervised_cpn_ft_20210527.log.json) | -| [VideoPose3D](/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_videopose3d-243frm-supv-cpn-ft_8xb128-200e_h36m.py) | 243 | | | [ckpt](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_243frames_fullconv_supervised_cpn_ft-88f5abbb_20210527.pth) | [log](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_243frames_fullconv_supervised_cpn_ft_20210527.log.json) | +| [VideoPose3D-supervised-CPN-1frm](/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_videopose3d-1frm-supv-cpn-ft_8xb128-80e_h36m.py) | 1 | 53.0 | 41.3 | [ckpt](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_1frame_fullconv_supervised_cpn_ft-5c3afaed_20210527.pth) | [log](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_1frame_fullconv_supervised_cpn_ft_20210527.log.json) | +| [VideoPose3D-supervised-CPN-243frm](/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_videopose3d-243frm-supv-cpn-ft_8xb128-200e_h36m.py) | 243 | 47.9 | 38.0 | [ckpt](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_243frames_fullconv_supervised_cpn_ft-88f5abbb_20210527.pth) | [log](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_243frames_fullconv_supervised_cpn_ft_20210527.log.json) | Testing results on Human3.6M dataset with ground truth 2D detections, semi-supervised training | Training Data | Arch | Receptive Field | MPJPE | P-MPJPE | N-MPJPE | ckpt | log | | :------------ | :-------------------------------------------------: | :-------------: | :---: | :-----: | :-----: | :-------------------------------------------------: | :-------------------------------------------------: | -| 10% S1 | [VideoPose3D](/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_videopose3d-27frm-semi-supv_8xb64-200e_h36m.py) | 27 | 57.2 | 42.4 | 54.2 | [ckpt](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_27frames_fullconv_semi-supervised-54aef83b_20210527.pth) | [log](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_27frames_fullconv_semi-supervised_20210527.log.json) | +| 10% S1 | [VideoPose3D-semi-supervised-27frm](/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_videopose3d-27frm-semi-supv_8xb64-200e_h36m.py) | 27 | 57.2 | 42.4 | 54.2 | [ckpt](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_27frames_fullconv_semi-supervised-54aef83b_20210527.pth) | [log](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_27frames_fullconv_semi-supervised_20210527.log.json) | Testing results on Human3.6M dataset with CPN 2D detections1, semi-supervised training -| Training Data | Arch | Receptive Field | MPJPE | P-MPJPE | N-MPJPE | ckpt | log | -| :------------ | :----------------------------: | :-------------: | :---: | :-----: | :-----: | :------------------------------------------------------------: | :-----------------------------------------------------------: | -| 10% S1 | [VideoPose3D](/configs/xxx.py) | 27 | 67.3 | 50.4 | 63.6 | [ckpt](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_27frames_fullconv_semi-supervised_cpn_ft-71be9cde_20210527.pth) | [log](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_27frames_fullconv_semi-supervised_cpn_ft_20210527.log.json) | +| Training Data | Arch | Receptive Field | MPJPE | P-MPJPE | N-MPJPE | ckpt | log | +| :------------ | :-------------------------------------------------: | :-------------: | :---: | :-----: | :-----: | :-------------------------------------------------: | :-------------------------------------------------: | +| 10% S1 | [VideoPose3D-semi-supervised-CPN-27frm](/configs/xxx.py) | 27 | 67.3 | 50.4 | 63.6 | [ckpt](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_27frames_fullconv_semi-supervised_cpn_ft-71be9cde_20210527.pth) | [log](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_27frames_fullconv_semi-supervised_cpn_ft_20210527.log.json) | 1 CPN 2D detections are provided by [official repo](https://github.com/facebookresearch/VideoPose3D/blob/master/DATASETS.md). The reformatted version used in this repository can be downloaded from [train_detection](https://download.openmmlab.com/mmpose/body3d/videopose/cpn_ft_h36m_dbb_train.npy) and [test_detection](https://download.openmmlab.com/mmpose/body3d/videopose/cpn_ft_h36m_dbb_test.npy). From 8a001e53fb514d480ebba35dd9573fa06c0e674f Mon Sep 17 00:00:00 2001 From: LareinaM Date: Thu, 6 Jul 2023 15:49:56 +0800 Subject: [PATCH 24/30] remove useless info --- .../h36m/pose-lift_motionbert-243frm_8xb32-120e_h36m.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_motionbert-243frm_8xb32-120e_h36m.py b/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_motionbert-243frm_8xb32-120e_h36m.py index 2803323ed2..5339d1ff13 100644 --- a/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_motionbert-243frm_8xb32-120e_h36m.py +++ b/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_motionbert-243frm_8xb32-120e_h36m.py @@ -77,8 +77,6 @@ dataset=dict( type=dataset_type, ann_file='annotation_body3d/fps50/h36m_test.npz', - # ann_file='annotation_body3d/fps50/h36m_test_original.npz', - # factor_file='annotation_body3d/fps50/h36m_factors.npy', seq_len=1, multiple_target=243, seq_step=1, From d06b39dfc4200842c7bf07931e83228d6a2da385 Mon Sep 17 00:00:00 2001 From: LareinaM Date: Fri, 7 Jul 2023 11:34:24 +0800 Subject: [PATCH 25/30] update --- configs/body_3d_keypoint/pose_lift/README.md | 14 +++++++------- tests/test_codecs/test_motionbert_label.py | 6 ++++-- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/configs/body_3d_keypoint/pose_lift/README.md b/configs/body_3d_keypoint/pose_lift/README.md index b5453b4437..e3e6ff7176 100644 --- a/configs/body_3d_keypoint/pose_lift/README.md +++ b/configs/body_3d_keypoint/pose_lift/README.md @@ -18,13 +18,13 @@ For single-person 3D pose estimation from a monocular camera, existing works can | Arch | MPJPE | P-MPJPE | N-MPJPE | ckpt | log | Details and Download | | :-------------------------------------------- | :---: | :-----: | :-----: | :-------------------------------------------: | :------------------------------------------: | :---------------------------------------------: | -| [VideoPose3D-supervised-27frm](/configs/body_3d_keypoint/video_pose_lift/h36m/vid-pl_videopose3d-27frm-supv_8xb128-80e_h36m.py) | 40.1 | 30.1 | / | [ckpt](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_27frames_fullconv_supervised-fe8fbba9_20210527.pth) | [log](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_27frames_fullconv_supervised_20210527.log.json) | [videpose3d_h36m.md](./h36m/videpose3d_h36m.md) | -| [VideoPose3D-supervised-81frm](/configs/body_3d_keypoint/video_pose_lift/h36m/vid-pl_videopose3d-81frm-supv_8xb128-80e_h36m.py) | 39.1 | 29.3 | / | [ckpt](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_81frames_fullconv_supervised-1f2d1104_20210527.pth) | [log](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_81frames_fullconv_supervised_20210527.log.json) | [videpose3d_h36m.md](./h36m/videpose3d_h36m.md) | -| [VideoPose3D-supervised-243frm](/configs/body_3d_keypoint/video_pose_lift/h36m/vid-pl_videopose3d-243frm-supv_8xb128-80e_h36m.py) | 37.6 | 28.3 | / | [ckpt](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_243frames_fullconv_supervised-880bea25_20210527.pth) | [log](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_243frames_fullconv_supervised_20210527.log.json) | [videpose3d_h36m.md](./h36m/videpose3d_h36m.md) | -| [VideoPose3D-supervised-CPN-1frm](/configs/body_3d_keypoint/video_pose_lift/h36m/vid-pl_videopose3d-1frm-supv-cpn-ft_8xb128-80e_h36m.py) | 53.0 | 41.3 | / | [ckpt](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_1frame_fullconv_supervised_cpn_ft-5c3afaed_20210527.pth) | [log](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_1frame_fullconv_supervised_cpn_ft_20210527.log.json) | [videpose3d_h36m.md](./h36m/videpose3d_h36m.md) | -| [VideoPose3D-supervised-CPN-243frm](/configs/body_3d_keypoint/video_pose_lift/h36m/vid-pl_videopose3d-243frm-supv-cpn-ft_8xb128-200e_h36m.py) | 47.9 | 38.0 | / | [ckpt](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_243frames_fullconv_supervised_cpn_ft-88f5abbb_20210527.pth) | [log](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_243frames_fullconv_supervised_cpn_ft_20210527.log.json) | [videpose3d_h36m.md](./h36m/videpose3d_h36m.md) | -| [VideoPose3D-semi-supervised-27frm](/configs/body_3d_keypoint/video_pose_lift/h36m/vid-pl_videopose3d-27frm-semi-supv_8xb64-200e_h36m.py) | 57.2 | 42.4 | 54.2 | [ckpt](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_27frames_fullconv_semi-supervised-54aef83b_20210527.pth) | [log](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_27frames_fullconv_semi-supervised_20210527.log.json) | [videpose3d_h36m.md](./h36m/videpose3d_h36m.md) | -| [VideoPose3D-semi-supervised-CPN-27frm](/configs/body_3d_keypoint/video_pose_lift/h36m/vid-pl_videopose3d-27frm-semi-supv-cpn-ft_8xb64-200e_h36m.py) | 67.3 | 50.4 | 63.6 | [ckpt](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_27frames_fullconv_semi-supervised_cpn_ft-71be9cde_20210527.pth) | [log](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_27frames_fullconv_semi-supervised_cpn_ft_20210527.log.json) | [videpose3d_h36m.md](./h36m/videpose3d_h36m.md) | +| [VideoPose3D-supervised-27frm](/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_videopose3d-27frm-supv_8xb128-80e_h36m.py) | 40.1 | 30.1 | / | [ckpt](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_27frames_fullconv_supervised-fe8fbba9_20210527.pth) | [log](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_27frames_fullconv_supervised_20210527.log.json) | [videpose3d_h36m.md](./h36m/videpose3d_h36m.md) | +| [VideoPose3D-supervised-81frm](/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_videopose3d-81frm-supv_8xb128-80e_h36m.py) | 39.1 | 29.3 | / | [ckpt](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_81frames_fullconv_supervised-1f2d1104_20210527.pth) | [log](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_81frames_fullconv_supervised_20210527.log.json) | [videpose3d_h36m.md](./h36m/videpose3d_h36m.md) | +| [VideoPose3D-supervised-243frm](/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_videopose3d-243frm-supv_8xb128-80e_h36m.py) | 37.6 | 28.3 | / | [ckpt](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_243frames_fullconv_supervised-880bea25_20210527.pth) | [log](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_243frames_fullconv_supervised_20210527.log.json) | [videpose3d_h36m.md](./h36m/videpose3d_h36m.md) | +| [VideoPose3D-supervised-CPN-1frm](/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_videopose3d-1frm-supv-cpn-ft_8xb128-80e_h36m.py) | 53.0 | 41.3 | / | [ckpt](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_1frame_fullconv_supervised_cpn_ft-5c3afaed_20210527.pth) | [log](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_1frame_fullconv_supervised_cpn_ft_20210527.log.json) | [videpose3d_h36m.md](./h36m/videpose3d_h36m.md) | +| [VideoPose3D-supervised-CPN-243frm](/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_videopose3d-243frm-supv-cpn-ft_8xb128-200e_h36m.py) | 47.9 | 38.0 | / | [ckpt](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_243frames_fullconv_supervised_cpn_ft-88f5abbb_20210527.pth) | [log](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_243frames_fullconv_supervised_cpn_ft_20210527.log.json) | [videpose3d_h36m.md](./h36m/videpose3d_h36m.md) | +| [VideoPose3D-semi-supervised-27frm](/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_videopose3d-27frm-semi-supv_8xb64-200e_h36m.py) | 57.2 | 42.4 | 54.2 | [ckpt](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_27frames_fullconv_semi-supervised-54aef83b_20210527.pth) | [log](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_27frames_fullconv_semi-supervised_20210527.log.json) | [videpose3d_h36m.md](./h36m/videpose3d_h36m.md) | +| [VideoPose3D-semi-supervised-CPN-27frm](/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_videopose3d-27frm-semi-supv-cpn-ft_8xb64-200e_h36m.py) | 67.3 | 50.4 | 63.6 | [ckpt](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_27frames_fullconv_semi-supervised_cpn_ft-71be9cde_20210527.pth) | [log](https://download.openmmlab.com/mmpose/body3d/videopose/videopose_h36m_27frames_fullconv_semi-supervised_cpn_ft_20210527.log.json) | [videpose3d_h36m.md](./h36m/videpose3d_h36m.md) | | [MotionBERT\*](/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_motionbert-243frm_8xb32-120e_h36m.py) | 35.3 | 27.7 | / | [ckpt](https://download.openmmlab.com/mmpose/v1/body_3d_keypoint/pose_lift/h36m/motionbert_h36m-f554954f_20230531.pth) | / | [motionbert_h36m.md](./h36m/motionbert_h36m.md) | | [MotionBERT-finetuned\*](/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_motionbert-243frm_8xb32-120e_h36m.py) | 27.5 | 21.6 | / | [ckpt](https://download.openmmlab.com/mmpose/v1/body_3d_keypoint/pose_lift/h36m/motionbert_ft_h36m-d80af323_20230531.pth) | / | [motionbert_h36m.md](./h36m/motionbert_h36m.md) | diff --git a/tests/test_codecs/test_motionbert_label.py b/tests/test_codecs/test_motionbert_label.py index 47ce3dfc68..01c9c654a2 100644 --- a/tests/test_codecs/test_motionbert_label.py +++ b/tests/test_codecs/test_motionbert_label.py @@ -137,7 +137,8 @@ def test_cicular_verification(self): keypoints[..., :, :] = keypoints[..., :, :] - keypoints[..., 0, :] - self.assertTrue(np.allclose(keypoints[..., :2], _keypoints[..., :2])) + self.assertTrue( + np.allclose(keypoints[..., :2] / 1000, _keypoints[..., :2])) # test with factor keypoints = (0.1 + 0.8 * np.random.rand(1, 17, 3)) @@ -154,4 +155,5 @@ def test_cicular_verification(self): keypoints *= encoded['factor'] keypoints[..., :, :] = keypoints[..., :, :] - keypoints[..., 0, :] - self.assertTrue(np.allclose(keypoints[..., :2], _keypoints[..., :2])) + self.assertTrue( + np.allclose(keypoints[..., :2] / 1000, _keypoints[..., :2])) From 3a278a335dff7c0dcea7aa6e857e5d00acade561 Mon Sep 17 00:00:00 2001 From: LareinaM Date: Thu, 6 Jul 2023 16:44:44 +0800 Subject: [PATCH 26/30] fix problems related to demo --- mmpose/apis/inference_3d.py | 1 + mmpose/codecs/motionbert_label.py | 10 +++++++--- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/mmpose/apis/inference_3d.py b/mmpose/apis/inference_3d.py index 8725b27caa..d4b9623b86 100644 --- a/mmpose/apis/inference_3d.py +++ b/mmpose/apis/inference_3d.py @@ -317,6 +317,7 @@ def inference_pose_lifter_model(model, K, ), dtype=np.float32) data_info['lifting_target'] = np.zeros((1, K, 3), dtype=np.float32) + data_info['factor'] = np.zeros((T, ), dtype=np.float32) data_info['lifting_target_visible'] = np.ones((1, K, 1), dtype=np.float32) diff --git a/mmpose/codecs/motionbert_label.py b/mmpose/codecs/motionbert_label.py index 1d036c49bb..ce3a9b4f65 100644 --- a/mmpose/codecs/motionbert_label.py +++ b/mmpose/codecs/motionbert_label.py @@ -131,8 +131,11 @@ def encode(self, ..., :2] = keypoint_labels[..., :2] / w * 2 - [1, h / w] # convert target to image coordinate - lifting_target_label, factor_ = camera_to_image_coord( - self.root_index, lifting_target_label, _camera_param) + T = keypoint_labels.shape[0] + factor_ = np.array([4] * T, dtype=np.float32).reshape(T, ) + if 'f' in _camera_param and 'c' in _camera_param: + lifting_target_label, factor_ = camera_to_image_coord( + self.root_index, lifting_target_label, _camera_param) lifting_target_label[..., :, :] = lifting_target_label[ ..., :, :] - lifting_target_label[..., self.root_index:self.root_index + @@ -141,7 +144,7 @@ def encode(self, factor = factor_ if factor.ndim == 1: factor = factor[:, None] - lifting_target_label *= 1000 * factor[..., None] + lifting_target_label *= factor[..., None] if self.concat_vis: keypoints_visible_ = keypoints_visible @@ -206,4 +209,5 @@ def decode( keypoints *= factor[..., None] keypoints[..., :, :] = keypoints[..., :, :] - keypoints[ ..., self.root_index:self.root_index + 1, :] + keypoints /= 1000. return keypoints, scores From fcf1ec68d15492b32e7fd52ceb2a059338aa17f6 Mon Sep 17 00:00:00 2001 From: LareinaM Date: Mon, 10 Jul 2023 10:39:34 +0800 Subject: [PATCH 27/30] rename --- ... pose-lift_videopose3d-1frm-supv-cpn-ft_8xb128-160e_h36m.py} | 2 +- ...py => pose-lift_videopose3d-243frm-supv_8xb128-160e_h36m.py} | 2 +- ....py => pose-lift_videopose3d-27frm-supv_8xb128-120e_h36m.py} | 2 +- ....py => pose-lift_videopose3d-81frm-supv_8xb128-160e_h36m.py} | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) rename configs/body_3d_keypoint/pose_lift/h36m/{pose-lift_videopose3d-1frm-supv-cpn-ft_8xb128-80e_h36m.py => pose-lift_videopose3d-1frm-supv-cpn-ft_8xb128-160e_h36m.py} (98%) rename configs/body_3d_keypoint/pose_lift/h36m/{pose-lift_videopose3d-243frm-supv_8xb128-80e_h36m.py => pose-lift_videopose3d-243frm-supv_8xb128-160e_h36m.py} (98%) rename configs/body_3d_keypoint/pose_lift/h36m/{pose-lift_videopose3d-27frm-supv_8xb128-80e_h36m.py => pose-lift_videopose3d-27frm-supv_8xb128-120e_h36m.py} (98%) rename configs/body_3d_keypoint/pose_lift/h36m/{pose-lift_videopose3d-81frm-supv_8xb128-80e_h36m.py => pose-lift_videopose3d-81frm-supv_8xb128-160e_h36m.py} (98%) diff --git a/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_videopose3d-1frm-supv-cpn-ft_8xb128-80e_h36m.py b/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_videopose3d-1frm-supv-cpn-ft_8xb128-160e_h36m.py similarity index 98% rename from configs/body_3d_keypoint/pose_lift/h36m/pose-lift_videopose3d-1frm-supv-cpn-ft_8xb128-80e_h36m.py rename to configs/body_3d_keypoint/pose_lift/h36m/pose-lift_videopose3d-1frm-supv-cpn-ft_8xb128-160e_h36m.py index 0cbf89142d..c1190fe83e 100644 --- a/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_videopose3d-1frm-supv-cpn-ft_8xb128-80e_h36m.py +++ b/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_videopose3d-1frm-supv-cpn-ft_8xb128-160e_h36m.py @@ -7,7 +7,7 @@ type='Pose3dLocalVisualizer', vis_backends=vis_backends, name='visualizer') # runtime -train_cfg = dict(max_epochs=80, val_interval=10) +train_cfg = dict(max_epochs=160, val_interval=10) # optimizer optim_wrapper = dict(optimizer=dict(type='Adam', lr=1e-4)) diff --git a/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_videopose3d-243frm-supv_8xb128-80e_h36m.py b/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_videopose3d-243frm-supv_8xb128-160e_h36m.py similarity index 98% rename from configs/body_3d_keypoint/pose_lift/h36m/pose-lift_videopose3d-243frm-supv_8xb128-80e_h36m.py rename to configs/body_3d_keypoint/pose_lift/h36m/pose-lift_videopose3d-243frm-supv_8xb128-160e_h36m.py index 0f311ac5cf..0d241c498f 100644 --- a/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_videopose3d-243frm-supv_8xb128-80e_h36m.py +++ b/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_videopose3d-243frm-supv_8xb128-160e_h36m.py @@ -7,7 +7,7 @@ type='Pose3dLocalVisualizer', vis_backends=vis_backends, name='visualizer') # runtime -train_cfg = dict(max_epochs=80, val_interval=10) +train_cfg = dict(max_epochs=160, val_interval=10) # optimizer optim_wrapper = dict(optimizer=dict(type='Adam', lr=1e-3)) diff --git a/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_videopose3d-27frm-supv_8xb128-80e_h36m.py b/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_videopose3d-27frm-supv_8xb128-120e_h36m.py similarity index 98% rename from configs/body_3d_keypoint/pose_lift/h36m/pose-lift_videopose3d-27frm-supv_8xb128-80e_h36m.py rename to configs/body_3d_keypoint/pose_lift/h36m/pose-lift_videopose3d-27frm-supv_8xb128-120e_h36m.py index 2589b493a6..803f907b7b 100644 --- a/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_videopose3d-27frm-supv_8xb128-80e_h36m.py +++ b/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_videopose3d-27frm-supv_8xb128-120e_h36m.py @@ -7,7 +7,7 @@ type='Pose3dLocalVisualizer', vis_backends=vis_backends, name='visualizer') # runtime -train_cfg = dict(max_epochs=80, val_interval=10) +train_cfg = dict(max_epochs=160, val_interval=10) # optimizer optim_wrapper = dict(optimizer=dict(type='Adam', lr=1e-3)) diff --git a/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_videopose3d-81frm-supv_8xb128-80e_h36m.py b/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_videopose3d-81frm-supv_8xb128-160e_h36m.py similarity index 98% rename from configs/body_3d_keypoint/pose_lift/h36m/pose-lift_videopose3d-81frm-supv_8xb128-80e_h36m.py rename to configs/body_3d_keypoint/pose_lift/h36m/pose-lift_videopose3d-81frm-supv_8xb128-160e_h36m.py index f2c27e423d..4b370fe76e 100644 --- a/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_videopose3d-81frm-supv_8xb128-80e_h36m.py +++ b/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_videopose3d-81frm-supv_8xb128-160e_h36m.py @@ -7,7 +7,7 @@ type='Pose3dLocalVisualizer', vis_backends=vis_backends, name='visualizer') # runtime -train_cfg = dict(max_epochs=80, val_interval=10) +train_cfg = dict(max_epochs=160, val_interval=10) # optimizer optim_wrapper = dict(optimizer=dict(type='Adam', lr=1e-3)) From 0d3b47ae0f7826ca8bf91d66e11bd2ee76757eb6 Mon Sep 17 00:00:00 2001 From: LareinaM Date: Mon, 10 Jul 2023 11:49:46 +0800 Subject: [PATCH 28/30] add transforms and full config --- ...-lift_motionbert-243frm_8xb32-120e_h36m.py | 41 +++++++++++++++++-- .../datasets/datasets/body3d/h36m_dataset.py | 12 +++++- .../datasets/transforms/pose3d_transforms.py | 18 ++++++-- .../test_transforms/test_pose3d_transforms.py | 20 +++++++++ 4 files changed, 83 insertions(+), 8 deletions(-) diff --git a/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_motionbert-243frm_8xb32-120e_h36m.py b/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_motionbert-243frm_8xb32-120e_h36m.py index 5339d1ff13..1d82ae2a4f 100644 --- a/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_motionbert-243frm_8xb32-120e_h36m.py +++ b/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_motionbert-243frm_8xb32-120e_h36m.py @@ -7,11 +7,16 @@ type='Pose3dLocalVisualizer', vis_backends=vis_backends, name='visualizer') # runtime -train_cfg = None +train_cfg = dict(max_epochs=120, val_interval=10) # optimizer +optim_wrapper = dict( + optimizer=dict(type='AdamW', lr=0.0002, weight_decay=0.01)) # learning policy +param_scheduler = [ + dict(type='ExponentialLR', gamma=0.99, end=120, by_epoch=True) +] auto_scale_lr = dict(base_batch_size=512) @@ -57,6 +62,18 @@ data_root = 'data/h36m/' # pipelines +train_pipeline = [ + dict( + type='RandomFlipAroundRoot', + keypoints_flip_cfg={}, + target_flip_cfg={}, + flip_image=True), + dict(type='GenerateTarget', encoder=codec), + dict( + type='PackPoseInputs', + meta_keys=('id', 'category_id', 'target_img_path', 'flip_indices', + 'factor', 'camera_param')) +] val_pipeline = [ dict(type='GenerateTarget', encoder=codec), dict( @@ -66,9 +83,27 @@ ] # data loaders +train_dataloader = dict( + batch_size=32, + prefetch_factor=4, + pin_memory=True, + num_workers=2, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=True), + dataset=dict( + type=dataset_type, + ann_file='annotation_body3d/fps50/h36m_train.npz', + seq_len=1, + multiple_target=243, + multiple_target_step=81, + camera_param_file='annotation_body3d/cameras.pkl', + data_root=data_root, + data_prefix=dict(img='images/'), + pipeline=train_pipeline, + )) + val_dataloader = dict( batch_size=32, - shuffle=False, prefetch_factor=4, pin_memory=True, num_workers=2, @@ -78,8 +113,8 @@ type=dataset_type, ann_file='annotation_body3d/fps50/h36m_test.npz', seq_len=1, - multiple_target=243, seq_step=1, + multiple_target=243, camera_param_file='annotation_body3d/cameras.pkl', data_root=data_root, data_prefix=dict(img='images/'), diff --git a/mmpose/datasets/datasets/body3d/h36m_dataset.py b/mmpose/datasets/datasets/body3d/h36m_dataset.py index 59cd358aa9..b7a4f71d65 100644 --- a/mmpose/datasets/datasets/body3d/h36m_dataset.py +++ b/mmpose/datasets/datasets/body3d/h36m_dataset.py @@ -47,6 +47,8 @@ class Human36mDataset(BaseMocapDataset): Default: 1. multiple_target (int): If larger than 0, merge every ``multiple_target`` sequence together. Default: 0. + multiple_target_step (int): The interval for merging sequence. Only + valid when ``multiple_target`` is larger than 0. Default: 0. pad_video_seq (bool): Whether to pad the video so that poses will be predicted for every frame in the video. Default: ``False``. causal (bool): If set to ``True``, the rightmost input frame will be @@ -110,6 +112,7 @@ def __init__(self, seq_len: int = 1, seq_step: int = 1, multiple_target: int = 0, + multiple_target_step: int = 0, pad_video_seq: bool = False, causal: bool = True, subset_frac: float = 1.0, @@ -151,6 +154,10 @@ def __init__(self, assert exists(factor_file), 'Annotation file does not exist.' self.factor_file = factor_file + if multiple_target > 0 and multiple_target_step == 0: + multiple_target_step = multiple_target + self.multiple_target_step = multiple_target_step + super().__init__( ann_file=ann_file, seq_len=seq_len, @@ -191,8 +198,9 @@ def get_sequence_indices(self) -> List[List[int]]: n_frame = len(_indices) seqs_from_video = [ _indices[i:(i + self.multiple_target):_step] - for i in range(0, n_frame, self.multiple_target) - ][:n_frame // self.multiple_target] + for i in range(0, n_frame, self.multiple_target_step) + ][:(n_frame + self.multiple_target_step - + self.multiple_target) // self.multiple_target_step] sequence_indices.extend(seqs_from_video) else: diff --git a/mmpose/datasets/transforms/pose3d_transforms.py b/mmpose/datasets/transforms/pose3d_transforms.py index e6559fa398..2149d7cb30 100644 --- a/mmpose/datasets/transforms/pose3d_transforms.py +++ b/mmpose/datasets/transforms/pose3d_transforms.py @@ -25,6 +25,8 @@ class RandomFlipAroundRoot(BaseTransform): flip_prob (float): Probability of flip. Default: 0.5. flip_camera (bool): Whether to flip horizontal distortion coefficients. Default: ``False``. + flip_image (bool): Whether to flip keypoints horizontally according + to image size. Default: ``False``. Required keys: keypoints @@ -39,14 +41,16 @@ def __init__(self, keypoints_flip_cfg, target_flip_cfg, flip_prob=0.5, - flip_camera=False): + flip_camera=False, + flip_image=False): self.keypoints_flip_cfg = keypoints_flip_cfg self.target_flip_cfg = target_flip_cfg self.flip_prob = flip_prob self.flip_camera = flip_camera + self.flip_image = flip_image def transform(self, results: Dict) -> dict: - """The transform function of :class:`ZeroCenterPose`. + """The transform function of :class:`RandomFlipAroundRoot`. See ``transform()`` method of :class:`BaseTransform` for details. @@ -76,6 +80,15 @@ def transform(self, results: Dict) -> dict: flip_indices = results['flip_indices'] # flip joint coordinates + _camera_param = deepcopy(results['camera_param']) + if self.flip_image: + assert 'camera_param' in results, \ + 'Camera parameters are missing.' + assert 'w' in _camera_param + w = _camera_param['w'] / 2 + self.keypoints_flip_cfg['center_x'] = w + self.target_flip_cfg['center_x'] = w + keypoints, keypoints_visible = flip_keypoints_custom_center( keypoints, keypoints_visible, flip_indices, **self.keypoints_flip_cfg) @@ -92,7 +105,6 @@ def transform(self, results: Dict) -> dict: if self.flip_camera: assert 'camera_param' in results, \ 'Camera parameters are missing.' - _camera_param = deepcopy(results['camera_param']) assert 'c' in _camera_param _camera_param['c'][0] *= -1 diff --git a/tests/test_datasets/test_transforms/test_pose3d_transforms.py b/tests/test_datasets/test_transforms/test_pose3d_transforms.py index db7a612dee..b87931bb74 100644 --- a/tests/test_datasets/test_transforms/test_pose3d_transforms.py +++ b/tests/test_datasets/test_transforms/test_pose3d_transforms.py @@ -153,3 +153,23 @@ def test_transform(self): -self.data_info['camera_param']['p'][0], camera2['p'][0], atol=4.)) + + # test flipping w.r.t. image + transform = RandomFlipAroundRoot({}, {}, flip_prob=1, flip_image=True) + results = deepcopy(self.data_info) + results = transform(results) + kpts2 = results['keypoints'] + tar2 = results['lifting_target'] + + camera_param = results['camera_param'] + for left, right in enumerate(flip_indices): + self.assertTrue( + np.allclose( + camera_param['w'] - kpts1[0][left][:1], + kpts2[0][right][:1], + atol=4.)) + self.assertTrue( + np.allclose(kpts1[0][left][1:], kpts2[0][right][1:], atol=4.)) + self.assertTrue( + np.allclose( + tar1[..., left, 1:], tar2[..., right, 1:], atol=4.)) From 16d1c0115ee1328c7f726f045b898f174bdfd3e6 Mon Sep 17 00:00:00 2001 From: LareinaM Date: Wed, 12 Jul 2023 15:19:18 +0800 Subject: [PATCH 29/30] add loss, change init method --- ...-lift_motionbert-243frm_8xb32-120e_h36m.py | 16 ++-- mmpose/codecs/motionbert_label.py | 9 ++- .../motion_regression_head.py | 17 ++-- mmpose/models/losses/regression_loss.py | 78 +++++++++++++++++++ 4 files changed, 104 insertions(+), 16 deletions(-) diff --git a/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_motionbert-243frm_8xb32-120e_h36m.py b/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_motionbert-243frm_8xb32-120e_h36m.py index 1d82ae2a4f..88f6c3897d 100644 --- a/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_motionbert-243frm_8xb32-120e_h36m.py +++ b/configs/body_3d_keypoint/pose_lift/h36m/pose-lift_motionbert-243frm_8xb32-120e_h36m.py @@ -31,7 +31,13 @@ ) # codec settings -codec = dict( +train_codec = dict( + type='MotionBERTLabel', + num_keypoints=17, + concat_vis=True, + rootrel=True, + factor_label=False) +val_codec = dict( type='MotionBERTLabel', num_keypoints=17, concat_vis=True, rootrel=True) # model settings @@ -52,8 +58,8 @@ in_channels=512, out_channels=3, embedding_size=512, - loss=dict(type='MPJPELoss'), - decoder=codec, + loss=dict(type='MPJPEVelocityJointLoss'), + decoder=val_codec, ), ) @@ -68,14 +74,14 @@ keypoints_flip_cfg={}, target_flip_cfg={}, flip_image=True), - dict(type='GenerateTarget', encoder=codec), + dict(type='GenerateTarget', encoder=train_codec), dict( type='PackPoseInputs', meta_keys=('id', 'category_id', 'target_img_path', 'flip_indices', 'factor', 'camera_param')) ] val_pipeline = [ - dict(type='GenerateTarget', encoder=codec), + dict(type='GenerateTarget', encoder=val_codec), dict( type='PackPoseInputs', meta_keys=('id', 'category_id', 'target_img_path', 'flip_indices', diff --git a/mmpose/codecs/motionbert_label.py b/mmpose/codecs/motionbert_label.py index ce3a9b4f65..d0c8cd0d40 100644 --- a/mmpose/codecs/motionbert_label.py +++ b/mmpose/codecs/motionbert_label.py @@ -34,6 +34,8 @@ class MotionBERTLabel(BaseKeypointCodec): Default: ``False``. rootrel (bool): If true, the root keypoint will be set to the coordinate origin. Default: ``False``. + factor_label (bool): If true, the label will be multiplied by a factor. + Default: ``True``. """ auxiliary_encode_keys = { @@ -46,7 +48,8 @@ def __init__(self, remove_root: bool = False, save_index: bool = False, concat_vis: bool = False, - rootrel: bool = False): + rootrel: bool = False, + factor_label: bool = True): super().__init__() self.num_keypoints = num_keypoints @@ -55,6 +58,7 @@ def __init__(self, self.save_index = save_index self.concat_vis = concat_vis self.rootrel = rootrel + self.factor_label = factor_label def encode(self, keypoints: np.ndarray, @@ -144,7 +148,8 @@ def encode(self, factor = factor_ if factor.ndim == 1: factor = factor[:, None] - lifting_target_label *= factor[..., None] + if self.factor_label: + lifting_target_label *= factor[..., None] if self.concat_vis: keypoints_visible_ = keypoints_visible diff --git a/mmpose/models/heads/regression_heads/motion_regression_head.py b/mmpose/models/heads/regression_heads/motion_regression_head.py index ae146a439b..a0037180c7 100644 --- a/mmpose/models/heads/regression_heads/motion_regression_head.py +++ b/mmpose/models/heads/regression_heads/motion_regression_head.py @@ -6,7 +6,7 @@ import torch from torch import Tensor, nn -from mmpose.evaluation.functional import keypoint_pck_accuracy +from mmpose.evaluation.functional import keypoint_mpjpe from mmpose.registry import KEYPOINT_CODECS, MODELS from mmpose.utils.tensor_utils import to_numpy from mmpose.utils.typing import (ConfigType, OptConfigType, OptSampleList, @@ -142,11 +142,11 @@ def loss(self, pred_outputs = self.forward(inputs) - lifting_target_label = torch.cat([ + lifting_target_label = torch.stack([ d.gt_instance_labels.lifting_target_label for d in batch_data_samples ]) - lifting_target_weights = torch.cat([ + lifting_target_weights = torch.stack([ d.gt_instance_labels.lifting_target_weights for d in batch_data_samples ]) @@ -159,19 +159,18 @@ def loss(self, losses.update(loss_pose3d=loss) # calculate accuracy - _, avg_acc, _ = keypoint_pck_accuracy( + mpjpe_err = keypoint_mpjpe( pred=to_numpy(pred_outputs), gt=to_numpy(lifting_target_label), - mask=to_numpy(lifting_target_weights) > 0, - thr=0.05, - norm_factor=np.ones((pred_outputs.size(0), 3), dtype=np.float32)) + mask=to_numpy(lifting_target_weights) > 0) - mpjpe_pose = torch.tensor(avg_acc, device=lifting_target_label.device) + mpjpe_pose = torch.tensor( + mpjpe_err, device=lifting_target_label.device) losses.update(mpjpe=mpjpe_pose) return losses @property def default_init_cfg(self): - init_cfg = [dict(type='Normal', layer=['Linear'], std=0.01, bias=0)] + init_cfg = [dict(type='TruncNormal', layer=['Linear'], std=0.02)] return init_cfg diff --git a/mmpose/models/losses/regression_loss.py b/mmpose/models/losses/regression_loss.py index 9a64a4adfe..b50ad99f04 100644 --- a/mmpose/models/losses/regression_loss.py +++ b/mmpose/models/losses/regression_loss.py @@ -365,6 +365,84 @@ def forward(self, output, target, target_weight=None): return loss * self.loss_weight +@MODELS.register_module() +class MPJPEVelocityJointLoss(nn.Module): + """MPJPE (Mean Per Joint Position Error) loss. + + Args: + loss_weight (float): Weight of the loss. Default: 1.0. + lambda_scale (float): Factor of the N-MPJPE loss. Default: 0.5. + lambda_3d_velocity (float): Factor of the velocity loss. Default: 20.0. + """ + + def __init__(self, + use_target_weight=False, + loss_weight=1., + lambda_scale=0.5, + lambda_3d_velocity=20.0): + super().__init__() + self.use_target_weight = use_target_weight + self.loss_weight = loss_weight + self.lambda_scale = lambda_scale + self.lambda_3d_velocity = lambda_3d_velocity + + def forward(self, output, target, target_weight=None): + """Forward function. + + Note: + - batch_size: N + - num_keypoints: K + - dimension of keypoints: D (D=2 or D=3) + + Args: + output (torch.Tensor[N, K, D]): Output regression. + target (torch.Tensor[N, K, D]): Target regression. + target_weight (torch.Tensor[N,K,D]): + Weights across different joint types. + """ + norm_output = torch.mean( + torch.sum(torch.square(output), dim=-1, keepdim=True), + dim=-2, + keepdim=True) + norm_target = torch.mean( + torch.sum(target * output, dim=-1, keepdim=True), + dim=-2, + keepdim=True) + + velocity_output = output[..., 1:, :, :] - output[..., :-1, :, :] + velocity_target = target[..., 1:, :, :] - target[..., :-1, :, :] + + if self.use_target_weight: + assert target_weight is not None + mpjpe = torch.mean( + torch.norm((output - target) * target_weight, dim=-1)) + + nmpjpe = torch.mean( + torch.norm( + (norm_target / norm_output * output - target) * + target_weight, + dim=-1)) + + loss_3d_velocity = torch.mean( + torch.norm( + (velocity_output - velocity_target) * target_weight, + dim=-1)) + else: + mpjpe = torch.mean(torch.norm(output - target, dim=-1)) + + nmpjpe = torch.mean( + torch.norm( + norm_target / norm_output * output - target, dim=-1)) + + loss_3d_velocity = torch.mean( + torch.norm(velocity_output - velocity_target, dim=-1)) + + loss = mpjpe + nmpjpe * self.lambda_scale + \ + loss_3d_velocity * self.lambda_3d_velocity + + return loss * self.loss_weight + + @MODELS.register_module() class MPJPELoss(nn.Module): """MPJPE (Mean Per Joint Position Error) loss. From b184f994ce2452c2740bbfea4ee4d5eabc149ef1 Mon Sep 17 00:00:00 2001 From: LareinaM Date: Thu, 13 Jul 2023 14:59:41 +0800 Subject: [PATCH 30/30] add docstring --- .../codecs/utils/camera_image_projection.py | 36 +++++++++++++++++-- 1 file changed, 34 insertions(+), 2 deletions(-) diff --git a/mmpose/codecs/utils/camera_image_projection.py b/mmpose/codecs/utils/camera_image_projection.py index 847062ce7e..5ed4d14109 100644 --- a/mmpose/codecs/utils/camera_image_projection.py +++ b/mmpose/codecs/utils/camera_image_projection.py @@ -1,8 +1,27 @@ # Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, Tuple + import numpy as np -def camera_to_image_coord(root_index, kpts_3d_cam, camera_param): +def camera_to_image_coord(root_index: int, kpts_3d_cam: np.ndarray, + camera_param: Dict) -> Tuple[np.ndarray, np.ndarray]: + """Project keypoints from camera space to image space and calculate factor. + + Args: + root_index (int): Index for root keypoint. + kpts_3d_cam (np.ndarray): Keypoint coordinates in camera space in + shape (N, K, D). + camera_param (dict): Parameters for the camera. + + Returns: + tuple: + - kpts_3d_image (np.ndarray): Keypoint coordinates in image space in + shape (N, K, D). + - factor (np.ndarray): The scaling factor that maps keypoints from + image space to camera space in shape (N, ). + """ + root = kpts_3d_cam[..., root_index, :] tl_kpt = root.copy() tl_kpt[..., :2] -= 1.0 @@ -28,7 +47,20 @@ def camera_to_image_coord(root_index, kpts_3d_cam, camera_param): return kpts_3d_image, factor -def camera_to_pixel(kpts_3d, fx, fy, cx, cy): +def camera_to_pixel(kpts_3d: np.ndarray, fx: float, fy: float, cx: float, + cy: float) -> np.ndarray: + """Project keypoints from camera space to image space. + + Args: + kpts_3d (np.ndarray): Keypoint coordinates in camera space. + fx (float): x-coordinate of camera's focal length. + fy (float): y-coordinate of camera's focal length. + cx (float): x-coordinate of image center. + cy (float): y-coordinate of image center. + + Returns: + pose_2d (np.ndarray): Projected keypoint coordinates in image space. + """ pose_2d = kpts_3d[..., :2] / kpts_3d[..., 2:3] pose_2d[..., 0] *= fx pose_2d[..., 1] *= fy