From 738da41fa7441fd26ae67a1412310b42228cce05 Mon Sep 17 00:00:00 2001 From: anteju <108555623+anteju@users.noreply.github.com> Date: Tue, 29 Nov 2022 19:11:25 -0800 Subject: [PATCH] [ASR] AudioToAudio datasets and related test (#5196) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * AudioToAudio datasets and related test Signed-off-by: Ante Jukić * Updated doc, created utility function in manifest to avoide code duplication Signed-off-by: Ante Jukić * Remove unused import Signed-off-by: Ante Jukić * Moved functionality to ASRAudioProcessor Signed-off-by: Ante Jukić * Addressed review comments Signed-off-by: Ante Jukić * Removed unused local variable Signed-off-by: Ante Jukić Signed-off-by: Ante Jukić --- nemo/collections/asr/data/audio_to_audio.py | 1119 +++++++++++++++++ .../asr/data/audio_to_audio_dataset.py | 92 ++ .../asr/parts/preprocessing/segment.py | 32 +- .../asr/parts/utils/audio_utils.py | 24 + .../common/parts/preprocessing/collections.py | 180 ++- .../common/parts/preprocessing/manifest.py | 44 +- nemo/collections/common/parts/utils.py | 33 +- tests/collections/asr/test_asr_datasets.py | 963 +++++++++++++- .../collections/asr/utils/test_audio_utils.py | 26 + tests/collections/common/test_utils.py | 33 + 10 files changed, 2523 insertions(+), 23 deletions(-) create mode 100644 nemo/collections/asr/data/audio_to_audio.py create mode 100644 nemo/collections/asr/data/audio_to_audio_dataset.py create mode 100644 tests/collections/common/test_utils.py diff --git a/nemo/collections/asr/data/audio_to_audio.py b/nemo/collections/asr/data/audio_to_audio.py new file mode 100644 index 0000000000000..19a2444f5f9fc --- /dev/null +++ b/nemo/collections/asr/data/audio_to_audio.py @@ -0,0 +1,1119 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc +import math +import random +from collections import OrderedDict +from dataclasses import dataclass +from typing import Callable, Dict, List, Optional, Tuple, Union + +import librosa +import numpy as np +import torch + +from nemo.collections.asr.parts.preprocessing.segment import AudioSegment +from nemo.collections.asr.parts.utils.audio_utils import ChannelSelectorType +from nemo.collections.common.parts.preprocessing import collections +from nemo.collections.common.parts.utils import flatten +from nemo.core.classes import Dataset +from nemo.core.neural_types import AudioSignal, EncodedRepresentation, LengthsType, NeuralType +from nemo.utils import logging +from nemo.utils.decorators import experimental + +__all__ = [ + 'AudioToTargetDataset', + 'AudioToTargetWithReferenceDataset', + 'AudioToTargetWithEmbeddingDataset', +] + + +def _audio_collate_fn(batch: List[dict]) -> Tuple[torch.Tensor]: + """Collate a batch of items returned by __getitem__. + Examples for each signal are zero padded to the same length + (batch_length), which is determined by the longest example. + Lengths of the original signals are returned in the output. + + Args: + batch: List of dictionaries. Each element of the list + has the following format + ``` + { + 'signal_0': 1D or 2D tensor, + 'signal_1': 1D or 2D tensor, + ... + 'signal_N': 1D or 2D tensor, + } + ``` + 1D tensors have shape (num_samples,) and 2D tensors + have shape (num_samples, num_channels) + + Returns: + A tuple containing signal tensor and signal length tensor (in samples) + for each signal. + The output has the following format: + ``` + (signal_0, signal_0_length, signal_1, signal_1_length, ..., signal_N, signal_N_length) + ``` + Note that the output format is obtained by interleaving signals and their length. + """ + signals = batch[0].keys() + + batched = tuple() + + for signal in signals: + signal_length = [b[signal].shape[0] for b in batch] + # Batch length is determined by the longest signal in the batch + batch_length = max(signal_length) + b_signal = [] + for s_len, b in zip(signal_length, batch): + # check if padding is necessary + if s_len < batch_length: + if b[signal].ndim == 1: + # single-channel signal + pad = (0, batch_length - s_len) + elif b[signal].ndim == 2: + # multi-channel signal + pad = (0, 0, 0, batch_length - s_len) + else: + raise RuntimeError( + f'Signal {signal} has unsuported dimensions {signal.shape}. Currently, only 1D and 2D arrays are supported.' + ) + b[signal] = torch.nn.functional.pad(b[signal], pad) + # append the current padded signal + b_signal.append(b[signal]) + # (signal_batched, signal_length) + batched += (torch.stack(b_signal), torch.tensor(signal_length, dtype=torch.int32)) + + # Currently, outputs are expected to be in a tuple, where each element must correspond + # to the output type in the OrderedDict returned by output_types. + # + # Therefore, we return batched signals by interleaving signals and their length: + # (signal_0, signal_0_length, signal_1, signal_1_length, ...) + return batched + + +@dataclass +class SignalSetup: + signals: List[str] # signal names + duration: Optional[Union[float, list]] = None # duration for each signal + channel_selectors: Optional[List[ChannelSelectorType]] = None # channel selector for loading each signal + + +class ASRAudioProcessor: + """Class that processes an example from Audio collection and returns + a dictionary with prepared signals. + + For example, the output dictionary may be the following + ``` + { + 'input_signal': input_signal_tensor, + 'target_signal': target_signal_tensor, + 'reference_signal': reference_signal_tensor, + 'embedding_vector': embedding_vector + } + ``` + Keys in the output dictionary are ordered with synchronous signals given first, + followed by asynchronous signals and embedding. + + Args: + sample_rate: sample rate used for all audio signals + random_offset: If `True`, offset will be randomized when loading a subsegment + from a file. + """ + + def __init__( + self, sample_rate: float, random_offset: bool, + ): + self.sample_rate = sample_rate + self.random_offset = random_offset + + self.sync_setup = None + self.async_setup = None + self.embedding_setup = None + + @property + def sample_rate(self) -> float: + return self._sample_rate + + @sample_rate.setter + def sample_rate(self, value: float): + if value <= 0: + raise ValueError(f'Sample rate must be positive, received {value}') + + self._sample_rate = value + + @property + def random_offset(self) -> bool: + return self._random_offset + + @random_offset.setter + def random_offset(self, value: bool): + self._random_offset = value + + @property + def sync_setup(self) -> SignalSetup: + """Return the current setup for synchronous signals. + + Returns: + A dataclass containing the list of signals, their + duration and channel selectors. + """ + return self._sync_setup + + @sync_setup.setter + def sync_setup(self, value: Optional[SignalSetup]): + """Setup signals to be loaded synchronously. + + Args: + value: An instance of SignalSetup with the following fields + - signals: list of signals (keys of example.audio_signals) which will be loaded + synchronously with the same start time and duration. + - duration: Duration for each signal to be loaded. + If duration is set to None, the whole file will be loaded. + - channel_selectors: A list of channel selector for each signal. If channel selector + is None, all channels in the audio file will be loaded. + """ + if value is None or isinstance(value, SignalSetup): + self._sync_setup = value + else: + raise ValueError(f'Unexpected type {type(value)} for value {value}.') + + @property + def async_setup(self) -> SignalSetup: + """Return the current setup for asynchronous signals. + + Returns: + A dataclass containing the list of signals, their + duration and channel selectors. + """ + return self._async_setup + + @async_setup.setter + def async_setup(self, value: Optional[SignalSetup]): + """Setup signals to be loaded asynchronously. + + Args: + Args: + value: An instance of SignalSetup with the following fields + - signals: list of signals (keys of example.audio_signals) which will be loaded + asynchronously with signals possibly having different start and duration + - duration: Duration for each signal to be loaded. + If duration is set to None, the whole file will be loaded. + - channel_selectors: A list of channel selector for each signal. If channel selector + is None, all channels in the audio file will be loaded. + """ + if value is None or isinstance(value, SignalSetup): + self._async_setup = value + else: + raise ValueError(f'Unexpected type {type(value)} for value {value}.') + + @property + def embedding_setup(self) -> SignalSetup: + """Setup signals corresponding to an embedding vector. + """ + return self._embedding_setup + + @embedding_setup.setter + def embedding_setup(self, value: SignalSetup): + """Setup signals corresponding to an embedding vector. + + Args: + value: An instance of SignalSetup with the following fields + - signals: list of signals (keys of example.audio_signals) which will be loaded + as embedding vectors. + """ + if value is None or isinstance(value, SignalSetup): + self._embedding_setup = value + else: + raise ValueError(f'Unexpected type {type(value)} for value {value}.') + + def process(self, example: collections.Audio.OUTPUT_TYPE) -> Dict[str, torch.Tensor]: + """Process an example from a collection of audio examples. + + Args: + example: an example from Audio collection. + + Returns: + An ordered dictionary of signals and their tensors. + For example, the output dictionary may be the following + ``` + { + 'input_signal': input_signal_tensor, + 'target_signal': target_signal_tensor, + 'reference_signal': reference_signal_tensor, + 'embedding_vector': embedding_vector + } + ``` + Keys in the output dictionary are ordered with synchronous signals given first, + followed by asynchronous signals and embedding. + """ + audio = self.load_audio(example=example) + audio = self.process_audio(audio=audio) + return audio + + def load_audio(self, example: collections.Audio.OUTPUT_TYPE) -> Dict[str, torch.Tensor]: + """Given an example, load audio from `example.audio_files` and prepare + the output dictionary. + + Args: + example: An example from an audio collection + + Returns: + An ordered dictionary of signals and their tensors. + For example, the output dictionary may be the following + ``` + { + 'input_signal': input_signal_tensor, + 'target_signal': target_signal_tensor, + 'reference_signal': reference_signal_tensor, + 'embedding_vector': embedding_vector + } + ``` + Keys in the output dictionary are ordered with synchronous signals given first, + followed by asynchronous signals and embedding. + """ + output = OrderedDict() + + if self.sync_setup is not None: + # Load all signals with the same start and duration + sync_signals = self.load_sync_signals(example) + output.update(sync_signals) + + if self.async_setup is not None: + # Load each signal independently + async_signals = self.load_async_signals(example) + output.update(async_signals) + + # Load embedding vector + if self.embedding_setup is not None: + embedding = self.load_embedding(example) + output.update(embedding) + + if not output: + raise RuntimeError('Output dictionary is empty. Please use `_setup` methods to setup signals to be loaded') + + return output + + def process_audio(self, audio: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + """Process audio signals available in the input dictionary. + + Args: + audio: A dictionary containing loaded signals `signal: tensor` + + Returns: + An ordered dictionary of signals and their tensors. + """ + # Currently, not doing any processing of the loaded signals. + return audio + + def load_sync_signals(self, example: collections.Audio.OUTPUT_TYPE) -> Dict[str, torch.Tensor]: + """Load signals with the same start and duration. + + Args: + example: an example from audio collection + + Returns: + An ordered dictionary of signals and their tensors. + """ + output = OrderedDict() + sync_audio_files = [example.audio_files[s] for s in self.sync_setup.signals] + + sync_samples = self.get_samples_synchronized( + audio_files=sync_audio_files, + channel_selectors=self.sync_setup.channel_selectors, + sample_rate=self.sample_rate, + duration=self.sync_setup.duration, + fixed_offset=example.offset, + random_offset=self.random_offset, + ) + + for signal, samples in zip(self.sync_setup.signals, sync_samples): + output[signal] = torch.tensor(samples) + + return output + + def load_async_signals(self, example: collections.Audio.OUTPUT_TYPE) -> Dict[str, torch.Tensor]: + """Load each async signal independently, no constraints on starting + from the same time. + + Args: + example: an example from audio collection + + Returns: + An ordered dictionary of signals and their tensors. + """ + output = OrderedDict() + for idx, signal in enumerate(self.async_setup.signals): + samples = self.get_samples( + audio_file=example.audio_files[signal], + sample_rate=self.sample_rate, + duration=self.async_setup.duration[idx], + channel_selector=self.async_setup.channel_selectors[idx], + fixed_offset=example.offset, + random_offset=self.random_offset, + ) + output[signal] = torch.tensor(samples) + return output + + @classmethod + def get_samples( + cls, + audio_file: str, + sample_rate: int, + duration: Optional[float] = None, + channel_selector: ChannelSelectorType = None, + fixed_offset: float = 0, + random_offset: bool = False, + ) -> np.ndarray: + """Get samples from an audio file. + For a single-channel signal, the output is shape (num_samples,). + For a multi-channel signal, the output is shape (num_samples, num_channels). + + Args: + audio_file: path to an audio file + sample_rate: desired sample rate for output samples + duration: Optional desired duration of output samples. + If `None`, the complete file will be loaded. + If set, a segment of `duration` seconds will be loaded. + channel_selector: Optional channel selector, for selecting a subset of channels. + fixed_offset: Optional fixed offset when loading samples. + random_offset: If `True`, offset will be randomized when loading a short segment + from a file. The value is randomized between fixed_offset and + max_offset (set depending on the duration and fixed_offset). + + Returns: + Numpy array with samples from audio file. + The array has shape (num_samples,) for a single-channel signal + or (num_samples, num_channels) for a multi-channel signal. + """ + output = cls.get_samples_synchronized( + audio_files=[audio_file], + sample_rate=sample_rate, + duration=duration, + channel_selectors=[channel_selector], + fixed_offset=fixed_offset, + random_offset=random_offset, + ) + + return output[0] + + @classmethod + def get_samples_synchronized( + cls, + audio_files: List[str], + sample_rate: int, + duration: Optional[float] = None, + channel_selectors: Optional[List[ChannelSelectorType]] = None, + fixed_offset: float = 0, + random_offset: bool = False, + ) -> List[np.ndarray]: + """Get samples from multiple files with the same start and end point. + + Args: + audio_files: list of paths to audio files + sample_rate: desired sample rate for output samples + duration: Optional desired duration of output samples. + If `None`, the complete files will be loaded. + If set, a segment of `duration` seconds will be loaded from + all files. Segment is synchronized across files, so that + start and end points are the same. + channel_selectors: Optional channel selector for each signal, for selecting + a subset of channels. + fixed_offset: Optional fixed offset when loading samples. + random_offset: If `True`, offset will be randomized when loading a short segment + from a file. The value is randomized between fixed_offset and + max_offset (set depending on the duration and fixed_offset). + + Returns: + List with the same size as `audio_files` but containing numpy arrays + with samples from each audio file. + Each array has shape (num_samples, ) or (num_samples, num_channels), for single- + or multi-channel signal, respectively. + For example, if `audio_files = [path/to/file_1.wav, path/to/file_2.wav]`, + the output will be a list `output = [samples_file_1, samples_file_2]`. + """ + if channel_selectors is None: + channel_selectors = [None] * len(audio_files) + + if duration is None: + # Load complete files starting from a fixed offset + offset = fixed_offset # fixed offset + num_samples = None # no constrain on the number of samples + + else: + # Fixed duration of the output + audio_durations = cls.get_duration(audio_files) + min_audio_duration = min(audio_durations) + available_duration = min_audio_duration - fixed_offset + + if available_duration <= 0: + raise ValueError(f'Fixed offset {fixed_offset}s is larger than shortest file {min_duration}s.') + + if duration + fixed_offset > min_audio_duration: + # The shortest file is shorter than the requested duration + logging.warning( + f'Shortest file ({min_audio_duration}s) is less than the desired duration {duration}s + fixed offset {fixed_offset}s. Returned signals will be shortened to {available_duration} seconds.' + ) + offset = fixed_offset + duration = available_duration + elif random_offset: + # Randomize offset based on the shortest file + max_offset = min_audio_duration - duration + offset = random.uniform(fixed_offset, max_offset) + else: + # Fixed offset + offset = fixed_offset + + # Fixed number of samples + num_samples = math.floor(duration * sample_rate) + + output = [] + + # Prepare segments + for idx, audio_file in enumerate(audio_files): + segment_samples = cls.get_samples_from_file( + audio_file=audio_file, + sample_rate=sample_rate, + offset=offset, + num_samples=num_samples, + channel_selector=channel_selectors[idx], + ) + output.append(segment_samples) + + return output + + @classmethod + def get_samples_from_file( + cls, + audio_file: Union[str, List[str]], + sample_rate: int, + offset: float, + num_samples: Optional[int] = None, + channel_selector: Optional[ChannelSelectorType] = None, + ) -> np.ndarray: + """Get samples from a single or multiple files. + If loading samples from multiple files, they will + be concatenated along the channel dimension. + + Args: + audio_file: path or a list of paths. + sample_rate: sample rate of the loaded samples + offset: fixed offset in seconds + num_samples: Optional, number of samples to load. + If `None`, all available samples will be loaded. + channel_selector: Select a subset of available channels. + + Returns: + An array with shape (samples,) or (samples, channels) + """ + if isinstance(audio_file, str): + # Load samples from a single file + segment_samples = cls.get_segment_from_file( + audio_file=audio_file, + sample_rate=sample_rate, + offset=offset, + num_samples=num_samples, + channel_selector=channel_selector, + ) + elif isinstance(audio_file, list): + # Load samples from multiple files and form a multi-channel signal + segment_samples = [] + for a_file in audio_file: + a_file_samples = cls.get_segment_from_file( + audio_file=a_file, + sample_rate=sample_rate, + offset=offset, + num_samples=num_samples, + channel_selector=channel_selector, + ) + segment_samples.append(a_file_samples) + segment_samples = cls.list_to_multichannel(segment_samples) + elif audio_file is None: + # Support for inference, when the target signal is `None` + segment_samples = [] + else: + raise RuntimeError(f'Unexpected audio_file type {type(audio_file)}') + return segment_samples + + @staticmethod + def get_segment_from_file( + audio_file: str, + sample_rate: int, + offset: float, + num_samples: Optional[int] = None, + channel_selector: Optional[ChannelSelectorType] = None, + ) -> np.ndarray: + """Get a segment of samples from a single audio file. + + Args: + audio_file: path to an audio file + sample_rate: sample rate of the loaded samples + offset: fixed offset in seconds + num_samples: Optional, number of samples to load. + If `None`, all available samples will be loaded. + channel_selector: Select a subset of available channels. + + Returns: + An array with shape (samples,) or (samples, channels) + """ + if num_samples is None: + segment = AudioSegment.from_file( + audio_file=audio_file, target_sr=sample_rate, offset=offset, channel_selector=channel_selector, + ) + + else: + segment = AudioSegment.segment_from_file( + audio_file=audio_file, + target_sr=sample_rate, + n_segments=num_samples, + offset=offset, + channel_selector=channel_selector, + ) + return segment.samples + + @staticmethod + def list_to_multichannel(signal: Union[np.ndarray, List[np.ndarray]]) -> np.ndarray: + """Convert a list of signals into a multi-channel signal by concatenating + the elements of the list along the channel dimension. + + If input is not a list, it is returned unmodified. + + Args: + signal: list of arrays + + Returns: + Numpy array obtained by concatenating the elements of the list + along the channel dimension (axis=1). + """ + if not isinstance(signal, list): + # Nothing to do there + return signal + elif len(signal) == 0: + # Nothing to do, return as is + return signal + elif len(signal) == 1: + # Nothing to concatenate, return the original format + return signal[0] + + # If multiple signals are provided in a list, we concatenate them along the channel dimension + if signal[0].ndim == 1: + # Single-channel individual files + mc_signal = np.stack(signal, axis=1) + elif signal[0].ndim == 2: + # Multi-channel individual files + mc_signal = np.concatenate(signal, axis=1) + else: + raise RuntimeError(f'Unexpected target with {signal[0].ndim} dimensions.') + + return mc_signal + + @staticmethod + def get_duration(audio_files: List[str]) -> List[float]: + """Get duration for each audio file in `audio_files`. + + Args: + audio_files: list of paths to audio files + + Returns: + List of durations in seconds. + """ + duration = [librosa.get_duration(filename=f) for f in flatten(audio_files)] + return duration + + def load_embedding(self, example: collections.Audio.OUTPUT_TYPE) -> Dict[str, torch.Tensor]: + """Given an example, load embedding from `example.audio_files[embedding]` + and return it in a dictionary. + + Args: + example: An example from audio collection + + Returns: + An dictionary of embedding keys and their tensors. + """ + output = OrderedDict() + for idx, signal in enumerate(self.embedding_setup.signals): + embedding_file = example.audio_files[signal] + embedding = self.load_embedding_vector(embedding_file) + output[signal] = torch.tensor(embedding) + return output + + @staticmethod + def load_embedding_vector(filepath: str) -> np.ndarray: + """Load an embedding vector from a file. + + Args: + filepath: path to a file storing a vector. + Currently, it is assumed the file is a npy file. + + Returns: + Array loaded from filepath. + """ + if filepath.endswith('.npy'): + with open(filepath, 'rb') as f: + embedding = np.load(f) + else: + raise RuntimeError(f'Unknown embedding file format in file: {filepath}') + + return embedding + + +@experimental +class BaseAudioDataset(Dataset): + """Base class of audio datasets, providing common functionality + for other audio datasets. + + Args: + collection: Collection of audio examples prepared from manifest files. + audio_processor: Used to process every example from the collection. + A callable with `process` method. For reference, + please check ASRAudioProcessor. + """ + + @property + @abc.abstractmethod + def output_types(self) -> Optional[Dict[str, NeuralType]]: + """Returns definitions of module output ports. + """ + + def __init__( + self, collection: collections.Audio, audio_processor: Callable, + ): + """Instantiates an audio dataset. + """ + super().__init__() + + self.collection = collection + self.audio_processor = audio_processor + + def num_channels(self, signal_key) -> int: + """Returns the number of channels for a particular signal in + items prepared by this dictionary. + + More specifically, this will get the tensor from the first + item in the dataset, check if it's a one- or two-dimensional + tensor, and return the number of channels based on the size + of the second axis (shape[1]). + + NOTE: + This assumes that all examples have the same number of channels. + + Args: + signal_key: string, used to select a signal from the dictionary + output by __getitem__ + + Returns: + Number of channels for the selected signal. + """ + # Assumption: whole dataset has the same number of channels + item = self.__getitem__(0) + + if item[signal_key].ndim == 1: + return 1 + elif item[signal_key].ndim == 2: + return item[signal_key].shape[1] + else: + raise RuntimeError( + f'Unexpected number of dimension for signal {signal_key} with shape {item[signal_key].shape}' + ) + + def __getitem__(self, index: int) -> Dict[str, torch.Tensor]: + """Return a single example from the dataset. + + Args: + index: integer index of an example in the collection + + Returns: + Dictionary providing mapping from signal to its tensor. + For example: + ``` + { + 'input_signal': input_signal_tensor, + 'target_signal': target_signal_tensor, + } + ``` + """ + example = self.collection[index] + output = self.audio_processor.process(example=example) + + return output + + def __len__(self) -> int: + """Return the number of examples in the dataset. + """ + return len(self.collection) + + def _collate_fn(self, batch) -> Tuple[torch.Tensor]: + """Collate items in a batch. + """ + return _audio_collate_fn(batch) + + +@experimental +class AudioToTargetDataset(BaseAudioDataset): + """A dataset for audio-to-audio tasks where the goal is to use + an input signal to recover the corresponding target signal. + + Each line of the manifest file is expected to have the following format + ``` + { + 'input_key': 'path/to/input.wav', + 'target_key': 'path/to/path_to_target.wav', + 'duration': duration_of_input, + } + ``` + + Additionally, multiple audio files may be provided for each key in the manifest, for example, + ``` + { + 'input_key': 'path/to/input.wav', + 'target_key': ['path/to/path_to_target_ch0.wav', 'path/to/path_to_target_ch1.wav'], + 'duration': duration_of_input, + } + ``` + + Keys for input and target signals can be configured in the constructor (`input_key` and `target_key`). + + Args: + manifest_filepath: Path to manifest file in a format described above. + sample_rate: Sample rate for loaded audio signals. + input_key: Key pointing to input audio files in the manifest + target_key: Key pointing to target audio files in manifest + audio_duration: Optional duration of each item returned by __getitem__. + If `None`, complete audio will be loaded. + If set, a random subsegment will be loaded synchronously from + target and audio, i.e., with the same start and end point. + random_offset: If `True`, offset will be randomized when loading a subsegment + from a file. + max_duration: If audio exceeds this length, do not include in dataset. + min_duration: If audio is less than this length, do not include in dataset. + max_utts: Limit number of utterances. + input_channel_selector: Optional, select subset of channels from each input audio file. + If `None`, all channels will be loaded. + target_channel_selector: Optional, select subset of channels from each input audio file. + If `None`, all channels will be loaded. + """ + + def __init__( + self, + manifest_filepath: str, + sample_rate: int, + input_key: str, + target_key: str, + audio_duration: Optional[float] = None, + random_offset: bool = False, + max_duration: Optional[float] = None, + min_duration: Optional[float] = None, + max_utts: Optional[int] = None, + input_channel_selector: Optional[int] = None, + target_channel_selector: Optional[int] = None, + ): + audio_to_manifest_key = { + 'input_signal': input_key, + 'target_signal': target_key, + } + + collection = collections.AudioCollection( + manifest_files=manifest_filepath, + audio_to_manifest_key=audio_to_manifest_key, + min_duration=min_duration, + max_duration=max_duration, + max_number=max_utts, + ) + + audio_processor = ASRAudioProcessor(sample_rate=sample_rate, random_offset=random_offset,) + audio_processor.sync_setup = SignalSetup( + signals=['input_signal', 'target_signal'], + duration=audio_duration, + channel_selectors=[input_channel_selector, target_channel_selector], + ) + + super().__init__( + collection=collection, audio_processor=audio_processor, + ) + + @property + def output_types(self) -> Optional[Dict[str, NeuralType]]: + """Returns definitions of module output ports. + + Returns: + Ordered dictionary in the following form: + ``` + { + 'input_signal': batched single- or multi-channel format, + 'input_length': batched original length of each input signal + 'target_signal': batched single- or multi-channel format, + 'target_length': batched original length of each target signal + } + ``` + """ + sc_audio_type = NeuralType(('B', 'T'), AudioSignal()) + mc_audio_type = NeuralType(('B', 'T', 'C'), AudioSignal()) + + return OrderedDict( + input_signal=sc_audio_type if self.num_channels('input_signal') == 1 else mc_audio_type, + input_length=NeuralType(('B',), LengthsType()), + target_signal=sc_audio_type if self.num_channels('target_signal') == 1 else mc_audio_type, + target_length=NeuralType(('B',), LengthsType()), + ) + + +@experimental +class AudioToTargetWithReferenceDataset(BaseAudioDataset): + """A dataset for audio-to-audio tasks where the goal is to use + an input signal to recover the corresponding target signal and an + additional reference signal is available. + + This can be used, for example, when a reference signal is + available from + - enrollment utterance for the target signal + - echo reference from playback + - reference from another sensor that correlates with the target signal + + Each line of the manifest file is expected to have the following format + ``` + { + 'input_key': 'path/to/input.wav', + 'target_key': 'path/to/path_to_target.wav', + 'reference_key': 'path/to/path_to_reference.wav', + 'duration': duration_of_input, + } + ``` + + Keys for input, target and reference signals can be configured in the constructor. + + Args: + manifest_filepath: Path to manifest file in a format described above. + sample_rate: Sample rate for loaded audio signals. + input_key: Key pointing to input audio files in the manifest + target_key: Key pointing to target audio files in manifest + reference_key: Key pointing to reference audio files in manifest + audio_duration: Optional duration of each item returned by __getitem__. + If `None`, complete audio will be loaded. + If set, a random subsegment will be loaded synchronously from + target and audio, i.e., with the same start and end point. + random_offset: If `True`, offset will be randomized when loading a subsegment + from a file. + max_duration: If audio exceeds this length, do not include in dataset. + min_duration: If audio is less than this length, do not include in dataset. + max_utts: Limit number of utterances. + input_channel_selector: Optional, select subset of channels from each input audio file. + If `None`, all channels will be loaded. + target_channel_selector: Optional, select subset of channels from each input audio file. + If `None`, all channels will be loaded. + reference_channel_selector: Optional, select subset of channels from each input audio file. + If `None`, all channels will be loaded. + reference_is_synchronized: If True, it is assumed that the reference signal is synchronized + with the input signal, so the same subsegment will be loaded as for + input and target. If False, reference signal will be loaded independently + from input and target. + reference_duration: Optional, can be used to set a fixed duration of the reference utterance. If `None`, + complete audio file will be loaded. + """ + + def __init__( + self, + manifest_filepath: str, + sample_rate: int, + input_key: str, + target_key: str, + reference_key: str, + audio_duration: Optional[float] = None, + random_offset: bool = False, + max_duration: Optional[float] = None, + min_duration: Optional[float] = None, + max_utts: Optional[int] = None, + input_channel_selector: Optional[int] = None, + target_channel_selector: Optional[int] = None, + reference_channel_selector: Optional[int] = None, + reference_is_synchronized: bool = True, + reference_duration: Optional[float] = None, + ): + audio_to_manifest_key = { + 'input_signal': input_key, + 'target_signal': target_key, + 'reference_signal': reference_key, + } + + collection = collections.AudioCollection( + manifest_files=manifest_filepath, + audio_to_manifest_key=audio_to_manifest_key, + min_duration=min_duration, + max_duration=max_duration, + max_number=max_utts, + ) + + audio_processor = ASRAudioProcessor(sample_rate=sample_rate, random_offset=random_offset,) + + if reference_is_synchronized: + audio_processor.sync_setup = SignalSetup( + signals=['input_signal', 'target_signal', 'reference_signal'], + duration=audio_duration, + channel_selectors=[input_channel_selector, target_channel_selector, reference_channel_selector], + ) + else: + audio_processor.sync_setup = SignalSetup( + signals=['input_signal', 'target_signal'], + duration=audio_duration, + channel_selectors=[input_channel_selector, target_channel_selector], + ) + audio_processor.async_setup = SignalSetup( + signals=['reference_signal'], + duration=[reference_duration], + channel_selectors=[reference_channel_selector], + ) + + super().__init__( + collection=collection, audio_processor=audio_processor, + ) + + @property + def output_types(self) -> Optional[Dict[str, NeuralType]]: + """Returns definitions of module output ports. + + Returns: + Ordered dictionary in the following form: + ``` + { + 'input_signal': batched single- or multi-channel format, + 'input_length': batched original length of each input signal + 'target_signal': batched single- or multi-channel format, + 'target_length': batched original length of each target signal + 'reference_signal': single- or multi-channel format, + 'reference_length': original length of each reference signal + } + ``` + """ + sc_audio_type = NeuralType(('B', 'T'), AudioSignal()) + mc_audio_type = NeuralType(('B', 'T', 'C'), AudioSignal()) + + return OrderedDict( + input_signal=sc_audio_type if self.num_channels('input_signal') == 1 else mc_audio_type, + input_length=NeuralType(('B',), LengthsType()), + target_signal=sc_audio_type if self.num_channels('target_signal') == 1 else mc_audio_type, + target_length=NeuralType(('B',), LengthsType()), + reference_signal=sc_audio_type if self.num_channels('reference_signal') == 1 else mc_audio_type, + reference_length=NeuralType(('B',), LengthsType()), + ) + + +@experimental +class AudioToTargetWithEmbeddingDataset(BaseAudioDataset): + """A dataset for audio-to-audio tasks where the goal is to use + an input signal to recover the corresponding target signal and an + additional embedding signal. It is assumed that the embedding + is in a form of a vector. + + Each line of the manifest file is expected to have the following format + ``` + { + input_key: 'path/to/input.wav', + target_key: 'path/to/path_to_target.wav', + embedding_key: 'path/to/path_to_reference.npy', + 'duration': duration_of_input, + } + ``` + + Keys for input, target and embedding signals can be configured in the constructor. + + Args: + manifest_filepath: Path to manifest file in a format described above. + sample_rate: Sample rate for loaded audio signals. + input_key: Key pointing to input audio files in the manifest + target_key: Key pointing to target audio files in manifest + embedding_key: Key pointing to embedding files in manifest + audio_duration: Optional duration of each item returned by __getitem__. + If `None`, complete audio will be loaded. + If set, a random subsegment will be loaded synchronously from + target and audio, i.e., with the same start and end point. + random_offset: If `True`, offset will be randomized when loading a subsegment + from a file. + max_duration: If audio exceeds this length, do not include in dataset. + min_duration: If audio is less than this length, do not include in dataset. + max_utts: Limit number of utterances. + input_channel_selector: Optional, select subset of channels from each input audio file. + If `None`, all channels will be loaded. + target_channel_selector: Optional, select subset of channels from each input audio file. + If `None`, all channels will be loaded. + """ + + def __init__( + self, + manifest_filepath: str, + sample_rate: int, + input_key: str, + target_key: str, + embedding_key: str, + audio_duration: Optional[float] = None, + random_offset: bool = False, + max_duration: Optional[float] = None, + min_duration: Optional[float] = None, + max_utts: Optional[int] = None, + input_channel_selector: Optional[int] = None, + target_channel_selector: Optional[int] = None, + ): + audio_to_manifest_key = { + 'input_signal': input_key, + 'target_signal': target_key, + 'embedding_vector': embedding_key, + } + + collection = collections.AudioCollection( + manifest_files=manifest_filepath, + audio_to_manifest_key=audio_to_manifest_key, + min_duration=min_duration, + max_duration=max_duration, + max_number=max_utts, + ) + + audio_processor = ASRAudioProcessor(sample_rate=sample_rate, random_offset=random_offset,) + audio_processor.sync_setup = SignalSetup( + signals=['input_signal', 'target_signal'], + duration=audio_duration, + channel_selectors=[input_channel_selector, target_channel_selector], + ) + audio_processor.embedding_setup = SignalSetup(signals=['embedding_vector']) + + super().__init__( + collection=collection, audio_processor=audio_processor, + ) + + @property + def output_types(self) -> Optional[Dict[str, NeuralType]]: + """Returns definitions of module output ports. + + Returns: + Ordered dictionary in the following form: + ``` + { + 'input_signal': batched single- or multi-channel format, + 'input_length': batched original length of each input signal + 'target_signal': batched single- or multi-channel format, + 'target_length': batched original length of each target signal + 'embedding_vector': batched embedded vector format, + 'embedding_length': batched original length of each embedding vector + } + ``` + """ + sc_audio_type = NeuralType(('B', 'T'), AudioSignal()) + mc_audio_type = NeuralType(('B', 'T', 'C'), AudioSignal()) + + return OrderedDict( + input_signal=sc_audio_type if self.num_channels('input_signal') == 1 else mc_audio_type, + input_length=NeuralType(('B',), LengthsType()), + target_signal=sc_audio_type if self.num_channels('target_signal') == 1 else mc_audio_type, + target_length=NeuralType(('B',), LengthsType()), + embedding_vector=NeuralType(('B', 'D'), EncodedRepresentation()), + embedding_length=NeuralType(('B',), LengthsType()), + ) diff --git a/nemo/collections/asr/data/audio_to_audio_dataset.py b/nemo/collections/asr/data/audio_to_audio_dataset.py new file mode 100644 index 0000000000000..52c2e429858d7 --- /dev/null +++ b/nemo/collections/asr/data/audio_to_audio_dataset.py @@ -0,0 +1,92 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nemo.collections.asr.data import audio_to_audio + + +def get_audio_to_target_dataset(config: dict) -> audio_to_audio.AudioToTargetDataset: + """Instantiates an audio-to-audio dataset. + + Args: + config: Config of AudioToTargetDataset. + + Returns: + An instance of AudioToTargetDataset + """ + dataset = audio_to_audio.AudioToTargetDataset( + manifest_filepath=config['manifest_filepath'], + sample_rate=config['sample_rate'], + input_key=config['input_key'], + target_key=config['target_key'], + audio_duration=config.get('audio_duration', None), + max_duration=config.get('max_duration', None), + min_duration=config.get('min_duration', None), + max_utts=config.get('max_utts', 0), + input_channel_selector=config.get('input_channel_selector', None), + target_channel_selector=config.get('target_channel_selector', None), + ) + return dataset + + +def get_audio_to_target_with_reference_dataset(config: dict) -> audio_to_audio.AudioToTargetWithReferenceDataset: + """Instantiates an audio-to-audio dataset. + + Args: + config: Config of AudioToTargetWithReferenceDataset. + + Returns: + An instance of AudioToTargetWithReferenceDataset + """ + dataset = audio_to_audio.AudioToTargetWithReferenceDataset( + manifest_filepath=config['manifest_filepath'], + sample_rate=config['sample_rate'], + input_key=config['input_key'], + target_key=config['target_key'], + reference_key=config['reference_key'], + audio_duration=config.get('audio_duration', None), + max_duration=config.get('max_duration', None), + min_duration=config.get('min_duration', None), + max_utts=config.get('max_utts', 0), + input_channel_selector=config.get('input_channel_selector', None), + target_channel_selector=config.get('target_channel_selector', None), + reference_channel_selector=config.get('reference_channel_selector', None), + reference_is_synchronized=config.get('reference_is_synchronized', True), + reference_duration=config.get('reference_duration', None), + ) + return dataset + + +def get_audio_to_target_with_embedding_dataset(config: dict) -> audio_to_audio.AudioToTargetWithEmbeddingDataset: + """Instantiates an audio-to-audio dataset. + + Args: + config: Config of AudioToTargetWithEmbeddingDataset. + + Returns: + An instance of AudioToTargetWithEmbeddingDataset + """ + dataset = audio_to_audio.AudioToTargetWithEmbeddingDataset( + manifest_filepath=config['manifest_filepath'], + sample_rate=config['sample_rate'], + input_key=config['input_key'], + target_key=config['target_key'], + embedding_key=config['embedding_key'], + audio_duration=config.get('audio_duration', None), + max_duration=config.get('max_duration', None), + min_duration=config.get('min_duration', None), + max_utts=config.get('max_utts', 0), + input_channel_selector=config.get('input_channel_selector', None), + target_channel_selector=config.get('target_channel_selector', None), + ) + return dataset diff --git a/nemo/collections/asr/parts/preprocessing/segment.py b/nemo/collections/asr/parts/preprocessing/segment.py index 2aca9bdf1b555..1b22dcf668402 100644 --- a/nemo/collections/asr/parts/preprocessing/segment.py +++ b/nemo/collections/asr/parts/preprocessing/segment.py @@ -257,12 +257,22 @@ def from_file( @classmethod def segment_from_file( - cls, audio_file, target_sr=None, n_segments=0, trim=False, orig_sr=None, channel_selector=None, + cls, audio_file, target_sr=None, n_segments=0, trim=False, orig_sr=None, channel_selector=None, offset=None ): - """Grabs n_segments number of samples from audio_file randomly from the - file as opposed to at a specified offset. + """Grabs n_segments number of samples from audio_file. + If offset is not provided, n_segments are selected randomly. + If offset is provided, it is used to calculate the starting sample. Note that audio_file can be either the file path, or a file-like object. + + :param audio_file: path to a file or a file-like object + :param target_sr: sample rate for the output samples + :param n_segments: desired number of samples + :param trim: if true, trim leading and trailing silence from an audio signal + :param orig_sr: the original sample rate + :param channel selector: select a subset of channels. If set to `None`, the original signal will be used. + :param offset: fixed offset in seconds + :return: numpy array of samples """ is_segmented = False try: @@ -275,13 +285,20 @@ def segment_from_file( if 0 < n_segments_at_original_sr < len(f): max_audio_start = len(f) - n_segments_at_original_sr - audio_start = random.randint(0, max_audio_start) + if offset is None: + audio_start = random.randint(0, max_audio_start) + else: + audio_start = math.floor(offset * sample_rate) + if audio_start > max_audio_start: + raise RuntimeError( + f'Provided audio start ({audio_start_seconds} seconds = {audio_start} samples) is larger than the maximum possible ({max_audio_start})' + ) f.seek(audio_start) samples = f.read(n_segments_at_original_sr, dtype='float32') is_segmented = True - elif n_segments_at_original_sr >= len(f): + elif n_segments_at_original_sr > len(f): logging.warning( - f"Number of segments is greater than the length of the audio file {audio_file}. This may lead to shape mismatch errors." + f"Number of segments ({n_segments_at_original_sr}) is greater than the length ({len(f)}) of the audio file {audio_file}. This may lead to shape mismatch errors." ) samples = f.read(dtype='float32') else: @@ -363,8 +380,7 @@ def subsegment(self, start_time=None, end_time=None): :param end_time: End of subsegment in seconds. :type end_time: float :raise ValueError: If start_time or end_time is incorrectly set, - e.g. out - of bounds in time. + e.g. out of bounds in time. """ start_time = 0.0 if start_time is None else start_time end_time = self.duration if end_time is None else end_time diff --git a/nemo/collections/asr/parts/utils/audio_utils.py b/nemo/collections/asr/parts/utils/audio_utils.py index be4b0f570d110..93d399740be1e 100644 --- a/nemo/collections/asr/parts/utils/audio_utils.py +++ b/nemo/collections/asr/parts/utils/audio_utils.py @@ -17,6 +17,7 @@ import librosa import numpy as np import numpy.typing as npt +import scipy import soundfile as sf from scipy.spatial.distance import pdist, squareform @@ -376,3 +377,26 @@ def pow2db(power: float, eps: Optional[float] = 1e-16) -> float: Power in dB. """ return 10 * np.log10(power + eps) + + +def get_segment_start(signal: np.ndarray, segment: np.ndarray) -> int: + """Get starting point of `segment` in `signal`. + We assume that `segment` is a sub-segment of `signal`. + For example, `signal` may be a 10 second audio signal, + and `segment` could be the signal between 2 seconds and + 5 seconds. This function will then return the index of + the sample where `segment` starts (at 2 seconds). + + Args: + signal: numpy array with shape (num_samples,) + segment: numpy array with shape (num_samples,) + + Returns: + Index of the start of `segment` in `signal`. + """ + if len(signal) <= len(segment): + raise ValueError( + f'segment must be shorter than signal: len(segment) = {len(segment)}, len(signal) = {len(signal)}' + ) + cc = scipy.signal.correlate(signal, segment, mode='valid') + return np.argmax(cc) diff --git a/nemo/collections/common/parts/preprocessing/collections.py b/nemo/collections/common/parts/preprocessing/collections.py index 82f072c68b6af..c5d5d6eee1578 100644 --- a/nemo/collections/common/parts/preprocessing/collections.py +++ b/nemo/collections/common/parts/preprocessing/collections.py @@ -16,7 +16,7 @@ import json import os from itertools import combinations -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, Iterable, List, Optional, Union import pandas as pd @@ -784,3 +784,181 @@ def __parse_item_rttm(self, line: str, manifest_file: str) -> Dict[str, Any]: offset=item.get('offset', None), ) return item + + +class Audio(_Collection): + """Prepare a list of all audio items, filtered by duration. + """ + + OUTPUT_TYPE = collections.namedtuple(typename='Audio', field_names='audio_files duration offset text') + + def __init__( + self, + audio_files_list: List[Dict[str, str]], + duration_list: List[float], + offset_list: List[float], + text_list: List[str], + min_duration: Optional[float] = None, + max_duration: Optional[float] = None, + max_number: Optional[int] = None, + do_sort_by_duration: bool = False, + ): + """Instantiantes an list of audio files. + + Args: + audio_files_list: list of dictionaries with mapping from audio_key to audio_filepath + duration_list: list of durations of input files + offset_list: list of offsets + text_list: list of texts + min_duration: Minimum duration to keep entry with (default: None). + max_duration: Maximum duration to keep entry with (default: None). + max_number: Maximum number of samples to collect. + do_sort_by_duration: True if sort samples list by duration. + """ + + output_type = self.OUTPUT_TYPE + data, total_duration = [], 0.0 + num_filtered, duration_filtered = 0, 0.0 + + for audio_files, duration, offset, text in zip(audio_files_list, duration_list, offset_list, text_list): + # Duration filters + if min_duration is not None and duration < min_duration: + duration_filtered += duration + num_filtered += 1 + continue + + if max_duration is not None and duration > max_duration: + duration_filtered += duration + num_filtered += 1 + continue + + total_duration += duration + data.append(output_type(audio_files, duration, offset, text)) + + # Max number of entities filter + if len(data) == max_number: + break + + if do_sort_by_duration: + data.sort(key=lambda entity: entity.duration) + + logging.info("Dataset loaded with %d files totalling %.2f hours", len(data), total_duration / 3600) + logging.info("%d files were filtered totalling %.2f hours", num_filtered, duration_filtered / 3600) + + super().__init__(data) + + +class AudioCollection(Audio): + """List of audio files from a manifest file. + """ + + def __init__( + self, manifest_files: Union[str, List[str]], audio_to_manifest_key: Dict[str, str], *args, **kwargs, + ): + """Instantiates a list of audio files loaded from a manifest file. + + Args: + manifest_files: path to a single manifest file or a list of paths + audio_to_manifest_key: dictionary mapping audio signals to keys of the manifest + """ + # Support for comma-separated manifests + if type(manifest_files) == str: + manifest_files = manifest_files.split(',') + + for audio_key, manifest_key in audio_to_manifest_key.items(): + # Support for comma-separated keys + if type(manifest_key) == str and ',' in manifest_key: + audio_to_manifest_key[audio_key] = manifest_key.split(',') + + # Keys from manifest which contain audio + self.audio_to_manifest_key = audio_to_manifest_key + + # Initialize data + audio_files_list, duration_list, offset_list, text_list = [], [], [], [] + + # Parse manifest files + for item in manifest.item_iter(manifest_files, parse_func=self.__parse_item): + audio_files_list.append(item['audio_files']) + duration_list.append(item['duration']) + offset_list.append(item['offset']) + text_list.append(item['text']) + + super().__init__(audio_files_list, duration_list, offset_list, text_list, *args, **kwargs) + + def __parse_item(self, line: str, manifest_file: str) -> Dict[str, Any]: + """Parse a single line from a manifest file. + + Args: + line: a string representing a line from a manifest file in JSON format + manifest_file: path to the manifest file. Used to resolve relative paths. + + Returns: + Dictionary with audio_files, duration, and offset. + """ + # Local utility function + def get_audio_file(item: Dict, manifest_key: Union[str, List[str]]): + """Get item[key] if key is string, or a list + of strings by combining item[key[0]], item[key[1]], etc. + """ + # Prepare audio file(s) + if manifest_key is None: + # Support for inference, when a target key is None + audio_file = None + elif isinstance(manifest_key, str): + # Load files from a single manifest key + audio_file = item[manifest_key] + elif isinstance(manifest_key, Iterable): + # Load files from multiple manifest keys + audio_file = [] + for key in manifest_key: + item_key = item[key] + if isinstance(item_key, str): + audio_file.append(item_key) + elif isinstance(item_key, list): + audio_file += item_key + else: + raise ValueError(f'Unexpected type {type(item_key)} of item for key {key}: {item_key}') + else: + raise ValueError(f'Unexpected type {type(manifest_key)} of manifest_key: {manifest_key}') + + return audio_file + + # Convert JSON line to a dictionary + item = json.loads(line) + + # Handle all audio files + audio_files = {} + for audio_key, manifest_key in self.audio_to_manifest_key.items(): + + audio_file = get_audio_file(item, manifest_key) + + # Get full path to audio file(s) + if isinstance(audio_file, str): + # This dictionary entry points to a single file + audio_files[audio_key] = manifest.get_full_path(audio_file, manifest_file) + elif isinstance(audio_file, Iterable): + # This dictionary entry points to multiple files + # Get the files and keep the list structure for this key + audio_files[audio_key] = [manifest.get_full_path(f, manifest_file) for f in audio_file] + elif audio_file is None and audio_key.startswith('target'): + # For inference, we don't need the target + audio_files[audio_key] = None + else: + raise ValueError(f'Unexpected type {type(audio_file)} of audio_file: {audio_file}') + item['audio_files'] = audio_files + + # Handle duration + if 'duration' not in item: + raise ValueError(f'Duration not available in line: {line}. Manifest file: {manifest_file}') + + # Handle offset + if 'offset' not in item: + item['offset'] = 0.0 + + # Handle text + if 'text' not in item: + item['text'] = None + + return dict( + audio_files=item['audio_files'], duration=item['duration'], offset=item['offset'], text=item['text'] + ) diff --git a/nemo/collections/common/parts/preprocessing/manifest.py b/nemo/collections/common/parts/preprocessing/manifest.py index d63a3ae63bcb4..fc7768d26d29c 100644 --- a/nemo/collections/common/parts/preprocessing/manifest.py +++ b/nemo/collections/common/parts/preprocessing/manifest.py @@ -93,17 +93,7 @@ def __parse_item(line: str, manifest_file: str) -> Dict[str, Any]: # try to attach the parent directory of manifest to the audio path. # Revert to the original path if the new path still doesn't exist. # Assume that the audio path is like "wavs/xxxxxx.wav". - manifest_dir = Path(manifest_file).parent - audio_file = Path(item['audio_file']) - if (len(str(audio_file)) < 255) and not audio_file.is_file() and not audio_file.is_absolute(): - # assume the "wavs/" dir and manifest are under the same parent dir - audio_file = manifest_dir / audio_file - if audio_file.is_file(): - item['audio_file'] = str(audio_file.absolute()) - else: - item['audio_file'] = expanduser(item['audio_file']) - else: - item['audio_file'] = expanduser(item['audio_file']) + item['audio_file'] = get_full_path(audio_file=item['audio_file'], manifest_file=manifest_file) # Duration. if 'duration' not in item: @@ -132,3 +122,35 @@ def __parse_item(line: str, manifest_file: str) -> Dict[str, Any]: ) return item + + +def get_full_path(audio_file: str, manifest_file: str, audio_file_len_limit: int = 255) -> str: + """Get full path to audio_file. + + If the audio_file is a relative path and does not exist, + try to attach the parent directory of manifest to the audio path. + Revert to the original path if the new path still doesn't exist. + Assume that the audio path is like "wavs/xxxxxx.wav". + + Args: + audio_file: path to an audio file, either absolute or assumed relative + to the manifest directory + manifest_file: path to a manifest file + audio_file_len_limit: limit for length of audio_file when using relative paths + + Returns: + Full path to audio_file. + """ + audio_file = Path(audio_file) + manifest_dir = Path(manifest_file).parent + + if (len(str(audio_file)) < audio_file_len_limit) and not audio_file.is_file() and not audio_file.is_absolute(): + # assume audio_file path is relative to manifest_dir + audio_file_path = manifest_dir / audio_file + if audio_file_path.is_file(): + audio_file = str(audio_file_path.absolute()) + else: + audio_file = expanduser(audio_file) + else: + audio_file = expanduser(audio_file) + return audio_file diff --git a/nemo/collections/common/parts/utils.py b/nemo/collections/common/parts/utils.py index d6d8c1f95cacb..0bfcfaf9cdfb2 100644 --- a/nemo/collections/common/parts/utils.py +++ b/nemo/collections/common/parts/utils.py @@ -14,11 +14,11 @@ import math import os -from typing import List +from typing import Iterable, List import torch.nn as nn -__all__ = ['if_exist', '_compute_softmax'] +__all__ = ['if_exist', '_compute_softmax', 'flatten'] activation_registry = { "identity": nn.Identity, @@ -67,3 +67,32 @@ def _compute_softmax(scores): for score in exp_scores: probs.append(score / total_sum) return probs + + +def flatten_iterable(iter: Iterable) -> Iterable: + """Flatten an iterable which contains values or + iterables with values. + + Args: + iter: iterable containing values at the deepest level. + + Returns: + A flat iterable containing values. + """ + for it in iter: + if isinstance(it, str) or not isinstance(it, Iterable): + yield it + else: + yield from flatten_iterable(it) + + +def flatten(list_in: List) -> List: + """Flatten a list of (nested lists of) values into a flat list. + + Args: + list_in: list of values, possibly nested + + Returns: + A flat list of values. + """ + return list(flatten_iterable(list_in)) diff --git a/tests/collections/asr/test_asr_datasets.py b/tests/collections/asr/test_asr_datasets.py index 83c94fa59631b..88b0ab4127f1d 100644 --- a/tests/collections/asr/test_asr_datasets.py +++ b/tests/collections/asr/test_asr_datasets.py @@ -19,11 +19,19 @@ import numpy as np import pytest +import soundfile as sf import torch.cuda from omegaconf import DictConfig, OmegaConf from torch.utils.data import DataLoader -from nemo.collections.asr.data import audio_to_text_dataset +from nemo.collections.asr.data import audio_to_audio_dataset, audio_to_text_dataset +from nemo.collections.asr.data.audio_to_audio import ( + ASRAudioProcessor, + AudioToTargetDataset, + AudioToTargetWithEmbeddingDataset, + AudioToTargetWithReferenceDataset, + _audio_collate_fn, +) from nemo.collections.asr.data.audio_to_text import TarredAudioToBPEDataset, TarredAudioToCharDataset from nemo.collections.asr.data.audio_to_text_dali import ( __DALI_MINIMUM_VERSION__, @@ -33,6 +41,8 @@ ) from nemo.collections.asr.data.audio_to_text_dataset import inject_dataloader_value_from_model_config from nemo.collections.asr.models.ctc_models import EncDecCTCModel +from nemo.collections.asr.parts.utils.audio_utils import get_segment_start +from nemo.collections.asr.parts.utils.manifest_utils import write_manifest from nemo.collections.common import tokenizers from nemo.utils import logging @@ -574,3 +584,954 @@ def test_dali_tarred_char_vs_ref_dataset(self, test_data_dir): err = np.abs(a - b) assert np.mean(err) < 0.0001 assert np.max(err) < 0.01 + + +class TestAudioDatasets: + @pytest.mark.unit + @pytest.mark.parametrize('num_channels', [1, 2]) + @pytest.mark.parametrize('num_targets', [1, 3]) + def test_list_to_multichannel(self, num_channels, num_targets): + """Test conversion of a list of arrays into + """ + random_seed = 42 + num_samples = 1000 + + # Generate random signals + _rng = np.random.default_rng(seed=random_seed) + + # Multi-channel signal + golden_target = _rng.normal(size=(num_samples, num_channels * num_targets)) + + # Create a list of num_targets signals with num_channels channels + target_list = [golden_target[:, n * num_channels : (n + 1) * num_channels] for n in range(num_targets)] + + # Check the original signal is not modified + assert (ASRAudioProcessor.list_to_multichannel(golden_target) == golden_target).all() + # Check the list is converted back to the original signal + assert (ASRAudioProcessor.list_to_multichannel(target_list) == golden_target).all() + + @pytest.mark.unit + def test_audio_collate_fn(self): + """Test `_audio_collate_fn` + """ + batch_size = 16 + random_seed = 42 + atol = 1e-5 + + # Generate random signals + _rng = np.random.default_rng(seed=random_seed) + + signal_to_channels = { + 'input_signal': 2, + 'target_signal': 1, + 'reference_signal': 1, + } + + signal_to_length = { + 'input_signal': _rng.integers(low=5, high=25, size=batch_size), + 'target_signal': _rng.integers(low=5, high=25, size=batch_size), + 'reference_signal': _rng.integers(low=5, high=25, size=batch_size), + } + + # Generate batch + batch = [] + for n in range(batch_size): + item = dict() + for signal, num_channels in signal_to_channels.items(): + random_signal = _rng.normal(size=(signal_to_length[signal][n], num_channels)) + random_signal = np.squeeze(random_signal) # get rid of channel dimention for single-channel + item[signal] = torch.tensor(random_signal) + batch.append(item) + + # Run UUT + batched = _audio_collate_fn(batch) + + batched_signals = { + 'input_signal': batched[0].cpu().detach().numpy(), + 'target_signal': batched[2].cpu().detach().numpy(), + 'reference_signal': batched[4].cpu().detach().numpy(), + } + + batched_lengths = { + 'input_signal': batched[1].cpu().detach().numpy(), + 'target_signal': batched[3].cpu().detach().numpy(), + 'reference_signal': batched[5].cpu().detach().numpy(), + } + + # Check outputs + for signal, b_signal in batched_signals.items(): + for n in range(batch_size): + # Check length + uut_length = batched_lengths[signal][n] + golden_length = signal_to_length[signal][n] + assert ( + uut_length == golden_length + ), f'Example {n} signal {signal} length mismatch: batched ({uut_length}) != golden ({golden_length})' + + uut_signal = b_signal[n][:uut_length, ...] + golden_signal = batch[n][signal][:uut_length, ...].cpu().detach().numpy() + assert np.isclose( + uut_signal, golden_signal, atol=atol + ).all(), f'Example {n} signal {signal} value mismatch.' + + @pytest.mark.unit + def test_audio_to_target_dataset(self): + """Test AudioWithTargetDataset in different configurations. + + Test below cover the following: + 1) no constraints + 2) filtering based on signal duration + 3) use with channel selector + 4) use with fixed audio duration and random subsegments + 5) collate a batch of items + + In this use case, each line of the manifest file has the following format: + ``` + { + 'input_filepath': 'path/to/input.wav', + 'target_filepath': 'path/to/path_to_target.wav', + 'duration': duration_of_input, + } + ``` + """ + # Data setup + random_seed = 42 + sample_rate = 16000 + num_examples = 25 + data_num_channels = { + 'input_signal': 4, + 'target_signal': 2, + } + data_min_duration = 2.0 + data_max_duration = 8.0 + data_key = { + 'input_signal': 'input_filepath', + 'target_signal': 'target_filepath', + } + + # Tolerance + atol = 1e-6 + + # Generate random signals + _rng = np.random.default_rng(seed=random_seed) + + # Input and target signals have the same duration + data_duration = np.round(_rng.uniform(low=data_min_duration, high=data_max_duration, size=num_examples), 3) + data_duration_samples = np.floor(data_duration * sample_rate).astype(int) + + data = dict() + for signal, num_channels in data_num_channels.items(): + data[signal] = [] + for n in range(num_examples): + if num_channels == 1: + random_signal = _rng.uniform(low=-0.5, high=0.5, size=(data_duration_samples[n])) + else: + random_signal = _rng.uniform(low=-0.5, high=0.5, size=(data_duration_samples[n], num_channels)) + data[signal].append(random_signal) + + with tempfile.TemporaryDirectory() as test_dir: + + # Build metadata for manifest + metadata = [] + + for n in range(num_examples): + + meta = dict() + + for signal in data: + # filenames + signal_filename = f'{signal}_{n:02d}.wav' + + # write audio files + sf.write(os.path.join(test_dir, signal_filename), data[signal][n], sample_rate, 'float') + + # update metadata + meta[data_key[signal]] = signal_filename + + meta['duration'] = data_duration[n] + metadata.append(meta) + + # Save manifest + manifest_filepath = os.path.join(test_dir, 'manifest.json') + write_manifest(manifest_filepath, metadata) + + # Test 1 + # - No constraints on channels or duration + dataset = AudioToTargetDataset( + manifest_filepath=manifest_filepath, + input_key=data_key['input_signal'], + target_key=data_key['target_signal'], + sample_rate=sample_rate, + ) + + # Also test the corresponding factory + config = { + 'manifest_filepath': manifest_filepath, + 'input_key': data_key['input_signal'], + 'target_key': data_key['target_signal'], + 'sample_rate': sample_rate, + } + dataset_factory = audio_to_audio_dataset.get_audio_to_target_dataset(config) + + # Test number of channels + for signal in data: + assert data_num_channels[signal] == dataset.num_channels( + signal + ), f'Num channels not correct for signal {signal}' + assert data_num_channels[signal] == dataset_factory.num_channels( + signal + ), f'Num channels not correct for signal {signal}' + + # Test returned examples + for n in range(num_examples): + item = dataset.__getitem__(n) + item_factory = dataset_factory.__getitem__(n) + + for signal in data: + item_signal = item[signal].cpu().detach().numpy() + golden_signal = data[signal][n] + assert ( + item_signal.shape == golden_signal.shape + ), f'Signal {signal}: item shape {item_signal.shape} not matching reference shape {golden_signal.shape}' + assert np.isclose( + item_signal, golden_signal, atol=atol + ).all(), f'Test 1: Failed for example {n}, signal {signal} (random seed {random_seed})' + + item_factory_signal = item_factory[signal].cpu().detach().numpy() + assert np.isclose( + item_factory_signal, golden_signal, atol=atol + ).all(), f'Test 1: Failed for factory example {n}, signal {signal} (random seed {random_seed})' + + # Test 2 + # - Filtering based on signal duration + min_duration = 3.5 + max_duration = 7.5 + + dataset = AudioToTargetDataset( + manifest_filepath=manifest_filepath, + input_key=data_key['input_signal'], + target_key=data_key['target_signal'], + min_duration=min_duration, + max_duration=max_duration, + sample_rate=sample_rate, + ) + + filtered_examples = [n for n, val in enumerate(data_duration) if min_duration <= val <= max_duration] + + for n in range(len(dataset)): + item = dataset.__getitem__(n) + + for signal in data: + item_signal = item[signal].cpu().detach().numpy() + golden_signal = data[signal][filtered_examples[n]] + assert ( + item_signal.shape == golden_signal.shape + ), f'Signal {signal}: item shape {item_signal.shape} not matching reference shape {golden_signal.shape}' + assert np.isclose( + item_signal, golden_signal, atol=atol + ).all(), f'Test 2: Failed for example {n}, signal {signal} (random seed {random_seed})' + + # Test 3 + # - Use channel selector + channel_selector = { + 'input_signal': [0, 2], + 'target_signal': 1, + } + + dataset = AudioToTargetDataset( + manifest_filepath=manifest_filepath, + input_key=data_key['input_signal'], + target_key=data_key['target_signal'], + input_channel_selector=channel_selector['input_signal'], + target_channel_selector=channel_selector['target_signal'], + sample_rate=sample_rate, + ) + + for n in range(len(dataset)): + item = dataset.__getitem__(n) + + for signal in data: + cs = channel_selector[signal] + item_signal = item[signal].cpu().detach().numpy() + golden_signal = data[signal][n][..., cs] + assert ( + item_signal.shape == golden_signal.shape + ), f'Signal {signal}: item shape {item_signal.shape} not matching reference shape {golden_signal.shape}' + assert np.isclose( + item_signal, golden_signal, atol=atol + ).all(), f'Test 3: Failed for example {n}, signal {signal} (random seed {random_seed})' + + # Test 4 + # - Use fixed duration (random segment selection) + audio_duration = 4.0 + audio_duration_samples = int(np.floor(audio_duration * sample_rate)) + + filtered_examples = [n for n, val in enumerate(data_duration) if val >= audio_duration] + + for random_offset in [True, False]: + # Test subsegments with the default fixed offset and a random offset + + dataset = AudioToTargetDataset( + manifest_filepath=manifest_filepath, + input_key=data_key['input_signal'], + target_key=data_key['target_signal'], + sample_rate=sample_rate, + min_duration=audio_duration, + audio_duration=audio_duration, + random_offset=random_offset, # random offset when selecting subsegment + ) + + for n in range(len(dataset)): + item = dataset.__getitem__(n) + + golden_start = golden_end = None + for signal in data: + item_signal = item[signal].cpu().detach().numpy() + full_golden_signal = data[signal][filtered_examples[n]] + + # Find random segment using correlation on the first channel + # of the first signal, and then use it fixed for other signals + if golden_start is None: + golden_start = get_segment_start( + signal=full_golden_signal[:, 0], segment=item_signal[:, 0] + ) + if not random_offset: + assert ( + golden_start == 0 + ), f'Expecting the signal to start at 0 when random_offset is False' + + golden_end = golden_start + audio_duration_samples + golden_signal = full_golden_signal[golden_start:golden_end, ...] + + # Test length is correct + assert ( + len(item_signal) == audio_duration_samples + ), f'Test 4: Signal length ({len(item_signal)}) not matching the expected length ({audio_duration_samples})' + + assert ( + item_signal.shape == golden_signal.shape + ), f'Signal {signal}: item shape {item_signal.shape} not matching reference shape {golden_signal.shape}' + # Test signal values + assert np.isclose( + item_signal, golden_signal, atol=atol + ).all(), f'Test 4: Failed for example {n}, signal {signal} (random seed {random_seed})' + + # Test 5: + # - Test collate_fn + batch_size = 16 + batch = [dataset.__getitem__(n) for n in range(batch_size)] + batched = dataset.collate_fn(batch) + + for n, signal in enumerate(data.keys()): + signal_shape = batched[2 * n].shape + signal_len = batched[2 * n + 1] + + assert signal_shape == ( + batch_size, + audio_duration_samples, + data_num_channels[signal], + ), f'Test 5: Unexpected signal {signal} shape {signal_shape}' + assert len(signal_len) == batch_size, f'Test 5: Unexpected length of signal_len ({len(signal_len)})' + assert all(signal_len == audio_duration_samples), f'Test 5: Unexpected signal_len {signal_len}' + + @pytest.mark.unit + def test_audio_to_target_dataset_with_target_list(self): + """Test AudioWithTargetDataset when the input manifest has a list + of audio files in the target key. + + In this use case, each line of the manifest file has the following format: + ``` + { + 'input_filepath': 'path/to/input.wav', + 'target_filepath': ['path/to/path_to_target_ch0.wav', 'path/to/path_to_target_ch1.wav'], + 'duration': duration_of_input, + } + ``` + """ + # Data setup + random_seed = 42 + sample_rate = 16000 + num_examples = 25 + data_num_channels = { + 'input_signal': 4, + 'target_signal': 2, + } + data_min_duration = 2.0 + data_max_duration = 8.0 + data_key = { + 'input_signal': 'input_filepath', + 'target_signal': 'target_filepath', + } + + # Tolerance + atol = 1e-6 + + # Generate random signals + _rng = np.random.default_rng(seed=random_seed) + + # Input and target signals have the same duration + data_duration = np.round(_rng.uniform(low=data_min_duration, high=data_max_duration, size=num_examples), 3) + data_duration_samples = np.floor(data_duration * sample_rate).astype(int) + + data = dict() + for signal, num_channels in data_num_channels.items(): + data[signal] = [] + for n in range(num_examples): + if num_channels == 1: + random_signal = _rng.uniform(low=-0.5, high=0.5, size=(data_duration_samples[n])) + else: + random_signal = _rng.uniform(low=-0.5, high=0.5, size=(data_duration_samples[n], num_channels)) + data[signal].append(random_signal) + + with tempfile.TemporaryDirectory() as test_dir: + + # Build metadata for manifest + metadata = [] + + for n in range(num_examples): + + meta = dict() + + for signal in data: + if signal == 'target_signal': + # Save targets as individual files + signal_filename = [] + for ch in range(data_num_channels[signal]): + # add current filename + signal_filename.append(f'{signal}_{n:02d}_ch_{ch}.wav') + # write audio file + sf.write( + os.path.join(test_dir, signal_filename[-1]), + data[signal][n][:, ch], + sample_rate, + 'float', + ) + else: + # single file + signal_filename = f'{signal}_{n:02d}.wav' + + # write audio files + sf.write(os.path.join(test_dir, signal_filename), data[signal][n], sample_rate, 'float') + + # update metadata + meta[data_key[signal]] = signal_filename + + meta['duration'] = data_duration[n] + metadata.append(meta) + + # Save manifest + manifest_filepath = os.path.join(test_dir, 'manifest.json') + write_manifest(manifest_filepath, metadata) + + # Test 1 + # - No constraints on channels or duration + dataset = AudioToTargetDataset( + manifest_filepath=manifest_filepath, + input_key=data_key['input_signal'], + target_key=data_key['target_signal'], + sample_rate=sample_rate, + ) + + config = { + 'manifest_filepath': manifest_filepath, + 'input_key': data_key['input_signal'], + 'target_key': data_key['target_signal'], + 'sample_rate': sample_rate, + } + dataset_factory = audio_to_audio_dataset.get_audio_to_target_dataset(config) + + for n in range(num_examples): + item = dataset.__getitem__(n) + item_factory = dataset_factory.__getitem__(n) + + for signal in data: + item_signal = item[signal].cpu().detach().numpy() + golden_signal = data[signal][n] + assert ( + item_signal.shape == golden_signal.shape + ), f'Signal {signal}: item shape {item_signal.shape} not matching reference shape {golden_signal.shape}' + assert np.isclose( + item_signal, golden_signal, atol=atol + ).all(), f'Test 1: Failed for example {n}, signal {signal} (random seed {random_seed})' + + item_factory_signal = item_factory[signal].cpu().detach().numpy() + assert np.isclose( + item_factory_signal, golden_signal, atol=atol + ).all(), f'Test 1: Failed for factory example {n}, signal {signal} (random seed {random_seed})' + + # Test 2 + # Set target as the first channel of input_filepath and all files listed in target_filepath. + # In this case, the target will have 3 channels. + dataset = AudioToTargetDataset( + manifest_filepath=manifest_filepath, + input_key=data_key['input_signal'], + target_key=[data_key['input_signal'], data_key['target_signal']], + target_channel_selector=0, + sample_rate=sample_rate, + ) + + for n in range(num_examples): + item = dataset.__getitem__(n) + + for signal in data: + item_signal = item[signal].cpu().detach().numpy() + golden_signal = data[signal][n] + if signal == 'target_signal': + # add the first channel of the input + golden_signal = np.concatenate([data['input_signal'][n][..., 0:1], golden_signal], axis=1) + assert ( + item_signal.shape == golden_signal.shape + ), f'Signal {signal}: item shape {item_signal.shape} not matching reference shape {golden_signal.shape}' + assert np.isclose( + item_signal, golden_signal, atol=atol + ).all(), f'Test 2: Failed for example {n}, signal {signal} (random seed {random_seed})' + + @pytest.mark.unit + def test_audio_to_target_dataset_for_inference(self): + """Test AudioWithTargetDataset when target_key is + not set, i.e., it is `None`. This is the case, e.g., when + running inference, and a target is not available. + + In this use case, each line of the manifest file has the following format: + ``` + { + 'input_filepath': 'path/to/input.wav', + 'duration': duration_of_input, + } + ``` + """ + # Data setup + random_seed = 42 + sample_rate = 16000 + num_examples = 25 + data_num_channels = { + 'input_signal': 4, + } + data_min_duration = 2.0 + data_max_duration = 8.0 + data_key = { + 'input_signal': 'input_filepath', + } + + # Tolerance + atol = 1e-6 + + # Generate random signals + _rng = np.random.default_rng(seed=random_seed) + + # Input and target signals have the same duration + data_duration = np.round(_rng.uniform(low=data_min_duration, high=data_max_duration, size=num_examples), 3) + data_duration_samples = np.floor(data_duration * sample_rate).astype(int) + + data = dict() + for signal, num_channels in data_num_channels.items(): + data[signal] = [] + for n in range(num_examples): + if num_channels == 1: + random_signal = _rng.uniform(low=-0.5, high=0.5, size=(data_duration_samples[n])) + else: + random_signal = _rng.uniform(low=-0.5, high=0.5, size=(data_duration_samples[n], num_channels)) + data[signal].append(random_signal) + + with tempfile.TemporaryDirectory() as test_dir: + # Build metadata for manifest + metadata = [] + for n in range(num_examples): + meta = dict() + for signal in data: + # filenames + signal_filename = f'{signal}_{n:02d}.wav' + # write audio files + sf.write(os.path.join(test_dir, signal_filename), data[signal][n], sample_rate, 'float') + # update metadata + meta[data_key[signal]] = signal_filename + meta['duration'] = data_duration[n] + metadata.append(meta) + + # Save manifest + manifest_filepath = os.path.join(test_dir, 'manifest.json') + write_manifest(manifest_filepath, metadata) + + # Test 1 + # - No constraints on channels or duration + dataset = AudioToTargetDataset( + manifest_filepath=manifest_filepath, + input_key=data_key['input_signal'], + target_key=None, # target_signal will be empty + sample_rate=sample_rate, + ) + + # Also test the corresponding factory + config = { + 'manifest_filepath': manifest_filepath, + 'input_key': data_key['input_signal'], + 'target_key': None, + 'sample_rate': sample_rate, + } + dataset_factory = audio_to_audio_dataset.get_audio_to_target_dataset(config) + + for n in range(num_examples): + item = dataset.__getitem__(n) + item_factory = dataset_factory.__getitem__(n) + + # Check target is None + assert item['target_signal'].numel() == 0, 'target_signal is expected to be empty.' + assert item_factory['target_signal'].numel() == 0, 'target_signal is expected to be empty.' + + # Check valid signals + for signal in data: + item_signal = item[signal].cpu().detach().numpy() + golden_signal = data[signal][n] + assert ( + item_signal.shape == golden_signal.shape + ), f'Signal {signal}: item shape {item_signal.shape} not matching reference shape {golden_signal.shape}' + assert np.isclose( + item_signal, golden_signal, atol=atol + ).all(), f'Test 1: Failed for example {n}, signal {signal} (random seed {random_seed})' + + item_factory_signal = item_factory[signal].cpu().detach().numpy() + assert np.isclose( + item_factory_signal, golden_signal, atol=atol + ).all(), f'Test 1: Failed for factory example {n}, signal {signal} (random seed {random_seed})' + + @pytest.mark.unit + def test_audio_to_target_with_reference_dataset(self): + """Test AudioWithTargetWithReferenceDataset in different configurations. + + 1) reference synchronized with input and target + 2) reference not synchronized + + In this use case, each line of the manifest file has the following format: + ``` + { + 'input_filepath': 'path/to/input.wav', + 'target_filepath': 'path/to/path_to_target.wav', + 'reference_filepath': 'path/to/path_to_reference.wav', + 'duration': duration_of_input, + } + ``` + """ + # Data setup + random_seed = 42 + sample_rate = 16000 + num_examples = 25 + data_num_channels = { + 'input_signal': 4, + 'target_signal': 2, + 'reference_signal': 1, + } + data_min_duration = 2.0 + data_max_duration = 8.0 + data_key = { + 'input_signal': 'input_filepath', + 'target_signal': 'target_filepath', + 'reference_signal': 'reference_filepath', + } + + # Tolerance + atol = 1e-6 + + # Generate random signals + _rng = np.random.default_rng(seed=random_seed) + + # Input and target signals have the same duration + data_duration = np.round(_rng.uniform(low=data_min_duration, high=data_max_duration, size=num_examples), 3) + data_duration_samples = np.floor(data_duration * sample_rate).astype(int) + + data = dict() + for signal, num_channels in data_num_channels.items(): + data[signal] = [] + for n in range(num_examples): + if num_channels == 1: + random_signal = _rng.uniform(low=-0.5, high=0.5, size=(data_duration_samples[n])) + else: + random_signal = _rng.uniform(low=-0.5, high=0.5, size=(data_duration_samples[n], num_channels)) + data[signal].append(random_signal) + + with tempfile.TemporaryDirectory() as test_dir: + + # Build metadata for manifest + metadata = [] + + for n in range(num_examples): + + meta = dict() + + for signal in data: + # filenames + signal_filename = f'{signal}_{n:02d}.wav' + + # write audio files + sf.write(os.path.join(test_dir, signal_filename), data[signal][n], sample_rate, 'float') + + # update metadata + meta[data_key[signal]] = signal_filename + + meta['duration'] = data_duration[n] + metadata.append(meta) + + # Save manifest + manifest_filepath = os.path.join(test_dir, 'manifest.json') + write_manifest(manifest_filepath, metadata) + + # Test 1 + # - No constraints on channels or duration + # - Reference is not synchronized with input and target, so whole reference signal will be loaded + dataset = AudioToTargetWithReferenceDataset( + manifest_filepath=manifest_filepath, + input_key=data_key['input_signal'], + target_key=data_key['target_signal'], + reference_key=data_key['reference_signal'], + reference_is_synchronized=False, + sample_rate=sample_rate, + ) + + # Also test the corresponding factory + config = { + 'manifest_filepath': manifest_filepath, + 'input_key': data_key['input_signal'], + 'target_key': data_key['target_signal'], + 'reference_key': data_key['reference_signal'], + 'reference_is_synchronized': False, + 'sample_rate': sample_rate, + } + dataset_factory = audio_to_audio_dataset.get_audio_to_target_with_reference_dataset(config) + + for n in range(num_examples): + item = dataset.__getitem__(n) + item_factory = dataset_factory.__getitem__(n) + + for signal in data: + item_signal = item[signal].cpu().detach().numpy() + golden_signal = data[signal][n] + assert ( + item_signal.shape == golden_signal.shape + ), f'Signal {signal}: item shape {item_signal.shape} not matching reference shape {golden_signal.shape}' + assert np.isclose( + item_signal, golden_signal, atol=atol + ).all(), f'Test 1: Failed for example {n}, signal {signal} (random seed {random_seed})' + + item_factory_signal = item_factory[signal].cpu().detach().numpy() + assert np.isclose( + item_factory_signal, golden_signal, atol=atol + ).all(), f'Test 1: Failed for factory example {n}, signal {signal} (random seed {random_seed})' + + # Test 2 + # - Use fixed duration (random segment selection) + # - Reference is synchronized with input and target, so the same segment of reference signal will be loaded + audio_duration = 4.0 + audio_duration_samples = int(np.floor(audio_duration * sample_rate)) + dataset = AudioToTargetWithReferenceDataset( + manifest_filepath=manifest_filepath, + input_key=data_key['input_signal'], + target_key=data_key['target_signal'], + reference_key=data_key['reference_signal'], + reference_is_synchronized=True, + sample_rate=sample_rate, + min_duration=audio_duration, + audio_duration=audio_duration, + random_offset=True, + ) + + filtered_examples = [n for n, val in enumerate(data_duration) if val >= audio_duration] + + for n in range(len(dataset)): + item = dataset.__getitem__(n) + + golden_start = golden_end = None + for signal in data: + item_signal = item[signal].cpu().detach().numpy() + full_golden_signal = data[signal][filtered_examples[n]] + + # Find random segment using correlation on the first channel + # of the first signal, and then use it fixed for other signals + if golden_start is None: + golden_start = get_segment_start(signal=full_golden_signal[:, 0], segment=item_signal[:, 0]) + golden_end = golden_start + audio_duration_samples + golden_signal = full_golden_signal[golden_start:golden_end, ...] + + # Test length is correct + assert ( + len(item_signal) == audio_duration_samples + ), f'Test 2: Signal {signal} length ({len(item_signal)}) not matching the expected length ({audio_duration_samples})' + + # Test signal values + assert np.isclose( + item_signal, golden_signal, atol=atol + ).all(), f'Test 2: Failed for example {n}, signal {signal} (random seed {random_seed})' + + # Test 3 + # - Use fixed duration (random segment selection) + # - Reference is not synchronized with input and target, so whole reference signal will be loaded + audio_duration = 4.0 + audio_duration_samples = int(np.floor(audio_duration * sample_rate)) + dataset = AudioToTargetWithReferenceDataset( + manifest_filepath=manifest_filepath, + input_key=data_key['input_signal'], + target_key=data_key['target_signal'], + reference_key=data_key['reference_signal'], + reference_is_synchronized=False, + sample_rate=sample_rate, + min_duration=audio_duration, + audio_duration=audio_duration, + random_offset=True, + ) + + filtered_examples = [n for n, val in enumerate(data_duration) if val >= audio_duration] + + for n in range(len(dataset)): + item = dataset.__getitem__(n) + + golden_start = golden_end = None + for signal in data: + item_signal = item[signal].cpu().detach().numpy() + full_golden_signal = data[signal][filtered_examples[n]] + + if signal == 'reference_signal': + # Complete signal is loaded for reference + golden_signal = full_golden_signal + else: + # Find random segment using correlation on the first channel + # of the first signal, and then use it fixed for other signals + if golden_start is None: + golden_start = get_segment_start( + signal=full_golden_signal[:, 0], segment=item_signal[:, 0] + ) + golden_end = golden_start + audio_duration_samples + golden_signal = full_golden_signal[golden_start:golden_end, ...] + + # Test length is correct + assert ( + len(item_signal) == audio_duration_samples + ), f'Test 3: Signal {signal} length ({len(item_signal)}) not matching the expected length ({audio_duration_samples})' + assert ( + item_signal.shape == golden_signal.shape + ), f'Signal {signal}: item shape {item_signal.shape} not matching reference shape {golden_signal.shape}' + # Test signal values + assert np.isclose( + item_signal, golden_signal, atol=atol + ).all(), f'Test 3: Failed for example {n}, signal {signal} (random seed {random_seed})' + + @pytest.mark.unit + def test_audio_to_target_with_embedding_dataset(self): + """Test AudioWithTargetWithEmbeddingDataset. + + In this use case, each line of the manifest file has the following format: + ``` + { + 'input_filepath': 'path/to/input.wav', + 'target_filepath': 'path/to/path_to_target.wav', + 'embedding_filepath': 'path/to/path_to_embedding.npy', + 'duration': duration_of_input, + } + ``` + """ + # Data setup + random_seed = 42 + sample_rate = 16000 + num_examples = 25 + data_num_channels = { + 'input_signal': 4, + 'target_signal': 2, + 'embedding_vector': 1, + } + data_min_duration = 2.0 + data_max_duration = 8.0 + embedding_length = 64 # 64-dimensional embedding vector + data_key = { + 'input_signal': 'input_filepath', + 'target_signal': 'target_filepath', + 'embedding_vector': 'embedding_filepath', + } + + # Tolerance + atol = 1e-6 + + # Generate random signals + _rng = np.random.default_rng(seed=random_seed) + + # Input and target signals have the same duration + data_duration = np.round(_rng.uniform(low=data_min_duration, high=data_max_duration, size=num_examples), 3) + data_duration_samples = np.floor(data_duration * sample_rate).astype(int) + + data = dict() + for signal, num_channels in data_num_channels.items(): + data[signal] = [] + for n in range(num_examples): + data_length = embedding_length if signal == 'embedding_vector' else data_duration_samples[n] + + if num_channels == 1: + random_signal = _rng.uniform(low=-0.5, high=0.5, size=(data_length)) + else: + random_signal = _rng.uniform(low=-0.5, high=0.5, size=(data_length, num_channels)) + data[signal].append(random_signal) + + with tempfile.TemporaryDirectory() as test_dir: + + # Build metadata for manifest + metadata = [] + + for n in range(num_examples): + + meta = dict() + + for signal in data: + if signal == 'embedding_vector': + signal_filename = f'{signal}_{n:02d}.npy' + np.save(os.path.join(test_dir, signal_filename), data[signal][n]) + + else: + # filenames + signal_filename = f'{signal}_{n:02d}.wav' + + # write audio files + sf.write(os.path.join(test_dir, signal_filename), data[signal][n], sample_rate, 'float') + + # update metadata + meta[data_key[signal]] = signal_filename + + meta['duration'] = data_duration[n] + metadata.append(meta) + + # Save manifest + manifest_filepath = os.path.join(test_dir, 'manifest.json') + write_manifest(manifest_filepath, metadata) + + # Test 1 + # - No constraints on channels or duration + dataset = AudioToTargetWithEmbeddingDataset( + manifest_filepath=manifest_filepath, + input_key=data_key['input_signal'], + target_key=data_key['target_signal'], + embedding_key=data_key['embedding_vector'], + sample_rate=sample_rate, + ) + + # Also test the corresponding factory + config = { + 'manifest_filepath': manifest_filepath, + 'input_key': data_key['input_signal'], + 'target_key': data_key['target_signal'], + 'embedding_key': data_key['embedding_vector'], + 'sample_rate': sample_rate, + } + dataset_factory = audio_to_audio_dataset.get_audio_to_target_with_embedding_dataset(config) + + for n in range(num_examples): + item = dataset.__getitem__(n) + item_factory = dataset_factory.__getitem__(n) + + for signal in data: + item_signal = item[signal].cpu().detach().numpy() + golden_signal = data[signal][n] + assert ( + item_signal.shape == golden_signal.shape + ), f'Signal {signal}: item shape {item_signal.shape} not matching reference shape {golden_signal.shape}' + assert np.isclose( + item_signal, golden_signal, atol=atol + ).all(), f'Test 1: Failed for example {n}, signal {signal} (random seed {random_seed})' + + item_factory_signal = item_factory[signal].cpu().detach().numpy() + assert np.isclose( + item_factory_signal, golden_signal, atol=atol + ).all(), f'Test 1: Failed for factory example {n}, signal {signal} (random seed {random_seed})' diff --git a/tests/collections/asr/utils/test_audio_utils.py b/tests/collections/asr/utils/test_audio_utils.py index e277f410c9d13..489a7a7740784 100644 --- a/tests/collections/asr/utils/test_audio_utils.py +++ b/tests/collections/asr/utils/test_audio_utils.py @@ -25,6 +25,7 @@ db2mag, estimated_coherence, generate_approximate_noise_field, + get_segment_start, mag2db, pow2db, rms, @@ -268,3 +269,28 @@ def test_db_conversion(self): assert all(np.abs(mag - 10 ** (mag_db / 20)) < abs_threshold) assert all(np.abs(db2mag(mag_db) - 10 ** (mag_db / 20)) < abs_threshold) assert all(np.abs(pow2db(mag ** 2) - mag_db) < abs_threshold) + + @pytest.mark.unit + def test_get_segment_start(self): + random_seed = 42 + num_examples = 50 + num_samples = 2000 + + _rng = np.random.default_rng(seed=random_seed) + + for n in range(num_examples): + # Generate signal + signal = _rng.normal(size=num_samples) + # Random start in the first half + start = _rng.integers(low=0, high=num_samples // 2) + # Random length + end = _rng.integers(low=start, high=num_samples) + # Selected segment + segment = signal[start:end] + + # UUT + estimated_start = get_segment_start(signal=signal, segment=segment) + + assert ( + estimated_start == start + ), f'Example {n}: estimated start ({estimated_start}) not matching the actual start ({start})' diff --git a/tests/collections/common/test_utils.py b/tests/collections/common/test_utils.py new file mode 100644 index 0000000000000..93c33e92bf988 --- /dev/null +++ b/tests/collections/common/test_utils.py @@ -0,0 +1,33 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from nemo.collections.common.parts.utils import flatten + + +class TestListUtils: + @pytest.mark.unit + def test_flatten(self): + """Test flattening an iterable with different values: str, bool, int, float, complex. + """ + test_cases = [] + test_cases.append({'input': ['aa', 'bb', 'cc'], 'golden': ['aa', 'bb', 'cc']}) + test_cases.append({'input': ['aa', ['bb', 'cc']], 'golden': ['aa', 'bb', 'cc']}) + test_cases.append({'input': ['aa', [['bb'], [['cc']]]], 'golden': ['aa', 'bb', 'cc']}) + test_cases.append({'input': ['aa', [[1, 2], [[3]], 4]], 'golden': ['aa', 1, 2, 3, 4]}) + test_cases.append({'input': [True, [2.5, 2.0 + 1j]], 'golden': [True, 2.5, 2.0 + 1j]}) + + for n, test_case in enumerate(test_cases): + assert flatten(test_case['input']) == test_case['golden'], f'Test case {n} failed!'