Skip to content

Commit

Permalink
Reduce memory usage in getMultiScaleCosAffinityMatrix function (NVIDI…
Browse files Browse the repository at this point in the history
…A#5876)

* Updated offline_clustering.py, the getMultiScaleCosAffinityMatrix function, reduced memory usage

Signed-off-by: gabitza-tech <[email protected]>

* torch.empty.cache() outside forward_infer()

Signed-off-by: Taejin Park <[email protected]>

* Removed unnecessary lines

Signed-off-by: Taejin Park <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Speed up for non torch.jit.script

Signed-off-by: Taejin Park <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* parallelism is default off

Signed-off-by: Taejin Park <[email protected]>

* nme_mat_size is unified as 512, removing redundant docstring

Signed-off-by: Taejin Park <[email protected]>

---------

Signed-off-by: gabitza-tech <[email protected]>
Signed-off-by: Taejin Park <[email protected]>
Co-authored-by: Taejin Park <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Signed-off-by: Jason <[email protected]>
  • Loading branch information
3 people authored and blisc committed Feb 10, 2023
1 parent 672cbcc commit 7bf8d9b
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 51 deletions.
72 changes: 25 additions & 47 deletions nemo/collections/asr/parts/utils/offline_clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
from torch.linalg import eigh, eigvalsh


@torch.jit.script
def cos_similarity(emb_a: torch.Tensor, emb_b: torch.Tensor, eps=torch.tensor(3.5e-4)) -> torch.Tensor:
"""
Calculate cosine similarities of the given two set of tensors. The output is an N by N
Expand All @@ -63,7 +62,6 @@ def cos_similarity(emb_a: torch.Tensor, emb_b: torch.Tensor, eps=torch.tensor(3.
return res


@torch.jit.script
def ScalerMinMax(X: torch.Tensor) -> torch.Tensor:
"""
Min-max scale the input affinity matrix X, which will lead to a dynamic range of [0, 1].
Expand All @@ -81,7 +79,6 @@ def ScalerMinMax(X: torch.Tensor) -> torch.Tensor:
return v_norm


@torch.jit.script
def getEuclideanDistance(
specEmbA: torch.Tensor, specEmbB: torch.Tensor, device: torch.device = torch.device('cpu')
) -> torch.Tensor:
Expand All @@ -105,7 +102,6 @@ def getEuclideanDistance(
return dis


@torch.jit.script
def kmeans_plusplus_torch(
X: torch.Tensor,
n_clusters: int,
Expand Down Expand Up @@ -190,7 +186,6 @@ def kmeans_plusplus_torch(
return centers, indices


@torch.jit.script
def kmeans_torch(
X: torch.Tensor,
num_clusters: int,
Expand Down Expand Up @@ -235,7 +230,6 @@ def kmeans_torch(
plusplus_init_states = kmeans_plusplus_torch(X, n_clusters=num_clusters, random_state=random_state, device=device)
centers = plusplus_init_states[0]

iter_count = 0
selected_cluster_indices = torch.zeros(input_size).long()

for iter_count in range(iter_limit):
Expand Down Expand Up @@ -268,7 +262,6 @@ def kmeans_torch(
return selected_cluster_indices


@torch.jit.script
def getTheLargestComponent(affinity_mat: torch.Tensor, seg_index: int, device: torch.device) -> torch.Tensor:
"""
Find the largest affinity_mat connected components for each given node.
Expand Down Expand Up @@ -306,15 +299,13 @@ def getTheLargestComponent(affinity_mat: torch.Tensor, seg_index: int, device: t
return connected_nodes


@torch.jit.script
def isGraphFullyConnected(affinity_mat: torch.Tensor, device: torch.device) -> torch.Tensor:
"""
Check whether the given affinity matrix is a fully connected graph.
"""
return getTheLargestComponent(affinity_mat, 0, device).sum() == affinity_mat.shape[0]


@torch.jit.script
def getKneighborsConnections(affinity_mat: torch.Tensor, p_value: int) -> torch.Tensor:
"""
Binarize top-p values for each row from the given affinity matrix.
Expand All @@ -328,7 +319,6 @@ def getKneighborsConnections(affinity_mat: torch.Tensor, p_value: int) -> torch.
return binarized_affinity_mat


@torch.jit.script
def getAffinityGraphMat(affinity_mat_raw: torch.Tensor, p_value: int) -> torch.Tensor:
"""
Calculate a binarized graph matrix and
Expand All @@ -339,7 +329,6 @@ def getAffinityGraphMat(affinity_mat_raw: torch.Tensor, p_value: int) -> torch.T
return symm_affinity_mat


@torch.jit.script
def getMinimumConnection(
mat: torch.Tensor, max_N: torch.Tensor, n_list: torch.Tensor, device: torch.device
) -> Tuple[torch.Tensor, torch.Tensor]:
Expand All @@ -358,7 +347,6 @@ def getMinimumConnection(
return affinity_mat, p_value


@torch.jit.script
def getRepeatedList(mapping_argmat: torch.Tensor, score_mat_size: torch.Tensor) -> torch.Tensor:
"""
Count the numbers in the mapping dictionary and create lists that contain
Expand All @@ -371,7 +359,6 @@ def getRepeatedList(mapping_argmat: torch.Tensor, score_mat_size: torch.Tensor)
return repeat_list


@torch.jit.script
def get_argmin_mat(timestamps_in_scales: List[torch.Tensor]) -> List[torch.Tensor]:
"""
Calculate the mapping between the base scale and other scales. A segment from a longer scale is
Expand Down Expand Up @@ -404,7 +391,6 @@ def get_argmin_mat(timestamps_in_scales: List[torch.Tensor]) -> List[torch.Tenso
return session_scale_mapping_list


@torch.jit.script
def getCosAffinityMatrix(emb: torch.Tensor) -> torch.Tensor:
"""
Calculate cosine similarity values among speaker embeddings then min-max normalize
Expand All @@ -430,7 +416,6 @@ def getCosAffinityMatrix(emb: torch.Tensor) -> torch.Tensor:
return sim_d


@torch.jit.script
def get_scale_interpolated_embs(
multiscale_weights: torch.Tensor,
embeddings_in_scales: List[torch.Tensor],
Expand Down Expand Up @@ -479,7 +464,6 @@ def get_scale_interpolated_embs(
return context_emb, session_scale_mapping_list


@torch.jit.script
def getMultiScaleCosAffinityMatrix(
multiscale_weights: torch.Tensor,
embeddings_in_scales: List[torch.Tensor],
Expand All @@ -489,10 +473,11 @@ def getMultiScaleCosAffinityMatrix(
"""
Calculate cosine similarity values among speaker embeddings for each scale then
apply multiscale weights to calculate the fused similarity matrix.
NOTE: Due to CUDA memory limit, the embedding vectors in embeddings_in_scales are stored in `cpu` device.
Args:
multiscale_weights (Tensor):
Tensor containing Multiscale weights
Tensor containing multiscale weights
Dimensions: (Number of scales) x 1
embeddings_in_scales (list):
List containing split embedding tensors by each scale
Expand All @@ -503,27 +488,24 @@ def getMultiScaleCosAffinityMatrix(
Returns:
fused_sim_d (Tensor):
This function generates an affinity matrix that is obtained by calculating
the weighted sum of the affinity matrices from the different scales.
An affinity matrix that is obtained by calculating the weighted sum of
the multiple affinity matrices from the different scales.
"""
multiscale_weights = multiscale_weights.to(device)
score_mat_list, repeated_tensor_list = [], []
multiscale_weights = torch.squeeze(multiscale_weights, dim=0).to(device)
session_scale_mapping_list = get_argmin_mat(timestamps_in_scales)
scale_list = list(range(len(timestamps_in_scales)))
fused_sim_d = torch.zeros(len(timestamps_in_scales[-1]), len(timestamps_in_scales[-1])).to(device)
for scale_idx in scale_list:
mapping_argmat = session_scale_mapping_list[scale_idx]
emb_t = embeddings_in_scales[scale_idx].half().to(device)
score_mat_torch = getCosAffinityMatrix(emb_t)
repeat_list = getRepeatedList(mapping_argmat, torch.tensor(score_mat_torch.shape[0])).to(device)
repeated_tensor_0 = torch.repeat_interleave(score_mat_torch, repeats=repeat_list, dim=0)
repeated_tensor_1 = torch.repeat_interleave(repeated_tensor_0, repeats=repeat_list, dim=1)
repeated_tensor_list.append(repeated_tensor_1)
repp = torch.stack(repeated_tensor_list).float()
fused_sim_d = torch.matmul(repp.permute(2, 1, 0), multiscale_weights.t()).squeeze(2).t()
repeated_tensor_0 = torch.repeat_interleave(score_mat_torch, repeats=repeat_list, dim=0).to(device)
repeated_tensor_1 = torch.repeat_interleave(repeated_tensor_0, repeats=repeat_list, dim=1).to(device)
fused_sim_d += multiscale_weights[scale_idx] * repeated_tensor_1
return fused_sim_d


@torch.jit.script
def getLaplacian(X: torch.Tensor) -> torch.Tensor:
"""
Calculate a laplacian matrix from an affinity matrix X.
Expand All @@ -535,7 +517,6 @@ def getLaplacian(X: torch.Tensor) -> torch.Tensor:
return L


@torch.jit.script
def eigDecompose(
laplacian: torch.Tensor, cuda: bool, device: torch.device = torch.device('cpu')
) -> Tuple[torch.Tensor, torch.Tensor]:
Expand All @@ -552,7 +533,6 @@ def eigDecompose(
return lambdas, diffusion_map


@torch.jit.script
def eigValueSh(laplacian: torch.Tensor, cuda: bool, device: torch.device = torch.device('cpu')) -> torch.Tensor:
"""
Calculate only eigenvalues from the Laplacian matrix.
Expand All @@ -567,7 +547,6 @@ def eigValueSh(laplacian: torch.Tensor, cuda: bool, device: torch.device = torch
return lambdas


@torch.jit.script
def getLamdaGaplist(lambdas: torch.Tensor) -> torch.Tensor:
"""
Calculate the gaps between lambda values.
Expand All @@ -577,7 +556,6 @@ def getLamdaGaplist(lambdas: torch.Tensor) -> torch.Tensor:
return lambdas[1:] - lambdas[:-1]


@torch.jit.script
def addAnchorEmb(emb: torch.Tensor, anchor_sample_n: int, anchor_spk_n: int, sigma: float) -> torch.Tensor:
"""
Add randomly generated synthetic embeddings to make eigenanalysis more stable.
Expand Down Expand Up @@ -674,7 +652,6 @@ def getEnhancedSpeakerCount(
return comp_est_num_of_spk


@torch.jit.script
def split_input_data(
embeddings_in_scales: torch.Tensor,
timestamps_in_scales: torch.Tensor,
Expand Down Expand Up @@ -705,7 +682,6 @@ def split_input_data(
return embeddings_in_scales, timestamps_in_scales


@torch.jit.script
def estimateNumofSpeakers(
affinity_mat: torch.Tensor, max_num_speakers: int, cuda: bool = False
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
Expand All @@ -728,16 +704,14 @@ def estimateNumofSpeakers(
lambda_gap (Tensor):
The gap between the lambda values from eigendecomposition
"""
with torch.no_grad():
laplacian = getLaplacian(affinity_mat)
lambdas = eigValueSh(laplacian, cuda=cuda)
lambdas = torch.sort(lambdas)[0]
lambda_gap = getLamdaGaplist(lambdas)
num_of_spk = torch.argmax(lambda_gap[: min(max_num_speakers, lambda_gap.shape[0])]) + 1
laplacian = getLaplacian(affinity_mat)
lambdas = eigValueSh(laplacian, cuda=cuda, device=affinity_mat.device)
lambdas = torch.sort(lambdas)[0]
lambda_gap = getLamdaGaplist(lambdas)
num_of_spk = torch.argmax(lambda_gap[: min(max_num_speakers, lambda_gap.shape[0])]) + 1
return num_of_spk, lambdas, lambda_gap


@torch.jit.script
class SpectralClustering:
"""
Perform spectral clustering by calculating spectral embeddings then run k-means clustering
Expand Down Expand Up @@ -816,6 +790,7 @@ def clusterSpectralEmbeddings(
"""
spectral_emb = self.getSpectralEmbeddings(affinity, n_spks=self.n_clusters, cuda=cuda)
labels_set = []

for random_state_seed in range(self.random_state, self.random_state + self.n_random_trials):
_labels = kmeans_torch(
X=spectral_emb, num_clusters=self.n_clusters, random_state=random_state_seed, device=device
Expand Down Expand Up @@ -843,14 +818,13 @@ def getSpectralEmbeddings(self, affinity_mat: torch.Tensor, n_spks: int = 8, cud
clustering label output
"""
laplacian = getLaplacian(affinity_mat)
lambdas_, diffusion_map_ = eigDecompose(laplacian, cuda=cuda)
_, diffusion_map_ = eigDecompose(laplacian, cuda=cuda, device=affinity_mat.device)
diffusion_map = diffusion_map_[:, :n_spks]
inv_idx = torch.arange(diffusion_map.size(1) - 1, -1, -1).long()
embedding = diffusion_map.T[inv_idx, :]
return embedding[:n_spks].T


@torch.jit.script
class NMESC:
"""
Normalized Maximum Eigengap based Spectral Clustering (NME-SC)
Expand Down Expand Up @@ -1118,10 +1092,10 @@ class SpeakerClustering(torch.nn.Module):
def __init__(
self,
min_samples_for_nmesc: int = 6,
nme_mat_size: int = 300,
nme_mat_size: int = 512,
sparse_search: bool = True,
maj_vote_spk_count: bool = False,
parallelism: bool = True,
parallelism: bool = False,
cuda: bool = False,
):
"""
Expand Down Expand Up @@ -1164,7 +1138,6 @@ def forward(self, param_dict: Dict[str, torch.Tensor]) -> torch.LongTensor:
naming convention.
See https://github.com/triton-inference-server/server/blob/main/docs/user_guide/model_configuration.md#special-conventions-for-pytorch-backend
Args:
param_dict (dict):
Dictionary containing the arguments for speaker clustering.
Expand Down Expand Up @@ -1211,6 +1184,7 @@ def forward_infer(
enhanced_count_thres: int = 40,
sparse_search_volume: int = 30,
fixed_thres: float = -1.0,
kmeans_random_trials: int = 1,
) -> torch.LongTensor:
"""
Calculate affinity matrix using timestamps and speaker embeddings, run NME analysis to estimate the best
Expand Down Expand Up @@ -1264,6 +1238,9 @@ def forward_infer(
If fixed_thres value is provided, NME-analysis process will be skipped.
This value should be optimized on a development set to obtain a quality result.
Default is None and performs NME-analysis to estimate the threshold.
kmeans_random_trials (int):
Number of random trials for initializing k-means clustering. More trials
will result in a more stable clustering result. Default is 1.
Returns:
Y (LongTensor):
Expand All @@ -1272,7 +1249,6 @@ def forward_infer(
self.embeddings_in_scales, self.timestamps_in_scales = split_input_data(
embeddings_in_scales, timestamps_in_scales, multiscale_segment_counts
)

# Last slot is the base scale embeddings
emb = self.embeddings_in_scales[-1]

Expand Down Expand Up @@ -1322,6 +1298,8 @@ def forward_infer(
else:
n_clusters = int(est_num_of_spk.item())

spectral_model = SpectralClustering(n_clusters=n_clusters, cuda=self.cuda, device=self.device)
spectral_model = SpectralClustering(
n_clusters=n_clusters, n_random_trials=kmeans_random_trials, cuda=self.cuda, device=self.device
)
Y = spectral_model.forward(affinity_mat)
return Y
13 changes: 9 additions & 4 deletions nemo/collections/asr/parts/utils/speaker_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,8 +410,6 @@ def perform_clustering(embs_and_timestamps, AUDIO_RTTM_MAP, out_rttm_dir, cluste
AUDIO_RTTM_MAP (dict): AUDIO_RTTM_MAP for mapping unique id with audio file path and rttm path
out_rttm_dir (str): Path to write predicted rttms
clustering_params (dict): clustering parameters provided through config that contains max_num_speakers (int),
oracle_num_speakers (bool), max_rp_threshold(float), sparse_search_volume(int) and enhance_count_threshold (int)
use_torch_script (bool): Boolean that determines whether to use torch.jit.script for speaker clustering
Returns:
all_reference (list[uniq_name,Annotation]): reference annotations for score calculation
Expand All @@ -428,7 +426,7 @@ def perform_clustering(embs_and_timestamps, AUDIO_RTTM_MAP, out_rttm_dir, cluste
logging.warning("cuda=False, using CPU for eigen decomposition. This might slow down the clustering process.")
cuda = False

speaker_clustering = SpeakerClustering(maj_vote_spk_count=clustering_params.maj_vote_spk_count, cuda=cuda)
speaker_clustering = SpeakerClustering(cuda=cuda)

# If True, export torch script module and save it to the base folder.
if clustering_params.get('export_script_module', False):
Expand All @@ -445,6 +443,8 @@ def perform_clustering(embs_and_timestamps, AUDIO_RTTM_MAP, out_rttm_dir, cluste
else:
num_speakers = -1

base_scale_idx = uniq_embs_and_timestamps['multiscale_segment_counts'].shape[0] - 1

cluster_labels = speaker_clustering.forward_infer(
embeddings_in_scales=uniq_embs_and_timestamps['embeddings'],
timestamps_in_scales=uniq_embs_and_timestamps['timestamps'],
Expand All @@ -456,7 +456,12 @@ def perform_clustering(embs_and_timestamps, AUDIO_RTTM_MAP, out_rttm_dir, cluste
sparse_search_volume=int(clustering_params.sparse_search_volume),
)

base_scale_idx = uniq_embs_and_timestamps['multiscale_segment_counts'].shape[0] - 1
del uniq_embs_and_timestamps
if cuda:
torch.cuda.empty_cache()
else:
gc.collect()

timestamps = speaker_clustering.timestamps_in_scales[base_scale_idx]
cluster_labels = cluster_labels.cpu().numpy()
if len(cluster_labels) != timestamps.shape[0]:
Expand Down

0 comments on commit 7bf8d9b

Please sign in to comment.