Skip to content

Commit

Permalink
Dataset factories + review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
anteju committed Oct 27, 2022
1 parent e95c9f9 commit ca69595
Show file tree
Hide file tree
Showing 4 changed files with 181 additions and 47 deletions.
3 changes: 3 additions & 0 deletions nemo/collections/asr/data/audio_to_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
]


# TODO: move utility functions to a more general module
# Local utility functions
def flatten_iterable(iter: Iterable[Union[str, Iterable[str]]]) -> Iterable[str]:
"""Flatten an iterable which contains strings or
Expand Down Expand Up @@ -97,6 +98,7 @@ def load_samples(
)
else:
# TODO: Load random subsegment of `duration` seconds
# TODO: Load segment of `duration` seconds starting at fix_segment (non-random)
raise NotImplementedError()

return segment.samples
Expand Down Expand Up @@ -157,6 +159,7 @@ def load_samples_synchronized(
raise RuntimeError(f'Unexpected audio_file type {type(audio_file)}')

else:
# TODO: Load segment of `duration` seconds starting at fix_segment (non-random)
audio_durations = [librosa.get_duration(filename=f) for f in flatten(audio_files)]
min_duration = min(audio_durations)
available_duration = min_duration - fixed_offset
Expand Down
77 changes: 77 additions & 0 deletions nemo/collections/asr/data/audio_to_audio_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
from nemo.collections.asr.data import audio_to_audio


def get_audio_to_target_dataset(config: dict) -> audio_to_audio.AudioToTargetDataset:
"""Instantiates an audio-to-audio dataset.
Args:
config: Config of AudioToTargetDataset.
Returns:
An instance of AudioToTargetDataset
"""
dataset = audio_to_audio.AudioToTargetDataset(
manifest_filepath=config['manifest_filepath'],
sample_rate=config['sample_rate'],
input_key=config['input_key'],
target_key=config['target_key'],
audio_duration=config.get('audio_duration', None),
max_duration=config.get('max_duration', None),
min_duration=config.get('min_duration', None),
max_utts=config.get('max_utts', 0),
input_channel_selector=config.get('input_channel_selector', None),
target_channel_selector=config.get('target_channel_selector', None),
)
return dataset


def get_audio_to_target_with_reference_dataset(config: dict) -> audio_to_audio.AudioToTargetWithReferenceDataset:
"""Instantiates an audio-to-audio dataset.
Args:
config: Config of AudioToTargetWithReferenceDataset.
Returns:
An instance of AudioToTargetWithReferenceDataset
"""
dataset = audio_to_audio.AudioToTargetWithReferenceDataset(
manifest_filepath=config['manifest_filepath'],
sample_rate=config['sample_rate'],
input_key=config['input_key'],
target_key=config['target_key'],
reference_key=config['reference_key'],
audio_duration=config.get('audio_duration', None),
max_duration=config.get('max_duration', None),
min_duration=config.get('min_duration', None),
max_utts=config.get('max_utts', 0),
input_channel_selector=config.get('input_channel_selector', None),
target_channel_selector=config.get('target_channel_selector', None),
reference_channel_selector=config.get('reference_channel_selector', None),
reference_is_synchronized=config.get('reference_is_synchronized', True),
)
return dataset


def get_audio_to_target_with_embedding_dataset(config: dict) -> audio_to_audio.AudioToTargetWithEmbeddingDataset:
"""Instantiates an audio-to-audio dataset.
Args:
config: Config of AudioToTargetWithEmbeddingDataset.
Returns:
An instance of AudioToTargetWithEmbeddingDataset
"""
dataset = audio_to_audio.AudioToTargetWithEmbeddingDataset(
manifest_filepath=config['manifest_filepath'],
sample_rate=config['sample_rate'],
input_key=config['input_key'],
target_key=config['target_key'],
embedding_key=config['embedding_key'],
audio_duration=config.get('audio_duration', None),
max_duration=config.get('max_duration', None),
min_duration=config.get('min_duration', None),
max_utts=config.get('max_utts', 0),
input_channel_selector=config.get('input_channel_selector', None),
target_channel_selector=config.get('target_channel_selector', None),
)
return dataset
2 changes: 1 addition & 1 deletion nemo/collections/common/parts/preprocessing/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -914,7 +914,7 @@ def get_full_path(audio_file: str, manifest_file: str) -> str:
# Handle all audio files
audio_files = {}
for key in self.audio_keys:
audio_file = item.pop(key)
audio_file = item[key]
if isinstance(audio_file, str):
# This dictionary entry points to a single file
audio_files[key] = get_full_path(audio_file, manifest_file)
Expand Down
Loading

0 comments on commit ca69595

Please sign in to comment.