-
Notifications
You must be signed in to change notification settings - Fork 0
/
tf2_test_data_loader.py
99 lines (88 loc) · 3.55 KB
/
tf2_test_data_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import os
import math
import tensorflow as tf
import matplotlib.pyplot as plt
# Corresponding changes are to be made here
# if the feature description in tf2_preprocessing.py
# is changed
test_feature_description = {
'segment': tf.io.FixedLenFeature([], tf.string),
'file': tf.io.FixedLenFeature([], tf.string),
'num': tf.io.FixedLenFeature([], tf.int64),
'label': tf.io.FixedLenFeature([], tf.int64)
}
def build_test_dataset(dir_path, batch_size=16, file_buffer=500*1024*1024):
'''Return a tf.data.Dataset based on all TFRecords in dir_path
Args:
dir_path: path to directory containing the TFRecords
batch_size: size of batch ie #training examples per element of the dataset
file_buffer: for TFRecords, size in bytes
label: target label for the example
'''
# glob pattern for files
file_pattern = os.path.join(dir_path, '*.tfrecord')
# stores shuffled filenames
file_ds = tf.data.Dataset.list_files(file_pattern)
# read from multiple files in parallel
ds = tf.data.TFRecordDataset(file_ds,
num_parallel_reads=tf.data.experimental.AUTOTUNE,
buffer_size=file_buffer)
# batch the examples
# dropping remainder for now, trouble when parsing - adding labels
ds = ds.batch(batch_size, drop_remainder=True)
# parse the records into the correct types
ds = ds.map(lambda x: _my_test_parser(x, batch_size),
num_parallel_calls=tf.data.experimental.AUTOTUNE)
ds = ds.prefetch(tf.data.experimental.AUTOTUNE)
return ds
def _my_test_parser(examples, batch_size):
'''Parses a batch of serialised tf.train.Example(s)
Args:
example: a batch serialised tf.train.Example(s)
Returns:
a tuple (segment, label)
where segment is a tensor of shape (#in_batch, #frames, h, w, #channels)
'''
# ex will be a tensor of serialised tensors
ex = tf.io.parse_example(examples, features=test_feature_description)
ex['segment'] = tf.map_fn(lambda x: _parse_segment(x),
ex['segment'], dtype=tf.uint8)
# ignoring filename and segment num for now
# returns a tuple (tensor1, tensor2)
# tensor1 is a batch of segments, tensor2 is the corresponding labels
return (ex['segment'], ex['label'])
def _parse_segment(segment):
'''Parses a segment and returns it as a tensor
A segment is a serialised tensor of a number of encoded jpegs
'''
# now a tensor of encoded jpegs
parsed = tf.io.parse_tensor(segment, out_type=tf.string)
# now a tensor of shape (#frames, h, w, #channels)
parsed = tf.map_fn(lambda y: tf.io.decode_jpeg(y), parsed, dtype=tf.uint8)
return parsed
def display_segment(segment, batch_size):
fig = plt.figure(figsize=(16, 16))
columns = int(math.sqrt(batch_size))
rows = math.ceil(batch_size / float(columns))
for i in range(1, columns*rows + 1):
img = segment[i-1]
fig.add_subplot(rows, columns, i)
plt.imshow(img)
plt.show()
if __name__ == "__main__":
dir_path = './ValidSetSamples'
batch_size = 16
ds = build_test_dataset(dir_path, batch_size=batch_size)
count = 0
for batch in ds:
for itr in range(batch_size):
segment = batch[0][itr]
label = batch[1][itr].numpy()
print(label)
display_segment(segment, batch_size)
print('Close the plot window manually')
inp = input("Hit q to quit, any other key to continue: ")
if inp == 'q':
break
if inp == 'q':
break