Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add OverlapAwareSpeakerEmbedding #51

Merged
merged 2 commits into from
May 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 11 additions & 15 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,29 +46,25 @@ Obtain overlap-aware speaker embeddings from a microphone stream
```python
import rx
import rx.operators as ops
import diart.operators as myops
import diart.operators as dops
from diart.sources import MicrophoneAudioSource
import diart.blocks as blocks
from diart.blocks import FramewiseModel, OverlapAwareSpeakerEmbedding

sample_rate = 16000
mic = MicrophoneAudioSource(sample_rate)

# Initialize independent modules
segmentation = blocks.FramewiseModel("pyannote/segmentation")
embedding = blocks.ChunkwiseModel("pyannote/embedding")
osp = blocks.OverlappedSpeechPenalty(gamma=3, beta=10)
normalization = blocks.EmbeddingNormalization(norm=1)
segmentation = FramewiseModel("pyannote/segmentation")
embedding = OverlapAwareSpeakerEmbedding("pyannote/embedding")

# Reformat microphone stream. Defaults to 5s duration and 500ms shift
regular_stream = mic.stream.pipe(myops.regularize_stream(sample_rate))
regular_stream = mic.stream.pipe(dops.regularize_stream(sample_rate))
# Branch the microphone stream to calculate segmentation
segmentation_stream = regular_stream.pipe(ops.map(segmentation))
# Join audio and segmentation stream to calculate speaker embeddings
embedding_stream = rx.zip(regular_stream, segmentation_stream).pipe(
ops.starmap(lambda wave, seg: (wave, osp(seg))),
ops.starmap(embedding),
ops.map(normalization)
)
embedding_stream = rx.zip(
regular_stream, segmentation_stream
).pipe(ops.starmap(embedding))

embedding_stream.subscribe(on_next=lambda emb: print(emb.shape))

Expand All @@ -89,11 +85,11 @@ torch.Size([4, 512])
1) Create environment:

```shell
conda create -n diarization python=3.8
conda activate diarization
conda create -n diart python=3.8
conda activate diart
```

2) Install the latest PyTorch version following the [official instructions](https://pytorch.org/get-started/locally/#start-locally)
2) [Install PyTorch](https://pytorch.org/get-started/locally/#start-locally)

3) Install pyannote.audio 2.0 (currently in development)
```shell
Expand Down
37 changes: 37 additions & 0 deletions src/diart/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,43 @@ def __call__(self, embeddings: torch.Tensor) -> torch.Tensor:
return norm_embs.squeeze()


class OverlapAwareSpeakerEmbedding:
"""
Extract overlap-aware speaker embeddings given an audio chunk and its segmentation.

Parameters
----------
model: pyannote.audio.Model, Text or Dict
The embedding model. It must take a waveform and weights as input.
gamma: float, optional
Exponent to lower low-confidence predictions.
Defaults to 3.
beta: float, optional
Softmax's temperature parameter (actually 1/beta) to lower joint speaker activations.
Defaults to 10.
norm: float or torch.Tensor of shape (batch, speakers, 1) where batch is optional
The target norm for the embeddings. It can be different for each speaker.
Defaults to 1.
device: Optional[torch.device]
The device on which to run the embedding model.
Defaults to GPU if available or CPU if not.
"""
def __init__(
self,
model: PipelineModel,
gamma: float = 3,
beta: float = 10,
norm: Union[float, torch.Tensor] = 1,
device: Optional[torch.device] = None,
):
self.embedding = ChunkwiseModel(model, device)
self.osp = OverlappedSpeechPenalty(gamma, beta)
self.normalize = EmbeddingNormalization(norm)

def __call__(self, waveform: TemporalFeatures, segmentation: TemporalFeatures) -> torch.Tensor:
return self.normalize(self.embedding(waveform, self.osp(segmentation)))


class AggregationStrategy:
"""Abstract class representing a strategy to aggregate overlapping buffers"""

Expand Down
33 changes: 13 additions & 20 deletions src/diart/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,11 +117,13 @@ def from_model_streams(
class OnlineSpeakerDiarization:
def __init__(self, config: PipelineConfig):
self.config = config
self.segmentation = blocks.FramewiseModel(config.segmentation, self.config.device)
self.embedding = blocks.ChunkwiseModel(config.embedding, self.config.device)
self.segmentation = blocks.FramewiseModel(config.segmentation, config.device)
self.embedding = blocks.OverlapAwareSpeakerEmbedding(
config.embedding, config.gamma, config.beta, norm=1, device=config.device
)
self.speaker_tracking = OnlineSpeakerTracking(config)
msg = "Invalid latency requested"
assert self.config.step <= self.config.latency <= self.duration, msg
assert config.step <= config.latency <= self.duration, msg

@property
def sample_rate(self) -> int:
Expand Down Expand Up @@ -150,17 +152,12 @@ def from_source(
dops.regularize_stream(self.duration, self.config.step, source.sample_rate)
)
# Branch the stream to calculate chunk segmentation
segmentation_stream = regular_stream.pipe(ops.map(self.segmentation))
# Join audio and segmentation stream to calculate speaker embeddings
osp = blocks.OverlappedSpeechPenalty(gamma=self.config.gamma, beta=self.config.beta)
embedding_stream = rx.zip(regular_stream, segmentation_stream).pipe(
ops.starmap(lambda wave, seg: (wave, osp(seg))),
ops.starmap(self.embedding),
ops.map(blocks.EmbeddingNormalization(norm=1))
)
seg_stream = regular_stream.pipe(ops.map(self.segmentation))
# Join audio and segmentation stream to calculate overlap-aware speaker embeddings
emb_stream = rx.zip(regular_stream, seg_stream).pipe(ops.starmap(self.embedding))
chunk_stream = regular_stream if output_waveform else None
return self.speaker_tracking.from_model_streams(
source.uri, source.duration, segmentation_stream, embedding_stream, chunk_stream
source.uri, source.duration, seg_stream, emb_stream, chunk_stream
)

def from_file(
Expand All @@ -176,10 +173,6 @@ def from_file(
self.sample_rate, self.duration, self.config.step
)

# Initialize pipeline modules
osp = blocks.OverlappedSpeechPenalty(self.config.gamma, self.config.beta)
emb_norm = blocks.EmbeddingNormalization(norm=1)

# Split audio into chunks
chunks = rearrange(
chunk_loader.get_chunks(file),
Expand All @@ -205,7 +198,7 @@ def from_file(
# Edge case: add batch dimension if i == i_end + 1
if seg.ndim == 2:
seg = seg[np.newaxis]
emb = emb_norm(self.embedding(batch, osp(seg)))
emb = self.embedding(batch, seg)
# Edge case: add batch dimension if i == i_end + 1
if emb.ndim == 2:
emb = emb.unsqueeze(0)
Expand All @@ -216,12 +209,12 @@ def from_file(

# Stream pre-calculated segmentation, embeddings and chunks
resolution = self.duration / segmentation.shape[1]
segmentation_stream = rx.range(0, num_chunks).pipe(
seg_stream = rx.range(0, num_chunks).pipe(
ops.map(lambda i: SlidingWindowFeature(
segmentation[i], SlidingWindow(resolution, resolution, i * self.config.step)
))
)
embedding_stream = rx.range(0, num_chunks).pipe(ops.map(lambda i: embeddings[i]))
emb_stream = rx.range(0, num_chunks).pipe(ops.map(lambda i: embeddings[i]))
wav_resolution = 1 / self.sample_rate
chunk_stream = None
if output_waveform:
Expand All @@ -234,5 +227,5 @@ def from_file(
# Build speaker tracking pipeline
duration = chunk_loader.audio.get_duration(file)
return self.speaker_tracking.from_model_streams(
file.stem, duration, segmentation_stream, embedding_stream, chunk_stream
file.stem, duration, seg_stream, emb_stream, chunk_stream
)