Skip to content

Commit

Permalink
Follow up on PR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
tilmankamp committed May 18, 2020
1 parent 7b08e59 commit f48b92c
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 59 deletions.
23 changes: 12 additions & 11 deletions doc/TRAINING.rst
Original file line number Diff line number Diff line change
Expand Up @@ -270,10 +270,10 @@ Augmentation
Augmentation is a useful technique for better generalization of machine learning models. Thus, a pre-processing pipeline with various augmentation techniques on raw pcm and spectrogram has been implemented and can be used while training the model. Following are the available augmentation techniques that can be enabled at training time by using the corresponding flags in the command line.


Audio Augmentation before feature caching
-----------------------------------------
Audio Augmentation
------------------

Augmentations that are applied before potential feature caching can be specified through the ``--augment`` multi-flag.
Augmentations that are applied before potential feature caching can be specified through the ``--augment`` flag. Being a multi-flag, it can be specified multiple times (see below for an example).

Each sample of the training data will get treated by every specified augmentation in their given order. However: whether an augmentation will actually get applied to a sample is decided by chance on base of the augmentation's probability value. For example a value of ``p=0.1`` would apply the according augmentation to just 10% of all samples. This also means that augmentations are not mutually exclusive on a per-sample basis.

Expand All @@ -294,24 +294,25 @@ In the documentation below, whenever a value is specified as ``<float-range>`` o

* ``<start>:<end>~<r>``: Combination of the two previous cases with a ranging center value. E.g. ``4-6~2`` would at the beginning of an epoch pick values between 2 and 6 and at the end of an epoch between 4 and 8.

Ranges specified with integer limits will only assume integer (rounded) values.

The flag ``--augmentations_per_epoch N`` receives an integer value and defaults to 1. During training, each epoch will do ``N`` passes over the training set, each time performing augmentation independently of previous passes. Be aware: this will also multiply the required size of the feature cache if it's enabled.
If feature caching is enabled, these augmentations will only be performed on the first epoch and the result will be reused for subsequent epochs. The flag ``--augmentations_per_epoch N`` (by default `N` is 1) could be used to get more than one epoch worth of augmentations into the cache. During training, each epoch will do ``N`` passes over the training set, each time performing augmentation independently of previous passes. Be aware: this will also multiply the required size of the feature cache if it's enabled.


**Overlay augmentation** ``--augment overlay[p=<float>,source=<str>,snr=<float-range>,layers=<int-range>]``
Layers another audio source (multiple times) onto augmented samples.

* **p**: probability value between 0.0 (never) and 1.0 (always) if a given sample gets augmented by this method

* **source**: path to the sample collection to use for augmenting (*.sdb or *.csv file)
* **source**: path to the sample collection to use for augmenting (*.sdb or *.csv file). It will be repeated if there are not enough samples left.
* **snr**: signal to noise ratio in dB - positive values for lowering volume of the overlay in relation to the sample

* **layers**: number of layers of the overlay signal (e.g. 10 layers of speech to get "cocktail-party effect")
* **layers**: number of layers added onto the sample (e.g. 10 layers of speech to get "cocktail-party effect"). A layer is just a sample of the same duration as the sample to augment. It gets stitched together from as many source samples as required.


**Reverb augmentation** ``--augment reverb[p=<float>,delay=<float-range>,decay=<float-range>]``
Adds reverberation to the augmented samples.
Adds simplified (no all-pass filters) `Schroeder reverberation <https://ccrma.stanford.edu/~jos/pasp/Schroeder_Reverberators.html>`_ to the augmented samples.

* **p**: probability value between 0.0 (never) and 1.0 (always) if a given sample gets augmented by this method

Expand Down Expand Up @@ -387,15 +388,15 @@ Example simulation of the codec augmentation of a wav-file first at the beginnin
bin/play.py --augment codec[p=0.1,bitrate=48000:16000] --clock 1.0 test.wav
Audio Augmentation after feature caching
----------------------------------------
The following augmentations are applied after feature caching, hence the way they are applied will not repeat epoch-wise.
Working on spectrogram and feature level, `bin/play.py` offers no ability to simulate them.

#. **Standard deviation for Gaussian additive noise:** ``--data_aug_features_additive``
#. **Standard deviation for Normal distribution around 1 for multiplicative noise:** ``--data_aug_features_multiplicative``
#. **Standard deviation for speeding-up tempo. If Standard deviation is 0, this augmentation is not performed:** ``--augmentation_speed_up_std``

Spectrogram Augmentation after feature caching
----------------------------------------------
Spectrogram Augmentation
------------------------

Inspired by Google Paper on `SpecAugment: A Simple Data Augmentation Method for Automatic Speech Recognition <https://arxiv.org/abs/1904.08779>`_

Expand Down
8 changes: 2 additions & 6 deletions training/deepspeech_training/util/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,10 @@
import collections
import numpy as np

from .helpers import LimitingPool, np_capped_squares
from .helpers import LimitingPool
from collections import namedtuple

AudioFormat = namedtuple('AudioFormat', 'rate channels width')
dBFS = namedtuple('dBFS', 'mean max')

DEFAULT_RATE = 16000
DEFAULT_CHANNELS = 1
Expand Down Expand Up @@ -398,11 +397,8 @@ def rms_to_dbfs(rms):
return 20.0 * math.log10(max(1e-16, rms)) + 3.0103


def mean_dbfs(sample_data):
return rms_to_dbfs(math.sqrt(np.mean(np_capped_squares(sample_data))))


def max_dbfs(sample_data):
# Peak dBFS based on the maximum energy sample. Will prevent overdrive if used for normalization.
return rms_to_dbfs(max(abs(np.min(sample_data)), abs(np.max(sample_data))))


Expand Down
16 changes: 1 addition & 15 deletions training/deepspeech_training/util/helpers.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
import os
import sys
import time
import math
import heapq
import semver
import random
import numpy as np

from multiprocessing import Pool
from collections import namedtuple
Expand Down Expand Up @@ -158,7 +156,7 @@ def get_value_range(value, target_type):
if len(value) == 2:
return ValueRange(target_type(value[0]), target_type(value[1]), 0)
if len(value) == 3:
return ValueRange(target_type(value[0]), target_type(value[1]), target_type(value[1]))
return ValueRange(target_type(value[0]), target_type(value[1]), target_type(value[2]))
raise ValueError('Cannot convert to ValueRange: Wrong tuple size')
return ValueRange(target_type(value), target_type(value), 0)

Expand All @@ -176,15 +174,3 @@ def pick_value_from_range(value_range, clock=None):
value = value_range.start + clock * (value_range.end - value_range.start)
value = random.uniform(value - value_range.r, value + value_range.r)
return round(value) if isinstance(value_range.start, int) else value


def call_if_exists(o, name, *args, **kwargs):
method = getattr(o, name, None)
if callable(method):
method(*args, **kwargs)


def np_capped_squares(data):
sqrt_max = math.sqrt(np.finfo(data.dtype).max)
data = np.minimum(np.maximum(data, -sqrt_max), sqrt_max) # prevent overflow during squaring
return data ** 2
15 changes: 5 additions & 10 deletions training/deepspeech_training/util/sample_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from functools import partial

from .signal_augmentations import parse_augmentation
from .helpers import MEGABYTE, GIGABYTE, Interleaved, LimitingPool, call_if_exists
from .helpers import MEGABYTE, GIGABYTE, Interleaved, LimitingPool
from .audio import Sample, DEFAULT_FORMAT, AUDIO_TYPE_OPUS, AUDIO_TYPE_NP, SERIALIZABLE_AUDIO_TYPES, get_audio_type_from_extension

BIG_ENDIAN = 'big'
Expand Down Expand Up @@ -300,16 +300,15 @@ def __del__(self):


class SampleList:
"""Sample collection reader for reading a DeepSpeech CSV file
Automatically orders samples by CSV column wav_filesize (if available)."""
"""Sample collection base class with samples loaded from a list of in-memory paths."""
def __init__(self, samples, labeled=True):
"""
Parameters
----------
samples : iterable of tuples of the form (sample_filename, filesize [, transcript])
File-size is used for ordering the samples; transcript has to be provided if labeled=True
labeled : bool or None
If True: Reads LabeledSample instances. Fails, if CSV file has no transcript column.
If True: Reads LabeledSample instances.
If False: Ignores transcripts (if available) and reads (unlabeled) util.audio.Sample instances.
"""
self.labeled = labeled
Expand All @@ -320,10 +319,6 @@ def __getitem__(self, i):
sample_spec = self.samples[i]
return load_sample(sample_spec[0], label=sample_spec[2] if self.labeled else None)

def __iter__(self):
for i in range(len(self.samples)):
yield self[i]

def __len__(self):
return len(self.samples)

Expand Down Expand Up @@ -493,7 +488,7 @@ def timed_samples():
augmentations = [] if augmentation_specs is None else list(map(parse_augmentation, augmentation_specs))
try:
for augmentation in augmentations:
call_if_exists(augmentation, 'start', buffering=buffering)
augmentation.start(buffering=buffering)
context = PreparationContext(audio_type, augmentations)
if process_ahead == 0:
for timed_sample in timed_samples():
Expand All @@ -505,4 +500,4 @@ def timed_samples():
yield from pool.imap(_augment_sample, timed_samples())
finally:
for augmentation in augmentations:
call_if_exists(augmentation, 'stop')
augmentation.stop()
56 changes: 39 additions & 17 deletions training/deepspeech_training/util/signal_augmentations.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,31 @@
from .audio import gain_db_to_ratio, max_dbfs, normalize_audio, AUDIO_TYPE_NP, AUDIO_TYPE_PCM, AUDIO_TYPE_OPUS
from .helpers import int_range, float_range, pick_value_from_range, MEGABYTE

SPEC_PARSER = re.compile(r'^([a-z]+)(\[(.*)\])?$')
SPEC_PARSER = re.compile(r'^(?P<cls>[a-z]+)(\[(?P<params>.*)\])?$')
BUFFER_SIZE = 1 * MEGABYTE


class Augmentation:
def __init__(self, p=1.0):
self.probability = float(p)

def start(self, buffering=BUFFER_SIZE):
pass

def apply(self, sample, clock):
raise NotImplementedError

def stop(self):
pass


def _enqueue_overlay_samples(sample_source, queue, buffering=BUFFER_SIZE):
"""
As the central distribution point for overlay samples this function is supposed to run in one process only.
This ensures that samples are not used twice if not required.
It loads the (raw and still compressed) data and provides it to the actual augmentation workers.
These are then doing decompression, potential conversion and overlaying in parallel.
"""
# preventing cyclic import problems
from .sample_collections import samples_from_source # pylint: disable=import-outside-toplevel
samples = samples_from_source(sample_source, buffering=buffering, labeled=False)
Expand All @@ -22,11 +42,11 @@ def _enqueue_overlay_samples(sample_source, queue, buffering=BUFFER_SIZE):
queue.put(sample)


class Overlay:
class Overlay(Augmentation):
"""See "Overlay augmentation" in TRAINING.rst"""
def __init__(self, source, p=1.0, snr=3.0, layers=1):
super(Overlay, self).__init__(p)
self.source = source
self.probability = float(p)
self.snr = float_range(snr)
self.layers = int_range(layers)
self.queue = Queue(max(1, math.floor(self.probability * self.layers[1] * os.cpu_count())))
Expand Down Expand Up @@ -72,10 +92,10 @@ def stop(self):
self.enqueue_process.terminate()


class Reverb:
class Reverb(Augmentation):
"""See "Reverb augmentation" in TRAINING.rst"""
def __init__(self, p=1.0, delay=20.0, decay=10.0):
self.probability = float(p)
super(Reverb, self).__init__(p)
self.delay = float_range(delay)
self.decay = float_range(decay)

Expand All @@ -102,10 +122,10 @@ def apply(self, sample, clock):
sample.audio = np.array(audio, dtype=np.float32)


class Resample:
class Resample(Augmentation):
"""See "Resample augmentation" in TRAINING.rst"""
def __init__(self, p=1.0, rate=8000):
self.probability = float(p)
super(Resample, self).__init__(p)
self.rate = int_range(rate)

def apply(self, sample, clock):
Expand All @@ -122,10 +142,10 @@ def apply(self, sample, clock):
sample.audio = audio


class Codec:
class Codec(Augmentation):
"""See "Codec augmentation" in TRAINING.rst"""
def __init__(self, p=1.0, bitrate=3200):
self.probability = float(p)
super(Codec, self).__init__(p)
self.bitrate = int_range(bitrate)

def apply(self, sample, clock):
Expand All @@ -134,10 +154,10 @@ def apply(self, sample, clock):
sample.change_audio_type(new_audio_type=AUDIO_TYPE_OPUS, bitrate=bitrate) # will get decoded again downstream


class Gaps:
class Gaps(Augmentation):
"""See "Gaps augmentation" in TRAINING.rst"""
def __init__(self, p=1.0, n=1, size=50.0):
self.probability = float(p)
super(Gaps, self).__init__(p)
self.n_gaps = int_range(n)
self.size = float_range(size)

Expand All @@ -154,10 +174,10 @@ def apply(self, sample, clock):
sample.audio = audio


class Volume:
class Volume(Augmentation):
"""See "Volume augmentation" in TRAINING.rst"""
def __init__(self, p=1.0, dbfs=3.0103):
self.probability = float(p)
super(Volume, self).__init__(p)
self.target_dbfs = float_range(dbfs)

def apply(self, sample, clock):
Expand All @@ -182,11 +202,13 @@ def parse_augmentation(augmentation_spec):
match = SPEC_PARSER.match(augmentation_spec)
if not match:
raise ValueError('Augmentation specification has wrong format')
cls_name = match.group(1)[0].upper() + match.group(1)[1:]
if cls_name not in globals():
cls_name = match.group('cls')
cls_name = cls_name[0].upper() + cls_name[1:]
augmentation_cls = globals()[cls_name] if cls_name in globals() else None
if not issubclass(augmentation_cls, Augmentation) or augmentation_cls == Augmentation:
raise ValueError('Unknown augmentation: {}'.format(cls_name))
augmentation_cls = globals()[cls_name]
parameters = [] if match.group(3) is None else match.group(3).split(',')
parameters = match.group('params')
parameters = [] if parameters is None else parameters.split(',')
args = []
kwargs = {}
for parameter in parameters:
Expand Down

0 comments on commit f48b92c

Please sign in to comment.