forked from TrackingLaboratory/tracklab
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathyolov8_api.py
75 lines (63 loc) · 2.21 KB
/
yolov8_api.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import os
import torch
import pandas as pd
from typing import Any
from tracklab.pipeline.imagelevel_module import ImageLevelModule
os.environ["YOLO_VERBOSE"] = "False"
from ultralytics import YOLO
from tracklab.utils.coordinates import ltrb_to_ltwh
import logging
log = logging.getLogger(__name__)
def collate_fn(batch):
idxs = [b[0] for b in batch]
images = [b["image"] for _, b in batch]
shapes = [b["shape"] for _, b in batch]
return idxs, (images, shapes)
class YOLOv8(ImageLevelModule):
collate_fn = collate_fn
input_columns = []
output_columns = [
"image_id",
"video_id",
"category_id",
"bbox_ltwh",
"bbox_conf",
]
def __init__(self, cfg, device, batch_size, **kwargs):
super().__init__(batch_size)
self.cfg = cfg
self.device = device
self.model = YOLO(cfg.path_to_checkpoint)
self.model.to(device)
self.id = 0
@torch.no_grad()
def preprocess(self, image, detections, metadata: pd.Series):
return {
"image": image,
"shape": (image.shape[1], image.shape[0]),
}
@torch.no_grad()
def process(self, batch: Any, detections: pd.DataFrame, metadatas: pd.DataFrame):
images, shapes = batch
results_by_image = self.model(images)
detections = []
for results, shape, (_, metadata) in zip(
results_by_image, shapes, metadatas.iterrows()
):
for bbox in results.boxes.cpu().numpy():
# check for `person` class
if bbox.cls == 0 and bbox.conf >= self.cfg.min_confidence:
detections.append(
pd.Series(
dict(
image_id=metadata.name,
bbox_ltwh=ltrb_to_ltwh(bbox.xyxy[0], shape),
bbox_conf=bbox.conf[0],
video_id=metadata.video_id,
category_id=1, # `person` class in posetrack
),
name=self.id,
)
)
self.id += 1
return detections