From dcf49162af4f200136457908ce5963248a5475be Mon Sep 17 00:00:00 2001 From: wbenbihi Date: Thu, 18 Aug 2022 16:53:02 +0800 Subject: [PATCH] [FEAT] Convert joints format in HTFPersonDatapoint --- hourglass_tensorflow/types.py | 24 +++++++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/hourglass_tensorflow/types.py b/hourglass_tensorflow/types.py index ac08046..2bce059 100644 --- a/hourglass_tensorflow/types.py +++ b/hourglass_tensorflow/types.py @@ -1,5 +1,9 @@ from typing import Any +from typing import Dict from typing import List +from typing import Type +from typing import Union +from typing import Literal from typing import Optional from pydantic import BaseModel @@ -26,6 +30,24 @@ class HTFPersonDatapoint(BaseModel): person_id: int source_image: str bbox: HTFPersonBBox - joints: List[HTFPersonJoint] + joints: Union[List[HTFPersonJoint], Dict[int, HTFPersonJoint]] center: HTFPoint scale: float + + def convert_joint( + self, to=Union[Literal["list"], Literal["dict"], Type[dict], Type[list]] + ) -> None: + if to in ["list", list]: + self._convert_joints_to_list() + if to in ["dict", dict]: + self._convert_joints_to_dict() + + def _convert_joints_to_dict(self) -> None: + if isinstance(self.joints, dict): + return + self.joints = {j.id: j for j in self.joints} + + def _convert_joints_to_list(self) -> None: + if isinstance(self.joints, list): + return + self.joints = [j for j in self.joints.values()]