diff --git a/README.md b/README.md
index 07d458c75..bd113caf6 100644
--- a/README.md
+++ b/README.md
@@ -116,6 +116,10 @@ Find detailed info on script usage (predict, coco2yolov5, coco_error_analysis) a
Find detailed info on COCO utilities (yolov5 conversion, slicing, subsampling, filtering, merging, splitting) at [COCO.md](docs/COCO.md).
+## MOT Challenge Utilities
+
+Find detailed info on MOT utilities (ground truth dataset creation, exporting tracker metrics in mot challenge format) at [MOT.md](docs/MOT.md).
+
## Adding new detection framework support
`sahi` library currently supports all [YOLOv5 models](https://github.com/ultralytics/yolov5/releases) and [MMDetection models](https://github.com/open-mmlab/mmdetection/blob/master/docs/model_zoo.md). Moreover, it is easy to add new frameworks.
diff --git a/docs/MOT.md b/docs/MOT.md
index f714ccfd7..3c84fc008 100644
--- a/docs/MOT.md
+++ b/docs/MOT.md
@@ -1,6 +1,9 @@
# MOT Utilities
-## MOT dataset creation steps:
+
+
+MOT Challenge formatted ground truth dataset creation:
+
- import required classes:
@@ -11,7 +14,7 @@ from sahi.utils.mot import MotAnnotation, MotFrame, MotVideo
- init video:
```python
-mot_video = MotVideo(export_dir="mot_video")
+mot_video = MotVideo(name="sequence_name")
```
- init first frame:
@@ -38,9 +41,135 @@ mot_frame.add_annotation(
mot_video.add_frame(mot_frame)
```
-- after adding all frames, your MOT formatted files are ready at `mot_video/` folder.
+- export in MOT challenge format:
-## Advanced MOT dataset creation:
+```python
+mot_video.export(export_dir="mot_gt", type="gt")
+```
+
+- your MOT challenge formatted ground truth files are ready under `mot_gt/sequence_name/` folder.
+
+
+
+
+Advanced MOT Challenge formatted ground truth dataset creation:
+
+
+- you can customize tracker while initializing mot video object:
+
+```python
+tracker_params = {
+ 'max_distance_between_points': 30,
+ 'min_detection_threshold': 0,
+ 'hit_inertia_min': 10,
+ 'hit_inertia_max': 12,
+ 'point_transience': 4,
+}
+# for details: https://github.com/tryolabs/norfair/tree/master/docs#arguments
+
+mot_video = MotVideo(tracker_kwargs=tracker_params)
+```
+
+- you can omit automatic track id generation and directly provide track ids of annotations:
+
+
+```python
+# create annotations with track ids:
+mot_frame.add_annotation(
+ MotAnnotation(bbox=[x_min, y_min, width, height], track_id=1)
+)
+
+mot_frame.add_annotation(
+ MotAnnotation(bbox=[x_min, y_min, width, height], track_id=2)
+)
+
+# add frame to video:
+mot_video.add_frame(mot_frame)
+
+# export in MOT challenge format without automatic track id generation:
+mot_video.export(export_dir="mot_gt", type="gt", use_tracker=False)
+```
+
+- you can overwrite the results into already present directory by adding `exist_ok=True`:
+
+```python
+mot_video.export(export_dir="mot_gt", type="gt", exist_ok=True)
+```
+
+
+
+
+MOT Challenge formatted tracker output creation:
+
+
+- import required classes:
+
+```python
+from sahi.utils.mot import MotAnnotation, MotFrame, MotVideo
+```
+
+- init video by providing video name:
+
+```python
+mot_video = MotVideo(name="sequence_name")
+```
+
+- init first frame:
+
+```python
+mot_frame = MotFrame()
+```
+
+- add tracker outputs to frame:
+
+```python
+mot_frame.add_annotation(
+ MotAnnotation(bbox=[x_min, y_min, width, height], track_id=1)
+)
+
+mot_frame.add_annotation(
+ MotAnnotation(bbox=[x_min, y_min, width, height], track_id=2)
+)
+```
+
+- add frame to video:
+
+```python
+mot_video.add_frame(mot_frame)
+```
+
+- export in MOT challenge format:
+
+```python
+mot_video.export(export_dir="mot_test", type="test")
+```
+
+- your MOT challenge formatted ground truth files are ready as `mot_test/sequence_name.txt`.
+
+
+
+
+Advanced MOT Challenge formatted tracker output creation:
+
+
+- you can enable tracker and directly provide object detector output:
+
+```python
+# add object detector outputs:
+mot_frame.add_annotation(
+ MotAnnotation(bbox=[x_min, y_min, width, height])
+)
+
+mot_frame.add_annotation(
+ MotAnnotation(bbox=[x_min, y_min, width, height])
+)
+
+# add frame to video:
+mot_video.add_frame(mot_frame)
+
+# export in MOT challenge format by applying a kalman based tracker:
+mot_video.export(export_dir="mot_gt", type="gt", use_tracker=True)
+```
- you can customize tracker while initializing mot video object:
@@ -54,5 +183,12 @@ tracker_params = {
}
# for details: https://github.com/tryolabs/norfair/tree/master/docs#arguments
-mot_video = MotVideo(export_dir="mot_video", tracker_kwargs=tracker_params)
-```
\ No newline at end of file
+mot_video = MotVideo(tracker_kwargs=tracker_params)
+```
+
+- you can overwrite the results into already present directory by adding `exist_ok=True`:
+
+```python
+mot_video.export(export_dir="mot_gt", type="gt", exist_ok=True)
+```
+
\ No newline at end of file
diff --git a/sahi/utils/mot.py b/sahi/utils/mot.py
index 248f6649c..308b98246 100644
--- a/sahi/utils/mot.py
+++ b/sahi/utils/mot.py
@@ -9,23 +9,23 @@
try:
import norfair
from norfair import Tracker, Detection
+ from norfair.tracker import TrackedObject
from norfair.metrics import PredictionsTextFile, InformationFile
except ImportError:
raise ImportError('Please run "pip install -U norfair" to install norfair first for MOT format handling.')
-class GroundTruthTextFile(PredictionsTextFile):
- def __init__(self, save_path="."):
+class MotTextFile(PredictionsTextFile):
+ def __init__(self, save_dir: str = ".", save_name: str = "gt"):
- predictions_folder = os.path.join(save_path, "gt")
- if not os.path.exists(predictions_folder):
- os.makedirs(predictions_folder)
+ if not os.path.exists(save_dir):
+ os.makedirs(save_dir)
- self.out_file_name = os.path.join(predictions_folder, "gt" + ".txt")
+ self.out_file_name = os.path.join(save_dir, save_name + ".txt")
self.frame_number = 1
- def update(self, predictions, frame_number=None):
+ def update(self, predictions: List[TrackedObject], frame_number: int = None):
if frame_number is None:
frame_number = self.frame_number
"""
@@ -68,13 +68,15 @@ def euclidean_distance(detection, tracked_object):
class MotAnnotation:
- def __init__(self, bbox: List[int], score: Optional[float] = 1):
+ def __init__(self, bbox: List[int], track_id: Optional[int] = None, score: Optional[float] = 1):
"""
Args:
bbox (List[int]): [x_min, y_min, width, height]
+ track_id: (Optional[int]): track id of the annotation
score (Optional[float])
"""
self.bbox = bbox
+ self.track_id = track_id
self.score = score
@@ -116,16 +118,70 @@ def to_norfair_detections(self, track_points: str = "bbox"):
norfair_detections.append(Detection(points=points, scores=scores))
return norfair_detections
+ def to_norfair_trackedobjects(self, track_points: str = "bbox"):
+ """
+ Args:
+ track_points (str): 'centroid' or 'bbox'. Defaults to 'bbox'.
+ """
+ tracker = Tracker(
+ distance_function=euclidean_distance,
+ distance_threshold=30,
+ detection_threshold=0,
+ hit_inertia_min=10,
+ hit_inertia_max=12,
+ point_transience=4,
+ )
+
+ tracked_object_list: List[TrackedObject] = []
+ # convert all detections to norfair detections
+ for annotation in self.annotation_list:
+ # ensure annotation.track_id is not None
+ assert annotation.track_id is not None, TypeError(
+ "to_norfair_trackedobjects() requires annotation.track_id to be set."
+ )
+ # calculate bbox points
+ xmin = annotation.bbox[0]
+ ymin = annotation.bbox[1]
+ xmax = annotation.bbox[0] + annotation.bbox[2]
+ ymax = annotation.bbox[1] + annotation.bbox[3]
+ track_id = annotation.track_id
+ scores = None
+ # calculate points as bbox or centroid
+ if track_points == "bbox":
+ points = np.array([[xmin, ymin], [xmax, ymax]]) # bbox
+ if annotation.score is not None:
+ scores = np.array([annotation.score, annotation.score])
+
+ elif track_points == "centroid":
+ points = np.array([(xmin + xmax) / 2, (ymin + ymax) / 2]) # centroid
+ if annotation.score is not None:
+ scores = np.array([annotation.score])
+ else:
+ ValueError("'track_points' should be one of ['centroid', 'bbox'].")
+ # create norfair formatted detection
+ detection = Detection(points=points, scores=scores)
+ # create trackedobject from norfair detection
+ tracked_object = TrackedObject(
+ detection,
+ tracker.hit_inertia_min,
+ tracker.hit_inertia_max,
+ tracker.initialization_delay,
+ tracker.detection_threshold,
+ period=1,
+ point_transience=tracker.point_transience,
+ filter_setup=tracker.filter_setup,
+ )
+ tracked_object.id = track_id
+ # append to tracked_object_list
+ tracked_object_list.append(tracked_object)
+ return tracked_object_list
+
class MotVideo:
- def __init__(
- self, export_dir: str = "runs/mot", track_points: str = "bbox", tracker_kwargs: Optional[Dict] = dict()
- ):
+ def __init__(self, name: Optional[str] = None, tracker_kwargs: Optional[Dict] = dict()):
"""
Args
- export_dir (str): Folder directory that will contain gt/gt.txt and seqinfo.ini
- For details: https://github.com/tryolabs/norfair/issues/42#issuecomment-819211873
- track_points (str): Track detections based on 'centroid' or 'bbox'. Defaults to 'bbox'.
+ name (str): Name of the video file.
tracker_kwargs (dict): a dict contains the tracker keys as below:
- max_distance_between_points (int)
- min_detection_threshold (float)
@@ -135,38 +191,26 @@ def __init__(
For details: https://github.com/tryolabs/norfair/tree/master/docs#arguments
"""
- self.export_dir: str = str(increment_path(Path(export_dir), exist_ok=False))
- self.track_points: str = track_points
+ self.name = name
+ self.tracker_kwargs = tracker_kwargs
- self.groundtruth_text_file: Optional[GroundTruthTextFile] = None
- self.tracker: Optional[Tracker] = None
-
- self._create_gt_file()
- self._init_tracker(
- tracker_kwargs.get("max_distance_between_points", 30),
- tracker_kwargs.get("min_detection_threshold", 0),
- tracker_kwargs.get("hit_inertia_min", 10),
- tracker_kwargs.get("hit_inertia_max", 12),
- tracker_kwargs.get("point_transience", 4),
- )
+ self.frame_list: List[MotFrame] = []
- def _create_info_file(self, seq_length: int):
+ def _create_info_file(self, seq_length: int, export_dir: str):
"""
Args:
seq_length (int): Number of frames present in video (seqLength parameter in seqinfo.ini)
For details: https://github.com/tryolabs/norfair/issues/42#issuecomment-819211873
+ export_dir (str): Folder directory that will contain exported file.
"""
# set file path
- filepath = Path(self.export_dir) / "seqinfo.ini"
+ filepath = Path(export_dir) / "seqinfo.ini"
# create folder directory if not exists
filepath.parent.mkdir(exist_ok=True)
# create seqinfo.ini file with seqLength
with open(str(filepath), "w") as file:
file.write(f"seqLength={seq_length}")
- def _create_gt_file(self):
- self.groundtruth_text_file = GroundTruthTextFile(save_path=self.export_dir)
-
def _init_tracker(
self,
max_distance_between_points: int = 30,
@@ -174,7 +218,7 @@ def _init_tracker(
hit_inertia_min: int = 10,
hit_inertia_max: int = 12,
point_transience: int = 4,
- ):
+ ) -> Tracker:
"""
Args
max_distance_between_points (int)
@@ -182,9 +226,11 @@ def _init_tracker(
hit_inertia_min (int)
hit_inertia_max (int)
point_transience (int)
+ Returns:
+ tracker: norfair.tracking.Tracker
For details: https://github.com/tryolabs/norfair/tree/master/docs#arguments
"""
- self.tracker = Tracker(
+ tracker = Tracker(
distance_function=euclidean_distance,
initialization_delay=0,
distance_threshold=max_distance_between_points,
@@ -193,10 +239,51 @@ def _init_tracker(
hit_inertia_max=hit_inertia_max,
point_transience=point_transience,
)
+ return tracker
def add_frame(self, frame: MotFrame):
assert type(frame) == MotFrame, "'frame' should be a MotFrame object."
- norfair_detections: List[Detection] = frame.to_norfair_detections(track_points=self.track_points)
- tracked_objects = self.tracker.update(detections=norfair_detections)
- self.groundtruth_text_file.update(predictions=tracked_objects)
- self._create_info_file(seq_length=self.groundtruth_text_file.frame_number)
+ self.frame_list.append(frame)
+
+ def export(self, export_dir: str = "runs/mot", type: str = "gt", use_tracker: bool = None, exist_ok=False):
+ """
+ Args
+ export_dir (str): Folder directory that will contain exported mot challenge formatted data.
+ type (str): Type of the MOT challenge export. 'gt' for groundturth data export, 'test' for tracker predictions export.
+ use_tracker (bool): Determines whether to apply kalman based tracker over frame detections or not.
+ Default is True for type='gt', False for type='test'.
+ exist_ok (bool): If True overwrites given directory.
+ """
+ assert type in ["gt", "test"], TypeError(f"'type' can be one of ['gt', 'test'], you provided: {type}")
+
+ export_dir: str = str(increment_path(Path(export_dir), exist_ok=exist_ok))
+
+ if type == "gt":
+ gt_dir = os.path.join(export_dir, self.name if self.name else "", "gt")
+ mot_text_file: MotTextFile = MotTextFile(save_dir=gt_dir, save_name="gt")
+ if use_tracker is None:
+ use_tracker = True
+ elif type == "test":
+ assert self.name is not None, TypeError("You have to set 'name' property of 'MotVideo'.")
+ mot_text_file: MotTextFile = MotTextFile(save_dir=export_dir, save_name=self.name)
+ if use_tracker is None:
+ use_tracker = False
+
+ tracker: Tracker = self._init_tracker(
+ self.tracker_kwargs.get("max_distance_between_points", 30),
+ self.tracker_kwargs.get("min_detection_threshold", 0),
+ self.tracker_kwargs.get("hit_inertia_min", 10),
+ self.tracker_kwargs.get("hit_inertia_max", 12),
+ self.tracker_kwargs.get("point_transience", 4),
+ )
+ for mot_frame in self.frame_list:
+ if use_tracker:
+ norfair_detections: List[Detection] = mot_frame.to_norfair_detections(track_points="bbox")
+ tracked_objects = tracker.update(detections=norfair_detections)
+ else:
+ tracked_objects = mot_frame.to_norfair_trackedobjects(track_points="bbox")
+ mot_text_file.update(predictions=tracked_objects)
+
+ if type == "gt":
+ info_dir = os.path.join(export_dir, self.name if self.name else "")
+ self._create_info_file(seq_length=mot_text_file.frame_number, export_dir=info_dir)
diff --git a/tests/test_motutils.py b/tests/test_motutils.py
index eb49466be..46d812bc2 100644
--- a/tests/test_motutils.py
+++ b/tests/test_motutils.py
@@ -10,7 +10,11 @@ class TestMotUtils(unittest.TestCase):
def test_mot_vid(self):
from sahi.utils.mot import MotAnnotation, MotFrame, MotVideo
- mot_video = MotVideo(export_dir="tests/data/mot/")
+ export_dir = "tests/data/mot/"
+ if os.path.isdir(export_dir):
+ shutil.rmtree(export_dir)
+
+ mot_video = MotVideo()
# frame 0
mot_frame = MotFrame()
mot_detection = MotAnnotation(bbox=[10, 10, 100, 100])
@@ -23,6 +27,24 @@ def test_mot_vid(self):
mot_detection = MotAnnotation(bbox=[95, 95, 98, 98])
mot_frame.add_annotation(mot_detection)
mot_video.add_frame(mot_frame)
+ # export
+ mot_video.export(export_dir=export_dir, type="gt", exist_ok=True)
+
+ mot_video = MotVideo(name="video.mp4")
+ # frame 0
+ mot_frame = MotFrame()
+ mot_detection = MotAnnotation(bbox=[10, 10, 100, 100], track_id=1)
+ mot_frame.add_annotation(mot_detection)
+ mot_video.add_frame(mot_frame)
+ # frame 1
+ mot_frame = MotFrame()
+ mot_detection = MotAnnotation(bbox=[12, 12, 98, 98], track_id=1)
+ mot_frame.add_annotation(mot_detection)
+ mot_detection = MotAnnotation(bbox=[95, 95, 98, 98], track_id=2)
+ mot_frame.add_annotation(mot_detection)
+ mot_video.add_frame(mot_frame)
+ # export
+ mot_video.export(export_dir=export_dir, type="test", exist_ok=True)
if __name__ == "__main__":