From bbdf161966c69320c1ff048a4f6ddd5e204731b0 Mon Sep 17 00:00:00 2001 From: Juan Manuel Coria Date: Mon, 9 May 2022 17:32:39 +0200 Subject: [PATCH 1/2] Add OverlapAwareSpeakerEmbedding block --- src/diart/blocks.py | 37 +++++++++++++++++++++++++++++++++++++ src/diart/pipelines.py | 33 +++++++++++++-------------------- 2 files changed, 50 insertions(+), 20 deletions(-) diff --git a/src/diart/blocks.py b/src/diart/blocks.py index 2c02623c..19b334d5 100644 --- a/src/diart/blocks.py +++ b/src/diart/blocks.py @@ -165,6 +165,43 @@ def __call__(self, embeddings: torch.Tensor) -> torch.Tensor: return norm_embs.squeeze() +class OverlapAwareSpeakerEmbedding: + """ + Extract overlap-aware speaker embeddings given an audio chunk and its segmentation. + + Parameters + ---------- + model: pyannote.audio.Model, Text or Dict + The embedding model. It must take a waveform and weights as input. + gamma: float, optional + Exponent to lower low-confidence predictions. + Defaults to 3. + beta: float, optional + Softmax's temperature parameter (actually 1/beta) to lower joint speaker activations. + Defaults to 10. + norm: float or torch.Tensor of shape (batch, speakers, 1) where batch is optional + The target norm for the embeddings. It can be different for each speaker. + Defaults to 1. + device: Optional[torch.device] + The device on which to run the embedding model. + Defaults to GPU if available or CPU if not. + """ + def __init__( + self, + model: PipelineModel, + gamma: float = 3, + beta: float = 10, + norm: Union[float, torch.Tensor] = 1, + device: Optional[torch.device] = None, + ): + self.embedding = ChunkwiseModel(model, device) + self.osp = OverlappedSpeechPenalty(gamma, beta) + self.normalize = EmbeddingNormalization(norm) + + def __call__(self, waveform: TemporalFeatures, segmentation: TemporalFeatures) -> torch.Tensor: + return self.normalize(self.embedding(waveform, self.osp(segmentation))) + + class AggregationStrategy: """Abstract class representing a strategy to aggregate overlapping buffers""" diff --git a/src/diart/pipelines.py b/src/diart/pipelines.py index 919a371e..649b6a4f 100644 --- a/src/diart/pipelines.py +++ b/src/diart/pipelines.py @@ -117,11 +117,13 @@ def from_model_streams( class OnlineSpeakerDiarization: def __init__(self, config: PipelineConfig): self.config = config - self.segmentation = blocks.FramewiseModel(config.segmentation, self.config.device) - self.embedding = blocks.ChunkwiseModel(config.embedding, self.config.device) + self.segmentation = blocks.FramewiseModel(config.segmentation, config.device) + self.embedding = blocks.OverlapAwareSpeakerEmbedding( + config.embedding, config.gamma, config.beta, norm=1, device=config.device + ) self.speaker_tracking = OnlineSpeakerTracking(config) msg = "Invalid latency requested" - assert self.config.step <= self.config.latency <= self.duration, msg + assert config.step <= config.latency <= self.duration, msg @property def sample_rate(self) -> int: @@ -150,17 +152,12 @@ def from_source( dops.regularize_stream(self.duration, self.config.step, source.sample_rate) ) # Branch the stream to calculate chunk segmentation - segmentation_stream = regular_stream.pipe(ops.map(self.segmentation)) - # Join audio and segmentation stream to calculate speaker embeddings - osp = blocks.OverlappedSpeechPenalty(gamma=self.config.gamma, beta=self.config.beta) - embedding_stream = rx.zip(regular_stream, segmentation_stream).pipe( - ops.starmap(lambda wave, seg: (wave, osp(seg))), - ops.starmap(self.embedding), - ops.map(blocks.EmbeddingNormalization(norm=1)) - ) + seg_stream = regular_stream.pipe(ops.map(self.segmentation)) + # Join audio and segmentation stream to calculate overlap-aware speaker embeddings + emb_stream = rx.zip(regular_stream, seg_stream).pipe(ops.starmap(self.embedding)) chunk_stream = regular_stream if output_waveform else None return self.speaker_tracking.from_model_streams( - source.uri, source.duration, segmentation_stream, embedding_stream, chunk_stream + source.uri, source.duration, seg_stream, emb_stream, chunk_stream ) def from_file( @@ -176,10 +173,6 @@ def from_file( self.sample_rate, self.duration, self.config.step ) - # Initialize pipeline modules - osp = blocks.OverlappedSpeechPenalty(self.config.gamma, self.config.beta) - emb_norm = blocks.EmbeddingNormalization(norm=1) - # Split audio into chunks chunks = rearrange( chunk_loader.get_chunks(file), @@ -205,7 +198,7 @@ def from_file( # Edge case: add batch dimension if i == i_end + 1 if seg.ndim == 2: seg = seg[np.newaxis] - emb = emb_norm(self.embedding(batch, osp(seg))) + emb = self.embedding(batch, seg) # Edge case: add batch dimension if i == i_end + 1 if emb.ndim == 2: emb = emb.unsqueeze(0) @@ -216,12 +209,12 @@ def from_file( # Stream pre-calculated segmentation, embeddings and chunks resolution = self.duration / segmentation.shape[1] - segmentation_stream = rx.range(0, num_chunks).pipe( + seg_stream = rx.range(0, num_chunks).pipe( ops.map(lambda i: SlidingWindowFeature( segmentation[i], SlidingWindow(resolution, resolution, i * self.config.step) )) ) - embedding_stream = rx.range(0, num_chunks).pipe(ops.map(lambda i: embeddings[i])) + emb_stream = rx.range(0, num_chunks).pipe(ops.map(lambda i: embeddings[i])) wav_resolution = 1 / self.sample_rate chunk_stream = None if output_waveform: @@ -234,5 +227,5 @@ def from_file( # Build speaker tracking pipeline duration = chunk_loader.audio.get_duration(file) return self.speaker_tracking.from_model_streams( - file.stem, duration, segmentation_stream, embedding_stream, chunk_stream + file.stem, duration, seg_stream, emb_stream, chunk_stream ) From c8c7db2a9994908f30fb6f5de171971b7d7e3aa0 Mon Sep 17 00:00:00 2001 From: Juan Manuel Coria Date: Mon, 9 May 2022 18:52:40 +0200 Subject: [PATCH 2/2] Update README.md --- README.md | 26 +++++++++++--------------- 1 file changed, 11 insertions(+), 15 deletions(-) diff --git a/README.md b/README.md index 2a73fb32..8d65ea50 100644 --- a/README.md +++ b/README.md @@ -46,29 +46,25 @@ Obtain overlap-aware speaker embeddings from a microphone stream ```python import rx import rx.operators as ops -import diart.operators as myops +import diart.operators as dops from diart.sources import MicrophoneAudioSource -import diart.blocks as blocks +from diart.blocks import FramewiseModel, OverlapAwareSpeakerEmbedding sample_rate = 16000 mic = MicrophoneAudioSource(sample_rate) # Initialize independent modules -segmentation = blocks.FramewiseModel("pyannote/segmentation") -embedding = blocks.ChunkwiseModel("pyannote/embedding") -osp = blocks.OverlappedSpeechPenalty(gamma=3, beta=10) -normalization = blocks.EmbeddingNormalization(norm=1) +segmentation = FramewiseModel("pyannote/segmentation") +embedding = OverlapAwareSpeakerEmbedding("pyannote/embedding") # Reformat microphone stream. Defaults to 5s duration and 500ms shift -regular_stream = mic.stream.pipe(myops.regularize_stream(sample_rate)) +regular_stream = mic.stream.pipe(dops.regularize_stream(sample_rate)) # Branch the microphone stream to calculate segmentation segmentation_stream = regular_stream.pipe(ops.map(segmentation)) # Join audio and segmentation stream to calculate speaker embeddings -embedding_stream = rx.zip(regular_stream, segmentation_stream).pipe( - ops.starmap(lambda wave, seg: (wave, osp(seg))), - ops.starmap(embedding), - ops.map(normalization) -) +embedding_stream = rx.zip( + regular_stream, segmentation_stream +).pipe(ops.starmap(embedding)) embedding_stream.subscribe(on_next=lambda emb: print(emb.shape)) @@ -89,11 +85,11 @@ torch.Size([4, 512]) 1) Create environment: ```shell -conda create -n diarization python=3.8 -conda activate diarization +conda create -n diart python=3.8 +conda activate diart ``` -2) Install the latest PyTorch version following the [official instructions](https://pytorch.org/get-started/locally/#start-locally) +2) [Install PyTorch](https://pytorch.org/get-started/locally/#start-locally) 3) Install pyannote.audio 2.0 (currently in development) ```shell