diff --git a/nemo/collections/asr/parts/utils/nmesc_clustering.py b/nemo/collections/asr/parts/utils/nmesc_clustering.py index b5ddf62fe6f1..c872ead22173 100644 --- a/nemo/collections/asr/parts/utils/nmesc_clustering.py +++ b/nemo/collections/asr/parts/utils/nmesc_clustering.py @@ -31,14 +31,14 @@ # https://arxiv.org/pdf/2003.02405.pdf and the implementation from # https://github.com/tango4j/Auto-Tuning-Spectral-Clustering. -from typing import Dict, List +from typing import Dict, List, Tuple import torch from torch.linalg import eigh, eigvalsh @torch.jit.script -def cos_similarity(a: torch.Tensor, b: torch.Tensor, eps=torch.tensor(3.5e-4)): +def cos_similarity(a: torch.Tensor, b: torch.Tensor, eps=torch.tensor(3.5e-4)) -> torch.Tensor: """ Calculate cosine similarities of the given two set of tensors. The output is an N by N matrix where N is the number of feature vectors. @@ -61,7 +61,7 @@ def cos_similarity(a: torch.Tensor, b: torch.Tensor, eps=torch.tensor(3.5e-4)): @torch.jit.script -def ScalerMinMax(X: torch.Tensor): +def ScalerMinMax(X: torch.Tensor) -> torch.Tensor: """ Min-max scale the input affinity matrix X, which will lead to a dynamic range of [0, 1]. @@ -79,7 +79,9 @@ def ScalerMinMax(X: torch.Tensor): @torch.jit.script -def getEuclideanDistance(specEmbA: torch.Tensor, specEmbB: torch.Tensor, device: torch.device = torch.device('cpu')): +def getEuclideanDistance( + specEmbA: torch.Tensor, specEmbB: torch.Tensor, device: torch.device = torch.device('cpu') +) -> torch.Tensor: """ Calculate Euclidean distances from the given feature tensors. @@ -193,7 +195,7 @@ def kmeans_torch( iter_limit: int = 15, random_state: int = 0, device: torch.device = torch.device('cpu'), -): +) -> torch.Tensor: """ Run k-means algorithm on the given set of spectral embeddings in X. The threshold and iter_limit variables are set to show the best performance on speaker diarization @@ -264,7 +266,7 @@ def kmeans_torch( @torch.jit.script -def getTheLargestComponent(affinity_mat: torch.Tensor, seg_index: int, device: torch.device): +def getTheLargestComponent(affinity_mat: torch.Tensor, seg_index: int, device: torch.device) -> torch.Tensor: """ Find the largest affinity_mat connected components for each given node. This is for checking whether the affinity_mat is fully connected. @@ -302,7 +304,7 @@ def getTheLargestComponent(affinity_mat: torch.Tensor, seg_index: int, device: t @torch.jit.script -def isGraphFullyConnected(affinity_mat: torch.Tensor, device: torch.device): +def isGraphFullyConnected(affinity_mat: torch.Tensor, device: torch.device) -> torch.Tensor: """ Check whether the given affinity matrix is a fully connected graph. """ @@ -310,7 +312,7 @@ def isGraphFullyConnected(affinity_mat: torch.Tensor, device: torch.device): @torch.jit.script -def getKneighborsConnections(affinity_mat: torch.Tensor, p_value: int): +def getKneighborsConnections(affinity_mat: torch.Tensor, p_value: int) -> torch.Tensor: """ Binarize top-p values for each row from the given affinity matrix. """ @@ -324,7 +326,7 @@ def getKneighborsConnections(affinity_mat: torch.Tensor, p_value: int): @torch.jit.script -def getAffinityGraphMat(affinity_mat_raw: torch.Tensor, p_value: int): +def getAffinityGraphMat(affinity_mat_raw: torch.Tensor, p_value: int) -> torch.Tensor: """ Calculate a binarized graph matrix and symmetrize the binarized graph matrix. @@ -335,7 +337,9 @@ def getAffinityGraphMat(affinity_mat_raw: torch.Tensor, p_value: int): @torch.jit.script -def getMinimumConnection(mat: torch.Tensor, max_N: torch.Tensor, n_list: torch.Tensor, device: torch.device): +def getMinimumConnection( + mat: torch.Tensor, max_N: torch.Tensor, n_list: torch.Tensor, device: torch.device +) -> Tuple[torch.Tensor, torch.Tensor]: """ Generate connections until fully connect all the nodes in the graph. If the graph is not fully connected, it might generate inaccurate results. @@ -352,7 +356,7 @@ def getMinimumConnection(mat: torch.Tensor, max_N: torch.Tensor, n_list: torch.T @torch.jit.script -def getRepeatedList(mapping_argmat: torch.Tensor, score_mat_size: torch.Tensor): +def getRepeatedList(mapping_argmat: torch.Tensor, score_mat_size: torch.Tensor) -> torch.Tensor: """ Count the numbers in the mapping dictionary and create lists that contain repeated indices that will be used for creating a repeated affinity matrix. @@ -365,7 +369,7 @@ def getRepeatedList(mapping_argmat: torch.Tensor, score_mat_size: torch.Tensor): @torch.jit.script -def get_argmin_mat(timestamps_in_scales: List[torch.Tensor]): +def get_argmin_mat(timestamps_in_scales: List[torch.Tensor]) -> List[torch.Tensor]: """ Calculate the mapping between the base scale and other scales. A segment from a longer scale is repeatedly mapped to a segment from a shorter scale or the base scale. @@ -376,8 +380,8 @@ def get_argmin_mat(timestamps_in_scales: List[torch.Tensor]): Each tensor has dimensions of (Number of base segments) x 2. Returns: - session_scale_mapping_dict (dict): - Dictionary containing argmin arrays indexed by scale index. + session_scale_mapping_list (list): + List containing argmin arrays indexed by scale index. """ scale_list = list(range(len(timestamps_in_scales))) segment_anchor_list = [] @@ -387,18 +391,18 @@ def get_argmin_mat(timestamps_in_scales: List[torch.Tensor]): base_scale_idx = max(scale_list) base_scale_anchor = segment_anchor_list[base_scale_idx] - session_scale_mapping_dict = [] + session_scale_mapping_list = [] for scale_idx in scale_list: curr_scale_anchor = segment_anchor_list[scale_idx] curr_mat = torch.tile(curr_scale_anchor, (base_scale_anchor.shape[0], 1)) base_mat = torch.tile(base_scale_anchor, (curr_scale_anchor.shape[0], 1)).t() argmin_mat = torch.argmin(torch.abs(curr_mat - base_mat), dim=1) - session_scale_mapping_dict.append(argmin_mat) - return session_scale_mapping_dict + session_scale_mapping_list.append(argmin_mat) + return session_scale_mapping_list @torch.jit.script -def getCosAffinityMatrix(emb: torch.Tensor): +def getCosAffinityMatrix(emb: torch.Tensor) -> torch.Tensor: """ Calculate cosine similarity values among speaker embeddings then min-max normalize the affinity matrix. @@ -426,7 +430,7 @@ def getMultiScaleCosAffinityMatrix( embeddings_in_scales: List[torch.Tensor], timestamps_in_scales: List[torch.Tensor], device: torch.device = torch.device('cpu'), -): +) -> torch.Tensor: """ Calculate cosine similarity values among speaker embeddings for each scale then apply multiscale weights to calculate the fused similarity matrix. @@ -444,10 +448,10 @@ def getMultiScaleCosAffinityMatrix( """ multiscale_weights = multiscale_weights.to(device) score_mat_list, repeated_tensor_list = [], [] - session_scale_mapping_dict = get_argmin_mat(timestamps_in_scales) + session_scale_mapping_list = get_argmin_mat(timestamps_in_scales) scale_list = list(range(len(timestamps_in_scales))) for scale_idx in scale_list: - mapping_argmat = session_scale_mapping_dict[scale_idx] + mapping_argmat = session_scale_mapping_list[scale_idx] emb_t = embeddings_in_scales[scale_idx].half().to(device) score_mat_torch = getCosAffinityMatrix(emb_t) repeat_list = getRepeatedList(mapping_argmat, torch.tensor(score_mat_torch.shape[0])).to(device) @@ -460,7 +464,7 @@ def getMultiScaleCosAffinityMatrix( @torch.jit.script -def getLaplacian(X: torch.Tensor): +def getLaplacian(X: torch.Tensor) -> torch.Tensor: """ Calculate a laplacian matrix from an affinity matrix X. """ @@ -472,7 +476,9 @@ def getLaplacian(X: torch.Tensor): @torch.jit.script -def eigDecompose(laplacian: torch.Tensor, cuda: bool, device: torch.device = torch.device('cpu')): +def eigDecompose( + laplacian: torch.Tensor, cuda: bool, device: torch.device = torch.device('cpu') +) -> Tuple[torch.Tensor, torch.Tensor]: """ Calculate eigenvalues and eigenvectors from the Laplacian matrix. """ @@ -487,7 +493,7 @@ def eigDecompose(laplacian: torch.Tensor, cuda: bool, device: torch.device = tor @torch.jit.script -def eigValueSh(laplacian: torch.Tensor, cuda: bool, device: torch.device = torch.device('cpu')): +def eigValueSh(laplacian: torch.Tensor, cuda: bool, device: torch.device = torch.device('cpu')) -> torch.Tensor: """ Calculate only eigenvalues from the Laplacian matrix. """ @@ -502,7 +508,7 @@ def eigValueSh(laplacian: torch.Tensor, cuda: bool, device: torch.device = torch @torch.jit.script -def getLamdaGaplist(lambdas: torch.Tensor): +def getLamdaGaplist(lambdas: torch.Tensor) -> torch.Tensor: """ Calculate the gaps between lambda values. """ @@ -512,7 +518,7 @@ def getLamdaGaplist(lambdas: torch.Tensor): @torch.jit.script -def addAnchorEmb(emb: torch.Tensor, anchor_sample_n: int, anchor_spk_n: int, sigma: float): +def addAnchorEmb(emb: torch.Tensor, anchor_sample_n: int, anchor_spk_n: int, sigma: float) -> torch.Tensor: """ Add randomly generated synthetic embeddings to make eigenanalysis more stable. We refer to these embeddings as anchor embeddings. @@ -555,7 +561,7 @@ def getEnhancedSpeakerCount( anchor_sample_n: int = 10, sigma: float = 50, cuda: bool = False, -): +) -> torch.Tensor: """ Calculate the number of speakers using NME analysis with anchor embeddings. Add dummy speaker embedding vectors and run speaker counting multiple times to enhance the speaker counting accuracy @@ -593,7 +599,7 @@ def getEnhancedSpeakerCount( mat = getCosAffinityMatrix(emb_aug) nmesc = NMESC( mat, - max_num_speaker=emb.shape[0], + max_num_speakers=emb.shape[0], max_rp_threshold=0.15, sparse_search=True, sparse_search_volume=50, @@ -612,7 +618,7 @@ def split_input_data( embeddings_in_scales: torch.Tensor, timestamps_in_scales: torch.Tensor, multiscale_segment_counts: torch.LongTensor, -): +) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: """ Split multiscale embeddings and multiscale timestamps and put split scale-wise data into python lists. This formatting function is needed to make the input type as `torch.Tensor`. @@ -639,14 +645,16 @@ def split_input_data( @torch.jit.script -def estimateNumofSpeakers(affinity_mat: torch.Tensor, max_num_speaker: int, cuda: bool = False): +def estimateNumofSpeakers( + affinity_mat: torch.Tensor, max_num_speakers: int, cuda: bool = False +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Estimate the number of speakers using eigendecomposition on the Laplacian Matrix. Args: affinity_mat (Tensor): N by N affinity matrix - max_num_speaker (int): + max_num_speakers (int): Maximum number of clusters to consider for each session cuda (bool): If cuda available eigendecomposition is computed on GPUs. @@ -664,7 +672,7 @@ def estimateNumofSpeakers(affinity_mat: torch.Tensor, max_num_speaker: int, cuda lambdas = eigValueSh(laplacian, cuda=cuda) lambdas = torch.sort(lambdas)[0] lambda_gap = getLamdaGaplist(lambdas) - num_of_spk = torch.argmax(lambda_gap[: min(max_num_speaker, lambda_gap.shape[0])]) + 1 + num_of_spk = torch.argmax(lambda_gap[: min(max_num_speakers, lambda_gap.shape[0])]) + 1 return num_of_spk, lambdas, lambda_gap @@ -706,7 +714,7 @@ def __init__( self.cuda = cuda self.device = device - def forward(self, X): + def forward(self, X) -> torch.Tensor: """ Call self.clusterSpectralEmbeddings() function to predict cluster labels. @@ -725,7 +733,7 @@ def forward(self, X): def clusterSpectralEmbeddings( self, affinity: torch.Tensor, cuda: bool = False, device: torch.device = torch.device('cpu') - ): + ) -> torch.Tensor: """ Perform k-means clustering on spectral embeddings. To alleviate the effect of randomness, k-means clustering is performed for (self.n_random_trials) times then the final labels are obtained @@ -757,7 +765,7 @@ def clusterSpectralEmbeddings( labels = stacked_labels[label_index] return labels - def getSpectralEmbeddings(self, affinity_mat: torch.Tensor, n_spks: int = 8, cuda: bool = False): + def getSpectralEmbeddings(self, affinity_mat: torch.Tensor, n_spks: int = 8, cuda: bool = False) -> torch.Tensor: """ Calculate eigenvalues and eigenvectors to extract spectral embeddings. @@ -807,19 +815,14 @@ class NMESC: Methods: NMEanalysis(): Performs NME-analysis to estimate p_value and the number of speakers - subsampleAffinityMat(nme_mat_size): Subsamples the number of speakers to reduce the computational load - getPvalueList(): Generates a list containing p-values that need to be examined. - getEigRatio(p_neighbors): Calculates g_p, which is a ratio between p_neighbors and the maximum eigengap - getLamdaGaplist(lambdas): Calculates lambda gap values from an array contains lambda values - estimateNumofSpeakers(affinity_mat): Estimates the number of speakers using lambda gap list """ @@ -827,13 +830,13 @@ class NMESC: def __init__( self, mat: torch.Tensor, - max_num_speaker: int = 10, + max_num_speakers: int = 10, max_rp_threshold: float = 0.15, sparse_search: bool = True, sparse_search_volume: int = 30, nme_mat_size: int = 512, use_subsampling_for_nme: bool = True, - fixed_thres: float = 0.0, + fixed_thres: float = -1.0, maj_vote_spk_count: bool = False, parallelism: bool = True, cuda: bool = False, @@ -843,7 +846,7 @@ def __init__( Args: mat (Tensor): Cosine similarity matrix calculated from the provided speaker embeddings. - max_num_speaker (int): + max_num_speakers (int): Maximum number of speakers for estimating number of speakers. Shows stable performance under 20. max_rp_threshold (float): @@ -868,29 +871,31 @@ def __init__( If True, take a majority vote on all p-values in the given range to estimate the number of speakers. The majority voting may contribute to surpress overcounting of the speakers and improve speaker counting accuracy. + parallelism (bool): + If True, turn on parallelism based on torch.jit.script library. cuda (bool): Use cuda for Eigen decomposition if cuda=True. nme_mat_size (int): Targeted size of matrix for NME analysis. """ - self.max_num_speaker: int = max_num_speaker - self.max_rp_threshold = max_rp_threshold - self.use_subsampling_for_nme = use_subsampling_for_nme + self.max_num_speakers: int = max_num_speakers + self.max_rp_threshold: float = max_rp_threshold + self.use_subsampling_for_nme: bool = use_subsampling_for_nme self.nme_mat_size: int = nme_mat_size - self.sparse_search = sparse_search - self.sparse_search_volume = sparse_search_volume + self.sparse_search: bool = sparse_search + self.sparse_search_volume: int = sparse_search_volume self.min_p_value = torch.tensor(2) self.fixed_thres: float = fixed_thres self.cuda: bool = cuda self.eps = 1e-10 self.max_N = torch.tensor(0) - self.mat = mat + self.mat: torch.Tensor = mat self.p_value_list: torch.Tensor = self.min_p_value.unsqueeze(0) self.device = device - self.maj_vote_spk_count = maj_vote_spk_count - self.parallelism = parallelism + self.maj_vote_spk_count: bool = maj_vote_spk_count + self.parallelism: bool = parallelism - def forward(self): + def forward(self) -> Tuple[torch.Tensor, torch.Tensor]: """ Subsample the input matrix to reduce the computational load. """ @@ -944,7 +949,7 @@ def forward(self): est_num_of_spk = est_spk_n_dict[rp_p_value.item()] return est_num_of_spk, p_hat_value - def subsampleAffinityMat(self, nme_mat_size: int): + def subsampleAffinityMat(self, nme_mat_size: int) -> torch.Tensor: """ Perform subsampling of affinity matrix. This subsampling is for calculational complexity, not for performance. @@ -969,7 +974,7 @@ def subsampleAffinityMat(self, nme_mat_size: int): self.mat = self.mat[:: subsample_ratio.item(), :: subsample_ratio.item()] return subsample_ratio - def getEigRatio(self, p_neighbors: int): + def getEigRatio(self, p_neighbors: int) -> torch.Tensor: """ For a given p_neighbors value, calculate g_p, which is a ratio between p_neighbors and the maximum eigengap values. @@ -989,14 +994,16 @@ def getEigRatio(self, p_neighbors: int): The ratio between p_neighbors value and the maximum eigen gap value. """ affinity_mat = getAffinityGraphMat(self.mat, p_neighbors) - est_num_of_spk, lambdas, lambda_gap_list = estimateNumofSpeakers(affinity_mat, self.max_num_speaker, self.cuda) - arg_sorted_idx = torch.argsort(lambda_gap_list[: self.max_num_speaker], descending=True) + est_num_of_spk, lambdas, lambda_gap_list = estimateNumofSpeakers( + affinity_mat, self.max_num_speakers, self.cuda + ) + arg_sorted_idx = torch.argsort(lambda_gap_list[: self.max_num_speakers], descending=True) max_key = arg_sorted_idx[0] max_eig_gap = lambda_gap_list[max_key] / (torch.max(lambdas).item() + self.eps) g_p = (p_neighbors / self.mat.shape[0]) / (max_eig_gap + self.eps) return torch.stack([g_p, est_num_of_spk]) - def getPvalueList(self): + def getPvalueList(self) -> torch.Tensor: """ Generates a p-value (p_neighbour) list for searching. p_value_list must include 2 (min_p_value) since at least one neighboring segment should be selected other than itself. @@ -1040,16 +1047,10 @@ def getPvalueList(self): class SpeakerClustering(torch.nn.Module): def __init__( self, - max_num_speaker: int = 8, min_samples_for_nmesc: int = 6, - enhanced_count_thres: int = 80, nme_mat_size: int = 300, - max_rp_threshold: float = 0.15, sparse_search: bool = True, - sparse_search_volume: int = 30, maj_vote_spk_count: bool = False, - fixed_thres: float = 0.0, - multiscale_weights: torch.Tensor = torch.tensor(1).unsqueeze(0), parallelism: bool = False, cuda: bool = False, ): @@ -1058,7 +1059,7 @@ def __init__( NME-SC part is converted to torch.tensor based operations in NeMo 1.9. Args: - max_num_speaker (int): + max_num_speakers (int): The maximum number of clusters to consider for each session min_samples_for_nmesc (int): The minimum number of samples required for NME clustering. This avoids @@ -1096,18 +1097,12 @@ def __init__( Boolean variable for toggling cuda availability. """ super().__init__() - self.max_num_speaker: int = max_num_speaker self.min_samples_for_nmesc: int = min_samples_for_nmesc - self.enhanced_count_thres: int = enhanced_count_thres self.nme_mat_size: int = nme_mat_size - self.max_rp_threshold: float = max_rp_threshold self.sparse_search: bool = sparse_search - self.sparse_search_volume: int = sparse_search_volume - self.maj_vote_spk_count: bool = maj_vote_spk_count - self.fixed_thres: float = fixed_thres - self.multiscale_weights: torch.Tensor = multiscale_weights self.parallelism: bool = parallelism self.cuda: bool = cuda + self.maj_vote_spk_count: bool = maj_vote_spk_count self.embeddings_in_scales: List[torch.Tensor] = [torch.Tensor(0)] self.timestamps_in_scales: List[torch.Tensor] = [torch.Tensor(0)] self.device = torch.device("cuda") if self.cuda else torch.device("cpu") @@ -1117,63 +1112,98 @@ def forward( embeddings_in_scales: torch.Tensor, timestamps_in_scales: torch.Tensor, multiscale_segment_counts: torch.LongTensor, + multiscale_weights: torch.Tensor, oracle_num_speakers: torch.LongTensor, - ): + max_rp_threshold: torch.Tensor = torch.tensor(0.15), + max_num_speakers: torch.LongTensor = torch.tensor(8, dtype=torch.long), + enhanced_count_thres: torch.LongTensor = torch.tensor(80, dtype=torch.long), + sparse_search_volume: torch.LongTensor = torch.tensor(30, dtype=torch.long), + fixed_thres: torch.Tensor = torch.tensor(-1.0), + ) -> torch.LongTensor: """ Calculate affinity matrix using timestamps and speaker embeddings, run NME analysis to estimate the best p-value and perform spectral clustering based on the estimated p-value and the calculated affinity matrix. + Caution: + For the sake of compatibility with libtorch, python boolean `False` is replaced with `torch.LongTensor(-1)`. + Args: embeddings_in_scales (Tensor): Concatenated Torch tensor containing embeddings in multiple scales This tensor has dimensions of (Number of base segments) x (Embedding Dimension) timestamps_in_scales (Tensor): - Concatenated Torch tensor containing timestamps in multiple scales - This tensor has dimensions of (Number of base segments) x 2 + Concatenated Torch tensor containing timestamps in multiple scales. + This tensor has dimensions of (Total number of segments all scales) x 2 Example: - [[0.4, 1.4], [0.9, 1.9], [1.4, 2.4], ... [121.2, 122.2]] + >>> timestamps_in_scales = \ + >>> torch.tensor([0.4, 1.4], [0.9, 1.9], [1.4, 2.4], ... [121.2, 122.2]]) + multiscale_segment_counts (LongTensor): Concatenated Torch tensor containing number of segments per each scale This tensor has dimensions of (Number of scales) Example: - [31, 52, 84, 105, 120] + >>> multiscale_segment_counts = torch.LongTensor([31, 52, 84, 105, 120]) + + multiscale_weights (Tensor): + Multi-scale weights that are used when affinity scores are merged. + Example: + >>> multiscale_weights = torch.tensor([1.4, 1.3, 1.2, 1.1, 1.0]) + oracle_num_speakers (LongTensor): The number of speakers in a session from the reference transcript + max_num_speakers (LongTensor): + The upper bound for the number of speakers in each session + enhanced_count_thres: (LongTensor) + For the short audio recordings, clustering algorithm cannot + accumulate enough amount of speaker profile for each cluster. + Thus, function `getEnhancedSpeakerCount` employs anchor embeddings + (dummy representations) to mitigate the effect of cluster sparsity. + enhanced_count_thres = 80 is recommended. + max_rp_threshold (Tensor): + Limits the range of parameter search. + Clustering performance can vary depending on this range. + Default is 0.15. + sparse_search_volume (LongTensor): + Number of p_values we search during NME analysis. + Default is 30. The lower the value, the faster NME-analysis becomes. + Lower than 20 might cause a poor parameter estimation. + fixed_thres (Tensor): + If fixed_thres value is provided, NME-analysis process will be skipped. + This value should be optimized on a development set to obtain a quality result. + Default is None and performs NME-analysis to estimate the threshold. Returns: - Y (torch.tensor[int]) + Y (Tensor): Speaker label for each segment. """ - self.embeddings_in_scales, self.timestamps_in_scales = split_input_data( embeddings_in_scales, timestamps_in_scales, multiscale_segment_counts ) emb = self.embeddings_in_scales[multiscale_segment_counts.shape[0] - 1] + # Cases for extreamly short sessions if emb.shape[0] == 1: return torch.zeros((1,), dtype=torch.int64) - elif emb.shape[0] <= max(self.enhanced_count_thres, self.min_samples_for_nmesc) and oracle_num_speakers < 0: + elif emb.shape[0] <= max(enhanced_count_thres, self.min_samples_for_nmesc) and oracle_num_speakers < 0: est_num_of_spk_enhanced = getEnhancedSpeakerCount(emb=emb, cuda=self.cuda) else: est_num_of_spk_enhanced = torch.tensor(-1) - oracle_num_speakers = int(oracle_num_speakers.item()) - if oracle_num_speakers > 0: - self.max_num_speaker = oracle_num_speakers + max_num_speakers = oracle_num_speakers mat = getMultiScaleCosAffinityMatrix( - self.multiscale_weights, self.embeddings_in_scales, self.timestamps_in_scales, self.device + multiscale_weights, self.embeddings_in_scales, self.timestamps_in_scales, self.device ) nmesc = NMESC( mat, - max_num_speaker=self.max_num_speaker, - max_rp_threshold=self.max_rp_threshold, + max_num_speakers=int(max_num_speakers.item()), + max_rp_threshold=float(max_rp_threshold.item()), sparse_search=self.sparse_search, - sparse_search_volume=self.sparse_search_volume, - fixed_thres=self.fixed_thres, + sparse_search_volume=int(sparse_search_volume.item()), + fixed_thres=float(fixed_thres.item()), nme_mat_size=self.nme_mat_size, maj_vote_spk_count=self.maj_vote_spk_count, parallelism=self.parallelism, @@ -1181,18 +1211,22 @@ def forward( device=self.device, ) + # If there are less than `min_samples_for_nmesc` segments, est_num_of_spk is 1. if mat.shape[0] > self.min_samples_for_nmesc: est_num_of_spk, p_hat_value = nmesc.forward() affinity_mat = getAffinityGraphMat(mat, p_hat_value) else: + est_num_of_spk = torch.tensor(1) affinity_mat = mat - est_num_of_spk = torch.tensor(1).to(self.device) + # n_clusters is number of speakers estimated from spectral clustering. if oracle_num_speakers > 0: - est_num_of_spk = torch.tensor(oracle_num_speakers).to(self.device) + n_clusters = int(oracle_num_speakers.item()) elif est_num_of_spk_enhanced > 0: - est_num_of_spk = est_num_of_spk_enhanced + n_clusters = int(est_num_of_spk_enhanced.item()) + else: + n_clusters = int(est_num_of_spk.item()) - spectral_model = SpectralClustering(n_clusters=est_num_of_spk, cuda=self.cuda, device=self.device) + spectral_model = SpectralClustering(n_clusters=n_clusters, cuda=self.cuda, device=self.device) Y = spectral_model.forward(affinity_mat) return Y diff --git a/nemo/collections/asr/parts/utils/speaker_utils.py b/nemo/collections/asr/parts/utils/speaker_utils.py index fdeff6de26fd..cee6c0724e2c 100644 --- a/nemo/collections/asr/parts/utils/speaker_utils.py +++ b/nemo/collections/asr/parts/utils/speaker_utils.py @@ -418,7 +418,7 @@ def write_cluster_labels(base_scale_idx, lines_cluster_labels, out_rttm_dir): f.write(clus_label_line) -def perform_clustering(embs_and_timestamps, AUDIO_RTTM_MAP, out_rttm_dir, clustering_params, use_torch_script=False): +def perform_clustering(embs_and_timestamps, AUDIO_RTTM_MAP, out_rttm_dir, clustering_params): """ Performs spectral clustering on embeddings with time stamps generated from VAD output @@ -440,7 +440,6 @@ def perform_clustering(embs_and_timestamps, AUDIO_RTTM_MAP, out_rttm_dir, cluste all_hypothesis = [] all_reference = [] no_references = False - max_num_speakers = clustering_params['max_num_speakers'] lines_cluster_labels = [] cuda = True @@ -448,31 +447,32 @@ def perform_clustering(embs_and_timestamps, AUDIO_RTTM_MAP, out_rttm_dir, cluste logging.warning("cuda=False, using CPU for Eigen decomposition. This might slow down the clustering process.") cuda = False + speaker_clustering = SpeakerClustering(maj_vote_spk_count=clustering_params.maj_vote_spk_count, cuda=cuda) + + # If True, export torch script module and save it to the base folder. + if clustering_params.get('export_script_module', False): + speaker_clustering = torch.jit.script(speaker_clustering) + torch.jit.save(speaker_clustering, 'speaker_clustering_script.pt') + for uniq_id, audio_rttm_values in tqdm(AUDIO_RTTM_MAP.items(), desc='clustering', leave=False): + uniq_embs_and_timestamps = embs_and_timestamps[uniq_id] + if clustering_params.oracle_num_speakers: num_speakers = audio_rttm_values.get('num_speakers', None) if num_speakers is None: raise ValueError("Provided option as oracle num of speakers but num_speakers in manifest is null") else: num_speakers = -1 - uniq_embs_and_timestamps = embs_and_timestamps[uniq_id] - num_speakers = torch.tensor(num_speakers, dtype=torch.long) - - speaker_clustering = SpeakerClustering( - max_num_speaker=max_num_speakers, - max_rp_threshold=clustering_params.max_rp_threshold, - sparse_search_volume=clustering_params.sparse_search_volume, - multiscale_weights=uniq_embs_and_timestamps['multiscale_weights'], - cuda=cuda, - ) - if use_torch_script: - speaker_clustering = torch.jit.script(speaker_clustering) cluster_labels = speaker_clustering.forward( - uniq_embs_and_timestamps['embeddings'], - uniq_embs_and_timestamps['timestamps'], - uniq_embs_and_timestamps['multiscale_segment_counts'], - oracle_num_speakers=num_speakers, + embeddings_in_scales=uniq_embs_and_timestamps['embeddings'], + timestamps_in_scales=uniq_embs_and_timestamps['timestamps'], + multiscale_segment_counts=uniq_embs_and_timestamps['multiscale_segment_counts'], + multiscale_weights=uniq_embs_and_timestamps['multiscale_weights'], + oracle_num_speakers=torch.tensor(num_speakers, dtype=torch.long), + max_num_speakers=torch.tensor(clustering_params.max_num_speakers, dtype=torch.long), + max_rp_threshold=torch.tensor(clustering_params.max_rp_threshold), + sparse_search_volume=torch.tensor(clustering_params.sparse_search_volume, dtype=torch.long), ) base_scale_idx = uniq_embs_and_timestamps['multiscale_segment_counts'].shape[0] - 1