Skip to content

Commit

Permalink
rebase master
Browse files Browse the repository at this point in the history
  • Loading branch information
jin-s13 committed Jun 9, 2021
1 parent a1c3e06 commit d9dc88a
Show file tree
Hide file tree
Showing 14 changed files with 208 additions and 254 deletions.
2 changes: 1 addition & 1 deletion configs/_base_/datasets/interhand3d.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
dataset_info = dict(
dataset_name='interhand2d',
dataset_name='interhand3d',
paper_info=dict(
author='Moon, Gyeongsik and Yu, Shoou-I and Wen, He and '
'Shiratori, Takaaki and Lee, Kyoung Mu',
Expand Down
23 changes: 23 additions & 0 deletions demo/body3d_two_stage_img_demo.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import os.path as osp
import warnings
from argparse import ArgumentParser

import mmcv
Expand All @@ -10,6 +11,7 @@
inference_top_down_pose_model, vis_3d_pose_result)
from mmpose.apis.inference import init_pose_model
from mmpose.core import SimpleCamera
from mmpose.datasets import DatasetInfo


def _keypoint_camera_to_world(keypoints,
Expand Down Expand Up @@ -150,6 +152,16 @@ def main():
'model is supported for the 1st stage (2D pose detection)'

dataset = pose_det_model.cfg.data['test']['type']
dataset_info = pose_det_model.cfg.data['test'].get(
'dataset_info', None)
if dataset_info is None:
warnings.warn(
'Please set `dataset_info` in the config.'
'Check https://github.com/open-mmlab/mmpose/pull/663 '
'for details.', DeprecationWarning)
else:
dataset_info = DatasetInfo(dataset_info)

img_keys = list(coco.imgs.keys())

for i in mmcv.track_iter_progress(range(len(img_keys))):
Expand All @@ -174,6 +186,7 @@ def main():
bbox_thr=None,
format='xywh',
dataset=dataset,
dataset_info=dataset_info,
return_heatmap=False,
outputs=None)

Expand All @@ -193,6 +206,14 @@ def main():
'"PoseLifter" model is supported for the 2nd stage ' \
'(2D-to-3D lifting)'
dataset = pose_lift_model.cfg.data['test']['type']
dataset_info = pose_lift_model.cfg.data['test'].get('dataset_info', None)
if dataset_info is None:
warnings.warn(
'Please set `dataset_info` in the config.'
'Check https://github.com/open-mmlab/mmpose/pull/663 for details.',
DeprecationWarning)
else:
dataset_info = DatasetInfo(dataset_info)

camera_params = None
if args.camera_param_file is not None:
Expand All @@ -207,6 +228,7 @@ def main():
pose_lift_model,
pose_results_2d=[pose_det_results],
dataset=dataset,
dataset_info=dataset_info,
with_track_id=False)

image_name = pose_det_results[0]['image_name']
Expand Down Expand Up @@ -255,6 +277,7 @@ def main():
pose_lift_model,
result=pose_lift_results_vis,
img=pose_lift_results[0]['image_name'],
dataset_info=dataset_info,
out_file=out_file)


Expand Down
76 changes: 50 additions & 26 deletions mmpose/apis/inference_3d.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import warnings

import numpy as np
import torch
from mmcv.parallel import collate, scatter
Expand Down Expand Up @@ -71,6 +73,7 @@ def _collate_pose_sequence(pose_results, with_track_id=True):
def inference_pose_lifter_model(model,
pose_results_2d,
dataset,
dataset_info=None,
with_track_id=True):
"""Inference 3D pose from 2D pose sequences using a pose lifter model.
Expand Down Expand Up @@ -100,11 +103,19 @@ def inference_pose_lifter_model(model,
cfg = model.cfg
test_pipeline = Compose(cfg.test_pipeline)

flip_pairs = None
if dataset == 'Body3DH36MDataset':
flip_pairs = [[1, 4], [2, 5], [3, 6], [11, 14], [12, 15], [13, 16]]
if dataset_info is not None:
flip_pairs = dataset_info.flip_pairs
else:
raise NotImplementedError()
warnings.warn(
'dataset is deprecated.'
'Please set `dataset_info` in the config.'
'Check https://github.com/open-mmlab/mmpose/pull/663 for details.',
DeprecationWarning)
# TODO: These will be removed in the later versions.
if dataset == 'Body3DH36MDataset':
flip_pairs = [[1, 4], [2, 5], [3, 6], [11, 14], [12, 15], [13, 16]]
else:
raise NotImplementedError()

pose_sequences_2d = _collate_pose_sequence(pose_results_2d, with_track_id)

Expand Down Expand Up @@ -184,6 +195,7 @@ def vis_3d_pose_result(model,
img=None,
dataset='Body3DH36MDataset',
kpt_score_thr=0.3,
dataset_info=None,
show=False,
out_file=None):
"""Visualize the 3D pose estimation results.
Expand All @@ -192,30 +204,42 @@ def vis_3d_pose_result(model,
model (nn.Module): The loaded model.
result (list[dict])
"""
if hasattr(model, 'module'):
model = model.module

palette = np.array([[255, 128, 0], [255, 153, 51], [255, 178, 102],
[230, 230, 0], [255, 153, 255], [153, 204, 255],
[255, 102, 255], [255, 51, 255], [102, 178, 255],
[51, 153, 255], [255, 153, 153], [255, 102, 102],
[255, 51, 51], [153, 255, 153], [102, 255, 102],
[51, 255, 51], [0, 255, 0], [0, 0, 255], [255, 0, 0],
[255, 255, 255]])

if dataset == 'Body3DH36MDataset':
skeleton = [[0, 1], [1, 2], [2, 3], [0, 4], [4, 5], [5, 6], [0, 7],
[7, 8], [8, 9], [9, 10], [8, 11], [11, 12], [12, 13],
[8, 14], [14, 15], [15, 16]]

pose_kpt_color = palette[[
9, 0, 0, 0, 16, 16, 16, 9, 9, 9, 9, 16, 16, 16, 0, 0, 0
]]
pose_link_color = palette[[
0, 0, 0, 16, 16, 16, 9, 9, 9, 9, 16, 16, 16, 0, 0, 0
]]
if dataset_info is not None:
skeleton = dataset_info.skeleton
pose_kpt_color = dataset_info.pose_kpt_color
pose_link_color = dataset_info.pose_link_color
else:
raise NotImplementedError
warnings.warn(
'dataset is deprecated.'
'Please set `dataset_info` in the config.'
'Check https://github.com/open-mmlab/mmpose/pull/663 for details.',
DeprecationWarning)
# TODO: These will be removed in the later versions.
palette = np.array([[255, 128, 0], [255, 153, 51], [255, 178, 102],
[230, 230, 0], [255, 153, 255], [153, 204, 255],
[255, 102, 255], [255, 51, 255], [102, 178, 255],
[51, 153, 255], [255, 153, 153], [255, 102, 102],
[255, 51, 51], [153, 255, 153], [102, 255, 102],
[51, 255, 51], [0, 255, 0], [0, 0, 255],
[255, 0, 0], [255, 255, 255]])

if dataset == 'Body3DH36MDataset':
skeleton = [[0, 1], [1, 2], [2, 3], [0, 4], [4, 5], [5, 6], [0, 7],
[7, 8], [8, 9], [9, 10], [8, 11], [11, 12], [12, 13],
[8, 14], [14, 15], [15, 16]]

pose_kpt_color = palette[[
9, 0, 0, 0, 16, 16, 16, 9, 9, 9, 9, 16, 16, 16, 0, 0, 0
]]
pose_link_color = palette[[
0, 0, 0, 16, 16, 16, 9, 9, 9, 9, 16, 16, 16, 0, 0, 0
]]
else:
raise NotImplementedError

if hasattr(model, 'module'):
model = model.module

img = model.show_result(
result,
Expand Down
1 change: 1 addition & 0 deletions mmpose/datasets/datasets/animal/animal_macaque_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from mmcv import Config
from xtcocotools.cocoeval import COCOeval

from ....core.post_processing import oks_nms, soft_oks_nms
from ...builder import DATASETS
from .._base_ import Kpt2dSviewRgbImgTopDownDataset

Expand Down
1 change: 1 addition & 0 deletions mmpose/datasets/datasets/animal/animal_pose_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from mmcv import Config
from xtcocotools.cocoeval import COCOeval

from ....core.post_processing import oks_nms, soft_oks_nms
from ...builder import DATASETS
from .._base_ import Kpt2dSviewRgbImgTopDownDataset

Expand Down
2 changes: 1 addition & 1 deletion mmpose/datasets/datasets/body3d/body3d_h36m_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from mmcv import Config

from mmpose.core.evaluation import keypoint_mpjpe
from ...builder import DATASETS
from mmpose.datasets.datasets._base_ import Kpt3dSviewKpt2dDataset
from ...builder import DATASETS


@DATASETS.register_module()
Expand Down
Loading

0 comments on commit d9dc88a

Please sign in to comment.