Skip to content

Commit

Permalink
Add documentation for some classes and methods (#31)
Browse files Browse the repository at this point in the history
* Adds documentation for some of the classes and methods under functional.py and mapping.py

Co-authored-by: Juan Coria <[email protected]>
  • Loading branch information
zaouk and juanmc2005 authored Jun 20, 2022
1 parent b91550b commit a6ddb07
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 2 deletions.
62 changes: 62 additions & 0 deletions src/diart/blocks/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,25 @@


class OnlineSpeakerClustering:
"""Implements constrained incremental online clustering of speakers and manages cluster centers.
Parameters
----------
tau_active:float
Threshold for detecting active speakers. This threshold is applied on the maximum value of per-speaker output
activation of the local segmentation model.
rho_update: float
Threshold for considering the extracted embedding when updating the centroid of the local speaker.
The centroid to which a local speaker is mapped is only updated if the ratio of speech/chunk duration
of a given local speaker is greater than this threshold.
delta_new: float
Threshold on the distance between a speaker embedding and a centroid. If the distance between a local speaker and all
centroids is larger than delta_new, then a new centroid is created for the current speaker.
metric: str. Defaults to "cosine".
The distance metric to use.
max_speakers: int
Maximum number of global speakers to track through a conversation. Defaults to 20.
"""
def __init__(
self,
tau_active: float,
Expand Down Expand Up @@ -51,17 +70,46 @@ def get_next_center_position(self) -> Optional[int]:
return center

def init_centers(self, dimension: int):
"""Initializes the speaker centroid matrix
Parameters
----------
dimension: int
Dimension of embeddings used for representing a speaker.
"""
self.centers = np.zeros((self.max_speakers, dimension))
self.active_centers = set()
self.blocked_centers = set()

def update(self, assignments: Iterable[Tuple[int, int]], embeddings: np.ndarray):
"""Updates the speaker centroids given a list of assignments and local speaker embeddings
Parameters
----------
assignments: Iterable[Tuple[int, int]])
An iterable of tuples with two elements having the first element as the source speaker
and the second element as the target speaker.
embeddings: np.ndarray, shape (local_speakers, embedding_dim)
Matrix containing embeddings for all local speakers.
"""
if self.centers is not None:
for l_spk, g_spk in assignments:
assert g_spk in self.active_centers, "Cannot update unknown centers"
self.centers[g_spk] += embeddings[l_spk]

def add_center(self, embedding: np.ndarray) -> int:
"""Add a new speaker centroid initialized to a given embedding
Parameters
----------
embedding: np.ndarray
Embedding vector of some local speaker
Returns
-------
center_index: int
Index of the created center
"""
center = self.get_next_center_position()
self.centers[center] = embedding
self.active_centers.add(center)
Expand All @@ -72,6 +120,20 @@ def identify(
segmentation: SlidingWindowFeature,
embeddings: torch.Tensor
) -> SpeakerMap:
"""Identify the centroids to which the input speaker embeddings belong.
Parameters
----------
segmentation: np.ndarray, shape (frames, local_speakers)
Matrix of segmentation outputs
embeddings: np.ndarray, shape (local_speakers, embedding_dim)
Matrix of embeddings
Returns
-------
speaker_map: SpeakerMap
A mapping from local speakers to global speakers.
"""
embeddings = embeddings.detach().cpu().numpy()
active_speakers = np.where(np.max(segmentation.data, axis=0) >= self.tau_active)[0]
long_speakers = np.where(np.mean(segmentation.data, axis=0) >= self.rho_update)[0]
Expand Down
34 changes: 34 additions & 0 deletions src/diart/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,23 @@ def mapped_indices(self, matrix: np.ndarray, axis: int) -> List[int]:
def hard_speaker_map(
self, num_src: int, num_tgt: int, assignments: Iterable[Tuple[int, int]]
) -> SpeakerMap:
"""Create a hard map object where the highest cost is put
everywhere except on hard assignments from ``assignments``.
Parameters
----------
num_src: int
Number of source speakers
num_tgt: int
Number of target speakers
assignments: Iterable[Tuple[int, int]]
An iterable of tuples with two elements having the first element as the source speaker
and the second element as the target speaker
Returns
-------
SpeakerMap
"""
mapping_matrix = self.invalid_tensor(shape=(num_src, num_tgt))
for src, tgt in assignments:
mapping_matrix[src, tgt] = self.best_possible_value
Expand Down Expand Up @@ -82,6 +99,23 @@ class SpeakerMapBuilder:
def hard_map(
shape: Tuple[int, int], assignments: Iterable[Tuple[int, int]], maximize: bool
) -> SpeakerMap:
"""Create a ``SpeakerMap`` object based on the given assignments. This is a "hard" map, meaning that the
highest cost is put everywhere except on hard assignments from ``assignments``.
Parameters
----------
shape: Tuple[int, int])
Shape of the mapping matrix
assignments: Iterable[Tuple[int, int]]
An iterable of tuples with two elements having the first element as the source speaker
and the second element as the target speaker
maximize: bool
whether to use scores where higher is better (true) or where lower is better (false)
Returns
-------
SpeakerMap
"""
num_src, num_tgt = shape
objective = MaximizationObjective if maximize else MinimizationObjective
return objective().hard_speaker_map(num_src, num_tgt, assignments)
Expand Down
3 changes: 1 addition & 2 deletions src/diart/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,7 @@ def from_namespace(args: Namespace) -> 'PipelineConfig':
)

def last_chunk_end_time(self, conv_duration: float) -> Optional[float]:
"""
Return the end time of the last chunk for a given conversation duration.
"""Return the end time of the last chunk for a given conversation duration.
Parameters
----------
Expand Down

0 comments on commit a6ddb07

Please sign in to comment.