Skip to content

Commit

Permalink
Add core classes and functions for online clustering diarizer part 1 (#…
Browse files Browse the repository at this point in the history
…5526)

* Add core classes and functions for online clustering diarizer

Signed-off-by: Taejin Park <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add audio to labels code

Signed-off-by: Taejin Park <[email protected]>

* resolve type errors

Signed-off-by: Taejin Park <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* added unit=tests for very short audio

Signed-off-by: Taejin Park <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Filled all missing docstrings

Signed-off-by: Taejin Park <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* resolved conflict and added missing docstrings

Signed-off-by: Taejin Park <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fixed unit-test errors

Signed-off-by: Taejin Park <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix the wrongly added file - megatron_gpt_model.py

Signed-off-by: Taejin Park <[email protected]>

* Fix wrongly included file - megatron_gpt_model.py

Signed-off-by: Taejin Park <[email protected]>

* resolve code quality issue

Signed-off-by: Taejin Park <[email protected]>

* Fixed unit-test errors and bugs

Signed-off-by: Taejin Park <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* changed total_sec for offline_clustering toy_data in unit-tests

Signed-off-by: Taejin Park <[email protected]>

* fixed merging index offset bug

Signed-off-by: Taejin Park <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* only including part 1 files

Signed-off-by: Taejin Park <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* removed unused function

Signed-off-by: Taejin Park <[email protected]>

* fixed unused imports

Signed-off-by: Taejin Park <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* divided nmesc_clustering.py into two and reflected first-pass comments

Signed-off-by: Taejin Park <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* adding offline/online_clustering.py

Signed-off-by: Taejin Park <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix code QL autocomment

Signed-off-by: Taejin Park <[email protected]>

* Removed unused imports

Signed-off-by: Taejin Park <[email protected]>

* Update nemo/collections/asr/parts/utils/online_clustering.py

Co-authored-by: Sean Naren <[email protected]>
Signed-off-by: Taejin Park <[email protected]>

* Reflected comments

Signed-off-by: Taejin Park <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* resolved code scanning issue

Signed-off-by: Taejin Park <[email protected]>

* Update nemo/collections/asr/parts/utils/offline_clustering.py

Co-authored-by: Sean Naren <[email protected]>
Signed-off-by: Taejin Park <[email protected]>

Signed-off-by: Taejin Park <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Nithin Rao <[email protected]>
Co-authored-by: Sean Naren <[email protected]>
  • Loading branch information
4 people authored Dec 15, 2022
1 parent f502d1a commit 815fe5e
Show file tree
Hide file tree
Showing 6 changed files with 1,654 additions and 146 deletions.
2 changes: 1 addition & 1 deletion nemo/collections/asr/data/audio_to_diar_label.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import torch

from nemo.collections.asr.parts.utils.nmesc_clustering import get_argmin_mat
from nemo.collections.asr.parts.utils.offline_clustering import get_argmin_mat
from nemo.collections.asr.parts.utils.speaker_utils import convert_rttm_line, prepare_split_data
from nemo.collections.common.parts.preprocessing.collections import DiarizationSpeechLabel
from nemo.core.classes import Dataset
Expand Down
21 changes: 13 additions & 8 deletions nemo/collections/asr/data/audio_to_label.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,20 +30,25 @@
VALID_FILE_FORMATS = ';'.join(['wav', 'mp3', 'flac'] + [fmt.lower() for fmt in valid_sf_formats.keys()])


def repeat_signal(signal, sig_len, required_length):
def repeat_signal(signal: torch.Tensor, sig_len: int, required_length: int) -> torch.Tensor:
"""repeat signal to make short signal to have required_length
Args:
signal (FloatTensor): input signal
sig_len (LongTensor): length of input signal
required_length(float) : length of generated signal
signal (Tensor): input signal
sig_len (int): length of input signal
required_length (int): length of generated signal
Returns:
signal (FloatTensor): generated signal of required_length by repeating itself.
signal (Tensor): generated signal of required_length by repeating itself.
"""
sub: torch.Tensor = torch.tensor([])
repeat = int(required_length // sig_len)
rem = int(required_length % sig_len)
sub = signal[-rem:] if rem > 0 else torch.tensor([])
rep_sig = torch.cat(repeat * [signal])
signal = torch.cat((rep_sig, sub))
sub: torch.Tensor = torch.tensor([])
rep_sig: torch.Tensor = torch.cat(repeat * [signal])
if rem > 0:
sub = signal[-rem:]
signal = torch.cat((rep_sig, sub))
else:
signal = rep_sig
return signal


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@


@torch.jit.script
def cos_similarity(a: torch.Tensor, b: torch.Tensor, eps=torch.tensor(3.5e-4)) -> torch.Tensor:
def cos_similarity(emb_a: torch.Tensor, emb_b: torch.Tensor, eps=torch.tensor(3.5e-4)) -> torch.Tensor:
"""
Calculate cosine similarities of the given two set of tensors. The output is an N by N
matrix where N is the number of feature vectors.
Expand All @@ -53,8 +53,11 @@ def cos_similarity(a: torch.Tensor, b: torch.Tensor, eps=torch.tensor(3.5e-4)) -
res (Tensor):
N by N matrix containing the cosine similarities of the values.
"""
a_norm = a / (torch.norm(a, dim=1).unsqueeze(1) + eps)
b_norm = b / (torch.norm(a, dim=1).unsqueeze(1) + eps)
# If number of embedding count is 1, it creates nan values
if emb_a.shape[0] == 1 or emb_b.shape[0] == 1:
raise ValueError(f"Number of feature vectors should be greater than 1 but got {emb_a.shape} and {emb_b.shape}")
a_norm = emb_a / (torch.norm(emb_a, dim=1).unsqueeze(1) + eps)
b_norm = emb_b / (torch.norm(emb_b, dim=1).unsqueeze(1) + eps)
res = torch.mm(a_norm, b_norm.transpose(0, 1))
res.fill_diagonal_(1)
return res
Expand Down Expand Up @@ -280,14 +283,14 @@ def getTheLargestComponent(affinity_mat: torch.Tensor, seg_index: int, device: t
Returns:
connected_nodes (Tensor):
A tensor containing booleans that indicate whether the node is connected.
"""
num_of_segments = affinity_mat.shape[0]

connected_nodes = torch.zeros(num_of_segments, dtype=torch.bool).to(device)
nodes_to_explore = torch.zeros(num_of_segments, dtype=torch.bool).to(device)

nodes_to_explore[seg_index] = True
nodes_to_explore = nodes_to_explore.to(device)
for k in range(num_of_segments):
last_num_component = connected_nodes.sum()
torch.logical_or(connected_nodes, nodes_to_explore, out=connected_nodes)
Expand All @@ -298,7 +301,7 @@ def getTheLargestComponent(affinity_mat: torch.Tensor, seg_index: int, device: t
if len(indices.size()) == 0:
indices = indices.unsqueeze(0)
for i in indices:
neighbors = affinity_mat[i]
neighbors = affinity_mat[i].to(device)
torch.logical_or(nodes_to_explore, neighbors.squeeze(0), out=nodes_to_explore)
return connected_nodes

Expand Down Expand Up @@ -331,7 +334,7 @@ def getAffinityGraphMat(affinity_mat_raw: torch.Tensor, p_value: int) -> torch.T
Calculate a binarized graph matrix and
symmetrize the binarized graph matrix.
"""
X = getKneighborsConnections(affinity_mat_raw, p_value)
X = affinity_mat_raw if p_value <= 0 else getKneighborsConnections(affinity_mat_raw, p_value)
symm_affinity_mat = 0.5 * (X + X.T)
return symm_affinity_mat

Expand Down Expand Up @@ -362,9 +365,9 @@ def getRepeatedList(mapping_argmat: torch.Tensor, score_mat_size: torch.Tensor)
repeated indices that will be used for creating a repeated affinity matrix.
This repeated matrix is then used for fusing multiple affinity values.
"""
repeat_list = torch.zeros(score_mat_size, dtype=torch.int32)
repeat_list = torch.zeros(score_mat_size, dtype=torch.int32).to(mapping_argmat.device)
idxs, counts = torch.unique(mapping_argmat, return_counts=True)
repeat_list[idxs] = counts.int()
repeat_list[idxs] = counts.int().to(mapping_argmat.device)
return repeat_list


Expand Down Expand Up @@ -418,12 +421,64 @@ def getCosAffinityMatrix(emb: torch.Tensor) -> torch.Tensor:
Matrix containing cosine similarity values among the given embedding vectors.
dimension: (Number of embedding vectors) x (Number of embedding vectors)
"""
emb = emb.float()
sim_d = cos_similarity(emb, emb)
sim_d = ScalerMinMax(sim_d)
if emb.shape[0] == 1:
sim_d = torch.tensor([[1]]).to(emb.device)
else:
emb = emb.float()
sim_d = cos_similarity(emb, emb)
sim_d = ScalerMinMax(sim_d)
return sim_d


@torch.jit.script
def get_scale_interpolated_embs(
multiscale_weights: torch.Tensor,
embeddings_in_scales: List[torch.Tensor],
timestamps_in_scales: List[torch.Tensor],
device: torch.device = torch.device('cpu'),
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
"""
Generate a scale-interpolated single embedding vector by calculating the weighted sum
of the multiple embedding vectors from different scales. The output is a set of embedding
vectors corresponding to the base-scale segments.
Args:
multiscale_weights (Tensor):
Tensor containing Multiscale weights
Dimensions: (Number of scales) x 1
embeddings_in_scales (list):
List containing split embedding tensors by each scale
timestamps_in_scales (list):
List containing split timestamps tensors by each scale
device (torch.device):
Torch device variable
Returns:
context_emb (torch.tensor):
A set of scale-interpolated embedding vectors.
Dimensions: (Number of base-scale segments) x (Dimensions of embedding vector)
session_scale_mapping_list (list):
List containing argmin arrays indexed by scale index.
"""
rep_mat_list = []
multiscale_weights = multiscale_weights.to(device)
session_scale_mapping_list = get_argmin_mat(timestamps_in_scales)
scale_list = list(range(len(timestamps_in_scales)))
for scale_idx in scale_list:
mapping_argmat = session_scale_mapping_list[scale_idx]
emb_t = embeddings_in_scales[scale_idx].to(device)
mapping_argmat = mapping_argmat.to(device)
repeat_list = getRepeatedList(mapping_argmat, torch.tensor(emb_t.shape[0])).to(device)
rep_emb_t = torch.repeat_interleave(emb_t, repeats=repeat_list, dim=0)
rep_mat_list.append(rep_emb_t)
stacked_scale_embs = torch.stack(rep_mat_list)
context_emb = torch.matmul(stacked_scale_embs.permute(2, 1, 0), multiscale_weights.t()).squeeze().t()
if len(context_emb.shape) < 2:
context_emb = context_emb.unsqueeze(0)
context_emb = context_emb.to(device)
return context_emb, session_scale_mapping_list


@torch.jit.script
def getMultiScaleCosAffinityMatrix(
multiscale_weights: torch.Tensor,
Expand All @@ -436,10 +491,15 @@ def getMultiScaleCosAffinityMatrix(
apply multiscale weights to calculate the fused similarity matrix.
Args:
uniq_embs_and_timestamps (dict):
The dictionary containing embeddings, timestamps and multiscale weights.
If uniq_embs_and_timestamps contains only one scale, single scale diarization
is performed.
multiscale_weights (Tensor):
Tensor containing Multiscale weights
Dimensions: (Number of scales) x 1
embeddings_in_scales (list):
List containing split embedding tensors by each scale
timestamps_in_scales (list):
List containing split timestamps tensors by each scale
device (torch.device):
Torch device variable
Returns:
fused_sim_d (Tensor):
Expand Down Expand Up @@ -539,10 +599,11 @@ def addAnchorEmb(emb: torch.Tensor, anchor_sample_n: int, anchor_spk_n: int, sig
"""
emb_dim = emb.shape[1]
std_org = torch.std(emb, dim=0)
sigma = torch.tensor(sigma).to(emb.device)
new_emb_list = []
for _ in range(anchor_spk_n):
emb_m = torch.tile(torch.randn(1, emb_dim), (anchor_sample_n, 1))
emb_noise = torch.randn(anchor_sample_n, emb_dim).T
emb_m = torch.tile(torch.randn(1, emb_dim), (anchor_sample_n, 1)).to(emb.device)
emb_noise = torch.randn(anchor_sample_n, emb_dim).T.to(emb.device)
emb_noise = torch.matmul(
torch.diag(std_org), emb_noise / torch.max(torch.abs(emb_noise), dim=0)[0].unsqueeze(0)
).T
Expand Down Expand Up @@ -602,14 +663,14 @@ def getEnhancedSpeakerCount(
max_num_speakers=emb.shape[0],
max_rp_threshold=0.15,
sparse_search=True,
sparse_search_volume=50,
sparse_search_volume=10,
fixed_thres=-1.0,
nme_mat_size=300,
cuda=cuda,
)
est_num_of_spk, _ = nmesc.forward()
est_num_of_spk_list.append(est_num_of_spk.item())
comp_est_num_of_spk = torch.mode(torch.tensor(est_num_of_spk_list))[0] - anchor_spk_n
comp_est_num_of_spk = torch.tensor(max(torch.mode(torch.tensor(est_num_of_spk_list))[0].item() - anchor_spk_n, 1))
return comp_est_num_of_spk


Expand All @@ -625,9 +686,9 @@ def split_input_data(
Args:
embeddings_in_scales (Tensor):
Concatenated Torch Tensor containing embeddings in multiple scales
Concatenated Torch tensor containing embeddings in multiple scales
timestamps_in_scales (Tensor):
Concatenated Torch Tensor containing timestamps in multiple scales
Concatenated Torch tensor containing timestamps in multiple scales
multiscale_segment_counts (LongTensor):
Concatenated Torch LongTensor containing number of segments per each scale
Expand Down Expand Up @@ -860,6 +921,8 @@ def __init__(
Number of p_values we search during NME analysis.
Default is 30. The lower the value, the faster NME-analysis becomes.
However, a value lower than 20 might cause a poor parameter estimation.
nme_mat_size (int):
Targeted size of matrix for NME analysis.
use_subsampling_for_nme (bool):
Use subsampling to reduce the calculational complexity.
Default is True.
Expand All @@ -875,8 +938,9 @@ def __init__(
If True, turn on parallelism based on torch.jit.script library.
cuda (bool):
Use cuda for Eigen decomposition if cuda=True.
nme_mat_size (int):
Targeted size of matrix for NME analysis.
device (torch.device):
Torch device variable
"""
self.max_num_speakers: int = max_num_speakers
self.max_rp_threshold: float = max_rp_threshold
Expand All @@ -886,18 +950,24 @@ def __init__(
self.sparse_search_volume: int = sparse_search_volume
self.min_p_value = torch.tensor(2)
self.fixed_thres: float = fixed_thres
self.cuda: bool = cuda
self.eps = 1e-10
self.max_N = torch.tensor(0)
self.mat: torch.Tensor = mat
self.p_value_list: torch.Tensor = self.min_p_value.unsqueeze(0)
self.device = device
self.cuda: bool = cuda
self.device: torch.device = device
self.maj_vote_spk_count: bool = maj_vote_spk_count
self.parallelism: bool = parallelism

def forward(self) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Subsample the input matrix to reduce the computational load.
Returns:
est_num_of_spk (Tensor):
Estimated number of speakers from NMESC approach
p_hat_value (Tensor):
Estimated p-value (determines how many neighboring values to be selected)
"""
if self.use_subsampling_for_nme:
subsample_ratio = self.subsampleAffinityMat(self.nme_mat_size)
Expand Down Expand Up @@ -1025,7 +1095,7 @@ def getPvalueList(self) -> torch.Tensor:
self.max_N = torch.max(
torch.floor(torch.tensor(self.mat.shape[0] * self.fixed_thres)).type(torch.int), self.min_p_value
)
p_value_list = torch.tensor(self.max_N).unsqueeze(0).int()
p_value_list = self.max_N.unsqueeze(0).int()
else:
self.max_N = torch.max(
torch.floor(torch.tensor(self.mat.shape[0] * self.max_rp_threshold)).type(torch.int), self.min_p_value
Expand Down Expand Up @@ -1088,7 +1158,7 @@ def __init__(
def forward(self, param_dict: Dict[str, torch.Tensor]) -> torch.LongTensor:
"""
A function wrapper designed for inference in exported script format.
Note:
Dict is used to allow easy inference of the exported jit model in Triton server using easy to understand
naming convention.
Expand Down Expand Up @@ -1138,7 +1208,7 @@ def forward_infer(
oracle_num_speakers: int = -1,
max_rp_threshold: float = 0.15,
max_num_speakers: int = 8,
enhanced_count_thres: int = 80,
enhanced_count_thres: int = 40,
sparse_search_volume: int = 30,
fixed_thres: float = -1.0,
) -> torch.LongTensor:
Expand Down Expand Up @@ -1180,7 +1250,7 @@ def forward_infer(
Limits the range of parameter search.
Clustering performance can vary depending on this range.
Default is 0.15.
enhanced_count_thres: (int)
enhanced_count_thres (int):
For the short audio recordings, clustering algorithm cannot
accumulate enough amount of speaker profile for each cluster.
Thus, function `getEnhancedSpeakerCount` employs anchor embeddings
Expand All @@ -1203,7 +1273,8 @@ def forward_infer(
embeddings_in_scales, timestamps_in_scales, multiscale_segment_counts
)

emb = self.embeddings_in_scales[multiscale_segment_counts.shape[0] - 1]
# Last slot is the base scale embeddings
emb = self.embeddings_in_scales[-1]

# Cases for extreamly short sessions
if emb.shape[0] == 1:
Expand Down Expand Up @@ -1239,7 +1310,8 @@ def forward_infer(
est_num_of_spk, p_hat_value = nmesc.forward()
affinity_mat = getAffinityGraphMat(mat, p_hat_value)
else:
est_num_of_spk = torch.tensor(1)
nmesc.fixed_thres = max_rp_threshold
est_num_of_spk, p_hat_value = nmesc.forward()
affinity_mat = mat

# n_clusters is number of speakers estimated from spectral clustering.
Expand Down
Loading

0 comments on commit 815fe5e

Please sign in to comment.