Skip to content

Commit

Permalink
Dataset factories + review comments
Browse files Browse the repository at this point in the history
Signed-off-by: Ante Jukić <[email protected]>
  • Loading branch information
anteju committed Oct 27, 2022
1 parent e95c9f9 commit c42d9f1
Show file tree
Hide file tree
Showing 4 changed files with 288 additions and 138 deletions.
63 changes: 34 additions & 29 deletions nemo/collections/asr/data/audio_to_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
]


# TODO: move utility functions to a more general module
# Local utility functions
def flatten_iterable(iter: Iterable[Union[str, Iterable[str]]]) -> Iterable[str]:
"""Flatten an iterable which contains strings or
Expand Down Expand Up @@ -97,6 +98,7 @@ def load_samples(
)
else:
# TODO: Load random subsegment of `duration` seconds
# TODO: Load segment of `duration` seconds starting at fix_segment (non-random)
raise NotImplementedError()

return segment.samples
Expand Down Expand Up @@ -157,6 +159,7 @@ def load_samples_synchronized(
raise RuntimeError(f'Unexpected audio_file type {type(audio_file)}')

else:
# TODO: Load segment of `duration` seconds starting at fix_segment (non-random)
audio_durations = [librosa.get_duration(filename=f) for f in flatten(audio_files)]
min_duration = min(audio_durations)
available_duration = min_duration - fixed_offset
Expand Down Expand Up @@ -511,9 +514,9 @@ def output_types(self) -> Optional[Dict[str, NeuralType]]:
Dictionary in the following form:
```
{
'input': single- or multi-channel format,
'input_signal': single- or multi-channel format,
'input_length': original length of each input signal
'target': single- or multi-channel format,
'target_signal': single- or multi-channel format,
'target_length': original length of each target signal
}
```
Expand All @@ -522,9 +525,9 @@ def output_types(self) -> Optional[Dict[str, NeuralType]]:
mc_audio_type = NeuralType(('B', 'T', 'C'), AudioSignal())

return OrderedDict(
input=sc_audio_type if self.num_channels('input') == 1 else mc_audio_type,
input_signal=sc_audio_type if self.num_channels('input_signal') == 1 else mc_audio_type,
input_length=NeuralType(('B',), LengthsType()),
target=sc_audio_type if self.num_channels('target') == 1 else mc_audio_type,
target_signal=sc_audio_type if self.num_channels('target_signal') == 1 else mc_audio_type,
target_length=NeuralType(('B',), LengthsType()),
)

Expand All @@ -538,8 +541,8 @@ def __getitem__(self, index: int) -> Dict[str, torch.tensor]:
Dictionary providing mapping from signal to its tensor.
```
{
'input': input_tensor,
'target': target_tensor,
'input_signal': input_tensor,
'target_signal': target_tensor,
}
```
"""
Expand All @@ -561,7 +564,7 @@ def __getitem__(self, index: int) -> Dict[str, torch.tensor]:
input_signal = list_to_multichannel(input_signal)
target_signal = list_to_multichannel(target_signal)

output = OrderedDict(input=torch.tensor(input_signal), target=torch.tensor(target_signal),)
output = OrderedDict(input_signal=torch.tensor(input_signal), target_signal=torch.tensor(target_signal),)

return output

Expand Down Expand Up @@ -657,11 +660,11 @@ def output_types(self) -> Optional[Dict[str, NeuralType]]:
Dictionary in the following form:
```
{
'input': single- or multi-channel format,
'input_signal': single- or multi-channel format,
'input_length': original length of each input signal
'target': single- or multi-channel format,
'target_signal': single- or multi-channel format,
'target_length': original length of each target signal
'reference': single- or multi-channel format,
'reference_signal': single- or multi-channel format,
'reference_length': original length of each reference signal
}
```
Expand All @@ -670,11 +673,11 @@ def output_types(self) -> Optional[Dict[str, NeuralType]]:
mc_audio_type = NeuralType(('B', 'T', 'C'), AudioSignal())

return OrderedDict(
input=sc_audio_type if self.num_channels('input') == 1 else mc_audio_type,
input_signal=sc_audio_type if self.num_channels('input_signal') == 1 else mc_audio_type,
input_length=NeuralType(('B',), LengthsType()),
target=sc_audio_type if self.num_channels('target') == 1 else mc_audio_type,
target_signal=sc_audio_type if self.num_channels('target_signal') == 1 else mc_audio_type,
target_length=NeuralType(('B',), LengthsType()),
reference=sc_audio_type if self.num_channels('reference') == 1 else mc_audio_type,
reference_signal=sc_audio_type if self.num_channels('reference_signal') == 1 else mc_audio_type,
reference_length=NeuralType(('B',), LengthsType()),
)

Expand All @@ -688,9 +691,9 @@ def __getitem__(self, index: int) -> Dict[str, torch.tensor]:
Dictionary providing mapping from signal to its tensor.
```
{
'input': input_tensor,
'target': target_tensor,
'reference': reference_tensor,
'input_signal': input_tensor,
'target_signal': target_tensor,
'reference_signal': reference_tensor,
}
```
"""
Expand Down Expand Up @@ -739,9 +742,9 @@ def __getitem__(self, index: int) -> Dict[str, torch.tensor]:

# Output dictionary
output = OrderedDict(
input=torch.tensor(input_signal),
target=torch.tensor(target_signal),
reference=torch.tensor(reference_signal),
input_signal=torch.tensor(input_signal),
target_signal=torch.tensor(target_signal),
reference_signal=torch.tensor(reference_signal),
)

return output
Expand Down Expand Up @@ -823,11 +826,11 @@ def output_types(self) -> Optional[Dict[str, NeuralType]]:
Dictionary in the following form:
```
{
'input': single- or multi-channel format,
'input_signal': single- or multi-channel format,
'input_length': original length of each input signal
'target': single- or multi-channel format,
'target_signal': single- or multi-channel format,
'target_length': original length of each target signal
'embedding': batched embedded vector format,
'embedding_vector': batched embedded vector format,
'embedding_length': original length of each embedding vector
}
```
Expand All @@ -836,11 +839,11 @@ def output_types(self) -> Optional[Dict[str, NeuralType]]:
mc_audio_type = NeuralType(('B', 'T', 'C'), AudioSignal())

return OrderedDict(
input=sc_audio_type if self.num_channels('input') == 1 else mc_audio_type,
input_signal=sc_audio_type if self.num_channels('input_signal') == 1 else mc_audio_type,
input_length=NeuralType(('B',), LengthsType()),
target=sc_audio_type if self.num_channels('target') == 1 else mc_audio_type,
target_signal=sc_audio_type if self.num_channels('target_signal') == 1 else mc_audio_type,
target_length=NeuralType(('B',), LengthsType()),
embedding=NeuralType(('B', 'D'), EncodedRepresentation()),
embedding_vector=NeuralType(('B', 'D'), EncodedRepresentation()),
embedding_length=NeuralType(('B',), LengthsType()),
)

Expand All @@ -854,9 +857,9 @@ def __getitem__(self, index):
Dictionary providing mapping from signal to its tensor.
```
{
'input': input_tensor,
'target': target_tensor,
'embedding': embedding_tensor,
'input_signal': input_tensor,
'target_signal': target_tensor,
'embedding_vector': embedding_tensor,
}
```
"""
Expand Down Expand Up @@ -885,7 +888,9 @@ def __getitem__(self, index):

# Output dictionary
output = OrderedDict(
input=torch.tensor(input_signal), target=torch.tensor(target_signal), embedding=torch.tensor(embedding),
input_signal=torch.tensor(input_signal),
target_signal=torch.tensor(target_signal),
embedding_vector=torch.tensor(embedding),
)

return output
91 changes: 91 additions & 0 deletions nemo/collections/asr/data/audio_to_audio_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from nemo.collections.asr.data import audio_to_audio


def get_audio_to_target_dataset(config: dict) -> audio_to_audio.AudioToTargetDataset:
"""Instantiates an audio-to-audio dataset.
Args:
config: Config of AudioToTargetDataset.
Returns:
An instance of AudioToTargetDataset
"""
dataset = audio_to_audio.AudioToTargetDataset(
manifest_filepath=config['manifest_filepath'],
sample_rate=config['sample_rate'],
input_key=config['input_key'],
target_key=config['target_key'],
audio_duration=config.get('audio_duration', None),
max_duration=config.get('max_duration', None),
min_duration=config.get('min_duration', None),
max_utts=config.get('max_utts', 0),
input_channel_selector=config.get('input_channel_selector', None),
target_channel_selector=config.get('target_channel_selector', None),
)
return dataset


def get_audio_to_target_with_reference_dataset(config: dict) -> audio_to_audio.AudioToTargetWithReferenceDataset:
"""Instantiates an audio-to-audio dataset.
Args:
config: Config of AudioToTargetWithReferenceDataset.
Returns:
An instance of AudioToTargetWithReferenceDataset
"""
dataset = audio_to_audio.AudioToTargetWithReferenceDataset(
manifest_filepath=config['manifest_filepath'],
sample_rate=config['sample_rate'],
input_key=config['input_key'],
target_key=config['target_key'],
reference_key=config['reference_key'],
audio_duration=config.get('audio_duration', None),
max_duration=config.get('max_duration', None),
min_duration=config.get('min_duration', None),
max_utts=config.get('max_utts', 0),
input_channel_selector=config.get('input_channel_selector', None),
target_channel_selector=config.get('target_channel_selector', None),
reference_channel_selector=config.get('reference_channel_selector', None),
reference_is_synchronized=config.get('reference_is_synchronized', True),
)
return dataset


def get_audio_to_target_with_embedding_dataset(config: dict) -> audio_to_audio.AudioToTargetWithEmbeddingDataset:
"""Instantiates an audio-to-audio dataset.
Args:
config: Config of AudioToTargetWithEmbeddingDataset.
Returns:
An instance of AudioToTargetWithEmbeddingDataset
"""
dataset = audio_to_audio.AudioToTargetWithEmbeddingDataset(
manifest_filepath=config['manifest_filepath'],
sample_rate=config['sample_rate'],
input_key=config['input_key'],
target_key=config['target_key'],
embedding_key=config['embedding_key'],
audio_duration=config.get('audio_duration', None),
max_duration=config.get('max_duration', None),
min_duration=config.get('min_duration', None),
max_utts=config.get('max_utts', 0),
input_channel_selector=config.get('input_channel_selector', None),
target_channel_selector=config.get('target_channel_selector', None),
)
return dataset
2 changes: 1 addition & 1 deletion nemo/collections/common/parts/preprocessing/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -914,7 +914,7 @@ def get_full_path(audio_file: str, manifest_file: str) -> str:
# Handle all audio files
audio_files = {}
for key in self.audio_keys:
audio_file = item.pop(key)
audio_file = item[key]
if isinstance(audio_file, str):
# This dictionary entry points to a single file
audio_files[key] = get_full_path(audio_file, manifest_file)
Expand Down
Loading

0 comments on commit c42d9f1

Please sign in to comment.