diff --git a/deep_speech_2/audio_data_utils.py b/deep_speech_2/audio_data_utils.py index c717bcf182..1cd29be114 100644 --- a/deep_speech_2/audio_data_utils.py +++ b/deep_speech_2/audio_data_utils.py @@ -8,6 +8,7 @@ import random import soundfile import numpy as np +import itertools import os RANDOM_SEED = 0 @@ -62,6 +63,7 @@ def __init__(self, self.__stride_ms__ = stride_ms self.__window_ms__ = window_ms self.__max_frequency__ = max_frequency + self.__epoc__ = 0 self.__random__ = random.Random(RANDOM_SEED) # load vocabulary (dictionary) self.__vocab_dict__, self.__vocab_list__ = \ @@ -245,10 +247,42 @@ def __padding_batch__(self, batch, padding_to=-1, flatten=False): new_batch.append((padded_audio, text)) return new_batch - def instance_reader_creator(self, - manifest_path, - sort_by_duration=True, - shuffle=False): + def __batch_shuffle__(self, manifest, batch_size): + """ + The instances have different lengths and they cannot be + combined into a single matrix multiplication. It usually + sorts the training examples by length and combines only + similarly-sized instances into minibatches, pads with + silence when necessary so that all instances in a batch + have the same length. This batch shuffle fuction is used + to make similarly-sized instances into minibatches and + make a batch-wise shuffle. + + 1. Sort the audio clips by duration. + 2. Generate a random number `k`, k in [0, batch_size). + 3. Randomly remove `k` instances in order to make different mini-batches, + then make minibatches and each minibatch size is batch_size. + 4. Shuffle the minibatches. + + :param manifest: manifest file. + :type manifest: list + :param batch_size: Batch size. This size is also used for generate + a random number for batch shuffle. + :type batch_size: int + :return: batch shuffled mainifest. + :rtype: list + """ + manifest.sort(key=lambda x: x["duration"]) + shift_len = self.__random__.randint(0, batch_size - 1) + batch_manifest = zip(*[iter(manifest[shift_len:])] * batch_size) + self.__random__.shuffle(batch_manifest) + batch_manifest = list(sum(batch_manifest, ())) + res_len = len(manifest) - shift_len - len(batch_manifest) + batch_manifest.extend(manifest[-res_len:]) + batch_manifest.extend(manifest[0:shift_len]) + return batch_manifest + + def instance_reader_creator(self, manifest): """ Instance reader creator for audio data. Creat a callable function to produce instances of data. @@ -256,32 +290,13 @@ def instance_reader_creator(self, Instance: a tuple of a numpy ndarray of audio spectrogram and a list of tokenized and indexed transcription text. - :param manifest_path: Filepath of manifest for audio clip files. - :type manifest_path: basestring - :param sort_by_duration: Sort the audio clips by duration if set True - (for SortaGrad). - :type sort_by_duration: bool - :param shuffle: Shuffle the audio clips if set True. - :type shuffle: bool + :param manifest: Filepath of manifest for audio clip files. + :type manifest: basestring :return: Data reader function. :rtype: callable """ - if sort_by_duration and shuffle: - sort_by_duration = False - logger.warn("When shuffle set to true, " - "sort_by_duration is forced to set False.") def reader(): - # read manifest - manifest = self.__read_manifest__( - manifest_path=manifest_path, - max_duration=self.__max_duration__, - min_duration=self.__min_duration__) - # sort (by duration) or shuffle manifest - if sort_by_duration: - manifest.sort(key=lambda x: x["duration"]) - if shuffle: - self.__random__.shuffle(manifest) # extract spectrogram feature for instance in manifest: spectrogram = self.__audio_featurize__( @@ -296,8 +311,8 @@ def batch_reader_creator(self, batch_size, padding_to=-1, flatten=False, - sort_by_duration=True, - shuffle=False): + sortagrad=False, + batch_shuffle=False): """ Batch data reader creator for audio data. Creat a callable function to produce batches of data. @@ -317,20 +332,32 @@ def batch_reader_creator(self, :param flatten: If set True, audio data will be flatten to be a 1-dim ndarray. Otherwise, 2-dim ndarray. Default is False. :type flatten: bool - :param sort_by_duration: Sort the audio clips by duration if set True - (for SortaGrad). - :type sort_by_duration: bool - :param shuffle: Shuffle the audio clips if set True. - :type shuffle: bool + :param sortagrad: Sort the audio clips by duration in the first epoc + if set True. + :type sortagrad: bool + :param batch_shuffle: Shuffle the audio clips if set True. It is + not a thorough instance-wise shuffle, but a + specific batch-wise shuffle. For more details, + please see `__batch_shuffle__` function. + :type batch_shuffle: bool :return: Batch reader function, producing batches of data when called. :rtype: callable """ def batch_reader(): - instance_reader = self.instance_reader_creator( + # read manifest + manifest = self.__read_manifest__( manifest_path=manifest_path, - sort_by_duration=sort_by_duration, - shuffle=shuffle) + max_duration=self.__max_duration__, + min_duration=self.__min_duration__) + + # sort (by duration) or shuffle manifest + if self.__epoc__ == 0 and sortagrad: + manifest.sort(key=lambda x: x["duration"]) + elif batch_shuffle: + manifest = self.__batch_shuffle__(manifest, batch_size) + + instance_reader = self.instance_reader_creator(manifest) batch = [] for instance in instance_reader(): batch.append(instance) @@ -339,6 +366,7 @@ def batch_reader(): batch = [] if len(batch) > 0: yield self.__padding_batch__(batch, padding_to, flatten) + self.__epoc__ += 1 return batch_reader diff --git a/deep_speech_2/train.py b/deep_speech_2/train.py index 89ab23c685..957c24267c 100644 --- a/deep_speech_2/train.py +++ b/deep_speech_2/train.py @@ -93,23 +93,27 @@ def train(): """ DeepSpeech2 training. """ + # initialize data generator - data_generator = DataGenerator( - vocab_filepath=args.vocab_filepath, - normalizer_manifest_path=args.normalizer_manifest_path, - normalizer_num_samples=200, - max_duration=20.0, - min_duration=0.0, - stride_ms=10, - window_ms=20) + def data_generator(): + return DataGenerator( + vocab_filepath=args.vocab_filepath, + normalizer_manifest_path=args.normalizer_manifest_path, + normalizer_num_samples=200, + max_duration=20.0, + min_duration=0.0, + stride_ms=10, + window_ms=20) + train_generator = data_generator() + test_generator = data_generator() # create network config - dict_size = data_generator.vocabulary_size() + dict_size = train_generator.vocabulary_size() + # paddle.data_type.dense_array is used for variable batch input. + # the size 161 * 161 is only an placeholder value and the real shape + # of input batch data will be set at each batch. audio_data = paddle.layer.data( - name="audio_spectrogram", - height=161, - width=2000, - type=paddle.data_type.dense_vector(322000)) + name="audio_spectrogram", type=paddle.data_type.dense_array(161 * 161)) text_data = paddle.layer.data( name="transcript_text", type=paddle.data_type.integer_value_sequence(dict_size)) @@ -136,28 +140,16 @@ def train(): cost=cost, parameters=parameters, update_equation=optimizer) # prepare data reader - train_batch_reader_sortagrad = data_generator.batch_reader_creator( - manifest_path=args.train_manifest_path, - batch_size=args.batch_size, - padding_to=2000, - flatten=True, - sort_by_duration=True, - shuffle=False) - train_batch_reader_nosortagrad = data_generator.batch_reader_creator( + train_batch_reader = train_generator.batch_reader_creator( manifest_path=args.train_manifest_path, batch_size=args.batch_size, - padding_to=2000, - flatten=True, - sort_by_duration=False, - shuffle=True) - test_batch_reader = data_generator.batch_reader_creator( + sortagrad=True if args.init_model_path is None else False, + batch_shuffle=True) + test_batch_reader = test_generator.batch_reader_creator( manifest_path=args.dev_manifest_path, batch_size=args.batch_size, - padding_to=2000, - flatten=True, - sort_by_duration=False, - shuffle=False) - feeding = data_generator.data_name_feeding() + batch_shuffle=False) + feeding = train_generator.data_name_feeding() # create event handler def event_handler(event): @@ -183,17 +175,8 @@ def event_handler(event): time.time() - start_time, event.pass_id, result.cost) # run train - # first pass with sortagrad - if args.use_sortagrad: - trainer.train( - reader=train_batch_reader_sortagrad, - event_handler=event_handler, - num_passes=1, - feeding=feeding) - args.num_passes -= 1 - # other passes without sortagrad trainer.train( - reader=train_batch_reader_nosortagrad, + reader=train_batch_reader, event_handler=event_handler, num_passes=args.num_passes, feeding=feeding)