diff --git a/nemo/collections/asr/parts/preprocessing/perturb.py b/nemo/collections/asr/parts/preprocessing/perturb.py index 7970a5962efc..14ef3d4082c5 100644 --- a/nemo/collections/asr/parts/preprocessing/perturb.py +++ b/nemo/collections/asr/parts/preprocessing/perturb.py @@ -432,6 +432,11 @@ def get_one_noise_sample(self, target_sr): ) def perturb(self, data, ref_mic=0): + """ + Args: + data (AudioSegment): audio data + ref_mic (int): reference mic index for scaling multi-channel audios + """ noise = read_one_audiosegment( self._manifest, data.sample_rate, @@ -442,6 +447,23 @@ def perturb(self, data, ref_mic=0): self.perturb_with_input_noise(data, noise, ref_mic=ref_mic) def perturb_with_input_noise(self, data, noise, data_rms=None, ref_mic=0): + """ + Args: + data (AudioSegment): audio data + noise (AudioSegment): noise data + data_rms (Union[float, List[float]): rms_db for data input + ref_mic (int): reference mic index for scaling multi-channel audios + """ + if data.num_channels != noise.num_channels: + raise ValueError( + f"Found mismatched channels for data ({data.num_channels}) and noise ({noise.num_channels})." + ) + + if not (0 <= ref_mic < data.num_channels): + raise ValueError( + f" reference mic ID must be an integer in [0, {data.num_channels}), got {ref_mic} instead." + ) + snr_db = self._rng.uniform(self._min_snr_db, self._max_snr_db) if data_rms is None: data_rms = data.rms_db @@ -467,14 +489,36 @@ def perturb_with_input_noise(self, data, noise, data_rms=None, ref_mic=0): else: data._samples += noise._samples - def perturb_with_foreground_noise( - self, data, noise, data_rms=None, max_noise_dur=2, max_additions=1, - ): + def perturb_with_foreground_noise(self, data, noise, data_rms=None, max_noise_dur=2, max_additions=1, ref_mic=0): + """ + Args: + data (AudioSegment): audio data + noise (AudioSegment): noise data + data_rms (Union[float, List[float]): rms_db for data input + max_noise_dur: (float): max noise duration + max_additions (int): number of times for adding noise + ref_mic (int): reference mic index for scaling multi-channel audios + """ + if data.num_channels != noise.num_channels: + raise ValueError( + f"Found mismatched channels for data ({data.num_channels}) and noise ({noise.num_channels})." + ) + + if not (0 <= ref_mic < data.num_channels): + raise ValueError( + f" reference mic ID must be an integer in [0, {data.num_channels}), got {ref_mic} instead." + ) + snr_db = self._rng.uniform(self._min_snr_db, self._max_snr_db) if not data_rms: data_rms = data.rms_db - noise_gain_db = min(data_rms - noise.rms_db - snr_db, self._max_gain_db) + if data.num_channels > 1: + noise_gain_db = data_rms[ref_mic] - noise.rms_db[ref_mic] - snr_db + else: + noise_gain_db = data_rms - noise.rms_db - snr_db + noise_gain_db = min(noise_gain_db, self._max_gain_db) + n_additions = self._rng.randint(1, max_additions) for i in range(n_additions): diff --git a/tests/collections/asr/test_preprocessing_segment.py b/tests/collections/asr/test_preprocessing_segment.py index dd2046e139b7..dba676dee3d9 100644 --- a/tests/collections/asr/test_preprocessing_segment.py +++ b/tests/collections/asr/test_preprocessing_segment.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json import os import tempfile from typing import List, Type, Union @@ -20,6 +21,7 @@ import pytest import soundfile as sf +from nemo.collections.asr.parts.preprocessing.perturb import NoisePerturbation from nemo.collections.asr.parts.preprocessing.segment import AudioSegment from nemo.collections.asr.parts.utils.audio_utils import select_channels @@ -120,3 +122,58 @@ def test_from_file(self, num_channels, channel_selector): assert uut.duration == self.signal_duration_sec max_diff = np.max(np.abs(uut.samples - golden_samples)) assert max_diff < self.max_diff_tol + + @pytest.mark.unit + @pytest.mark.parametrize("data_channels", [1, 4]) + @pytest.mark.parametrize("noise_channels", [1, 4]) + def test_noise_perturb_channels(self, data_channels, noise_channels): + """Test loading a signal from a file. + """ + with tempfile.TemporaryDirectory() as test_dir: + # Prepare a wav file + audio_file = os.path.join(test_dir, 'audio.wav') + if data_channels == 1: + # samples is a one-dimensional vector for single-channel signal + samples = np.random.rand(self.num_samples) + else: + samples = np.random.rand(self.num_samples, data_channels) + sf.write(audio_file, samples, self.sample_rate, 'float') + + noise_file = os.path.join(test_dir, 'noise.wav') + if noise_channels == 1: + # samples is a one-dimensional vector for single-channel signal + samples = np.random.rand(self.num_samples) + else: + samples = np.random.rand(self.num_samples, noise_channels) + sf.write(noise_file, samples, self.sample_rate, 'float') + + manifest_file = os.path.join(test_dir, 'noise_manifest.json') + with open(manifest_file, 'w') as fout: + item = {'audio_filepath': os.path.abspath(noise_file), 'label': '-', 'duration': 0.1, 'offset': 0.0} + fout.write(f'{json.dumps(item)}\n') + + perturber = NoisePerturbation(manifest_file) + audio = AudioSegment.from_file(audio_file) + noise = AudioSegment.from_file(noise_file) + + if data_channels == noise_channels: + try: + _ = perturber.perturb_with_input_noise(audio, noise, ref_mic=0) + except ValueError as e: + assert False, "perturb_with_input_noise failed with ref_mic=0" + + with pytest.raises(ValueError): + _ = perturber.perturb_with_input_noise(audio, noise, ref_mic=data_channels) + + try: + _ = perturber.perturb_with_foreground_noise(audio, noise, ref_mic=0) + except ValueError as e: + assert False, "perturb_with_foreground_noise failed with ref_mic=0" + + with pytest.raises(ValueError): + _ = perturber.perturb_with_foreground_noise(audio, noise, ref_mic=data_channels) + else: + with pytest.raises(ValueError): + _ = perturber.perturb_with_input_noise(audio, noise) + with pytest.raises(ValueError): + _ = perturber.perturb_with_foreground_noise(audio, noise)