Skip to content

Commit

Permalink
Flesh out mediapipe model info (#15)
Browse files Browse the repository at this point in the history
Co-authored-by: aaroncherian <[email protected]>
  • Loading branch information
philipqueen and aaroncherian authored Aug 28, 2024
1 parent d3f9cd4 commit 915ef7b
Show file tree
Hide file tree
Showing 20 changed files with 931 additions and 118 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -166,3 +166,4 @@ cython_debug/
*.jpeg

skellytracker/utilities/quine_directory_printer/output/*
recorded_objects.npy
1 change: 1 addition & 0 deletions skellytracker/RUN_ME.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def main(demo_tracker: str = "mediapipe_holistic_tracker"):
charuco_squares_y = 5
number_of_charuco_markers = (charuco_squares_x - 1) * (charuco_squares_y - 1)
charuco_ids = [str(index) for index in range(number_of_charuco_markers)]

CharucoTracker(
tracked_object_names=charuco_ids,
squares_x=charuco_squares_x,
Expand Down
1 change: 1 addition & 0 deletions skellytracker/SINGLE_IMAGE_RUN.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
charuco_squares_y = 5
number_of_charuco_markers = (charuco_squares_x - 1) * (charuco_squares_y - 1)
charuco_ids = [str(index) for index in range(number_of_charuco_markers)]

CharucoTracker(
tracked_object_names=charuco_ids,
squares_x=charuco_squares_x,
Expand Down
59 changes: 43 additions & 16 deletions skellytracker/process_folder_of_videos.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
from pydantic import BaseModel


from skellytracker.system.constants import BASE_2D_FILE_NAME
from skellytracker.trackers.base_tracker.base_tracker import BaseTracker
from skellytracker.trackers.base_tracker.model_info import ModelInfo
from skellytracker.trackers.bright_point_tracker.brightest_point_tracker import (
BrightestPointTracker,
)
Expand Down Expand Up @@ -37,16 +39,18 @@

logger = logging.getLogger(__name__)

file_name_dictionary = {
"MediapipeHolisticTracker": "mediapipe2dData_numCams_numFrames_numTrackedPoints_pixelXY.npy",
"YOLOMediapipeComboTracker": "mediapipe2dData_numCams_numFrames_numTrackedPoints_pixelXY.npy",
"YOLOPoseTracker": "yolo2dData_numCams_numFrames_numTrackedPoints_pixelXY.npy",
"BrightestPointTracker": "brightestPoint2dData_numCams_numFrames_numTrackedPoints_pixelXY.npy",
}
try:
from skellytracker.trackers.openpose_tracker.openpose_tracker import (
OpenPoseTracker,
)
except ModuleNotFoundError:
print("To use openpose_tracker, install skellytracker[openpose]")

logger = logging.getLogger(__name__)


def process_folder_of_videos(
tracker_name: str,
model_info: ModelInfo,
tracking_params: BaseModel,
synchronized_video_path: Path,
output_folder_path: Optional[Path] = None,
Expand All @@ -57,7 +61,7 @@ def process_folder_of_videos(
Process a folder of synchronized videos with the given tracker.
Tracked data will be saved to a .npy file with the shape (numCams, numFrames, numTrackedPoints, pixelXYZ).
:param tracker_name: Tracker to use.
:param model_info: Model info for tracker.
:param tracking_params: Tracking parameters to use.
:param synchronized_video_path: Path to folder of synchronized videos.
:param output_folder_path: Path to save tracked data to.
Expand All @@ -72,7 +76,7 @@ def process_folder_of_videos(
else:
num_processes = min(num_processes, len(video_paths), cpu_count() - 1)

file_name = file_name_dictionary[tracker_name]
file_name = model_info.name + "_" + BASE_2D_FILE_NAME
synchronized_video_path = Path(synchronized_video_path)
if output_folder_path is None:
output_folder_path = (
Expand All @@ -89,7 +93,7 @@ def process_folder_of_videos(
annotated_video_path.mkdir(parents=True, exist_ok=True)

tasks = [
(tracker_name, tracking_params, video_path, annotated_video_path)
(model_info.tracker_name, tracking_params, video_path, annotated_video_path)
for video_path in video_paths
]

Expand Down Expand Up @@ -126,9 +130,14 @@ def process_single_video(
:param annotated_video_path: Path to save annotated video to.
:return: Array of tracking data
"""
video_name = (
video_path.stem + "_mediapipe.mp4"
) # TODO: fix it so blender output doesn't require mediapipe addendum here

if tracker_name == "OpenPoseTracker":
video_name = video_path.stem + "_openpose.avi"
else:
video_name = (
video_path.stem + "_mediapipe.mp4"
) # TODO: fix it so blender output doesn't require mediapipe addendum here

tracker = get_tracker(tracker_name=tracker_name, tracking_params=tracking_params)
logger.info(
f"Processing video: {video_name} with tracker: {tracker.__class__.__name__}"
Expand All @@ -137,7 +146,7 @@ def process_single_video(
input_video_filepath=video_path,
output_video_filepath=annotated_video_path / video_name,
save_data_bool=False,
)
) # TODO: raise a custom error here if output_array is None?
return output_array


Expand Down Expand Up @@ -177,6 +186,17 @@ def get_tracker(tracker_name: str, tracking_params: BaseModel) -> BaseTracker:
elif tracker_name == "BrightestPointTracker":
tracker = BrightestPointTracker()

elif tracker_name == "OpenPoseTracker":
tracker = OpenPoseTracker(
openpose_root_folder_path=tracking_params.openpose_root_folder_path,
output_json_folder_path=tracking_params.output_json_path,
net_resolution=tracking_params.net_resolution,
number_people_max=tracking_params.number_people_max,
track_faces=tracking_params.track_face,
track_hands=tracking_params.track_hands,
output_resolution=tracking_params.output_resolution,
)

else:
raise ValueError("Invalid tracker type")

Expand All @@ -192,19 +212,26 @@ def get_tracker_params(tracker_name: str) -> BaseModel:
return YOLOTrackingParams()
elif tracker_name == "BrightestPointTracker":
return BaseModel()
elif tracker_name == "OpenPoseTracker":
raise ValueError(
"OpenPoseTracker requires explicitly setting the OpenPose root folder path and output json path, please provide tracking params directly"
)
else:
raise ValueError("Invalid tracker type")


if __name__ == "__main__":
from skellytracker.trackers.mediapipe_tracker.mediapipe_model_info import MediapipeModelInfo

synchronized_video_path = Path(
"/Users/philipqueen/freemocap_data/recording_sessions/freemocap_sample_data/synchronized_videos"
"/Your/Path/To/freemocap_data/recording_sessions/freemocap_sample_data/synchronized_videos"
)

tracker_name = "YOLOMediapipeComboTracker"
num_processes = None

process_folder_of_videos(
tracker_name=tracker_name,
model_info=MediapipeModelInfo(),
tracking_params=get_tracker_params(tracker_name=tracker_name),
synchronized_video_path=synchronized_video_path,
num_processes=num_processes,
Expand Down
1 change: 1 addition & 0 deletions skellytracker/system/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
BASE_2D_FILE_NAME = "2dData_numCams_numFrames_numTrackedPoints_pixelXY.npy"
2 changes: 1 addition & 1 deletion skellytracker/tests/test_mediapipe_holistic_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def test_record(test_image):
assert processed_results is not None
assert processed_results.shape == (
1,
MediapipeModelInfo.num_tracked_points_total,
MediapipeModelInfo.num_tracked_points,
3,
)

Expand Down
6 changes: 3 additions & 3 deletions skellytracker/tests/test_yolo_mediapipe_combo_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def test_record_no_buffer(test_image):
assert processed_results is not None
assert processed_results.shape == (
1,
MediapipeModelInfo.num_tracked_points_total,
MediapipeModelInfo.num_tracked_points,
3,
)

Expand Down Expand Up @@ -122,7 +122,7 @@ def test_record_buffer_by_image_size(test_image):
assert processed_results is not None
assert processed_results.shape == (
1,
MediapipeModelInfo.num_tracked_points_total,
MediapipeModelInfo.num_tracked_points,
3,
)

Expand Down Expand Up @@ -186,7 +186,7 @@ def test_record_buffer_by_box_size(test_image):
assert processed_results is not None
assert processed_results.shape == (
1,
MediapipeModelInfo.num_tracked_points_total,
MediapipeModelInfo.num_tracked_points,
3,
)

Expand Down
22 changes: 19 additions & 3 deletions skellytracker/trackers/base_tracker/base_recorder.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from abc import ABC, abstractmethod
import logging
from typing import Dict, Optional
from pathlib import Path
from typing import Dict, Union, Optional

import numpy as np

Expand All @@ -26,7 +27,6 @@ def record(
Record the tracked objects as they are created by the tracker.
:param tracked_object: A tracked objects dictionary.
:param annotated_image: Image array with tracking results annotated.
:return: None
"""
pass
Expand All @@ -45,7 +45,7 @@ def clear_recorded_objects(self):
self.recorded_objects = []
self.recorded_objects_array = None

def save(self, file_path: str) -> None:
def save(self, file_path: Union[str, Path]) -> None:
"""
Save the recorded objects to a file.
Expand All @@ -58,3 +58,19 @@ def save(self, file_path: str) -> None:
recorded_objects_array = self.recorded_objects_array
logger.info(f"Saving recorded objects to {file_path}")
np.save(file_path, recorded_objects_array)


class BaseCumulativeRecorder(BaseRecorder):
"""
A base class for recording data from cumulative trackers.
Throws a descriptive error for methods that do not apply to recording data from this type of tracker.
Trackers implementing this will only use the process_tracked_objects method to get data in the proper format.
"""

def __init__(self):
super().__init__()

def record(self, tracked_objects: Dict[str, TrackedObject]) -> None:
raise NotImplementedError(
"This tracker does not support by frame recording, please use process_tracked_objects instead"
)
59 changes: 56 additions & 3 deletions skellytracker/trackers/base_tracker/base_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from tqdm import tqdm


from skellytracker.trackers.base_tracker.base_recorder import BaseRecorder
from skellytracker.trackers.base_tracker.base_recorder import BaseCumulativeRecorder, BaseRecorder
from skellytracker.trackers.base_tracker.tracked_object import TrackedObject
from skellytracker.trackers.base_tracker.video_handler import VideoHandler
from skellytracker.trackers.demo_viewers.image_demo_viewer import ImageDemoViewer
Expand Down Expand Up @@ -65,15 +65,15 @@ def process_video(
output_video_filepath: Optional[Union[str, Path]] = None,
save_data_bool: bool = False,
use_tqdm: bool = True,
) -> Optional[np.ndarray]:
) -> Union[np.ndarray, None]:
"""
Run the tracker on a video.
:param input_video_filepath: Path to video file.
:param output_video_filepath: Path to save annotated video to, does not save video if None.
:param save_data_bool: Whether to save the data to a file.
:param use_tqdm: Whether to use tqdm to show a progress bar
:return: Array of tracked keypoint data, if save_data_bool is True
:return: Array of tracked keypoint data if tracker has an associated recorder
"""

cap = cv2.VideoCapture(str(input_video_filepath))
Expand Down Expand Up @@ -160,3 +160,56 @@ def image_demo(self, image_path: Path) -> None:

image_viewer = ImageDemoViewer(self, self.__class__.__name__)
image_viewer.run(image_path=image_path)


class BaseCumulativeTracker(BaseTracker):
"""
A base class for tracking algorithms that run cumulatively, i.e are not able to process videos frame by frame.
Throws a descriptive error for the abstract methods of BaseTracker that do not apply to this type of tracker.
Trackers inheriting from this will need to overwrite the `process_video` method.
"""

def __init__(
self,
tracked_object_names: List[str],
recorder: BaseCumulativeRecorder,
**data: Any,
):
super().__init__(
tracked_object_names=tracked_object_names, recorder=recorder, **data
)

def process_image(self, **kwargs) -> None:
raise NotImplementedError(
"This tracker does not support processing individual images, please use process_video instead."
)

def annotate_image(self, **kwargs) -> None:
raise NotImplementedError(
"This tracker does not support processing individual images, please use process_video instead."
)

@abstractmethod
def process_video(
self,
input_video_filepath: Union[str, Path],
output_video_filepath: Optional[Union[str, Path]] = None,
save_data_bool: bool = False,
use_tqdm: bool = True,
**kwargs,
) -> Union[np.ndarray, None]:
"""
Run the tracker on a video.
:param input_video_filepath: Path to video file.
:param output_video_filepath: Path to save annotated video to, does not save video if None.
:param save_data_bool: Whether to save the data to a file.
:param use_tqdm: Whether to use tqdm to show a progress bar
:return: Array of tracked keypoint data
"""
pass

def image_demo(self, image_path: Path) -> None:
raise NotImplementedError(
"This tracker does not support processing individual images, please use process_video instead."
)
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@

class BaseTrackingParams(BaseModel):
num_processes: int = 1
run_image_tracking: bool = True
run_image_tracking: bool = True
13 changes: 13 additions & 0 deletions skellytracker/trackers/base_tracker/model_info.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from typing import Dict, List, Optional, Union


class ModelInfo(dict):
name: str
tracker_name: str
landmark_names: List[str]
num_tracked_points: int
tracked_object_names: Optional[list] = None
virtual_markers_definitions: Optional[Dict[str, Dict[str, List[Union[str, float]]]]] = None
segment_connections: Optional[Dict[str, Dict[str, str]]] = None
center_of_mass_definitions: Optional[Dict[str, Dict[str, float]]] = None
joint_hierarchy: Optional[Dict[str, List[str]]] = None
4 changes: 3 additions & 1 deletion skellytracker/trackers/base_tracker/video_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@ def __init__(
"""
self.output_path = output_path
fourcc = cv2.VideoWriter.fourcc(*codec)
self.video_writer = cv2.VideoWriter(str(output_path), fourcc, fps, frame_size)
self.video_writer = cv2.VideoWriter(
str(output_path), fourcc, fps, frame_size
)

def add_frame(self, frame: np.ndarray) -> None:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def record(self, tracked_objects: Dict[str, TrackedObject]) -> None:
self.recorded_objects.append(
[
deepcopy(tracked_objects[tracked_object_name])
for tracked_object_name in MediapipeModelInfo.mediapipe_tracked_object_names
for tracked_object_name in MediapipeModelInfo.tracked_object_names
]
)

Expand All @@ -27,7 +27,7 @@ def process_tracked_objects(self, **kwargs) -> np.ndarray:
self.recorded_objects_array = np.zeros(
(
len(self.recorded_objects),
MediapipeModelInfo.num_tracked_points_total,
MediapipeModelInfo.num_tracked_points,
3,
)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def __init__(
smooth_landmarks=True,
):
super().__init__(
tracked_object_names=MediapipeModelInfo.mediapipe_tracked_object_names,
tracked_object_names=MediapipeModelInfo.tracked_object_names,
recorder=MediapipeHolisticRecorder(),
)
self.mp_drawing = mp.solutions.drawing_utils
Expand Down
Loading

0 comments on commit 915ef7b

Please sign in to comment.