Skip to content

Commit

Permalink
Adds documentation for some of the classes and methods under function…
Browse files Browse the repository at this point in the history
…al.py and mapping.py
  • Loading branch information
zaouk committed Jun 20, 2022
1 parent 9e51718 commit 18983aa
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 0 deletions.
58 changes: 58 additions & 0 deletions src/diart/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,25 @@ def __call__(self, buffers: List[SlidingWindowFeature]) -> SlidingWindowFeature:


class OnlineSpeakerClustering:
"""Makes an object for handling constrained incremental online clustering of speakers and maintaining cluster centers.
Parameters
----------
tau_active:float
Threshold for detecting active speakers. This threshold is applied on the maximum value of per-speaker output
activation of the local segmentation model.
rho_update: float
Threshold for considering the extracted embedding when updating the centroid of the local speaker.
The centroid to which a local speaker is mapped is only updated if the ratio of speech/chunk duration of
of a given local speaker is greater than this threshold.
delta_new: float
Threshold on the distance between a speaker embedding and a centroid. If the distance between a local speaker and all
centroids is larger than delta_new, then a new centroid is created for the current speaker.
metric: str. Defaults to "cosine".
The distance metric to use.
max_speakers: int. Defaults to 20.
Optional Maximum number of global speakers to track through a conversation.
"""
def __init__(
self,
tau_active: float,
Expand Down Expand Up @@ -323,17 +342,44 @@ def get_next_center_position(self) -> Optional[int]:
return center

def init_centers(self, dimension: int):
"""Initializes the speaker centroid matrix
Parameters
----------
dimension: int
Dimension of embeddings used for representing a speaker.
"""
self.centers = np.zeros((self.max_speakers, dimension))
self.active_centers = set()
self.blocked_centers = set()

def update(self, assignments: Iterable[Tuple[int, int]], embeddings: np.ndarray):
"""Updates the speaker centroids given a list of assignments and local speaker embeddings
Parameters
----------
assignments: Iterable[Tuple[int, int]])
An iterable of tuples with two elements having the first element as the source speaker
and the second element as the target speaker.
embeddings: np.ndarray of shape (n_local_speakers, dim_embedding)
Matrix containing embeddings for all local speakers.
"""
if self.centers is not None:
for l_spk, g_spk in assignments:
assert g_spk in self.active_centers, "Cannot update unknown centers"
self.centers[g_spk] += embeddings[l_spk]

def add_center(self, embedding: np.ndarray) -> int:
"""Add a new speaker centroid initialized to a given embedding
Parameters
----------
embedding: np.ndarray
Embedding vector of some local speaker
Returns:
int: index of the created center
"""
center = self.get_next_center_position()
self.centers[center] = embedding
self.active_centers.add(center)
Expand All @@ -344,6 +390,18 @@ def identify(
segmentation: SlidingWindowFeature,
embeddings: torch.Tensor
) -> SpeakerMap:
"""Identify the centroids to which the input speaker embeddings belong.
Parameters
----------
segmentation: np.ndarray of shape (n_frames, n_local_speakers)
Matrix of segmentation outputs
embeddings: np.ndarray of shape (n_local_speakers, dim_embedding)
Matrix of embeddings
Returns:
SpeakerMap: a mapping from local speakers to global speakers.
"""
embeddings = embeddings.detach().cpu().numpy()
active_speakers = np.where(np.max(segmentation.data, axis=0) >= self.tau_active)[0]
long_speakers = np.where(np.mean(segmentation.data, axis=0) >= self.rho_update)[0]
Expand Down
30 changes: 30 additions & 0 deletions src/diart/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,21 @@ def mapped_indices(self, matrix: np.ndarray, axis: int) -> List[int]:
def hard_speaker_map(
self, num_src: int, num_tgt: int, assignments: Iterable[Tuple[int, int]]
) -> SpeakerMap:
"""Returns a hard map object where the highest cost is put everywhere except on hard assignments from assignments
Parameters
----------
num_src: int
Number of source speakers
num_tgt: int
Number of target speakers
assignments: Iterable[Tuple[int, int]]
An iterable of tuples with two elements having the first element as the source speaker
and the second element as the target speaker
Returns:
SpeakerMap
"""
mapping_matrix = self.invalid_tensor(shape=(num_src, num_tgt))
for src, tgt in assignments:
mapping_matrix[src, tgt] = self.best_possible_value
Expand Down Expand Up @@ -82,6 +97,21 @@ class SpeakerMapBuilder:
def hard_map(
shape: Tuple[int, int], assignments: Iterable[Tuple[int, int]], maximize: bool
) -> SpeakerMap:
"""Returns a SpeakerMap object based on the given assignments. This is a "hard" map, meaning that the
highest cost is put everywhere except on hard assignments from assignments.
Parameters
----------
shape: Tuple[int, int])
shape of the mapping matrix
assignments: Iterable[Tuple[int, int]]
An iterable of tuples with two elements having the first element as the source speaker
and the second element as the target speaker
maximize: bool
whether to use scores where higher is better (true) or where lower is better (false)
Returns: a SpeakerMap
"""
num_src, num_tgt = shape
objective = MaximizationObjective if maximize else MinimizationObjective
return objective().hard_speaker_map(num_src, num_tgt, assignments)
Expand Down

0 comments on commit 18983aa

Please sign in to comment.