-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
94 lines (73 loc) · 3.8 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import cv2
import streamlit as st
COCO_CLASSES = ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck',
'boat', 'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat',
'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella',
'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite',
'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 'bottle',
'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange',
'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant',
'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors',
'teddy bear', 'hair drier', 'toothbrush']
def process_image(od, image, confidence_threshold, nms_threshold, classes=None):
od.thresh = confidence_threshold
od.nms_thresh = nms_threshold
return od.detect_get_box_in([image], box_format='ltrb', classes=classes)[0]
def annotate_image(image, detections, font=cv2.FONT_HERSHEY_SIMPLEX, color=(255, 0, 0),
font_size=1):
draw_frame = image.copy()
frame_h, frame_w, _ = image.shape
for det in detections:
bb, score, det_class = det
l, t, r, b = bb
l = max(l, 0)
t = max(t, 0)
r = min(r, frame_w)
b = min(b, frame_h)
cv2.rectangle(draw_frame, (l, t), (r, b), color, 2)
label = f'{det_class}: {score:.2f}'
cv2.putText(draw_frame, label, (l+5, t+30), font, font_size, color, 2)
return draw_frame
def process_tracks(od, tracker, image, confidence_threshold, nms_threshold, classes=None):
od.thresh = confidence_threshold
od.nms_thresh = nms_threshold
detections = od.detect_get_box_in([image], box_format='ltwh', classes=classes)[0]
return tracker.update_tracks(frame=image, raw_detections=detections)
def annotate_tracks(image, tracks, color=(255, 0, 0), font_size=1):
draw_frame = image.copy()
for track in tracks:
if not track.is_confirmed() or track.time_since_update > 0:
continue
_draw_track(draw_frame, track, color=color, font_size=font_size)
return draw_frame
def _draw_track(frame, track, font=cv2.FONT_HERSHEY_SIMPLEX, color=(255, 0, 0), font_size=1):
frame_h, frame_w, _ = frame.shape
l, t, r, b = [int(x) for x in track.to_ltrb(orig=True)]
l = max(l, 0)
t = max(t, 0)
r = min(r, frame_w)
b = min(b, frame_h)
label = f'{track.get_det_class()} {track.track_id}: {track.get_det_conf():.2f}'
cv2.rectangle(frame, (l, t), (r, b), color, 2)
cv2.putText(frame, label, (l+5, t+30), font, font_size, color, 2)
def process_byte_tracks(od, tracker, image, nms_threshold, low_confidence_threshold=0.1, classes=None):
od.thresh = low_confidence_threshold
od.nms_thresh = nms_threshold
detections = od.detect_get_box_in([image], box_format='ltwh', classes=classes)[0]
return tracker.update(detections=detections)
def annotate_byte_tracks(image, tracks, color=(255, 0, 0), font_size=1):
draw_frame = image.copy()
for track in tracks:
_draw_byte_track(draw_frame, track, color=color, font_size=font_size)
return draw_frame
def _draw_byte_track(frame, track, font=cv2.FONT_HERSHEY_SIMPLEX, color=(255, 0, 0), font_size=1):
frame_h, frame_w, _ = frame.shape
l, t, r, b = [int(x) for x in track.ltrb]
l = max(l, 0)
t = max(t, 0)
r = min(r, frame_w)
b = min(b, frame_h)
label = f'{track.det_class} {track.track_id}: {track.score:.2f}'
cv2.rectangle(frame, (l, t), (r, b), color, 2)
cv2.putText(frame, label, (l+5, t+30), font, font_size, color, 2)