-
Notifications
You must be signed in to change notification settings - Fork 765
/
label_utils.py
142 lines (114 loc) · 3.9 KB
/
label_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
"""Label utility functions
Main use: labeling, dictionary of colors,
label retrieval, loading label csv file,
drawing label on an image
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import numpy as np
import csv
import config
import os
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
from random import randint
def get_box_color(index=None):
"""Retrieve plt-compatible color string based on object index"""
colors = ['w', 'r', 'b', 'g', 'c', 'm', 'y', 'g', 'c', 'm', 'k']
if index is None:
return colors[randint(0, len(colors) - 1)]
return colors[index % len(colors)]
def get_box_rgbcolor(index=None):
"""Retrieve rgb color based on object index"""
colors = [(0, 0, 0), (255, 0, 0), (0, 0, 255), (0, 255, 0), (128, 128, 0)]
if index is None:
return colors[randint(0, len(colors) - 1)]
return colors[index % len(colors)]
def index2class(index=0):
"""Convert index (int) to class name (string)"""
classes = config.params['classes']
return classes[index]
def class2index(class_="background"):
"""Convert class name (string) to index (int)"""
classes = config.params['classes']
return classes.index(class_)
def load_csv(path):
"""Load a csv file into an np array"""
data = []
with open(path) as csv_file:
rows = csv.reader(csv_file, delimiter=',')
for row in rows:
data.append(row)
return np.array(data)
def get_label_dictionary(labels, keys):
"""Associate key (filename) to value (box coords, class)"""
dictionary = {}
for key in keys:
dictionary[key] = [] # empty boxes
for label in labels:
if len(label) != 6:
print("Incomplete label:", label[0])
continue
value = label[1:]
if value[0]==value[1]:
continue
if value[2]==value[3]:
continue
if label[-1]==0:
print("No object labelled as bg:", label[0])
continue
# box coords are float32
value = value.astype(np.float32)
# filename is key
key = label[0]
# boxes = bounding box coords and class label
boxes = dictionary[key]
boxes.append(value)
dictionary[key] = boxes
# remove dataset entries w/o labels
for key in keys:
if len(dictionary[key]) == 0:
del dictionary[key]
return dictionary
def build_label_dictionary(path):
"""Build a dict with key=filename, value=[box coords, class]"""
labels = load_csv(path)
dir_path = os.path.dirname(path)
# skip the 1st line header
labels = labels[1:]
# keys are filenames
keys = np.unique(labels[:,0])
#keys = [os.path.join(dir_path, key) for key in keys]
dictionary = get_label_dictionary(labels, keys)
dict = {}
for key in dictionary.keys():
dict[os.path.join(dir_path, key)] = np.array(dictionary[key])
classes = np.unique(labels[:,-1]).astype(int).tolist()
# insert background label 0
classes.insert(0, 0)
#print("Num of unique classes:: ", classes)
return dict, classes
def show_labels(image, labels, ax=None):
"""Draw bounding box on an object given box coords (labels[1:5])"""
if ax is None:
fig, ax = plt.subplots(1)
ax.imshow(image)
for label in labels:
# default label format is xmin, xmax, ymin, ymax
w = label[1] - label[0]
h = label[3] - label[2]
x = label[0]
y = label[2]
category = int(label[4])
color = get_box_color(category)
# Rectangle ((xmin, ymin), width, height)
rect = Rectangle((x, y),
w,
h,
linewidth=2,
edgecolor=color,
facecolor='none')
ax.add_patch(rect)
plt.show()