Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make annotate images optional #51

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion skellytracker/trackers/base_tracker/base_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,12 @@ def __init__(
self.tracked_objects[name] = TrackedObject(object_id=name)

@abstractmethod
def process_image(self, image: np.ndarray, **kwargs) -> Dict[str, TrackedObject]:
def process_image(self, image: np.ndarray, annotate_image: bool = True, **kwargs) -> Dict[str, TrackedObject]:
"""
Process the input image and apply the tracking algorithm.

:param image: An input image.
:param annotate_image: Whether to annotate a copy of the image with the results of the tracking algorithm.
:return: A dictionary of tracked objects
"""
pass
Expand Down
12 changes: 8 additions & 4 deletions skellytracker/trackers/charuco_tracker/charuco_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

default_aruco_dictionary = cv2.aruco.getPredefinedDictionary(cv2.aruco.DICT_4X4_250)


class CharucoTracker(BaseTracker):
def __init__(
self,
Expand All @@ -36,7 +37,9 @@ def __init__(
self.tracked_object_names = tracked_object_names
self.dictionary = dictionary

def process_image(self, image: np.ndarray, **kwargs) -> Dict[str, TrackedObject]:
def process_image(
self, image: np.ndarray, annotate_image: bool = True, **kwargs
) -> Dict[str, TrackedObject]:
# Convert the image to grayscale
gray_image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)

Expand All @@ -59,9 +62,10 @@ def process_image(self, image: np.ndarray, **kwargs) -> Dict[str, TrackedObject]
self.tracked_objects[object_id].pixel_x = corner[0][0]
self.tracked_objects[object_id].pixel_y = corner[0][1]

self.annotated_image = self.annotate_image(
image=image, tracked_objects=self.tracked_objects
)
if annotate_image:
self.annotated_image = self.annotate_image(
image=image, tracked_objects=self.tracked_objects
)

return self.tracked_objects

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,9 @@ def __init__(
)
self.detector = vision.FaceLandmarker.create_from_options(options)

def process_image(self, image: np.ndarray, **kwargs) -> Dict[str, TrackedObject]:
def process_image(
self, image: np.ndarray, annotate_image: bool = True, **kwargs
) -> Dict[str, TrackedObject]:
rgb_image = cv2.cvtColor(
image, cv2.COLOR_BGR2RGB
) # TODO: may need to convert this into an `mp.Image`, but can't find documentation about that
Expand All @@ -63,11 +65,12 @@ def process_image(self, image: np.ndarray, **kwargs) -> Dict[str, TrackedObject]
blendshape.score for blendshape in results.face_blendshapes[0]
] # TODO: assumes we're only interested in 1 face, but docs say this works for multiple faces??

self.annotated_image = self.annotate_image(
image=image,
tracked_objects=self.tracked_objects,
face_landmarks=results.face_landmarks[0],
)
if annotate_image:
self.annotated_image = self.annotate_image(
image=image,
tracked_objects=self.tracked_objects,
face_landmarks=results.face_landmarks[0],
)

return self.tracked_objects

Expand Down Expand Up @@ -127,7 +130,7 @@ def get_or_download_mediapipe_blendshape_model(self) -> Path:
r.raise_for_status()
model_path.write_bytes(r.content)
return model_path


if __name__ == "__main__":
MediapipeBlendshapeTracker().demo()
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@ def __init__(
smooth_landmarks=smooth_landmarks,
)

def process_image(self, image: np.ndarray, **kwargs) -> Dict[str, TrackedObject]:
def process_image(
self, image: np.ndarray, annotate_image: bool = True, **kwargs
) -> Dict[str, TrackedObject]:
# Convert the image to RGB
rgb_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

Expand All @@ -57,9 +59,10 @@ def process_image(self, image: np.ndarray, **kwargs) -> Dict[str, TrackedObject]
"landmarks"
] = results.right_hand_landmarks

self.annotated_image = self.annotate_image(
image=image, tracked_objects=self.tracked_objects
)
if annotate_image:
self.annotated_image = self.annotate_image(
image=image, tracked_objects=self.tracked_objects
)

return self.tracked_objects

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,9 @@ def __init__(
self.bounding_box_buffer_percentage = bounding_box_buffer_percentage
self.buffer_size_method = buffer_size_method

def process_image(self, image: np.ndarray, **kwargs) -> Dict[str, TrackedObject]:
def process_image(
self, image: np.ndarray, annotate_image: bool = True, **kwargs
) -> Dict[str, TrackedObject]:

yolo_results = self.model(image, classes=0, max_det=1, verbose=False)
box_xyxy = np.asarray(yolo_results[0].boxes.xyxy.cpu()).flatten()
Expand Down Expand Up @@ -119,9 +121,10 @@ def process_image(self, image: np.ndarray, **kwargs) -> Dict[str, TrackedObject]

bbox_image = buffered_yolo_results[0].plot()

self.annotated_image = self.annotate_image(
image=bbox_image, tracked_objects=self.tracked_objects
)
if annotate_image:
self.annotated_image = self.annotate_image(
image=bbox_image, tracked_objects=self.tracked_objects
)

return self.tracked_objects

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@ def __init__(
else:
self.classes = None # None includes all classes

def process_image(self, image, **kwargs) -> Dict[str, TrackedObject]:
def process_image(
self, image, annotate_image: bool = True, **kwargs
) -> Dict[str, TrackedObject]:
results = self.model(
image,
classes=self.classes,
Expand All @@ -53,7 +55,8 @@ def process_image(self, image, **kwargs) -> Dict[str, TrackedObject]:
0
].boxes.orig_shape

self.annotated_image = self.annotate_image(image, results=results, **kwargs)
if annotate_image:
self.annotated_image = self.annotate_image(image, results=results, **kwargs)

return self.tracked_objects

Expand Down
11 changes: 7 additions & 4 deletions skellytracker/trackers/yolo_tracker/yolo_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,18 @@ def __init__(self, model_size: str = "nano"):
pytorch_model = YOLOModelInfo.model_dictionary[model_size]
self.model = YOLO(pytorch_model)

def process_image(self, image: np.ndarray, **kwargs) -> Dict[str, TrackedObject]:
def process_image(
self, image: np.ndarray, annotate_image: bool = True, **kwargs
) -> Dict[str, TrackedObject]:
# "max_det=1" argument to limit to single person tracking for now
results = self.model(image, max_det=1, verbose=False)

self.unpack_results(results)

self.annotated_image = self.annotate_image(
image=image, results=results, **kwargs
)
if annotate_image:
self.annotated_image = self.annotate_image(
image=image, results=results, **kwargs
)

return self.tracked_objects

Expand Down
Loading