Skip to content

Commit

Permalink
Merge pull request #396 from pkuyym/fix-393
Browse files Browse the repository at this point in the history
Give option to disable converting from transcription text to ids.
  • Loading branch information
xinghai-sun authored Nov 3, 2017
2 parents 8bd02c4 + 081789b commit b37ece0
Show file tree
Hide file tree
Showing 6 changed files with 28 additions and 24 deletions.
17 changes: 12 additions & 5 deletions deep_speech_2/data_utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ class DataGenerator(object):
:type num_threads: int
:param random_seed: Random seed.
:type random_seed: int
:param keep_transcription_text: If set to True, transcription text will
be passed forward directly without
converting to index sequence.
:type keep_transcription_text: bool
"""

def __init__(self,
Expand All @@ -69,7 +73,8 @@ def __init__(self,
specgram_type='linear',
use_dB_normalization=True,
num_threads=multiprocessing.cpu_count() // 2,
random_seed=0):
random_seed=0,
keep_transcription_text=False):
self._max_duration = max_duration
self._min_duration = min_duration
self._normalizer = FeatureNormalizer(mean_std_filepath)
Expand All @@ -84,6 +89,7 @@ def __init__(self,
use_dB_normalization=use_dB_normalization)
self._num_threads = num_threads
self._rng = random.Random(random_seed)
self._keep_transcription_text = keep_transcription_text
self._epoch = 0
# for caching tar files info
self._local_data = local()
Expand All @@ -97,8 +103,8 @@ def process_utterance(self, filename, transcript):
:type filename: basestring | file
:param transcript: Transcription text.
:type transcript: basestring
:return: Tuple of audio feature tensor and list of token ids for
transcription.
:return: Tuple of audio feature tensor and data of transcription part,
where transcription part could be token ids or text.
:rtype: tuple of (2darray, list)
"""
if filename.startswith('tar:'):
Expand All @@ -107,9 +113,10 @@ def process_utterance(self, filename, transcript):
else:
speech_segment = SpeechSegment.from_file(filename, transcript)
self._augmentation_pipeline.transform_audio(speech_segment)
specgram, text_ids = self._speech_featurizer.featurize(speech_segment)
specgram, transcript_part = self._speech_featurizer.featurize(
speech_segment, self._keep_transcription_text)
specgram = self._normalizer.apply(specgram)
return specgram, text_ids
return specgram, transcript_part

def batch_reader_creator(self,
manifest_path,
Expand Down
8 changes: 5 additions & 3 deletions deep_speech_2/data_utils/featurizer/speech_featurizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,12 @@ def __init__(self,
target_dB=target_dB)
self._text_featurizer = TextFeaturizer(vocab_filepath)

def featurize(self, speech_segment):
def featurize(self, speech_segment, keep_transcription_text):
"""Extract features for speech segment.
1. For audio parts, extract the audio features.
2. For transcript parts, convert text string to a list of token indices
in char-level.
2. For transcript parts, keep the original text or convert text string
to a list of token indices in char-level.
:param audio_segment: Speech segment to extract features from.
:type audio_segment: SpeechSegment
Expand All @@ -74,6 +74,8 @@ def featurize(self, speech_segment):
:rtype: tuple
"""
audio_feature = self._audio_featurizer.featurize(speech_segment)
if keep_transcription_text:
return audio_feature, speech_segment.transcript
text_ids = self._text_featurizer.featurize(speech_segment.transcript)
return audio_feature, text_ids

Expand Down
3 changes: 2 additions & 1 deletion deep_speech_2/deploy/demo_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,8 @@ def start_server():
mean_std_filepath=args.mean_std_path,
augmentation_config='{}',
specgram_type=args.specgram_type,
num_threads=1)
num_threads=1,
keep_transcription_text=True)
# prepare ASR model
ds2_model = DeepSpeech2Model(
vocab_size=data_generator.vocab_size,
Expand Down
8 changes: 3 additions & 5 deletions deep_speech_2/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ def infer():
mean_std_filepath=args.mean_std_path,
augmentation_config='{}',
specgram_type=args.specgram_type,
num_threads=1)
num_threads=1,
keep_transcription_text=True)
batch_reader = data_generator.batch_reader_creator(
manifest_path=args.infer_manifest,
batch_size=args.num_samples,
Expand Down Expand Up @@ -102,10 +103,7 @@ def infer():
num_processes=args.num_proc_bsearch)

error_rate_func = cer if args.error_rate_type == 'cer' else wer
target_transcripts = [
''.join([data_generator.vocab_list[token] for token in transcript])
for _, transcript in infer_data
]
target_transcripts = [transcript for _, transcript in infer_data]
for target, result in zip(target_transcripts, result_transcripts):
print("\nTarget Transcription: %s\nOutput Transcription: %s" %
(target, result))
Expand Down
8 changes: 3 additions & 5 deletions deep_speech_2/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ def evaluate():
mean_std_filepath=args.mean_std_path,
augmentation_config='{}',
specgram_type=args.specgram_type,
num_threads=args.num_proc_data)
num_threads=args.num_proc_data,
keep_transcription_text=True)
batch_reader = data_generator.batch_reader_creator(
manifest_path=args.test_manifest,
batch_size=args.batch_size,
Expand Down Expand Up @@ -103,10 +104,7 @@ def evaluate():
vocab_list=vocab_list,
language_model_path=args.lang_model_path,
num_processes=args.num_proc_bsearch)
target_transcripts = [
''.join([data_generator.vocab_list[token] for token in transcript])
for _, transcript in infer_data
]
target_transcripts = [transcript for _, transcript in infer_data]
for target, result in zip(target_transcripts, result_transcripts):
error_sum += error_rate_func(target, result)
num_ins += 1
Expand Down
8 changes: 3 additions & 5 deletions deep_speech_2/tools/tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,8 @@ def tune():
mean_std_filepath=args.mean_std_path,
augmentation_config='{}',
specgram_type=args.specgram_type,
num_threads=args.num_proc_data)
num_threads=args.num_proc_data,
keep_transcription_text=True)

audio_data = paddle.layer.data(
name="audio_spectrogram",
Expand Down Expand Up @@ -163,10 +164,7 @@ def tune():
for i in xrange(len(infer_data))
]

target_transcripts = [
''.join([data_generator.vocab_list[token] for token in transcript])
for _, transcript in infer_data
]
target_transcripts = [transcript for _, transcript in infer_data]

num_ins += len(target_transcripts)
# grid search
Expand Down

0 comments on commit b37ece0

Please sign in to comment.