Skip to content

Commit

Permalink
Merge pull request #196 from tryolabs/multiple-trackers
Browse files Browse the repository at this point in the history
Support for object count when using multiple `Tracker` instances
  • Loading branch information
javiber authored Sep 30, 2022
2 parents d501ffb + 7c916bd commit 968f62e
Show file tree
Hide file tree
Showing 2 changed files with 224 additions and 58 deletions.
196 changes: 140 additions & 56 deletions norfair/tracker.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Callable, Hashable, List, Optional, Sequence, Union
from typing import Any, Callable, Hashable, List, Optional, Sequence, Tuple, Union

import numpy as np
from rich import print
Expand Down Expand Up @@ -115,10 +115,9 @@ def __init__(

self.distance_threshold = distance_threshold
self.detection_threshold = detection_threshold
TrackedObject.count = 0
self.reid_distance_function = reid_distance_function
self.reid_distance_threshold = reid_distance_threshold
self.abs_to_rel = None
self._obj_factory = _TrackedObjectFactory()

def update(
self,
Expand Down Expand Up @@ -157,11 +156,7 @@ def update(
"""
if coord_transformations is not None:
for det in detections:
det.absolute_points = coord_transformations.rel_to_abs(
det.absolute_points
)
self.abs_to_rel = coord_transformations.abs_to_rel
self.period = period
det.update_coordinate_transformation(coord_transformations)

# Remove stale trackers and make candidate object real if the hit counter is positive
alive_objects = []
Expand All @@ -185,7 +180,7 @@ def update(
# Update tracker
for obj in self.tracked_objects:
obj.tracker_step()
obj.abs_to_rel = self.abs_to_rel
obj.update_coordinate_transformation(coord_transformations)

# Update initialized tracked objects with detections
(
Expand All @@ -197,6 +192,7 @@ def update(
self.distance_threshold,
[o for o in alive_objects if not o.is_initializing],
detections,
period,
)

# Update not yet initialized tracked objects with yet unmatched detections
Expand All @@ -209,6 +205,7 @@ def update(
self.distance_threshold,
[o for o in alive_objects if o.is_initializing],
unmatched_detections,
period,
)

if self.reid_distance_function is not None:
Expand All @@ -218,25 +215,46 @@ def update(
self.reid_distance_threshold,
unmatched_init_trackers + dead_objects,
matched_not_init_trackers,
period,
)

# Create new tracked objects from remaining unmatched detections
for detection in unmatched_detections:
self.tracked_objects.append(
TrackedObject(
detection,
self.hit_counter_max,
self.initialization_delay,
self.pointwise_hit_counter_max,
self.detection_threshold,
self.period,
self.filter_factory,
self.past_detections_length,
self.reid_hit_counter_max,
self.abs_to_rel,
self._obj_factory.create(
initial_detection=detection,
hit_counter_max=self.hit_counter_max,
initialization_delay=self.initialization_delay,
pointwise_hit_counter_max=self.pointwise_hit_counter_max,
detection_threshold=self.detection_threshold,
period=period,
filter_factory=self.filter_factory,
past_detections_length=self.past_detections_length,
reid_hit_counter_max=self.reid_hit_counter_max,
coord_transformations=coord_transformations,
)
)

return self.get_active_objects()

@property
def current_object_count(self) -> int:
"""Number of active TrackedObjects"""
return len(self.get_active_objects())

@property
def total_object_count(self) -> int:
"""Total number of TrackedObjects initialized in the by this Tracker"""
return self._obj_factory.count

def get_active_objects(self) -> List["TrackedObject"]:
"""Get the list of active objects
Returns
-------
List["TrackedObject"]
The list of active objects
"""
return [
o
for o in self.tracked_objects
Expand Down Expand Up @@ -276,6 +294,7 @@ def _update_objects_in_place(
distance_threshold,
objects: Sequence["TrackedObject"],
candidates: Optional[Union[List["Detection"], List["TrackedObject"]]],
period: int,
):
if candidates is not None and len(candidates) > 0:
distance_matrix = self._get_distances(
Expand Down Expand Up @@ -324,7 +343,7 @@ def _update_objects_in_place(
matched_object = objects[match_obj_idx]
if match_distance < distance_threshold:
if isinstance(matched_candidate, Detection):
matched_object.hit(matched_candidate, period=self.period)
matched_object.hit(matched_candidate, period=period)
matched_object.last_distance = match_distance
matched_objects.append(matched_object)
elif isinstance(matched_candidate, TrackedObject):
Expand Down Expand Up @@ -383,6 +402,51 @@ def match_dets_and_objs(self, distance_matrix: np.ndarray, distance_threshold):
return [], []


class _TrackedObjectFactory:
global_count = 0

def __init__(self) -> None:
self.count = 0
self.initializing_count = 0

def create(
self,
initial_detection: "Detection",
hit_counter_max: int,
initialization_delay: int,
pointwise_hit_counter_max: int,
detection_threshold: float,
period: int,
filter_factory: "FilterFactory",
past_detections_length: int,
reid_hit_counter_max: Optional[int],
coord_transformations: CoordinatesTransformation,
) -> "TrackedObject":
obj = TrackedObject(
obj_factory=self,
initial_detection=initial_detection,
hit_counter_max=hit_counter_max,
initialization_delay=initialization_delay,
pointwise_hit_counter_max=pointwise_hit_counter_max,
detection_threshold=detection_threshold,
period=period,
filter_factory=filter_factory,
past_detections_length=past_detections_length,
reid_hit_counter_max=reid_hit_counter_max,
coord_transformations=coord_transformations,
)
return obj

def get_initializing_id(self) -> int:
self.initializing_count += 1
return self.initializing_count

def get_ids(self) -> Tuple[int, int]:
self.count += 1
_TrackedObjectFactory.global_count += 1
return self.count, _TrackedObjectFactory.global_count


class TrackedObject:
"""
The objects returned by the tracker's `update` function on each iteration.
Expand All @@ -396,10 +460,11 @@ class TrackedObject:
----------
estimate : np.ndarray
Where the tracker predicts the point will be in the current frame based on past detections.
A numpy array with the same shape as the detections being fed to the tracker that produced it.
id : Optional[int]
The unique identifier assigned to this object by the tracker.
The unique identifier assigned to this object by the tracker. Set to `None` if the object is initializing.
global_id : Optional[int]
The globally unique identifier assigned to this object. Set to `None` if the object is initializing
last_detection : Detection
The last detection that matched with this tracked object.
Useful if you are storing embeddings in your detections and want to do metric learning, or for debugging.
Expand All @@ -419,14 +484,11 @@ class TrackedObject:
Each new object created by the `Tracker` starts as an uninitialized `TrackedObject`,
which needs to reach a certain match rate to be converted into a full blown `TrackedObject`.
`initializing_id` is the id temporarily assigned to `TrackedObject` while they are getting initialized.
"""

count = 0
initializing_count = 0

def __init__(
self,
obj_factory: _TrackedObjectFactory,
initial_detection: "Detection",
hit_counter_max: int,
initialization_delay: int,
Expand All @@ -436,21 +498,19 @@ def __init__(
filter_factory: "FilterFactory",
past_detections_length: int,
reid_hit_counter_max: Optional[int],
abs_to_rel: Callable[[np.array], np.array],
coord_transformations: Optional[CoordinatesTransformation] = None,
):
if not isinstance(initial_detection, Detection):
print(
f"\n[red]ERROR[/red]: The detection list fed into `tracker.update()` should be composed of {Detection} objects not {type(initial_detection)}.\n"
)
exit()

self._obj_factory = obj_factory
self.dim_points = initial_detection.absolute_points.shape[1]
self.num_points = initial_detection.absolute_points.shape[0]
self.hit_counter_max: int = hit_counter_max
self.pointwise_hit_counter_max: int = pointwise_hit_counter_max
self.pointwise_hit_counter_max: int = max(pointwise_hit_counter_max, period)
self.initialization_delay = initialization_delay
if self.pointwise_hit_counter_max < period:
self.pointwise_hit_counter_max = period
self.detection_threshold: float = detection_threshold
self.initial_period: int = period
self.hit_counter: int = period
Expand All @@ -460,12 +520,14 @@ def __init__(
self.current_min_distance: Optional[float] = None
self.last_detection: "Detection" = initial_detection
self.age: int = 0
self.is_initializing_flag: bool = True
self.is_initializing: bool = self.hit_counter <= self.initialization_delay

self.initializing_id: Optional[int] = self._obj_factory.get_initializing_id()
self.id: Optional[int] = None
self.initializing_id: int = (
TrackedObject.initializing_count
) # Just for debugging
TrackedObject.initializing_count += 1
self.global_id: Optional[int] = None
if not self.is_initializing:
self._acquire_ids()

if initial_detection.scores is None:
self.detected_at_least_once_points = np.array([True] * self.num_points)
else:
Expand All @@ -486,7 +548,9 @@ def __init__(
self.filter = filter_factory.create_filter(initial_detection.absolute_points)
self.dim_z = self.dim_points * self.num_points
self.label = initial_detection.label
self.abs_to_rel = abs_to_rel
self.abs_to_rel = None
if coord_transformations is not None:
self.update_coordinate_transformation(coord_transformations)

def tracker_step(self):
self.hit_counter -= 1
Expand All @@ -500,14 +564,6 @@ def tracker_step(self):
# Advances the tracker's state
self.filter.predict()

@property
def is_initializing(self):
if self.is_initializing_flag and self.hit_counter > self.initialization_delay:
self.is_initializing_flag = False
TrackedObject.count += 1
self.id = TrackedObject.count
return self.is_initializing_flag

@property
def hit_counter_is_positive(self):
return self.hit_counter >= 0
Expand All @@ -519,9 +575,7 @@ def reid_hit_counter_is_positive(self):
@property
def estimate(self):
positions = self.filter.x.T.flatten()[: self.dim_z].reshape(-1, self.dim_points)
velocities = self.filter.x.T.flatten()[self.dim_z :].reshape(
-1, self.dim_points
)

if self.abs_to_rel is not None:
return self.abs_to_rel(positions)
return positions
Expand All @@ -546,11 +600,24 @@ def live_points(self):
return self.point_hit_counter > 0

def hit(self, detection: "Detection", period: int = 1):
"""Update tracked object with a new detection
Parameters
----------
detection : Detection
the new detection matched to this tracked object
period : int, optional
frames corresponding to the period of time since last update.
"""
self._conditionally_add_to_past_detections(detection)

self.last_detection = detection
self.hit_counter = min(self.hit_counter + 2 * period, self.hit_counter_max)

if self.is_initializing and self.hit_counter > self.initialization_delay:
self.is_initializing = False
self._acquire_ids()

# We use a kalman filter in which we consider each coordinate on each point as a sensor.
# This is a hacky way to update only certain sensors (only x, y coordinates for
# points which were detected).
Expand Down Expand Up @@ -579,13 +646,6 @@ def hit(self, detection: "Detection", period: int = 1):
np.expand_dims(detection.absolute_points.flatten(), 0).T, None, H
)

# Force points being detected for the first time to have velocity = 0
# This is needed because some detectors (like OpenPose) set points with
# low confidence to coordinates (0, 0). And when they then get their first
# real detection this creates a huge velocity vector in our KalmanFilter
# and causes the tracker to start with wildly inaccurate estimations which
# eventually coverge to the real detections.

detected_at_least_once_mask = np.array(
[(m,) * self.dim_points for m in self.detected_at_least_once_points]
).flatten()
Expand All @@ -600,6 +660,12 @@ def hit(self, detection: "Detection", period: int = 1):
detection.absolute_points.flatten(), 0
).T[first_detection_mask]

# Force points being detected for the first time to have velocity = 0
# This is needed because some detectors (like OpenPose) set points with
# low confidence to coordinates (0, 0). And when they then get their first
# real detection this creates a huge velocity vector in our KalmanFilter
# and causes the tracker to start with wildly inaccurate estimations which
# eventually coverge to the real detections.
self.filter.x[self.dim_z :][np.logical_not(detected_at_least_once_mask)] = 0
self.detected_at_least_once_points = np.logical_or(
self.detected_at_least_once_points, points_over_threshold_mask
Expand Down Expand Up @@ -651,6 +717,15 @@ def merge(self, tracked_object):
for past_detection in tracked_object.past_detections:
self._conditionally_add_to_past_detections(past_detection)

def update_coordinate_transformation(
self, coordinate_transformation: CoordinatesTransformation
):
if coordinate_transformation is not None:
self.abs_to_rel = coordinate_transformation.abs_to_rel

def _acquire_ids(self):
self.id, self.global_id = self._obj_factory.get_ids()


class Detection:
"""Detections returned by the detector must be converted to a `Detection` object before being used by Norfair.
Expand Down Expand Up @@ -693,3 +768,12 @@ def __init__(
self.label = label
self.absolute_points = self.points.copy()
self.embedding = embedding
self.age = None

def update_coordinate_transformation(
self, coordinate_transformation: CoordinatesTransformation
):
if coordinate_transformation is not None:
self.absolute_points = coordinate_transformation.rel_to_abs(
self.absolute_points
)
Loading

0 comments on commit 968f62e

Please sign in to comment.