Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update perturb.py #5231

Merged
merged 9 commits into from
Oct 26, 2022
52 changes: 48 additions & 4 deletions nemo/collections/asr/parts/preprocessing/perturb.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,11 @@ def get_one_noise_sample(self, target_sr):
)

def perturb(self, data, ref_mic=0):
"""
Args:
data (AudioSegment): audio data
ref_mic (int): reference mic index for scaling multi-channel audios
"""
noise = read_one_audiosegment(
self._manifest,
data.sample_rate,
Expand All @@ -442,6 +447,23 @@ def perturb(self, data, ref_mic=0):
self.perturb_with_input_noise(data, noise, ref_mic=ref_mic)

def perturb_with_input_noise(self, data, noise, data_rms=None, ref_mic=0):
"""
Args:
data (AudioSegment): audio data
noise (AudioSegment): noise data
data_rms (Union[float, List[float]): rms_db for data input
ref_mic (int): reference mic index for scaling multi-channel audios
"""
if data.num_channels != noise.num_channels:
raise ValueError(
f"Found mismatched channels for data ({data.num_channels}) and noise ({noise.num_channels})."
)

if not (0 <= ref_mic < data.num_channels):
raise ValueError(
f" reference mic ID must be an integer in [0, {data.num_channels}), got {ref_mic} instead."
)

snr_db = self._rng.uniform(self._min_snr_db, self._max_snr_db)
if data_rms is None:
data_rms = data.rms_db
Expand All @@ -467,14 +489,36 @@ def perturb_with_input_noise(self, data, noise, data_rms=None, ref_mic=0):
else:
data._samples += noise._samples

def perturb_with_foreground_noise(
self, data, noise, data_rms=None, max_noise_dur=2, max_additions=1,
):
def perturb_with_foreground_noise(self, data, noise, data_rms=None, max_noise_dur=2, max_additions=1, ref_mic=0):
"""
Args:
data (AudioSegment): audio data
noise (AudioSegment): noise data
data_rms (Union[float, List[float]): rms_db for data input
max_noise_dur: (float): max noise duration
max_additions (int): number of times for adding noise
ref_mic (int): reference mic index for scaling multi-channel audios
"""
if data.num_channels != noise.num_channels:
raise ValueError(
f"Found mismatched channels for data ({data.num_channels}) and noise ({noise.num_channels})."
)

if not (0 <= ref_mic < data.num_channels):
raise ValueError(
f" reference mic ID must be an integer in [0, {data.num_channels}), got {ref_mic} instead."
)

snr_db = self._rng.uniform(self._min_snr_db, self._max_snr_db)
if not data_rms:
data_rms = data.rms_db

noise_gain_db = min(data_rms - noise.rms_db - snr_db, self._max_gain_db)
if data.num_channels > 1:
noise_gain_db = data_rms[ref_mic] - noise.rms_db[ref_mic] - snr_db
else:
noise_gain_db = data_rms - noise.rms_db - snr_db
noise_gain_db = min(noise_gain_db, self._max_gain_db)

n_additions = self._rng.randint(1, max_additions)

for i in range(n_additions):
Expand Down
57 changes: 57 additions & 0 deletions tests/collections/asr/test_preprocessing_segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import os
import tempfile
from typing import List, Type, Union
Expand All @@ -20,6 +21,7 @@
import pytest
import soundfile as sf

from nemo.collections.asr.parts.preprocessing.perturb import NoisePerturbation
from nemo.collections.asr.parts.preprocessing.segment import AudioSegment
from nemo.collections.asr.parts.utils.audio_utils import select_channels

Expand Down Expand Up @@ -120,3 +122,58 @@ def test_from_file(self, num_channels, channel_selector):
assert uut.duration == self.signal_duration_sec
max_diff = np.max(np.abs(uut.samples - golden_samples))
assert max_diff < self.max_diff_tol

@pytest.mark.unit
@pytest.mark.parametrize("data_channels", [1, 4])
@pytest.mark.parametrize("noise_channels", [1, 4])
def test_noise_perturb_channels(self, data_channels, noise_channels):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🙏

"""Test loading a signal from a file.
"""
with tempfile.TemporaryDirectory() as test_dir:
# Prepare a wav file
audio_file = os.path.join(test_dir, 'audio.wav')
if data_channels == 1:
# samples is a one-dimensional vector for single-channel signal
samples = np.random.rand(self.num_samples)
else:
samples = np.random.rand(self.num_samples, data_channels)
sf.write(audio_file, samples, self.sample_rate, 'float')

noise_file = os.path.join(test_dir, 'noise.wav')
if noise_channels == 1:
# samples is a one-dimensional vector for single-channel signal
samples = np.random.rand(self.num_samples)
else:
samples = np.random.rand(self.num_samples, data_channels)
sf.write(noise_file, samples, self.sample_rate, 'float')

manifest_file = os.path.join(test_dir, 'noise_manifest.json')
with open(manifest_file, 'w') as fout:
item = {'audio_filepath': os.path.abspath(noise_file), 'label': '-', 'duration': 0.1, 'offset': 0.0}
fout.write(f'{json.dumps(item)}\n')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: consider using write_manifest


perturber = NoisePerturbation(manifest_file)
audio = AudioSegment.from_file(audio_file)
noise = AudioSegment.from_file(noise_file)

if data_channels == noise_channels:
try:
_ = perturber.perturb_with_input_noise(audio, noise, ref_mic=0)
except ValueError as e:
assert False, "perturb_with_input_noise failed with ref_mic=0"

with pytest.raises(ValueError):
_ = perturber.perturb_with_input_noise(audio, noise, ref_mic=data_channels)

try:
_ = perturber.perturb_with_foreground_noise(audio, noise, ref_mic=0)
except ValueError as e:
assert False, "perturb_with_foreground_noise failed with ref_mic=0"

with pytest.raises(ValueError):
_ = perturber.perturb_with_foreground_noise(audio, noise, ref_mic=data_channels)
else:
with pytest.raises(ValueError):
_ = perturber.perturb_with_input_noise(audio, noise)
with pytest.raises(ValueError):
_ = perturber.perturb_with_foreground_noise(audio, noise)