diff --git a/nemo/collections/asr/data/audio_to_audio.py b/nemo/collections/asr/data/audio_to_audio.py new file mode 100644 index 0000000000000..9d84cd1e0d54f --- /dev/null +++ b/nemo/collections/asr/data/audio_to_audio.py @@ -0,0 +1,932 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc +import math +import random +from collections import OrderedDict +from typing import Dict, List, Optional, Tuple, Union + +import librosa +import numpy as np +import torch + +from nemo.collections.asr.parts.preprocessing.segment import AudioSegment +from nemo.collections.asr.parts.utils.audio_utils import ChannelSelectorType +from nemo.collections.common.parts.preprocessing.collections import AudioCollection +from nemo.collections.common.parts.utils import flatten +from nemo.core.classes import Dataset +from nemo.core.neural_types import AudioSignal, EncodedRepresentation, LengthsType, NeuralType +from nemo.utils import logging +from nemo.utils.decorators import experimental + +__all__ = [ + 'AudioToTargetDataset', + 'AudioToTargetWithReferenceDataset', + 'AudioToTargetWithEmbeddingDataset', +] + + +def load_samples_synchronized( + audio_files: List[str], + sample_rate: int, + duration: Optional[float] = None, + channel_selectors: Optional[List[ChannelSelectorType]] = None, + fixed_offset: float = 0, + random_offset: bool = False, +) -> List[np.ndarray]: + """Load samples from multiple files with the same start and end point. + + Args: + audio_files: list of paths to audio files + sample_rate: desired sample rate for output samples + duration: Optional desired duration of output samples. + If `None`, the complete files will be loaded. + If set, a segment of `duration` seconds will be loaded from + all files. Segment is synchronized across files, so that + start and end points are the same. + channel_selectors: Optional channel selector for each signal, for selecting + a subset of channels. + fixed_offset: Optional fixed offset when loading samples. + random_offset: If `True`, offset will be randomized when loading a short segment + from a file. The value is randomized between fixed_offset and + max_offset (set depending on the duration and fixed_offset). + + Returns: + List with the same size as `audio_files` but containing numpy arrays + with samples from each audio file. + Each array has shape (num_samples, ) or (num_samples, num_channels), for single- + or multi-channel signal, respectively. + For example, if `audio_files = [path/to/file_1.wav, path/to/file_2.wav]`, + the output will be a list `output = [samples_1, samples_2]`. + """ + if channel_selectors is None: + channel_selectors = [None] * len(audio_files) + + output = [] + + if duration is None: + # Load complete files starting from a fixed offset + output = [] + + for audio_file, channel_selector in zip(audio_files, channel_selectors): + if isinstance(audio_file, str): + segment = AudioSegment.from_file( + audio_file=audio_file, + target_sr=sample_rate, + offset=fixed_offset, + channel_selector=channel_selector, + ) + output.append(segment.samples) + elif isinstance(audio_file, list): + samples_list = [] + for f in audio_file: + segment = AudioSegment.from_file( + audio_file=f, target_sr=sample_rate, offset=fixed_offset, channel_selector=channel_selector + ) + samples_list.append(segment.samples) + output.append(samples_list) + elif audio_file is None: + # Support for inference, when the target signal is `None` + output.append([]) + else: + raise RuntimeError(f'Unexpected audio_file type {type(audio_file)}') + + else: + audio_durations = [librosa.get_duration(filename=f) for f in flatten(audio_files)] + min_duration = min(audio_durations) + available_duration = min_duration - fixed_offset + + if available_duration <= 0: + raise ValueError(f'Fixed offset {fixed_offset}s is larger than shortest file {min_duration}s.') + elif min_duration < duration + fixed_offset: + logging.warning( + f'Shortest file ({min_duration}s) is less than desired duration {duration}s + fixed offset {fixed_offset}s. Returned signals will be shortened to {available_duration}s.' + ) + offset = fixed_offset + num_samples = math.floor(available_duration * sample_rate) + elif random_offset: + # Randomize offset based on the shortest file + max_offset = min_duration - duration + offset = random.uniform(fixed_offset, max_offset) + # Fixed number of samples + num_samples = math.floor(duration * sample_rate) + else: + # Fixed offset + offset = fixed_offset + # Fixed number of samples + num_samples = math.floor(duration * sample_rate) + + # Prepare segments + for audio_file, channel_selector in zip(audio_files, channel_selectors): + # Load segments starting from the same offset + if isinstance(audio_file, str): + segment = AudioSegment.segment_from_file( + audio_file=audio_file, + target_sr=sample_rate, + n_segments=num_samples, + offset=offset, + channel_selector=channel_selector, + ) + output.append(segment.samples) + elif isinstance(audio_file, list): + samples_list = [] + for f in audio_file: + segment = AudioSegment.segment_from_file( + audio_file=f, + target_sr=sample_rate, + n_segments=num_samples, + offset=offset, + channel_selector=channel_selector, + ) + samples_list.append(segment.samples) + output.append(samples_list) + else: + raise RuntimeError(f'Unexpected audio_file type {type(audio_file)}') + + return output + + +def load_samples( + audio_file: str, + sample_rate: int, + duration: Optional[float] = None, + channel_selector: ChannelSelectorType = None, + fixed_offset: float = 0, + random_offset: bool = False, +) -> np.ndarray: + """Load samples from an audio file. + For a single-channel signal, the output is shape (num_samples,). + For a multi-channel signal, the output is shape (num_samples, num_channels). + + Args: + audio_file: path to an audio file + sample_rate: desired sample rate for output samples + duration: Optional desired duration of output samples. + If `None`, the complete file will be loaded. + If set, a segment of `duration` seconds will be loaded. + channel_selector: Optional channel selector, for selecting a subset of channels. + fixed_offset: Optional fixed offset when loading samples. + random_offset: If `True`, offset will be randomized when loading a short segment + from a file. The value is randomized between fixed_offset and + max_offset (set depending on the duration and fixed_offset). + + Returns: + Numpy array with samples from audio file. + The array has shape (num_samples,) for a single-channel signal + or (num_samples, num_channels) for a multi-channel signal. + """ + output = load_samples_synchronized( + audio_files=[audio_file], + sample_rate=sample_rate, + duration=duration, + channel_selectors=[channel_selector], + fixed_offset=fixed_offset, + random_offset=random_offset, + ) + + return output[0] + + +def load_embedding(filepath: str) -> np.ndarray: + """Load an embedding vector from a file. + + Args: + filepath: path to a file storing a vector. + Currently, it is assumed the file is a npy file. + + Returns: + Array loaded from filepath. + """ + if filepath.endswith('.npy'): + with open(filepath, 'rb') as f: + embedding = np.load(f) + else: + raise RuntimeError(f'Unknown embedding file format in file: {filepath}') + + return embedding + + +def list_to_multichannel(signal: Union[np.ndarray, List[np.ndarray]]) -> np.ndarray: + """Convert a list of signals into a multi-channel signal by concatenating + the elements of the list along the channel dimension. + + If input is not a list, it is returned unmodified. + + Args: + signal: list of arrays + + Returns: + Numpy array obrained by concatenating the elements of the list + along the channel dimension (axis=1). + """ + if not isinstance(signal, list): + # Nothing to do there + return signal + elif len(signal) == 0: + # Nothing to do, return as is + return signal + elif len(signal) == 1: + # Nothing to concatenate, return the original format + return signal[0] + + # If multiple signals are provided in a list, we concatenate them along the channel dimension + if signal[0].ndim == 1: + # Single-channel individual files + mc_signal = np.stack(signal, axis=1) + elif signal[0].ndim == 2: + # Multi-channel individual files + mc_signal = np.concatenate(signal, axis=1) + else: + raise RuntimeError(f'Unexpected target with {signal[0].ndim} dimensions.') + + return mc_signal + + +def _audio_collate_fn(batch: List[dict]) -> Tuple[torch.Tensor]: + """Collate a batch of items returned by __getitem__. + Examples for each signal are zero padded to the same length + (batch_length), which is determined by the longest example. + Lengths of the original signals are returned in the output. + + Args: + batch: List of dictionaries. Each element of the list + has the following format + ``` + { + 'signal_0': 1D or 2D tensor, + 'signal_1': 1D or 2D tensor, + ... + 'signal_N': 1D or 2D tensor, + } + ``` + 1D tensors have shape (num_samples,) and 2D tensors + have shape (num_samples, num_channels) + + Returns: + A tuple containing signal tensor and signal length tensor (in samples) + for each signal. + The output has the following format: + ``` + (signal_0, signal_0_length, signal_1, signal_1_length, ..., signal_N, signal_N_length) + ``` + """ + signals = batch[0].keys() + + batched = tuple() + + for signal in signals: + signal_length = [b[signal].shape[0] for b in batch] + # Batch length is determined by the longest signal in the batch + batch_length = max(signal_length) + b_signal = [] + for s_len, b in zip(signal_length, batch): + # check if padding is necessary + if s_len < batch_length: + if b[signal].ndim == 1: + # single-channel signal + pad = (0, batch_length - s_len) + elif b[signal].ndim == 2: + # multi-channel signal + pad = (0, 0, 0, batch_length - s_len) + else: + raise RuntimeError( + f'Signal {signal} has unsuported dimensions {signal.shape}. Currently, only 1D and 2D arrays are supported.' + ) + b[signal] = torch.nn.functional.pad(b[signal], pad) + # append the current padded signal + b_signal.append(b[signal]) + # (signal_batched, signal_length) + batched += (torch.stack(b_signal), torch.tensor(signal_length, dtype=torch.int32)) + + # Currently, outputs are expected to be in a tuple, where each element must correspond + # to the output type in the OrderedDict returned by output_types. + # + # Therefore, we return batched signals by interleaving signals and their length: + # (signal_0, signal_0_length, signal_1, signal_1_length, ...) + return batched + + +@experimental +class BaseAudioDataset(Dataset): + """Base class of audio datasets, providing common functionality + for other audio datasets. + + Each line of the manifest file is expected to have the following format + ``` + { + audio_key[0]: 'path/to/audio_file_0', + audio_key[1]: 'path/to/audio_file_1', + ... + 'duration': duration_of_input, + } + ``` + + Args: + manifest_filepath: Path to manifest file in a format described above. + sample_rate: Sample rate for loaded audio signals. + audio_to_manifest_key: Dictionary mapping audio signal labels to manifest keys. + audio_duration: Optional duration of each item returned by __getitem__. + If `None`, complete audio will be loaded. + If set, a subsegment will be loaded synchronously from + target and audio, i.e., with the same start and end point. + random_offset: If `True`, offset will be randomized when loading a subsegment + from a file. + max_duration: If audio exceeds this length, do not include in dataset. + min_duration: If audio is less than this length, do not include in dataset. + max_utts: Limit number of utterances. + """ + + @property + @abc.abstractmethod + def output_types(self) -> Optional[Dict[str, NeuralType]]: + """Returns definitions of module output ports. + """ + + def __init__( + self, + manifest_filepath: str, + sample_rate: int, + audio_to_manifest_key: Dict[str, Union[str, List[str]]], + audio_duration: Optional[float] = None, + random_offset: bool = False, + max_duration: Optional[float] = None, + min_duration: Optional[float] = None, + max_utts: int = 0, + ): + """Instantiates an audio_dataset. + """ + super().__init__() + + if type(manifest_filepath) == str: + manifest_filepath = manifest_filepath.split(',') + + for audio_key, manifest_key in audio_to_manifest_key.items(): + if type(manifest_key) == str and ',' in manifest_key: + audio_to_manifest_key[audio_key] = manifest_key.split(',') + + self.collection = AudioCollection( + manifest_files=manifest_filepath, + audio_to_manifest_key=audio_to_manifest_key, + min_duration=min_duration, + max_duration=max_duration, + max_number=max_utts, + ) + + self.sample_rate = sample_rate + self.audio_duration = audio_duration + self.random_offset = random_offset + + def num_channels(self, signal_key) -> int: + """Returns the number of channels for a particular signal in + items prepared by this dictionary. + + More specifically, this will get the tensor from the first + item in the dataset, check if it's a one- or two-dimensional + tensor, and return the number of channels based on the size + of the second axis (shape[1]). + + NOTE: + This assumes that all examples returned by `__getitem__` + have the same number of channels. + + Args: + signal_key: string, used to select a signal from the dictionary + output by __getitem__ + + Returns: + Number of channels for the selected signal. + """ + # Assumption: whole dataset has the same number of channels + item = self.__getitem__(0) + if item[signal_key].ndim == 1: + return 1 + elif item[signal_key].ndim == 2: + return item[signal_key].shape[1] + else: + raise RuntimeError( + f'Unexpected number of dimension for signal {signal_key} with shape {item[signal_key].shape}' + ) + + @abc.abstractmethod + def __getitem__(self, index: int) -> Dict[str, torch.Tensor]: + """Return a single example from the dataset. + + Args: + index: integer index of an example in the collection + + Returns: + Dictionary providing mapping from signal to its tensor. + For example: + ``` + { + 'input': input_tensor, + 'target': target_tensor, + } + ``` + """ + + def __len__(self) -> int: + """Return the number of examples in the dataset. + """ + return len(self.collection) + + def _collate_fn(self, batch) -> Tuple[torch.Tensor]: + """Collate items in a batch. + """ + return _audio_collate_fn(batch) + + +@experimental +class AudioToTargetDataset(BaseAudioDataset): + """A dataset for audio-to-audio tasks where the goal is to use + an input signal to recover the corresponding target signal. + + Each line of the manifest file is expected to have the following format + ``` + { + 'input_key': 'path/to/input.wav', + 'target_key': 'path/to/path_to_target.wav', + 'duration': duration_of_input, + } + ``` + + Additionally, multiple audio files may be provided for each key in the manifest, for example, + ``` + { + 'input_key': 'path/to/input.wav', + 'target_key': ['path/to/path_to_target_ch0.wav', 'path/to/path_to_target_ch1.wav'], + 'duration': duration_of_input, + } + ``` + + Keys for input and target signals can be configured in the constructor (`input_key` and `target_key`). + + Args: + manifest_filepath: Path to manifest file in a format described above. + sample_rate: Sample rate for loaded audio signals. + input_key: Key pointing to input audio files in the manifest + target_key: Key pointing to target audio files in manifest + audio_duration: Optional duration of each item returned by __getitem__. + If `None`, complete audio will be loaded. + If set, a random subsegment will be loaded synchronously from + target and audio, i.e., with the same start and end point. + random_offset: If `True`, offset will be randomized when loading a subsegment + from a file. + max_duration: If audio exceeds this length, do not include in dataset. + min_duration: If audio is less than this length, do not include in dataset. + max_utts: Limit number of utterances. + input_channel_selector: Optional, select subset of channels from each input audio file. + If `None`, all channels will be loaded. + target_channel_selector: Optional, select subset of channels from each input audio file. + If `None`, all channels will be loaded. + """ + + def __init__( + self, + manifest_filepath: str, + sample_rate: int, + input_key: str, + target_key: str, + audio_duration: Optional[float] = None, + random_offset: bool = False, + max_duration: Optional[float] = None, + min_duration: Optional[float] = None, + max_utts: Optional[int] = None, + input_channel_selector: Optional[int] = None, + target_channel_selector: Optional[int] = None, + ): + self.audio_to_manifest_key = { + 'input': input_key, + 'target': target_key, + } + self.input_channel_selector = input_channel_selector + self.target_channel_selector = target_channel_selector + + super().__init__( + manifest_filepath=manifest_filepath, + audio_to_manifest_key=self.audio_to_manifest_key, + sample_rate=sample_rate, + audio_duration=audio_duration, + random_offset=random_offset, + max_duration=max_duration, + min_duration=min_duration, + max_utts=max_utts, + ) + + @property + def output_types(self) -> Optional[Dict[str, NeuralType]]: + """Returns definitions of module output ports. + + Returns: + Ordered dictionary in the following form: + ``` + { + 'input_signal': batched single- or multi-channel format, + 'input_length': batched original length of each input signal + 'target_signal': batched single- or multi-channel format, + 'target_length': batched original length of each target signal + } + ``` + """ + sc_audio_type = NeuralType(('B', 'T'), AudioSignal()) + mc_audio_type = NeuralType(('B', 'T', 'C'), AudioSignal()) + + return OrderedDict( + input_signal=sc_audio_type if self.num_channels('input_signal') == 1 else mc_audio_type, + input_length=NeuralType(('B',), LengthsType()), + target_signal=sc_audio_type if self.num_channels('target_signal') == 1 else mc_audio_type, + target_length=NeuralType(('B',), LengthsType()), + ) + + def __getitem__(self, index: int) -> Dict[str, torch.Tensor]: + """Return a single example from the dataset. + + Args: + index: integer index of an example in the collection + + Returns: + Dictionary providing mapping from signal to its tensor. + ``` + { + 'input_signal': input_tensor, + 'target_signal': target_tensor, + } + ``` + """ + example = self.collection[index] + + input_file = example.audio_files['input'] + target_file = example.audio_files['target'] + + # Load the same segment for different signals + input_signal, target_signal = load_samples_synchronized( + audio_files=[input_file, target_file], + channel_selectors=[self.input_channel_selector, self.target_channel_selector], + sample_rate=self.sample_rate, + duration=self.audio_duration, + fixed_offset=example.offset, + random_offset=self.random_offset, + ) + + # If necessary, convert a list of arrays into a multi-channel array + input_signal = list_to_multichannel(input_signal) + target_signal = list_to_multichannel(target_signal) + + # Output dictionary + output = OrderedDict(input_signal=torch.tensor(input_signal), target_signal=torch.tensor(target_signal),) + + return output + + +@experimental +class AudioToTargetWithReferenceDataset(BaseAudioDataset): + """A dataset for audio-to-audio tasks where the goal is to use + an input signal to recover the corresponding target signal and an + additional reference signal is available. + + This can be used, for example, when a reference signal is + available from + - enrollment utterance for the target signal + - echo reference from playback + - reference from another sensor that correlates with the target signal + + Each line of the manifest file is expected to have the following format + ``` + { + 'input_key': 'path/to/input.wav', + 'target_key': 'path/to/path_to_target.wav', + 'reference_key': 'path/to/path_to_reference.wav', + 'duration': duration_of_input, + } + ``` + + Keys for input, target and reference signals can be configured in the constructor. + + Args: + manifest_filepath: Path to manifest file in a format described above. + sample_rate: Sample rate for loaded audio signals. + input_key: Key pointing to input audio files in the manifest + target_key: Key pointing to target audio files in manifest + reference_key: Key pointing to reference audio files in manifest + audio_duration: Optional duration of each item returned by __getitem__. + If `None`, complete audio will be loaded. + If set, a random subsegment will be loaded synchronously from + target and audio, i.e., with the same start and end point. + random_offset: If `True`, offset will be randomized when loading a subsegment + from a file. + max_duration: If audio exceeds this length, do not include in dataset. + min_duration: If audio is less than this length, do not include in dataset. + max_utts: Limit number of utterances. + input_channel_selector: Optional, select subset of channels from each input audio file. + If `None`, all channels will be loaded. + target_channel_selector: Optional, select subset of channels from each input audio file. + If `None`, all channels will be loaded. + reference_channel_selector: Optional, select subset of channels from each input audio file. + If `None`, all channels will be loaded. + reference_is_synchronized: If True, it is assumed that the reference signal is synchronized + with the input signal, so the same subsegment will be loaded as for + input and target. If False, reference signal will be loaded independently + from input and target. + """ + + def __init__( + self, + manifest_filepath: str, + sample_rate: int, + input_key: str, + target_key: str, + reference_key: str, + audio_duration: Optional[float] = None, + random_offset: bool = False, + max_duration: Optional[float] = None, + min_duration: Optional[float] = None, + max_utts: Optional[int] = None, + input_channel_selector: Optional[int] = None, + target_channel_selector: Optional[int] = None, + reference_channel_selector: Optional[int] = None, + reference_is_synchronized: bool = True, # can be disable when reference is an enrollment utterance + ): + self.audio_to_manifest_key = { + 'input': input_key, + 'target': target_key, + 'reference': reference_key, + } + self.input_channel_selector = input_channel_selector + self.target_channel_selector = target_channel_selector + self.reference_channel_selector = reference_channel_selector + self.reference_is_synchronized = reference_is_synchronized + + super().__init__( + manifest_filepath=manifest_filepath, + audio_to_manifest_key=self.audio_to_manifest_key, + sample_rate=sample_rate, + audio_duration=audio_duration, + random_offset=random_offset, + max_duration=max_duration, + min_duration=min_duration, + max_utts=max_utts, + ) + + @property + def output_types(self) -> Optional[Dict[str, NeuralType]]: + """Returns definitions of module output ports. + + Returns: + Ordered dictionary in the following form: + ``` + { + 'input_signal': batched single- or multi-channel format, + 'input_length': batched original length of each input signal + 'target_signal': batched single- or multi-channel format, + 'target_length': batched original length of each target signal + 'reference_signal': single- or multi-channel format, + 'reference_length': original length of each reference signal + } + ``` + """ + sc_audio_type = NeuralType(('B', 'T'), AudioSignal()) + mc_audio_type = NeuralType(('B', 'T', 'C'), AudioSignal()) + + return OrderedDict( + input_signal=sc_audio_type if self.num_channels('input_signal') == 1 else mc_audio_type, + input_length=NeuralType(('B',), LengthsType()), + target_signal=sc_audio_type if self.num_channels('target_signal') == 1 else mc_audio_type, + target_length=NeuralType(('B',), LengthsType()), + reference_signal=sc_audio_type if self.num_channels('reference_signal') == 1 else mc_audio_type, + reference_length=NeuralType(('B',), LengthsType()), + ) + + def __getitem__(self, index: int) -> Dict[str, torch.Tensor]: + """Return a single example from the dataset. + + Args: + index: integer index of an example in the collection + + Returns: + Dictionary providing mapping from signal to its tensor. + ``` + { + 'input_signal': input_tensor, + 'target_signal': target_tensor, + 'reference_signal': reference_tensor, + } + ``` + """ + + example = self.collection[index] + + input_file = example.audio_files['input'] + target_file = example.audio_files['target'] + reference_file = example.audio_files['reference'] + + if self.reference_is_synchronized: + # Load synchronized segments from input, target and reference + input_signal, target_signal, reference_signal = load_samples_synchronized( + audio_files=[input_file, target_file, reference_file], + channel_selectors=[ + self.input_channel_selector, + self.target_channel_selector, + self.reference_channel_selector, + ], + sample_rate=self.sample_rate, + duration=self.audio_duration, + fixed_offset=example.offset, + random_offset=self.random_offset, + ) + else: + # Load the synchronized segments from input and target + input_signal, target_signal = load_samples_synchronized( + audio_files=[input_file, target_file], + channel_selectors=[self.input_channel_selector, self.target_channel_selector], + sample_rate=self.sample_rate, + duration=self.audio_duration, + fixed_offset=example.offset, + random_offset=self.random_offset, + ) + + # Reference is not synchronized with input/target, get samples independently + reference_signal = load_samples( + audio_file=reference_file, + sample_rate=self.sample_rate, + duration=None, # TODO: add reference_duration to __init__ + channel_selector=self.reference_channel_selector, + ) + + # If necessary, convert a list of arrays into a multi-channel array + input_signal = list_to_multichannel(input_signal) + target_signal = list_to_multichannel(target_signal) + reference_signal = list_to_multichannel(reference_signal) + + # Output dictionary + output = OrderedDict( + input_signal=torch.tensor(input_signal), + target_signal=torch.tensor(target_signal), + reference_signal=torch.tensor(reference_signal), + ) + + return output + + +@experimental +class AudioToTargetWithEmbeddingDataset(BaseAudioDataset): + """A dataset for audio-to-audio tasks where the goal is to use + an input signal to recover the corresponding target signal and an + additional embedding signal. It is assumed that the embedding + is in a form of a vector. + + Each line of the manifest file is expected to have the following format + ``` + { + input_key: 'path/to/input.wav', + target_key: 'path/to/path_to_target.wav', + embedding_key: 'path/to/path_to_reference.npy', + 'duration': duration_of_input, + } + ``` + + Keys for input, target and embedding signals can be configured in the constructor. + + Args: + manifest_filepath: Path to manifest file in a format described above. + sample_rate: Sample rate for loaded audio signals. + input_key: Key pointing to input audio files in the manifest + target_key: Key pointing to target audio files in manifest + embedding_key: Key pointing to embedding files in manifest + audio_duration: Optional duration of each item returned by __getitem__. + If `None`, complete audio will be loaded. + If set, a random subsegment will be loaded synchronously from + target and audio, i.e., with the same start and end point. + random_offset: If `True`, offset will be randomized when loading a subsegment + from a file. + max_duration: If audio exceeds this length, do not include in dataset. + min_duration: If audio is less than this length, do not include in dataset. + max_utts: Limit number of utterances. + input_channel_selector: Optional, select subset of channels from each input audio file. + If `None`, all channels will be loaded. + target_channel_selector: Optional, select subset of channels from each input audio file. + If `None`, all channels will be loaded. + """ + + def __init__( + self, + manifest_filepath: str, + sample_rate: int, + input_key: str, + target_key: str, + embedding_key: str, + audio_duration: Optional[float] = None, + random_offset: bool = False, + max_duration: Optional[float] = None, + min_duration: Optional[float] = None, + max_utts: Optional[int] = None, + input_channel_selector: Optional[int] = None, + target_channel_selector: Optional[int] = None, + ): + self.audio_to_manifest_key = { + 'input': input_key, + 'target': target_key, + 'embedding': embedding_key, + } + self.input_channel_selector = input_channel_selector + self.target_channel_selector = target_channel_selector + + super().__init__( + manifest_filepath=manifest_filepath, + audio_to_manifest_key=self.audio_to_manifest_key, + sample_rate=sample_rate, + audio_duration=audio_duration, + random_offset=random_offset, + max_duration=max_duration, + min_duration=min_duration, + max_utts=max_utts, + ) + + @property + def output_types(self) -> Optional[Dict[str, NeuralType]]: + """Returns definitions of module output ports. + + Returns: + Ordered dictionary in the following form: + ``` + { + 'input_signal': batched single- or multi-channel format, + 'input_length': batched original length of each input signal + 'target_signal': batched single- or multi-channel format, + 'target_length': batched original length of each target signal + 'embedding_vector': batched embedded vector format, + 'embedding_length': batched original length of each embedding vector + } + ``` + """ + sc_audio_type = NeuralType(('B', 'T'), AudioSignal()) + mc_audio_type = NeuralType(('B', 'T', 'C'), AudioSignal()) + + return OrderedDict( + input_signal=sc_audio_type if self.num_channels('input_signal') == 1 else mc_audio_type, + input_length=NeuralType(('B',), LengthsType()), + target_signal=sc_audio_type if self.num_channels('target_signal') == 1 else mc_audio_type, + target_length=NeuralType(('B',), LengthsType()), + embedding_vector=NeuralType(('B', 'D'), EncodedRepresentation()), + embedding_length=NeuralType(('B',), LengthsType()), + ) + + def __getitem__(self, index): + """Return a single example from the dataset. + + Args: + index: integer index of an example in the collection + + Returns: + Dictionary providing mapping from signal to its tensor. + ``` + { + 'input_signal': input_tensor, + 'target_signal': target_tensor, + 'embedding_vector': embedding_tensor, + } + ``` + """ + + example = self.collection[index] + + input_file = example.audio_files['input'] + target_file = example.audio_files['target'] + embedding_file = example.audio_files['embedding'] + + # Load synchronized segments from input and target + input_signal, target_signal = load_samples_synchronized( + audio_files=[input_file, target_file], + channel_selectors=[self.input_channel_selector, self.target_channel_selector,], + sample_rate=self.sample_rate, + duration=self.audio_duration, + fixed_offset=example.offset, + random_offset=self.random_offset, + ) + + # Load embedding + embedding_vector = load_embedding(embedding_file) + + # If necessary, convert a list of arrays into a multi-channel array + input_signal = list_to_multichannel(input_signal) + target_signal = list_to_multichannel(target_signal) + + # Output dictionary + output = OrderedDict( + input_signal=torch.tensor(input_signal), + target_signal=torch.tensor(target_signal), + embedding_vector=torch.tensor(embedding_vector), + ) + + return output diff --git a/nemo/collections/asr/data/audio_to_audio_dataset.py b/nemo/collections/asr/data/audio_to_audio_dataset.py new file mode 100644 index 0000000000000..5835d4d349088 --- /dev/null +++ b/nemo/collections/asr/data/audio_to_audio_dataset.py @@ -0,0 +1,91 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nemo.collections.asr.data import audio_to_audio + + +def get_audio_to_target_dataset(config: dict) -> audio_to_audio.AudioToTargetDataset: + """Instantiates an audio-to-audio dataset. + + Args: + config: Config of AudioToTargetDataset. + + Returns: + An instance of AudioToTargetDataset + """ + dataset = audio_to_audio.AudioToTargetDataset( + manifest_filepath=config['manifest_filepath'], + sample_rate=config['sample_rate'], + input_key=config['input_key'], + target_key=config['target_key'], + audio_duration=config.get('audio_duration', None), + max_duration=config.get('max_duration', None), + min_duration=config.get('min_duration', None), + max_utts=config.get('max_utts', 0), + input_channel_selector=config.get('input_channel_selector', None), + target_channel_selector=config.get('target_channel_selector', None), + ) + return dataset + + +def get_audio_to_target_with_reference_dataset(config: dict) -> audio_to_audio.AudioToTargetWithReferenceDataset: + """Instantiates an audio-to-audio dataset. + + Args: + config: Config of AudioToTargetWithReferenceDataset. + + Returns: + An instance of AudioToTargetWithReferenceDataset + """ + dataset = audio_to_audio.AudioToTargetWithReferenceDataset( + manifest_filepath=config['manifest_filepath'], + sample_rate=config['sample_rate'], + input_key=config['input_key'], + target_key=config['target_key'], + reference_key=config['reference_key'], + audio_duration=config.get('audio_duration', None), + max_duration=config.get('max_duration', None), + min_duration=config.get('min_duration', None), + max_utts=config.get('max_utts', 0), + input_channel_selector=config.get('input_channel_selector', None), + target_channel_selector=config.get('target_channel_selector', None), + reference_channel_selector=config.get('reference_channel_selector', None), + reference_is_synchronized=config.get('reference_is_synchronized', True), + ) + return dataset + + +def get_audio_to_target_with_embedding_dataset(config: dict) -> audio_to_audio.AudioToTargetWithEmbeddingDataset: + """Instantiates an audio-to-audio dataset. + + Args: + config: Config of AudioToTargetWithEmbeddingDataset. + + Returns: + An instance of AudioToTargetWithEmbeddingDataset + """ + dataset = audio_to_audio.AudioToTargetWithEmbeddingDataset( + manifest_filepath=config['manifest_filepath'], + sample_rate=config['sample_rate'], + input_key=config['input_key'], + target_key=config['target_key'], + embedding_key=config['embedding_key'], + audio_duration=config.get('audio_duration', None), + max_duration=config.get('max_duration', None), + min_duration=config.get('min_duration', None), + max_utts=config.get('max_utts', 0), + input_channel_selector=config.get('input_channel_selector', None), + target_channel_selector=config.get('target_channel_selector', None), + ) + return dataset diff --git a/nemo/collections/asr/parts/preprocessing/segment.py b/nemo/collections/asr/parts/preprocessing/segment.py index 2aca9bdf1b555..1b22dcf668402 100644 --- a/nemo/collections/asr/parts/preprocessing/segment.py +++ b/nemo/collections/asr/parts/preprocessing/segment.py @@ -257,12 +257,22 @@ def from_file( @classmethod def segment_from_file( - cls, audio_file, target_sr=None, n_segments=0, trim=False, orig_sr=None, channel_selector=None, + cls, audio_file, target_sr=None, n_segments=0, trim=False, orig_sr=None, channel_selector=None, offset=None ): - """Grabs n_segments number of samples from audio_file randomly from the - file as opposed to at a specified offset. + """Grabs n_segments number of samples from audio_file. + If offset is not provided, n_segments are selected randomly. + If offset is provided, it is used to calculate the starting sample. Note that audio_file can be either the file path, or a file-like object. + + :param audio_file: path to a file or a file-like object + :param target_sr: sample rate for the output samples + :param n_segments: desired number of samples + :param trim: if true, trim leading and trailing silence from an audio signal + :param orig_sr: the original sample rate + :param channel selector: select a subset of channels. If set to `None`, the original signal will be used. + :param offset: fixed offset in seconds + :return: numpy array of samples """ is_segmented = False try: @@ -275,13 +285,20 @@ def segment_from_file( if 0 < n_segments_at_original_sr < len(f): max_audio_start = len(f) - n_segments_at_original_sr - audio_start = random.randint(0, max_audio_start) + if offset is None: + audio_start = random.randint(0, max_audio_start) + else: + audio_start = math.floor(offset * sample_rate) + if audio_start > max_audio_start: + raise RuntimeError( + f'Provided audio start ({audio_start_seconds} seconds = {audio_start} samples) is larger than the maximum possible ({max_audio_start})' + ) f.seek(audio_start) samples = f.read(n_segments_at_original_sr, dtype='float32') is_segmented = True - elif n_segments_at_original_sr >= len(f): + elif n_segments_at_original_sr > len(f): logging.warning( - f"Number of segments is greater than the length of the audio file {audio_file}. This may lead to shape mismatch errors." + f"Number of segments ({n_segments_at_original_sr}) is greater than the length ({len(f)}) of the audio file {audio_file}. This may lead to shape mismatch errors." ) samples = f.read(dtype='float32') else: @@ -363,8 +380,7 @@ def subsegment(self, start_time=None, end_time=None): :param end_time: End of subsegment in seconds. :type end_time: float :raise ValueError: If start_time or end_time is incorrectly set, - e.g. out - of bounds in time. + e.g. out of bounds in time. """ start_time = 0.0 if start_time is None else start_time end_time = self.duration if end_time is None else end_time diff --git a/nemo/collections/asr/parts/utils/audio_utils.py b/nemo/collections/asr/parts/utils/audio_utils.py index 32ae2505f0c5c..793f6690b7b40 100644 --- a/nemo/collections/asr/parts/utils/audio_utils.py +++ b/nemo/collections/asr/parts/utils/audio_utils.py @@ -17,6 +17,7 @@ import librosa import numpy as np import numpy.typing as npt +import scipy import soundfile as sf from scipy.spatial.distance import pdist, squareform @@ -378,3 +379,22 @@ def pow2db(power: float, eps: Optional[float] = 1e-16) -> float: Power in dB. """ return 10 * np.log10(power + eps) + + +def get_segment_start(signal: np.ndarray, segment: np.ndarray) -> int: + """Get starting point of `segment` in `signal`. + We assume that `segment` is a part of `signal`. + + Args: + signal: numpy array with shape (num_samples,) + segment: numpy array with shape (num_samples,) + + Returns: + Index of the start of `segment` in `signal`. + """ + if len(signal) <= len(segment): + raise ValueError( + f'segment must be shorter than signal: len(segment) = {len(segment)}, len(signal) = {len(signal)}' + ) + cc = scipy.signal.correlate(signal, segment, mode='valid') + return np.argmax(cc) diff --git a/nemo/collections/common/parts/preprocessing/collections.py b/nemo/collections/common/parts/preprocessing/collections.py index 82f072c68b6af..acfe952d3298d 100644 --- a/nemo/collections/common/parts/preprocessing/collections.py +++ b/nemo/collections/common/parts/preprocessing/collections.py @@ -16,7 +16,8 @@ import json import os from itertools import combinations -from typing import Any, Dict, List, Optional, Union +from pathlib import Path +from typing import Any, Dict, Iterable, List, Optional, Union import pandas as pd @@ -784,3 +785,196 @@ def __parse_item_rttm(self, line: str, manifest_file: str) -> Dict[str, Any]: offset=item.get('offset', None), ) return item + + +class Audio(_Collection): + """Prepare a list of all audio items, filtered by duration. + """ + + OUTPUT_TYPE = collections.namedtuple(typename='Audio', field_names='audio_files duration offset text') + + def __init__( + self, + audio_files_list: List[Dict[str, str]], + duration_list: List[float], + offset_list: List[float], + text_list: List[str], + min_duration: Optional[float] = None, + max_duration: Optional[float] = None, + max_number: Optional[int] = None, + do_sort_by_duration: bool = False, + ): + """Instantiantes an list of audio files. + + Args: + audio_files_list: list of dictionaries with mapping from audio_key to audio_filepath + duration_list: list of durations of input files + offset_list: list of offsets + text_list: list of texts + min_duration: Minimum duration to keep entry with (default: None). + max_duration: Maximum duration to keep entry with (default: None). + max_number: Maximum number of samples to collect. + do_sort_by_duration: True if sort samples list by duration. + """ + + output_type = self.OUTPUT_TYPE + data, total_duration = [], 0.0 + num_filtered, duration_filtered = 0, 0.0 + + for audio_files, duration, offset, text in zip(audio_files_list, duration_list, offset_list, text_list): + # Duration filters + if min_duration is not None and duration < min_duration: + duration_filtered += duration + num_filtered += 1 + continue + + if max_duration is not None and duration > max_duration: + duration_filtered += duration + num_filtered += 1 + continue + + total_duration += duration + data.append(output_type(audio_files, duration, offset, text)) + + # Max number of entities filter + if len(data) == max_number: + break + + if do_sort_by_duration: + data.sort(key=lambda entity: entity.duration) + + logging.info("Dataset loaded with %d files totalling %.2f hours", len(data), total_duration / 3600) + logging.info("%d files were filtered totalling %.2f hours", num_filtered, duration_filtered / 3600) + + super().__init__(data) + + +class AudioCollection(Audio): + """List of audio files from a manifest file. + """ + + def __init__( + self, manifest_files: Union[str, List[str]], audio_to_manifest_key: Dict[str, str], *args, **kwargs, + ): + """Instantiates a list of audio files loaded from a manifest file. + + Args: + manifest_files: path to a single manifest file or a list of paths + audio_to_manifest_key: dictionary mapping audio signals to keys of the manifest + """ + # Keys from manifest which contain audio + self.audio_to_manifest_key = audio_to_manifest_key + + # Initialize data + audio_files_list, duration_list, offset_list, text_list = [], [], [], [] + + # Parse manifest files + for item in manifest.item_iter(manifest_files, parse_func=self.__parse_item): + audio_files_list.append(item['audio_files']) + duration_list.append(item['duration']) + offset_list.append(item['offset']) + text_list.append(item['text']) + + super().__init__(audio_files_list, duration_list, offset_list, text_list, *args, **kwargs) + + def __parse_item(self, line: str, manifest_file: str) -> Dict[str, Any]: + """Parse a single line from a manifest file. + + Args: + line: a string representing a line from a manifest file in JSON format + manifest_file: path to the manifest file. Used to resolve relative paths. + + Returns: + Dictionary with audio_files, duration, and offset. + """ + # Utility function + def get_full_path(audio_file: str, manifest_file: str) -> str: + """ # TODO move to some utility module, since this is + relatively general general + + Get full path to audio_file. + If path in `audio_file` is not pointing to a valid file and it + is a relative path, we assume that the path is relative to + manifest_dir. + """ + audio_file = Path(audio_file) + manifest_dir = Path(manifest_file).parent + + if (len(str(audio_file)) < 255) and not audio_file.is_file() and not audio_file.is_absolute(): + # assume the path in manifest is relative to manifest_dir + audio_file_path = manifest_dir / audio_file + if audio_file_path.is_file(): + audio_file = str(audio_file_path.absolute()) + else: + audio_file = os.path.expanduser(audio_file) + else: + audio_file = os.path.expanduser(audio_file) + return audio_file + + # Local utility function + def get_audio_file(item: Dict, manifest_key: Union[str, List[str]]): + """Get item[key] if key is string, or a list + of strings by combining item[key[0]], item[key[1]], etc. + """ + # Prepare audio file(s) + if manifest_key is None: + # Support for inference, when a target key is None + audio_file = None + elif isinstance(manifest_key, str): + # Load files from a single manifest key + audio_file = item[manifest_key] + elif isinstance(manifest_key, Iterable): + # Load files from multiple manifest keys + audio_file = [] + for key in manifest_key: + item_key = item[key] + if isinstance(item_key, str): + audio_file.append(item_key) + elif isinstance(item_key, list): + audio_file += item_key + else: + raise ValueError(f'Unexpected type {type(item_key)} of item for key {key}: {item_key}') + else: + raise ValueError(f'Unexpected type {type(manifest_key)} of manifest_key: {manifest_key}') + + return audio_file + + # Convert JSON line to a dictionary + item = json.loads(line) + + # Handle all audio files + audio_files = {} + for audio_key, manifest_key in self.audio_to_manifest_key.items(): + + audio_file = get_audio_file(item, manifest_key) + + # Get full path to audio file(s) + if isinstance(audio_file, str): + # This dictionary entry points to a single file + audio_files[audio_key] = get_full_path(audio_file, manifest_file) + elif isinstance(audio_file, Iterable): + # This dictionary entry points to multiple files + # Get the files and keep the list structure for this key + audio_files[audio_key] = [get_full_path(f, manifest_file) for f in audio_file] + elif audio_key == 'target' and audio_file is None: + # For inference, we don't need the target + audio_files[audio_key] = None + else: + raise ValueError(f'Unexpected type {type(audio_file)} of audio_file: {audio_file}') + item['audio_files'] = audio_files + + # Handle duration + if 'duration' not in item: + raise ValueError(f'Duration not available in line: {line}. Manifest file: {manifest_file}') + + # Handle offset + if 'offset' not in item: + item['offset'] = 0.0 + + # Handle text + if 'text' not in item: + item['text'] = None + + return dict( + audio_files=item['audio_files'], duration=item['duration'], offset=item['offset'], text=item['text'] + ) diff --git a/nemo/collections/common/parts/utils.py b/nemo/collections/common/parts/utils.py index d6d8c1f95cacb..0bfcfaf9cdfb2 100644 --- a/nemo/collections/common/parts/utils.py +++ b/nemo/collections/common/parts/utils.py @@ -14,11 +14,11 @@ import math import os -from typing import List +from typing import Iterable, List import torch.nn as nn -__all__ = ['if_exist', '_compute_softmax'] +__all__ = ['if_exist', '_compute_softmax', 'flatten'] activation_registry = { "identity": nn.Identity, @@ -67,3 +67,32 @@ def _compute_softmax(scores): for score in exp_scores: probs.append(score / total_sum) return probs + + +def flatten_iterable(iter: Iterable) -> Iterable: + """Flatten an iterable which contains values or + iterables with values. + + Args: + iter: iterable containing values at the deepest level. + + Returns: + A flat iterable containing values. + """ + for it in iter: + if isinstance(it, str) or not isinstance(it, Iterable): + yield it + else: + yield from flatten_iterable(it) + + +def flatten(list_in: List) -> List: + """Flatten a list of (nested lists of) values into a flat list. + + Args: + list_in: list of values, possibly nested + + Returns: + A flat list of values. + """ + return list(flatten_iterable(list_in)) diff --git a/tests/collections/asr/test_asr_datasets.py b/tests/collections/asr/test_asr_datasets.py index 83c94fa59631b..547bb00d4b763 100644 --- a/tests/collections/asr/test_asr_datasets.py +++ b/tests/collections/asr/test_asr_datasets.py @@ -19,11 +19,19 @@ import numpy as np import pytest +import soundfile as sf import torch.cuda from omegaconf import DictConfig, OmegaConf from torch.utils.data import DataLoader -from nemo.collections.asr.data import audio_to_text_dataset +from nemo.collections.asr.data import audio_to_audio_dataset, audio_to_text_dataset +from nemo.collections.asr.data.audio_to_audio import ( + AudioToTargetDataset, + AudioToTargetWithEmbeddingDataset, + AudioToTargetWithReferenceDataset, + _audio_collate_fn, + list_to_multichannel, +) from nemo.collections.asr.data.audio_to_text import TarredAudioToBPEDataset, TarredAudioToCharDataset from nemo.collections.asr.data.audio_to_text_dali import ( __DALI_MINIMUM_VERSION__, @@ -33,6 +41,8 @@ ) from nemo.collections.asr.data.audio_to_text_dataset import inject_dataloader_value_from_model_config from nemo.collections.asr.models.ctc_models import EncDecCTCModel +from nemo.collections.asr.parts.utils.audio_utils import get_segment_start +from nemo.collections.asr.parts.utils.manifest_utils import write_manifest from nemo.collections.common import tokenizers from nemo.utils import logging @@ -574,3 +584,946 @@ def test_dali_tarred_char_vs_ref_dataset(self, test_data_dir): err = np.abs(a - b) assert np.mean(err) < 0.0001 assert np.max(err) < 0.01 + + +class TestAudioDatasets: + @pytest.mark.unit + @pytest.mark.parametrize('num_channels', [1, 2]) + @pytest.mark.parametrize('num_targets', [1, 3]) + def test_list_to_multichannel(self, num_channels, num_targets): + """Test conversion of a list of arrays into + """ + random_seed = 42 + num_samples = 1000 + + # Generate random signals + _rng = np.random.default_rng(seed=random_seed) + + # Multi-channel signal + golden_target = _rng.normal(size=(num_samples, num_channels * num_targets)) + + # Create a list of num_targets signals with num_channels channels + target_list = [golden_target[:, n * num_channels : (n + 1) * num_channels] for n in range(num_targets)] + + # Check the original signal is not modified + assert (list_to_multichannel(golden_target) == golden_target).all() + # Check the list is converted back to the original signal + assert (list_to_multichannel(target_list) == golden_target).all() + + @pytest.mark.unit + def test_audio_collate_fn(self): + """Test `_audio_collate_fn` + """ + batch_size = 16 + random_seed = 42 + atol = 1e-5 + + # Generate random signals + _rng = np.random.default_rng(seed=random_seed) + + signal_to_channels = { + 'input_signal': 2, + 'target_signal': 1, + 'reference_signal': 1, + } + + signal_to_length = { + 'input_signal': _rng.integers(low=5, high=25, size=batch_size), + 'target_signal': _rng.integers(low=5, high=25, size=batch_size), + 'reference_signal': _rng.integers(low=5, high=25, size=batch_size), + } + + # Generate batch + batch = [] + for n in range(batch_size): + item = dict() + for signal, num_channels in signal_to_channels.items(): + random_signal = _rng.normal(size=(signal_to_length[signal][n], num_channels)) + random_signal = np.squeeze(random_signal) # get rid of channel dimention for single-channel + item[signal] = torch.tensor(random_signal) + batch.append(item) + + # Run UUT + batched = _audio_collate_fn(batch) + + batched_signals = { + 'input_signal': batched[0].cpu().detach().numpy(), + 'target_signal': batched[2].cpu().detach().numpy(), + 'reference_signal': batched[4].cpu().detach().numpy(), + } + + batched_lengths = { + 'input_signal': batched[1].cpu().detach().numpy(), + 'target_signal': batched[3].cpu().detach().numpy(), + 'reference_signal': batched[5].cpu().detach().numpy(), + } + + # Check outputs + for signal, b_signal in batched_signals.items(): + for n in range(batch_size): + # Check length + uut_length = batched_lengths[signal][n] + golden_length = signal_to_length[signal][n] + assert ( + uut_length == golden_length + ), f'Example {n} signal {signal} length mismatch: batched ({uut_length}) != golden ({golden_length})' + + uut_signal = b_signal[n][:uut_length, ...] + golden_signal = batch[n][signal][:uut_length, ...].cpu().detach().numpy() + assert np.isclose( + uut_signal, golden_signal, atol=atol + ).all(), f'Example {n} signal {signal} value mismatch.' + + @pytest.mark.unit + def test_audio_to_target_dataset(self): + """Test AudioWithTargetDataset in different configurations. + + Test below cover the following: + 1) no constraints + 2) filtering based on signal duration + 3) use with channel selector + 4) use with fixed audio duration and random subsegments + 5) collate a batch of items + + In this use case, each line of the manifest file has the following format: + ``` + { + 'input_filepath': 'path/to/input.wav', + 'target_filepath': 'path/to/path_to_target.wav', + 'duration': duration_of_input, + } + ``` + """ + # Data setup + random_seed = 42 + sample_rate = 16000 + num_examples = 25 + data_num_channels = { + 'input_signal': 4, + 'target_signal': 2, + } + data_min_duration = 2.0 + data_max_duration = 8.0 + data_key = { + 'input_signal': 'input_filepath', + 'target_signal': 'target_filepath', + } + + # Tolerance + max_diff_tol = 1e-6 + atol = 1e-6 + + # Generate random signals + _rng = np.random.default_rng(seed=random_seed) + + # Input and target signals have the same duration + data_duration = np.round(_rng.uniform(low=data_min_duration, high=data_max_duration, size=num_examples), 3) + data_duration_samples = np.floor(data_duration * sample_rate).astype(int) + + data = dict() + for signal, num_channels in data_num_channels.items(): + data[signal] = [] + for n in range(num_examples): + if num_channels == 1: + random_signal = _rng.uniform(low=-0.5, high=0.5, size=(data_duration_samples[n])) + else: + random_signal = _rng.uniform(low=-0.5, high=0.5, size=(data_duration_samples[n], num_channels)) + data[signal].append(random_signal) + + with tempfile.TemporaryDirectory() as test_dir: + + # Build metadata for manifest + metadata = [] + + for n in range(num_examples): + + meta = dict() + + for signal in data: + # filenames + signal_filename = f'{signal}_{n:02d}.wav' + + # write audio files + sf.write(os.path.join(test_dir, signal_filename), data[signal][n], sample_rate, 'float') + + # update metadata + meta[data_key[signal]] = signal_filename + + meta['duration'] = data_duration[n] + metadata.append(meta) + + # Save manifest + manifest_filepath = os.path.join(test_dir, 'manifest.json') + write_manifest(manifest_filepath, metadata) + + # Test 1 + # - No constraints on channels or duration + dataset = AudioToTargetDataset( + manifest_filepath=manifest_filepath, + input_key=data_key['input_signal'], + target_key=data_key['target_signal'], + sample_rate=sample_rate, + ) + + # Also test the corresponding factory + config = { + 'manifest_filepath': manifest_filepath, + 'input_key': data_key['input_signal'], + 'target_key': data_key['target_signal'], + 'sample_rate': sample_rate, + } + dataset_factory = audio_to_audio_dataset.get_audio_to_target_dataset(config) + + for n in range(num_examples): + item = dataset.__getitem__(n) + item_factory = dataset_factory.__getitem__(n) + + for signal in data: + item_signal = item[signal].cpu().detach().numpy() + golden_signal = data[signal][n] + assert ( + item_signal.shape == golden_signal.shape + ), f'Signal {signal}: item shape {item_signal.shape} not matching reference shape {golden_signal.shape}' + assert np.isclose( + item_signal, golden_signal, atol=atol + ).all(), f'Test 1: Failed for example {n}, signal {signal} (random seed {random_seed})' + + item_factory_signal = item_factory[signal].cpu().detach().numpy() + assert np.isclose( + item_factory_signal, golden_signal, atol=atol + ).all(), f'Test 1: Failed for factory example {n}, signal {signal} (random seed {random_seed})' + + # Test 2 + # - Filtering based on signal duration + min_duration = 3.5 + max_duration = 7.5 + + dataset = AudioToTargetDataset( + manifest_filepath=manifest_filepath, + input_key=data_key['input_signal'], + target_key=data_key['target_signal'], + min_duration=min_duration, + max_duration=max_duration, + sample_rate=sample_rate, + ) + + filtered_examples = [n for n, val in enumerate(data_duration) if min_duration <= val <= max_duration] + + for n in range(len(dataset)): + item = dataset.__getitem__(n) + + for signal in data: + item_signal = item[signal].cpu().detach().numpy() + golden_signal = data[signal][filtered_examples[n]] + assert ( + item_signal.shape == golden_signal.shape + ), f'Signal {signal}: item shape {item_signal.shape} not matching reference shape {golden_signal.shape}' + assert np.isclose( + item_signal, golden_signal, atol=atol + ).all(), f'Test 2: Failed for example {n}, signal {signal} (random seed {random_seed})' + + # Test 3 + # - Use channel selector + channel_selector = { + 'input_signal': [0, 2], + 'target_signal': 1, + } + + dataset = AudioToTargetDataset( + manifest_filepath=manifest_filepath, + input_key=data_key['input_signal'], + target_key=data_key['target_signal'], + input_channel_selector=channel_selector['input_signal'], + target_channel_selector=channel_selector['target_signal'], + sample_rate=sample_rate, + ) + + for n in range(len(dataset)): + item = dataset.__getitem__(n) + + for signal in data: + cs = channel_selector[signal] + item_signal = item[signal].cpu().detach().numpy() + golden_signal = data[signal][n][..., cs] + assert ( + item_signal.shape == golden_signal.shape + ), f'Signal {signal}: item shape {item_signal.shape} not matching reference shape {golden_signal.shape}' + assert np.isclose( + item_signal, golden_signal, atol=atol + ).all(), f'Test 3: Failed for example {n}, signal {signal} (random seed {random_seed})' + + # Test 4 + # - Use fixed duration (random segment selection) + audio_duration = 4.0 + audio_duration_samples = int(np.floor(audio_duration * sample_rate)) + + filtered_examples = [n for n, val in enumerate(data_duration) if val >= audio_duration] + + for random_offset in [True, False]: + # Test subsegments with the default fixed offset and a random offset + + dataset = AudioToTargetDataset( + manifest_filepath=manifest_filepath, + input_key=data_key['input_signal'], + target_key=data_key['target_signal'], + sample_rate=sample_rate, + min_duration=audio_duration, + audio_duration=audio_duration, + random_offset=random_offset, # random offset when selecting subsegment + ) + + for n in range(len(dataset)): + item = dataset.__getitem__(n) + + golden_start = golden_end = None + for signal in data: + item_signal = item[signal].cpu().detach().numpy() + full_golden_signal = data[signal][filtered_examples[n]] + + # Find random segment using correlation on the first channel + # of the first signal, and then use it fixed for other signals + if golden_start is None: + golden_start = get_segment_start( + signal=full_golden_signal[:, 0], segment=item_signal[:, 0] + ) + if not random_offset: + assert ( + golden_start == 0 + ), f'Expecting the signal to start at 0 when random_offset is False' + + golden_end = golden_start + audio_duration_samples + golden_signal = full_golden_signal[golden_start:golden_end, ...] + + # Test length is correct + assert ( + len(item_signal) == audio_duration_samples + ), f'Test 4: Signal length ({len(item_signal)}) not matching the expected length ({audio_duration_samples})' + + assert ( + item_signal.shape == golden_signal.shape + ), f'Signal {signal}: item shape {item_signal.shape} not matching reference shape {golden_signal.shape}' + # Test signal values + assert np.isclose( + item_signal, golden_signal, atol=atol + ).all(), f'Test 4: Failed for example {n}, signal {signal} (random seed {random_seed})' + + # Test 5: + # - Test collate_fn + batch_size = 16 + batch = [dataset.__getitem__(n) for n in range(batch_size)] + batched = dataset.collate_fn(batch) + + for n, signal in enumerate(data.keys()): + signal_shape = batched[2 * n].shape + signal_len = batched[2 * n + 1] + + assert signal_shape == ( + batch_size, + audio_duration_samples, + data_num_channels[signal], + ), f'Test 5: Unexpected signal {signal} shape {signal_shape}' + assert len(signal_len) == batch_size, f'Test 5: Unexpected length of signal_len ({len(signal_len)})' + assert all(signal_len == audio_duration_samples), f'Test 5: Unexpected signal_len {signal_len}' + + @pytest.mark.unit + def test_audio_to_target_dataset_with_target_list(self): + """Test AudioWithTargetDataset when the input manifest has a list + of audio files in the target key. + + In this use case, each line of the manifest file has the following format: + ``` + { + 'input_filepath': 'path/to/input.wav', + 'target_filepath': ['path/to/path_to_target_ch0.wav', 'path/to/path_to_target_ch1.wav'], + 'duration': duration_of_input, + } + ``` + """ + # Data setup + random_seed = 42 + sample_rate = 16000 + num_examples = 25 + data_num_channels = { + 'input_signal': 4, + 'target_signal': 2, + } + data_min_duration = 2.0 + data_max_duration = 8.0 + data_key = { + 'input_signal': 'input_filepath', + 'target_signal': 'target_filepath', + } + + # Tolerance + atol = 1e-6 + + # Generate random signals + _rng = np.random.default_rng(seed=random_seed) + + # Input and target signals have the same duration + data_duration = np.round(_rng.uniform(low=data_min_duration, high=data_max_duration, size=num_examples), 3) + data_duration_samples = np.floor(data_duration * sample_rate).astype(int) + + data = dict() + for signal, num_channels in data_num_channels.items(): + data[signal] = [] + for n in range(num_examples): + if num_channels == 1: + random_signal = _rng.uniform(low=-0.5, high=0.5, size=(data_duration_samples[n])) + else: + random_signal = _rng.uniform(low=-0.5, high=0.5, size=(data_duration_samples[n], num_channels)) + data[signal].append(random_signal) + + with tempfile.TemporaryDirectory() as test_dir: + + # Build metadata for manifest + metadata = [] + + for n in range(num_examples): + + meta = dict() + + for signal in data: + if signal == 'target_signal': + # Save targets as individual files + signal_filename = [] + for ch in range(data_num_channels[signal]): + # add current filename + signal_filename.append(f'{signal}_{n:02d}_ch_{ch}.wav') + # write audio file + sf.write( + os.path.join(test_dir, signal_filename[-1]), + data[signal][n][:, ch], + sample_rate, + 'float', + ) + else: + # single file + signal_filename = f'{signal}_{n:02d}.wav' + + # write audio files + sf.write(os.path.join(test_dir, signal_filename), data[signal][n], sample_rate, 'float') + + # update metadata + meta[data_key[signal]] = signal_filename + + meta['duration'] = data_duration[n] + metadata.append(meta) + + # Save manifest + manifest_filepath = os.path.join(test_dir, 'manifest.json') + write_manifest(manifest_filepath, metadata) + + # Test 1 + # - No constraints on channels or duration + dataset = AudioToTargetDataset( + manifest_filepath=manifest_filepath, + input_key=data_key['input_signal'], + target_key=data_key['target_signal'], + sample_rate=sample_rate, + ) + + config = { + 'manifest_filepath': manifest_filepath, + 'input_key': data_key['input_signal'], + 'target_key': data_key['target_signal'], + 'sample_rate': sample_rate, + } + dataset_factory = audio_to_audio_dataset.get_audio_to_target_dataset(config) + + for n in range(num_examples): + item = dataset.__getitem__(n) + item_factory = dataset_factory.__getitem__(n) + + for signal in data: + item_signal = item[signal].cpu().detach().numpy() + golden_signal = data[signal][n] + assert ( + item_signal.shape == golden_signal.shape + ), f'Signal {signal}: item shape {item_signal.shape} not matching reference shape {golden_signal.shape}' + assert np.isclose( + item_signal, golden_signal, atol=atol + ).all(), f'Test 1: Failed for example {n}, signal {signal} (random seed {random_seed})' + + item_factory_signal = item_factory[signal].cpu().detach().numpy() + assert np.isclose( + item_factory_signal, golden_signal, atol=atol + ).all(), f'Test 1: Failed for factory example {n}, signal {signal} (random seed {random_seed})' + + # Test 2 + # Set target as the first channel of input_filepath and all files listed in target_filepath. + # In this case, the target will have 3 channels. + dataset = AudioToTargetDataset( + manifest_filepath=manifest_filepath, + input_key=data_key['input_signal'], + target_key=[data_key['input_signal'], data_key['target_signal']], + target_channel_selector=0, + sample_rate=sample_rate, + ) + + for n in range(num_examples): + item = dataset.__getitem__(n) + + for signal in data: + item_signal = item[signal].cpu().detach().numpy() + golden_signal = data[signal][n] + if signal == 'target_signal': + # add the first channel of the input + golden_signal = np.concatenate([data['input_signal'][n][..., 0:1], golden_signal], axis=1) + assert ( + item_signal.shape == golden_signal.shape + ), f'Signal {signal}: item shape {item_signal.shape} not matching reference shape {golden_signal.shape}' + assert np.isclose( + item_signal, golden_signal, atol=atol + ).all(), f'Test 2: Failed for example {n}, signal {signal} (random seed {random_seed})' + + @pytest.mark.unit + def test_audio_to_target_dataset_for_inference(self): + """Test AudioWithTargetDataset when target_key is + not set, i.e., it is `None`. This is the case, e.g., when + running inference, and a target is not available. + + In this use case, each line of the manifest file has the following format: + ``` + { + 'input_filepath': 'path/to/input.wav', + 'duration': duration_of_input, + } + ``` + """ + # Data setup + random_seed = 42 + sample_rate = 16000 + num_examples = 25 + data_num_channels = { + 'input_signal': 4, + } + data_min_duration = 2.0 + data_max_duration = 8.0 + data_key = { + 'input_signal': 'input_filepath', + } + + # Tolerance + max_diff_tol = 1e-6 + atol = 1e-6 + + # Generate random signals + _rng = np.random.default_rng(seed=random_seed) + + # Input and target signals have the same duration + data_duration = np.round(_rng.uniform(low=data_min_duration, high=data_max_duration, size=num_examples), 3) + data_duration_samples = np.floor(data_duration * sample_rate).astype(int) + + data = dict() + for signal, num_channels in data_num_channels.items(): + data[signal] = [] + for n in range(num_examples): + if num_channels == 1: + random_signal = _rng.uniform(low=-0.5, high=0.5, size=(data_duration_samples[n])) + else: + random_signal = _rng.uniform(low=-0.5, high=0.5, size=(data_duration_samples[n], num_channels)) + data[signal].append(random_signal) + + with tempfile.TemporaryDirectory() as test_dir: + # Build metadata for manifest + metadata = [] + for n in range(num_examples): + meta = dict() + for signal in data: + # filenames + signal_filename = f'{signal}_{n:02d}.wav' + # write audio files + sf.write(os.path.join(test_dir, signal_filename), data[signal][n], sample_rate, 'float') + # update metadata + meta[data_key[signal]] = signal_filename + meta['duration'] = data_duration[n] + metadata.append(meta) + + # Save manifest + manifest_filepath = os.path.join(test_dir, 'manifest.json') + write_manifest(manifest_filepath, metadata) + + # Test 1 + # - No constraints on channels or duration + dataset = AudioToTargetDataset( + manifest_filepath=manifest_filepath, + input_key=data_key['input_signal'], + target_key=None, # target_signal will be empty + sample_rate=sample_rate, + ) + + # Also test the corresponding factory + config = { + 'manifest_filepath': manifest_filepath, + 'input_key': data_key['input_signal'], + 'target_key': None, + 'sample_rate': sample_rate, + } + dataset_factory = audio_to_audio_dataset.get_audio_to_target_dataset(config) + + for n in range(num_examples): + item = dataset.__getitem__(n) + item_factory = dataset_factory.__getitem__(n) + + # Check target is None + assert item['target_signal'].numel() == 0, 'target_signal is expected to be empty.' + assert item_factory['target_signal'].numel() == 0, 'target_signal is expected to be empty.' + + # Check valid signals + for signal in data: + item_signal = item[signal].cpu().detach().numpy() + golden_signal = data[signal][n] + assert ( + item_signal.shape == golden_signal.shape + ), f'Signal {signal}: item shape {item_signal.shape} not matching reference shape {golden_signal.shape}' + assert np.isclose( + item_signal, golden_signal, atol=atol + ).all(), f'Test 1: Failed for example {n}, signal {signal} (random seed {random_seed})' + + item_factory_signal = item_factory[signal].cpu().detach().numpy() + assert np.isclose( + item_factory_signal, golden_signal, atol=atol + ).all(), f'Test 1: Failed for factory example {n}, signal {signal} (random seed {random_seed})' + + @pytest.mark.unit + def test_audio_to_target_with_reference_dataset(self): + """Test AudioWithTargetWithReferenceDataset in different configurations. + + 1) reference synchronized with input and target + 2) reference not synchronized + + In this use case, each line of the manifest file has the following format: + ``` + { + 'input_filepath': 'path/to/input.wav', + 'target_filepath': 'path/to/path_to_target.wav', + 'reference_filepath': 'path/to/path_to_reference.wav', + 'duration': duration_of_input, + } + ``` + """ + # Data setup + random_seed = 42 + sample_rate = 16000 + num_examples = 25 + data_num_channels = { + 'input_signal': 4, + 'target_signal': 2, + 'reference_signal': 1, + } + data_min_duration = 2.0 + data_max_duration = 8.0 + data_key = { + 'input_signal': 'input_filepath', + 'target_signal': 'target_filepath', + 'reference_signal': 'reference_filepath', + } + + # Tolerance + atol = 1e-6 + + # Generate random signals + _rng = np.random.default_rng(seed=random_seed) + + # Input and target signals have the same duration + data_duration = np.round(_rng.uniform(low=data_min_duration, high=data_max_duration, size=num_examples), 3) + data_duration_samples = np.floor(data_duration * sample_rate).astype(int) + + data = dict() + for signal, num_channels in data_num_channels.items(): + data[signal] = [] + for n in range(num_examples): + if num_channels == 1: + random_signal = _rng.uniform(low=-0.5, high=0.5, size=(data_duration_samples[n])) + else: + random_signal = _rng.uniform(low=-0.5, high=0.5, size=(data_duration_samples[n], num_channels)) + data[signal].append(random_signal) + + with tempfile.TemporaryDirectory() as test_dir: + + # Build metadata for manifest + metadata = [] + + for n in range(num_examples): + + meta = dict() + + for signal in data: + # filenames + signal_filename = f'{signal}_{n:02d}.wav' + + # write audio files + sf.write(os.path.join(test_dir, signal_filename), data[signal][n], sample_rate, 'float') + + # update metadata + meta[data_key[signal]] = signal_filename + + meta['duration'] = data_duration[n] + metadata.append(meta) + + # Save manifest + manifest_filepath = os.path.join(test_dir, 'manifest.json') + write_manifest(manifest_filepath, metadata) + + # Test 1 + # - No constraints on channels or duration + # - Reference is not synchronized with input and target, so whole reference signal will be loaded + dataset = AudioToTargetWithReferenceDataset( + manifest_filepath=manifest_filepath, + input_key=data_key['input_signal'], + target_key=data_key['target_signal'], + reference_key=data_key['reference_signal'], + reference_is_synchronized=False, + sample_rate=sample_rate, + ) + + # Also test the corresponding factory + config = { + 'manifest_filepath': manifest_filepath, + 'input_key': data_key['input_signal'], + 'target_key': data_key['target_signal'], + 'reference_key': data_key['reference_signal'], + 'reference_is_synchronized': False, + 'sample_rate': sample_rate, + } + dataset_factory = audio_to_audio_dataset.get_audio_to_target_with_reference_dataset(config) + + for n in range(num_examples): + item = dataset.__getitem__(n) + item_factory = dataset_factory.__getitem__(n) + + for signal in data: + item_signal = item[signal].cpu().detach().numpy() + golden_signal = data[signal][n] + assert ( + item_signal.shape == golden_signal.shape + ), f'Signal {signal}: item shape {item_signal.shape} not matching reference shape {golden_signal.shape}' + assert np.isclose( + item_signal, golden_signal, atol=atol + ).all(), f'Test 1: Failed for example {n}, signal {signal} (random seed {random_seed})' + + item_factory_signal = item_factory[signal].cpu().detach().numpy() + assert np.isclose( + item_factory_signal, golden_signal, atol=atol + ).all(), f'Test 1: Failed for factory example {n}, signal {signal} (random seed {random_seed})' + + # Test 2 + # - Use fixed duration (random segment selection) + # - Reference is synchronized with input and target, so the same segment of reference signal will be loaded + audio_duration = 4.0 + audio_duration_samples = int(np.floor(audio_duration * sample_rate)) + dataset = AudioToTargetWithReferenceDataset( + manifest_filepath=manifest_filepath, + input_key=data_key['input_signal'], + target_key=data_key['target_signal'], + reference_key=data_key['reference_signal'], + reference_is_synchronized=True, + sample_rate=sample_rate, + min_duration=audio_duration, + audio_duration=audio_duration, + random_offset=True, + ) + + filtered_examples = [n for n, val in enumerate(data_duration) if val >= audio_duration] + + for n in range(len(dataset)): + item = dataset.__getitem__(n) + + golden_start = golden_end = None + for signal in data: + item_signal = item[signal].cpu().detach().numpy() + full_golden_signal = data[signal][filtered_examples[n]] + + # Find random segment using correlation on the first channel + # of the first signal, and then use it fixed for other signals + if golden_start is None: + golden_start = get_segment_start(signal=full_golden_signal[:, 0], segment=item_signal[:, 0]) + golden_end = golden_start + audio_duration_samples + golden_signal = full_golden_signal[golden_start:golden_end, ...] + + # Test length is correct + assert ( + len(item_signal) == audio_duration_samples + ), f'Test 2: Signal {signal} length ({len(item_signal)}) not matching the expected length ({audio_duration_samples})' + + # Test signal values + assert np.isclose( + item_signal, golden_signal, atol=atol + ).all(), f'Test 2: Failed for example {n}, signal {signal} (random seed {random_seed})' + + # Test 3 + # - Use fixed duration (random segment selection) + # - Reference is not synchronized with input and target, so whole reference signal will be loaded + audio_duration = 4.0 + audio_duration_samples = int(np.floor(audio_duration * sample_rate)) + dataset = AudioToTargetWithReferenceDataset( + manifest_filepath=manifest_filepath, + input_key=data_key['input_signal'], + target_key=data_key['target_signal'], + reference_key=data_key['reference_signal'], + reference_is_synchronized=False, + sample_rate=sample_rate, + min_duration=audio_duration, + audio_duration=audio_duration, + random_offset=True, + ) + + filtered_examples = [n for n, val in enumerate(data_duration) if val >= audio_duration] + + for n in range(len(dataset)): + item = dataset.__getitem__(n) + + golden_start = golden_end = None + for signal in data: + item_signal = item[signal].cpu().detach().numpy() + full_golden_signal = data[signal][filtered_examples[n]] + + if signal == 'reference_signal': + # Complete signal is loaded for reference + golden_signal = full_golden_signal + else: + # Find random segment using correlation on the first channel + # of the first signal, and then use it fixed for other signals + if golden_start is None: + golden_start = get_segment_start( + signal=full_golden_signal[:, 0], segment=item_signal[:, 0] + ) + golden_end = golden_start + audio_duration_samples + golden_signal = full_golden_signal[golden_start:golden_end, ...] + + # Test length is correct + assert ( + len(item_signal) == audio_duration_samples + ), f'Test 3: Signal {signal} length ({len(item_signal)}) not matching the expected length ({audio_duration_samples})' + assert ( + item_signal.shape == golden_signal.shape + ), f'Signal {signal}: item shape {item_signal.shape} not matching reference shape {golden_signal.shape}' + # Test signal values + assert np.isclose( + item_signal, golden_signal, atol=atol + ).all(), f'Test 3: Failed for example {n}, signal {signal} (random seed {random_seed})' + + @pytest.mark.unit + def test_audio_to_target_with_embedding_dataset(self): + """Test AudioWithTargetWithEmbeddingDataset. + + In this use case, each line of the manifest file has the following format: + ``` + { + 'input_filepath': 'path/to/input.wav', + 'target_filepath': 'path/to/path_to_target.wav', + 'embedding_filepath': 'path/to/path_to_embedding.npy', + 'duration': duration_of_input, + } + ``` + """ + # Data setup + random_seed = 42 + sample_rate = 16000 + num_examples = 25 + data_num_channels = { + 'input_signal': 4, + 'target_signal': 2, + 'embedding_vector': 1, + } + data_min_duration = 2.0 + data_max_duration = 8.0 + embedding_length = 64 # 64-dimensional embedding vector + data_key = { + 'input_signal': 'input_filepath', + 'target_signal': 'target_filepath', + 'embedding_vector': 'embedding_filepath', + } + + # Tolerance + atol = 1e-6 + + # Generate random signals + _rng = np.random.default_rng(seed=random_seed) + + # Input and target signals have the same duration + data_duration = np.round(_rng.uniform(low=data_min_duration, high=data_max_duration, size=num_examples), 3) + data_duration_samples = np.floor(data_duration * sample_rate).astype(int) + + data = dict() + for signal, num_channels in data_num_channels.items(): + data[signal] = [] + for n in range(num_examples): + data_length = embedding_length if signal == 'embedding_vector' else data_duration_samples[n] + + if num_channels == 1: + random_signal = _rng.uniform(low=-0.5, high=0.5, size=(data_length)) + else: + random_signal = _rng.uniform(low=-0.5, high=0.5, size=(data_length, num_channels)) + data[signal].append(random_signal) + + with tempfile.TemporaryDirectory() as test_dir: + + # Build metadata for manifest + metadata = [] + + for n in range(num_examples): + + meta = dict() + + for signal in data: + if signal == 'embedding_vector': + signal_filename = f'{signal}_{n:02d}.npy' + np.save(os.path.join(test_dir, signal_filename), data[signal][n]) + + else: + # filenames + signal_filename = f'{signal}_{n:02d}.wav' + + # write audio files + sf.write(os.path.join(test_dir, signal_filename), data[signal][n], sample_rate, 'float') + + # update metadata + meta[data_key[signal]] = signal_filename + + meta['duration'] = data_duration[n] + metadata.append(meta) + + # Save manifest + manifest_filepath = os.path.join(test_dir, 'manifest.json') + write_manifest(manifest_filepath, metadata) + + # Test 1 + # - No constraints on channels or duration + dataset = AudioToTargetWithEmbeddingDataset( + manifest_filepath=manifest_filepath, + input_key=data_key['input_signal'], + target_key=data_key['target_signal'], + embedding_key=data_key['embedding_vector'], + sample_rate=sample_rate, + ) + + # Also test the corresponding factory + config = { + 'manifest_filepath': manifest_filepath, + 'input_key': data_key['input_signal'], + 'target_key': data_key['target_signal'], + 'embedding_key': data_key['embedding_vector'], + 'sample_rate': sample_rate, + } + dataset_factory = audio_to_audio_dataset.get_audio_to_target_with_embedding_dataset(config) + + for n in range(num_examples): + item = dataset.__getitem__(n) + item_factory = dataset_factory.__getitem__(n) + + for signal in data: + item_signal = item[signal].cpu().detach().numpy() + golden_signal = data[signal][n] + assert ( + item_signal.shape == golden_signal.shape + ), f'Signal {signal}: item shape {item_signal.shape} not matching reference shape {golden_signal.shape}' + assert np.isclose( + item_signal, golden_signal, atol=atol + ).all(), f'Test 1: Failed for example {n}, signal {signal} (random seed {random_seed})' + + item_factory_signal = item_factory[signal].cpu().detach().numpy() + assert np.isclose( + item_factory_signal, golden_signal, atol=atol + ).all(), f'Test 1: Failed for factory example {n}, signal {signal} (random seed {random_seed})' diff --git a/tests/collections/asr/utils/test_audio_utils.py b/tests/collections/asr/utils/test_audio_utils.py index e277f410c9d13..489a7a7740784 100644 --- a/tests/collections/asr/utils/test_audio_utils.py +++ b/tests/collections/asr/utils/test_audio_utils.py @@ -25,6 +25,7 @@ db2mag, estimated_coherence, generate_approximate_noise_field, + get_segment_start, mag2db, pow2db, rms, @@ -268,3 +269,28 @@ def test_db_conversion(self): assert all(np.abs(mag - 10 ** (mag_db / 20)) < abs_threshold) assert all(np.abs(db2mag(mag_db) - 10 ** (mag_db / 20)) < abs_threshold) assert all(np.abs(pow2db(mag ** 2) - mag_db) < abs_threshold) + + @pytest.mark.unit + def test_get_segment_start(self): + random_seed = 42 + num_examples = 50 + num_samples = 2000 + + _rng = np.random.default_rng(seed=random_seed) + + for n in range(num_examples): + # Generate signal + signal = _rng.normal(size=num_samples) + # Random start in the first half + start = _rng.integers(low=0, high=num_samples // 2) + # Random length + end = _rng.integers(low=start, high=num_samples) + # Selected segment + segment = signal[start:end] + + # UUT + estimated_start = get_segment_start(signal=signal, segment=segment) + + assert ( + estimated_start == start + ), f'Example {n}: estimated start ({estimated_start}) not matching the actual start ({start})' diff --git a/tests/collections/common/test_utils.py b/tests/collections/common/test_utils.py new file mode 100644 index 0000000000000..93c33e92bf988 --- /dev/null +++ b/tests/collections/common/test_utils.py @@ -0,0 +1,33 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from nemo.collections.common.parts.utils import flatten + + +class TestListUtils: + @pytest.mark.unit + def test_flatten(self): + """Test flattening an iterable with different values: str, bool, int, float, complex. + """ + test_cases = [] + test_cases.append({'input': ['aa', 'bb', 'cc'], 'golden': ['aa', 'bb', 'cc']}) + test_cases.append({'input': ['aa', ['bb', 'cc']], 'golden': ['aa', 'bb', 'cc']}) + test_cases.append({'input': ['aa', [['bb'], [['cc']]]], 'golden': ['aa', 'bb', 'cc']}) + test_cases.append({'input': ['aa', [[1, 2], [[3]], 4]], 'golden': ['aa', 1, 2, 3, 4]}) + test_cases.append({'input': [True, [2.5, 2.0 + 1j]], 'golden': [True, 2.5, 2.0 + 1j]}) + + for n, test_case in enumerate(test_cases): + assert flatten(test_case['input']) == test_case['golden'], f'Test case {n} failed!'