Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Dec 7, 2022
1 parent bcaf7bc commit 7534ad5
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 14 deletions.
18 changes: 6 additions & 12 deletions nemo/collections/asr/parts/utils/online_clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,9 +350,7 @@ def merge_vectors(selected_inds: torch.Tensor, emb_ndx: torch.Tensor, pre_cluste


@torch.jit.script
def get_closest_embeddings(
affinity_mat: torch.Tensor, n_closest: int
) -> Tuple[torch.Tensor, torch.Tensor]:
def get_closest_embeddings(affinity_mat: torch.Tensor, n_closest: int) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Get the indices of the embedding vectors we want to merge.
Expand Down Expand Up @@ -387,9 +385,7 @@ def get_closest_embeddings(
"""
comb_limit = int(affinity_mat.shape[0] - 1)
if n_closest > comb_limit:
raise ValueError(
f"Got n_closest of {n_closest}: {n_closest} is bigger than comb_limit {comb_limit}"
)
raise ValueError(f"Got n_closest of {n_closest}: {n_closest} is bigger than comb_limit {comb_limit}")

# Take summed values over one axis
sum_cmat = affinity_mat.sum(0)
Expand Down Expand Up @@ -1046,7 +1042,9 @@ def match_labels(self, Y_new: torch.Tensor, add_new: bool) -> torch.Tensor:
try:
Y_out = torch.hstack((self.Y_fullhist[: self.history_buffer_seg_end], Y_matched[self.history_n :]))
except:
import ipdb; ipdb.set_trace()
import ipdb

ipdb.set_trace()
self.Y_fullhist = Y_out
else:
# Do not update cumulative labels since there are no new segments.
Expand All @@ -1058,11 +1056,7 @@ def match_labels(self, Y_new: torch.Tensor, add_new: bool) -> torch.Tensor:
return Y_out

def forward_infer(
self,
emb: torch.Tensor,
frame_index: int,
enhanced_count_thres: int = 40,
cuda: bool = False,
self, emb: torch.Tensor, frame_index: int, enhanced_count_thres: int = 40, cuda: bool = False,
) -> torch.Tensor:
"""
Perform speaker clustering in online mode. Embedding vector set `emb` is expected to be containing
Expand Down
3 changes: 1 addition & 2 deletions tests/collections/asr/test_diar_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,7 +582,7 @@ def test_online_speaker_clustering(self, n_spks, total_sec, buffer_size, sigma,
assert add_new
cumul_label_acc = sum(evaluation_list) / len(evaluation_list)
assert cumul_label_acc > 0.9

@pytest.mark.run_only_on('CPU')
@pytest.mark.unit
@pytest.mark.parametrize("n_spks", [4])
Expand All @@ -592,4 +592,3 @@ def test_online_speaker_clustering(self, n_spks, total_sec, buffer_size, sigma,
@pytest.mark.parametrize("seed", [0])
def test_online_speaker_clustering_cpu(self, n_spks, total_sec, buffer_size, sigma, seed):
self.test_online_speaker_clustering(n_spks, total_sec, buffer_size, sigma, seed)

0 comments on commit 7534ad5

Please sign in to comment.