Skip to content

Commit

Permalink
Merge pull request #74 from qingqing01/ds2
Browse files Browse the repository at this point in the history
Support variable input batch and SortaGrad.
  • Loading branch information
xinghai-sun authored Jun 12, 2017
2 parents d67d362 + cb6da07 commit d91dab0
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 76 deletions.
98 changes: 63 additions & 35 deletions deep_speech_2/audio_data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import random
import soundfile
import numpy as np
import itertools
import os

RANDOM_SEED = 0
Expand Down Expand Up @@ -62,6 +63,7 @@ def __init__(self,
self.__stride_ms__ = stride_ms
self.__window_ms__ = window_ms
self.__max_frequency__ = max_frequency
self.__epoc__ = 0
self.__random__ = random.Random(RANDOM_SEED)
# load vocabulary (dictionary)
self.__vocab_dict__, self.__vocab_list__ = \
Expand Down Expand Up @@ -245,43 +247,56 @@ def __padding_batch__(self, batch, padding_to=-1, flatten=False):
new_batch.append((padded_audio, text))
return new_batch

def instance_reader_creator(self,
manifest_path,
sort_by_duration=True,
shuffle=False):
def __batch_shuffle__(self, manifest, batch_size):
"""
The instances have different lengths and they cannot be
combined into a single matrix multiplication. It usually
sorts the training examples by length and combines only
similarly-sized instances into minibatches, pads with
silence when necessary so that all instances in a batch
have the same length. This batch shuffle fuction is used
to make similarly-sized instances into minibatches and
make a batch-wise shuffle.
1. Sort the audio clips by duration.
2. Generate a random number `k`, k in [0, batch_size).
3. Randomly remove `k` instances in order to make different mini-batches,
then make minibatches and each minibatch size is batch_size.
4. Shuffle the minibatches.
:param manifest: manifest file.
:type manifest: list
:param batch_size: Batch size. This size is also used for generate
a random number for batch shuffle.
:type batch_size: int
:return: batch shuffled mainifest.
:rtype: list
"""
manifest.sort(key=lambda x: x["duration"])
shift_len = self.__random__.randint(0, batch_size - 1)
batch_manifest = zip(*[iter(manifest[shift_len:])] * batch_size)
self.__random__.shuffle(batch_manifest)
batch_manifest = list(sum(batch_manifest, ()))
res_len = len(manifest) - shift_len - len(batch_manifest)
batch_manifest.extend(manifest[-res_len:])
batch_manifest.extend(manifest[0:shift_len])
return batch_manifest

def instance_reader_creator(self, manifest):
"""
Instance reader creator for audio data. Creat a callable function to
produce instances of data.
Instance: a tuple of a numpy ndarray of audio spectrogram and a list of
tokenized and indexed transcription text.
:param manifest_path: Filepath of manifest for audio clip files.
:type manifest_path: basestring
:param sort_by_duration: Sort the audio clips by duration if set True
(for SortaGrad).
:type sort_by_duration: bool
:param shuffle: Shuffle the audio clips if set True.
:type shuffle: bool
:param manifest: Filepath of manifest for audio clip files.
:type manifest: basestring
:return: Data reader function.
:rtype: callable
"""
if sort_by_duration and shuffle:
sort_by_duration = False
logger.warn("When shuffle set to true, "
"sort_by_duration is forced to set False.")

def reader():
# read manifest
manifest = self.__read_manifest__(
manifest_path=manifest_path,
max_duration=self.__max_duration__,
min_duration=self.__min_duration__)
# sort (by duration) or shuffle manifest
if sort_by_duration:
manifest.sort(key=lambda x: x["duration"])
if shuffle:
self.__random__.shuffle(manifest)
# extract spectrogram feature
for instance in manifest:
spectrogram = self.__audio_featurize__(
Expand All @@ -296,8 +311,8 @@ def batch_reader_creator(self,
batch_size,
padding_to=-1,
flatten=False,
sort_by_duration=True,
shuffle=False):
sortagrad=False,
batch_shuffle=False):
"""
Batch data reader creator for audio data. Creat a callable function to
produce batches of data.
Expand All @@ -317,20 +332,32 @@ def batch_reader_creator(self,
:param flatten: If set True, audio data will be flatten to be a 1-dim
ndarray. Otherwise, 2-dim ndarray. Default is False.
:type flatten: bool
:param sort_by_duration: Sort the audio clips by duration if set True
(for SortaGrad).
:type sort_by_duration: bool
:param shuffle: Shuffle the audio clips if set True.
:type shuffle: bool
:param sortagrad: Sort the audio clips by duration in the first epoc
if set True.
:type sortagrad: bool
:param batch_shuffle: Shuffle the audio clips if set True. It is
not a thorough instance-wise shuffle, but a
specific batch-wise shuffle. For more details,
please see `__batch_shuffle__` function.
:type batch_shuffle: bool
:return: Batch reader function, producing batches of data when called.
:rtype: callable
"""

def batch_reader():
instance_reader = self.instance_reader_creator(
# read manifest
manifest = self.__read_manifest__(
manifest_path=manifest_path,
sort_by_duration=sort_by_duration,
shuffle=shuffle)
max_duration=self.__max_duration__,
min_duration=self.__min_duration__)

# sort (by duration) or shuffle manifest
if self.__epoc__ == 0 and sortagrad:
manifest.sort(key=lambda x: x["duration"])
elif batch_shuffle:
manifest = self.__batch_shuffle__(manifest, batch_size)

instance_reader = self.instance_reader_creator(manifest)
batch = []
for instance in instance_reader():
batch.append(instance)
Expand All @@ -339,6 +366,7 @@ def batch_reader():
batch = []
if len(batch) > 0:
yield self.__padding_batch__(batch, padding_to, flatten)
self.__epoc__ += 1

return batch_reader

Expand Down
65 changes: 24 additions & 41 deletions deep_speech_2/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,23 +93,27 @@ def train():
"""
DeepSpeech2 training.
"""

# initialize data generator
data_generator = DataGenerator(
vocab_filepath=args.vocab_filepath,
normalizer_manifest_path=args.normalizer_manifest_path,
normalizer_num_samples=200,
max_duration=20.0,
min_duration=0.0,
stride_ms=10,
window_ms=20)
def data_generator():
return DataGenerator(
vocab_filepath=args.vocab_filepath,
normalizer_manifest_path=args.normalizer_manifest_path,
normalizer_num_samples=200,
max_duration=20.0,
min_duration=0.0,
stride_ms=10,
window_ms=20)

train_generator = data_generator()
test_generator = data_generator()
# create network config
dict_size = data_generator.vocabulary_size()
dict_size = train_generator.vocabulary_size()
# paddle.data_type.dense_array is used for variable batch input.
# the size 161 * 161 is only an placeholder value and the real shape
# of input batch data will be set at each batch.
audio_data = paddle.layer.data(
name="audio_spectrogram",
height=161,
width=2000,
type=paddle.data_type.dense_vector(322000))
name="audio_spectrogram", type=paddle.data_type.dense_array(161 * 161))
text_data = paddle.layer.data(
name="transcript_text",
type=paddle.data_type.integer_value_sequence(dict_size))
Expand All @@ -136,28 +140,16 @@ def train():
cost=cost, parameters=parameters, update_equation=optimizer)

# prepare data reader
train_batch_reader_sortagrad = data_generator.batch_reader_creator(
manifest_path=args.train_manifest_path,
batch_size=args.batch_size,
padding_to=2000,
flatten=True,
sort_by_duration=True,
shuffle=False)
train_batch_reader_nosortagrad = data_generator.batch_reader_creator(
train_batch_reader = train_generator.batch_reader_creator(
manifest_path=args.train_manifest_path,
batch_size=args.batch_size,
padding_to=2000,
flatten=True,
sort_by_duration=False,
shuffle=True)
test_batch_reader = data_generator.batch_reader_creator(
sortagrad=True if args.init_model_path is None else False,
batch_shuffle=True)
test_batch_reader = test_generator.batch_reader_creator(
manifest_path=args.dev_manifest_path,
batch_size=args.batch_size,
padding_to=2000,
flatten=True,
sort_by_duration=False,
shuffle=False)
feeding = data_generator.data_name_feeding()
batch_shuffle=False)
feeding = train_generator.data_name_feeding()

# create event handler
def event_handler(event):
Expand All @@ -183,17 +175,8 @@ def event_handler(event):
time.time() - start_time, event.pass_id, result.cost)

# run train
# first pass with sortagrad
if args.use_sortagrad:
trainer.train(
reader=train_batch_reader_sortagrad,
event_handler=event_handler,
num_passes=1,
feeding=feeding)
args.num_passes -= 1
# other passes without sortagrad
trainer.train(
reader=train_batch_reader_nosortagrad,
reader=train_batch_reader,
event_handler=event_handler,
num_passes=args.num_passes,
feeding=feeding)
Expand Down

0 comments on commit d91dab0

Please sign in to comment.