Skip to content

Commit

Permalink
add ipython import guard (#11191) (#11195)
Browse files Browse the repository at this point in the history
* add ipython import guard



* handle notebooks which uses vad_utils plot function:



* Apply isort and black reformatting



* decrease line length



* Apply isort and black reformatting



* small update to doc



* Apply isort and black reformatting



---------

Signed-off-by: Nithin Rao Koluguri <[email protected]>
Signed-off-by: nithinraok <[email protected]>
Co-authored-by: nithinraok <[email protected]>
  • Loading branch information
nithinraok and nithinraok authored Nov 6, 2024
1 parent 794d56a commit 5182968
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 22 deletions.
78 changes: 56 additions & 22 deletions nemo/collections/asr/parts/utils/vad_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union

import IPython.display as ipd
import librosa
import matplotlib.pyplot as plt
import numpy as np
Expand All @@ -40,6 +39,15 @@
from nemo.collections.common.parts.preprocessing.manifest import get_full_path
from nemo.utils import logging

HAVE_IPYTHON = False
try:
import IPython.display as ipd

HAVE_IPYTHON = True
except:
HAVE_IPYTHON = False


"""
This file contains all the utility functions required for voice activity detection.
"""
Expand All @@ -66,7 +74,8 @@ def prepare_manifest(config: dict) -> str:
input_list = config['input']
else:
raise ValueError(
"The input for manifest preparation would either be a string of the filepath to manifest or a list of {'audio_filepath': i, 'offset': 0, 'duration': null} "
"The input for manifest preparation would either be a string of the filepath to \
manifest or a list of {'audio_filepath': i, 'offset': 0, 'duration': null} "
)

args_func = {
Expand Down Expand Up @@ -195,7 +204,8 @@ def write_vad_infer_manifest(file: dict, args_func: dict) -> list:

def get_vad_stream_status(data: list) -> list:
"""
Generate a list of status for each snippet in manifest. A snippet should be in single, start, next or end status.
Generate a list of status for each snippet in manifest.
A snippet should be in single, start, next or end status.
Used for concatenating to full audio file.
Args:
data (list): list of filepath of audio snippet
Expand Down Expand Up @@ -246,7 +256,8 @@ def generate_overlap_vad_seq(
out_dir: str = None,
) -> str:
"""
Generate predictions with overlapping input windows/segments. Then a smoothing filter is applied to decide the label for a frame spanned by multiple windows.
Generate predictions with overlapping input windows/segments.
Then a smoothing filter is applied to decide the label for a frame spanned by multiple windows.
Two common smoothing filters are supported: majority vote (median) and average (mean).
This function uses multiprocessing to speed up.
Args:
Expand Down Expand Up @@ -310,7 +321,8 @@ def generate_overlap_vad_seq_per_tensor(
frame: torch.Tensor, per_args: Dict[str, float], smoothing_method: str
) -> torch.Tensor:
"""
Use generated frame prediction (generated by shifting window of shift_length_in_sec (10ms)) to generate prediction with overlapping input window/segments
Use generated frame prediction (generated by shifting window of shift_length_in_sec (10ms))
to generate prediction with overlapping input window/segments
See description in generate_overlap_vad_seq.
Use this for single instance pipeline.
"""
Expand Down Expand Up @@ -472,7 +484,8 @@ def binarization(sequence: torch.Tensor, per_args: Dict[str, float]) -> torch.Te
Binarize predictions to speech and non-speech
Reference
Paper: Gregory Gelly and Jean-Luc Gauvain. "Minimum Word Error Training of RNN-based Voice Activity Detection", InterSpeech 2015.
Paper: Gregory Gelly and Jean-Luc Gauvain. "Minimum Word Error Training of RNN-based Voice Activity Detection", \
InterSpeech 2015.
Implementation: https://github.com/pyannote/pyannote-audio/blob/master/pyannote/audio/utils/signal.py
Args:
Expand All @@ -485,7 +498,8 @@ def binarization(sequence: torch.Tensor, per_args: Dict[str, float]) -> torch.Te
frame_length_in_sec (float): length of frame.
Returns:
speech_segments(torch.Tensor): A tensor of speech segment in torch.Tensor([[start1, end1], [start2, end2]]) format.
speech_segments(torch.Tensor): A tensor of speech segment in torch.Tensor([[start1, end1], [start2, end2]]) \
format.
"""
frame_length_in_sec = per_args.get('frame_length_in_sec', 0.01)

Expand Down Expand Up @@ -536,7 +550,8 @@ def remove_segments(original_segments: torch.Tensor, to_be_removed_segments: tor
"""
Remove speech segments list in to_be_removed_segments from original_segments.
For example,
remove torch.Tensor([[start2, end2],[start4, end4]]) from torch.Tensor([[start1, end1],[start2, end2],[start3, end3], [start4, end4]]),
remove torch.Tensor([[start2, end2],[start4, end4]]) from torch.Tensor([[start1, end1],[start2, end2],\
[start3, end3], [start4, end4]]),
->
torch.Tensor([[start1, end1],[start3, end3]])
"""
Expand All @@ -562,17 +577,21 @@ def filtering(speech_segments: torch.Tensor, per_args: Dict[str, float]) -> torc
Filter out short non_speech and speech segments.
Reference
Paper: Gregory Gelly and Jean-Luc Gauvain. "Minimum Word Error Training of RNN-based Voice Activity Detection", InterSpeech 2015.
Paper: Gregory Gelly and Jean-Luc Gauvain. "Minimum Word Error Training of RNN-based Voice Activity Detection", \
InterSpeech 2015.
Implementation: https://github.com/pyannote/pyannote-audio/blob/master/pyannote/audio/utils/signal.py
Args:
speech_segments (torch.Tensor): A tensor of speech segment in torch.Tensor([[start1, end1], [start2, end2]]) format.
speech_segments (torch.Tensor): A tensor of speech segment in torch.Tensor([[start1, end1], \
[start2, end2]]) format.
per_args:
min_duration_on (float): threshold for small non_speech deletion
min_duration_off (float): threshold for short speech segment deletion
filter_speech_first (float): Whether to perform short speech segment deletion first. Use 1.0 to represent True.
filter_speech_first (float): Whether to perform short speech segment deletion first. \
Use 1.0 to represent True.
Returns:
speech_segments(torch.Tensor): A tensor of filtered speech segment in torch.Tensor([[start1, end1], [start2, end2]]) format.
speech_segments(torch.Tensor): A tensor of filtered speech segment in \
torch.Tensor([[start1, end1], [start2, end2]]) format.
"""
if speech_segments.shape == torch.Size([0]):
return speech_segments
Expand Down Expand Up @@ -709,7 +728,8 @@ def generate_vad_segment_table(
17,18, speech
Args:
vad_pred_dir (str): directory of prediction files to be processed.
postprocessing_params (dict): dictionary of thresholds for prediction score. See details in binarization and filtering.
postprocessing_params (dict): dictionary of thresholds for prediction score.
See details in binarization and filtering.
frame_length_in_sec (float): frame length.
out_dir (str): output dir of generated table/csv file.
num_workers(float): number of process for multiprocessing
Expand Down Expand Up @@ -820,10 +840,12 @@ def vad_tune_threshold_on_dev(
num_workers: int = 20,
) -> Tuple[dict, dict]:
"""
Tune thresholds on dev set. Return best thresholds which gives the lowest detection error rate (DetER) in thresholds.
Tune thresholds on dev set. Return best thresholds which gives the lowest detection error rate
(DetER) in thresholds.
Args:
params (dict): dictionary of parameters to be tuned on.
vad_pred_method (str): suffix of prediction file. Use to locate file. Should be either in "frame", "mean" or "median".
vad_pred_method (str): suffix of prediction file. Use to locate file.
Should be either in "frame", "mean" or "median".
groundtruth_RTTM_dir (str): directory of ground-truth rttm files or a file contains the paths of them.
focus_metric (str): metrics we care most when tuning threshold. Should be either in "DetER", "FA", "MISS"
frame_length_in_sec (float): frame length.
Expand Down Expand Up @@ -914,7 +936,8 @@ def check_if_param_valid(params: dict) -> bool:
for j in params[i]:
if not j >= 0:
raise ValueError(
"Invalid inputs! All float parameters except pad_onset and pad_offset should be larger than 0!"
"Invalid inputs! All float parameters except pad_onset and pad_offset should be \
larger than 0!"
)

if not (all(i <= 1 for i in params['onset']) and all(i <= 1 for i in params['offset'])):
Expand Down Expand Up @@ -986,9 +1009,13 @@ def plot(
threshold (float): threshold for prediction score (from 0 to 1).
per_args(dict): a dict that stores the thresholds for postprocessing.
unit_frame_len (float): unit frame length in seconds for VAD predictions.
label_repeat (int): repeat the label for this number of times to match different frame lengths in preds and labels.
label_repeat (int): repeat the label for this number of times to match different \
frame lengths in preds and labels.
xticks_step (int): step size for xticks.
"""
if HAVE_IPYTHON is False:
raise ImportError("IPython is not installed. Please install IPython to use this function.")

plt.figure(figsize=[20, 2])

audio, sample_rate = librosa.load(
Expand Down Expand Up @@ -1254,7 +1281,8 @@ def stitch_segmented_asr_output(
fout.flush()

logging.info(
f"Finish stitch segmented ASR output to {stitched_output_manifest}, the speech segments info has been stored in directory {speech_segments_tensor_dir}"
f"Finish stitch segmented ASR output to {stitched_output_manifest}, \
the speech segments info has been stored in directory {speech_segments_tensor_dir}"
)
return stitched_output_manifest

Expand Down Expand Up @@ -1438,6 +1466,9 @@ def plot_sample_from_rttm(
"""
Plot audio signal and frame-level labels from RTTM file
"""
if HAVE_IPYTHON is False:
raise ImportError("IPython is not installed. Please install IPython to use this function.")

plt.figure(figsize=[20, 2])

audio, sample_rate = librosa.load(path=audio_file, sr=16000, mono=True, offset=offset, duration=max_duration)
Expand Down Expand Up @@ -1472,8 +1503,9 @@ def plot_sample_from_rttm(
def align_labels_to_frames(probs, labels, threshold=0.2):
"""
Aligns labels to frames when the frame length (e.g., 10ms) is different from the label length (e.g., 20ms).
The threshold 0.2 is not important, since the actual ratio will always be close to an integer unless using frame/label
lengths that are not multiples of each other (e.g., 15ms frame length and 20ms label length), which is not valid.
The threshold 0.2 is not important, since the actual ratio will always be close to an integer
unless using frame/label. lengths that are not multiples of each other
(e.g., 15ms frame length and 20ms label length), which is not valid.
The value 0.2 here is just for easier unit testing.
Args:
probs (List[float]): list of probabilities
Expand Down Expand Up @@ -1511,11 +1543,13 @@ def align_labels_to_frames(probs, labels, threshold=0.2):
ratio = frames_len / labels_len
res = frames_len % labels_len
if ceil(ratio) - ratio < threshold:
# e.g., ratio is 1.83, ceil(ratio) = 2, then we repeat labels to make it a multiple of 2, and discard the redundant labels
# e.g., ratio is 1.83, ceil(ratio) = 2, then we repeat labels to make it a
# multiple of 2, and discard the redundant labels
labels = labels.repeat_interleave(ceil(ratio), dim=0).long().tolist()
labels = labels[:frames_len]
else:
# e.g., ratio is 2.02, floor(ratio) = 2, then we repeat labels to make it a multiple of 2 and add additional labels
# e.g., ratio is 2.02, floor(ratio) = 2, then we repeat labels to make it a multiple of
# 2 and add additional labels
labels = labels.repeat_interleave(floor(ratio), dim=0).long().tolist()
if res > 0:
labels += labels[-res:]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
"!pip install wget\n",
"!apt-get install sox libsndfile1 ffmpeg\n",
"!pip install text-unidecode\n",
"!pip install ipython\n",
"\n",
"# ## Install NeMo\n",
"BRANCH = 'r2.0.0'\n",
Expand Down

0 comments on commit 5182968

Please sign in to comment.