Skip to content

Commit

Permalink
separate ann and pred transform
Browse files Browse the repository at this point in the history
  • Loading branch information
Tau-J committed Oct 8, 2023
1 parent 048811e commit bace6e0
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 34 deletions.
4 changes: 2 additions & 2 deletions mmpose/evaluation/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
multilabel_classification_accuracy,
pose_pck_accuracy, simcc_pck_accuracy)
from .nms import nms, nms_torch, oks_nms, soft_oks_nms
from .transforms import transform_keypoints, transform_sigmas
from .transforms import transform_ann, transform_pred, transform_sigmas

__all__ = [
'keypoint_pck_accuracy', 'keypoint_auc', 'keypoint_nme', 'keypoint_epe',
'pose_pck_accuracy', 'multilabel_classification_accuracy',
'simcc_pck_accuracy', 'nms', 'oks_nms', 'soft_oks_nms', 'keypoint_mpjpe',
'nms_torch', 'transform_keypoints', 'transform_sigmas'
'nms_torch', 'transform_ann', 'transform_sigmas', 'transform_pred'
]
78 changes: 54 additions & 24 deletions mmpose/evaluation/functional/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,43 +27,73 @@ def transform_sigmas(sigmas: Union[List, np.ndarray], num_keypoints: int,
return new_sigmas


def transform_keypoints(kpt_info: Union[dict, list], num_keypoints: int,
mapping: Union[List[Tuple[int, int]],
List[Tuple[Tuple, int]]]):
"""Transforms anns and predictions of keypoints based on the mapping."""
def transform_ann(ann_info: Union[dict, list], num_keypoints: int,
mapping: Union[List[Tuple[int, int]], List[Tuple[Tuple,
int]]]):
"""Transforms annotations based on the mapping."""
if len(mapping):
source_index, target_index = map(list, zip(*mapping))
else:
source_index, target_index = [], []

list_input = True
if not isinstance(kpt_info, list):
kpt_info = [kpt_info]
if not isinstance(ann_info, list):
ann_info = [ann_info]
list_input = False

for each in kpt_info:
for each in ann_info:
if 'keypoints' in each:
keypoints = np.array(each['keypoints'])
if len(keypoints.shape) > 1:
# transform predictions
N, _, C = keypoints.shape
new_keypoints = np.zeros((N, num_keypoints, C),
dtype=keypoints.dtype)
new_keypoints[:, target_index] = keypoints[:, source_index]
each['keypoints'] = new_keypoints
else:
# transform annotations
C = 3
keypoints = keypoints.reshape(-1, C)
new_keypoints = np.zeros((num_keypoints, C),
dtype=keypoints.dtype)
new_keypoints[target_index] = keypoints[source_index]
each['keypoints'] = new_keypoints.reshape(-1).tolist()

C = 3
keypoints = keypoints.reshape(-1, C)
new_keypoints = np.zeros((num_keypoints, C), dtype=keypoints.dtype)
new_keypoints[target_index] = keypoints[source_index]
each['keypoints'] = new_keypoints.reshape(-1).tolist()

if 'num_keypoints' in each:
each['num_keypoints'] = num_keypoints

if not list_input:
ann_info = ann_info[0]

return ann_info


def transform_pred(pred_info: Union[dict, list], num_keypoints: int,
mapping: Union[List[Tuple[int, int]], List[Tuple[Tuple,
int]]]):
"""Transforms predictions based on the mapping."""
if len(mapping):
source_index, target_index = map(list, zip(*mapping))
else:
source_index, target_index = [], []

list_input = True
if not isinstance(pred_info, list):
pred_info = [pred_info]
list_input = False

for each in pred_info:
if 'keypoints' in each:
keypoints = np.array(each['keypoints'])

N, _, C = keypoints.shape
new_keypoints = np.zeros((N, num_keypoints, C),
dtype=keypoints.dtype)
new_keypoints[:, target_index] = keypoints[:, source_index]
each['keypoints'] = new_keypoints

keypoint_scores = np.array(each['keypoint_scores'])
new_scores = np.zeros((N, num_keypoints),
dtype=keypoint_scores.dtype)
new_scores[:, target_index] = keypoint_scores[:, source_index]
each['keypoint_scores'] = new_scores

if 'num_keypoints' in each:
each['num_keypoints'] = num_keypoints

if not list_input:
kpt_info = kpt_info[0]
pred_info = pred_info[0]

return kpt_info
return pred_info
10 changes: 5 additions & 5 deletions mmpose/evaluation/metrics/coco_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from mmpose.registry import METRICS
from mmpose.structures.bbox import bbox_xyxy2xywh
from ..functional import (oks_nms, soft_oks_nms, transform_keypoints,
from ..functional import (oks_nms, soft_oks_nms, transform_ann, transform_pred,
transform_sigmas)


Expand Down Expand Up @@ -394,7 +394,7 @@ def compute_metrics(self, results: list) -> Dict[str, float]:
self.coco = COCO(coco_json_path)
if self.gt_converter is not None:
for id_, ann in self.coco.anns.items():
self.coco.anns[id_] = transform_keypoints(
self.coco.anns[id_] = transform_ann(
ann, self.gt_converter['num_keypoints'],
self.gt_converter['mapping'])

Expand All @@ -405,9 +405,9 @@ def compute_metrics(self, results: list) -> Dict[str, float]:
img_id = pred['img_id']

if self.pred_converter is not None:
pred = transform_keypoints(
pred, self.pred_converter['num_keypoints'],
self.pred_converter['mapping'])
pred = transform_pred(pred,
self.pred_converter['num_keypoints'],
self.pred_converter['mapping'])

for idx, keypoints in enumerate(pred['keypoints']):

Expand Down
28 changes: 25 additions & 3 deletions tests/test_evaluation/test_functional/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

import numpy as np

from mmpose.evaluation.functional import transform_keypoints, transform_sigmas
from mmpose.evaluation.functional import (transform_ann, transform_pred,
transform_sigmas)


class TestKeypointEval(TestCase):
Expand All @@ -19,7 +20,7 @@ def test_transform_sigmas(self):
for i, j in mapping:
self.assertEqual(sigmas[i], new_sigmas[j])

def test_transform_keypoints(self):
def test_transform_ann(self):
mapping = [(3, 0), (6, 1), (16, 2), (5, 3)]
num_keypoints = 5

Expand All @@ -28,10 +29,31 @@ def test_transform_keypoints(self):
keypoints=np.random.randint(3, size=(17 * 3, )).tolist())
kpt_info_copy = deepcopy(kpt_info)

_ = transform_keypoints(kpt_info, num_keypoints, mapping)
_ = transform_ann(kpt_info, num_keypoints, mapping)

self.assertEqual(kpt_info['num_keypoints'], 5)
self.assertEqual(len(kpt_info['keypoints']), 15)
for i, j in mapping:
self.assertListEqual(kpt_info_copy['keypoints'][i * 3:i * 3 + 3],
kpt_info['keypoints'][j * 3:j * 3 + 3])

def test_transform_pred(self):
mapping = [(3, 0), (6, 1), (16, 2), (5, 3)]
num_keypoints = 5

kpt_info = dict(
num_keypoints=17,
keypoints=np.random.randint(3, size=(
1,
17,
3,
)))
kpt_info_copy = deepcopy(kpt_info)

_ = transform_pred(kpt_info, num_keypoints, mapping)

self.assertEqual(kpt_info['num_keypoints'], 5)
self.assertEqual(len(kpt_info['keypoints']), 1)
for i, j in mapping:
self.assertListEqual(kpt_info_copy['keypoints'][:, i, :],
kpt_info['keypoints'][:, j, :])

0 comments on commit bace6e0

Please sign in to comment.