diff --git a/norfair/tracker.py b/norfair/tracker.py index 8220230c..2a1def74 100644 --- a/norfair/tracker.py +++ b/norfair/tracker.py @@ -537,12 +537,12 @@ def __init__( self.update_coordinate_transformation(coord_transformations) def tracker_step(self): - self.hit_counter -= 1 if self.reid_hit_counter is None: if self.hit_counter <= 0: self.reid_hit_counter = self.reid_hit_counter_max else: self.reid_hit_counter -= 1 + self.hit_counter -= 1 self.point_hit_counter -= 1 self.age += 1 # Advances the tracker's state diff --git a/tests/test_tracker.py b/tests/test_tracker.py index 3a4ce103..6e3b9e8a 100644 --- a/tests/test_tracker.py +++ b/tests/test_tracker.py @@ -299,6 +299,60 @@ def test_multiple_trackers(): assert tracker2.total_object_count == 1 +def test_reid_hit_counter(): + # + # test reid hit counter and initializations + # + + # simple reid distance + def dist(new_obj, tracked_obj): + return np.linalg.norm(new_obj.estimate - tracked_obj.estimate) + + hit_counter_max = 2 + reid_hit_counter_max = 2 + + tracker = Tracker( + distance_function="euclidean", + distance_threshold=1, + hit_counter_max=hit_counter_max, + initialization_delay=1, + reid_distance_function=dist, + reid_distance_threshold=5, + reid_hit_counter_max=reid_hit_counter_max, + ) + + tracked_objects = tracker.update([Detection(points=np.array([[1, 1]]))]) + tracked_objects = tracker.update([Detection(points=np.array([[1, 1]]))]) + + # check that hit counters initialize correctly + assert len(tracked_objects) == 1 + assert tracked_objects[0].hit_counter == 2 + assert tracked_objects[0].reid_hit_counter == None + + # check that object is dead if it doesn't get matched to any detections + obj_id = tracked_objects[0].id + for _ in range(hit_counter_max + 1): + tracked_objects = tracker.update() + assert len(tracked_objects) == 0 + + # check that previous object gets back to life after reid matching + for _ in range(hit_counter_max): + tracked_objects = tracker.update([Detection(points=np.array([[2, 2]]))]) + assert len(tracked_objects) == 1 + assert tracked_objects[0].id == obj_id + assert tracked_objects[0].reid_hit_counter == None + assert tracked_objects[0].hit_counter == hit_counter_max + + # check that previous object gets eliminated after hit_counter_max + reid_hit_counter_max + 1 + for _ in range(hit_counter_max + reid_hit_counter_max + 1): + tracked_objects = tracker.update() + assert len(tracked_objects) == 0 + for _ in range(2): + tracked_objects = tracker.update([Detection(points=np.array([[1, 1]]))]) + assert len(tracked_objects) == 1 + assert tracked_objects[0].id != obj_id + + # TODO tests list: # - detections with different labels # - partial matches where some points are missing