Skip to content

Commit

Permalink
Update perturb.py (NVIDIA#5231)
Browse files Browse the repository at this point in the history
* Update perturb.py

Add checking for channels mismatch for audio and noise data, throw an exception if they have different number of channels. Also fixed `perturb_with_foreground_noise` as done in `perturb_with_input_noise`

Signed-off-by: He Huang (Steve) <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update check and teest

Signed-off-by: stevehuang52 <[email protected]>

* fix test

Signed-off-by: stevehuang52 <[email protected]>

Signed-off-by: He Huang (Steve) <[email protected]>
Signed-off-by: stevehuang52 <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Signed-off-by: 1-800-bad-code <[email protected]>
  • Loading branch information
2 people authored and 1-800-BAD-CODE committed Nov 13, 2022
1 parent 1f172b1 commit 2d535ad
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 4 deletions.
52 changes: 48 additions & 4 deletions nemo/collections/asr/parts/preprocessing/perturb.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,11 @@ def get_one_noise_sample(self, target_sr):
)

def perturb(self, data, ref_mic=0):
"""
Args:
data (AudioSegment): audio data
ref_mic (int): reference mic index for scaling multi-channel audios
"""
noise = read_one_audiosegment(
self._manifest,
data.sample_rate,
Expand All @@ -442,6 +447,23 @@ def perturb(self, data, ref_mic=0):
self.perturb_with_input_noise(data, noise, ref_mic=ref_mic)

def perturb_with_input_noise(self, data, noise, data_rms=None, ref_mic=0):
"""
Args:
data (AudioSegment): audio data
noise (AudioSegment): noise data
data_rms (Union[float, List[float]): rms_db for data input
ref_mic (int): reference mic index for scaling multi-channel audios
"""
if data.num_channels != noise.num_channels:
raise ValueError(
f"Found mismatched channels for data ({data.num_channels}) and noise ({noise.num_channels})."
)

if not (0 <= ref_mic < data.num_channels):
raise ValueError(
f" reference mic ID must be an integer in [0, {data.num_channels}), got {ref_mic} instead."
)

snr_db = self._rng.uniform(self._min_snr_db, self._max_snr_db)
if data_rms is None:
data_rms = data.rms_db
Expand All @@ -467,14 +489,36 @@ def perturb_with_input_noise(self, data, noise, data_rms=None, ref_mic=0):
else:
data._samples += noise._samples

def perturb_with_foreground_noise(
self, data, noise, data_rms=None, max_noise_dur=2, max_additions=1,
):
def perturb_with_foreground_noise(self, data, noise, data_rms=None, max_noise_dur=2, max_additions=1, ref_mic=0):
"""
Args:
data (AudioSegment): audio data
noise (AudioSegment): noise data
data_rms (Union[float, List[float]): rms_db for data input
max_noise_dur: (float): max noise duration
max_additions (int): number of times for adding noise
ref_mic (int): reference mic index for scaling multi-channel audios
"""
if data.num_channels != noise.num_channels:
raise ValueError(
f"Found mismatched channels for data ({data.num_channels}) and noise ({noise.num_channels})."
)

if not (0 <= ref_mic < data.num_channels):
raise ValueError(
f" reference mic ID must be an integer in [0, {data.num_channels}), got {ref_mic} instead."
)

snr_db = self._rng.uniform(self._min_snr_db, self._max_snr_db)
if not data_rms:
data_rms = data.rms_db

noise_gain_db = min(data_rms - noise.rms_db - snr_db, self._max_gain_db)
if data.num_channels > 1:
noise_gain_db = data_rms[ref_mic] - noise.rms_db[ref_mic] - snr_db
else:
noise_gain_db = data_rms - noise.rms_db - snr_db
noise_gain_db = min(noise_gain_db, self._max_gain_db)

n_additions = self._rng.randint(1, max_additions)

for i in range(n_additions):
Expand Down
57 changes: 57 additions & 0 deletions tests/collections/asr/test_preprocessing_segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import os
import tempfile
from typing import List, Type, Union
Expand All @@ -20,6 +21,7 @@
import pytest
import soundfile as sf

from nemo.collections.asr.parts.preprocessing.perturb import NoisePerturbation
from nemo.collections.asr.parts.preprocessing.segment import AudioSegment
from nemo.collections.asr.parts.utils.audio_utils import select_channels

Expand Down Expand Up @@ -120,3 +122,58 @@ def test_from_file(self, num_channels, channel_selector):
assert uut.duration == self.signal_duration_sec
max_diff = np.max(np.abs(uut.samples - golden_samples))
assert max_diff < self.max_diff_tol

@pytest.mark.unit
@pytest.mark.parametrize("data_channels", [1, 4])
@pytest.mark.parametrize("noise_channels", [1, 4])
def test_noise_perturb_channels(self, data_channels, noise_channels):
"""Test loading a signal from a file.
"""
with tempfile.TemporaryDirectory() as test_dir:
# Prepare a wav file
audio_file = os.path.join(test_dir, 'audio.wav')
if data_channels == 1:
# samples is a one-dimensional vector for single-channel signal
samples = np.random.rand(self.num_samples)
else:
samples = np.random.rand(self.num_samples, data_channels)
sf.write(audio_file, samples, self.sample_rate, 'float')

noise_file = os.path.join(test_dir, 'noise.wav')
if noise_channels == 1:
# samples is a one-dimensional vector for single-channel signal
samples = np.random.rand(self.num_samples)
else:
samples = np.random.rand(self.num_samples, noise_channels)
sf.write(noise_file, samples, self.sample_rate, 'float')

manifest_file = os.path.join(test_dir, 'noise_manifest.json')
with open(manifest_file, 'w') as fout:
item = {'audio_filepath': os.path.abspath(noise_file), 'label': '-', 'duration': 0.1, 'offset': 0.0}
fout.write(f'{json.dumps(item)}\n')

perturber = NoisePerturbation(manifest_file)
audio = AudioSegment.from_file(audio_file)
noise = AudioSegment.from_file(noise_file)

if data_channels == noise_channels:
try:
_ = perturber.perturb_with_input_noise(audio, noise, ref_mic=0)
except ValueError as e:
assert False, "perturb_with_input_noise failed with ref_mic=0"

with pytest.raises(ValueError):
_ = perturber.perturb_with_input_noise(audio, noise, ref_mic=data_channels)

try:
_ = perturber.perturb_with_foreground_noise(audio, noise, ref_mic=0)
except ValueError as e:
assert False, "perturb_with_foreground_noise failed with ref_mic=0"

with pytest.raises(ValueError):
_ = perturber.perturb_with_foreground_noise(audio, noise, ref_mic=data_channels)
else:
with pytest.raises(ValueError):
_ = perturber.perturb_with_input_noise(audio, noise)
with pytest.raises(ValueError):
_ = perturber.perturb_with_foreground_noise(audio, noise)

0 comments on commit 2d535ad

Please sign in to comment.