Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

online mix noise audio data in training step #2622

Open
wants to merge 32 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
681f470
Remove comments check from alphabet
carlfm01 Jun 5, 2019
421243d
Remove sort from feeding
carlfm01 Jun 5, 2019
d08efad
Remove sort from evaluate tools
carlfm01 Jun 5, 2019
b0a14b5
Merge pull request #1 from carlfm01/master
carlfm01 Jun 29, 2019
ba1a587
Remove TF dependency
carlfm01 Jun 29, 2019
aebd08d
[ADD] mix noise audio
mychiux413 Dec 30, 2019
d255c3f
[FIX] add missing file decoded_augmentation.py
mychiux413 Dec 30, 2019
ec25136
mix noise works, but performance is bad
mychiux413 Dec 31, 2019
484134e
[MOD] use tf.Dataset to cache noise audio
mychiux413 Dec 31, 2019
4f24f08
rename decoded -> audio
mychiux413 Dec 31, 2019
1f57ece
[FIX] don't create tf.Dataset in other tf.Dataset's pipeline
mychiux413 Jan 2, 2020
66cc7c4
limit audio signal between +-1.0
mychiux413 Jan 13, 2020
b7eb0f4
[FIX] switch shuffle/map for memory cost, replace cache with prefetch…
mychiux413 Feb 11, 2020
ccae7cc
[MOD] limit the buffer size of .shuffle() to protect memory usage
mychiux413 Feb 17, 2020
8cc95f9
[ADD] bin/normalize_noise_audio.py
mychiux413 Feb 19, 2020
9e2648a
[MOD] mix noise into complete audio
mychiux413 Feb 21, 2020
2269514
[ADD] dev/test dataset can also mix noise [MOD] use SNR to balance no…
mychiux413 Mar 6, 2020
0b8147c
[ADD] use dbfs and SNR to determine the balance of audio/noise, add o…
mychiux413 Mar 16, 2020
42bc45b
[FIX] audiofile_to_features & samples_to_mfccs return 3 values now, a…
mychiux413 Mar 19, 2020
289722d
Fix issues.
Mar 29, 2020
9334e79
Save invalid files.
Mar 29, 2020
25736e0
Merge remote-tracking branch 'noiseaug/more-augment-options' into noi…
Mar 29, 2020
40b431b
Fix merging errors.
Mar 29, 2020
f7d1279
[FIX] replace tqdm with prograssbar [ADD] separate speech/noise mixin…
mychiux413 Mar 31, 2020
7792226
Merge branch 'no-sort' into more-augment-options
carlfm01 Apr 2, 2020
c4c3ced
Merge #f7d1279.
Apr 12, 2020
c151b1d
Merge branch 'master' into noisetest
Apr 17, 2020
c089b7f
Fix merge not detecting moved scripts.
Apr 17, 2020
491a4b0
Undo personal changes.
Apr 17, 2020
735cbbb
Merge branch 'master' of https://github.com/mozilla/DeepSpeech into n…
Apr 23, 2020
2fa91e8
To recover the incorrect merge
mychiux413 May 12, 2020
6b820bb
Merge pull request #1 from DanBmh/noiseaugmaster
mychiux413 May 14, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 170 additions & 0 deletions bin/normalize_noise_audio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
from __future__ import absolute_import, division, print_function
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about prepare_noise.py?


# Make sure we can import stuff from util/
# This script needs to be run from the root of the DeepSpeech repository

from librosa import get_duration
from multiprocessing import Pool
from functools import partial
import math
import argparse
import sys
import os
import progressbar
sys.path.insert(1, os.path.join(sys.path[0], '..'))

from util.feeding import secs_to_hours

try:
from pydub import AudioSegment
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We want to keep dependencies at a minimum. Please check, if your required functionality couldn't be covered by e.g. librosa. In the past I had some trouble with pydub leaking memory and moved away from it.

except ImportError as err:
print('[ImportError] try `sudo apt-get install ffmpeg && pip install pydub`')
raise err


def detect_silence(sound: AudioSegment, silence_threshold=-50.0, chunk_size=10):
start_trim = 0 # ms
sound_size = len(sound)
assert chunk_size > 0 # to avoid infinite loop
while sound[start_trim:(start_trim + chunk_size)].dBFS < silence_threshold and start_trim < sound_size:
start_trim += chunk_size

end_trim = sound_size
while sound[(end_trim - chunk_size):end_trim].dBFS < silence_threshold and end_trim > 0:
end_trim -= chunk_size

start_trim = min(sound_size, start_trim)
end_trim = max(0, end_trim)

return min([start_trim, end_trim]), max([start_trim, end_trim])


def trim_silence_audio(sound: AudioSegment, silence_threshold=-50.0, chunk_size=10):
start_trim, end_trim = detect_silence(sound, silence_threshold, chunk_size)
return sound[start_trim:end_trim]


def convert(filename, dst_dirpath, dirpath, normalize, trim_silence,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please check, if how this is covered by or could be merged into the current audio.py.

min_duration_seconds, max_duration_seconds):
if not filename.endswith(('.wav', '.raw')):
return

filepath = os.path.join(dirpath, filename)
if filename.endswith('.wav'):
sound: AudioSegment = AudioSegment.from_file(filepath)
else:
try:
sound: AudioSegment = AudioSegment.from_raw(filepath,
sample_width=2,
frame_rate=44100,
channels=1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please take sample_width, frame_rate and channels from the command line.

except Exception as err: # pylint: disable=broad-except
print('Retrying conversion: {}'.format(err))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why?

try:
sound: AudioSegment = AudioSegment.from_raw(filepath,
sample_width=2,
frame_rate=48000,
channels=1)
except Exception as err: # pylint: disable=broad-except
print('Skipping file {}, got error: {}'.format(filepath, err))
return
try:
sound = sound.set_frame_rate(16000)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please make this command-line configurable.

except Exception as err: # pylint: disable=broad-except
print('Skipping {}'.format(err))
return

n_splits = max(1, math.ceil(sound.duration_seconds / max_duration_seconds))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great idea to split noise into chunks to limit wasted overlap during augmentation!

chunk_duration_ms = math.ceil(len(sound) / n_splits)
chunks = []

for i in range(n_splits):
end_ms = min((i + 1) * chunk_duration_ms, len(sound))
chunk = sound[(i * chunk_duration_ms):end_ms]
chunks.append(chunk)

for i, chunk in enumerate(chunks):
dst_path = os.path.join(dst_dirpath, str(i) + '_' + filename)
if dst_path.endswith('.raw'):
dst_path = dst_path[:-4] + '.wav'

if os.path.exists(dst_path):
print('Audio already exists: {}'.format(dst_path))
return

if normalize:
chunk = chunk.normalize()
if chunk.dBFS < -30.0:
chunk = chunk.compress_dynamic_range().normalize()
if chunk.dBFS < -30.0:
chunk = chunk.compress_dynamic_range().normalize()
if trim_silence:
chunk = trim_silence_audio(chunk)

if chunk.duration_seconds < min_duration_seconds:
return
chunk.export(dst_path, format='wav')


def get_noise_duration(dst_dir):
duration = 0.0
file_num = 0
for dirpath, _, filenames in os.walk(dst_dir):
for f in filenames:
if not f.endswith('.wav'):
continue
duration += get_duration(filename=os.path.join(dirpath, f))
file_num += 1
return duration, file_num


def main(src_dir,
dst_dir,
min_duration_seconds,
max_duration_seconds,
normalize=True,
trim_silence=True):
assert os.path.exists(src_dir)
if not os.path.exists(dst_dir):
os.makedirs(dst_dir, exist_ok=False)
src_dir = os.path.abspath(src_dir)
dst_dir = os.path.abspath(dst_dir)

for dirpath, _, filenames in os.walk(src_dir):
dirpath = os.path.abspath(dirpath)
dst_dirpath = os.path.join(
dst_dir, dirpath.replace(src_dir, '').lstrip('/'))

print('Converting directory: {} -> {}'.format(dirpath, dst_dirpath))
if not os.path.exists(dst_dirpath):
os.makedirs(dst_dirpath, exist_ok=False)

convert_func = partial(convert,
dst_dirpath=dst_dirpath,
dirpath=dirpath,
normalize=normalize,
trim_silence=trim_silence,
min_duration_seconds=min_duration_seconds,
max_duration_seconds=max_duration_seconds)

pool = Pool(processes=None)
pbar = progressbar.ProgressBar(prefix='Preparing Noise Dataset', max_value=len(filenames)).start()
for i, _ in enumerate(pool.imap_unordered(convert_func, filenames)):
pbar.update(i)
pbar.finish()


if __name__ == "__main__":
PARSER = argparse.ArgumentParser(description='Optimize noise files')
PARSER.add_argument('--from_dir', help='Convert wav from directory', type=str)
PARSER.add_argument('--to_dir', help='save wav to directory', type=str)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This tool should also be able to produce SDBs like our SDB tool.

I'll put up a PR for changing classes util.audio.LabeledSample, util.sample_collections.CSV, util.sample_collections.SDB and util.sample_collections.DirectSDBWriter to allow unlabeled (noise) samples. Once that is merged, this PR has to be re-based and refactored accordingly.

PARSER.add_argument('--min_sec', help='min duration seconds of saved file', type=float, default=1.0)
PARSER.add_argument('--max_sec', help='max duration seconds of saved file', type=float, default=30.0)
PARSER.add_argument('--normalize', action='store_true', help='Normalize sound range, default is true', default=True)
PARSER.add_argument('--trim', action='store_true', help='Trim silence, default is true', default=True)
PARAMS = PARSER.parse_args()

main(PARAMS.from_dir, PARAMS.to_dir, PARAMS.min_sec, PARAMS.max_sec, PARAMS.normalize, PARAMS.trim)

DURATION, FILE_NUM = get_noise_duration(PARAMS.to_dir)
print("Your noise dataset has {} files and a duration of {}\n".format(FILE_NUM, secs_to_hours(DURATION)))
6 changes: 3 additions & 3 deletions training/deepspeech_training/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,21 +43,21 @@ def sparse_tuple_to_texts(sp_tuple, alphabet):
return [alphabet.decode(res) for res in results]


def evaluate(test_csvs, create_model):
def evaluate(test_csvs, create_model, noise_sources=None, speech_sources=None):
if FLAGS.scorer_path:
scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta,
FLAGS.scorer_path, Config.alphabet)
else:
scorer = None

test_csvs = FLAGS.test_files.split(',')
test_sets = [create_dataset([csv], batch_size=FLAGS.test_batch_size, train_phase=False) for csv in test_csvs]
test_sets = [create_dataset([csv], batch_size=FLAGS.test_batch_size, train_phase=False, noise_sources=noise_sources, speech_sources=speech_sources) for csv in test_csvs]
iterator = tfv1.data.Iterator.from_structure(tfv1.data.get_output_types(test_sets[0]),
tfv1.data.get_output_shapes(test_sets[0]),
output_classes=tfv1.data.get_output_classes(test_sets[0]))
test_init_ops = [iterator.make_initializer(test_set) for test_set in test_sets]

batch_wav_filename, (batch_x, batch_x_len), batch_y = iterator.get_next()
batch_wav_filename, (batch_x, batch_x_len), batch_y, _ = iterator.get_next()

# One rate per layer
no_dropout = [None] * 6
Expand Down
50 changes: 37 additions & 13 deletions training/deepspeech_training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ def calculate_mean_edit_distance_and_loss(iterator, dropout, reuse):
the decoded result and the batch's original Y.
'''
# Obtain the next batch of data
batch_filenames, (batch_x, batch_seq_len), batch_y = iterator.get_next()
batch_filenames, (batch_x, batch_seq_len), batch_y, review_audio = iterator.get_next()

if FLAGS.train_cudnn:
rnn_impl = rnn_impl_cudnn_rnn
Expand All @@ -239,7 +239,9 @@ def calculate_mean_edit_distance_and_loss(iterator, dropout, reuse):
logits, _ = create_model(batch_x, batch_seq_len, dropout, reuse=reuse, rnn_impl=rnn_impl)

# Compute the CTC loss using TensorFlow's `ctc_loss`
total_loss = tfv1.nn.ctc_loss(labels=batch_y, inputs=logits, sequence_length=batch_seq_len)
total_loss = tfv1.nn.ctc_loss(labels=batch_y,
inputs=logits,
sequence_length=batch_seq_len)

# Check if any files lead to non finite loss
non_finite_files = tf.gather(batch_filenames, tfv1.where(~tf.math.is_finite(total_loss)))
Expand All @@ -248,7 +250,7 @@ def calculate_mean_edit_distance_and_loss(iterator, dropout, reuse):
avg_loss = tf.reduce_mean(input_tensor=total_loss)

# Finally we return the average loss
return avg_loss, non_finite_files
return avg_loss, non_finite_files, review_audio


# Adam Optimization
Expand Down Expand Up @@ -309,7 +311,7 @@ def get_tower_results(iterator, optimizer, dropout_rates):
with tf.name_scope('tower_%d' % i):
# Calculate the avg_loss and mean_edit_distance and retrieve the decoded
# batch along with the original batch's labels (Y) of this tower
avg_loss, non_finite_files = calculate_mean_edit_distance_and_loss(iterator, dropout_rates, reuse=i > 0)
avg_loss, non_finite_files, review_audio = calculate_mean_edit_distance_and_loss(iterator, dropout_rates, reuse=i > 0)

# Allow for variables to be re-used by the next tower
tfv1.get_variable_scope().reuse_variables()
Expand All @@ -326,6 +328,8 @@ def get_tower_results(iterator, optimizer, dropout_rates):
tower_non_finite_files.append(non_finite_files)

avg_loss_across_towers = tf.reduce_mean(input_tensor=tower_avg_losses, axis=0)
if FLAGS.review_audio_steps:
tfv1.summary.audio(name='step_audio', tensor=review_audio, sample_rate=FLAGS.audio_sample_rate, collections=['step_audio_summaries'])
tfv1.summary.scalar(name='step_loss', tensor=avg_loss_across_towers, collections=['step_summaries'])

all_non_finite_files = tf.concat(tower_non_finite_files, axis=0)
Expand Down Expand Up @@ -415,7 +419,9 @@ def train():
FLAGS.augmentation_freq_and_time_masking or
FLAGS.augmentation_pitch_and_tempo_scaling or
FLAGS.augmentation_speed_up_std > 0 or
FLAGS.augmentation_sparse_warp):
FLAGS.augmentation_sparse_warp or
FLAGS.train_augmentation_noise_files or
FLAGS.train_augmentation_speech_files):
do_cache_dataset = False

exception_box = ExceptionBox()
Expand All @@ -428,7 +434,9 @@ def train():
train_phase=True,
exception_box=exception_box,
process_ahead=len(Config.available_devices) * FLAGS.train_batch_size * 2,
buffering=FLAGS.read_buffer)
buffering=FLAGS.read_buffer,
noise_sources=FLAGS.train_augmentation_noise_files,
speech_sources=FLAGS.train_augmentation_speech_files)

iterator = tfv1.data.Iterator.from_structure(tfv1.data.get_output_types(train_set),
tfv1.data.get_output_shapes(train_set),
Expand All @@ -444,7 +452,9 @@ def train():
train_phase=False,
exception_box=exception_box,
process_ahead=len(Config.available_devices) * FLAGS.dev_batch_size * 2,
buffering=FLAGS.read_buffer) for source in dev_sources]
buffering=FLAGS.read_buffer,
noise_sources=FLAGS.dev_augmentation_noise_files,
speech_sources=FLAGS.dev_augmentation_speech_files) for source in dev_sources]
dev_init_ops = [iterator.make_initializer(dev_set) for dev_set in dev_sets]

# Dropout
Expand Down Expand Up @@ -482,6 +492,7 @@ def train():
apply_gradient_op = optimizer.apply_gradients(avg_tower_gradients, global_step=global_step)

# Summaries
step_audio_summaries_op = tfv1.summary.merge_all('step_audio_summaries')
step_summaries_op = tfv1.summary.merge_all('step_summaries')
step_summary_writers = {
'train': tfv1.summary.FileWriter(os.path.join(FLAGS.summary_dir, 'train'), max_queue=120),
Expand Down Expand Up @@ -541,11 +552,21 @@ def __call__(self, progress, data, **kwargs):
session.run(init_op)

# Batch loop

audio_summary_steps = 0
while True:
try:
_, current_step, batch_loss, problem_files, step_summary = \
session.run([train_op, global_step, loss, non_finite_files, step_summaries_op],
feed_dict=feed_dict)
step_audio_summary = None
if audio_summary_steps < FLAGS.review_audio_steps and epoch == 0:
_, current_step, batch_loss, problem_files, step_summary, step_audio_summary = \
session.run([train_op, global_step, loss, non_finite_files, step_summaries_op, step_audio_summaries_op],
feed_dict=feed_dict)
audio_summary_steps += 1
else:
_, current_step, batch_loss, problem_files, step_summary = \
session.run([train_op, global_step, loss, non_finite_files, step_summaries_op],
feed_dict=feed_dict)

exception_box.raise_if_set()
except tf.errors.InvalidArgumentError as err:
if FLAGS.augmentation_sparse_warp:
Expand All @@ -566,6 +587,9 @@ def __call__(self, progress, data, **kwargs):

pbar.update(step_count)

if step_audio_summary is not None:
step_summary_writer.add_summary(step_audio_summary, current_step)

step_summary_writer.add_summary(step_summary, current_step)

if is_train and FLAGS.checkpoint_secs > 0 and time.time() - checkpoint_time > FLAGS.checkpoint_secs:
Expand Down Expand Up @@ -639,7 +663,7 @@ def __call__(self, progress, data, **kwargs):


def test():
samples = evaluate(FLAGS.test_files.split(','), create_model)
samples = evaluate(FLAGS.test_files.split(','), create_model, noise_sources=FLAGS.test_augmentation_noise_files, speech_sources=FLAGS.test_augmentation_speech_files)
if FLAGS.test_output_file:
# Save decoded tuples as JSON, converting NumPy floats to Python floats
json.dump(samples, open(FLAGS.test_output_file, 'w'), default=float)
Expand All @@ -651,7 +675,7 @@ def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
# Create feature computation graph
input_samples = tfv1.placeholder(tf.float32, [Config.audio_window_samples], 'input_samples')
samples = tf.expand_dims(input_samples, -1)
mfccs, _ = samples_to_mfccs(samples, FLAGS.audio_sample_rate)
mfccs, _, _ = samples_to_mfccs(samples, FLAGS.audio_sample_rate)
mfccs = tf.identity(mfccs, name='mfccs')

# Input tensor will be of shape [batch_size, n_steps, 2*n_context+1, n_input]
Expand Down Expand Up @@ -851,7 +875,7 @@ def do_single_file_inference(input_file_path):
# Restore variables from training checkpoint
load_graph_for_evaluation(session)

features, features_len = audiofile_to_features(input_file_path)
features, features_len, _ = audiofile_to_features(input_file_path)
previous_state_c = np.zeros([1, Config.n_cell_dim])
previous_state_h = np.zeros([1, Config.n_cell_dim])

Expand Down
Loading