Skip to content

Commit

Permalink
[FEAT] Convert joints format in HTFPersonDatapoint
Browse files Browse the repository at this point in the history
  • Loading branch information
wbenbihi committed Aug 18, 2022
1 parent c2bccc9 commit dcf4916
Showing 1 changed file with 23 additions and 1 deletion.
24 changes: 23 additions & 1 deletion hourglass_tensorflow/types.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from typing import Any
from typing import Dict
from typing import List
from typing import Type
from typing import Union
from typing import Literal
from typing import Optional

from pydantic import BaseModel
Expand All @@ -26,6 +30,24 @@ class HTFPersonDatapoint(BaseModel):
person_id: int
source_image: str
bbox: HTFPersonBBox
joints: List[HTFPersonJoint]
joints: Union[List[HTFPersonJoint], Dict[int, HTFPersonJoint]]
center: HTFPoint
scale: float

def convert_joint(
self, to=Union[Literal["list"], Literal["dict"], Type[dict], Type[list]]
) -> None:
if to in ["list", list]:
self._convert_joints_to_list()
if to in ["dict", dict]:
self._convert_joints_to_dict()

def _convert_joints_to_dict(self) -> None:
if isinstance(self.joints, dict):
return
self.joints = {j.id: j for j in self.joints}

def _convert_joints_to_list(self) -> None:
if isinstance(self.joints, list):
return
self.joints = [j for j in self.joints.values()]

0 comments on commit dcf4916

Please sign in to comment.