-
Notifications
You must be signed in to change notification settings - Fork 4
/
gen_coco_tfrecord.py
137 lines (123 loc) · 5.3 KB
/
gen_coco_tfrecord.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import sys
sys.path.append('coco-text')
import coco_text
import argparse
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
import os
import io
from PIL import Image
from object_detection.utils import dataset_util
def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument('--train_or_val', required=True, type=str, help='train or val')
parser.add_argument('--cocotext_json', required=True, help='Path to cocotext.v2.json')
parser.add_argument('--coco_imgdir', required=True, help='Path to COCO/images/ directory')
parser.add_argument('--output_path', required=True, help='Path to output tfrecord')
return parser.parse_args()
def create_tf_example(ann, file_name, width, height, encoded_jpg):
x1, y1, w, h = list(map(int, ann['bbox']))
x2 = x1 + w
y2 = y1 + h
xmin = [x1 / width]
xmax = [x2 / width]
ymin = [y1 / height]
ymax = [y2 / height]
cls_text = ['Text'.encode('utf8')]
cls_idx = [1] # bbox is 'Text' only, which id is defined in label_map
filename = file_name.encode('utf8')
image_format = b'jpg'
tf_example = tf.train.Example(features=tf.train.Features(feature={
'image/height': dataset_util.int64_feature(height),
'image/width': dataset_util.int64_feature(width),
'image/filename': dataset_util.bytes_feature(filename),
'image/source_id': dataset_util.bytes_feature(filename),
'image/encoded': dataset_util.bytes_feature(encoded_jpg),
'image/format': dataset_util.bytes_feature(image_format),
'image/object/bbox/xmin': dataset_util.float_list_feature(xmin),
'image/object/bbox/xmax': dataset_util.float_list_feature(xmax),
'image/object/bbox/ymin': dataset_util.float_list_feature(ymin),
'image/object/bbox/ymax': dataset_util.float_list_feature(ymax),
'image/object/class/text': dataset_util.bytes_list_feature(cls_text),
'image/object/class/label': dataset_util.int64_list_feature(cls_idx),
}))
return tf_example
def create_tf_examples(writer, anns, path, file_name, width, height, encoded_jpg):
xmins, ymins = [], []
xmaxs, ymaxs = [], []
classes_text = []
classes = []
num_examples = 0
for ann in anns:
xmin = ann['bbox'][0]
ymin = ann['bbox'][1]
w = ann['bbox'][2]
h = ann['bbox'][3]
xmax = xmin + w
ymax = ymin + h
# normalize
xmin /= width
xmax /= width
ymin /= height
ymax /= height
if xmin < 1 and xmax < 1 and ymin < 1 and ymax < 1:
xmins.append(xmin)
xmaxs.append(xmax)
ymins.append(ymin)
ymaxs.append(ymax)
classes_text.append('Text'.encode('utf8'))
classes.append(1)
filename = os.path.join(path, file_name)
filename = filename.encode('utf8')
image_format = b'jpg'
if len(xmins) != 0:
tf_example = tf.train.Example(features=tf.train.Features(feature={
'image/height': dataset_util.int64_feature(height),
'image/width': dataset_util.int64_feature(width),
'image/filename': dataset_util.bytes_feature(filename),
'image/source_id': dataset_util.bytes_feature(filename),
'image/encoded': dataset_util.bytes_feature(encoded_jpg),
'image/format': dataset_util.bytes_feature(image_format),
'image/object/bbox/xmin': dataset_util.float_list_feature(xmins),
'image/object/bbox/xmax': dataset_util.float_list_feature(xmaxs),
'image/object/bbox/ymin': dataset_util.float_list_feature(ymins),
'image/object/bbox/ymax': dataset_util.float_list_feature(ymaxs),
'image/object/class/text': dataset_util.bytes_list_feature(classes_text),
'image/object/class/label': dataset_util.int64_list_feature(classes),
}))
writer.write(tf_example.SerializeToString())
num_examples += 1
return num_examples
if __name__ == "__main__":
args = parse_arguments()
train_or_val = args.train_or_val.lower()
ct = coco_text.COCO_Text(args.cocotext_json)
img_ids = ct.getImgIds(imgIds=ct.train, catIds=[('legibility', 'legible')]) \
if train_or_val == 'train' else ct.getImgIds(imgIds=ct.val, catIds=[('legibility', 'legible')])
seen = set()
num_examples = 0
writer = tf.python_io.TFRecordWriter(args.output_path)
for img_id in img_ids:
img = ct.loadImgs(img_id)[0]
file_name = img['file_name']
if file_name in seen:
continue
seen.add(file_name)
train_val_dir = 'train2014'
path = os.path.join(args.coco_imgdir, train_val_dir)
pil_img = Image.open(os.path.join(path, file_name))
width, height = pil_img.size
# sanity check
if width != img['width'] or height != img['height']:
width = img['width']
height = img['height']
if width == 0 or height == 0:
continue
with tf.gfile.GFile(os.path.join(path, file_name), 'rb') as fid:
encoded_jpg = fid.read()
ann_ids = ct.getAnnIds(img['id'])
anns = ct.loadAnns(ann_ids)
n = create_tf_examples(writer, anns, path, file_name, width, height, encoded_jpg)
num_examples += n
writer.close()
print('Generated({} examples):'.format(num_examples, args.output_path))