diff --git a/nemo/collections/asr/data/data_simulation.py b/nemo/collections/asr/data/data_simulation.py index a6bee3e4efca..324f609f37f1 100644 --- a/nemo/collections/asr/data/data_simulation.py +++ b/nemo/collections/asr/data/data_simulation.py @@ -17,6 +17,7 @@ import os import shutil import warnings +from collections import defaultdict from typing import Dict, Iterable, List, Optional, Tuple, Union import h5py @@ -163,6 +164,8 @@ def __init__(self, cfg): self._params = cfg # internal params self._manifest = read_manifest(self._params.data_simulator.manifest_filepath) + self._speaker_samples = self._build_speaker_samples_map() + self._noise_samples = [] self._sentence = None self._text = "" self._words = [] @@ -270,6 +273,14 @@ def _check_args(self): if len(self._manifest) == 0: raise Exception("Manifest file is empty. Check that the source path is correct.") + def clean_up(self): + self._sentence = None + self._words = [] + self._alignments = [] + self._audio_read_buffer_dict = {} + self._noise_read_buffer_dict = {} + torch.cuda.empty_cache() + def _get_speaker_ids(self) -> List[str]: """ Randomly select speaker IDs from the loaded manifest file. @@ -277,50 +288,104 @@ def _get_speaker_ids(self) -> List[str]: Returns: speaker_ids (list): Speaker IDs """ - speaker_ids = [] - s = 0 - while s < self._params.data_simulator.session_config.num_speakers: - file = self._manifest[np.random.randint(0, len(self._manifest) - 1)] - speaker_id = file['speaker_id'] - if speaker_id not in speaker_ids: # ensure speaker IDs are not duplicated - speaker_ids.append(speaker_id) - s += 1 + all_speaker_ids = list(self._speaker_samples.keys()) + idx_list = np.random.permutation(len(all_speaker_ids))[ + : self._params.data_simulator.session_config.num_speakers + ] + speaker_ids = [all_speaker_ids[i] for i in idx_list] return speaker_ids + def _build_speaker_samples_map(self) -> Dict: + """ + Build a dictionary for mapping speaker ID to their list of samples + + Returns: + speaker_samples (Dict[list]): + Dictionary mapping speaker ID to their list of samples + """ + speaker_samples = defaultdict(list) + logging.info("Building speaker to samples map...") + for sample in tqdm(self._manifest, total=len(self._manifest)): + speaker_id = sample['speaker_id'] + speaker_samples[speaker_id].append(sample) + return speaker_samples + + def _sample_noise_manifest(self, noise_manifest) -> list: + """ + Sample noise manifest to a specified count `num_noise_files` for the current simulated audio session. + + Args: + noise_manifest (list): + List of noise source samples to be sampled from. + + Returns: + sampled_noise_manifest (list): + List of noise samples to be used for the current session. + """ + num_noise_files = min(len(noise_manifest), self._params.data_simulator.background_noise.num_noise_files) + sampled_noise_manifest = [] + if num_noise_files > 0: + selected_noise_ids = np.random.choice(range(len(noise_manifest)), num_noise_files, replace=False) + for k in selected_noise_ids: + sampled_noise_manifest.append(noise_manifest[k]) + return sampled_noise_manifest + + def _read_noise_manifest(self): + """ + Read the noise manifest file and sample the noise manifest. + + Returns: + noise_manifest (list): + List of the entire noise source samples. + """ + noise_manifest = [] + if self._params.data_simulator.background_noise.add_bg is True: + if self._params.data_simulator.background_noise.background_manifest is not None: + if os.path.exists(self._params.data_simulator.background_noise.background_manifest): + noise_manifest = read_manifest(self._params.data_simulator.background_noise.background_manifest) + else: + raise FileNotFoundError( + f"Noise manifest file: {self._params.data_simulator.background_noise.background_manifest} file not found." + ) + else: + raise FileNotFoundError( + f"Noise manifest file is null. Please provide a valid noise manifest file if add_bg=True." + ) + return noise_manifest + def _get_speaker_samples(self, speaker_ids: List[str]) -> Dict[str, list]: """ Get a list of the samples for each of the specified speakers. Args: speaker_ids (list): LibriSpeech speaker IDs for each speaker in the current session. + Returns: - speaker_lists (dict): Dictionary of manifest lines per speaker + speaker_wav_align_map (dict): Dictionary containing speaker IDs and their corresponding wav filepath and alignments. """ - speaker_lists = {} - for i in range(self._params.data_simulator.session_config.num_speakers): - speaker_lists[str(speaker_ids[i])] = [] - # loop over manifest and add files corresponding to each speaker to each sublist - for file in self._manifest: - new_speaker_id = file['speaker_id'] - if new_speaker_id in speaker_ids: - speaker_lists[str(new_speaker_id)].append(file) - return speaker_lists + speaker_wav_align_map = defaultdict(list) + for sid in speaker_ids: + speaker_wav_align_map[sid] = self._speaker_samples[sid] + return speaker_wav_align_map - def _load_speaker_sample(self, speaker_lists: List[dict], speaker_ids: List[str], speaker_turn: int) -> str: + def _load_speaker_sample( + self, speaker_wav_align_map: List[dict], speaker_ids: List[str], speaker_turn: int + ) -> str: """ Load a sample for the selected speaker ID. The first alignment and word must be silence that determines the start of the alignments. Args: - speaker_lists (list): List of samples for each speaker in the session. + speaker_wav_align_map (dict): Dictionary containing speaker IDs and their corresponding wav filepath and alignments. speaker_ids (list): LibriSpeech speaker IDs for each speaker in the current session. speaker_turn (int): Current speaker turn. + Returns: file_path (str): Path to the desired audio file """ speaker_id = speaker_ids[speaker_turn] - file_id = np.random.randint(0, max(len(speaker_lists[str(speaker_id)]) - 1, 1)) - file_dict = speaker_lists[str(speaker_id)][file_id] + file_id = np.random.randint(0, max(len(speaker_wav_align_map[str(speaker_id)]) - 1, 1)) + file_dict = speaker_wav_align_map[str(speaker_id)][file_id] # Check whether the first word is silence and insert a silence token if the first token is not silence. if file_dict['words'][0] != "": @@ -348,7 +413,7 @@ def _get_speaker_dominance(self) -> List[float]: total = np.sum(dominance) if total == 0: for i in range(len(dominance)): - dominance[i] += min_dominance + dominance[i] += self._params.data_simulator.session_params.min_dominance # scale accounting for min_dominance which has to be added after dominance = (dominance / total) * ( 1 @@ -531,8 +596,8 @@ def _get_end_buffer_and_window( def _add_file( self, - file: dict, - audio_file: torch.Tensor, + audio_manifest: dict, + audio_file, sentence_word_count: int, max_word_count_in_sentence: int, max_samples_in_sentence: int, @@ -542,7 +607,7 @@ def _add_file( Uses the alignments to segment the audio file. Args: - file (dict): Line from manifest file for current audio file + audio_manifest (dict): Line from manifest file for current audio file audio_file (tensor): Current loaded audio file sentence_word_count (int): Running count for number of words in sentence max_word_count_in_sentence (int): Maximum count for number of words in sentence @@ -551,12 +616,12 @@ def _add_file( sentence_word_count+current_word_count (int): Running word count len(self._sentence) (tensor): Current length of the audio file """ - if len(file['alignments']) <= 1: - raise ValueError(f"Alignment file has inappropriate length of {len(file['alignments'])}") + if len(audio_manifest['alignments']) <= 1: + raise ValueError(f"Alignment file has inappropriate length of {len(audio_manifest['alignments'])}") - offset_idx = np.random.randint(low=1, high=len(file['words'])) + offset_idx = np.random.randint(low=1, high=len(audio_manifest['words'])) - first_alignment = int(file['alignments'][offset_idx - 1] * self._params.data_simulator.sr) + first_alignment = int(audio_manifest['alignments'][offset_idx - 1] * self._params.data_simulator.sr) start_cutoff, start_window_amount = self._get_start_buffer_and_window(first_alignment) if not self._params.data_simulator.session_params.start_window: # cut off the start of the sentence start_window_amount = 0 @@ -573,16 +638,18 @@ def _add_file( while ( current_word_count < remaining_duration and dur_sample_count < remaining_dur_sample_count - and word_idx < len(file['words']) + and word_idx < len(audio_manifest['words']) ): - dur_sample_count = int(file['alignments'][word_idx] * self._params.data_simulator.sr) - start_cutoff + dur_sample_count = ( + int(audio_manifest['alignments'][word_idx] * self._params.data_simulator.sr) - start_cutoff + ) # check the length of the generated sentence in terms of sample count (int). if curr_dur_sample_count + dur_sample_count > remaining_dur_sample_count: # if the upcoming loop will exceed the remaining sample count, break out of the loop. break - word = file['words'][word_idx] + word = audio_manifest['words'][word_idx] if silence_count > 0 and word == "": break @@ -591,7 +658,7 @@ def _add_file( self._alignments.append( float(sentence_sample_count * 1.0 / self._params.data_simulator.sr) - float(start_cutoff * 1.0 / self._params.data_simulator.sr) - + file['alignments'][word_idx] + + audio_manifest['alignments'][word_idx] ) if word == "": @@ -620,24 +687,23 @@ def _add_file( ), 0, ) - self._sentence = self._sentence.to(self._device) self._sentence = torch.cat( ( self._sentence, audio_file[start_cutoff + start_window_amount : start_cutoff + prev_dur_sample_count], ), 0, - ) - self._sentence = self._sentence.to(self._device) + ).to(self._device) else: self._sentence = torch.cat( (self._sentence, audio_file[start_cutoff : start_cutoff + prev_dur_sample_count]), 0 - ) - self._sentence = self._sentence.to(self._device) + ).to(self._device) # windowing at the end of the sentence - if (word_idx < len(file['words'])) and self._params.data_simulator.session_params.window_type is not None: + if ( + word_idx < len(audio_manifest['words']) + ) and self._params.data_simulator.session_params.window_type is not None: release_buffer, end_window_amount = self._get_end_buffer_and_window( prev_dur_sample_count, remaining_dur_sample_count, @@ -651,8 +717,7 @@ def _add_file( ], ), 0, - ) - self._sentence = self._sentence.to(self._device) + ).to(self._device) if end_window_amount > 0: # include window window = self._get_window(end_window_amount, start=False) @@ -672,13 +737,17 @@ def _add_file( ), ), 0, - ) - self._sentence = self._sentence.to(self._device) + ).to(self._device) + del audio_file return sentence_word_count + current_word_count, len(self._sentence) def _build_sentence( - self, speaker_turn: int, speaker_ids: List[str], speaker_lists: List[dict], max_samples_in_sentence: int + self, + speaker_turn: int, + speaker_ids: List[str], + speaker_wav_align_map: Dict[str, list], + max_samples_in_sentence: int, ): """ Build a new sentence by attaching utterance samples together until the sentence has reached a desired length. @@ -687,7 +756,7 @@ def _build_sentence( Args: speaker_turn (int): Current speaker turn. speaker_ids (list): LibriSpeech speaker IDs for each speaker in the current session. - speaker_lists (list): List of samples for each speaker in the session. + speaker_wav_align_map (dict): Dictionary containing speaker IDs and their corresponding wav filepath and alignments. max_samples_in_sentence (int): Maximum length for sentence in terms of samples """ # select speaker length @@ -704,22 +773,23 @@ def _build_sentence( self._text = "" self._words = [] self._alignments = [] - sentence_word_count = sentence_sample_count = 0 + sentence_word_count = 0 + sentence_sample_count = 0 # build sentence while sentence_word_count < sl and sentence_sample_count < max_samples_in_sentence: - file = self._load_speaker_sample(speaker_lists, speaker_ids, speaker_turn) - if file['audio_filepath'] in self._audio_read_buffer_dict: - audio_file, sr = self._audio_read_buffer_dict[file['audio_filepath']] + audio_manifest = self._load_speaker_sample(speaker_wav_align_map, speaker_ids, speaker_turn) + if audio_manifest['audio_filepath'] in self._audio_read_buffer_dict: + audio_file, sr = self._audio_read_buffer_dict[audio_manifest['audio_filepath']] else: - audio_file, sr = sf.read(file['audio_filepath']) + audio_file, sr = sf.read(audio_manifest['audio_filepath']) audio_file = torch.from_numpy(audio_file).to(self._device) if audio_file.ndim > 1: - audio_file = torch.mean(audio_file, 1, False) - self._audio_read_buffer_dict[file['audio_filepath']] = (audio_file, sr) + audio_file = torch.mean(audio_file, 1, False).to(self._device) + self._audio_read_buffer_dict[audio_manifest['audio_filepath']] = (audio_file, sr) sentence_word_count, sentence_sample_count = self._add_file( - file, audio_file, sentence_word_count, sl, max_samples_in_sentence + audio_manifest, audio_file, sentence_word_count, sl, max_samples_in_sentence ) # look for split locations @@ -745,7 +815,8 @@ def _build_sentence( # per-speaker normalization (accounting for active speaker time) if self._params.data_simulator.session_params.normalize: if torch.max(torch.abs(self._sentence)) > 0: - split_length = split_sum = torch.tensor(0).to(self._device).double() + split_length = torch.tensor(0).to(self._device).double() + split_sum = torch.tensor(0).to(self._device).double() for split in splits: split_length += len(self._sentence[split[0] : split[1]]) split_sum += torch.sum(self._sentence[split[0] : split[1]] ** 2) @@ -852,34 +923,34 @@ def _add_silence_or_overlap( return new_start - def _get_background(self, len_array: int, power_array: float) -> torch.Tensor: + def _get_background(self, len_array: int, power_array: float): """ Augment with background noise (inserting ambient background noise up to the desired SNR for the full clip). Args: len_array (int): Length of background noise required. avg_power_array (float): Average power of the audio file. + Returns: bg_array (tensor): Tensor containing background noise """ - - manifest = read_manifest(self._params.data_simulator.background_noise.background_manifest) bg_array = torch.zeros(len_array).to(self._device) desired_snr = self._params.data_simulator.background_noise.snr ratio = 10 ** (desired_snr / 20) desired_avg_power_noise = (power_array / ratio).to(self._device) - running_len = 0 + running_len, file_id = 0, 0 while running_len < len_array: # build background audio stream (the same length as the full file) - file_id = np.random.randint(0, len(manifest) - 1) - file = manifest[file_id] - if file['audio_filepath'] in self._audio_read_buffer_dict: - audio_file, sr = self._noise_read_buffer_dict[file['audio_filepath']] + audio_manifest = self._noise_samples[file_id % len(self._noise_samples)] + file_id += 1 + + if audio_manifest['audio_filepath'] in self._noise_read_buffer_dict: + audio_file, sr = self._noise_read_buffer_dict[audio_manifest['audio_filepath']] else: - audio_file, sr = sf.read(file['audio_filepath']) + audio_file, sr = sf.read(audio_manifest['audio_filepath']) audio_file = torch.from_numpy(audio_file).to(self._device) if audio_file.ndim > 1: audio_file = torch.mean(audio_file, 1, False) - self._noise_read_buffer_dict[file['audio_filepath']] = (audio_file, sr) + self._noise_read_buffer_dict[audio_manifest['audio_filepath']] = (audio_file, sr) if running_len + len(audio_file) < len_array: end_audio_file = running_len + len(audio_file) @@ -904,6 +975,7 @@ def _create_new_rttm_entry(self, start: int, end: int, speaker_id: int) -> List[ start (int): Current start of the audio file being inserted. end (int): End of the audio file being inserted. speaker_id (int): LibriSpeech speaker ID for the current entry. + Returns: rttm_list (list): List of rttm entries """ @@ -940,6 +1012,7 @@ def _create_new_json_entry( speaker_id (int): LibriSpeech speaker ID for the current entry. rttm_filepath (str): Output rttm filepath. ctm_filepath (str): Output ctm filepath. + Returns: dict (dict): JSON entry """ @@ -1031,7 +1104,17 @@ def create_segment_manifest_ds(self) -> str: self.segment_manifest_filepath = output_manifest_filepath return self.segment_manifest_filepath - def _generate_session(self, idx: int, basepath: str, filename: str, enforce_counter: int = 2): + def _generate_session( + self, + idx: int, + basepath: str, + filename: str, + speaker_ids: List[str], + speaker_wav_align_map: Dict[str, list], + noise_samples: list, + device: torch.device, + enforce_counter: int = 2, + ): """ _generate_session function without RIR simulation. Generate a multispeaker audio session and corresponding label files. @@ -1040,17 +1123,21 @@ def _generate_session(self, idx: int, basepath: str, filename: str, enforce_coun idx (int): Index for current session (out of total number of sessions). basepath (str): Path to output directory. filename (str): Filename for output files. + speaker_ids (list): List of speaker IDs that will be used in this session. + speaker_wav_align_map (dict): Dictionary containing speaker IDs and their corresponding wav filepath and alignments. + noise_samples (list): List of randomly sampled noise source files that will be used for generating this session. + device (torch.device): Device to use for generating this session. enforce_counter (int): In enforcement mode, dominance is increased by a factor of enforce_counter for unrepresented speakers """ - speaker_ids = self._get_speaker_ids() # randomly select speaker IDs + self._device = device speaker_dominance = self._get_speaker_dominance() # randomly determine speaker dominance base_speaker_dominance = np.copy(speaker_dominance) - speaker_lists = self._get_speaker_samples(speaker_ids) # get list of samples per speaker self._set_speaker_volume() running_len_sample_count, prev_len_sample_count = 0, 0 prev_speaker = None rttm_list, json_list, ctm_list = [], [], [] + self._noise_samples = noise_samples self._furthest_sample = [0 for n in range(self._params.data_simulator.session_config.num_speakers)] self._missing_overlap = 0 @@ -1087,7 +1174,7 @@ def _generate_session(self, idx: int, basepath: str, filename: str, enforce_coun ): break - self._build_sentence(speaker_turn, speaker_ids, speaker_lists, max_samples_in_sentence) + self._build_sentence(speaker_turn, speaker_ids, speaker_wav_align_map, max_samples_in_sentence) length = len(self._sentence) start = self._add_silence_or_overlap( @@ -1134,9 +1221,12 @@ def _generate_session(self, idx: int, basepath: str, filename: str, enforce_coun # background noise augmentation if self._params.data_simulator.background_noise.add_bg: - avg_power_array = torch.mean(array[is_speech == 1] ** 2) - bg = self._get_background(len(array), avg_power_array) - array += bg + if len(self._noise_samples) > 0: + avg_power_array = torch.mean(array[is_speech == 1] ** 2) + bg = self._get_background(len(array), avg_power_array) + array += bg + else: + raise ValueError('No background noise samples found in self._noise_samples.') array = array / (1.0 * torch.max(torch.abs(array))) # normalize wav file to avoid clipping if torch.is_tensor(array): @@ -1147,6 +1237,10 @@ def _generate_session(self, idx: int, basepath: str, filename: str, enforce_coun write_ctm(os.path.join(basepath, filename + '.ctm'), ctm_list) write_text(os.path.join(basepath, filename + '.txt'), ctm_list) + del array + self.clean_up() + return basepath, filename + def generate_sessions(self, random_seed: int = None): """ Generate several multispeaker audio sessions and corresponding list files. @@ -1184,29 +1278,54 @@ def generate_sessions(self, random_seed: int = None): tp = concurrent.futures.ProcessPoolExecutor(max_workers=self._params.get("num_workers", 1)) futures = [] - for i in range(self._params.data_simulator.session_config.num_sessions): + num_sessions = self._params.data_simulator.session_config.num_sessions + source_noise_manifest = self._read_noise_manifest() + queue = [] + # add radomly sampled arguments to a list(queue) for multiprocessing + for sess_idx in range(num_sessions): + filename = self._params.data_simulator.outputs.output_filename + f"_{sess_idx}" + speaker_ids = self._get_speaker_ids() + speaker_wav_align_map = self._get_speaker_samples(speaker_ids) + noise_samples = self._sample_noise_manifest(source_noise_manifest) + if torch.cuda.is_available(): + device = torch.device(f"cuda:{sess_idx % torch.cuda.device_count()}") + else: + device = self._device + queue.append((sess_idx, basepath, filename, speaker_ids, speaker_wav_align_map, noise_samples, device)) + + # for multiprocessing speed, we avoid loading potentially huge manifest list and speaker sample files into each process. + if num_workers > 1: + self._manifest = None + self._speaker_samples = None + + for sess_idx in range(num_sessions): self._furthest_sample = [0 for n in range(self._params.data_simulator.session_config.num_speakers)] self._missing_overlap = 0 self._audio_read_buffer_dict = {} - filename = self._params.data_simulator.outputs.output_filename + f"_{i}" if num_workers > 1: - futures.append([tp.submit(self._generate_session, i, basepath, filename), (basepath, filename)]) + futures.append(tp.submit(self._generate_session, *queue[sess_idx])) else: - futures.append([None, (basepath, filename)]) + futures.append(queue[sess_idx]) - for future in tqdm(futures, desc="Waiting for generators to finish", unit="jobs"): - basepath, filename = future[1] + if num_workers > 1: + generator = concurrent.futures.as_completed(futures) + else: + generator = futures + + for future in tqdm(generator, desc="Waiting for generators to finish", unit="jobs", total=len(futures)): if num_workers > 1: - future[0].result() + basepath, filename = future.result() else: - self._generate_session(i, basepath, filename) + self._noise_samples = self._sample_noise_manifest(source_noise_manifest) + basepath, filename = self._generate_session(*future) + wavlist.write(os.path.join(basepath, filename + '.wav\n')) rttmlist.write(os.path.join(basepath, filename + '.rttm\n')) jsonlist.write(os.path.join(basepath, filename + '.json\n')) ctmlist.write(os.path.join(basepath, filename + '.ctm\n')) textlist.write(os.path.join(basepath, filename + '.txt\n')) - # throw error if number of speakers is less than requested + # throw warning if number of speakers is less than requested num_missing = 0 for k in range(len(self._furthest_sample)): if self._furthest_sample[k] == 0: @@ -1215,6 +1334,7 @@ def generate_sessions(self, random_seed: int = None): warnings.warn( f"{self._params.data_simulator.session_config.num_speakers-num_missing} speakers were included in the clip instead of the requested amount of {self._params.data_simulator.session_config.num_speakers}" ) + tp.shutdown() wavlist.close() @@ -1223,6 +1343,8 @@ def generate_sessions(self, random_seed: int = None): ctmlist.close() textlist.close() + logging.info(f"Data simulation has been completed, results saved at: {basepath}") + class RIRMultiSpeakerSimulator(MultiSpeakerSimulator): """ @@ -1312,7 +1434,7 @@ def _check_args_rir(self): if len(sublist) != 3: raise Exception("Three coordinates must be provided for orientations") - def _generate_rir_gpuRIR(self) -> Tuple[torch.Tensor, int]: + def _generate_rir_gpuRIR(self): """ Create simulated RIR using the gpuRIR library @@ -1457,7 +1579,17 @@ def _convolve_rir(self, input, speaker_turn: int, RIR: torch.Tensor) -> Tuple[li output_sound.append(torch.tensor(out_channel)) return output_sound, length - def _generate_session(self, idx: int, basepath: str, filename: str, enforce_counter: int = 2): + def _generate_session( + self, + idx: int, + basepath: str, + filename: str, + speaker_ids: list, + speaker_wav_align_map: dict, + noise_samples: list, + device: torch.device, + enforce_counter: int = 2, + ): """ Generate a multispeaker audio session and corresponding label files. @@ -1465,17 +1597,21 @@ def _generate_session(self, idx: int, basepath: str, filename: str, enforce_coun idx (int): Index for current session (out of total number of sessions). basepath (str): Path to output directory. filename (str): Filename for output files. + speaker_ids (list): List of speaker IDs that will be used in this session. + speaker_wav_align_map (dict): Dictionary containing speaker IDs and their corresponding wav filepath and alignments. + noise_samples (list): List of randomly sampled noise source files that will be used for generating this session. + device (torch.device): Device to use for generating this session. enforce_counter (int): In enforcement mode, dominance is increased by a factor of enforce_counter for unrepresented speakers """ - speaker_ids = self._get_speaker_ids() # randomly select speaker IDs + self._device = device speaker_dominance = self._get_speaker_dominance() # randomly determine speaker dominance base_speaker_dominance = np.copy(speaker_dominance) - speaker_lists = self._get_speaker_samples(speaker_ids) # get list of samples per speaker self._set_speaker_volume() running_len_sample_count, prev_len_sample_count = 0, 0 # starting point for each sentence prev_speaker = None rttm_list, json_list, ctm_list = [], [], [] + self._noise_samples = noise_samples self._furthest_sample = [0 for n in range(self._params.data_simulator.session_config.num_speakers)] self._missing_overlap = 0 @@ -1523,7 +1659,7 @@ def _generate_session(self, idx: int, basepath: str, filename: str, enforce_coun < self._params.data_simulator.session_params.end_buffer * self._params.data_simulator.sr ): break - self._build_sentence(speaker_turn, speaker_ids, speaker_lists, max_samples_in_sentence) + self._build_sentence(speaker_turn, speaker_ids, speaker_wav_align_map, max_samples_in_sentence) augmented_sentence, length = self._convolve_rir(self._sentence, speaker_turn, RIR) start = self._add_silence_or_overlap( @@ -1587,6 +1723,9 @@ def _generate_session(self, idx: int, basepath: str, filename: str, enforce_coun write_manifest(os.path.join(basepath, filename + '.json'), json_list) write_ctm(os.path.join(basepath, filename + '.ctm'), ctm_list) write_text(os.path.join(basepath, filename + '.txt'), ctm_list) + del array + self.clean_up() + return basepath, filename def check_angle(key: str, val: Union[float, Iterable[float]]) -> bool: diff --git a/tools/speech_data_simulator/conf/data_simulator.yaml b/tools/speech_data_simulator/conf/data_simulator.yaml index e0fb0c8b7588..ff13580493be 100644 --- a/tools/speech_data_simulator/conf/data_simulator.yaml +++ b/tools/speech_data_simulator/conf/data_simulator.yaml @@ -42,6 +42,7 @@ data_simulator: background_noise: # If bg noise is used, a noise source position must be passed for RIR mode add_bg: false # Add ambient background noise if true background_manifest: null # Path to background noise manifest file + num_noise_files: 10 # Number of randomly chosen noise source files to be potentially included in one session snr: 60 # SNR for background noise (using average speaker power) speaker_enforcement: diff --git a/tools/speech_data_simulator/multispeaker_simulator.py b/tools/speech_data_simulator/multispeaker_simulator.py index 9fc725e972c4..1f5d84b335d0 100644 --- a/tools/speech_data_simulator/multispeaker_simulator.py +++ b/tools/speech_data_simulator/multispeaker_simulator.py @@ -12,14 +12,28 @@ # See the License for the specific language governing permissions and # limitations under the License. +from multiprocessing import set_start_method + from nemo.collections.asr.data.data_simulation import MultiSpeakerSimulator, RIRMultiSpeakerSimulator from nemo.core.config import hydra_runner """ -This script creates a synthetic diarization session using the LibriSpeech dataset. +This script creates a synthetic diarization session using the provided audio dataset with ctm files. Usage: - python create_diarization_dataset_librispeech.py + python /tools/speech_data_simulator/multispeaker_simulator.py \ + num_workers=10 \ + data_simulator.random_seed=42 \ + data_simulator.manifest_filepath=manifest_with_alignment_file.json \ + data_simulator.outputs.output_dir=./simulated_data \ + data_simulator.outputs.output_filename=sim_spk2_sess20 \ + data_simulator.session_config.num_sessions=1000 \ + data_simulator.session_config.num_speakers=2 \ + data_simulator.session_config.session_length=20 \ + data_simulator.background_noise.add_bg=False \ + data_simulator.background_noise.background_manifest=background_noise.json \ + data_simulator.background_noise.snr=40 \ + Check out parameters in ./conf/data_simulator.yaml. """ @@ -27,10 +41,12 @@ @hydra_runner(config_path="conf", config_name="data_simulator.yaml") def main(cfg): if cfg.data_simulator.rir_generation.use_rir: - lg = RIRMultiSpeakerSimulator(cfg=cfg) + simulator = RIRMultiSpeakerSimulator(cfg=cfg) else: - lg = MultiSpeakerSimulator(cfg=cfg) - lg.generate_sessions() + simulator = MultiSpeakerSimulator(cfg=cfg) + + set_start_method('spawn', force=True) + simulator.generate_sessions() if __name__ == "__main__":