-
Notifications
You must be signed in to change notification settings - Fork 3
/
tfr_util.py
136 lines (107 loc) · 4.33 KB
/
tfr_util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
# Original work Copyright 2017 The TensorFlow Authors. All Rights Reserved.
# Modifications Copyright 2018 Defense Innovation Unit Experimental.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
'''
Modifications copyright (C) 2018 <eScience Institue at University of Washington>
Licensed under CC BY-NC-ND 4.0 License [see LICENSE-CC BY-NC-ND 4.0.markdown for details]
Written by An Yan
'''
from PIL import Image
import tensorflow as tf
import io
import numpy as np
'''
TensorflowRecord (TFRecord) processing helper functions to be re-used by any scripts
that create or read TFRecord files.
'''
def to_tf_example(img, boxes, class_num):
"""
Converts a single image with respective boxes into a TFExample. Multiple TFExamples make up a TFRecord.
Args:
img: an image array
boxes: an array of bounding boxes for the given image
class_num: an array of class numbers for each bouding box
Output:
A TFExample containing encoded image data, scaled bounding boxes with classes, and other metadata.
"""
encoded = convertToJpeg(img)
width = img.shape[0]
height = img.shape[1]
xmin = []
ymin = []
xmax = []
ymax = []
classes = []
classes_text = []
for ind,box in enumerate(boxes):
xmin.append(box[0] / width)
ymin.append(box[1] / height)
xmax.append(box[2] / width)
ymax.append(box[3] / height)
classes.append(int(class_num[ind]))
example = tf.train.Example(features=tf.train.Features(feature={
'image/height': int64_feature(height),
'image/width': int64_feature(width),
'image/encoded': bytes_feature(encoded),
'image/format': bytes_feature('jpeg'.encode('utf8')),
'image/object/bbox/xmin': float_list_feature(xmin),
'image/object/bbox/xmax': float_list_feature(xmax),
'image/object/bbox/ymin': float_list_feature(ymin),
'image/object/bbox/ymax': float_list_feature(ymax),
'image/object/class/label': int64_list_feature(classes),
}))
return example
def convertToJpeg(im):
"""
Converts an image array into an encoded JPEG string.
Args:
im: an image array
Output:
an encoded byte string containing the converted JPEG image.
"""
with io.BytesIO() as f:
im = Image.fromarray(im)
im.save(f, format='JPEG')
return f.getvalue()
def create_tf_record(output_filename, images, boxes):
""" DEPRECIATED
Creates a TFRecord file from examples.
Args:
output_filename: Path to where output file is saved.
images: an array of images to create a record for
boxes: an array of bounding box coordinates ([xmin,ymin,xmax,ymax]) with the same index as images
"""
writer = tf.python_io.TFRecordWriter(output_filename)
k = 0
for idx, image in enumerate(images):
if idx % 100 == 0:
print('On image %d of %d' %(idx, len(images)))
tf_example = to_tf_example(image,boxes[idx],fname)
if np.array(tf_example.features.feature['image/object/bbox/xmin'].float_list.value[0]).any():
writer.write(tf_example.SerializeToString())
k = k + 1
print("saved: %d chips" % k)
writer.close()
## VARIOUS HELPERS BELOW ##
def int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def int64_list_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
def bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def bytes_list_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))
def float_list_feature(value):
return tf.train.Feature(float_list=tf.train.FloatList(value=value))