Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refine data reader and move data_reader to async_data_reader. #726

Merged
merged 3 commits into from
Mar 13, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
import data_utils.augmentor.trans_mean_variance_norm as trans_mean_variance_norm
import data_utils.augmentor.trans_add_delta as trans_add_delta
from data_utils.util import suppress_complaints, suppress_signal
from data_utils.util import CriticalException, ForceExitWrapper
from data_utils.util import SharedNDArray, SharedMemoryPoolManager
from data_utils.util import DaemonProcessGroup, batch_to_ndarray
from data_utils.util import CriticalException, ForceExitWrapper, EpochEndSignal


class SampleInfo(object):
Expand Down Expand Up @@ -67,7 +69,8 @@ class SampleInfoBucket(object):
split_sentence_threshold(int): Sentence whose length larger than
the value will trigger split operation.
split_sub_sentence_len(int): sub-sentence length is equal to
(split_sub_sentence_len + rand() % split_perturb).
(split_sub_sentence_len + \
rand() % split_perturb).
"""

def __init__(self,
Expand Down Expand Up @@ -159,11 +162,7 @@ def generate_sample_info_list(self):
return sample_info_list


class EpochEndSignal():
pass


class DataReader(object):
class AsyncDataReader(object):
"""DataReader provides basic audio sample preprocessing pipeline including
data loading and data augmentation.

Expand All @@ -174,7 +173,7 @@ class DataReader(object):
corresponding description file.
drop_frame_len (int): Samples whose label length above the value will be
dropped.(Using '-1' to disable the policy)
process_num (int): Number of processes for processing data.
proc_num (int): Number of processes for processing data.
sample_buffer_size (int): Buffer size to indicate the maximum samples
cached.
sample_info_buffer_size (int): Buffer size to indicate the maximum
Expand All @@ -193,10 +192,10 @@ def __init__(self,
feature_file_list,
label_file_list,
drop_frame_len=512,
process_num=10,
proc_num=10,
sample_buffer_size=1024,
sample_info_buffer_size=1024,
batch_buffer_size=1024,
batch_buffer_size=10,
shuffle_block_num=10,
random_seed=0,
verbose=0):
Expand All @@ -213,9 +212,14 @@ def __init__(self,
self._sample_buffer_size = sample_buffer_size
self._sample_info_buffer_size = sample_info_buffer_size
self._batch_buffer_size = batch_buffer_size
self._process_num = process_num
self._proc_num = proc_num
if self._proc_num <= 2:
raise ValueError("Value of `proc_num` should be greater than 2.")
self._sample_proc_num = self._proc_num - 2
self._verbose = verbose
self._force_exit = ForceExitWrapper(self._manager.Value('b', False))
self._pool_manager = SharedMemoryPoolManager(self._batch_buffer_size *
3, self._manager)

def generate_bucket_list(self, is_shuffle):
if self._block_info_list is None:
Expand Down Expand Up @@ -249,13 +253,23 @@ def generate_bucket_list(self, is_shuffle):
def set_transformers(self, transformers):
self._transformers = transformers

def _sample_generator(self):
def recycle(self, *args):
for shared_ndarray in args:
if not isinstance(shared_ndarray, SharedNDArray):
raise Value("Only support recycle SharedNDArray object.")
shared_ndarray.recycle(self._pool_manager.pool)

def _start_async_processing(self):
sample_info_queue = self._manager.Queue(self._sample_info_buffer_size)
sample_queue = self._manager.Queue(self._sample_buffer_size)
self._order_id = 0

@suppress_complaints(verbose=self._verbose, notify=self._force_exit)
def ordered_feeding_task(sample_info_queue):
if self._verbose == 0:
signal.signal(signal.SIGTERM, suppress_signal)
signal.signal(signal.SIGINT, suppress_signal)

for sample_info_bucket in self._bucket_list:
try:
sample_info_list = \
Expand All @@ -268,13 +282,12 @@ def ordered_feeding_task(sample_info_queue):
sample_info_queue.put((sample_info, self._order_id))
self._order_id += 1

for i in xrange(self._process_num):
for i in xrange(self._sample_proc_num):
sample_info_queue.put(EpochEndSignal())

feeding_thread = Thread(
target=ordered_feeding_task, args=(sample_info_queue, ))
feeding_thread.daemon = True
feeding_thread.start()
feeding_proc = DaemonProcessGroup(
proc_num=1, target=ordered_feeding_task, args=(sample_info_queue, ))
feeding_proc.start_all()

@suppress_complaints(verbose=self._verbose, notify=self._force_exit)
def ordered_processing_task(sample_info_queue, sample_queue, out_order):
Expand All @@ -301,12 +314,12 @@ def read_bytes(fpath, start, size):
sample_info.feature_start,
sample_info.feature_size)

assert sample_info.feature_frame_num * sample_info.feature_dim * 4 \
== len(feature_bytes), \
(sample_info.feature_bin_path,
sample_info.feature_frame_num,
sample_info.feature_dim,
len(feature_bytes))
assert sample_info.feature_frame_num \
* sample_info.feature_dim * 4 == len(feature_bytes), \
(sample_info.feature_bin_path,
sample_info.feature_frame_num,
sample_info.feature_dim,
len(feature_bytes))

label_bytes = read_bytes(sample_info.label_bin_path,
sample_info.label_start,
Expand Down Expand Up @@ -351,75 +364,71 @@ def read_bytes(fpath, start, size):

out_order = self._manager.list([0])
args = (sample_info_queue, sample_queue, out_order)
workers = [
Process(
target=ordered_processing_task, args=args)
for _ in xrange(self._process_num)
]
sample_proc = DaemonProcessGroup(
proc_num=self._sample_proc_num,
target=ordered_processing_task,
args=args)
sample_proc.start_all()

for w in workers:
w.daemon = True
w.start()
return sample_queue

finished_process_num = 0

while self._force_exit == False:
try:
sample = sample_queue.get_nowait()
except Queue.Empty:
time.sleep(0.001)
else:
if isinstance(sample, EpochEndSignal):
finished_process_num += 1
if finished_process_num >= self._process_num:
break
def batch_iterator(self, batch_size, minimum_batch_size):
@suppress_complaints(verbose=self._verbose, notify=self._force_exit)
def batch_assembling_task(sample_queue, batch_queue, pool):
def conv_to_shared(ndarray):
while self._force_exit == False:
try:
(name, shared_ndarray) = pool.popitem()
except Exception as e:
time.sleep(0.001)
else:
continue
shared_ndarray.copy(ndarray)
return shared_ndarray

yield sample

def batch_iterator(self, batch_size, minimum_batch_size):
def batch_to_ndarray(batch_samples, lod):
assert len(batch_samples)
frame_dim = batch_samples[0][0].shape[1]
batch_feature = np.zeros((lod[-1], frame_dim), dtype="float32")
batch_label = np.zeros((lod[-1], 1), dtype="int64")
start = 0
for sample in batch_samples:
frame_num = sample[0].shape[0]
batch_feature[start:start + frame_num, :] = sample[0]
batch_label[start:start + frame_num, :] = sample[1]
start += frame_num
return (batch_feature, batch_label)
if self._verbose == 0:
signal.signal(signal.SIGTERM, suppress_signal)
signal.signal(signal.SIGINT, suppress_signal)

@suppress_complaints(verbose=self._verbose, notify=self._force_exit)
def batch_assembling_task(sample_generator, batch_queue):
batch_samples = []
lod = [0]
for sample in sample_generator():
batch_samples.append(sample)
lod.append(lod[-1] + sample[0].shape[0])
if len(batch_samples) == batch_size:
(batch_feature, batch_label) = batch_to_ndarray(
batch_samples, lod)
batch_queue.put((batch_feature, batch_label, lod))
batch_samples = []
lod = [0]
done_num = 0
while done_num < self._sample_proc_num:
sample = sample_queue.get()
if isinstance(sample, EpochEndSignal):
done_num += 1
else:
batch_samples.append(sample)
lod.append(lod[-1] + sample[0].shape[0])
if len(batch_samples) == batch_size:
feature, label = batch_to_ndarray(batch_samples, lod)

feature = conv_to_shared(feature)
label = conv_to_shared(label)
lod = conv_to_shared(np.array(lod).astype('int64'))

batch_queue.put((feature, label, lod))
batch_samples = []
lod = [0]

if len(batch_samples) >= minimum_batch_size:
(batch_feature, batch_label) = batch_to_ndarray(batch_samples,
lod)
batch_queue.put((batch_feature, batch_label, lod))
(feature, label) = batch_to_ndarray(batch_samples, lod)

feature = conv_to_shared(feature)
label = conv_to_shared(label)
lod = conv_to_shared(np.array(lod).astype('int64'))

batch_queue.put((feature, label, lod))

batch_queue.put(EpochEndSignal())

batch_queue = Queue.Queue(self._batch_buffer_size)
sample_queue = self._start_async_processing()
batch_queue = self._manager.Queue(self._batch_buffer_size)

assembling_thread = Thread(
assembling_proc = DaemonProcessGroup(
proc_num=1,
target=batch_assembling_task,
args=(self._sample_generator, batch_queue))
assembling_thread.daemon = True
assembling_thread.start()
args=(sample_queue, batch_queue, self._pool_manager.pool))
assembling_proc.start_all()

while self._force_exit == False:
try:
Expand Down
Loading