From c42d9f1b6f79116812a67b88b0e6800d6d9dc1a1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ante=20Jukic=CC=81?= Date: Tue, 25 Oct 2022 16:30:18 -0700 Subject: [PATCH] Dataset factories + review comments MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Ante Jukić --- nemo/collections/asr/data/audio_to_audio.py | 63 ++-- .../asr/data/audio_to_audio_dataset.py | 91 ++++++ .../common/parts/preprocessing/collections.py | 2 +- tests/collections/asr/test_asr_datasets.py | 270 +++++++++++------- 4 files changed, 288 insertions(+), 138 deletions(-) create mode 100644 nemo/collections/asr/data/audio_to_audio_dataset.py diff --git a/nemo/collections/asr/data/audio_to_audio.py b/nemo/collections/asr/data/audio_to_audio.py index 5c1b07ee790a8..9efe83dc4862d 100644 --- a/nemo/collections/asr/data/audio_to_audio.py +++ b/nemo/collections/asr/data/audio_to_audio.py @@ -37,6 +37,7 @@ ] +# TODO: move utility functions to a more general module # Local utility functions def flatten_iterable(iter: Iterable[Union[str, Iterable[str]]]) -> Iterable[str]: """Flatten an iterable which contains strings or @@ -97,6 +98,7 @@ def load_samples( ) else: # TODO: Load random subsegment of `duration` seconds + # TODO: Load segment of `duration` seconds starting at fix_segment (non-random) raise NotImplementedError() return segment.samples @@ -157,6 +159,7 @@ def load_samples_synchronized( raise RuntimeError(f'Unexpected audio_file type {type(audio_file)}') else: + # TODO: Load segment of `duration` seconds starting at fix_segment (non-random) audio_durations = [librosa.get_duration(filename=f) for f in flatten(audio_files)] min_duration = min(audio_durations) available_duration = min_duration - fixed_offset @@ -511,9 +514,9 @@ def output_types(self) -> Optional[Dict[str, NeuralType]]: Dictionary in the following form: ``` { - 'input': single- or multi-channel format, + 'input_signal': single- or multi-channel format, 'input_length': original length of each input signal - 'target': single- or multi-channel format, + 'target_signal': single- or multi-channel format, 'target_length': original length of each target signal } ``` @@ -522,9 +525,9 @@ def output_types(self) -> Optional[Dict[str, NeuralType]]: mc_audio_type = NeuralType(('B', 'T', 'C'), AudioSignal()) return OrderedDict( - input=sc_audio_type if self.num_channels('input') == 1 else mc_audio_type, + input_signal=sc_audio_type if self.num_channels('input_signal') == 1 else mc_audio_type, input_length=NeuralType(('B',), LengthsType()), - target=sc_audio_type if self.num_channels('target') == 1 else mc_audio_type, + target_signal=sc_audio_type if self.num_channels('target_signal') == 1 else mc_audio_type, target_length=NeuralType(('B',), LengthsType()), ) @@ -538,8 +541,8 @@ def __getitem__(self, index: int) -> Dict[str, torch.tensor]: Dictionary providing mapping from signal to its tensor. ``` { - 'input': input_tensor, - 'target': target_tensor, + 'input_signal': input_tensor, + 'target_signal': target_tensor, } ``` """ @@ -561,7 +564,7 @@ def __getitem__(self, index: int) -> Dict[str, torch.tensor]: input_signal = list_to_multichannel(input_signal) target_signal = list_to_multichannel(target_signal) - output = OrderedDict(input=torch.tensor(input_signal), target=torch.tensor(target_signal),) + output = OrderedDict(input_signal=torch.tensor(input_signal), target_signal=torch.tensor(target_signal),) return output @@ -657,11 +660,11 @@ def output_types(self) -> Optional[Dict[str, NeuralType]]: Dictionary in the following form: ``` { - 'input': single- or multi-channel format, + 'input_signal': single- or multi-channel format, 'input_length': original length of each input signal - 'target': single- or multi-channel format, + 'target_signal': single- or multi-channel format, 'target_length': original length of each target signal - 'reference': single- or multi-channel format, + 'reference_signal': single- or multi-channel format, 'reference_length': original length of each reference signal } ``` @@ -670,11 +673,11 @@ def output_types(self) -> Optional[Dict[str, NeuralType]]: mc_audio_type = NeuralType(('B', 'T', 'C'), AudioSignal()) return OrderedDict( - input=sc_audio_type if self.num_channels('input') == 1 else mc_audio_type, + input_signal=sc_audio_type if self.num_channels('input_signal') == 1 else mc_audio_type, input_length=NeuralType(('B',), LengthsType()), - target=sc_audio_type if self.num_channels('target') == 1 else mc_audio_type, + target_signal=sc_audio_type if self.num_channels('target_signal') == 1 else mc_audio_type, target_length=NeuralType(('B',), LengthsType()), - reference=sc_audio_type if self.num_channels('reference') == 1 else mc_audio_type, + reference_signal=sc_audio_type if self.num_channels('reference_signal') == 1 else mc_audio_type, reference_length=NeuralType(('B',), LengthsType()), ) @@ -688,9 +691,9 @@ def __getitem__(self, index: int) -> Dict[str, torch.tensor]: Dictionary providing mapping from signal to its tensor. ``` { - 'input': input_tensor, - 'target': target_tensor, - 'reference': reference_tensor, + 'input_signal': input_tensor, + 'target_signal': target_tensor, + 'reference_signal': reference_tensor, } ``` """ @@ -739,9 +742,9 @@ def __getitem__(self, index: int) -> Dict[str, torch.tensor]: # Output dictionary output = OrderedDict( - input=torch.tensor(input_signal), - target=torch.tensor(target_signal), - reference=torch.tensor(reference_signal), + input_signal=torch.tensor(input_signal), + target_signal=torch.tensor(target_signal), + reference_signal=torch.tensor(reference_signal), ) return output @@ -823,11 +826,11 @@ def output_types(self) -> Optional[Dict[str, NeuralType]]: Dictionary in the following form: ``` { - 'input': single- or multi-channel format, + 'input_signal': single- or multi-channel format, 'input_length': original length of each input signal - 'target': single- or multi-channel format, + 'target_signal': single- or multi-channel format, 'target_length': original length of each target signal - 'embedding': batched embedded vector format, + 'embedding_vector': batched embedded vector format, 'embedding_length': original length of each embedding vector } ``` @@ -836,11 +839,11 @@ def output_types(self) -> Optional[Dict[str, NeuralType]]: mc_audio_type = NeuralType(('B', 'T', 'C'), AudioSignal()) return OrderedDict( - input=sc_audio_type if self.num_channels('input') == 1 else mc_audio_type, + input_signal=sc_audio_type if self.num_channels('input_signal') == 1 else mc_audio_type, input_length=NeuralType(('B',), LengthsType()), - target=sc_audio_type if self.num_channels('target') == 1 else mc_audio_type, + target_signal=sc_audio_type if self.num_channels('target_signal') == 1 else mc_audio_type, target_length=NeuralType(('B',), LengthsType()), - embedding=NeuralType(('B', 'D'), EncodedRepresentation()), + embedding_vector=NeuralType(('B', 'D'), EncodedRepresentation()), embedding_length=NeuralType(('B',), LengthsType()), ) @@ -854,9 +857,9 @@ def __getitem__(self, index): Dictionary providing mapping from signal to its tensor. ``` { - 'input': input_tensor, - 'target': target_tensor, - 'embedding': embedding_tensor, + 'input_signal': input_tensor, + 'target_signal': target_tensor, + 'embedding_vector': embedding_tensor, } ``` """ @@ -885,7 +888,9 @@ def __getitem__(self, index): # Output dictionary output = OrderedDict( - input=torch.tensor(input_signal), target=torch.tensor(target_signal), embedding=torch.tensor(embedding), + input_signal=torch.tensor(input_signal), + target_signal=torch.tensor(target_signal), + embedding_vector=torch.tensor(embedding), ) return output diff --git a/nemo/collections/asr/data/audio_to_audio_dataset.py b/nemo/collections/asr/data/audio_to_audio_dataset.py new file mode 100644 index 0000000000000..5835d4d349088 --- /dev/null +++ b/nemo/collections/asr/data/audio_to_audio_dataset.py @@ -0,0 +1,91 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nemo.collections.asr.data import audio_to_audio + + +def get_audio_to_target_dataset(config: dict) -> audio_to_audio.AudioToTargetDataset: + """Instantiates an audio-to-audio dataset. + + Args: + config: Config of AudioToTargetDataset. + + Returns: + An instance of AudioToTargetDataset + """ + dataset = audio_to_audio.AudioToTargetDataset( + manifest_filepath=config['manifest_filepath'], + sample_rate=config['sample_rate'], + input_key=config['input_key'], + target_key=config['target_key'], + audio_duration=config.get('audio_duration', None), + max_duration=config.get('max_duration', None), + min_duration=config.get('min_duration', None), + max_utts=config.get('max_utts', 0), + input_channel_selector=config.get('input_channel_selector', None), + target_channel_selector=config.get('target_channel_selector', None), + ) + return dataset + + +def get_audio_to_target_with_reference_dataset(config: dict) -> audio_to_audio.AudioToTargetWithReferenceDataset: + """Instantiates an audio-to-audio dataset. + + Args: + config: Config of AudioToTargetWithReferenceDataset. + + Returns: + An instance of AudioToTargetWithReferenceDataset + """ + dataset = audio_to_audio.AudioToTargetWithReferenceDataset( + manifest_filepath=config['manifest_filepath'], + sample_rate=config['sample_rate'], + input_key=config['input_key'], + target_key=config['target_key'], + reference_key=config['reference_key'], + audio_duration=config.get('audio_duration', None), + max_duration=config.get('max_duration', None), + min_duration=config.get('min_duration', None), + max_utts=config.get('max_utts', 0), + input_channel_selector=config.get('input_channel_selector', None), + target_channel_selector=config.get('target_channel_selector', None), + reference_channel_selector=config.get('reference_channel_selector', None), + reference_is_synchronized=config.get('reference_is_synchronized', True), + ) + return dataset + + +def get_audio_to_target_with_embedding_dataset(config: dict) -> audio_to_audio.AudioToTargetWithEmbeddingDataset: + """Instantiates an audio-to-audio dataset. + + Args: + config: Config of AudioToTargetWithEmbeddingDataset. + + Returns: + An instance of AudioToTargetWithEmbeddingDataset + """ + dataset = audio_to_audio.AudioToTargetWithEmbeddingDataset( + manifest_filepath=config['manifest_filepath'], + sample_rate=config['sample_rate'], + input_key=config['input_key'], + target_key=config['target_key'], + embedding_key=config['embedding_key'], + audio_duration=config.get('audio_duration', None), + max_duration=config.get('max_duration', None), + min_duration=config.get('min_duration', None), + max_utts=config.get('max_utts', 0), + input_channel_selector=config.get('input_channel_selector', None), + target_channel_selector=config.get('target_channel_selector', None), + ) + return dataset diff --git a/nemo/collections/common/parts/preprocessing/collections.py b/nemo/collections/common/parts/preprocessing/collections.py index 162761a75f2aa..bc6751073bbfe 100644 --- a/nemo/collections/common/parts/preprocessing/collections.py +++ b/nemo/collections/common/parts/preprocessing/collections.py @@ -914,7 +914,7 @@ def get_full_path(audio_file: str, manifest_file: str) -> str: # Handle all audio files audio_files = {} for key in self.audio_keys: - audio_file = item.pop(key) + audio_file = item[key] if isinstance(audio_file, str): # This dictionary entry points to a single file audio_files[key] = get_full_path(audio_file, manifest_file) diff --git a/tests/collections/asr/test_asr_datasets.py b/tests/collections/asr/test_asr_datasets.py index 04cc8bd2f9e94..931f64545fddd 100644 --- a/tests/collections/asr/test_asr_datasets.py +++ b/tests/collections/asr/test_asr_datasets.py @@ -24,7 +24,7 @@ from omegaconf import DictConfig, OmegaConf from torch.utils.data import DataLoader -from nemo.collections.asr.data import audio_to_text_dataset +from nemo.collections.asr.data import audio_to_audio_dataset, audio_to_text_dataset from nemo.collections.asr.data.audio_to_audio import ( AudioToTargetDataset, AudioToTargetWithEmbeddingDataset, @@ -627,21 +627,21 @@ def test_audio_collate_fn(self): """ batch_size = 16 random_seed = 42 - max_diff_tol = 1e-5 + atol = 1e-5 # Generate random signals _rng = np.random.default_rng(seed=random_seed) signal_to_channels = { - 'input': 2, - 'target': 1, - 'reference': 1, + 'input_signal': 2, + 'target_signal': 1, + 'reference_signal': 1, } signal_to_length = { - 'input': _rng.integers(low=5, high=25, size=batch_size), - 'target': _rng.integers(low=5, high=25, size=batch_size), - 'reference': _rng.integers(low=5, high=25, size=batch_size), + 'input_signal': _rng.integers(low=5, high=25, size=batch_size), + 'target_signal': _rng.integers(low=5, high=25, size=batch_size), + 'reference_signal': _rng.integers(low=5, high=25, size=batch_size), } # Generate batch @@ -658,15 +658,15 @@ def test_audio_collate_fn(self): batched = _audio_collate_fn(batch) batched_signals = { - 'input': batched[0].cpu().detach().numpy(), - 'target': batched[2].cpu().detach().numpy(), - 'reference': batched[4].cpu().detach().numpy(), + 'input_signal': batched[0].cpu().detach().numpy(), + 'target_signal': batched[2].cpu().detach().numpy(), + 'reference_signal': batched[4].cpu().detach().numpy(), } batched_lengths = { - 'input': batched[1].cpu().detach().numpy(), - 'target': batched[3].cpu().detach().numpy(), - 'reference': batched[5].cpu().detach().numpy(), + 'input_signal': batched[1].cpu().detach().numpy(), + 'target_signal': batched[3].cpu().detach().numpy(), + 'reference_signal': batched[5].cpu().detach().numpy(), } # Check outputs @@ -681,8 +681,9 @@ def test_audio_collate_fn(self): uut_signal = b_signal[n][:uut_length, ...] golden_signal = batch[n][signal][:uut_length, ...].cpu().detach().numpy() - max_diff = np.max(np.abs(uut_signal - golden_signal)) - assert max_diff < max_diff_tol, f'Example {n} signal {signal} value mismatch.' + assert np.isclose( + uut_signal, golden_signal, atol=atol + ).all(), f'Example {n} signal {signal} value mismatch.' @pytest.mark.unit def test_audio_to_target_dataset(self): @@ -709,18 +710,19 @@ def test_audio_to_target_dataset(self): sample_rate = 16000 num_examples = 25 data_num_channels = { - 'input': 4, - 'target': 2, + 'input_signal': 4, + 'target_signal': 2, } data_min_duration = 2.0 data_max_duration = 8.0 data_key = { - 'input': 'input_filepath', - 'target': 'target_filepath', + 'input_signal': 'input_filepath', + 'target_signal': 'target_filepath', } # Tolerance max_diff_tol = 1e-6 + atol = 1e-6 # Generate random signals _rng = np.random.default_rng(seed=random_seed) @@ -769,21 +771,35 @@ def test_audio_to_target_dataset(self): # - No constraints on channels or duration dataset = AudioToTargetDataset( manifest_filepath=manifest_filepath, - input_key=data_key['input'], - target_key=data_key['target'], + input_key=data_key['input_signal'], + target_key=data_key['target_signal'], sample_rate=sample_rate, ) + # Also test the corresponding factory + config = { + 'manifest_filepath': manifest_filepath, + 'input_key': data_key['input_signal'], + 'target_key': data_key['target_signal'], + 'sample_rate': sample_rate, + } + dataset_factory = audio_to_audio_dataset.get_audio_to_target_dataset(config) + for n in range(num_examples): item = dataset.__getitem__(n) + item_factory = dataset_factory.__getitem__(n) for signal in data: item_signal = item[signal].cpu().detach().numpy() golden_signal = data[signal][n] - max_diff = np.max(np.abs(item_signal - golden_signal)) - assert ( - max_diff < max_diff_tol - ), f'Test 1: Failed for example {n}, signal {signal} (random seed {random_seed})' + assert np.isclose( + item_signal, golden_signal, atol=atol + ).all(), f'Test 1: Failed for example {n}, signal {signal} (random seed {random_seed})' + + item_factory_signal = item_factory[signal].cpu().detach().numpy() + assert np.isclose( + item_factory_signal, golden_signal, atol=atol + ).all(), f'Test 1: Failed for factory example {n}, signal {signal} (random seed {random_seed})' # Test 2 # - Filtering based on signal duration @@ -792,8 +808,8 @@ def test_audio_to_target_dataset(self): dataset = AudioToTargetDataset( manifest_filepath=manifest_filepath, - input_key=data_key['input'], - target_key=data_key['target'], + input_key=data_key['input_signal'], + target_key=data_key['target_signal'], min_duration=min_duration, max_duration=max_duration, sample_rate=sample_rate, @@ -807,24 +823,23 @@ def test_audio_to_target_dataset(self): for signal in data: item_signal = item[signal].cpu().detach().numpy() golden_signal = data[signal][filtered_examples[n]] - max_diff = np.max(np.abs(item_signal - golden_signal)) - assert ( - max_diff < max_diff_tol - ), f'Test 2: Failed for example {n}, signal {signal} (random seed {random_seed})' + assert np.isclose( + item_signal, golden_signal, atol=atol + ).all(), f'Test 2: Failed for example {n}, signal {signal} (random seed {random_seed})' # Test 3 # - Use channel selector channel_selector = { - 'input': [0, 2], - 'target': 1, + 'input_signal': [0, 2], + 'target_signal': 1, } dataset = AudioToTargetDataset( manifest_filepath=manifest_filepath, - input_key=data_key['input'], - target_key=data_key['target'], - input_channel_selector=channel_selector['input'], - target_channel_selector=channel_selector['target'], + input_key=data_key['input_signal'], + target_key=data_key['target_signal'], + input_channel_selector=channel_selector['input_signal'], + target_channel_selector=channel_selector['target_signal'], sample_rate=sample_rate, ) @@ -835,10 +850,9 @@ def test_audio_to_target_dataset(self): cs = channel_selector[signal] item_signal = item[signal].cpu().detach().numpy() golden_signal = data[signal][n][..., cs] - max_diff = np.max(np.abs(item_signal - golden_signal)) - assert ( - max_diff < max_diff_tol - ), f'Test 3: Failed for example {n}, signal {signal} (random seed {random_seed})' + assert np.isclose( + item_signal, golden_signal, atol=atol + ).all(), f'Test 3: Failed for example {n}, signal {signal} (random seed {random_seed})' # Test 4 # - Use fixed duration (random segment selection) @@ -847,8 +861,8 @@ def test_audio_to_target_dataset(self): dataset = AudioToTargetDataset( manifest_filepath=manifest_filepath, - input_key=data_key['input'], - target_key=data_key['target'], + input_key=data_key['input_signal'], + target_key=data_key['target_signal'], sample_rate=sample_rate, min_duration=audio_duration, audio_duration=audio_duration, @@ -877,10 +891,9 @@ def test_audio_to_target_dataset(self): ), f'Test 4: Signal length ({len(item_signal)}) not matching the expected length ({audio_duration_samples})' # Test signal values - max_diff = np.max(np.abs(item_signal - golden_signal)) - assert ( - max_diff < max_diff_tol - ), f'Test 4: Failed for example {n}, signal {signal} (random seed {random_seed})' + assert np.isclose( + item_signal, golden_signal, atol=atol + ).all(), f'Test 4: Failed for example {n}, signal {signal} (random seed {random_seed})' # Test 5: # - Test collate_fn @@ -919,18 +932,18 @@ def test_audio_to_target_dataset_with_target_list(self): sample_rate = 16000 num_examples = 25 data_num_channels = { - 'input': 4, - 'target': 2, + 'input_signal': 4, + 'target_signal': 2, } data_min_duration = 2.0 data_max_duration = 8.0 data_key = { - 'input': 'input_filepath', - 'target': 'target_filepath', + 'input_signal': 'input_filepath', + 'target_signal': 'target_filepath', } # Tolerance - max_diff_tol = 1e-6 + atol = 1e-6 # Generate random signals _rng = np.random.default_rng(seed=random_seed) @@ -959,7 +972,7 @@ def test_audio_to_target_dataset_with_target_list(self): meta = dict() for signal in data: - if signal == 'target': + if signal == 'target_signal': # Save targets as individual files signal_filename = [] for ch in range(data_num_channels[signal]): @@ -993,21 +1006,34 @@ def test_audio_to_target_dataset_with_target_list(self): # - No constraints on channels or duration dataset = AudioToTargetDataset( manifest_filepath=manifest_filepath, - input_key=data_key['input'], - target_key=data_key['target'], + input_key=data_key['input_signal'], + target_key=data_key['target_signal'], sample_rate=sample_rate, ) + config = { + 'manifest_filepath': manifest_filepath, + 'input_key': data_key['input_signal'], + 'target_key': data_key['target_signal'], + 'sample_rate': sample_rate, + } + dataset_factory = audio_to_audio_dataset.get_audio_to_target_dataset(config) + for n in range(num_examples): item = dataset.__getitem__(n) + item_factory = dataset_factory.__getitem__(n) for signal in data: item_signal = item[signal].cpu().detach().numpy() golden_signal = data[signal][n] - max_diff = np.max(np.abs(item_signal - golden_signal)) - assert ( - max_diff < max_diff_tol - ), f'Test 1: Failed for example {n}, signal {signal} (random seed {random_seed})' + assert np.isclose( + item_signal, golden_signal, atol=atol + ).all(), f'Test 1: Failed for example {n}, signal {signal} (random seed {random_seed})' + + item_factory_signal = item_factory[signal].cpu().detach().numpy() + assert np.isclose( + item_factory_signal, golden_signal, atol=atol + ).all(), f'Test 1: Failed for factory example {n}, signal {signal} (random seed {random_seed})' @pytest.mark.unit def test_audio_to_target_with_reference_dataset(self): @@ -1031,20 +1057,20 @@ def test_audio_to_target_with_reference_dataset(self): sample_rate = 16000 num_examples = 25 data_num_channels = { - 'input': 4, - 'target': 2, - 'reference': 1, + 'input_signal': 4, + 'target_signal': 2, + 'reference_signal': 1, } data_min_duration = 2.0 data_max_duration = 8.0 data_key = { - 'input': 'input_filepath', - 'target': 'target_filepath', - 'reference': 'reference_filepath', + 'input_signal': 'input_filepath', + 'target_signal': 'target_filepath', + 'reference_signal': 'reference_filepath', } # Tolerance - max_diff_tol = 1e-6 + atol = 1e-6 # Generate random signals _rng = np.random.default_rng(seed=random_seed) @@ -1091,26 +1117,42 @@ def test_audio_to_target_with_reference_dataset(self): # Test 1 # - No constraints on channels or duration - # - Reference is synchronized with input and target, so whole reference signal will be loaded + # - Reference is not synchronized with input and target, so whole reference signal will be loaded dataset = AudioToTargetWithReferenceDataset( manifest_filepath=manifest_filepath, - input_key=data_key['input'], - target_key=data_key['target'], - reference_key=data_key['reference'], - reference_is_synchronized=True, + input_key=data_key['input_signal'], + target_key=data_key['target_signal'], + reference_key=data_key['reference_signal'], + reference_is_synchronized=False, sample_rate=sample_rate, ) + # Also test the corresponding factory + config = { + 'manifest_filepath': manifest_filepath, + 'input_key': data_key['input_signal'], + 'target_key': data_key['target_signal'], + 'reference_key': data_key['reference_signal'], + 'reference_is_synchronized': False, + 'sample_rate': sample_rate, + } + dataset_factory = audio_to_audio_dataset.get_audio_to_target_with_reference_dataset(config) + for n in range(num_examples): item = dataset.__getitem__(n) + item_factory = dataset_factory.__getitem__(n) for signal in data: item_signal = item[signal].cpu().detach().numpy() golden_signal = data[signal][n] - max_diff = np.max(np.abs(item_signal - golden_signal)) - assert ( - max_diff < max_diff_tol - ), f'Test 1: Failed for example {n}, signal {signal} (random seed {random_seed})' + assert np.isclose( + item_signal, golden_signal, atol=atol + ).all(), f'Test 1: Failed for example {n}, signal {signal} (random seed {random_seed})' + + item_factory_signal = item_factory[signal].cpu().detach().numpy() + assert np.isclose( + item_factory_signal, golden_signal, atol=atol + ).all(), f'Test 1: Failed for factory example {n}, signal {signal} (random seed {random_seed})' # Test 2 # - Use fixed duration (random segment selection) @@ -1119,9 +1161,9 @@ def test_audio_to_target_with_reference_dataset(self): audio_duration_samples = int(np.floor(audio_duration * sample_rate)) dataset = AudioToTargetWithReferenceDataset( manifest_filepath=manifest_filepath, - input_key=data_key['input'], - target_key=data_key['target'], - reference_key=data_key['reference'], + input_key=data_key['input_signal'], + target_key=data_key['target_signal'], + reference_key=data_key['reference_signal'], reference_is_synchronized=True, sample_rate=sample_rate, min_duration=audio_duration, @@ -1151,10 +1193,9 @@ def test_audio_to_target_with_reference_dataset(self): ), f'Test 2: Signal {signal} length ({len(item_signal)}) not matching the expected length ({audio_duration_samples})' # Test signal values - max_diff = np.max(np.abs(item_signal - golden_signal)) - assert ( - max_diff < max_diff_tol - ), f'Test 2: Failed for example {n}, signal {signal} (random seed {random_seed})' + assert np.isclose( + item_signal, golden_signal, atol=atol + ).all(), f'Test 2: Failed for example {n}, signal {signal} (random seed {random_seed})' # Test 3 # - Use fixed duration (random segment selection) @@ -1163,9 +1204,9 @@ def test_audio_to_target_with_reference_dataset(self): audio_duration_samples = int(np.floor(audio_duration * sample_rate)) dataset = AudioToTargetWithReferenceDataset( manifest_filepath=manifest_filepath, - input_key=data_key['input'], - target_key=data_key['target'], - reference_key=data_key['reference'], + input_key=data_key['input_signal'], + target_key=data_key['target_signal'], + reference_key=data_key['reference_signal'], reference_is_synchronized=False, sample_rate=sample_rate, min_duration=audio_duration, @@ -1182,7 +1223,7 @@ def test_audio_to_target_with_reference_dataset(self): item_signal = item[signal].cpu().detach().numpy() full_golden_signal = data[signal][filtered_examples[n]] - if signal == 'reference': + if signal == 'reference_signal': # Complete signal is loaded for reference golden_signal = full_golden_signal else: @@ -1201,10 +1242,9 @@ def test_audio_to_target_with_reference_dataset(self): ), f'Test 3: Signal {signal} length ({len(item_signal)}) not matching the expected length ({audio_duration_samples})' # Test signal values - max_diff = np.max(np.abs(item_signal - golden_signal)) - assert ( - max_diff < max_diff_tol - ), f'Test 3: Failed for example {n}, signal {signal} (random seed {random_seed})' + assert np.isclose( + item_signal, golden_signal, atol=atol + ).all(), f'Test 3: Failed for example {n}, signal {signal} (random seed {random_seed})' @pytest.mark.unit def test_audio_to_target_with_embedding_dataset(self): @@ -1225,21 +1265,21 @@ def test_audio_to_target_with_embedding_dataset(self): sample_rate = 16000 num_examples = 25 data_num_channels = { - 'input': 4, - 'target': 2, - 'embedding': 1, + 'input_signal': 4, + 'target_signal': 2, + 'embedding_vector': 1, } data_min_duration = 2.0 data_max_duration = 8.0 embedding_length = 64 # 64-dimensional embedding vector data_key = { - 'input': 'input_filepath', - 'target': 'target_filepath', - 'embedding': 'embedding_filepath', + 'input_signal': 'input_filepath', + 'target_signal': 'target_filepath', + 'embedding_vector': 'embedding_filepath', } # Tolerance - max_diff_tol = 1e-6 + atol = 1e-6 # Generate random signals _rng = np.random.default_rng(seed=random_seed) @@ -1252,7 +1292,7 @@ def test_audio_to_target_with_embedding_dataset(self): for signal, num_channels in data_num_channels.items(): data[signal] = [] for n in range(num_examples): - data_length = embedding_length if signal == 'embedding' else data_duration_samples[n] + data_length = embedding_length if signal == 'embedding_vector' else data_duration_samples[n] if num_channels == 1: random_signal = _rng.uniform(low=-0.5, high=0.5, size=(data_length)) @@ -1270,7 +1310,7 @@ def test_audio_to_target_with_embedding_dataset(self): meta = dict() for signal in data: - if signal == 'embedding': + if signal == 'embedding_vector': signal_filename = f'{signal}_{n:02d}.npy' np.save(os.path.join(test_dir, signal_filename), data[signal][n]) @@ -1293,22 +1333,36 @@ def test_audio_to_target_with_embedding_dataset(self): # Test 1 # - No constraints on channels or duration - # - Reference is synchronized with input and target, so whole reference signal will be loaded dataset = AudioToTargetWithEmbeddingDataset( manifest_filepath=manifest_filepath, - input_key=data_key['input'], - target_key=data_key['target'], - embedding_key=data_key['embedding'], + input_key=data_key['input_signal'], + target_key=data_key['target_signal'], + embedding_key=data_key['embedding_vector'], sample_rate=sample_rate, ) + # Also test the corresponding factory + config = { + 'manifest_filepath': manifest_filepath, + 'input_key': data_key['input_signal'], + 'target_key': data_key['target_signal'], + 'embedding_key': data_key['embedding_vector'], + 'sample_rate': sample_rate, + } + dataset_factory = audio_to_audio_dataset.get_audio_to_target_with_embedding_dataset(config) + for n in range(num_examples): item = dataset.__getitem__(n) + item_factory = dataset_factory.__getitem__(n) for signal in data: item_signal = item[signal].cpu().detach().numpy() golden_signal = data[signal][n] - max_diff = np.max(np.abs(item_signal - golden_signal)) - assert ( - max_diff < max_diff_tol - ), f'Test 1: Failed for example {n}, signal {signal} (random seed {random_seed})' + assert np.isclose( + item_signal, golden_signal, atol=atol + ).all(), f'Test 1: Failed for example {n}, signal {signal} (random seed {random_seed})' + + item_factory_signal = item_factory[signal].cpu().detach().numpy() + assert np.isclose( + item_factory_signal, golden_signal, atol=atol + ).all(), f'Test 1: Failed for factory example {n}, signal {signal} (random seed {random_seed})'