Skip to content

Commit

Permalink
Handle breaking change pose_limb_color -> pose_link_color
Browse files Browse the repository at this point in the history
  • Loading branch information
ly015 committed Aug 26, 2021
1 parent d65a7ae commit 22cebf6
Show file tree
Hide file tree
Showing 14 changed files with 217 additions and 325 deletions.
17 changes: 3 additions & 14 deletions mmpose/core/visualization/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import numpy as np
import trimesh
from matplotlib import pyplot as plt
from mmcv.utils.misc import deprecated_api_warning
from mmcv.visualization.color import color_val

try:
Expand Down Expand Up @@ -94,6 +95,7 @@ def imshow_bboxes(img,
return img


@deprecated_api_warning({'pose_limb_color': 'pose_link_color'})
def imshow_keypoints(img,
pose_result,
skeleton=None,
Expand All @@ -102,8 +104,7 @@ def imshow_keypoints(img,
pose_link_color=None,
radius=4,
thickness=1,
show_keypoint_weight=False,
pose_limb_color=None):
show_keypoint_weight=False):
"""Draw keypoints and links on an image.
Args:
Expand All @@ -116,23 +117,11 @@ def imshow_keypoints(img,
to be shown. Default: 0.3.
pose_kpt_color (np.array[Nx3]`): Color of N keypoints. If None,
the keypoint will not be drawn.
pose_limb_color (np.array[Mx3]): Deprecated (see `pose_link_color).
Color of M links. If None, the links will not be drawn.
pose_link_color (np.array[Mx3]): Color of M links. If None, the
links will not be drawn.
thickness (int): Thickness of lines.
"""

# TODO: These will be removed in the later versions.
if pose_limb_color is not None:
warnings.warn(
'pose_limb_color is deprecated.'
'Please use pose_link_color instead.'
'Check https://github.com/open-mmlab/mmpose/pull/663 for details.',
DeprecationWarning)
if pose_link_color is None:
pose_link_color = pose_limb_color

img = mmcv.imread(img)
img_h, img_w, _ = img.shape

Expand Down
18 changes: 4 additions & 14 deletions mmpose/models/detectors/associative_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import mmcv
import torch
from mmcv.image import imwrite
from mmcv.utils.misc import deprecated_api_warning
from mmcv.visualization.image import imshow

from mmpose.core.evaluation import (aggregate_results, get_group_preds,
Expand Down Expand Up @@ -298,6 +299,8 @@ def forward_test(self, img, img_metas, return_heatmap=False, **kwargs):

return result

@deprecated_api_warning({'pose_limb_color': 'pose_link_color'},
cls_name='AssociativeEmbedding')
def show_result(self,
img,
result,
Expand All @@ -313,8 +316,7 @@ def show_result(self,
show=False,
show_keypoint_weight=False,
wait_time=0,
out_file=None,
pose_limb_color=None):
out_file=None):
"""Draw `result` over `img`.
Args:
Expand All @@ -327,7 +329,6 @@ def show_result(self,
to be shown. Default: 0.3.
pose_kpt_color (np.array[Nx3]`): Color of N keypoints.
If None, do not draw keypoints.
pose_limb_color (np.array[Mx3]): Deprecated (see `pose_link_color).
pose_link_color (np.array[Mx3]): Color of M links.
If None, do not draw links.
radius (int): Radius of circles.
Expand All @@ -345,17 +346,6 @@ def show_result(self,
Returns:
Tensor: Visualized image only if not `show` or `out_file`
"""

# TODO: These will be removed in the later versions.
if pose_limb_color is not None:
warnings.warn(
'pose_limb_color is deprecated.'
'Please use pose_link_color instead.'
'Check https://github.com/open-mmlab/mmpose/pull/663 for '
'details.', DeprecationWarning)
if pose_link_color is None:
pose_link_color = pose_limb_color

img = mmcv.imread(img)
img = img.copy()
img_h, img_w, _ = img.shape
Expand Down
20 changes: 4 additions & 16 deletions mmpose/models/detectors/interhand_3d.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import warnings

import mmcv
import numpy as np
from mmcv.utils.misc import deprecated_api_warning

from mmpose.core import imshow_keypoints, imshow_keypoints_3d
from ..builder import POSENETS
Expand Down Expand Up @@ -101,6 +100,8 @@ def forward_test(self, img, img_metas, **kwargs):
result = {}
return result

@deprecated_api_warning({'pose_limb_color': 'pose_link_color'},
cls_name='Interhand3D')
def show_result(self,
result,
img=None,
Expand All @@ -116,8 +117,7 @@ def show_result(self,
win_name='',
show=False,
wait_time=0,
out_file=None,
pose_limb_color=None):
out_file=None):
"""Visualize 3D pose estimation results.
Args:
Expand All @@ -138,7 +138,6 @@ def show_result(self,
thickness (int): Thickness of lines.
pose_kpt_color (np.array[Nx3]`): Color of N keypoints.
If None, do not draw keypoints.
pose_limb_color (np.array[Mx3]): Deprecated (see `pose_link_color).
pose_link_color (np.array[Mx3]): Color of M limbs.
If None, do not draw limbs.
vis_height (int): The image hight of the visualization. The width
Expand All @@ -158,17 +157,6 @@ def show_result(self,
Returns:
Tensor: Visualized img, only if not `show` or `out_file`.
"""

# TODO: These will be removed in the later versions.
if pose_limb_color is not None:
warnings.warn(
'pose_limb_color is deprecated.'
'Please use pose_link_color instead.'
'Check https://github.com/open-mmlab/mmpose/pull/663 for '
'details.', DeprecationWarning)
if pose_link_color is None:
pose_link_color = pose_limb_color

if num_instances < 0:
assert len(result) > 0
result = sorted(result, key=lambda x: x.get('track_id', 0))
Expand Down
18 changes: 4 additions & 14 deletions mmpose/models/detectors/pose_lifter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import mmcv
import numpy as np
from mmcv.utils.misc import deprecated_api_warning

from mmpose.core import imshow_bboxes, imshow_keypoints, imshow_keypoints_3d
from .. import builder
Expand Down Expand Up @@ -282,6 +283,8 @@ def forward_dummy(self, input):

return output

@deprecated_api_warning({'pose_limb_color': 'pose_link_color'},
cls_name='PoseLifter')
def show_result(self,
result,
img=None,
Expand All @@ -295,8 +298,7 @@ def show_result(self,
win_name='',
show=False,
wait_time=0,
out_file=None,
pose_limb_color=None):
out_file=None):
"""Visualize 3D pose estimation results.
Args:
Expand All @@ -312,7 +314,6 @@ def show_result(self,
links, each is a pair of joint indices.
pose_kpt_color (np.array[Nx3]`): Color of N keypoints.
If None, do not draw keypoints.
pose_limb_color (np.array[Mx3]): Deprecated (see `pose_link_color).
pose_link_color (np.array[Mx3]): Color of M links.
If None, do not draw links.
radius (int): Radius of circles.
Expand All @@ -329,17 +330,6 @@ def show_result(self,
Returns:
Tensor: Visualized img, only if not `show` or `out_file`.
"""

# TODO: These will be removed in the later versions.
if pose_limb_color is not None:
warnings.warn(
'pose_limb_color is deprecated.'
'Please use pose_link_color instead.'
'Check https://github.com/open-mmlab/mmpose/pull/663 for '
'details.', DeprecationWarning)
if pose_link_color is None:
pose_link_color = pose_limb_color

if num_instances < 0:
assert len(result) > 0
result = sorted(result, key=lambda x: x.get('track_id', 1e4))
Expand Down
18 changes: 4 additions & 14 deletions mmpose/models/detectors/top_down.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import mmcv
import numpy as np
from mmcv.image import imwrite
from mmcv.utils.misc import deprecated_api_warning
from mmcv.visualization.image import imshow

from mmpose.core import imshow_bboxes, imshow_keypoints
Expand Down Expand Up @@ -214,6 +215,8 @@ def forward_dummy(self, img):
output = self.keypoint_head(output)
return output

@deprecated_api_warning({'pose_limb_color': 'pose_link_color'},
cls_name='TopDown')
def show_result(self,
img,
result,
Expand All @@ -231,8 +234,7 @@ def show_result(self,
show=False,
show_keypoint_weight=False,
wait_time=0,
out_file=None,
pose_limb_color=None):
out_file=None):
"""Draw `result` over `img`.
Args:
Expand All @@ -246,7 +248,6 @@ def show_result(self,
bbox_color (str or tuple or :obj:`Color`): Color of bbox lines.
pose_kpt_color (np.array[Nx3]`): Color of N keypoints.
If None, do not draw keypoints.
pose_limb_color (np.array[Mx3]): Deprecated (see `pose_link_color).
pose_link_color (np.array[Mx3]): Color of M links.
If None, do not draw links.
text_color (str or tuple or :obj:`Color`): Color of texts.
Expand All @@ -265,17 +266,6 @@ def show_result(self,
Returns:
Tensor: Visualized img, only if not `show` or `out_file`.
"""

# TODO: These will be removed in the later versions.
if pose_limb_color is not None:
warnings.warn(
'pose_limb_color is deprecated.'
'Please use pose_link_color instead.'
'Check https://github.com/open-mmlab/mmpose/pull/663 for '
'details.', DeprecationWarning)
if pose_link_color is None:
pose_link_color = pose_limb_color

img = mmcv.imread(img)
img = img.copy()

Expand Down
51 changes: 1 addition & 50 deletions tests/test_datasets/test_animal_dataset.py
Original file line number Diff line number Diff line change
@@ -1,63 +1,14 @@
import copy
import tempfile

import numpy as np
import pytest
from mmcv import Config
from numpy.testing import assert_almost_equal
from tests.utils.data_utils import convert_db_to_output

from mmpose.datasets import DATASETS


def convert_db_to_output(db, batch_size=2, keys=None, is_3d=False):
outputs = []
len_db = len(db)
for i in range(0, len_db, batch_size):
if is_3d:
keypoints = np.stack([
db[j]['joints_3d'].reshape((-1, 3))
for j in range(i, min(i + batch_size, len_db))
])
else:
keypoints = np.stack([
np.hstack([
db[j]['joints_3d'].reshape((-1, 3))[:, :2],
db[j]['joints_3d_visible'].reshape((-1, 3))[:, :1]
]) for j in range(i, min(i + batch_size, len_db))
])
image_paths = [
db[j]['image_file'] for j in range(i, min(i + batch_size, len_db))
]
bbox_ids = [j for j in range(i, min(i + batch_size, len_db))]
box = np.stack([
np.array([
db[j]['center'][0], db[j]['center'][1], db[j]['scale'][0],
db[j]['scale'][1],
db[j]['scale'][0] * db[j]['scale'][1] * 200 * 200, 1.0
],
dtype=np.float32)
for j in range(i, min(i + batch_size, len_db))
])

output = {}
output['preds'] = keypoints
output['boxes'] = box
output['image_paths'] = image_paths
output['output_heatmap'] = None
output['bbox_ids'] = bbox_ids

if keys is not None:
keys = keys if isinstance(keys, list) else [keys]
for key in keys:
output[key] = [
db[j][key] for j in range(i, min(i + batch_size, len_db))
]

outputs.append(output)

return outputs


def test_animal_horse10_dataset():
dataset = 'AnimalHorse10Dataset'
dataset_class = DATASETS.get(dataset)
Expand Down
Loading

0 comments on commit 22cebf6

Please sign in to comment.