From a6ddb07907afe81dc9919c216e1beff4a791086a Mon Sep 17 00:00:00 2001 From: Khaled Zaouk Date: Mon, 20 Jun 2022 11:17:28 +0200 Subject: [PATCH] Add documentation for some classes and methods (#31) * Adds documentation for some of the classes and methods under functional.py and mapping.py Co-authored-by: Juan Coria --- src/diart/blocks/clustering.py | 62 ++++++++++++++++++++++++++++++++++ src/diart/mapping.py | 34 +++++++++++++++++++ src/diart/pipelines.py | 3 +- 3 files changed, 97 insertions(+), 2 deletions(-) diff --git a/src/diart/blocks/clustering.py b/src/diart/blocks/clustering.py index 57a3ab2c..882001b9 100644 --- a/src/diart/blocks/clustering.py +++ b/src/diart/blocks/clustering.py @@ -8,6 +8,25 @@ class OnlineSpeakerClustering: + """Implements constrained incremental online clustering of speakers and manages cluster centers. + + Parameters + ---------- + tau_active:float + Threshold for detecting active speakers. This threshold is applied on the maximum value of per-speaker output + activation of the local segmentation model. + rho_update: float + Threshold for considering the extracted embedding when updating the centroid of the local speaker. + The centroid to which a local speaker is mapped is only updated if the ratio of speech/chunk duration + of a given local speaker is greater than this threshold. + delta_new: float + Threshold on the distance between a speaker embedding and a centroid. If the distance between a local speaker and all + centroids is larger than delta_new, then a new centroid is created for the current speaker. + metric: str. Defaults to "cosine". + The distance metric to use. + max_speakers: int + Maximum number of global speakers to track through a conversation. Defaults to 20. + """ def __init__( self, tau_active: float, @@ -51,17 +70,46 @@ def get_next_center_position(self) -> Optional[int]: return center def init_centers(self, dimension: int): + """Initializes the speaker centroid matrix + + Parameters + ---------- + dimension: int + Dimension of embeddings used for representing a speaker. + """ self.centers = np.zeros((self.max_speakers, dimension)) self.active_centers = set() self.blocked_centers = set() def update(self, assignments: Iterable[Tuple[int, int]], embeddings: np.ndarray): + """Updates the speaker centroids given a list of assignments and local speaker embeddings + + Parameters + ---------- + assignments: Iterable[Tuple[int, int]]) + An iterable of tuples with two elements having the first element as the source speaker + and the second element as the target speaker. + embeddings: np.ndarray, shape (local_speakers, embedding_dim) + Matrix containing embeddings for all local speakers. + """ if self.centers is not None: for l_spk, g_spk in assignments: assert g_spk in self.active_centers, "Cannot update unknown centers" self.centers[g_spk] += embeddings[l_spk] def add_center(self, embedding: np.ndarray) -> int: + """Add a new speaker centroid initialized to a given embedding + + Parameters + ---------- + embedding: np.ndarray + Embedding vector of some local speaker + + Returns + ------- + center_index: int + Index of the created center + """ center = self.get_next_center_position() self.centers[center] = embedding self.active_centers.add(center) @@ -72,6 +120,20 @@ def identify( segmentation: SlidingWindowFeature, embeddings: torch.Tensor ) -> SpeakerMap: + """Identify the centroids to which the input speaker embeddings belong. + + Parameters + ---------- + segmentation: np.ndarray, shape (frames, local_speakers) + Matrix of segmentation outputs + embeddings: np.ndarray, shape (local_speakers, embedding_dim) + Matrix of embeddings + + Returns + ------- + speaker_map: SpeakerMap + A mapping from local speakers to global speakers. + """ embeddings = embeddings.detach().cpu().numpy() active_speakers = np.where(np.max(segmentation.data, axis=0) >= self.tau_active)[0] long_speakers = np.where(np.mean(segmentation.data, axis=0) >= self.rho_update)[0] diff --git a/src/diart/mapping.py b/src/diart/mapping.py index 01465086..2795ba0b 100644 --- a/src/diart/mapping.py +++ b/src/diart/mapping.py @@ -22,6 +22,23 @@ def mapped_indices(self, matrix: np.ndarray, axis: int) -> List[int]: def hard_speaker_map( self, num_src: int, num_tgt: int, assignments: Iterable[Tuple[int, int]] ) -> SpeakerMap: + """Create a hard map object where the highest cost is put + everywhere except on hard assignments from ``assignments``. + + Parameters + ---------- + num_src: int + Number of source speakers + num_tgt: int + Number of target speakers + assignments: Iterable[Tuple[int, int]] + An iterable of tuples with two elements having the first element as the source speaker + and the second element as the target speaker + + Returns + ------- + SpeakerMap + """ mapping_matrix = self.invalid_tensor(shape=(num_src, num_tgt)) for src, tgt in assignments: mapping_matrix[src, tgt] = self.best_possible_value @@ -82,6 +99,23 @@ class SpeakerMapBuilder: def hard_map( shape: Tuple[int, int], assignments: Iterable[Tuple[int, int]], maximize: bool ) -> SpeakerMap: + """Create a ``SpeakerMap`` object based on the given assignments. This is a "hard" map, meaning that the + highest cost is put everywhere except on hard assignments from ``assignments``. + + Parameters + ---------- + shape: Tuple[int, int]) + Shape of the mapping matrix + assignments: Iterable[Tuple[int, int]] + An iterable of tuples with two elements having the first element as the source speaker + and the second element as the target speaker + maximize: bool + whether to use scores where higher is better (true) or where lower is better (false) + + Returns + ------- + SpeakerMap + """ num_src, num_tgt = shape objective = MaximizationObjective if maximize else MinimizationObjective return objective().hard_speaker_map(num_src, num_tgt, assignments) diff --git a/src/diart/pipelines.py b/src/diart/pipelines.py index 4ccbf1d4..dea360d6 100644 --- a/src/diart/pipelines.py +++ b/src/diart/pipelines.py @@ -81,8 +81,7 @@ def from_namespace(args: Namespace) -> 'PipelineConfig': ) def last_chunk_end_time(self, conv_duration: float) -> Optional[float]: - """ - Return the end time of the last chunk for a given conversation duration. + """Return the end time of the last chunk for a given conversation duration. Parameters ----------