diff --git a/src/diart/functional.py b/src/diart/functional.py index 1d340c74..e15a85b3 100644 --- a/src/diart/functional.py +++ b/src/diart/functional.py @@ -280,6 +280,25 @@ def __call__(self, buffers: List[SlidingWindowFeature]) -> SlidingWindowFeature: class OnlineSpeakerClustering: + """Makes an object for handling constrained incremental online clustering of speakers and maintaining cluster centers. + + Parameters + ---------- + tau_active:float + Threshold for detecting active speakers. This threshold is applied on the maximum value of per-speaker output + activation of the local segmentation model. + rho_update: float + Threshold for considering the extracted embedding when updating the centroid of the local speaker. + The centroid to which a local speaker is mapped is only updated if the ratio of speech/chunk duration of + of a given local speaker is greater than this threshold. + delta_new: float + Threshold on the distance between a speaker embedding and a centroid. If the distance between a local speaker and all + centroids is larger than delta_new, then a new centroid is created for the current speaker. + metric: str. Defaults to "cosine". + The distance metric to use. + max_speakers: int. Defaults to 20. + Optional Maximum number of global speakers to track through a conversation. + """ def __init__( self, tau_active: float, @@ -323,17 +342,44 @@ def get_next_center_position(self) -> Optional[int]: return center def init_centers(self, dimension: int): + """Initializes the speaker centroid matrix + + Parameters + ---------- + dimension: int + Dimension of embeddings used for representing a speaker. + """ self.centers = np.zeros((self.max_speakers, dimension)) self.active_centers = set() self.blocked_centers = set() def update(self, assignments: Iterable[Tuple[int, int]], embeddings: np.ndarray): + """Updates the speaker centroids given a list of assignments and local speaker embeddings + + Parameters + ---------- + assignments: Iterable[Tuple[int, int]]) + An iterable of tuples with two elements having the first element as the source speaker + and the second element as the target speaker. + embeddings: np.ndarray of shape (n_local_speakers, dim_embedding) + Matrix containing embeddings for all local speakers. + """ if self.centers is not None: for l_spk, g_spk in assignments: assert g_spk in self.active_centers, "Cannot update unknown centers" self.centers[g_spk] += embeddings[l_spk] def add_center(self, embedding: np.ndarray) -> int: + """Add a new speaker centroid initialized to a given embedding + + Parameters + ---------- + embedding: np.ndarray + Embedding vector of some local speaker + + Returns: + int: index of the created center + """ center = self.get_next_center_position() self.centers[center] = embedding self.active_centers.add(center) @@ -344,6 +390,18 @@ def identify( segmentation: SlidingWindowFeature, embeddings: torch.Tensor ) -> SpeakerMap: + """Identify the centroids to which the input speaker embeddings belong. + + Parameters + ---------- + segmentation: np.ndarray of shape (n_frames, n_local_speakers) + Matrix of segmentation outputs + embeddings: np.ndarray of shape (n_local_speakers, dim_embedding) + Matrix of embeddings + + Returns: + SpeakerMap: a mapping from local speakers to global speakers. + """ embeddings = embeddings.detach().cpu().numpy() active_speakers = np.where(np.max(segmentation.data, axis=0) >= self.tau_active)[0] long_speakers = np.where(np.mean(segmentation.data, axis=0) >= self.rho_update)[0] diff --git a/src/diart/mapping.py b/src/diart/mapping.py index 7327ac6e..bad45366 100644 --- a/src/diart/mapping.py +++ b/src/diart/mapping.py @@ -22,6 +22,21 @@ def mapped_indices(self, matrix: np.ndarray, axis: int) -> List[int]: def hard_speaker_map( self, num_src: int, num_tgt: int, assignments: Iterable[Tuple[int, int]] ) -> SpeakerMap: + """Returns a hard map object where the highest cost is put everywhere except on hard assignments from assignments + + Parameters + ---------- + num_src: int + Number of source speakers + num_tgt: int + Number of target speakers + assignments: Iterable[Tuple[int, int]] + An iterable of tuples with two elements having the first element as the source speaker + and the second element as the target speaker + + Returns: + SpeakerMap + """ mapping_matrix = self.invalid_tensor(shape=(num_src, num_tgt)) for src, tgt in assignments: mapping_matrix[src, tgt] = self.best_possible_value @@ -82,6 +97,21 @@ class SpeakerMapBuilder: def hard_map( shape: Tuple[int, int], assignments: Iterable[Tuple[int, int]], maximize: bool ) -> SpeakerMap: + """Returns a SpeakerMap object based on the given assignments. This is a "hard" map, meaning that the + highest cost is put everywhere except on hard assignments from assignments. + + Parameters + ---------- + shape: Tuple[int, int]) + shape of the mapping matrix + assignments: Iterable[Tuple[int, int]] + An iterable of tuples with two elements having the first element as the source speaker + and the second element as the target speaker + maximize: bool + whether to use scores where higher is better (true) or where lower is better (false) + + Returns: a SpeakerMap + """ num_src, num_tgt = shape objective = MaximizationObjective if maximize else MinimizationObjective return objective().hard_speaker_map(num_src, num_tgt, assignments)