From 6a43b7619499150b4e568f40d057610d49684feb Mon Sep 17 00:00:00 2001
From: Juan Manuel Coria
Date: Tue, 31 May 2022 19:05:39 +0200
Subject: [PATCH 01/29] Add base classes for custom segmentation and embedding
models. Unify pipeline specifications in PipelineConfig
---
README.md | 20 +++++-----
src/diart/argdoc.py | 2 +-
src/diart/benchmark.py | 4 +-
src/diart/blocks.py | 90 ++++++++++++++++++++++++++----------------
src/diart/inference.py | 18 ++++++---
src/diart/models.py | 83 ++++++++++++++++++++++++++++++++++++++
src/diart/pipelines.py | 60 +++++++++++++++-------------
src/diart/stream.py | 10 +++--
8 files changed, 205 insertions(+), 82 deletions(-)
create mode 100644 src/diart/models.py
diff --git a/README.md b/README.md
index f742bd09..5faf55d5 100644
--- a/README.md
+++ b/README.md
@@ -63,8 +63,9 @@ from diart.sources import MicrophoneAudioSource
from diart.inference import RealTimeInference
from diart.pipelines import OnlineSpeakerDiarization, PipelineConfig
-pipeline = OnlineSpeakerDiarization(PipelineConfig())
-audio_source = MicrophoneAudioSource(pipeline.sample_rate)
+config = PipelineConfig() # Default parameters
+pipeline = OnlineSpeakerDiarization(config)
+audio_source = MicrophoneAudioSource(config.sample_rate)
inference = RealTimeInference("/output/path", do_plot=True)
inference(pipeline, audio_source)
@@ -86,17 +87,18 @@ import rx
import rx.operators as ops
import diart.operators as dops
from diart.sources import MicrophoneAudioSource
-from diart.blocks import FramewiseModel, OverlapAwareSpeakerEmbedding
-
-sample_rate = 16000
-mic = MicrophoneAudioSource(sample_rate)
+from diart.blocks import SpeakerSegmentation, OverlapAwareSpeakerEmbedding
+from diart.models import SegmentationModel, EmbeddingModel
# Initialize independent modules
-segmentation = FramewiseModel("pyannote/segmentation")
-embedding = OverlapAwareSpeakerEmbedding("pyannote/embedding")
+seg_model = SegmentationModel.from_pyannote("pyannote/segmentation")
+segmentation = SpeakerSegmentation(seg_model)
+emb_model = EmbeddingModel.from_pyannote("pyannote/embedding")
+embedding = OverlapAwareSpeakerEmbedding(emb_model)
+mic = MicrophoneAudioSource(seg_model.get_sample_rate())
# Reformat microphone stream. Defaults to 5s duration and 500ms shift
-regular_stream = mic.stream.pipe(dops.regularize_stream(sample_rate))
+regular_stream = mic.stream.pipe(dops.regularize_stream(seg_model.get_sample_rate()))
# Branch the microphone stream to calculate segmentation
segmentation_stream = regular_stream.pipe(ops.map(segmentation))
# Join audio and segmentation stream to calculate speaker embeddings
diff --git a/src/diart/argdoc.py b/src/diart/argdoc.py
index 933f96ca..640dae14 100644
--- a/src/diart/argdoc.py
+++ b/src/diart/argdoc.py
@@ -6,5 +6,5 @@
GAMMA = "Parameter gamma for overlapped speech penalty"
BETA = "Parameter beta for overlapped speech penalty"
MAX_SPEAKERS = "Maximum number of speakers"
-GPU = "Run on GPU"
+CPU = "Force models to run on CPU"
OUTPUT = "Directory to store the system's output in RTTM format"
diff --git a/src/diart/benchmark.py b/src/diart/benchmark.py
index fc07f754..b161b708 100644
--- a/src/diart/benchmark.py
+++ b/src/diart/benchmark.py
@@ -20,7 +20,7 @@
parser.add_argument("--beta", default=10, type=float, help=f"{argdoc.BETA}. Defaults to 10")
parser.add_argument("--max-speakers", default=20, type=int, help=f"{argdoc.MAX_SPEAKERS}. Defaults to 20")
parser.add_argument("--batch-size", default=32, type=int, help="For segmentation and embedding pre-calculation. If BATCH_SIZE < 2, run fully online and estimate real-time latency. Defaults to 32")
- parser.add_argument("--gpu", dest="gpu", action="store_true", help=argdoc.GPU)
+ parser.add_argument("--cpu", dest="cpu", action="store_true", help=f"{argdoc.CPU}. Defaults to GPU if available, CPU otherwise")
parser.add_argument("--output", type=str, help=f"{argdoc.OUTPUT}. Defaults to `root`")
args = parser.parse_args()
@@ -37,7 +37,7 @@
gamma=args.gamma,
beta=args.beta,
max_speakers=args.max_speakers,
- device=torch.device("cuda") if args.gpu else None,
+ device=torch.device("cpu") if args.cpu else None,
))
benchmark(pipeline, args.batch_size)
diff --git a/src/diart/blocks.py b/src/diart/blocks.py
index 2da18939..b5117d07 100644
--- a/src/diart/blocks.py
+++ b/src/diart/blocks.py
@@ -1,15 +1,14 @@
from typing import Union, Optional, List, Iterable, Tuple
-from typing_extensions import Literal
import numpy as np
import torch
-from pyannote.audio.pipelines.utils import PipelineModel, get_model, get_devices
+from einops import rearrange
from pyannote.audio.utils.signal import Binarize as PyanBinarize
from pyannote.core import Annotation, Segment, SlidingWindow, SlidingWindowFeature
-from einops import rearrange
+from typing_extensions import Literal
from .mapping import SpeakerMap, SpeakerMapBuilder
-
+from .models import SegmentationModel, EmbeddingModel
TemporalFeatures = Union[SlidingWindowFeature, np.ndarray, torch.Tensor]
@@ -42,26 +41,31 @@ def resolve_features(features: TemporalFeatures) -> torch.Tensor:
return data.float()
-class FramewiseModel:
- def __init__(self, model: PipelineModel, device: Optional[torch.device] = None):
- self.model = get_model(model)
+class SpeakerSegmentation:
+ def __init__(self, model: SegmentationModel, device: Optional[torch.device] = None):
+ self.model = model
self.model.eval()
- if device is None:
- device = get_devices(needs=1)[0]
- self.model.to(device)
+ self.device = device
+ if self.device is None:
+ self.device = torch.device("cpu")
+ self.model.to(self.device)
- @property
- def sample_rate(self) -> int:
- return self.model.audio.sample_rate
+ def __call__(self, waveform: TemporalFeatures) -> TemporalFeatures:
+ """
+ Calculate the speaker segmentation of input audio.
- @property
- def duration(self) -> float:
- return self.model.specifications.duration
+ Parameters
+ ----------
+ waveform: TemporalFeatures, shape (samples, channels) or (batch, samples, channels)
- def __call__(self, waveform: TemporalFeatures) -> TemporalFeatures:
+ Returns
+ -------
+ speaker_segmentation: TemporalFeatures, shape (batch, frames, speakers)
+ The batch dimension is omitted if waveform is a `SlidingWindowFeature`.
+ """
with torch.no_grad():
wave = rearrange(resolve_features(waveform), "batch sample channel -> batch channel sample")
- output = self.model(wave.to(self.model.device)).cpu()
+ output = self.model(wave.to(self.device)).cpu()
batch_size, num_frames, _ = output.shape
@@ -72,7 +76,7 @@ def __call__(self, waveform: TemporalFeatures) -> TemporalFeatures:
# Wrap if a SlidingWindowFeature was given as input
if isinstance(waveform, SlidingWindowFeature):
# Temporal resolution of the output
- duration = wave.shape[-1] / self.sample_rate
+ duration = wave.shape[-1] / self.model.get_sample_rate()
resolution = duration / num_frames
# Temporal shift to keep track of current start time
resolution = SlidingWindow(
@@ -88,26 +92,44 @@ def __call__(self, waveform: TemporalFeatures) -> TemporalFeatures:
return output
-class ChunkwiseModel:
- def __init__(self, model: PipelineModel, device: Optional[torch.device] = None):
- self.model = get_model(model)
+class SpeakerEmbedding:
+ def __init__(self, model: EmbeddingModel, device: Optional[torch.device] = None):
+ self.model = model
self.model.eval()
- if device is None:
- device = get_devices(needs=1)[0]
- self.model.to(device)
+ self.device = device
+ if self.device is None:
+ self.device = torch.device("cpu")
+ self.model.to(self.device)
+
+ def __call__(self, waveform: TemporalFeatures, weights: Optional[TemporalFeatures] = None) -> torch.Tensor:
+ """
+ Calculate speaker embeddings of input audio.
+ If weights are given, calculate many speaker embeddings from the same waveform.
+
+ Parameters
+ ----------
+ waveform: TemporalFeatures, shape (samples, channels) or (batch, samples, channels)
+ weights: Optional[TemporalFeatures], shape (frames, speakers) or (batch, frames, speakers)
+ Per-speaker and per-frame weights. Defaults to no weights.
- def __call__(self, waveform: TemporalFeatures, weights: Optional[TemporalFeatures]) -> torch.Tensor:
+ Returns
+ -------
+ embeddings: torch.Tensor
+ If weights are provided, the shape is (batch, speakers, embedding_dim),
+ otherwise the shape is (batch, embedding_dim).
+ If batch size == 1, the batch dimension is omitted.
+ """
with torch.no_grad():
- inputs = resolve_features(waveform).to(self.model.device)
+ inputs = resolve_features(waveform).to(self.device)
inputs = rearrange(inputs, "batch sample channel -> batch channel sample")
if weights is not None:
- weights = resolve_features(weights).to(self.model.device)
+ weights = resolve_features(weights).to(self.device)
batch_size, _, num_speakers = weights.shape
inputs = inputs.repeat(1, num_speakers, 1)
weights = rearrange(weights, "batch frame spk -> (batch spk) frame")
inputs = rearrange(inputs, "batch spk sample -> (batch spk) 1 sample")
output = rearrange(
- self.model(inputs, weights=weights),
+ self.model(inputs, weights),
"(batch spk) feat -> batch spk feat",
batch=batch_size,
spk=num_speakers
@@ -125,7 +147,7 @@ class OverlappedSpeechPenalty:
Exponent to lower low-confidence predictions.
Defaults to 3.
beta: float, optional
- Softmax's temperature parameter (actually 1/beta) to lower joint speaker activations.
+ Temperature parameter (actually 1/beta) to lower joint speaker activations.
Defaults to 10.
"""
def __init__(self, gamma: float = 3, beta: float = 10):
@@ -171,8 +193,8 @@ class OverlapAwareSpeakerEmbedding:
Parameters
----------
- model: pyannote.audio.Model, Text or Dict
- The embedding model. It must take a waveform and weights as input.
+ model: EmbeddingModel
+ A pre-trained embedding model.
gamma: float, optional
Exponent to lower low-confidence predictions.
Defaults to 3.
@@ -188,13 +210,13 @@ class OverlapAwareSpeakerEmbedding:
"""
def __init__(
self,
- model: PipelineModel,
+ model: EmbeddingModel,
gamma: float = 3,
beta: float = 10,
norm: Union[float, torch.Tensor] = 1,
device: Optional[torch.device] = None,
):
- self.embedding = ChunkwiseModel(model, device)
+ self.embedding = SpeakerEmbedding(model, device)
self.osp = OverlappedSpeechPenalty(gamma, beta)
self.normalize = EmbeddingNormalization(norm)
diff --git a/src/diart/inference.py b/src/diart/inference.py
index a9c79589..4574279b 100644
--- a/src/diart/inference.py
+++ b/src/diart/inference.py
@@ -59,12 +59,12 @@ def __call__(self, pipeline: OnlineSpeakerDiarization, source: src.AudioSource)
observable.pipe(
ops.do(rttm_writer),
dops.buffer_output(
- duration=pipeline.duration,
+ duration=pipeline.config.duration,
step=pipeline.config.step,
latency=pipeline.config.latency,
- sample_rate=pipeline.sample_rate
+ sample_rate=pipeline.config.sample_rate
),
- ).subscribe(RealTimePlot(pipeline.duration, pipeline.config.latency))
+ ).subscribe(RealTimePlot(pipeline.config.duration, pipeline.config.latency))
# Stream audio through the pipeline
source.read()
@@ -125,7 +125,11 @@ def __call__(self, pipeline: OnlineSpeakerDiarization, batch_size: int = 32) ->
DataFrame with detailed performance on each file, as well as average performance.
None if the reference is not provided.
"""
- chunk_loader = src.ChunkLoader(pipeline.sample_rate, pipeline.duration, pipeline.config.step)
+ chunk_loader = src.ChunkLoader(
+ pipeline.config.sample_rate,
+ pipeline.config.duration,
+ pipeline.config.step
+ )
audio_file_paths = list(self.speech_path.iterdir())
num_audio_files = len(audio_file_paths)
for i, filepath in enumerate(audio_file_paths):
@@ -137,7 +141,11 @@ def __call__(self, pipeline: OnlineSpeakerDiarization, batch_size: int = 32) ->
source = src.FileAudioSource(
filepath,
filepath.stem,
- src.RegularAudioFileReader(pipeline.sample_rate, pipeline.duration, pipeline.config.step),
+ src.RegularAudioFileReader(
+ pipeline.config.sample_rate,
+ pipeline.config.duration,
+ pipeline.config.step
+ ),
# Benchmark the processing time of a single chunk
profile=True,
)
diff --git a/src/diart/models.py b/src/diart/models.py
new file mode 100644
index 00000000..74ae5c9b
--- /dev/null
+++ b/src/diart/models.py
@@ -0,0 +1,83 @@
+from typing import Optional
+
+import torch
+import torch.nn as nn
+from pyannote.audio.pipelines.utils import PipelineModel, get_model
+
+
+class SegmentationModel(nn.Module):
+ """
+ Minimal interface for a segmentation model.
+ """
+
+ @staticmethod
+ def from_pyannote(model: PipelineModel) -> 'SegmentationModel':
+ class PyannoteSegmentationModel(SegmentationModel):
+ def __init__(self, pyannote_model: PipelineModel):
+ super().__init__()
+ self.model = get_model(pyannote_model)
+
+ def get_sample_rate(self) -> int:
+ return self.model.audio.sample_rate
+
+ def get_duration(self) -> float:
+ return self.model.specifications.duration
+
+ def __call__(self, waveform: torch.Tensor) -> torch.Tensor:
+ return self.model(waveform)
+
+ return PyannoteSegmentationModel(model)
+
+ def get_sample_rate(self) -> int:
+ """Return the sample rate expected for model inputs"""
+ raise NotImplementedError
+
+ def get_duration(self) -> float:
+ """Return the input duration by default (usually the one used during training)"""
+ raise NotImplementedError
+
+ def __call__(self, waveform: torch.Tensor) -> torch.Tensor:
+ """
+ Forward pass of a segmentation model.
+
+ Parameters
+ ----------
+ waveform: torch.Tensor, shape (batch, channels, samples)
+
+ Returns
+ -------
+ speaker_segmentation: torch.Tensor, shape (batch, frames, speakers)
+ """
+ raise NotImplementedError
+
+
+class EmbeddingModel(nn.Module):
+ """Minimal interface for an embedding model."""
+
+ @staticmethod
+ def from_pyannote(model: PipelineModel) -> 'EmbeddingModel':
+ class PyannoteEmbeddingModel(EmbeddingModel):
+ def __init__(self, pyannote_model: PipelineModel):
+ super().__init__()
+ self.model = get_model(pyannote_model)
+
+ def __call__(self, waveform: torch.Tensor, weights: Optional[torch.Tensor] = None) -> torch.Tensor:
+ return self.model(waveform, weights=weights)
+
+ return PyannoteEmbeddingModel(model)
+
+ def __call__(self, waveform: torch.Tensor, weights: Optional[torch.Tensor] = None) -> torch.Tensor:
+ """
+ Forward pass of an embedding model with optional weights.
+
+ Parameters
+ ----------
+ waveform: torch.Tensor, shape (batch, channels, samples)
+ weights: Optional[torch.Tensor], shape (batch, frames)
+ Temporal weights for each sample in the batch. Defaults to no weights.
+
+ Returns
+ -------
+ speaker_embeddings: torch.Tensor, shape (batch, speakers, embedding_dim)
+ """
+ raise NotImplementedError
diff --git a/src/diart/pipelines.py b/src/diart/pipelines.py
index b56b235a..2678d182 100644
--- a/src/diart/pipelines.py
+++ b/src/diart/pipelines.py
@@ -7,11 +7,12 @@
import rx.operators as ops
import torch
from einops import rearrange
-from pyannote.audio.pipelines.utils import PipelineModel
from pyannote.core import SlidingWindowFeature, SlidingWindow
+from pyannote.audio.pipelines.utils import get_devices
from tqdm import tqdm
from . import blocks
+from . import models as m
from . import operators as dops
from . import sources as src
@@ -19,8 +20,8 @@
class PipelineConfig:
def __init__(
self,
- segmentation: PipelineModel = "pyannote/segmentation",
- embedding: PipelineModel = "pyannote/embedding",
+ segmentation: Optional[m.SegmentationModel] = None,
+ embedding: Optional[m.EmbeddingModel] = None,
duration: Optional[float] = None,
step: float = 0.5,
latency: Optional[float] = None,
@@ -32,22 +33,40 @@ def __init__(
max_speakers: int = 20,
device: Optional[torch.device] = None,
):
+ # Default segmentation model is pyannote/segmentation
self.segmentation = segmentation
+ if self.segmentation is None:
+ self.segmentation = m.SegmentationModel.from_pyannote("pyannote/segmentation")
+
+ # Default duration is the one given by the segmentation model
+ self.duration = duration
+ if self.duration is None:
+ self.duration = self.segmentation.get_duration()
+
+ # Expected sample rate is given by the segmentation model
+ self.sample_rate = self.segmentation.get_sample_rate()
+
+ # Default embedding model is pyannote/embedding
self.embedding = embedding
- self.requested_duration = duration
+ if self.embedding is None:
+ self.embedding = m.EmbeddingModel.from_pyannote("pyannote/embedding")
+
+ # Latency defaults to the step duration
self.step = step
self.latency = latency
if self.latency is None:
self.latency = self.step
+
self.tau_active = tau_active
self.rho_update = rho_update
self.delta_new = delta_new
self.gamma = gamma
self.beta = beta
self.max_speakers = max_speakers
+
self.device = device
if self.device is None:
- self.device = torch.device("cpu")
+ self.device = get_devices(needs=1)[0]
def last_chunk_end_time(self, conv_duration: float) -> Optional[float]:
"""
@@ -117,39 +136,26 @@ def from_model_streams(
class OnlineSpeakerDiarization:
def __init__(self, config: PipelineConfig):
self.config = config
- self.segmentation = blocks.FramewiseModel(config.segmentation, config.device)
+ self.segmentation = blocks.SpeakerSegmentation(config.segmentation, config.device)
self.embedding = blocks.OverlapAwareSpeakerEmbedding(
config.embedding, config.gamma, config.beta, norm=1, device=config.device
)
self.speaker_tracking = OnlineSpeakerTracking(config)
- msg = "Invalid latency requested"
- assert config.step <= config.latency <= self.duration, msg
-
- @property
- def sample_rate(self) -> int:
- """Sample rate expected by the segmentation model"""
- return self.segmentation.sample_rate
-
- @property
- def duration(self) -> float:
- """Chunk duration (in seconds). Defaults to segmentation model duration"""
- duration = self.config.requested_duration
- if duration is None:
- duration = self.segmentation.duration
- return duration
+ msg = f"Latency should be in the range [{config.step}, {config.duration}]"
+ assert config.step <= config.latency <= config.duration, msg
def from_source(
self,
source: src.AudioSource,
output_waveform: bool = True
) -> rx.Observable:
- msg = f"Audio source has sample rate {source.sample_rate}, expected {self.sample_rate}"
- assert source.sample_rate == self.sample_rate, msg
+ msg = f"Audio source has sample rate {source.sample_rate}, expected {self.config.sample_rate}"
+ assert source.sample_rate == self.config.sample_rate, msg
# Regularize the stream to a specific chunk duration and step
regular_stream = source.stream
if not source.is_regular:
regular_stream = source.stream.pipe(
- dops.regularize_stream(self.duration, self.config.step, source.sample_rate)
+ dops.regularize_stream(self.config.duration, self.config.step, source.sample_rate)
)
# Branch the stream to calculate chunk segmentation
seg_stream = regular_stream.pipe(ops.map(self.segmentation))
@@ -170,7 +176,7 @@ def from_file(
# Audio file information
file = Path(file)
chunk_loader = src.ChunkLoader(
- self.sample_rate, self.duration, self.config.step
+ self.config.sample_rate, self.config.duration, self.config.step
)
# Split audio into chunks
@@ -207,14 +213,14 @@ def from_file(
embeddings = torch.vstack(embeddings)
# Stream pre-calculated segmentation, embeddings and chunks
- resolution = self.duration / segmentation.shape[1]
+ resolution = self.config.duration / segmentation.shape[1]
seg_stream = rx.range(0, num_chunks).pipe(
ops.map(lambda i: SlidingWindowFeature(
segmentation[i], SlidingWindow(resolution, resolution, i * self.config.step)
))
)
emb_stream = rx.range(0, num_chunks).pipe(ops.map(lambda i: embeddings[i]))
- wav_resolution = 1 / self.sample_rate
+ wav_resolution = 1 / self.config.sample_rate
chunk_stream = None
if output_waveform:
chunk_stream = rx.range(0, num_chunks).pipe(
diff --git a/src/diart/stream.py b/src/diart/stream.py
index d98d2c8f..4d4bb819 100644
--- a/src/diart/stream.py
+++ b/src/diart/stream.py
@@ -21,7 +21,7 @@
parser.add_argument("--beta", default=10, type=float, help=f"{argdoc.BETA}. Defaults to 10")
parser.add_argument("--max-speakers", default=20, type=int, help=f"{argdoc.MAX_SPEAKERS}. Defaults to 20")
parser.add_argument("--no-plot", dest="no_plot", action="store_true", help="Skip plotting for faster inference")
- parser.add_argument("--gpu", dest="gpu", action="store_true", help=argdoc.GPU)
+ parser.add_argument("--cpu", dest="cpu", action="store_true", help=f"{argdoc.CPU}. Defaults to GPU if available, CPU otherwise")
parser.add_argument("--output", type=str, help=f"{argdoc.OUTPUT}. Defaults to home directory if SOURCE == 'microphone' or parent directory if SOURCE is a file")
args = parser.parse_args()
@@ -35,7 +35,7 @@
gamma=args.gamma,
beta=args.beta,
max_speakers=args.max_speakers,
- device=torch.device("cuda") if args.gpu else None,
+ device=torch.device("cpu") if args.cpu else None,
))
# Manage audio source
@@ -46,12 +46,14 @@
file=args.source,
uri=args.source.stem,
reader=src.RegularAudioFileReader(
- pipeline.sample_rate, pipeline.duration, pipeline.config.step
+ pipeline.config.sample_rate,
+ pipeline.config.duration,
+ pipeline.config.step,
),
)
else:
args.output = Path("~/").expanduser() if args.output is None else Path(args.output)
- audio_source = src.MicrophoneAudioSource(pipeline.sample_rate)
+ audio_source = src.MicrophoneAudioSource(pipeline.config.sample_rate)
# Run online inference
RealTimeInference(args.output, do_plot=not args.no_plot)(pipeline, audio_source)
From 6c0f5713b6ab0650d70506a9e20f61a8095910c8 Mon Sep 17 00:00:00 2001
From: Juan Manuel Coria
Date: Wed, 1 Jun 2022 00:19:39 +0200
Subject: [PATCH 02/29] Add own binarize implementation
---
src/diart/blocks.py | 36 +++++++++++++++++-------------------
1 file changed, 17 insertions(+), 19 deletions(-)
diff --git a/src/diart/blocks.py b/src/diart/blocks.py
index b5117d07..b79fcdd9 100644
--- a/src/diart/blocks.py
+++ b/src/diart/blocks.py
@@ -3,7 +3,6 @@
import numpy as np
import torch
from einops import rearrange
-from pyannote.audio.utils.signal import Binarize as PyanBinarize
from pyannote.core import Annotation, Segment, SlidingWindow, SlidingWindowFeature
from typing_extensions import Literal
@@ -561,26 +560,25 @@ def __call__(self, segmentation: SlidingWindowFeature, embeddings: torch.Tensor)
class Binarize:
- def __init__(self, uri: str, tau_active: float):
+ def __init__(self, uri: str, threshold: float):
self.uri = uri
- self._binarize = PyanBinarize(
- onset=tau_active,
- offset=tau_active,
- min_duration_on=0,
- min_duration_off=0,
- )
-
- def _select(
- self, scores: SlidingWindowFeature, speaker: int
- ) -> SlidingWindowFeature:
- return SlidingWindowFeature(
- scores[:, speaker].reshape(-1, 1), scores.sliding_window
- )
+ self.threshold = threshold
def __call__(self, segmentation: SlidingWindowFeature) -> Annotation:
+ num_frames, num_speakers = segmentation.data.shape
+ timestamps = segmentation.sliding_window
+ is_active = segmentation.data.T > self.threshold # shape (speakers, frames)
+ # Artificially add last inactive frame to close any remaining speaker turns
+ is_active = np.append(is_active, [[False]] * num_speakers, axis=1)
annotation = Annotation(uri=self.uri, modality="speech")
- for speaker in range(segmentation.data.shape[1]):
- turns = self._binarize(self._select(segmentation, speaker))
- for speaker_turn in turns.itersegments():
- annotation[speaker_turn, speaker] = f"speaker{speaker}"
+ for spk in range(num_speakers):
+ start = timestamps[0].middle
+ for t in range(num_frames):
+ # Any (False, True) start a speaker turn at "True" index
+ if not is_active[spk, t] and is_active[spk, t + 1]:
+ start = timestamps[t + 1].middle
+ # Any (True, False) end a speaker turn at "False" index
+ elif is_active[spk, t] and not is_active[spk, t + 1]:
+ region = Segment(start, timestamps[t + 1].middle)
+ annotation[region, spk] = f"speaker{spk}"
return annotation
From 7f771731b5cb9e99f80bf933d116d2137f5f72df Mon Sep 17 00:00:00 2001
From: Juan Manuel Coria
Date: Wed, 1 Jun 2022 00:48:54 +0200
Subject: [PATCH 03/29] Make Binarize more efficient by looking at all speakers
at the same time. Add docstring to Binarize
---
src/diart/blocks.py | 54 +++++++++++++++++++++++++++++++++------------
1 file changed, 40 insertions(+), 14 deletions(-)
diff --git a/src/diart/blocks.py b/src/diart/blocks.py
index b79fcdd9..83f099fb 100644
--- a/src/diart/blocks.py
+++ b/src/diart/blocks.py
@@ -1,4 +1,4 @@
-from typing import Union, Optional, List, Iterable, Tuple
+from typing import Union, Optional, List, Iterable, Tuple, Text
import numpy as np
import torch
@@ -560,25 +560,51 @@ def __call__(self, segmentation: SlidingWindowFeature, embeddings: torch.Tensor)
class Binarize:
- def __init__(self, uri: str, threshold: float):
+ """
+ Transform a speaker segmentation from the discrete-time domain
+ into a continuous-time speaker segmentation.
+
+ Parameters
+ ----------
+ uri: Text
+ Uri of the audio stream.
+ threshold: float
+ Probability threshold to determine if a speaker is active at a given frame.
+ """
+
+ def __init__(self, uri: Text, threshold: float):
self.uri = uri
self.threshold = threshold
def __call__(self, segmentation: SlidingWindowFeature) -> Annotation:
+ """
+ Return the continuous-time segmentation
+ corresponding to the discrete-time input segmentation.
+
+ Parameters
+ ----------
+ segmentation: SlidingWindowFeature
+ Discrete-time speaker segmentation.
+
+ Returns
+ -------
+ annotation: Annotation
+ Continuous-time speaker segmentation.
+ """
num_frames, num_speakers = segmentation.data.shape
timestamps = segmentation.sliding_window
- is_active = segmentation.data.T > self.threshold # shape (speakers, frames)
+ is_active = segmentation.data > self.threshold
# Artificially add last inactive frame to close any remaining speaker turns
- is_active = np.append(is_active, [[False]] * num_speakers, axis=1)
+ is_active = np.append(is_active, [[False] * num_speakers], axis=0)
+ start_times = np.zeros(num_speakers) + timestamps[0].middle
annotation = Annotation(uri=self.uri, modality="speech")
- for spk in range(num_speakers):
- start = timestamps[0].middle
- for t in range(num_frames):
- # Any (False, True) start a speaker turn at "True" index
- if not is_active[spk, t] and is_active[spk, t + 1]:
- start = timestamps[t + 1].middle
- # Any (True, False) end a speaker turn at "False" index
- elif is_active[spk, t] and not is_active[spk, t + 1]:
- region = Segment(start, timestamps[t + 1].middle)
- annotation[region, spk] = f"speaker{spk}"
+ for t in range(num_frames):
+ # Any (False, True) starts a speaker turn at "True" index
+ onsets = np.logical_and(np.logical_not(is_active[t]), is_active[t + 1])
+ start_times[onsets] = timestamps[t + 1].middle
+ # Any (True, False) ends a speaker turn at "False" index
+ offsets = np.logical_and(is_active[t], np.logical_not(is_active[t + 1]))
+ for spk in np.where(offsets)[0]:
+ region = Segment(start_times[spk], timestamps[t + 1].middle)
+ annotation[region, spk] = f"speaker{spk}"
return annotation
From f2866c515deaaf8724501ad33d93d3188dc33ba9 Mon Sep 17 00:00:00 2001
From: Juan Manuel Coria
Date: Wed, 1 Jun 2022 00:54:06 +0200
Subject: [PATCH 04/29] Simplify default device logic
---
src/diart/pipelines.py | 3 +--
1 file changed, 1 insertion(+), 2 deletions(-)
diff --git a/src/diart/pipelines.py b/src/diart/pipelines.py
index 2678d182..2087f986 100644
--- a/src/diart/pipelines.py
+++ b/src/diart/pipelines.py
@@ -8,7 +8,6 @@
import torch
from einops import rearrange
from pyannote.core import SlidingWindowFeature, SlidingWindow
-from pyannote.audio.pipelines.utils import get_devices
from tqdm import tqdm
from . import blocks
@@ -66,7 +65,7 @@ def __init__(
self.device = device
if self.device is None:
- self.device = get_devices(needs=1)[0]
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def last_chunk_end_time(self, conv_duration: float) -> Optional[float]:
"""
From 3f8bd6b50522f80b72177fb46fb5c49723bc8565 Mon Sep 17 00:00:00 2001
From: Juan Coria
Date: Fri, 3 Jun 2022 10:39:01 +0200
Subject: [PATCH 05/29] Replace `resolve_features` with
`TemporalFeatureFormatter` (#59)
* Add TemporalFeatureFormatter to cast temporal inputs to pytorch tensors and then restore outputs to the original type
* Update docstrings
* Add cleaner implementation of TemporalFeatureFormatter
* Bug fixes
---
src/diart/blocks.py | 75 ++++-------------------
src/diart/features.py | 135 ++++++++++++++++++++++++++++++++++++++++++
2 files changed, 146 insertions(+), 64 deletions(-)
create mode 100644 src/diart/features.py
diff --git a/src/diart/blocks.py b/src/diart/blocks.py
index 2da18939..df2edf1c 100644
--- a/src/diart/blocks.py
+++ b/src/diart/blocks.py
@@ -9,37 +9,7 @@
from einops import rearrange
from .mapping import SpeakerMap, SpeakerMapBuilder
-
-
-TemporalFeatures = Union[SlidingWindowFeature, np.ndarray, torch.Tensor]
-
-
-def resolve_features(features: TemporalFeatures) -> torch.Tensor:
- """
- Transform features into a `torch.Tensor` and add batch dimension if missing.
-
- Parameters
- ----------
- features: Union[SlidingWindowFeature, np.ndarray, torch.Tensor]
- Shape (frames, channels) or (batch, frames, channels)
-
- Returns
- -------
- transformed_features: torch.Tensor, shape (batch, frames, channels)
- """
- # As torch.Tensor with shape (..., channels, frames)
- if isinstance(features, SlidingWindowFeature):
- data = torch.from_numpy(features.data)
- elif isinstance(features, np.ndarray):
- data = torch.from_numpy(features)
- else:
- data = features
- # Make sure there's a batch dimension
- msg = "Temporal features must be 2D or 3D"
- assert data.ndim in (2, 3), msg
- if data.ndim == 2:
- data = data.unsqueeze(0)
- return data.float()
+from .features import TemporalFeatures, TemporalFeatureFormatter
class FramewiseModel:
@@ -49,6 +19,7 @@ def __init__(self, model: PipelineModel, device: Optional[torch.device] = None):
if device is None:
device = get_devices(needs=1)[0]
self.model.to(device)
+ self.formatter = TemporalFeatureFormatter()
@property
def sample_rate(self) -> int:
@@ -60,32 +31,9 @@ def duration(self) -> float:
def __call__(self, waveform: TemporalFeatures) -> TemporalFeatures:
with torch.no_grad():
- wave = rearrange(resolve_features(waveform), "batch sample channel -> batch channel sample")
+ wave = rearrange(self.formatter.cast(waveform), "batch sample channel -> batch channel sample")
output = self.model(wave.to(self.model.device)).cpu()
-
- batch_size, num_frames, _ = output.shape
-
- # Remove batch dimension if batch size is 1
- if output.shape[0] == 1:
- output = output[0]
-
- # Wrap if a SlidingWindowFeature was given as input
- if isinstance(waveform, SlidingWindowFeature):
- # Temporal resolution of the output
- duration = wave.shape[-1] / self.sample_rate
- resolution = duration / num_frames
- # Temporal shift to keep track of current start time
- resolution = SlidingWindow(
- start=waveform.sliding_window.start,
- duration=resolution,
- step=resolution
- )
- return SlidingWindowFeature(output.numpy(), resolution)
-
- if isinstance(waveform, np.ndarray):
- return output.numpy()
-
- return output
+ return self.formatter.restore_type(output)
class ChunkwiseModel:
@@ -95,13 +43,15 @@ def __init__(self, model: PipelineModel, device: Optional[torch.device] = None):
if device is None:
device = get_devices(needs=1)[0]
self.model.to(device)
+ self.waveform_formatter = TemporalFeatureFormatter()
+ self.weights_formatter = TemporalFeatureFormatter()
def __call__(self, waveform: TemporalFeatures, weights: Optional[TemporalFeatures]) -> torch.Tensor:
with torch.no_grad():
- inputs = resolve_features(waveform).to(self.model.device)
+ inputs = self.waveform_formatter.cast(waveform).to(self.model.device)
inputs = rearrange(inputs, "batch sample channel -> batch channel sample")
if weights is not None:
- weights = resolve_features(weights).to(self.model.device)
+ weights = self.weights_formatter.cast(weights).to(self.model.device)
batch_size, _, num_speakers = weights.shape
inputs = inputs.repeat(1, num_speakers, 1)
weights = rearrange(weights, "batch frame spk -> (batch spk) frame")
@@ -131,18 +81,15 @@ class OverlappedSpeechPenalty:
def __init__(self, gamma: float = 3, beta: float = 10):
self.gamma = gamma
self.beta = beta
+ self.formatter = TemporalFeatureFormatter()
def __call__(self, segmentation: TemporalFeatures) -> TemporalFeatures:
- weights = resolve_features(segmentation) # shape (batch, frames, speakers)
+ weights = self.formatter.cast(segmentation) # shape (batch, frames, speakers)
with torch.no_grad():
probs = torch.softmax(self.beta * weights, dim=-1)
weights = torch.pow(weights, self.gamma) * torch.pow(probs, self.gamma)
weights[weights < 1e-8] = 1e-8
- if isinstance(segmentation, SlidingWindowFeature):
- return SlidingWindowFeature(weights.cpu().numpy(), segmentation.sliding_window)
- if isinstance(segmentation, np.ndarray):
- return weights.cpu().numpy()
- return weights
+ return self.formatter.restore_type(weights)
class EmbeddingNormalization:
diff --git a/src/diart/features.py b/src/diart/features.py
new file mode 100644
index 00000000..eaf9d55e
--- /dev/null
+++ b/src/diart/features.py
@@ -0,0 +1,135 @@
+from typing import Union, Optional
+
+import numpy as np
+import torch
+from pyannote.core import SlidingWindow, SlidingWindowFeature
+
+
+TemporalFeatures = Union[SlidingWindowFeature, np.ndarray, torch.Tensor]
+
+
+class TemporalFeatureFormatterState:
+ """
+ Represents the recorded type of a temporal feature formatter.
+ Its job is to transform temporal features into tensors and
+ recover the original format on other features.
+ """
+ def to_tensor(self, features: TemporalFeatures) -> torch.Tensor:
+ raise NotImplementedError
+
+ def to_internal_type(self, features: torch.Tensor) -> TemporalFeatures:
+ """
+ Cast `features` to the representing type and remove batch dimension if required.
+
+ Parameters
+ ----------
+ features: torch.Tensor, shape (batch, frames, dim)
+ Batched temporal features.
+ Returns
+ -------
+ new_features: SlidingWindowFeature or numpy.ndarray or torch.Tensor, shape (batch, frames, dim)
+ """
+ raise NotImplementedError
+
+
+class SlidingWindowFeatureFormatterState(TemporalFeatureFormatterState):
+ def __init__(self, duration: float):
+ self.duration = duration
+ self._cur_start_time = 0
+
+ def to_tensor(self, features: SlidingWindowFeature) -> torch.Tensor:
+ msg = "Features sliding window duration and step must be equal"
+ assert features.sliding_window.duration == features.sliding_window.step, msg
+ self._cur_start_time = features.sliding_window.start
+ return torch.from_numpy(features.data)
+
+ def to_internal_type(self, features: torch.Tensor) -> TemporalFeatures:
+ batch_size, num_frames, _ = features.shape
+ assert batch_size == 1, "Batched SlidingWindowFeature objects are not supported"
+ # Calculate resolution
+ resolution = self.duration / num_frames
+ # Temporal shift to keep track of current start time
+ resolution = SlidingWindow(start=self._cur_start_time, duration=resolution, step=resolution)
+ return SlidingWindowFeature(features.squeeze(dim=0).cpu().numpy(), resolution)
+
+
+class NumpyArrayFormatterState(TemporalFeatureFormatterState):
+ def to_tensor(self, features: np.ndarray) -> torch.Tensor:
+ return torch.from_numpy(features)
+
+ def to_internal_type(self, features: torch.Tensor) -> TemporalFeatures:
+ return features.cpu().numpy()
+
+
+class PytorchTensorFormatterState(TemporalFeatureFormatterState):
+ def to_tensor(self, features: torch.Tensor) -> torch.Tensor:
+ return features
+
+ def to_internal_type(self, features: torch.Tensor) -> TemporalFeatures:
+ return features
+
+
+class TemporalFeatureFormatter:
+ """
+ Manages the typing and format of temporal features.
+ When casting temporal features as torch.Tensor, it remembers its
+ type and format so it can lately restore it on other temporal features.
+ """
+ def __init__(self):
+ self.state: Optional[TemporalFeatureFormatterState] = None
+
+ def set_state(self, features: TemporalFeatures):
+ if self.state is not None:
+ return
+
+ if isinstance(features, SlidingWindowFeature):
+ msg = "Features sliding window duration and step must be equal"
+ assert features.sliding_window.duration == features.sliding_window.step, msg
+ self.state = SlidingWindowFeatureFormatterState(
+ features.data.shape[0] * features.sliding_window.duration,
+ )
+ elif isinstance(features, np.ndarray):
+ self.state = NumpyArrayFormatterState()
+ elif isinstance(features, torch.Tensor):
+ self.state = PytorchTensorFormatterState()
+ else:
+ msg = "Unknown format. Provide one of SlidingWindowFeature, numpy.ndarray, torch.Tensor"
+ raise ValueError(msg)
+
+ def cast(self, features: TemporalFeatures) -> torch.Tensor:
+ """
+ Transform features into a `torch.Tensor` and add batch dimension if missing.
+
+ Parameters
+ ----------
+ features: SlidingWindowFeature or numpy.ndarray or torch.Tensor
+ Shape (frames, dim) or (batch, frames, dim)
+
+ Returns
+ -------
+ features: torch.Tensor, shape (batch, frames, dim)
+ """
+ # Set state if not initialized
+ self.set_state(features)
+ # Convert features to tensor
+ data = self.state.to_tensor(features)
+ # Make sure there's a batch dimension
+ msg = "Temporal features must be 2D or 3D"
+ assert data.ndim in (2, 3), msg
+ if data.ndim == 2:
+ data = data.unsqueeze(0)
+ return data.float()
+
+ def restore_type(self, features: torch.Tensor) -> TemporalFeatures:
+ """
+ Cast `features` to the internal type and remove batch dimension if required.
+
+ Parameters
+ ----------
+ features: torch.Tensor, shape (batch, frames, dim)
+ Batched temporal features.
+ Returns
+ -------
+ new_features: SlidingWindowFeature or numpy.ndarray or torch.Tensor, shape (batch, frames, dim)
+ """
+ return self.state.to_internal_type(features)
From 7842258a3cf3b196364409590a6bb6520681cbc8 Mon Sep 17 00:00:00 2001
From: Juan Manuel Coria
Date: Fri, 3 Jun 2022 10:42:27 +0200
Subject: [PATCH 06/29] Start reducing strong dependency on pyannote.audio
non-model classes
---
src/diart/models.py | 57 +++++++++++++++++++++++++++++++++++++-------
src/diart/sinks.py | 2 +-
src/diart/sources.py | 31 ++++++++++++++----------
3 files changed, 67 insertions(+), 23 deletions(-)
diff --git a/src/diart/models.py b/src/diart/models.py
index 74ae5c9b..db777d9d 100644
--- a/src/diart/models.py
+++ b/src/diart/models.py
@@ -2,7 +2,12 @@
import torch
import torch.nn as nn
-from pyannote.audio.pipelines.utils import PipelineModel, get_model
+
+try:
+ import pyannote.audio.pipelines.utils as pyannote
+ _has_pyannote = True
+except ImportError:
+ _has_pyannote = False
class SegmentationModel(nn.Module):
@@ -11,11 +16,24 @@ class SegmentationModel(nn.Module):
"""
@staticmethod
- def from_pyannote(model: PipelineModel) -> 'SegmentationModel':
+ def from_pyannote(model) -> 'SegmentationModel':
+ """
+ Returns a `SegmentationModel` wrapping a pyannote model.
+
+ Parameters
+ ----------
+ model: pyannote.PipelineModel
+
+ Returns
+ -------
+ wrapper: SegmentationModel
+ """
+ assert _has_pyannote, "No pyannote.audio installation found"
+
class PyannoteSegmentationModel(SegmentationModel):
- def __init__(self, pyannote_model: PipelineModel):
+ def __init__(self, pyannote_model):
super().__init__()
- self.model = get_model(pyannote_model)
+ self.model = pyannote.get_model(pyannote_model)
def get_sample_rate(self) -> int:
return self.model.audio.sample_rate
@@ -55,18 +73,39 @@ class EmbeddingModel(nn.Module):
"""Minimal interface for an embedding model."""
@staticmethod
- def from_pyannote(model: PipelineModel) -> 'EmbeddingModel':
+ def from_pyannote(model) -> 'EmbeddingModel':
+ """
+ Returns an `EmbeddingModel` wrapping a pyannote model.
+
+ Parameters
+ ----------
+ model: pyannote.PipelineModel
+
+ Returns
+ -------
+ wrapper: EmbeddingModel
+ """
+ assert _has_pyannote, "No pyannote.audio installation found"
+
class PyannoteEmbeddingModel(EmbeddingModel):
- def __init__(self, pyannote_model: PipelineModel):
+ def __init__(self, pyannote_model):
super().__init__()
- self.model = get_model(pyannote_model)
+ self.model = pyannote.get_model(pyannote_model)
- def __call__(self, waveform: torch.Tensor, weights: Optional[torch.Tensor] = None) -> torch.Tensor:
+ def __call__(
+ self,
+ waveform: torch.Tensor,
+ weights: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
return self.model(waveform, weights=weights)
return PyannoteEmbeddingModel(model)
- def __call__(self, waveform: torch.Tensor, weights: Optional[torch.Tensor] = None) -> torch.Tensor:
+ def __call__(
+ self,
+ waveform: torch.Tensor,
+ weights: Optional[torch.Tensor] = None
+ ) -> torch.Tensor:
"""
Forward pass of an embedding model with optional weights.
diff --git a/src/diart/sinks.py b/src/diart/sinks.py
index 7d0bb8d2..0d33bf01 100644
--- a/src/diart/sinks.py
+++ b/src/diart/sinks.py
@@ -1,13 +1,13 @@
from pathlib import Path
from traceback import print_exc
from typing import Union, Text, Optional, Tuple
-from typing_extensions import Literal
import matplotlib.pyplot as plt
from pyannote.core import Annotation, Segment, SlidingWindowFeature, notebook
from pyannote.database.util import load_rttm
from pyannote.metrics.diarization import DiarizationErrorRate
from rx.core import Observer
+from typing_extensions import Literal
class RTTMWriter(Observer):
diff --git a/src/diart/sources.py b/src/diart/sources.py
index 472c2a09..88012794 100644
--- a/src/diart/sources.py
+++ b/src/diart/sources.py
@@ -1,16 +1,21 @@
import random
import time
from queue import SimpleQueue
-from typing import Tuple, Text, Optional, Iterable, List
+from typing import Tuple, Text, Optional, Iterable, List, Union
+from pathlib import Path
import numpy as np
import sounddevice as sd
from einops import rearrange
-from pyannote.audio.core.io import Audio, AudioFile
+# TODO replace with torchaudio
+from pyannote.audio.core.io import Audio
from pyannote.core import SlidingWindowFeature, SlidingWindow
from rx.subject import Subject
+FilePath = Union[Text, Path]
+
+
class ChunkLoader:
"""Loads an audio file and chunks it according to a given window and step size.
@@ -36,7 +41,7 @@ def __init__(
self.window_samples = int(round(window_duration * sample_rate))
self.step_samples = int(round(step_duration * sample_rate))
- def get_chunks(self, file: AudioFile) -> np.ndarray:
+ def get_chunks(self, file: FilePath) -> np.ndarray:
waveform, _ = self.audio(file)
_, num_samples = waveform.shape
chunks = rearrange(
@@ -51,7 +56,7 @@ def get_chunks(self, file: AudioFile) -> np.ndarray:
return np.vstack([chunks, last_chunk])
return chunks
- def num_chunks(self, file: AudioFile) -> int:
+ def num_chunks(self, file: FilePath) -> int:
numerator = self.audio.get_duration(file) - self.window_duration + self.step_duration
return int(np.ceil(numerator / self.step_duration))
@@ -115,13 +120,13 @@ def is_regular(self) -> bool:
A regular reading method always yields the same amount of samples."""
return False
- def get_duration(self, file: AudioFile) -> float:
+ def get_duration(self, file: FilePath) -> float:
return self.audio.get_duration(file)
- def get_num_chunks(self, file: AudioFile) -> Optional[int]:
+ def get_num_chunks(self, file: FilePath) -> Optional[int]:
return None
- def iterate(self, file: AudioFile) -> Iterable[SlidingWindowFeature]:
+ def iterate(self, file: FilePath) -> Iterable[SlidingWindowFeature]:
"""Return an iterable over the file's samples"""
raise NotImplementedError
@@ -153,11 +158,11 @@ def __init__(
def is_regular(self) -> bool:
return True
- def get_num_chunks(self, file: AudioFile) -> Optional[int]:
+ def get_num_chunks(self, file: FilePath) -> Optional[int]:
"""Return the number of chunks emitted for `file`"""
return self.chunk_loader.num_chunks(file)
- def iterate(self, file: AudioFile) -> Iterable[SlidingWindowFeature]:
+ def iterate(self, file: FilePath) -> Iterable[SlidingWindowFeature]:
chunks = self.chunk_loader.get_chunks(file)
for i, chunk in enumerate(chunks):
w = SlidingWindow(
@@ -192,7 +197,7 @@ def __init__(
self.start, self.end = refresh_rate_range
self.delay = simulate_delay
- def iterate(self, file: AudioFile) -> Iterable[SlidingWindowFeature]:
+ def iterate(self, file: FilePath) -> Iterable[SlidingWindowFeature]:
waveform, _ = self.audio(file)
total_samples = waveform.shape[1]
i = 0
@@ -211,8 +216,8 @@ class FileAudioSource(AudioSource):
Parameters
----------
- file: AudioFile
- The file to stream.
+ file: FilePath
+ Path to the file to stream.
uri: Text
Unique identifier of the audio source.
reader: AudioFileReader
@@ -222,7 +227,7 @@ class FileAudioSource(AudioSource):
"""
def __init__(
self,
- file: AudioFile,
+ file: FilePath,
uri: Text,
reader: AudioFileReader,
profile: bool = False,
From f521a9e22ac97c489be84fc5ffd0ee383472795f Mon Sep 17 00:00:00 2001
From: Juan Manuel Coria
Date: Fri, 3 Jun 2022 18:11:16 +0200
Subject: [PATCH 07/29] Replace pyannote's Audio class with simpler custom
implementation to avoid hard dependency on pyannote.audio
---
requirements.txt | 4 +++
setup.cfg | 6 +++-
src/diart/pipelines.py | 2 +-
src/diart/sources.py | 66 +++++++++++++++++++++++++++++++++++-------
4 files changed, 66 insertions(+), 12 deletions(-)
diff --git a/requirements.txt b/requirements.txt
index bc7472a7..2a267b34 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -6,4 +6,8 @@ sounddevice>=0.4.2
einops>=0.3.0
tqdm>=4.64.0
pandas>=1.4.2
+torchaudio>=0.10,<1.0
+pyannote.core>=4.4
+pyannote.database>=4.1.1
+pyannote.metrics>=3.2
git+https://github.com/pyannote/pyannote-audio.git@develop#egg=pyannote-audio
diff --git a/setup.cfg b/setup.cfg
index ee69d733..9b6f5490 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -28,7 +28,11 @@ install_requires =
einops>=0.3.0
tqdm>=4.64.0
pandas>=1.4.2
- pyannote-audio @ git+https://github.com/pyannote/pyannote-audio.git@develop#egg=pyannote-audio
+ torchaudio>=0.10,<1.0
+ pyannote.core>=4.4
+ pyannote.database>=4.1.1
+ pyannote.metrics>=3.2
+ git+https://github.com/pyannote/pyannote-audio.git@develop#egg=pyannote-audio
[options.packages.find]
diff --git a/src/diart/pipelines.py b/src/diart/pipelines.py
index 2087f986..01d0d6a4 100644
--- a/src/diart/pipelines.py
+++ b/src/diart/pipelines.py
@@ -229,7 +229,7 @@ def from_file(
)
# Build speaker tracking pipeline
- duration = chunk_loader.audio.get_duration(file)
+ duration = chunk_loader.loader.get_duration(file)
return self.speaker_tracking.from_model_streams(
file.stem, duration, seg_stream, emb_stream, chunk_stream
)
diff --git a/src/diart/sources.py b/src/diart/sources.py
index 88012794..15d2f535 100644
--- a/src/diart/sources.py
+++ b/src/diart/sources.py
@@ -1,21 +1,67 @@
import random
import time
+from pathlib import Path
from queue import SimpleQueue
from typing import Tuple, Text, Optional, Iterable, List, Union
-from pathlib import Path
import numpy as np
import sounddevice as sd
+import torch
+import torchaudio
from einops import rearrange
-# TODO replace with torchaudio
-from pyannote.audio.core.io import Audio
from pyannote.core import SlidingWindowFeature, SlidingWindow
from rx.subject import Subject
+from torchaudio.functional import resample
+
+torchaudio.set_audio_backend("soundfile")
FilePath = Union[Text, Path]
+class AudioLoader:
+ def __init__(self, sample_rate: int, mono: bool = True):
+ self.sample_rate = sample_rate
+ self.mono = mono
+
+ def load(self, filepath: FilePath) -> torch.Tensor:
+ """
+ Load an audio file into a torch.Tensor.
+
+ Parameters
+ ----------
+ filepath : FilePath
+
+ Returns
+ -------
+ waveform : torch.Tensor, shape (channels, samples)
+ """
+ waveform, sample_rate = torchaudio.load(filepath)
+ # Get channel mean if mono
+ if self.mono and waveform.shape[0] > 1:
+ waveform = waveform.mean(dim=0, keepdim=True)
+ # Resample if needed
+ if self.sample_rate != sample_rate:
+ waveform = resample(waveform, sample_rate, self.sample_rate)
+ return waveform
+
+ @staticmethod
+ def get_duration(filepath: FilePath) -> float:
+ """Get audio file duration in seconds.
+
+ Parameters
+ ----------
+ filepath : FilePath
+
+ Returns
+ -------
+ duration : float
+ Duration in seconds.
+ """
+ info = torchaudio.info(filepath)
+ return info.num_frames / info.sample_rate
+
+
class ChunkLoader:
"""Loads an audio file and chunks it according to a given window and step size.
@@ -35,14 +81,14 @@ def __init__(
window_duration: float,
step_duration: float,
):
- self.audio = Audio(sample_rate, mono=True)
+ self.loader = AudioLoader(sample_rate, mono=True)
self.window_duration = window_duration
self.step_duration = step_duration
self.window_samples = int(round(window_duration * sample_rate))
self.step_samples = int(round(step_duration * sample_rate))
def get_chunks(self, file: FilePath) -> np.ndarray:
- waveform, _ = self.audio(file)
+ waveform = self.loader.load(file)
_, num_samples = waveform.shape
chunks = rearrange(
waveform.unfold(1, self.window_samples, self.step_samples),
@@ -57,7 +103,7 @@ def get_chunks(self, file: FilePath) -> np.ndarray:
return chunks
def num_chunks(self, file: FilePath) -> int:
- numerator = self.audio.get_duration(file) - self.window_duration + self.step_duration
+ numerator = self.loader.get_duration(file) - self.window_duration + self.step_duration
return int(np.ceil(numerator / self.step_duration))
@@ -107,12 +153,12 @@ class AudioFileReader:
Sample rate of the audio file.
"""
def __init__(self, sample_rate: int):
- self.audio = Audio(sample_rate=sample_rate, mono=True)
+ self.loader = AudioLoader(sample_rate, mono=True)
self.resolution = 1 / sample_rate
@property
def sample_rate(self) -> int:
- return self.audio.sample_rate
+ return self.loader.sample_rate
@property
def is_regular(self) -> bool:
@@ -121,7 +167,7 @@ def is_regular(self) -> bool:
return False
def get_duration(self, file: FilePath) -> float:
- return self.audio.get_duration(file)
+ return self.loader.get_duration(file)
def get_num_chunks(self, file: FilePath) -> Optional[int]:
return None
@@ -198,7 +244,7 @@ def __init__(
self.delay = simulate_delay
def iterate(self, file: FilePath) -> Iterable[SlidingWindowFeature]:
- waveform, _ = self.audio(file)
+ waveform = self.loader.load(file)
total_samples = waveform.shape[1]
i = 0
while i < total_samples:
From 230cda7f0d7fb5e3b85414c9cdc60bdd9093505d Mon Sep 17 00:00:00 2001
From: Juan Manuel Coria
Date: Mon, 6 Jun 2022 17:22:00 +0200
Subject: [PATCH 08/29] Merge AudioLoader and ChunkLoader
---
README.md | 3 ++
setup.cfg | 1 -
src/diart/audio.py | 101 ++++++++++++++++++++++++++++++++++
src/diart/inference.py | 10 ++--
src/diart/pipelines.py | 15 ++----
src/diart/sources.py | 119 ++++-------------------------------------
6 files changed, 124 insertions(+), 125 deletions(-)
create mode 100644 src/diart/audio.py
diff --git a/README.md b/README.md
index 5faf55d5..f12e0753 100644
--- a/README.md
+++ b/README.md
@@ -29,10 +29,13 @@ conda activate diart
2) [Install PyTorch](https://pytorch.org/get-started/locally/#start-locally)
3) Install pyannote.audio 2.0 (currently in development)
+
```shell
pip install git+https://github.com/pyannote/pyannote-audio.git@develop#egg=pyannote-audio
```
+*Note:* starting from version 0.4, installing pyannote.audio is mandatory to run the default system or to use pyannote-based models. In any other case, this step can be ignored.
+
4) Install diart:
```shell
pip install diart
diff --git a/setup.cfg b/setup.cfg
index 9b6f5490..d2d5c5b3 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -32,7 +32,6 @@ install_requires =
pyannote.core>=4.4
pyannote.database>=4.1.1
pyannote.metrics>=3.2
- git+https://github.com/pyannote/pyannote-audio.git@develop#egg=pyannote-audio
[options.packages.find]
diff --git a/src/diart/audio.py b/src/diart/audio.py
new file mode 100644
index 00000000..e607ba5f
--- /dev/null
+++ b/src/diart/audio.py
@@ -0,0 +1,101 @@
+from pathlib import Path
+from typing import Text, Union
+
+import numpy as np
+import torch
+import torchaudio
+from einops import rearrange
+from torchaudio.functional import resample
+
+torchaudio.set_audio_backend("soundfile")
+
+
+FilePath = Union[Text, Path]
+
+
+class AudioLoader:
+ def __init__(self, sample_rate: int, mono: bool = True):
+ self.sample_rate = sample_rate
+ self.mono = mono
+
+ def load(self, filepath: FilePath) -> torch.Tensor:
+ """Load an audio file into a torch.Tensor.
+
+ Parameters
+ ----------
+ filepath : FilePath
+ Path to an audio file
+
+ Returns
+ -------
+ waveform : torch.Tensor, shape (channels, samples)
+ """
+ waveform, sample_rate = torchaudio.load(filepath)
+ # Get channel mean if mono
+ if self.mono and waveform.shape[0] > 1:
+ waveform = waveform.mean(dim=0, keepdim=True)
+ # Resample if needed
+ if self.sample_rate != sample_rate:
+ waveform = resample(waveform, sample_rate, self.sample_rate)
+ return waveform
+
+ @staticmethod
+ def get_duration(filepath: FilePath) -> float:
+ """Get audio file duration in seconds.
+
+ Parameters
+ ----------
+ filepath : FilePath
+ Path to an audio file
+
+ Returns
+ -------
+ duration : float
+ Duration in seconds.
+ """
+ info = torchaudio.info(filepath)
+ return info.num_frames / info.sample_rate
+
+ def load_sliding_chunks(self, filepath: FilePath, chunk_duration: float, step_duration: float) -> np.ndarray:
+ """Load an audio file and extract sliding chunks of a given duration with a given step duration.
+
+ Parameters
+ ----------
+ filepath : FilePath
+ Path to an audio file
+ chunk_duration: float
+ Duration of the chunk in seconds.
+ step_duration: float
+ Duration of the step between chunks in seconds.
+ """
+ chunk_samples = int(round(chunk_duration * self.sample_rate))
+ step_samples = int(round(step_duration * self.sample_rate))
+ waveform = self.load(filepath)
+ _, num_samples = waveform.shape
+ chunks = rearrange(
+ waveform.unfold(1, chunk_samples, step_samples),
+ "channel chunk sample -> chunk channel sample",
+ ).numpy()
+ # Add padded last chunk
+ if num_samples - chunk_samples % step_samples > 0:
+ last_chunk = waveform[:, chunks.shape[0] * step_samples:].unsqueeze(0).numpy()
+ diff_samples = chunk_samples - last_chunk.shape[-1]
+ last_chunk = np.concatenate([last_chunk, np.zeros((1, 1, diff_samples))], axis=-1)
+ return np.vstack([chunks, last_chunk])
+ return chunks
+
+ def get_num_sliding_chunks(self, filepath: FilePath, chunk_duration: float, step_duration: float) -> int:
+ """Estimate the number of sliding chunks of a
+ given chunk duration and step without loading the audio.
+
+ Parameters
+ ----------
+ filepath : FilePath
+ Path to an audio file
+ chunk_duration: float
+ Duration of the chunk in seconds.
+ step_duration: float
+ Duration of the step between chunks in seconds.
+ """
+ numerator = self.get_duration(filepath) - chunk_duration + step_duration
+ return int(np.ceil(numerator / step_duration))
\ No newline at end of file
diff --git a/src/diart/inference.py b/src/diart/inference.py
index 4574279b..da9dfbc2 100644
--- a/src/diart/inference.py
+++ b/src/diart/inference.py
@@ -125,15 +125,13 @@ def __call__(self, pipeline: OnlineSpeakerDiarization, batch_size: int = 32) ->
DataFrame with detailed performance on each file, as well as average performance.
None if the reference is not provided.
"""
- chunk_loader = src.ChunkLoader(
- pipeline.config.sample_rate,
- pipeline.config.duration,
- pipeline.config.step
- )
+ loader = src.AudioLoader(pipeline.config.sample_rate, mono=True)
audio_file_paths = list(self.speech_path.iterdir())
num_audio_files = len(audio_file_paths)
for i, filepath in enumerate(audio_file_paths):
- num_chunks = chunk_loader.num_chunks(filepath)
+ num_chunks = loader.get_num_sliding_chunks(
+ filepath, pipeline.config.duration, pipeline.config.step
+ )
# Stream fully online if batch size is 1 or lower
source = None
diff --git a/src/diart/pipelines.py b/src/diart/pipelines.py
index 01d0d6a4..4ff453ea 100644
--- a/src/diart/pipelines.py
+++ b/src/diart/pipelines.py
@@ -1,6 +1,6 @@
import math
from pathlib import Path
-from typing import Optional, Union, Text
+from typing import Optional, Text
import numpy as np
import rx
@@ -167,20 +167,16 @@ def from_source(
def from_file(
self,
- file: Union[Text, Path],
+ filepath: src.FilePath,
output_waveform: bool = False,
batch_size: int = 32,
desc: Optional[Text] = None,
) -> rx.Observable:
- # Audio file information
- file = Path(file)
- chunk_loader = src.ChunkLoader(
- self.config.sample_rate, self.config.duration, self.config.step
- )
+ loader = src.AudioLoader(self.config.sample_rate, mono=True)
# Split audio into chunks
chunks = rearrange(
- chunk_loader.get_chunks(file),
+ loader.load_sliding_chunks(filepath, self.config.duration, self.config.step),
"chunk channel sample -> chunk sample channel"
)
num_chunks = chunks.shape[0]
@@ -229,7 +225,6 @@ def from_file(
)
# Build speaker tracking pipeline
- duration = chunk_loader.loader.get_duration(file)
return self.speaker_tracking.from_model_streams(
- file.stem, duration, seg_stream, emb_stream, chunk_stream
+ Path(filepath).stem, loader.get_duration(filepath), seg_stream, emb_stream, chunk_stream
)
diff --git a/src/diart/sources.py b/src/diart/sources.py
index 15d2f535..93a33b3a 100644
--- a/src/diart/sources.py
+++ b/src/diart/sources.py
@@ -1,110 +1,14 @@
import random
import time
-from pathlib import Path
from queue import SimpleQueue
-from typing import Tuple, Text, Optional, Iterable, List, Union
+from typing import Tuple, Text, Optional, Iterable, List
import numpy as np
import sounddevice as sd
-import torch
-import torchaudio
-from einops import rearrange
from pyannote.core import SlidingWindowFeature, SlidingWindow
from rx.subject import Subject
-from torchaudio.functional import resample
-torchaudio.set_audio_backend("soundfile")
-
-
-FilePath = Union[Text, Path]
-
-
-class AudioLoader:
- def __init__(self, sample_rate: int, mono: bool = True):
- self.sample_rate = sample_rate
- self.mono = mono
-
- def load(self, filepath: FilePath) -> torch.Tensor:
- """
- Load an audio file into a torch.Tensor.
-
- Parameters
- ----------
- filepath : FilePath
-
- Returns
- -------
- waveform : torch.Tensor, shape (channels, samples)
- """
- waveform, sample_rate = torchaudio.load(filepath)
- # Get channel mean if mono
- if self.mono and waveform.shape[0] > 1:
- waveform = waveform.mean(dim=0, keepdim=True)
- # Resample if needed
- if self.sample_rate != sample_rate:
- waveform = resample(waveform, sample_rate, self.sample_rate)
- return waveform
-
- @staticmethod
- def get_duration(filepath: FilePath) -> float:
- """Get audio file duration in seconds.
-
- Parameters
- ----------
- filepath : FilePath
-
- Returns
- -------
- duration : float
- Duration in seconds.
- """
- info = torchaudio.info(filepath)
- return info.num_frames / info.sample_rate
-
-
-class ChunkLoader:
- """Loads an audio file and chunks it according to a given window and step size.
-
- Parameters
- ----------
- sample_rate: int
- Sample rate to load audio.
- window_duration: float
- Duration of the chunk in seconds.
- step_duration: float
- Duration of the step between chunks in seconds.
- """
-
- def __init__(
- self,
- sample_rate: int,
- window_duration: float,
- step_duration: float,
- ):
- self.loader = AudioLoader(sample_rate, mono=True)
- self.window_duration = window_duration
- self.step_duration = step_duration
- self.window_samples = int(round(window_duration * sample_rate))
- self.step_samples = int(round(step_duration * sample_rate))
-
- def get_chunks(self, file: FilePath) -> np.ndarray:
- waveform = self.loader.load(file)
- _, num_samples = waveform.shape
- chunks = rearrange(
- waveform.unfold(1, self.window_samples, self.step_samples),
- "channel chunk frame -> chunk channel frame",
- ).numpy()
- # Add padded last chunk
- if num_samples - self.window_samples % self.step_samples > 0:
- last_chunk = waveform[:, chunks.shape[0] * self.step_samples:].unsqueeze(0).numpy()
- diff_samples = self.window_samples - last_chunk.shape[-1]
- last_chunk = np.concatenate([last_chunk, np.zeros((1, 1, diff_samples))], axis=-1)
- return np.vstack([chunks, last_chunk])
- return chunks
-
- def num_chunks(self, file: FilePath) -> int:
- numerator = self.loader.get_duration(file) - self.window_duration + self.step_duration
- return int(np.ceil(numerator / self.step_duration))
+from .audio import FilePath, AudioLoader
class AudioSource:
@@ -184,7 +88,7 @@ class RegularAudioFileReader(AudioFileReader):
----------
sample_rate: int
Sample rate of the audio file.
- window_duration: float
+ chunk_duration: float
Duration of each chunk of samples (window) in seconds.
step_duration: float
Step duration between chunks in seconds.
@@ -192,27 +96,26 @@ class RegularAudioFileReader(AudioFileReader):
def __init__(
self,
sample_rate: int,
- window_duration: float,
+ chunk_duration: float,
step_duration: float,
):
super().__init__(sample_rate)
- self.chunk_loader = ChunkLoader(
- sample_rate, window_duration, step_duration
- )
+ self.chunk_duration = chunk_duration
+ self.step_duration = step_duration
@property
def is_regular(self) -> bool:
return True
- def get_num_chunks(self, file: FilePath) -> Optional[int]:
- """Return the number of chunks emitted for `file`"""
- return self.chunk_loader.num_chunks(file)
+ def get_num_chunks(self, filepath: FilePath) -> Optional[int]:
+ """Return the number of chunks that will be emitted for a given file"""
+ return self.loader.get_num_sliding_chunks(filepath, self.chunk_duration, self.step_duration)
def iterate(self, file: FilePath) -> Iterable[SlidingWindowFeature]:
- chunks = self.chunk_loader.get_chunks(file)
+ chunks = self.loader.load_sliding_chunks(file, self.chunk_duration, self.step_duration)
for i, chunk in enumerate(chunks):
w = SlidingWindow(
- start=i * self.chunk_loader.step_duration,
+ start=i * self.step_duration,
duration=self.resolution,
step=self.resolution
)
From 0e9e718cfed41d177df93ebef390ade129d35a15 Mon Sep 17 00:00:00 2001
From: Juan Coria
Date: Thu, 9 Jun 2022 14:05:00 +0200
Subject: [PATCH 09/29] Update README.md
---
README.md | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/README.md b/README.md
index f12e0753..bbab03fa 100644
--- a/README.md
+++ b/README.md
@@ -34,7 +34,7 @@ conda activate diart
pip install git+https://github.com/pyannote/pyannote-audio.git@develop#egg=pyannote-audio
```
-*Note:* starting from version 0.4, installing pyannote.audio is mandatory to run the default system or to use pyannote-based models. In any other case, this step can be ignored.
+**Note:** starting from version 0.4, installing pyannote.audio is mandatory to run the default system or to use pyannote-based models. In any other case, this step can be ignored.
4) Install diart:
```shell
From 9110d957be7cd70e8f39884f4f6485b6ab3c5a25 Mon Sep 17 00:00:00 2001
From: Juan Manuel Coria
Date: Thu, 9 Jun 2022 15:33:11 +0200
Subject: [PATCH 10/29] Remove file audio readers and make all file reading
regular
---
src/diart/inference.py | 8 +--
src/diart/sources.py | 160 +++++++++--------------------------------
src/diart/stream.py | 12 ++--
3 files changed, 41 insertions(+), 139 deletions(-)
diff --git a/src/diart/inference.py b/src/diart/inference.py
index da9dfbc2..b45de89f 100644
--- a/src/diart/inference.py
+++ b/src/diart/inference.py
@@ -139,11 +139,9 @@ def __call__(self, pipeline: OnlineSpeakerDiarization, batch_size: int = 32) ->
source = src.FileAudioSource(
filepath,
filepath.stem,
- src.RegularAudioFileReader(
- pipeline.config.sample_rate,
- pipeline.config.duration,
- pipeline.config.step
- ),
+ pipeline.config.sample_rate,
+ pipeline.config.duration,
+ pipeline.config.step,
# Benchmark the processing time of a single chunk
profile=True,
)
diff --git a/src/diart/sources.py b/src/diart/sources.py
index 93a33b3a..8af379e9 100644
--- a/src/diart/sources.py
+++ b/src/diart/sources.py
@@ -1,7 +1,6 @@
-import random
import time
from queue import SimpleQueue
-from typing import Tuple, Text, Optional, Iterable, List
+from typing import Text, Optional, List
import numpy as np
import sounddevice as sd
@@ -48,118 +47,6 @@ def read(self):
raise NotImplementedError
-class AudioFileReader:
- """Represents a method for reading an audio file.
-
- Parameters
- ----------
- sample_rate: int
- Sample rate of the audio file.
- """
- def __init__(self, sample_rate: int):
- self.loader = AudioLoader(sample_rate, mono=True)
- self.resolution = 1 / sample_rate
-
- @property
- def sample_rate(self) -> int:
- return self.loader.sample_rate
-
- @property
- def is_regular(self) -> bool:
- """Whether the reading is regular. Defaults to False.
- A regular reading method always yields the same amount of samples."""
- return False
-
- def get_duration(self, file: FilePath) -> float:
- return self.loader.get_duration(file)
-
- def get_num_chunks(self, file: FilePath) -> Optional[int]:
- return None
-
- def iterate(self, file: FilePath) -> Iterable[SlidingWindowFeature]:
- """Return an iterable over the file's samples"""
- raise NotImplementedError
-
-
-class RegularAudioFileReader(AudioFileReader):
- """Reads a file always yielding the same number of samples with a given step.
-
- Parameters
- ----------
- sample_rate: int
- Sample rate of the audio file.
- chunk_duration: float
- Duration of each chunk of samples (window) in seconds.
- step_duration: float
- Step duration between chunks in seconds.
- """
- def __init__(
- self,
- sample_rate: int,
- chunk_duration: float,
- step_duration: float,
- ):
- super().__init__(sample_rate)
- self.chunk_duration = chunk_duration
- self.step_duration = step_duration
-
- @property
- def is_regular(self) -> bool:
- return True
-
- def get_num_chunks(self, filepath: FilePath) -> Optional[int]:
- """Return the number of chunks that will be emitted for a given file"""
- return self.loader.get_num_sliding_chunks(filepath, self.chunk_duration, self.step_duration)
-
- def iterate(self, file: FilePath) -> Iterable[SlidingWindowFeature]:
- chunks = self.loader.load_sliding_chunks(file, self.chunk_duration, self.step_duration)
- for i, chunk in enumerate(chunks):
- w = SlidingWindow(
- start=i * self.step_duration,
- duration=self.resolution,
- step=self.resolution
- )
- yield SlidingWindowFeature(chunk.T, w)
-
-
-class IrregularAudioFileReader(AudioFileReader):
- """Reads an audio file yielding a different number of non-overlapping samples in each event.
- This class is useful to simulate how a system would work in unreliable reading conditions.
-
- Parameters
- ----------
- sample_rate: int
- Sample rate of the audio file.
- refresh_rate_range: (float, float)
- Duration range within which to determine the number of samples to yield (in seconds).
- simulate_delay: bool
- Whether to simulate that the samples are being read in real time before they are yielded.
- Defaults to False (no delay).
- """
- def __init__(
- self,
- sample_rate: int,
- refresh_rate_range: Tuple[float, float],
- simulate_delay: bool = False,
- ):
- super().__init__(sample_rate)
- self.start, self.end = refresh_rate_range
- self.delay = simulate_delay
-
- def iterate(self, file: FilePath) -> Iterable[SlidingWindowFeature]:
- waveform = self.loader.load(file)
- total_samples = waveform.shape[1]
- i = 0
- while i < total_samples:
- rnd_duration = random.uniform(self.start, self.end)
- if self.delay:
- time.sleep(rnd_duration)
- num_samples = int(round(rnd_duration * self.sample_rate))
- last_i = i
- i += num_samples
- yield waveform[:, last_i:i]
-
-
class FileAudioSource(AudioSource):
"""Represents an audio source tied to a file.
@@ -169,8 +56,12 @@ class FileAudioSource(AudioSource):
Path to the file to stream.
uri: Text
Unique identifier of the audio source.
- reader: AudioFileReader
- Determines how the file will be read.
+ sample_rate: int
+ Sample rate of the chunks emitted.
+ chunk_duration: float
+ Duration of each chunk in seconds. Defaults to 5s.
+ step_duration: float
+ Duration of the step between consecutive chunks in seconds. Defaults to 500ms.
profile: bool
If True, prints the average processing time of emitting a chunk. Defaults to False.
"""
@@ -178,19 +69,24 @@ def __init__(
self,
file: FilePath,
uri: Text,
- reader: AudioFileReader,
+ sample_rate: int,
+ chunk_duration: float = 5,
+ step_duration: float = 0.5,
profile: bool = False,
):
- super().__init__(uri, reader.sample_rate)
- self.reader = reader
- self._duration = self.reader.get_duration(file)
+ super().__init__(uri, sample_rate)
+ self.loader = AudioLoader(sample_rate, mono=True)
+ self._duration = self.loader.get_duration(file)
self.file = file
+ self.chunk_duration = chunk_duration
+ self.step_duration = step_duration
self.profile = profile
+ self.resolution = 1 / sample_rate
@property
def is_regular(self) -> bool:
- # The regularity depends on the reader
- return self.reader.is_regular
+ # An audio file is always a regular source
+ return True
@property
def duration(self) -> Optional[float]:
@@ -199,8 +95,9 @@ def duration(self) -> Optional[float]:
@property
def length(self) -> Optional[int]:
- # Only the reader can know how many chunks are going to be emitted
- return self.reader.get_num_chunks(self.file)
+ return self.loader.get_num_sliding_chunks(
+ self.file, self.chunk_duration, self.step_duration
+ )
def _check_print_time(self, times: List[float]):
if self.profile:
@@ -213,15 +110,24 @@ def _check_print_time(self, times: List[float]):
def read(self):
"""Send each chunk of samples through the stream"""
times = []
- for waveform in self.reader.iterate(self.file):
+ chunks = self.loader.load_sliding_chunks(
+ self.file, self.chunk_duration, self.step_duration
+ )
+ for i, waveform in enumerate(chunks):
+ window = SlidingWindow(
+ start=i * self.step_duration,
+ duration=self.resolution,
+ step=self.resolution
+ )
+ chunk = SlidingWindowFeature(waveform.T, window)
try:
if self.profile:
# Profiling assumes that on_next is blocking
start_time = time.monotonic()
- self.stream.on_next(waveform)
+ self.stream.on_next(chunk)
times.append(time.monotonic() - start_time)
else:
- self.stream.on_next(waveform)
+ self.stream.on_next(chunk)
except Exception as e:
self._check_print_time(times)
self.stream.on_error(e)
diff --git a/src/diart/stream.py b/src/diart/stream.py
index 4d4bb819..da6ba97c 100644
--- a/src/diart/stream.py
+++ b/src/diart/stream.py
@@ -43,13 +43,11 @@
args.source = Path(args.source).expanduser()
args.output = args.source.parent if args.output is None else Path(args.output)
audio_source = src.FileAudioSource(
- file=args.source,
- uri=args.source.stem,
- reader=src.RegularAudioFileReader(
- pipeline.config.sample_rate,
- pipeline.config.duration,
- pipeline.config.step,
- ),
+ args.source,
+ args.source.stem,
+ pipeline.config.sample_rate,
+ pipeline.config.duration,
+ pipeline.config.step,
)
else:
args.output = Path("~/").expanduser() if args.output is None else Path(args.output)
From 935f8f45d7bef3ad8056837f98d60fb00fb8761d Mon Sep 17 00:00:00 2001
From: Juan Manuel Coria
Date: Thu, 9 Jun 2022 17:40:32 +0200
Subject: [PATCH 11/29] Move real time profiling from audio source to
OnlineSpeakerDiarization pipeline
---
src/diart/benchmark.py | 2 +-
src/diart/inference.py | 2 -
src/diart/operators.py | 34 ++++++++++++++
src/diart/pipelines.py | 103 ++++++++++++++++++++---------------------
src/diart/sources.py | 27 +----------
src/diart/stream.py | 2 +-
src/diart/utils.py | 6 ++-
7 files changed, 92 insertions(+), 84 deletions(-)
diff --git a/src/diart/benchmark.py b/src/diart/benchmark.py
index b161b708..713b54b6 100644
--- a/src/diart/benchmark.py
+++ b/src/diart/benchmark.py
@@ -38,6 +38,6 @@
beta=args.beta,
max_speakers=args.max_speakers,
device=torch.device("cpu") if args.cpu else None,
- ))
+ ), profile=True)
benchmark(pipeline, args.batch_size)
diff --git a/src/diart/inference.py b/src/diart/inference.py
index b45de89f..e1cd95a3 100644
--- a/src/diart/inference.py
+++ b/src/diart/inference.py
@@ -142,8 +142,6 @@ def __call__(self, pipeline: OnlineSpeakerDiarization, batch_size: int = 32) ->
pipeline.config.sample_rate,
pipeline.config.duration,
pipeline.config.step,
- # Benchmark the processing time of a single chunk
- profile=True,
)
observable = pipeline.from_source(source, output_waveform=False)
else:
diff --git a/src/diart/operators.py b/src/diart/operators.py
index 646add85..cdd92dba 100644
--- a/src/diart/operators.py
+++ b/src/diart/operators.py
@@ -1,3 +1,4 @@
+import time
from dataclasses import dataclass
from typing import Callable, Optional, List, Any, Tuple, Text
@@ -300,3 +301,36 @@ def progress(
on_completed=lambda: pbar.close(),
)
)
+
+
+class _Chronometer:
+ def __init__(self):
+ self.current_start_time = None
+ self.history = []
+
+ def start(self):
+ self.current_start_time = time.monotonic()
+
+ def stop(self):
+ end_time = time.monotonic() - self.current_start_time
+ self.current_start_time = None
+ self.history.append(end_time)
+
+ def report(self):
+ print(
+ f"Stream took {np.mean(self.history).item():.3f} "
+ f"(+/-{np.std(self.history).item():.3f}) seconds/chunk "
+ f"-- based on {len(self.history)} chunks"
+ )
+
+
+def profile(observable: rx.Observable, operations: List[Operator]) -> rx.Observable:
+ chronometer = _Chronometer()
+ return observable.pipe(
+ ops.do_action(lambda _: chronometer.start()),
+ *operations,
+ ops.do_action(
+ on_next=lambda _: chronometer.stop(),
+ on_completed=chronometer.report,
+ )
+ )
diff --git a/src/diart/pipelines.py b/src/diart/pipelines.py
index 4ff453ea..1c61de96 100644
--- a/src/diart/pipelines.py
+++ b/src/diart/pipelines.py
@@ -1,6 +1,6 @@
import math
from pathlib import Path
-from typing import Optional, Text
+from typing import Optional, Text, List, Tuple
import numpy as np
import rx
@@ -14,6 +14,7 @@
from . import models as m
from . import operators as dops
from . import sources as src
+from . import utils
class PipelineConfig:
@@ -83,18 +84,15 @@ class OnlineSpeakerTracking:
def __init__(self, config: PipelineConfig):
self.config = config
- def from_model_streams(
+ def get_end_time(self, duration: Optional[float]) -> Optional[float]:
+ return None if duration is None else self.config.last_chunk_end_time(duration)
+
+ def get_operators(
self,
- uri: Text,
+ source_uri: Text,
source_duration: Optional[float],
- segmentation_stream: rx.Observable,
- embedding_stream: rx.Observable,
- audio_chunk_stream: Optional[rx.Observable] = None,
- ) -> rx.Observable:
- end_time = None
- if source_duration is not None:
- end_time = self.config.last_chunk_end_time(source_duration)
- # Initialize clustering and aggregation modules
+ output_waveform: bool = True
+ ) -> List[dops.Operator]:
clustering = blocks.OnlineSpeakerClustering(
self.config.tau_active,
self.config.rho_update,
@@ -102,39 +100,33 @@ def from_model_streams(
"cosine",
self.config.max_speakers,
)
- aggregation = blocks.DelayedAggregation(
+ end_time = self.get_end_time(source_duration)
+ pred_aggregation = blocks.DelayedAggregation(
self.config.step, self.config.latency, strategy="hamming", stream_end=end_time
)
- binarize = blocks.Binarize(uri, self.config.tau_active)
-
- # Join segmentation and embedding streams to update a background clustering model
- # while regulating latency and binarizing the output
- pipeline = rx.zip(segmentation_stream, embedding_stream).pipe(
- ops.starmap(clustering),
+ audio_aggregation = blocks.DelayedAggregation(
+ self.config.step, self.config.latency, strategy="first", stream_end=end_time
+ )
+ binarize = blocks.Binarize(source_uri, self.config.tau_active)
+ return [
+ # Identify global speakers with online clustering
+ ops.starmap(lambda wav, seg, emb: (wav, clustering(seg, emb))),
# Buffer 'num_overlapping' sliding chunks with a step of 1 chunk
- dops.buffer_slide(aggregation.num_overlapping_windows),
+ dops.buffer_slide(pred_aggregation.num_overlapping_windows),
# Aggregate overlapping output windows
- ops.map(aggregation),
+ ops.map(lambda buffers: utils.unzip(buffers)),
+ ops.starmap(lambda wav_buffer, pred_buffer: (
+ audio_aggregation(wav_buffer), pred_aggregation(pred_buffer)
+ )),
# Binarize output
- ops.map(binarize),
- )
- # Add corresponding waveform to the output
- if audio_chunk_stream is not None:
- window_selector = blocks.DelayedAggregation(
- self.config.step, self.config.latency, strategy="first", stream_end=end_time
- )
- waveform_stream = audio_chunk_stream.pipe(
- dops.buffer_slide(window_selector.num_overlapping_windows),
- ops.map(window_selector),
- )
- return rx.zip(pipeline, waveform_stream)
- # No waveform needed, add None for consistency
- return pipeline.pipe(ops.map(lambda ann: (ann, None)))
+ ops.starmap(lambda wav, pred: (binarize(pred), wav if output_waveform else None)),
+ ]
class OnlineSpeakerDiarization:
- def __init__(self, config: PipelineConfig):
+ def __init__(self, config: PipelineConfig, profile: bool = False):
self.config = config
+ self.profile = profile
self.segmentation = blocks.SpeakerSegmentation(config.segmentation, config.device)
self.embedding = blocks.OverlapAwareSpeakerEmbedding(
config.embedding, config.gamma, config.beta, norm=1, device=config.device
@@ -143,27 +135,26 @@ def __init__(self, config: PipelineConfig):
msg = f"Latency should be in the range [{config.step}, {config.duration}]"
assert config.step <= config.latency <= config.duration, msg
- def from_source(
- self,
- source: src.AudioSource,
- output_waveform: bool = True
- ) -> rx.Observable:
+ def from_source(self, source: src.AudioSource, output_waveform: bool = True) -> rx.Observable:
msg = f"Audio source has sample rate {source.sample_rate}, expected {self.config.sample_rate}"
assert source.sample_rate == self.config.sample_rate, msg
+ operators = []
# Regularize the stream to a specific chunk duration and step
- regular_stream = source.stream
if not source.is_regular:
- regular_stream = source.stream.pipe(
- dops.regularize_stream(self.config.duration, self.config.step, source.sample_rate)
- )
- # Branch the stream to calculate chunk segmentation
- seg_stream = regular_stream.pipe(ops.map(self.segmentation))
- # Join audio and segmentation stream to calculate overlap-aware speaker embeddings
- emb_stream = rx.zip(regular_stream, seg_stream).pipe(ops.starmap(self.embedding))
- chunk_stream = regular_stream if output_waveform else None
- return self.speaker_tracking.from_model_streams(
- source.uri, source.duration, seg_stream, emb_stream, chunk_stream
- )
+ operators.append(dops.regularize_stream(
+ self.config.duration, self.config.step, source.sample_rate
+ ))
+ operators += [
+ # Extract segmentation and keep audio
+ ops.map(lambda wav: (wav, self.segmentation(wav))),
+ # Extract embeddings and keep segmentation
+ ops.starmap(lambda wav, seg: (wav, seg, self.embedding(wav, seg))),
+ ]
+ # Add speaker tracking
+ operators += self.speaker_tracking.get_operators(source.uri, source.duration, output_waveform)
+ if self.profile:
+ return dops.profile(source.stream, operators)
+ return source.stream.pipe(*operators)
def from_file(
self,
@@ -225,6 +216,10 @@ def from_file(
)
# Build speaker tracking pipeline
- return self.speaker_tracking.from_model_streams(
- Path(filepath).stem, loader.get_duration(filepath), seg_stream, emb_stream, chunk_stream
+ return rx.zip(chunk_stream, seg_stream, emb_stream).pipe(
+ *self.speaker_tracking.get_operators(
+ Path(filepath).stem,
+ loader.get_duration(filepath),
+ output_waveform,
+ )
)
diff --git a/src/diart/sources.py b/src/diart/sources.py
index 8af379e9..ad6c961a 100644
--- a/src/diart/sources.py
+++ b/src/diart/sources.py
@@ -1,8 +1,6 @@
-import time
from queue import SimpleQueue
-from typing import Text, Optional, List
+from typing import Text, Optional
-import numpy as np
import sounddevice as sd
from pyannote.core import SlidingWindowFeature, SlidingWindow
from rx.subject import Subject
@@ -62,8 +60,6 @@ class FileAudioSource(AudioSource):
Duration of each chunk in seconds. Defaults to 5s.
step_duration: float
Duration of the step between consecutive chunks in seconds. Defaults to 500ms.
- profile: bool
- If True, prints the average processing time of emitting a chunk. Defaults to False.
"""
def __init__(
self,
@@ -72,7 +68,6 @@ def __init__(
sample_rate: int,
chunk_duration: float = 5,
step_duration: float = 0.5,
- profile: bool = False,
):
super().__init__(uri, sample_rate)
self.loader = AudioLoader(sample_rate, mono=True)
@@ -80,7 +75,6 @@ def __init__(
self.file = file
self.chunk_duration = chunk_duration
self.step_duration = step_duration
- self.profile = profile
self.resolution = 1 / sample_rate
@property
@@ -99,17 +93,8 @@ def length(self) -> Optional[int]:
self.file, self.chunk_duration, self.step_duration
)
- def _check_print_time(self, times: List[float]):
- if self.profile:
- print(
- f"File {self.uri}: took {np.mean(times).item():.2f} seconds/chunk "
- f"(+/- {np.std(times).item():.2f} seconds/chunk) "
- f"-- based on {len(times)} inputs"
- )
-
def read(self):
"""Send each chunk of samples through the stream"""
- times = []
chunks = self.loader.load_sliding_chunks(
self.file, self.chunk_duration, self.step_duration
)
@@ -121,18 +106,10 @@ def read(self):
)
chunk = SlidingWindowFeature(waveform.T, window)
try:
- if self.profile:
- # Profiling assumes that on_next is blocking
- start_time = time.monotonic()
- self.stream.on_next(chunk)
- times.append(time.monotonic() - start_time)
- else:
- self.stream.on_next(chunk)
+ self.stream.on_next(chunk)
except Exception as e:
- self._check_print_time(times)
self.stream.on_error(e)
self.stream.on_completed()
- self._check_print_time(times)
class MicrophoneAudioSource(AudioSource):
diff --git a/src/diart/stream.py b/src/diart/stream.py
index da6ba97c..e843ed08 100644
--- a/src/diart/stream.py
+++ b/src/diart/stream.py
@@ -36,7 +36,7 @@
beta=args.beta,
max_speakers=args.max_speakers,
device=torch.device("cpu") if args.cpu else None,
- ))
+ ), profile=True)
# Manage audio source
if args.source != "microphone":
diff --git a/src/diart/utils.py b/src/diart/utils.py
index 7e8214a9..ff843193 100644
--- a/src/diart/utils.py
+++ b/src/diart/utils.py
@@ -1,9 +1,13 @@
-from typing import Optional
+from typing import Optional, List, Tuple
import matplotlib.pyplot as plt
from pyannote.core import Annotation, Segment, SlidingWindowFeature, notebook
+def unzip(zipped: List[Tuple]) -> Tuple:
+ return tuple(zip(*zipped))
+
+
def visualize_feature(duration: Optional[float] = None):
def apply(feature: SlidingWindowFeature):
if duration is None:
From f841b5643d384d8f35f83ee230facd2c704dcbf2 Mon Sep 17 00:00:00 2001
From: Juan Manuel Coria
Date: Thu, 9 Jun 2022 17:48:58 +0200
Subject: [PATCH 12/29] Make pipeline always output the waveform alongside the
prediction
---
src/diart/inference.py | 2 +-
src/diart/pipelines.py | 7 +++----
src/diart/sinks.py | 27 +++++++++------------------
3 files changed, 13 insertions(+), 23 deletions(-)
diff --git a/src/diart/inference.py b/src/diart/inference.py
index e1cd95a3..9d384f73 100644
--- a/src/diart/inference.py
+++ b/src/diart/inference.py
@@ -143,7 +143,7 @@ def __call__(self, pipeline: OnlineSpeakerDiarization, batch_size: int = 32) ->
pipeline.config.duration,
pipeline.config.step,
)
- observable = pipeline.from_source(source, output_waveform=False)
+ observable = pipeline.from_source(source)
else:
observable = pipeline.from_file(
filepath,
diff --git a/src/diart/pipelines.py b/src/diart/pipelines.py
index 1c61de96..ddaf484f 100644
--- a/src/diart/pipelines.py
+++ b/src/diart/pipelines.py
@@ -91,7 +91,6 @@ def get_operators(
self,
source_uri: Text,
source_duration: Optional[float],
- output_waveform: bool = True
) -> List[dops.Operator]:
clustering = blocks.OnlineSpeakerClustering(
self.config.tau_active,
@@ -119,7 +118,7 @@ def get_operators(
audio_aggregation(wav_buffer), pred_aggregation(pred_buffer)
)),
# Binarize output
- ops.starmap(lambda wav, pred: (binarize(pred), wav if output_waveform else None)),
+ ops.starmap(lambda wav, pred: (binarize(pred), wav)),
]
@@ -135,7 +134,7 @@ def __init__(self, config: PipelineConfig, profile: bool = False):
msg = f"Latency should be in the range [{config.step}, {config.duration}]"
assert config.step <= config.latency <= config.duration, msg
- def from_source(self, source: src.AudioSource, output_waveform: bool = True) -> rx.Observable:
+ def from_source(self, source: src.AudioSource) -> rx.Observable:
msg = f"Audio source has sample rate {source.sample_rate}, expected {self.config.sample_rate}"
assert source.sample_rate == self.config.sample_rate, msg
operators = []
@@ -151,7 +150,7 @@ def from_source(self, source: src.AudioSource, output_waveform: bool = True) ->
ops.starmap(lambda wav, seg: (wav, seg, self.embedding(wav, seg))),
]
# Add speaker tracking
- operators += self.speaker_tracking.get_operators(source.uri, source.duration, output_waveform)
+ operators += self.speaker_tracking.get_operators(source.uri, source.duration)
if self.profile:
return dops.profile(source.stream, operators)
return source.stream.pipe(*operators)
diff --git a/src/diart/sinks.py b/src/diart/sinks.py
index 0d33bf01..993e95a8 100644
--- a/src/diart/sinks.py
+++ b/src/diart/sinks.py
@@ -24,7 +24,7 @@ def patch_rttm(self):
with open(self.path, 'w') as file:
annotation.support(self.patch_collar).write_rttm(file)
- def on_next(self, value: Tuple[Annotation, Optional[SlidingWindowFeature]]):
+ def on_next(self, value: Tuple[Annotation, SlidingWindowFeature]):
with open(self.path, 'a') as file:
value[0].write_rttm(file)
@@ -58,16 +58,14 @@ def __init__(
self.latency = latency
self.figure, self.axs, self.num_axs = None, None, -1
- def _init_num_axs(self, waveform: Optional[SlidingWindowFeature]):
+ def _init_num_axs(self):
if self.num_axs == -1:
- self.num_axs = 1
- if waveform is not None:
- self.num_axs += 1
+ self.num_axs = 2
if self.reference is not None:
self.num_axs += 1
- def _init_figure(self, waveform: Optional[SlidingWindowFeature]):
- self._init_num_axs(waveform)
+ def _init_figure(self):
+ self._init_num_axs()
self.figure, self.axs = plt.subplots(self.num_axs, 1, figsize=(10, 2 * self.num_axs))
if self.num_axs == 1:
self.axs = [self.axs]
@@ -87,7 +85,7 @@ def on_next(self, values: Tuple[Annotation, SlidingWindowFeature, float]):
prediction, waveform, real_time = values
# Initialize figure if first call
if self.figure is None:
- self._init_figure(waveform)
+ self._init_figure()
# Clear previous plots
self._clear_axs()
# Set plot bounds
@@ -100,16 +98,9 @@ def on_next(self, values: Tuple[Annotation, SlidingWindowFeature, float]):
prediction.rename_labels(mapping=mapping, copy=False)
notebook.plot_annotation(prediction, self.axs[0])
self.axs[0].set_title("Output")
- if self.num_axs == 2:
- if waveform is not None:
- notebook.plot_feature(waveform, self.axs[1])
- self.axs[1].set_title("Audio")
- elif self.reference is not None:
- notebook.plot_annotation(self.reference, self.axs[1])
- self.axs[1].set_title("Reference")
- elif self.num_axs == 3:
- notebook.plot_feature(waveform, self.axs[1])
- self.axs[1].set_title("Audio")
+ notebook.plot_feature(waveform, self.axs[1])
+ self.axs[1].set_title("Audio")
+ if self.num_axs == 3:
notebook.plot_annotation(self.reference, self.axs[2])
self.axs[2].set_title("Reference")
From 0b0116b7b9d447711fbfb63a7e689437b960c867 Mon Sep 17 00:00:00 2001
From: Juan Manuel Coria
Date: Thu, 9 Jun 2022 18:24:08 +0200
Subject: [PATCH 13/29] Move feature pre-calculation from
OnlineSpeakerDiarization to a new type of streaming source:
PrecalculatedFeaturesAudioSource
---
src/diart/inference.py | 28 +++++++------
src/diart/pipelines.py | 83 ++++-----------------------------------
src/diart/sources.py | 89 +++++++++++++++++++++++++++++++++++++++++-
3 files changed, 111 insertions(+), 89 deletions(-)
diff --git a/src/diart/inference.py b/src/diart/inference.py
index 9d384f73..3e98e51f 100644
--- a/src/diart/inference.py
+++ b/src/diart/inference.py
@@ -48,7 +48,7 @@ def __call__(self, pipeline: OnlineSpeakerDiarization, source: src.AudioSource)
"""
rttm_path = self.output_path / f"{source.uri}.rttm"
rttm_writer = RTTMWriter(path=rttm_path)
- observable = pipeline.from_source(source).pipe(
+ observable = pipeline.from_audio_source(source).pipe(
dops.progress(f"Streaming {source.uri}", total=source.length, leave=True)
)
if not self.do_plot:
@@ -134,7 +134,6 @@ def __call__(self, pipeline: OnlineSpeakerDiarization, batch_size: int = 32) ->
)
# Stream fully online if batch size is 1 or lower
- source = None
if batch_size < 2:
source = src.FileAudioSource(
filepath,
@@ -143,26 +142,31 @@ def __call__(self, pipeline: OnlineSpeakerDiarization, batch_size: int = 32) ->
pipeline.config.duration,
pipeline.config.step,
)
- observable = pipeline.from_source(source)
+ observable = pipeline.from_audio_source(source)
else:
- observable = pipeline.from_file(
+ source = src.PrecalculatedFeaturesAudioSource(
filepath,
- batch_size=batch_size,
- desc=f"Pre-calculating {filepath.stem} ({i + 1}/{num_audio_files})",
+ filepath.stem,
+ pipeline.config.sample_rate,
+ pipeline.segmentation,
+ pipeline.embedding,
+ pipeline.config.duration,
+ pipeline.config.step,
+ batch_size,
+ # TODO decouple progress bar
+ progress_msg=f"Pre-calculating {filepath.stem} ({i + 1}/{num_audio_files})",
)
+ observable = pipeline.from_feature_source(source)
observable.pipe(
dops.progress(
desc=f"Streaming {filepath.stem} ({i + 1}/{num_audio_files})",
total=num_chunks,
- leave=source is None
+ leave=isinstance(source, src.PrecalculatedFeaturesAudioSource)
)
- ).subscribe(
- RTTMWriter(path=self.output_path / f"{filepath.stem}.rttm")
- )
+ ).subscribe(RTTMWriter(path=self.output_path / f"{filepath.stem}.rttm"))
- if source is not None:
- source.read()
+ source.read()
# Run evaluation
if self.reference_path is not None:
diff --git a/src/diart/pipelines.py b/src/diart/pipelines.py
index ddaf484f..e4c9a987 100644
--- a/src/diart/pipelines.py
+++ b/src/diart/pipelines.py
@@ -87,11 +87,7 @@ def __init__(self, config: PipelineConfig):
def get_end_time(self, duration: Optional[float]) -> Optional[float]:
return None if duration is None else self.config.last_chunk_end_time(duration)
- def get_operators(
- self,
- source_uri: Text,
- source_duration: Optional[float],
- ) -> List[dops.Operator]:
+ def get_operators(self, source: src.AudioSource) -> List[dops.Operator]:
clustering = blocks.OnlineSpeakerClustering(
self.config.tau_active,
self.config.rho_update,
@@ -99,14 +95,14 @@ def get_operators(
"cosine",
self.config.max_speakers,
)
- end_time = self.get_end_time(source_duration)
+ end_time = self.get_end_time(source.duration)
pred_aggregation = blocks.DelayedAggregation(
self.config.step, self.config.latency, strategy="hamming", stream_end=end_time
)
audio_aggregation = blocks.DelayedAggregation(
self.config.step, self.config.latency, strategy="first", stream_end=end_time
)
- binarize = blocks.Binarize(source_uri, self.config.tau_active)
+ binarize = blocks.Binarize(source.uri, self.config.tau_active)
return [
# Identify global speakers with online clustering
ops.starmap(lambda wav, seg, emb: (wav, clustering(seg, emb))),
@@ -134,7 +130,7 @@ def __init__(self, config: PipelineConfig, profile: bool = False):
msg = f"Latency should be in the range [{config.step}, {config.duration}]"
assert config.step <= config.latency <= config.duration, msg
- def from_source(self, source: src.AudioSource) -> rx.Observable:
+ def from_audio_source(self, source: src.AudioSource) -> rx.Observable:
msg = f"Audio source has sample rate {source.sample_rate}, expected {self.config.sample_rate}"
assert source.sample_rate == self.config.sample_rate, msg
operators = []
@@ -150,75 +146,10 @@ def from_source(self, source: src.AudioSource) -> rx.Observable:
ops.starmap(lambda wav, seg: (wav, seg, self.embedding(wav, seg))),
]
# Add speaker tracking
- operators += self.speaker_tracking.get_operators(source.uri, source.duration)
+ operators += self.speaker_tracking.get_operators(source)
if self.profile:
return dops.profile(source.stream, operators)
return source.stream.pipe(*operators)
- def from_file(
- self,
- filepath: src.FilePath,
- output_waveform: bool = False,
- batch_size: int = 32,
- desc: Optional[Text] = None,
- ) -> rx.Observable:
- loader = src.AudioLoader(self.config.sample_rate, mono=True)
-
- # Split audio into chunks
- chunks = rearrange(
- loader.load_sliding_chunks(filepath, self.config.duration, self.config.step),
- "chunk channel sample -> chunk sample channel"
- )
- num_chunks = chunks.shape[0]
-
- # Set progress if needed
- iterator = range(0, num_chunks, batch_size)
- if desc is not None:
- total = int(math.ceil(num_chunks / batch_size))
- iterator = tqdm(iterator, desc=desc, total=total, unit="batch", leave=False)
-
- # Pre-calculate segmentation and embeddings
- segmentation, embeddings = [], []
- for i in iterator:
- i_end = i + batch_size
- if i_end > num_chunks:
- i_end = num_chunks
- batch = chunks[i:i_end]
- seg = self.segmentation(batch)
- # Edge case: add batch dimension if i == i_end + 1
- if seg.ndim == 2:
- seg = seg[np.newaxis]
- emb = self.embedding(batch, seg)
- # Edge case: add batch dimension if i == i_end + 1
- if emb.ndim == 2:
- emb = emb.unsqueeze(0)
- segmentation.append(seg)
- embeddings.append(emb)
- segmentation = np.vstack(segmentation)
- embeddings = torch.vstack(embeddings)
-
- # Stream pre-calculated segmentation, embeddings and chunks
- resolution = self.config.duration / segmentation.shape[1]
- seg_stream = rx.range(0, num_chunks).pipe(
- ops.map(lambda i: SlidingWindowFeature(
- segmentation[i], SlidingWindow(resolution, resolution, i * self.config.step)
- ))
- )
- emb_stream = rx.range(0, num_chunks).pipe(ops.map(lambda i: embeddings[i]))
- wav_resolution = 1 / self.config.sample_rate
- chunk_stream = None
- if output_waveform:
- chunk_stream = rx.range(0, num_chunks).pipe(
- ops.map(lambda i: SlidingWindowFeature(
- chunks[i], SlidingWindow(wav_resolution, wav_resolution, i * self.config.step)
- ))
- )
-
- # Build speaker tracking pipeline
- return rx.zip(chunk_stream, seg_stream, emb_stream).pipe(
- *self.speaker_tracking.get_operators(
- Path(filepath).stem,
- loader.get_duration(filepath),
- output_waveform,
- )
- )
+ def from_feature_source(self, source: src.AudioSource) -> rx.Observable:
+ return source.stream.pipe(*self.speaker_tracking.get_operators(source))
diff --git a/src/diart/sources.py b/src/diart/sources.py
index ad6c961a..a87b8ce7 100644
--- a/src/diart/sources.py
+++ b/src/diart/sources.py
@@ -1,13 +1,20 @@
+import math
from queue import SimpleQueue
-from typing import Text, Optional
+from typing import Text, Optional, Callable
+import numpy as np
import sounddevice as sd
+import torch
+from einops import rearrange
from pyannote.core import SlidingWindowFeature, SlidingWindow
from rx.subject import Subject
+from tqdm import tqdm
from .audio import FilePath, AudioLoader
+from .features import TemporalFeatures
+# TODO rename this to something else since the same API is also used to stream features
class AudioSource:
"""Represents a source of audio that can start streaming via the `stream` property.
@@ -112,6 +119,86 @@ def read(self):
self.stream.on_completed()
+class PrecalculatedFeaturesAudioSource(FileAudioSource):
+ def __init__(
+ self,
+ file: FilePath,
+ uri: Text,
+ sample_rate: int,
+ segmentation: Callable[[TemporalFeatures], TemporalFeatures],
+ embedding: Callable[[TemporalFeatures, TemporalFeatures], TemporalFeatures],
+ chunk_duration: float = 5,
+ step_duration: float = 0.5,
+ batch_size: int = 32,
+ progress_msg: Optional[Text] = None,
+ ):
+ super().__init__(file, uri, sample_rate, chunk_duration, step_duration)
+ self.segmentation = segmentation
+ self.embedding = embedding
+ self.batch_size = batch_size
+ self.progress_msg = progress_msg
+
+ def read(self):
+ # Split audio into chunks
+ chunks = rearrange(
+ self.loader.load_sliding_chunks(
+ self.file, self.chunk_duration, self.step_duration
+ ),
+ "chunk channel sample -> chunk sample channel"
+ )
+ num_chunks = chunks.shape[0]
+
+ # Set progress if needed
+ iterator = range(0, num_chunks, self.batch_size)
+ if self.progress_msg is not None:
+ total = int(math.ceil(num_chunks / self.batch_size))
+ iterator = tqdm(iterator, desc=self.progress_msg, total=total, unit="batch", leave=False)
+
+ # Pre-calculate segmentation and embeddings
+ segmentation, embeddings = [], []
+ for i in iterator:
+ i_end = i + self.batch_size
+ if i_end > num_chunks:
+ i_end = num_chunks
+ batch = chunks[i:i_end]
+ seg = self.segmentation(batch)
+ # Edge case: add batch dimension if i == i_end + 1
+ if seg.ndim == 2:
+ seg = seg[np.newaxis]
+ emb = self.embedding(batch, seg)
+ # Edge case: add batch dimension if i == i_end + 1
+ if emb.ndim == 2:
+ emb = emb.unsqueeze(0)
+ segmentation.append(seg)
+ embeddings.append(emb)
+ segmentation = np.vstack(segmentation)
+ embeddings = torch.vstack(embeddings)
+
+ # Stream pre-calculated segmentation, embeddings and chunks
+ seg_resolution = self.chunk_duration / segmentation.shape[1]
+ for i in range(num_chunks):
+ chunk_window = SlidingWindow(
+ start=i * self.step_duration,
+ duration=self.resolution,
+ step=self.resolution,
+ )
+ seg_window = SlidingWindow(
+ start=i * self.step_duration,
+ duration=seg_resolution,
+ step=seg_resolution,
+ )
+ try:
+ self.stream.on_next((
+ SlidingWindowFeature(chunks[i], chunk_window),
+ SlidingWindowFeature(segmentation[i], seg_window),
+ embeddings[i]
+ ))
+ except Exception as e:
+ self.stream.on_error(e)
+
+ self.stream.on_completed()
+
+
class MicrophoneAudioSource(AudioSource):
"""Represents an audio source tied to the default microphone available"""
From ba3a53b41e3782288fd75789e51ae7e406645681 Mon Sep 17 00:00:00 2001
From: Juan Manuel Coria
Date: Fri, 10 Jun 2022 10:43:27 +0200
Subject: [PATCH 14/29] Refactoring and minor improvements
---
README.md | 2 +-
src/diart/audio.py | 2 +-
src/diart/benchmark.py | 24 ++++--------------------
src/diart/inference.py | 3 ---
src/diart/mapping.py | 24 ++----------------------
src/diart/operators.py | 3 ++-
src/diart/pipelines.py | 30 +++++++++++++++++++++---------
src/diart/sinks.py | 16 ++++++++++++++--
src/diart/sources.py | 10 +++-------
src/diart/stream.py | 24 ++++--------------------
10 files changed, 52 insertions(+), 86 deletions(-)
diff --git a/README.md b/README.md
index 29012f29..eb90e123 100644
--- a/README.md
+++ b/README.md
@@ -102,7 +102,7 @@ embedding = OverlapAwareSpeakerEmbedding(emb_model)
mic = MicrophoneAudioSource(seg_model.get_sample_rate())
# Reformat microphone stream. Defaults to 5s duration and 500ms shift
-regular_stream = mic.stream.pipe(dops.regularize_stream(seg_model.get_sample_rate()))
+regular_stream = mic.stream.pipe(dops.regularize_audio_stream(seg_model.get_sample_rate()))
# Branch the microphone stream to calculate segmentation
segmentation_stream = regular_stream.pipe(ops.map(segmentation))
# Join audio and segmentation stream to calculate speaker embeddings
diff --git a/src/diart/audio.py b/src/diart/audio.py
index e607ba5f..764ac3ee 100644
--- a/src/diart/audio.py
+++ b/src/diart/audio.py
@@ -98,4 +98,4 @@ def get_num_sliding_chunks(self, filepath: FilePath, chunk_duration: float, step
Duration of the step between chunks in seconds.
"""
numerator = self.get_duration(filepath) - chunk_duration + step_duration
- return int(np.ceil(numerator / step_duration))
\ No newline at end of file
+ return int(np.ceil(numerator / step_duration))
diff --git a/src/diart/benchmark.py b/src/diart/benchmark.py
index 713b54b6..1ac0bafd 100644
--- a/src/diart/benchmark.py
+++ b/src/diart/benchmark.py
@@ -1,13 +1,10 @@
import argparse
-import torch
-
import diart.argdoc as argdoc
from diart.inference import Benchmark
from diart.pipelines import OnlineSpeakerDiarization, PipelineConfig
if __name__ == "__main__":
- # Define script arguments
parser = argparse.ArgumentParser()
parser.add_argument("root", type=str, help="Directory with audio files CONVERSATION.(wav|flac|m4a|...)")
parser.add_argument("--reference", type=str, help="Optional. Directory with RTTM files CONVERSATION.rttm. Names must match audio files")
@@ -24,20 +21,7 @@
parser.add_argument("--output", type=str, help=f"{argdoc.OUTPUT}. Defaults to `root`")
args = parser.parse_args()
- # Set benchmark configuration
- benchmark = Benchmark(args.root, args.reference, args.output)
-
- # Define online speaker diarization pipeline
- pipeline = OnlineSpeakerDiarization(PipelineConfig(
- step=args.step,
- latency=args.latency,
- tau_active=args.tau,
- rho_update=args.rho,
- delta_new=args.delta,
- gamma=args.gamma,
- beta=args.beta,
- max_speakers=args.max_speakers,
- device=torch.device("cpu") if args.cpu else None,
- ), profile=True)
-
- benchmark(pipeline, args.batch_size)
+ Benchmark(args.root, args.reference, args.output)(
+ OnlineSpeakerDiarization(PipelineConfig.from_namespace(args), profile=True),
+ args.batch_size
+ )
diff --git a/src/diart/inference.py b/src/diart/inference.py
index 3e98e51f..6e204247 100644
--- a/src/diart/inference.py
+++ b/src/diart/inference.py
@@ -137,7 +137,6 @@ def __call__(self, pipeline: OnlineSpeakerDiarization, batch_size: int = 32) ->
if batch_size < 2:
source = src.FileAudioSource(
filepath,
- filepath.stem,
pipeline.config.sample_rate,
pipeline.config.duration,
pipeline.config.step,
@@ -146,14 +145,12 @@ def __call__(self, pipeline: OnlineSpeakerDiarization, batch_size: int = 32) ->
else:
source = src.PrecalculatedFeaturesAudioSource(
filepath,
- filepath.stem,
pipeline.config.sample_rate,
pipeline.segmentation,
pipeline.embedding,
pipeline.config.duration,
pipeline.config.step,
batch_size,
- # TODO decouple progress bar
progress_msg=f"Pre-calculating {filepath.stem} ({i + 1}/{num_audio_files})",
)
observable = pipeline.from_feature_source(source)
diff --git a/src/diart/mapping.py b/src/diart/mapping.py
index 7327ac6e..01465086 100644
--- a/src/diart/mapping.py
+++ b/src/diart/mapping.py
@@ -4,7 +4,7 @@
import numpy as np
from pyannote.core.utils.distance import cdist
-from scipy.optimize import linear_sum_assignment
+from scipy.optimize import linear_sum_assignment as lsap
class MappingMatrixObjective:
@@ -12,7 +12,7 @@ def invalid_tensor(self, shape: Union[Tuple, int]) -> np.ndarray:
return np.ones(shape) * self.invalid_value
def optimal_assignments(self, matrix: np.ndarray) -> List[int]:
- return list(linear_sum_assignment(matrix, self.maximize)[1])
+ return list(lsap(matrix, self.maximize)[1])
def mapped_indices(self, matrix: np.ndarray, axis: int) -> List[int]:
# Entries full of invalid_value are not mapped
@@ -137,26 +137,6 @@ def dist(
dist_matrix = cdist(embeddings1, embeddings2, metric=metric)
return SpeakerMap(dist_matrix, MinimizationObjective())
- @staticmethod
- def clf_output(predictions: np.ndarray, pad_to: Optional[int] = None) -> SpeakerMap:
- """
- Parameters
- ----------
- predictions : np.ndarray, (num_local_speakers, num_global_speakers)
- Probability outputs of a speaker embedding classifier
- pad_to : int, optional
- Pad num_global_speakers to this value.
- Useful to deal with unknown speakers that may appear in the future.
- Defaults to no padding
- """
- num_locals, num_globals = predictions.shape
- objective = MaximizationObjective(max_value=1)
- if pad_to is not None and num_globals < pad_to:
- padding = np.ones((num_locals, pad_to - num_globals))
- padding = objective.invalid_value * padding
- predictions = np.concatenate([predictions, padding], axis=1)
- return SpeakerMap(predictions, objective)
-
class SpeakerMap:
def __init__(self, mapping_matrix: np.ndarray, objective: MappingMatrixObjective):
diff --git a/src/diart/operators.py b/src/diart/operators.py
index cdd92dba..c8575866 100644
--- a/src/diart/operators.py
+++ b/src/diart/operators.py
@@ -37,7 +37,7 @@ def call_fn(state) -> SlidingWindowFeature:
return call_fn
-def regularize_stream(
+def regularize_audio_stream(
duration: float = 5,
step: float = 0.5,
sample_rate: int = 16000
@@ -331,6 +331,7 @@ def profile(observable: rx.Observable, operations: List[Operator]) -> rx.Observa
*operations,
ops.do_action(
on_next=lambda _: chronometer.stop(),
+ on_error=lambda _: chronometer.report(),
on_completed=chronometer.report,
)
)
diff --git a/src/diart/pipelines.py b/src/diart/pipelines.py
index e4c9a987..4ccbf1d4 100644
--- a/src/diart/pipelines.py
+++ b/src/diart/pipelines.py
@@ -1,14 +1,9 @@
-import math
-from pathlib import Path
-from typing import Optional, Text, List, Tuple
+from argparse import Namespace
+from typing import Optional, List
-import numpy as np
import rx
import rx.operators as ops
import torch
-from einops import rearrange
-from pyannote.core import SlidingWindowFeature, SlidingWindow
-from tqdm import tqdm
from . import blocks
from . import models as m
@@ -68,6 +63,23 @@ def __init__(
if self.device is None:
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ @staticmethod
+ def from_namespace(args: Namespace) -> 'PipelineConfig':
+ return PipelineConfig(
+ segmentation=getattr(args, "segmentation", None),
+ embedding=getattr(args, "embedding", None),
+ duration=getattr(args, "duration", None),
+ step=args.step,
+ latency=args.latency,
+ tau_active=args.tau,
+ rho_update=args.rho,
+ delta_new=args.delta,
+ gamma=args.gamma,
+ beta=args.beta,
+ max_speakers=args.max_speakers,
+ device=torch.device("cpu") if args.cpu else None,
+ )
+
def last_chunk_end_time(self, conv_duration: float) -> Optional[float]:
"""
Return the end time of the last chunk for a given conversation duration.
@@ -109,7 +121,7 @@ def get_operators(self, source: src.AudioSource) -> List[dops.Operator]:
# Buffer 'num_overlapping' sliding chunks with a step of 1 chunk
dops.buffer_slide(pred_aggregation.num_overlapping_windows),
# Aggregate overlapping output windows
- ops.map(lambda buffers: utils.unzip(buffers)),
+ ops.map(utils.unzip),
ops.starmap(lambda wav_buffer, pred_buffer: (
audio_aggregation(wav_buffer), pred_aggregation(pred_buffer)
)),
@@ -136,7 +148,7 @@ def from_audio_source(self, source: src.AudioSource) -> rx.Observable:
operators = []
# Regularize the stream to a specific chunk duration and step
if not source.is_regular:
- operators.append(dops.regularize_stream(
+ operators.append(dops.regularize_audio_stream(
self.config.duration, self.config.step, source.sample_rate
))
operators += [
diff --git a/src/diart/sinks.py b/src/diart/sinks.py
index 993e95a8..e859aa77 100644
--- a/src/diart/sinks.py
+++ b/src/diart/sinks.py
@@ -10,6 +10,10 @@
from typing_extensions import Literal
+class WindowClosedException(Exception):
+ pass
+
+
class RTTMWriter(Observer):
def __init__(self, path: Union[Path, Text], patch_collar: float = 0.05):
super().__init__()
@@ -57,6 +61,10 @@ def __init__(
self.window_duration = duration
self.latency = latency
self.figure, self.axs, self.num_axs = None, None, -1
+ self.window_closed = False
+
+ def _on_window_closed(self, event):
+ self.window_closed = True
def _init_num_axs(self):
if self.num_axs == -1:
@@ -69,6 +77,7 @@ def _init_figure(self):
self.figure, self.axs = plt.subplots(self.num_axs, 1, figsize=(10, 2 * self.num_axs))
if self.num_axs == 1:
self.axs = [self.axs]
+ self.figure.canvas.mpl_connect('close_event', self._on_window_closed)
def _clear_axs(self):
for i in range(self.num_axs):
@@ -82,6 +91,9 @@ def get_plot_bounds(self, real_time: float) -> Segment:
return Segment(start_time, end_time)
def on_next(self, values: Tuple[Annotation, SlidingWindowFeature, float]):
+ if self.window_closed:
+ raise WindowClosedException
+
prediction, waveform, real_time = values
# Initialize figure if first call
if self.figure is None:
@@ -111,5 +123,5 @@ def on_next(self, values: Tuple[Annotation, SlidingWindowFeature, float]):
plt.pause(0.05)
def on_error(self, error: Exception):
- print_exc()
- exit(1)
+ if not isinstance(error, WindowClosedException):
+ print_exc()
diff --git a/src/diart/sources.py b/src/diart/sources.py
index a87b8ce7..89fe2634 100644
--- a/src/diart/sources.py
+++ b/src/diart/sources.py
@@ -1,4 +1,5 @@
import math
+from pathlib import Path
from queue import SimpleQueue
from typing import Text, Optional, Callable
@@ -14,7 +15,6 @@
from .features import TemporalFeatures
-# TODO rename this to something else since the same API is also used to stream features
class AudioSource:
"""Represents a source of audio that can start streaming via the `stream` property.
@@ -59,8 +59,6 @@ class FileAudioSource(AudioSource):
----------
file: FilePath
Path to the file to stream.
- uri: Text
- Unique identifier of the audio source.
sample_rate: int
Sample rate of the chunks emitted.
chunk_duration: float
@@ -71,12 +69,11 @@ class FileAudioSource(AudioSource):
def __init__(
self,
file: FilePath,
- uri: Text,
sample_rate: int,
chunk_duration: float = 5,
step_duration: float = 0.5,
):
- super().__init__(uri, sample_rate)
+ super().__init__(Path(file).stem, sample_rate)
self.loader = AudioLoader(sample_rate, mono=True)
self._duration = self.loader.get_duration(file)
self.file = file
@@ -123,7 +120,6 @@ class PrecalculatedFeaturesAudioSource(FileAudioSource):
def __init__(
self,
file: FilePath,
- uri: Text,
sample_rate: int,
segmentation: Callable[[TemporalFeatures], TemporalFeatures],
embedding: Callable[[TemporalFeatures, TemporalFeatures], TemporalFeatures],
@@ -132,7 +128,7 @@ def __init__(
batch_size: int = 32,
progress_msg: Optional[Text] = None,
):
- super().__init__(file, uri, sample_rate, chunk_duration, step_duration)
+ super().__init__(file, sample_rate, chunk_duration, step_duration)
self.segmentation = segmentation
self.embedding = embedding
self.batch_size = batch_size
diff --git a/src/diart/stream.py b/src/diart/stream.py
index e843ed08..0ef8a5a6 100644
--- a/src/diart/stream.py
+++ b/src/diart/stream.py
@@ -1,15 +1,12 @@
import argparse
from pathlib import Path
-import torch
-
import diart.argdoc as argdoc
import diart.sources as src
from diart.inference import RealTimeInference
from diart.pipelines import OnlineSpeakerDiarization, PipelineConfig
if __name__ == "__main__":
- # Define script arguments
parser = argparse.ArgumentParser()
parser.add_argument("source", type=str, help="Path to an audio file | 'microphone'")
parser.add_argument("--step", default=0.5, type=float, help=f"{argdoc.STEP}. Defaults to 0.5")
@@ -26,32 +23,19 @@
args = parser.parse_args()
# Define online speaker diarization pipeline
- pipeline = OnlineSpeakerDiarization(PipelineConfig(
- step=args.step,
- latency=args.latency,
- tau_active=args.tau,
- rho_update=args.rho,
- delta_new=args.delta,
- gamma=args.gamma,
- beta=args.beta,
- max_speakers=args.max_speakers,
- device=torch.device("cpu") if args.cpu else None,
- ), profile=True)
+ config = PipelineConfig.from_namespace(args)
+ pipeline = OnlineSpeakerDiarization(config, profile=True)
# Manage audio source
if args.source != "microphone":
args.source = Path(args.source).expanduser()
args.output = args.source.parent if args.output is None else Path(args.output)
audio_source = src.FileAudioSource(
- args.source,
- args.source.stem,
- pipeline.config.sample_rate,
- pipeline.config.duration,
- pipeline.config.step,
+ args.source, config.sample_rate, config.duration, config.step
)
else:
args.output = Path("~/").expanduser() if args.output is None else Path(args.output)
- audio_source = src.MicrophoneAudioSource(pipeline.config.sample_rate)
+ audio_source = src.MicrophoneAudioSource(config.sample_rate)
# Run online inference
RealTimeInference(args.output, do_plot=not args.no_plot)(pipeline, audio_source)
From afd3448307fbfa127a4e00e31772188c42e5f864 Mon Sep 17 00:00:00 2001
From: Juan Manuel Coria
Date: Mon, 13 Jun 2022 18:23:59 +0200
Subject: [PATCH 15/29] Reorganize blocks into a package
---
src/diart/blocks.py | 558 -------------------------------
src/diart/blocks/__init__.py | 16 +
src/diart/blocks/aggregation.py | 198 +++++++++++
src/diart/blocks/clustering.py | 150 +++++++++
src/diart/blocks/embedding.py | 138 ++++++++
src/diart/blocks/segmentation.py | 36 ++
src/diart/blocks/utils.py | 55 +++
7 files changed, 593 insertions(+), 558 deletions(-)
delete mode 100644 src/diart/blocks.py
create mode 100644 src/diart/blocks/__init__.py
create mode 100644 src/diart/blocks/aggregation.py
create mode 100644 src/diart/blocks/clustering.py
create mode 100644 src/diart/blocks/embedding.py
create mode 100644 src/diart/blocks/segmentation.py
create mode 100644 src/diart/blocks/utils.py
diff --git a/src/diart/blocks.py b/src/diart/blocks.py
deleted file mode 100644
index 33d2b72a..00000000
--- a/src/diart/blocks.py
+++ /dev/null
@@ -1,558 +0,0 @@
-from typing import Union, Optional, List, Iterable, Tuple, Text
-
-import numpy as np
-import torch
-from einops import rearrange
-from pyannote.core import Annotation, Segment, SlidingWindow, SlidingWindowFeature
-from typing_extensions import Literal
-
-from .features import TemporalFeatures, TemporalFeatureFormatter
-from .mapping import SpeakerMap, SpeakerMapBuilder
-from .models import SegmentationModel, EmbeddingModel
-
-
-class SpeakerSegmentation:
- def __init__(self, model: SegmentationModel, device: Optional[torch.device] = None):
- self.model = model
- self.model.eval()
- self.device = device
- if self.device is None:
- self.device = torch.device("cpu")
- self.model.to(self.device)
- self.formatter = TemporalFeatureFormatter()
-
- def __call__(self, waveform: TemporalFeatures) -> TemporalFeatures:
- """
- Calculate the speaker segmentation of input audio.
-
- Parameters
- ----------
- waveform: TemporalFeatures, shape (samples, channels) or (batch, samples, channels)
-
- Returns
- -------
- speaker_segmentation: TemporalFeatures, shape (batch, frames, speakers)
- The batch dimension is omitted if waveform is a `SlidingWindowFeature`.
- """
- with torch.no_grad():
- wave = rearrange(self.formatter.cast(waveform), "batch sample channel -> batch channel sample")
- output = self.model(wave.to(self.device)).cpu()
- return self.formatter.restore_type(output)
-
-
-class SpeakerEmbedding:
- def __init__(self, model: EmbeddingModel, device: Optional[torch.device] = None):
- self.model = model
- self.model.eval()
- self.device = device
- if self.device is None:
- self.device = torch.device("cpu")
- self.model.to(self.device)
- self.waveform_formatter = TemporalFeatureFormatter()
- self.weights_formatter = TemporalFeatureFormatter()
-
- def __call__(self, waveform: TemporalFeatures, weights: Optional[TemporalFeatures] = None) -> torch.Tensor:
- """
- Calculate speaker embeddings of input audio.
- If weights are given, calculate many speaker embeddings from the same waveform.
-
- Parameters
- ----------
- waveform: TemporalFeatures, shape (samples, channels) or (batch, samples, channels)
- weights: Optional[TemporalFeatures], shape (frames, speakers) or (batch, frames, speakers)
- Per-speaker and per-frame weights. Defaults to no weights.
-
- Returns
- -------
- embeddings: torch.Tensor
- If weights are provided, the shape is (batch, speakers, embedding_dim),
- otherwise the shape is (batch, embedding_dim).
- If batch size == 1, the batch dimension is omitted.
- """
- with torch.no_grad():
- inputs = self.waveform_formatter.cast(waveform).to(self.device)
- inputs = rearrange(inputs, "batch sample channel -> batch channel sample")
- if weights is not None:
- weights = self.weights_formatter.cast(weights).to(self.device)
- batch_size, _, num_speakers = weights.shape
- inputs = inputs.repeat(1, num_speakers, 1)
- weights = rearrange(weights, "batch frame spk -> (batch spk) frame")
- inputs = rearrange(inputs, "batch spk sample -> (batch spk) 1 sample")
- output = rearrange(
- self.model(inputs, weights),
- "(batch spk) feat -> batch spk feat",
- batch=batch_size,
- spk=num_speakers
- )
- else:
- output = self.model(inputs)
- return output.squeeze().cpu()
-
-
-class OverlappedSpeechPenalty:
- """
- Parameters
- ----------
- gamma: float, optional
- Exponent to lower low-confidence predictions.
- Defaults to 3.
- beta: float, optional
- Temperature parameter (actually 1/beta) to lower joint speaker activations.
- Defaults to 10.
- """
- def __init__(self, gamma: float = 3, beta: float = 10):
- self.gamma = gamma
- self.beta = beta
- self.formatter = TemporalFeatureFormatter()
-
- def __call__(self, segmentation: TemporalFeatures) -> TemporalFeatures:
- weights = self.formatter.cast(segmentation) # shape (batch, frames, speakers)
- with torch.no_grad():
- probs = torch.softmax(self.beta * weights, dim=-1)
- weights = torch.pow(weights, self.gamma) * torch.pow(probs, self.gamma)
- weights[weights < 1e-8] = 1e-8
- return self.formatter.restore_type(weights)
-
-
-class EmbeddingNormalization:
- def __init__(self, norm: Union[float, torch.Tensor] = 1):
- self.norm = norm
- # Add batch dimension if missing
- if isinstance(self.norm, torch.Tensor) and self.norm.ndim == 2:
- self.norm = self.norm.unsqueeze(0)
-
- def __call__(self, embeddings: torch.Tensor) -> torch.Tensor:
- # Add batch dimension if missing
- if embeddings.ndim == 2:
- embeddings = embeddings.unsqueeze(0)
- if isinstance(self.norm, torch.Tensor):
- batch_size1, num_speakers1, _ = self.norm.shape
- batch_size2, num_speakers2, _ = embeddings.shape
- assert batch_size1 == batch_size2 and num_speakers1 == num_speakers2
- with torch.no_grad():
- norm_embs = self.norm * embeddings / torch.norm(embeddings, p=2, dim=-1, keepdim=True)
- return norm_embs.squeeze()
-
-
-class OverlapAwareSpeakerEmbedding:
- """
- Extract overlap-aware speaker embeddings given an audio chunk and its segmentation.
-
- Parameters
- ----------
- model: EmbeddingModel
- A pre-trained embedding model.
- gamma: float, optional
- Exponent to lower low-confidence predictions.
- Defaults to 3.
- beta: float, optional
- Softmax's temperature parameter (actually 1/beta) to lower joint speaker activations.
- Defaults to 10.
- norm: float or torch.Tensor of shape (batch, speakers, 1) where batch is optional
- The target norm for the embeddings. It can be different for each speaker.
- Defaults to 1.
- device: Optional[torch.device]
- The device on which to run the embedding model.
- Defaults to GPU if available or CPU if not.
- """
- def __init__(
- self,
- model: EmbeddingModel,
- gamma: float = 3,
- beta: float = 10,
- norm: Union[float, torch.Tensor] = 1,
- device: Optional[torch.device] = None,
- ):
- self.embedding = SpeakerEmbedding(model, device)
- self.osp = OverlappedSpeechPenalty(gamma, beta)
- self.normalize = EmbeddingNormalization(norm)
-
- def __call__(self, waveform: TemporalFeatures, segmentation: TemporalFeatures) -> torch.Tensor:
- return self.normalize(self.embedding(waveform, self.osp(segmentation)))
-
-
-class AggregationStrategy:
- """Abstract class representing a strategy to aggregate overlapping buffers"""
-
- @staticmethod
- def build(name: Literal["mean", "hamming", "first"]) -> 'AggregationStrategy':
- """Build an AggregationStrategy instance based on its name"""
- assert name in ("mean", "hamming", "first")
- if name == "mean":
- return AverageStrategy()
- elif name == "hamming":
- return HammingWeightedAverageStrategy()
- else:
- return FirstOnlyStrategy()
-
- def __call__(self, buffers: List[SlidingWindowFeature], focus: Segment) -> SlidingWindowFeature:
- """Aggregate chunks over a specific region.
-
- Parameters
- ----------
- buffers: list of SlidingWindowFeature, shapes (frames, speakers)
- Buffers to aggregate
- focus: Segment
- Region to aggregate that is shared among the buffers
-
- Returns
- -------
- aggregation: SlidingWindowFeature, shape (cropped_frames, speakers)
- Aggregated values over the focus region
- """
- aggregation = self.aggregate(buffers, focus)
- resolution = focus.duration / aggregation.shape[0]
- resolution = SlidingWindow(
- start=focus.start,
- duration=resolution,
- step=resolution
- )
- return SlidingWindowFeature(aggregation, resolution)
-
- def aggregate(self, buffers: List[SlidingWindowFeature], focus: Segment) -> np.ndarray:
- raise NotImplementedError
-
-
-class HammingWeightedAverageStrategy(AggregationStrategy):
- """Compute the average weighted by the corresponding Hamming-window aligned to each buffer"""
-
- def aggregate(self, buffers: List[SlidingWindowFeature], focus: Segment) -> np.ndarray:
- num_frames, num_speakers = buffers[0].data.shape
- hamming, intersection = [], []
- for buffer in buffers:
- # Crop buffer to focus region
- b = buffer.crop(focus, fixed=focus.duration)
- # Crop Hamming window to focus region
- h = np.expand_dims(np.hamming(num_frames), axis=-1)
- h = SlidingWindowFeature(h, buffer.sliding_window)
- h = h.crop(focus, fixed=focus.duration)
- hamming.append(h.data)
- intersection.append(b.data)
- hamming, intersection = np.stack(hamming), np.stack(intersection)
- # Calculate weighted mean
- return np.sum(hamming * intersection, axis=0) / np.sum(hamming, axis=0)
-
-
-class AverageStrategy(AggregationStrategy):
- """Compute a simple average over the focus region"""
-
- def aggregate(self, buffers: List[SlidingWindowFeature], focus: Segment) -> np.ndarray:
- # Stack all overlapping regions
- intersection = np.stack([
- buffer.crop(focus, fixed=focus.duration)
- for buffer in buffers
- ])
- return np.mean(intersection, axis=0)
-
-
-class FirstOnlyStrategy(AggregationStrategy):
- """Instead of aggregating, keep the first focus region in the buffer list"""
-
- def aggregate(self, buffers: List[SlidingWindowFeature], focus: Segment) -> np.ndarray:
- return buffers[0].crop(focus, fixed=focus.duration)
-
-
-class DelayedAggregation:
- """Aggregate aligned overlapping windows of the same duration
- across sliding buffers with a specific step and latency.
-
- Parameters
- ----------
- step: float
- Shift between two consecutive buffers, in seconds.
- latency: float, optional
- Desired latency, in seconds. Defaults to step.
- The higher the latency, the more overlapping windows to aggregate.
- strategy: ("mean", "hamming", "any"), optional
- Specifies how to aggregate overlapping windows. Defaults to "hamming".
- "mean": simple average
- "hamming": average weighted by the Hamming window values (aligned to the buffer)
- "any": no aggregation, pick the first overlapping window
- stream_end: float, optional
- Stream end time (in seconds). Defaults to None.
- If the stream end time is known, then append remaining outputs at the end,
- otherwise the last `latency - step` seconds are ignored.
-
- Example
- --------
- >>> duration = 5
- >>> frames = 500
- >>> step = 0.5
- >>> speakers = 2
- >>> start_time = 10
- >>> resolution = duration / frames
- >>> dagg = DelayedAggregation(step=step, latency=2, strategy="mean")
- >>> buffers = [
- >>> SlidingWindowFeature(
- >>> np.random.rand(frames, speakers),
- >>> SlidingWindow(start=(i + start_time) * step, duration=resolution, step=resolution)
- >>> )
- >>> for i in range(dagg.num_overlapping_windows)
- >>> ]
- >>> dagg.num_overlapping_windows
- ... 4
- >>> dagg(buffers).data.shape
- ... (51, 2) # Rounding errors are possible when cropping the buffers
- """
-
- def __init__(
- self,
- step: float,
- latency: Optional[float] = None,
- strategy: Literal["mean", "hamming", "first"] = "hamming",
- stream_end: Optional[float] = None
- ):
- self.step = step
- self.latency = latency
- self.strategy = strategy
- self.stream_end = stream_end
-
- if self.latency is None:
- self.latency = self.step
-
- assert self.step <= self.latency, "Invalid latency requested"
-
- self.num_overlapping_windows = int(round(self.latency / self.step))
- self.aggregate = AggregationStrategy.build(self.strategy)
-
- def _prepend_or_append(
- self,
- output_window: SlidingWindowFeature,
- output_region: Segment,
- buffers: List[SlidingWindowFeature]
- ):
- last_buffer = buffers[-1].extent
- # Prepend prediction until we match the latency in case of first buffer
- if len(buffers) == 1 and last_buffer.start == 0:
- num_frames = output_window.data.shape[0]
- first_region = Segment(0, output_region.end)
- first_output = buffers[0].crop(
- first_region, fixed=first_region.duration
- )
- first_output[-num_frames:] = output_window.data
- resolution = output_region.end / first_output.shape[0]
- output_window = SlidingWindowFeature(
- first_output,
- SlidingWindow(start=0, duration=resolution, step=resolution)
- )
- # Append rest of the outputs
- elif self.stream_end is not None and last_buffer.end == self.stream_end:
- # FIXME instead of appending a larger chunk than expected when latency > step,
- # keep emitting windows until the signal ends.
- # This should be fixed at the observable level and not within the aggregation block.
- num_frames = output_window.data.shape[0]
- last_region = Segment(output_region.start, last_buffer.end)
- last_output = buffers[-1].crop(
- last_region, fixed=last_region.duration
- )
- last_output[:num_frames] = output_window.data
- resolution = self.latency / last_output.shape[0]
- output_window = SlidingWindowFeature(
- last_output,
- SlidingWindow(
- start=output_region.start,
- duration=resolution,
- step=resolution
- )
- )
- return output_window
-
- def __call__(self, buffers: List[SlidingWindowFeature]) -> SlidingWindowFeature:
- # Determine overlapping region to aggregate
- start = buffers[-1].extent.end - self.latency
- region = Segment(start, start + self.step)
- return self._prepend_or_append(self.aggregate(buffers, region), region, buffers)
-
-
-class OnlineSpeakerClustering:
- def __init__(
- self,
- tau_active: float,
- rho_update: float,
- delta_new: float,
- metric: Optional[str] = "cosine",
- max_speakers: int = 20
- ):
- self.tau_active = tau_active
- self.rho_update = rho_update
- self.delta_new = delta_new
- self.metric = metric
- self.max_speakers = max_speakers
- self.centers: Optional[np.ndarray] = None
- self.active_centers = set()
- self.blocked_centers = set()
-
- @property
- def num_free_centers(self) -> int:
- return self.max_speakers - self.num_known_speakers - self.num_blocked_speakers
-
- @property
- def num_known_speakers(self) -> int:
- return len(self.active_centers)
-
- @property
- def num_blocked_speakers(self) -> int:
- return len(self.blocked_centers)
-
- @property
- def inactive_centers(self) -> List[int]:
- return [
- c
- for c in range(self.max_speakers)
- if c not in self.active_centers or c in self.blocked_centers
- ]
-
- def get_next_center_position(self) -> Optional[int]:
- for center in range(self.max_speakers):
- if center not in self.active_centers and center not in self.blocked_centers:
- return center
-
- def init_centers(self, dimension: int):
- self.centers = np.zeros((self.max_speakers, dimension))
- self.active_centers = set()
- self.blocked_centers = set()
-
- def update(self, assignments: Iterable[Tuple[int, int]], embeddings: np.ndarray):
- if self.centers is not None:
- for l_spk, g_spk in assignments:
- assert g_spk in self.active_centers, "Cannot update unknown centers"
- self.centers[g_spk] += embeddings[l_spk]
-
- def add_center(self, embedding: np.ndarray) -> int:
- center = self.get_next_center_position()
- self.centers[center] = embedding
- self.active_centers.add(center)
- return center
-
- def identify(
- self,
- segmentation: SlidingWindowFeature,
- embeddings: torch.Tensor
- ) -> SpeakerMap:
- embeddings = embeddings.detach().cpu().numpy()
- active_speakers = np.where(np.max(segmentation.data, axis=0) >= self.tau_active)[0]
- long_speakers = np.where(np.mean(segmentation.data, axis=0) >= self.rho_update)[0]
- num_local_speakers = segmentation.data.shape[1]
-
- if self.centers is None:
- self.init_centers(embeddings.shape[1])
- assignments = [
- (spk, self.add_center(embeddings[spk]))
- for spk in active_speakers
- ]
- return SpeakerMapBuilder.hard_map(
- shape=(num_local_speakers, self.max_speakers),
- assignments=assignments,
- maximize=False,
- )
-
- # Obtain a mapping based on distances between embeddings and centers
- dist_map = SpeakerMapBuilder.dist(embeddings, self.centers, self.metric)
- # Remove any assignments containing invalid speakers
- inactive_speakers = np.array([
- spk for spk in range(num_local_speakers)
- if spk not in active_speakers
- ])
- dist_map = dist_map.unmap_speakers(inactive_speakers, self.inactive_centers)
- # Keep assignments under the distance threshold
- valid_map = dist_map.unmap_threshold(self.delta_new)
-
- # Some speakers might be unidentified
- missed_speakers = [
- s for s in active_speakers
- if not valid_map.is_source_speaker_mapped(s)
- ]
-
- # Add assignments to new centers if possible
- new_center_speakers = []
- for spk in missed_speakers:
- has_space = len(new_center_speakers) < self.num_free_centers
- if has_space and spk in long_speakers:
- # Flag as a new center
- new_center_speakers.append(spk)
- else:
- # Cannot create a new center
- # Get global speakers in order of preference
- preferences = np.argsort(dist_map.mapping_matrix[spk, :])
- preferences = [
- g_spk for g_spk in preferences if g_spk in self.active_centers
- ]
- # Get the free global speakers among the preferences
- _, g_assigned = valid_map.valid_assignments()
- free = [g_spk for g_spk in preferences if g_spk not in g_assigned]
- if free:
- # The best global speaker is the closest free one
- valid_map = valid_map.set_source_speaker(spk, free[0])
-
- # Update known centers
- to_update = [
- (ls, gs)
- for ls, gs in zip(*valid_map.valid_assignments())
- if ls not in missed_speakers and ls in long_speakers
- ]
- self.update(to_update, embeddings)
-
- # Add new centers
- for spk in new_center_speakers:
- valid_map = valid_map.set_source_speaker(
- spk, self.add_center(embeddings[spk])
- )
-
- return valid_map
-
- def __call__(self, segmentation: SlidingWindowFeature, embeddings: torch.Tensor) -> SlidingWindowFeature:
- return SlidingWindowFeature(
- self.identify(segmentation, embeddings).apply(segmentation.data),
- segmentation.sliding_window
- )
-
-
-class Binarize:
- """
- Transform a speaker segmentation from the discrete-time domain
- into a continuous-time speaker segmentation.
-
- Parameters
- ----------
- uri: Text
- Uri of the audio stream.
- threshold: float
- Probability threshold to determine if a speaker is active at a given frame.
- """
-
- def __init__(self, uri: Text, threshold: float):
- self.uri = uri
- self.threshold = threshold
-
- def __call__(self, segmentation: SlidingWindowFeature) -> Annotation:
- """
- Return the continuous-time segmentation
- corresponding to the discrete-time input segmentation.
-
- Parameters
- ----------
- segmentation: SlidingWindowFeature
- Discrete-time speaker segmentation.
-
- Returns
- -------
- annotation: Annotation
- Continuous-time speaker segmentation.
- """
- num_frames, num_speakers = segmentation.data.shape
- timestamps = segmentation.sliding_window
- is_active = segmentation.data > self.threshold
- # Artificially add last inactive frame to close any remaining speaker turns
- is_active = np.append(is_active, [[False] * num_speakers], axis=0)
- start_times = np.zeros(num_speakers) + timestamps[0].middle
- annotation = Annotation(uri=self.uri, modality="speech")
- for t in range(num_frames):
- # Any (False, True) starts a speaker turn at "True" index
- onsets = np.logical_and(np.logical_not(is_active[t]), is_active[t + 1])
- start_times[onsets] = timestamps[t + 1].middle
- # Any (True, False) ends a speaker turn at "False" index
- offsets = np.logical_and(is_active[t], np.logical_not(is_active[t + 1]))
- for spk in np.where(offsets)[0]:
- region = Segment(start_times[spk], timestamps[t + 1].middle)
- annotation[region, spk] = f"speaker{spk}"
- return annotation
diff --git a/src/diart/blocks/__init__.py b/src/diart/blocks/__init__.py
new file mode 100644
index 00000000..a2d88c00
--- /dev/null
+++ b/src/diart/blocks/__init__.py
@@ -0,0 +1,16 @@
+from .aggregation import (
+ AggregationStrategy,
+ HammingWeightedAverageStrategy,
+ AverageStrategy,
+ FirstOnlyStrategy,
+ DelayedAggregation,
+)
+from .clustering import OnlineSpeakerClustering
+from .embedding import (
+ SpeakerEmbedding,
+ OverlappedSpeechPenalty,
+ EmbeddingNormalization,
+ OverlapAwareSpeakerEmbedding,
+)
+from .segmentation import SpeakerSegmentation
+from .utils import Binarize
diff --git a/src/diart/blocks/aggregation.py b/src/diart/blocks/aggregation.py
new file mode 100644
index 00000000..d0379711
--- /dev/null
+++ b/src/diart/blocks/aggregation.py
@@ -0,0 +1,198 @@
+from typing import Optional, List
+
+import numpy as np
+from pyannote.core import Segment, SlidingWindow, SlidingWindowFeature
+from typing_extensions import Literal
+
+
+class AggregationStrategy:
+ """Abstract class representing a strategy to aggregate overlapping buffers"""
+
+ @staticmethod
+ def build(name: Literal["mean", "hamming", "first"]) -> 'AggregationStrategy':
+ """Build an AggregationStrategy instance based on its name"""
+ assert name in ("mean", "hamming", "first")
+ if name == "mean":
+ return AverageStrategy()
+ elif name == "hamming":
+ return HammingWeightedAverageStrategy()
+ else:
+ return FirstOnlyStrategy()
+
+ def __call__(self, buffers: List[SlidingWindowFeature], focus: Segment) -> SlidingWindowFeature:
+ """Aggregate chunks over a specific region.
+
+ Parameters
+ ----------
+ buffers: list of SlidingWindowFeature, shapes (frames, speakers)
+ Buffers to aggregate
+ focus: Segment
+ Region to aggregate that is shared among the buffers
+
+ Returns
+ -------
+ aggregation: SlidingWindowFeature, shape (cropped_frames, speakers)
+ Aggregated values over the focus region
+ """
+ aggregation = self.aggregate(buffers, focus)
+ resolution = focus.duration / aggregation.shape[0]
+ resolution = SlidingWindow(
+ start=focus.start,
+ duration=resolution,
+ step=resolution
+ )
+ return SlidingWindowFeature(aggregation, resolution)
+
+ def aggregate(self, buffers: List[SlidingWindowFeature], focus: Segment) -> np.ndarray:
+ raise NotImplementedError
+
+
+class HammingWeightedAverageStrategy(AggregationStrategy):
+ """Compute the average weighted by the corresponding Hamming-window aligned to each buffer"""
+
+ def aggregate(self, buffers: List[SlidingWindowFeature], focus: Segment) -> np.ndarray:
+ num_frames, num_speakers = buffers[0].data.shape
+ hamming, intersection = [], []
+ for buffer in buffers:
+ # Crop buffer to focus region
+ b = buffer.crop(focus, fixed=focus.duration)
+ # Crop Hamming window to focus region
+ h = np.expand_dims(np.hamming(num_frames), axis=-1)
+ h = SlidingWindowFeature(h, buffer.sliding_window)
+ h = h.crop(focus, fixed=focus.duration)
+ hamming.append(h.data)
+ intersection.append(b.data)
+ hamming, intersection = np.stack(hamming), np.stack(intersection)
+ # Calculate weighted mean
+ return np.sum(hamming * intersection, axis=0) / np.sum(hamming, axis=0)
+
+
+class AverageStrategy(AggregationStrategy):
+ """Compute a simple average over the focus region"""
+
+ def aggregate(self, buffers: List[SlidingWindowFeature], focus: Segment) -> np.ndarray:
+ # Stack all overlapping regions
+ intersection = np.stack([
+ buffer.crop(focus, fixed=focus.duration)
+ for buffer in buffers
+ ])
+ return np.mean(intersection, axis=0)
+
+
+class FirstOnlyStrategy(AggregationStrategy):
+ """Instead of aggregating, keep the first focus region in the buffer list"""
+
+ def aggregate(self, buffers: List[SlidingWindowFeature], focus: Segment) -> np.ndarray:
+ return buffers[0].crop(focus, fixed=focus.duration)
+
+
+class DelayedAggregation:
+ """Aggregate aligned overlapping windows of the same duration
+ across sliding buffers with a specific step and latency.
+
+ Parameters
+ ----------
+ step: float
+ Shift between two consecutive buffers, in seconds.
+ latency: float, optional
+ Desired latency, in seconds. Defaults to step.
+ The higher the latency, the more overlapping windows to aggregate.
+ strategy: ("mean", "hamming", "any"), optional
+ Specifies how to aggregate overlapping windows. Defaults to "hamming".
+ "mean": simple average
+ "hamming": average weighted by the Hamming window values (aligned to the buffer)
+ "any": no aggregation, pick the first overlapping window
+ stream_end: float, optional
+ Stream end time (in seconds). Defaults to None.
+ If the stream end time is known, then append remaining outputs at the end,
+ otherwise the last `latency - step` seconds are ignored.
+
+ Example
+ --------
+ >>> duration = 5
+ >>> frames = 500
+ >>> step = 0.5
+ >>> speakers = 2
+ >>> start_time = 10
+ >>> resolution = duration / frames
+ >>> dagg = DelayedAggregation(step=step, latency=2, strategy="mean")
+ >>> buffers = [
+ >>> SlidingWindowFeature(
+ >>> np.random.rand(frames, speakers),
+ >>> SlidingWindow(start=(i + start_time) * step, duration=resolution, step=resolution)
+ >>> )
+ >>> for i in range(dagg.num_overlapping_windows)
+ >>> ]
+ >>> dagg.num_overlapping_windows
+ ... 4
+ >>> dagg(buffers).data.shape
+ ... (51, 2) # Rounding errors are possible when cropping the buffers
+ """
+
+ def __init__(
+ self,
+ step: float,
+ latency: Optional[float] = None,
+ strategy: Literal["mean", "hamming", "first"] = "hamming",
+ stream_end: Optional[float] = None
+ ):
+ self.step = step
+ self.latency = latency
+ self.strategy = strategy
+ self.stream_end = stream_end
+
+ if self.latency is None:
+ self.latency = self.step
+
+ assert self.step <= self.latency, "Invalid latency requested"
+
+ self.num_overlapping_windows = int(round(self.latency / self.step))
+ self.aggregate = AggregationStrategy.build(self.strategy)
+
+ def _prepend_or_append(
+ self,
+ output_window: SlidingWindowFeature,
+ output_region: Segment,
+ buffers: List[SlidingWindowFeature]
+ ):
+ last_buffer = buffers[-1].extent
+ # Prepend prediction until we match the latency in case of first buffer
+ if len(buffers) == 1 and last_buffer.start == 0:
+ num_frames = output_window.data.shape[0]
+ first_region = Segment(0, output_region.end)
+ first_output = buffers[0].crop(
+ first_region, fixed=first_region.duration
+ )
+ first_output[-num_frames:] = output_window.data
+ resolution = output_region.end / first_output.shape[0]
+ output_window = SlidingWindowFeature(
+ first_output,
+ SlidingWindow(start=0, duration=resolution, step=resolution)
+ )
+ # Append rest of the outputs
+ elif self.stream_end is not None and last_buffer.end == self.stream_end:
+ # FIXME instead of appending a larger chunk than expected when latency > step,
+ # keep emitting windows until the signal ends.
+ # This should be fixed at the observable level and not within the aggregation block.
+ num_frames = output_window.data.shape[0]
+ last_region = Segment(output_region.start, last_buffer.end)
+ last_output = buffers[-1].crop(
+ last_region, fixed=last_region.duration
+ )
+ last_output[:num_frames] = output_window.data
+ resolution = self.latency / last_output.shape[0]
+ output_window = SlidingWindowFeature(
+ last_output,
+ SlidingWindow(
+ start=output_region.start,
+ duration=resolution,
+ step=resolution
+ )
+ )
+ return output_window
+
+ def __call__(self, buffers: List[SlidingWindowFeature]) -> SlidingWindowFeature:
+ # Determine overlapping region to aggregate
+ start = buffers[-1].extent.end - self.latency
+ region = Segment(start, start + self.step)
+ return self._prepend_or_append(self.aggregate(buffers, region), region, buffers)
diff --git a/src/diart/blocks/clustering.py b/src/diart/blocks/clustering.py
new file mode 100644
index 00000000..1842ecfa
--- /dev/null
+++ b/src/diart/blocks/clustering.py
@@ -0,0 +1,150 @@
+from typing import Optional, List, Iterable, Tuple
+
+import numpy as np
+import torch
+from pyannote.core import SlidingWindowFeature
+
+from mapping import SpeakerMap, SpeakerMapBuilder
+
+
+class OnlineSpeakerClustering:
+ def __init__(
+ self,
+ tau_active: float,
+ rho_update: float,
+ delta_new: float,
+ metric: Optional[str] = "cosine",
+ max_speakers: int = 20
+ ):
+ self.tau_active = tau_active
+ self.rho_update = rho_update
+ self.delta_new = delta_new
+ self.metric = metric
+ self.max_speakers = max_speakers
+ self.centers: Optional[np.ndarray] = None
+ self.active_centers = set()
+ self.blocked_centers = set()
+
+ @property
+ def num_free_centers(self) -> int:
+ return self.max_speakers - self.num_known_speakers - self.num_blocked_speakers
+
+ @property
+ def num_known_speakers(self) -> int:
+ return len(self.active_centers)
+
+ @property
+ def num_blocked_speakers(self) -> int:
+ return len(self.blocked_centers)
+
+ @property
+ def inactive_centers(self) -> List[int]:
+ return [
+ c
+ for c in range(self.max_speakers)
+ if c not in self.active_centers or c in self.blocked_centers
+ ]
+
+ def get_next_center_position(self) -> Optional[int]:
+ for center in range(self.max_speakers):
+ if center not in self.active_centers and center not in self.blocked_centers:
+ return center
+
+ def init_centers(self, dimension: int):
+ self.centers = np.zeros((self.max_speakers, dimension))
+ self.active_centers = set()
+ self.blocked_centers = set()
+
+ def update(self, assignments: Iterable[Tuple[int, int]], embeddings: np.ndarray):
+ if self.centers is not None:
+ for l_spk, g_spk in assignments:
+ assert g_spk in self.active_centers, "Cannot update unknown centers"
+ self.centers[g_spk] += embeddings[l_spk]
+
+ def add_center(self, embedding: np.ndarray) -> int:
+ center = self.get_next_center_position()
+ self.centers[center] = embedding
+ self.active_centers.add(center)
+ return center
+
+ def identify(
+ self,
+ segmentation: SlidingWindowFeature,
+ embeddings: torch.Tensor
+ ) -> SpeakerMap:
+ embeddings = embeddings.detach().cpu().numpy()
+ active_speakers = np.where(np.max(segmentation.data, axis=0) >= self.tau_active)[0]
+ long_speakers = np.where(np.mean(segmentation.data, axis=0) >= self.rho_update)[0]
+ num_local_speakers = segmentation.data.shape[1]
+
+ if self.centers is None:
+ self.init_centers(embeddings.shape[1])
+ assignments = [
+ (spk, self.add_center(embeddings[spk]))
+ for spk in active_speakers
+ ]
+ return SpeakerMapBuilder.hard_map(
+ shape=(num_local_speakers, self.max_speakers),
+ assignments=assignments,
+ maximize=False,
+ )
+
+ # Obtain a mapping based on distances between embeddings and centers
+ dist_map = SpeakerMapBuilder.dist(embeddings, self.centers, self.metric)
+ # Remove any assignments containing invalid speakers
+ inactive_speakers = np.array([
+ spk for spk in range(num_local_speakers)
+ if spk not in active_speakers
+ ])
+ dist_map = dist_map.unmap_speakers(inactive_speakers, self.inactive_centers)
+ # Keep assignments under the distance threshold
+ valid_map = dist_map.unmap_threshold(self.delta_new)
+
+ # Some speakers might be unidentified
+ missed_speakers = [
+ s for s in active_speakers
+ if not valid_map.is_source_speaker_mapped(s)
+ ]
+
+ # Add assignments to new centers if possible
+ new_center_speakers = []
+ for spk in missed_speakers:
+ has_space = len(new_center_speakers) < self.num_free_centers
+ if has_space and spk in long_speakers:
+ # Flag as a new center
+ new_center_speakers.append(spk)
+ else:
+ # Cannot create a new center
+ # Get global speakers in order of preference
+ preferences = np.argsort(dist_map.mapping_matrix[spk, :])
+ preferences = [
+ g_spk for g_spk in preferences if g_spk in self.active_centers
+ ]
+ # Get the free global speakers among the preferences
+ _, g_assigned = valid_map.valid_assignments()
+ free = [g_spk for g_spk in preferences if g_spk not in g_assigned]
+ if free:
+ # The best global speaker is the closest free one
+ valid_map = valid_map.set_source_speaker(spk, free[0])
+
+ # Update known centers
+ to_update = [
+ (ls, gs)
+ for ls, gs in zip(*valid_map.valid_assignments())
+ if ls not in missed_speakers and ls in long_speakers
+ ]
+ self.update(to_update, embeddings)
+
+ # Add new centers
+ for spk in new_center_speakers:
+ valid_map = valid_map.set_source_speaker(
+ spk, self.add_center(embeddings[spk])
+ )
+
+ return valid_map
+
+ def __call__(self, segmentation: SlidingWindowFeature, embeddings: torch.Tensor) -> SlidingWindowFeature:
+ return SlidingWindowFeature(
+ self.identify(segmentation, embeddings).apply(segmentation.data),
+ segmentation.sliding_window
+ )
diff --git a/src/diart/blocks/embedding.py b/src/diart/blocks/embedding.py
new file mode 100644
index 00000000..8598daa2
--- /dev/null
+++ b/src/diart/blocks/embedding.py
@@ -0,0 +1,138 @@
+from typing import Optional, Union
+
+import torch
+from einops import rearrange
+
+from features import TemporalFeatures, TemporalFeatureFormatter
+from models import EmbeddingModel
+
+
+class SpeakerEmbedding:
+ def __init__(self, model: EmbeddingModel, device: Optional[torch.device] = None):
+ self.model = model
+ self.model.eval()
+ self.device = device
+ if self.device is None:
+ self.device = torch.device("cpu")
+ self.model.to(self.device)
+ self.waveform_formatter = TemporalFeatureFormatter()
+ self.weights_formatter = TemporalFeatureFormatter()
+
+ def __call__(self, waveform: TemporalFeatures, weights: Optional[TemporalFeatures] = None) -> torch.Tensor:
+ """
+ Calculate speaker embeddings of input audio.
+ If weights are given, calculate many speaker embeddings from the same waveform.
+
+ Parameters
+ ----------
+ waveform: TemporalFeatures, shape (samples, channels) or (batch, samples, channels)
+ weights: Optional[TemporalFeatures], shape (frames, speakers) or (batch, frames, speakers)
+ Per-speaker and per-frame weights. Defaults to no weights.
+
+ Returns
+ -------
+ embeddings: torch.Tensor
+ If weights are provided, the shape is (batch, speakers, embedding_dim),
+ otherwise the shape is (batch, embedding_dim).
+ If batch size == 1, the batch dimension is omitted.
+ """
+ with torch.no_grad():
+ inputs = self.waveform_formatter.cast(waveform).to(self.device)
+ inputs = rearrange(inputs, "batch sample channel -> batch channel sample")
+ if weights is not None:
+ weights = self.weights_formatter.cast(weights).to(self.device)
+ batch_size, _, num_speakers = weights.shape
+ inputs = inputs.repeat(1, num_speakers, 1)
+ weights = rearrange(weights, "batch frame spk -> (batch spk) frame")
+ inputs = rearrange(inputs, "batch spk sample -> (batch spk) 1 sample")
+ output = rearrange(
+ self.model(inputs, weights),
+ "(batch spk) feat -> batch spk feat",
+ batch=batch_size,
+ spk=num_speakers
+ )
+ else:
+ output = self.model(inputs)
+ return output.squeeze().cpu()
+
+
+class OverlappedSpeechPenalty:
+ """
+ Parameters
+ ----------
+ gamma: float, optional
+ Exponent to lower low-confidence predictions.
+ Defaults to 3.
+ beta: float, optional
+ Temperature parameter (actually 1/beta) to lower joint speaker activations.
+ Defaults to 10.
+ """
+ def __init__(self, gamma: float = 3, beta: float = 10):
+ self.gamma = gamma
+ self.beta = beta
+ self.formatter = TemporalFeatureFormatter()
+
+ def __call__(self, segmentation: TemporalFeatures) -> TemporalFeatures:
+ weights = self.formatter.cast(segmentation) # shape (batch, frames, speakers)
+ with torch.no_grad():
+ probs = torch.softmax(self.beta * weights, dim=-1)
+ weights = torch.pow(weights, self.gamma) * torch.pow(probs, self.gamma)
+ weights[weights < 1e-8] = 1e-8
+ return self.formatter.restore_type(weights)
+
+
+class EmbeddingNormalization:
+ def __init__(self, norm: Union[float, torch.Tensor] = 1):
+ self.norm = norm
+ # Add batch dimension if missing
+ if isinstance(self.norm, torch.Tensor) and self.norm.ndim == 2:
+ self.norm = self.norm.unsqueeze(0)
+
+ def __call__(self, embeddings: torch.Tensor) -> torch.Tensor:
+ # Add batch dimension if missing
+ if embeddings.ndim == 2:
+ embeddings = embeddings.unsqueeze(0)
+ if isinstance(self.norm, torch.Tensor):
+ batch_size1, num_speakers1, _ = self.norm.shape
+ batch_size2, num_speakers2, _ = embeddings.shape
+ assert batch_size1 == batch_size2 and num_speakers1 == num_speakers2
+ with torch.no_grad():
+ norm_embs = self.norm * embeddings / torch.norm(embeddings, p=2, dim=-1, keepdim=True)
+ return norm_embs.squeeze()
+
+
+class OverlapAwareSpeakerEmbedding:
+ """
+ Extract overlap-aware speaker embeddings given an audio chunk and its segmentation.
+
+ Parameters
+ ----------
+ model: EmbeddingModel
+ A pre-trained embedding model.
+ gamma: float, optional
+ Exponent to lower low-confidence predictions.
+ Defaults to 3.
+ beta: float, optional
+ Softmax's temperature parameter (actually 1/beta) to lower joint speaker activations.
+ Defaults to 10.
+ norm: float or torch.Tensor of shape (batch, speakers, 1) where batch is optional
+ The target norm for the embeddings. It can be different for each speaker.
+ Defaults to 1.
+ device: Optional[torch.device]
+ The device on which to run the embedding model.
+ Defaults to GPU if available or CPU if not.
+ """
+ def __init__(
+ self,
+ model: EmbeddingModel,
+ gamma: float = 3,
+ beta: float = 10,
+ norm: Union[float, torch.Tensor] = 1,
+ device: Optional[torch.device] = None,
+ ):
+ self.embedding = SpeakerEmbedding(model, device)
+ self.osp = OverlappedSpeechPenalty(gamma, beta)
+ self.normalize = EmbeddingNormalization(norm)
+
+ def __call__(self, waveform: TemporalFeatures, segmentation: TemporalFeatures) -> torch.Tensor:
+ return self.normalize(self.embedding(waveform, self.osp(segmentation)))
diff --git a/src/diart/blocks/segmentation.py b/src/diart/blocks/segmentation.py
new file mode 100644
index 00000000..d3015806
--- /dev/null
+++ b/src/diart/blocks/segmentation.py
@@ -0,0 +1,36 @@
+from typing import Optional
+
+import torch
+from einops import rearrange
+
+from features import TemporalFeatures, TemporalFeatureFormatter
+from models import SegmentationModel
+
+
+class SpeakerSegmentation:
+ def __init__(self, model: SegmentationModel, device: Optional[torch.device] = None):
+ self.model = model
+ self.model.eval()
+ self.device = device
+ if self.device is None:
+ self.device = torch.device("cpu")
+ self.model.to(self.device)
+ self.formatter = TemporalFeatureFormatter()
+
+ def __call__(self, waveform: TemporalFeatures) -> TemporalFeatures:
+ """
+ Calculate the speaker segmentation of input audio.
+
+ Parameters
+ ----------
+ waveform: TemporalFeatures, shape (samples, channels) or (batch, samples, channels)
+
+ Returns
+ -------
+ speaker_segmentation: TemporalFeatures, shape (batch, frames, speakers)
+ The batch dimension is omitted if waveform is a `SlidingWindowFeature`.
+ """
+ with torch.no_grad():
+ wave = rearrange(self.formatter.cast(waveform), "batch sample channel -> batch channel sample")
+ output = self.model(wave.to(self.device)).cpu()
+ return self.formatter.restore_type(output)
\ No newline at end of file
diff --git a/src/diart/blocks/utils.py b/src/diart/blocks/utils.py
new file mode 100644
index 00000000..9a495ca8
--- /dev/null
+++ b/src/diart/blocks/utils.py
@@ -0,0 +1,55 @@
+from typing import Text
+
+import numpy as np
+from pyannote.core import Annotation, Segment, SlidingWindowFeature
+
+
+class Binarize:
+ """
+ Transform a speaker segmentation from the discrete-time domain
+ into a continuous-time speaker segmentation.
+
+ Parameters
+ ----------
+ uri: Text
+ Uri of the audio stream.
+ threshold: float
+ Probability threshold to determine if a speaker is active at a given frame.
+ """
+
+ def __init__(self, uri: Text, threshold: float):
+ self.uri = uri
+ self.threshold = threshold
+
+ def __call__(self, segmentation: SlidingWindowFeature) -> Annotation:
+ """
+ Return the continuous-time segmentation
+ corresponding to the discrete-time input segmentation.
+
+ Parameters
+ ----------
+ segmentation: SlidingWindowFeature
+ Discrete-time speaker segmentation.
+
+ Returns
+ -------
+ annotation: Annotation
+ Continuous-time speaker segmentation.
+ """
+ num_frames, num_speakers = segmentation.data.shape
+ timestamps = segmentation.sliding_window
+ is_active = segmentation.data > self.threshold
+ # Artificially add last inactive frame to close any remaining speaker turns
+ is_active = np.append(is_active, [[False] * num_speakers], axis=0)
+ start_times = np.zeros(num_speakers) + timestamps[0].middle
+ annotation = Annotation(uri=self.uri, modality="speech")
+ for t in range(num_frames):
+ # Any (False, True) starts a speaker turn at "True" index
+ onsets = np.logical_and(np.logical_not(is_active[t]), is_active[t + 1])
+ start_times[onsets] = timestamps[t + 1].middle
+ # Any (True, False) ends a speaker turn at "False" index
+ offsets = np.logical_and(is_active[t], np.logical_not(is_active[t + 1]))
+ for spk in np.where(offsets)[0]:
+ region = Segment(start_times[spk], timestamps[t + 1].middle)
+ annotation[region, spk] = f"speaker{spk}"
+ return annotation
From 8b47aebd6c7fdaded6d44ed77b235570e19a8029 Mon Sep 17 00:00:00 2001
From: Juan Manuel Coria
Date: Mon, 13 Jun 2022 18:45:31 +0200
Subject: [PATCH 16/29] Add relative imports to blocks package
---
src/diart/blocks/clustering.py | 2 +-
src/diart/blocks/embedding.py | 4 ++--
src/diart/blocks/segmentation.py | 4 ++--
3 files changed, 5 insertions(+), 5 deletions(-)
diff --git a/src/diart/blocks/clustering.py b/src/diart/blocks/clustering.py
index 1842ecfa..57a3ab2c 100644
--- a/src/diart/blocks/clustering.py
+++ b/src/diart/blocks/clustering.py
@@ -4,7 +4,7 @@
import torch
from pyannote.core import SlidingWindowFeature
-from mapping import SpeakerMap, SpeakerMapBuilder
+from ..mapping import SpeakerMap, SpeakerMapBuilder
class OnlineSpeakerClustering:
diff --git a/src/diart/blocks/embedding.py b/src/diart/blocks/embedding.py
index 8598daa2..19288876 100644
--- a/src/diart/blocks/embedding.py
+++ b/src/diart/blocks/embedding.py
@@ -3,8 +3,8 @@
import torch
from einops import rearrange
-from features import TemporalFeatures, TemporalFeatureFormatter
-from models import EmbeddingModel
+from ..features import TemporalFeatures, TemporalFeatureFormatter
+from ..models import EmbeddingModel
class SpeakerEmbedding:
diff --git a/src/diart/blocks/segmentation.py b/src/diart/blocks/segmentation.py
index d3015806..33064207 100644
--- a/src/diart/blocks/segmentation.py
+++ b/src/diart/blocks/segmentation.py
@@ -3,8 +3,8 @@
import torch
from einops import rearrange
-from features import TemporalFeatures, TemporalFeatureFormatter
-from models import SegmentationModel
+from ..features import TemporalFeatures, TemporalFeatureFormatter
+from ..models import SegmentationModel
class SpeakerSegmentation:
From b91550bb83c959670905307174503ed32f0e3758 Mon Sep 17 00:00:00 2001
From: Juan Manuel Coria
Date: Mon, 13 Jun 2022 21:00:01 +0200
Subject: [PATCH 17/29] Add README instructions needed for a no-pyannote
installation
---
README.md | 13 ++++++++++---
1 file changed, 10 insertions(+), 3 deletions(-)
diff --git a/README.md b/README.md
index eb90e123..4a32803c 100644
--- a/README.md
+++ b/README.md
@@ -27,9 +27,16 @@ conda create -n diart python=3.8
conda activate diart
```
-2) [Install PyTorch](https://pytorch.org/get-started/locally/#start-locally)
+2) Install `PortAudio` and `soundfile`:
-3) Install pyannote.audio 2.0 (currently in development)
+```shell
+conda install portaudio
+conda install pysoundfile -c conda-forge
+```
+
+3) [Install PyTorch](https://pytorch.org/get-started/locally/#start-locally)
+
+4) Install pyannote.audio 2.0 (currently in development)
```shell
pip install git+https://github.com/pyannote/pyannote-audio.git@develop#egg=pyannote-audio
@@ -37,7 +44,7 @@ pip install git+https://github.com/pyannote/pyannote-audio.git@develop#egg=pyann
**Note:** starting from version 0.4, installing pyannote.audio is mandatory to run the default system or to use pyannote-based models. In any other case, this step can be ignored.
-4) Install diart:
+5) Install diart:
```shell
pip install diart
```
From a6ddb07907afe81dc9919c216e1beff4a791086a Mon Sep 17 00:00:00 2001
From: Khaled Zaouk
Date: Mon, 20 Jun 2022 11:17:28 +0200
Subject: [PATCH 18/29] Add documentation for some classes and methods (#31)
* Adds documentation for some of the classes and methods under functional.py and mapping.py
Co-authored-by: Juan Coria
---
src/diart/blocks/clustering.py | 62 ++++++++++++++++++++++++++++++++++
src/diart/mapping.py | 34 +++++++++++++++++++
src/diart/pipelines.py | 3 +-
3 files changed, 97 insertions(+), 2 deletions(-)
diff --git a/src/diart/blocks/clustering.py b/src/diart/blocks/clustering.py
index 57a3ab2c..882001b9 100644
--- a/src/diart/blocks/clustering.py
+++ b/src/diart/blocks/clustering.py
@@ -8,6 +8,25 @@
class OnlineSpeakerClustering:
+ """Implements constrained incremental online clustering of speakers and manages cluster centers.
+
+ Parameters
+ ----------
+ tau_active:float
+ Threshold for detecting active speakers. This threshold is applied on the maximum value of per-speaker output
+ activation of the local segmentation model.
+ rho_update: float
+ Threshold for considering the extracted embedding when updating the centroid of the local speaker.
+ The centroid to which a local speaker is mapped is only updated if the ratio of speech/chunk duration
+ of a given local speaker is greater than this threshold.
+ delta_new: float
+ Threshold on the distance between a speaker embedding and a centroid. If the distance between a local speaker and all
+ centroids is larger than delta_new, then a new centroid is created for the current speaker.
+ metric: str. Defaults to "cosine".
+ The distance metric to use.
+ max_speakers: int
+ Maximum number of global speakers to track through a conversation. Defaults to 20.
+ """
def __init__(
self,
tau_active: float,
@@ -51,17 +70,46 @@ def get_next_center_position(self) -> Optional[int]:
return center
def init_centers(self, dimension: int):
+ """Initializes the speaker centroid matrix
+
+ Parameters
+ ----------
+ dimension: int
+ Dimension of embeddings used for representing a speaker.
+ """
self.centers = np.zeros((self.max_speakers, dimension))
self.active_centers = set()
self.blocked_centers = set()
def update(self, assignments: Iterable[Tuple[int, int]], embeddings: np.ndarray):
+ """Updates the speaker centroids given a list of assignments and local speaker embeddings
+
+ Parameters
+ ----------
+ assignments: Iterable[Tuple[int, int]])
+ An iterable of tuples with two elements having the first element as the source speaker
+ and the second element as the target speaker.
+ embeddings: np.ndarray, shape (local_speakers, embedding_dim)
+ Matrix containing embeddings for all local speakers.
+ """
if self.centers is not None:
for l_spk, g_spk in assignments:
assert g_spk in self.active_centers, "Cannot update unknown centers"
self.centers[g_spk] += embeddings[l_spk]
def add_center(self, embedding: np.ndarray) -> int:
+ """Add a new speaker centroid initialized to a given embedding
+
+ Parameters
+ ----------
+ embedding: np.ndarray
+ Embedding vector of some local speaker
+
+ Returns
+ -------
+ center_index: int
+ Index of the created center
+ """
center = self.get_next_center_position()
self.centers[center] = embedding
self.active_centers.add(center)
@@ -72,6 +120,20 @@ def identify(
segmentation: SlidingWindowFeature,
embeddings: torch.Tensor
) -> SpeakerMap:
+ """Identify the centroids to which the input speaker embeddings belong.
+
+ Parameters
+ ----------
+ segmentation: np.ndarray, shape (frames, local_speakers)
+ Matrix of segmentation outputs
+ embeddings: np.ndarray, shape (local_speakers, embedding_dim)
+ Matrix of embeddings
+
+ Returns
+ -------
+ speaker_map: SpeakerMap
+ A mapping from local speakers to global speakers.
+ """
embeddings = embeddings.detach().cpu().numpy()
active_speakers = np.where(np.max(segmentation.data, axis=0) >= self.tau_active)[0]
long_speakers = np.where(np.mean(segmentation.data, axis=0) >= self.rho_update)[0]
diff --git a/src/diart/mapping.py b/src/diart/mapping.py
index 01465086..2795ba0b 100644
--- a/src/diart/mapping.py
+++ b/src/diart/mapping.py
@@ -22,6 +22,23 @@ def mapped_indices(self, matrix: np.ndarray, axis: int) -> List[int]:
def hard_speaker_map(
self, num_src: int, num_tgt: int, assignments: Iterable[Tuple[int, int]]
) -> SpeakerMap:
+ """Create a hard map object where the highest cost is put
+ everywhere except on hard assignments from ``assignments``.
+
+ Parameters
+ ----------
+ num_src: int
+ Number of source speakers
+ num_tgt: int
+ Number of target speakers
+ assignments: Iterable[Tuple[int, int]]
+ An iterable of tuples with two elements having the first element as the source speaker
+ and the second element as the target speaker
+
+ Returns
+ -------
+ SpeakerMap
+ """
mapping_matrix = self.invalid_tensor(shape=(num_src, num_tgt))
for src, tgt in assignments:
mapping_matrix[src, tgt] = self.best_possible_value
@@ -82,6 +99,23 @@ class SpeakerMapBuilder:
def hard_map(
shape: Tuple[int, int], assignments: Iterable[Tuple[int, int]], maximize: bool
) -> SpeakerMap:
+ """Create a ``SpeakerMap`` object based on the given assignments. This is a "hard" map, meaning that the
+ highest cost is put everywhere except on hard assignments from ``assignments``.
+
+ Parameters
+ ----------
+ shape: Tuple[int, int])
+ Shape of the mapping matrix
+ assignments: Iterable[Tuple[int, int]]
+ An iterable of tuples with two elements having the first element as the source speaker
+ and the second element as the target speaker
+ maximize: bool
+ whether to use scores where higher is better (true) or where lower is better (false)
+
+ Returns
+ -------
+ SpeakerMap
+ """
num_src, num_tgt = shape
objective = MaximizationObjective if maximize else MinimizationObjective
return objective().hard_speaker_map(num_src, num_tgt, assignments)
diff --git a/src/diart/pipelines.py b/src/diart/pipelines.py
index 4ccbf1d4..dea360d6 100644
--- a/src/diart/pipelines.py
+++ b/src/diart/pipelines.py
@@ -81,8 +81,7 @@ def from_namespace(args: Namespace) -> 'PipelineConfig':
)
def last_chunk_end_time(self, conv_duration: float) -> Optional[float]:
- """
- Return the end time of the last chunk for a given conversation duration.
+ """Return the end time of the last chunk for a given conversation duration.
Parameters
----------
From 8b91773c7f8207a08205e943dff7c37bc3398e4d Mon Sep 17 00:00:00 2001
From: Juan Manuel Coria
Date: Mon, 13 Jun 2022 22:29:51 +0200
Subject: [PATCH 19/29] Add initial implementation of hyper-parameter tuning
with optuna
---
requirements.txt | 2 +-
setup.cfg | 1 +
src/diart/benchmark.py | 7 ++++
src/diart/inference.py | 29 +++++++++-----
src/diart/optim.py | 90 ++++++++++++++++++++++++++++++++++++++++++
src/diart/pipelines.py | 16 ++++----
src/diart/stream.py | 7 ++++
7 files changed, 135 insertions(+), 17 deletions(-)
create mode 100644 src/diart/optim.py
diff --git a/requirements.txt b/requirements.txt
index 2a267b34..d2259fb5 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -10,4 +10,4 @@ torchaudio>=0.10,<1.0
pyannote.core>=4.4
pyannote.database>=4.1.1
pyannote.metrics>=3.2
-git+https://github.com/pyannote/pyannote-audio.git@develop#egg=pyannote-audio
+optuna>=2.10
diff --git a/setup.cfg b/setup.cfg
index d2d5c5b3..a0f7c776 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -32,6 +32,7 @@ install_requires =
pyannote.core>=4.4
pyannote.database>=4.1.1
pyannote.metrics>=3.2
+ optuna>=2.10
[options.packages.find]
diff --git a/src/diart/benchmark.py b/src/diart/benchmark.py
index 1ac0bafd..c83d9e54 100644
--- a/src/diart/benchmark.py
+++ b/src/diart/benchmark.py
@@ -1,5 +1,7 @@
import argparse
+import torch
+
import diart.argdoc as argdoc
from diart.inference import Benchmark
from diart.pipelines import OnlineSpeakerDiarization, PipelineConfig
@@ -21,6 +23,11 @@
parser.add_argument("--output", type=str, help=f"{argdoc.OUTPUT}. Defaults to `root`")
args = parser.parse_args()
+ args.device = torch.device("cpu") if args.cpu else None
+ args.tau_active = args.tau
+ args.rho_update = args.rho
+ args.delta_new = args.delta
+
Benchmark(args.root, args.reference, args.output)(
OnlineSpeakerDiarization(PipelineConfig.from_namespace(args), profile=True),
args.batch_size
diff --git a/src/diart/inference.py b/src/diart/inference.py
index 6e204247..ec31a2fe 100644
--- a/src/diart/inference.py
+++ b/src/diart/inference.py
@@ -107,7 +107,12 @@ def __init__(
self.output_path = Path(output_path).expanduser()
self.output_path.mkdir(parents=True, exist_ok=True)
- def __call__(self, pipeline: OnlineSpeakerDiarization, batch_size: int = 32) -> Optional[pd.DataFrame]:
+ def __call__(
+ self,
+ pipeline: OnlineSpeakerDiarization,
+ batch_size: int = 32,
+ verbose: bool = True
+ ) -> Optional[pd.DataFrame]:
"""
Run a given pipeline on a set of audio files using
pre-calculated segmentation and embeddings in batches.
@@ -118,6 +123,8 @@ def __call__(self, pipeline: OnlineSpeakerDiarization, batch_size: int = 32) ->
Configured speaker diarization pipeline.
batch_size: int
Batch size. Defaults to 32.
+ verbose: bool
+ Whether to log its progress. Defaults to True.
Returns
-------
@@ -143,6 +150,7 @@ def __call__(self, pipeline: OnlineSpeakerDiarization, batch_size: int = 32) ->
)
observable = pipeline.from_audio_source(source)
else:
+ msg = f"Pre-calculating {filepath.stem} ({i + 1}/{num_audio_files})"
source = src.PrecalculatedFeaturesAudioSource(
filepath,
pipeline.config.sample_rate,
@@ -151,17 +159,20 @@ def __call__(self, pipeline: OnlineSpeakerDiarization, batch_size: int = 32) ->
pipeline.config.duration,
pipeline.config.step,
batch_size,
- progress_msg=f"Pre-calculating {filepath.stem} ({i + 1}/{num_audio_files})",
+ progress_msg=msg if verbose else None,
)
observable = pipeline.from_feature_source(source)
- observable.pipe(
- dops.progress(
- desc=f"Streaming {filepath.stem} ({i + 1}/{num_audio_files})",
- total=num_chunks,
- leave=isinstance(source, src.PrecalculatedFeaturesAudioSource)
+ if verbose:
+ observable = observable.pipe(
+ dops.progress(
+ desc=f"Streaming {filepath.stem} ({i + 1}/{num_audio_files})",
+ total=num_chunks,
+ leave=isinstance(source, src.PrecalculatedFeaturesAudioSource)
+ )
)
- ).subscribe(RTTMWriter(path=self.output_path / f"{filepath.stem}.rttm"))
+
+ observable.subscribe(RTTMWriter(path=self.output_path / f"{filepath.stem}.rttm"))
source.read()
@@ -173,4 +184,4 @@ def __call__(self, pipeline: OnlineSpeakerDiarization, batch_size: int = 32) ->
hyp = load_rttm(self.output_path / ref_path.name).popitem()[1]
metric(ref, hyp)
- return metric.report(display=True)
+ return metric.report(display=verbose)
diff --git a/src/diart/optim.py b/src/diart/optim.py
new file mode 100644
index 00000000..e1138ffa
--- /dev/null
+++ b/src/diart/optim.py
@@ -0,0 +1,90 @@
+from pathlib import Path
+from typing import Dict, Text, Tuple, Optional, Callable
+
+import optuna
+from optuna.pruners._base import BasePruner
+from optuna.samplers import TPESampler
+from optuna.samplers._base import BaseSampler
+from optuna.trial import Trial
+
+from audio import FilePath
+from benchmark import Benchmark
+from pipelines import PipelineConfig, OnlineSpeakerDiarization
+
+
+class HyperParameterOptimizer:
+ def __init__(
+ self,
+ speech_path: FilePath,
+ reference_path: FilePath,
+ output_path: FilePath,
+ base_config: PipelineConfig,
+ hparams: Dict[Text, Tuple[float, float]],
+ batch_size: int = 32,
+ ):
+ self.base_config = base_config
+ self.hparams = hparams
+ self.batch_size = batch_size
+ self.output_path = Path(output_path).expanduser()
+ self.output_path.mkdir(parents=True, exist_ok=False)
+ self.tmp_path = self.output_path / "current_iter"
+ self.benchmark = Benchmark(speech_path, reference_path, self.tmp_path)
+
+ def _objective(self, trial: Trial) -> float:
+ # Set suggested values for optimized hyper-parameters
+ trial_config = vars(self.base_config)
+ for hp_name, (low, high) in self.hparams.items():
+ trial_config[hp_name] = trial.suggest_uniform(hp_name, low, high)
+
+ # Instantiate pipeline with the new configuration
+ pipeline = OnlineSpeakerDiarization(PipelineConfig(**trial_config))
+
+ # Run pipeline over the dataset
+ report = self.benchmark(pipeline, self.batch_size, verbose=False)
+
+ # Clean RTTM files
+ for tmp_file in self.tmp_path.iterdir():
+ tmp_file.unlink()
+
+ # Extract DER from report
+ return report.loc["TOTAL", "diarization error rate %"]
+
+ def optimize(
+ self,
+ sampler: Optional[BaseSampler] = None,
+ pruner: Optional[BasePruner] = None,
+ num_iter: int = 100,
+ experiment_name: Optional[Text] = None,
+ ) -> Tuple[float, Dict[Text, float]]:
+ """Optimize the given hyper-parameters on the given dataset.
+
+ Parameters
+ ----------
+ sampler: Optional[optuna.BaseSampler]
+ The Optuna sampler to use during optimization. Defaults to TPESampler.
+ pruner: Optional[optuna.BasePruner]
+ The Optuna pruner to use during optimization. Defaults to None.
+ num_iter: int
+ Number of iterations over the dataset. Defaults to 100.
+ experiment_name: Optional[Text]
+ Name of the optimization run. Defaults to None.
+
+ Returns
+ -------
+ der_value: float
+ Diarization error rate of the best iteration.
+ best_hparams: Dict[Text, float]
+ Hyper-parameters of the best iteration.
+ """
+ sampler = TPESampler() if sampler is None else sampler
+ storage = self.output_path / "trials.db"
+ study = optuna.create_study(
+ storage=f"sqlite://{str(storage)}",
+ sampler=sampler,
+ pruner=pruner,
+ study_name=experiment_name,
+ direction="minimize",
+ load_if_exists=True,
+ )
+ study.optimize(self._objective, n_trials=num_iter)
+ return study.best_value, study.best_params
diff --git a/src/diart/pipelines.py b/src/diart/pipelines.py
index dea360d6..fa27c087 100644
--- a/src/diart/pipelines.py
+++ b/src/diart/pipelines.py
@@ -1,5 +1,4 @@
-from argparse import Namespace
-from typing import Optional, List
+from typing import Optional, List, Any
import rx
import rx.operators as ops
@@ -64,22 +63,25 @@ def __init__(
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@staticmethod
- def from_namespace(args: Namespace) -> 'PipelineConfig':
+ def from_namespace(args: Any) -> 'PipelineConfig':
return PipelineConfig(
segmentation=getattr(args, "segmentation", None),
embedding=getattr(args, "embedding", None),
duration=getattr(args, "duration", None),
step=args.step,
latency=args.latency,
- tau_active=args.tau,
- rho_update=args.rho,
- delta_new=args.delta,
+ tau_active=args.tau_active,
+ rho_update=args.rho_update,
+ delta_new=args.delta_new,
gamma=args.gamma,
beta=args.beta,
max_speakers=args.max_speakers,
- device=torch.device("cpu") if args.cpu else None,
+ device=args.device,
)
+ def copy(self) -> 'PipelineConfig':
+ return PipelineConfig.from_namespace(self)
+
def last_chunk_end_time(self, conv_duration: float) -> Optional[float]:
"""Return the end time of the last chunk for a given conversation duration.
diff --git a/src/diart/stream.py b/src/diart/stream.py
index 0ef8a5a6..1e548bf6 100644
--- a/src/diart/stream.py
+++ b/src/diart/stream.py
@@ -1,6 +1,8 @@
import argparse
from pathlib import Path
+import torch
+
import diart.argdoc as argdoc
import diart.sources as src
from diart.inference import RealTimeInference
@@ -22,6 +24,11 @@
parser.add_argument("--output", type=str, help=f"{argdoc.OUTPUT}. Defaults to home directory if SOURCE == 'microphone' or parent directory if SOURCE is a file")
args = parser.parse_args()
+ args.device = torch.device("cpu") if args.cpu else None
+ args.tau_active = args.tau
+ args.rho_update = args.rho
+ args.delta_new = args.delta
+
# Define online speaker diarization pipeline
config = PipelineConfig.from_namespace(args)
pipeline = OnlineSpeakerDiarization(config, profile=True)
From e1a647091e8a79b42dc71c52ae800c5f8575d8ef Mon Sep 17 00:00:00 2001
From: Juan Manuel Coria
Date: Wed, 15 Jun 2022 17:08:27 +0200
Subject: [PATCH 20/29] Improve hyper-parameter optimization
---
README.md | 43 +++++++++++-
src/diart/benchmark.py | 16 +++--
src/diart/inference.py | 30 ++++-----
src/diart/optim.py | 150 +++++++++++++++++++++++++----------------
src/diart/pipelines.py | 10 ++-
src/diart/stream.py | 4 --
6 files changed, 160 insertions(+), 93 deletions(-)
diff --git a/README.md b/README.md
index 4a32803c..2a52b0d5 100644
--- a/README.md
+++ b/README.md
@@ -84,9 +84,50 @@ inference(pipeline, audio_source)
For faster inference and evaluation on a dataset we recommend to use `Benchmark` (see our notes on [reproducibility](#reproducibility))
+## Optimize hyper-parameters to your own dataset
+
+Diart implements a hyper-parameter optimizer based on [optuna](https://github.com/optuna/optuna).
+`diart.optim.Optimizer` allows you to tune any pipeline to a custom dataset.
+More information on Optuna can be found [here](https://optuna.readthedocs.io/en/stable/index.html).
+
+### A simple example
+
+```python
+from diart.optim import OptimizationObjective, Optimizer, TauActive, RhoUpdate, DeltaNew
+from diart.pipelines import PipelineConfig
+from diart.inference import Benchmark
+
+# Benchmark runs and evaluates the pipeline with each configuration
+benchmark = Benchmark("/wav/dir", "/rttm/dir", "/out/dir", show_report=False)
+# Base configuration for the pipeline we're going to tune
+base_config = PipelineConfig(duration=5, step=0.5, latency=5)
+# Hyper-parameters to optimize
+hparams = [TauActive, RhoUpdate, DeltaNew]
+# The objective implements an optimization step
+objective = OptimizationObjective(benchmark, base_config, hparams)
+# Run optimization for 100 iterations
+Optimizer(objective).optimize(num_iter=100, show_progress=True)
+```
+
+### Distributed optimization
+
+For bigger datasets, it is sometimes more convenient to run optimization in parallel.
+If the same `study_name` and `storage` are given to the optimizer, all optimization processes will share the information from previous runs.
+More information on distributed optimization can be found [here](https://optuna.readthedocs.io/en/stable/tutorial/10_key_features/004_distributed.html#sphx-glr-tutorial-10-key-features-004-distributed-py).
+
+```python
+from diart.optim import Optimizer
+
+objective = ...
+study_name = "my_study"
+storage = "mysql://root@localhost/example"
+optimizer = Optimizer(objective, study_name, storage)
+optimizer.optimize(num_iter=100, show_progress=True)
+```
+
## Build your own pipeline
-Diart also provides building blocks that can be combined to create your own pipeline.
+For a more advanced usage, diart also provides building blocks that can be combined to create your own pipelines.
Streaming is powered by [RxPY](https://github.com/ReactiveX/RxPY), but the `blocks` module is completely independent and can be used separately.
### Example
diff --git a/src/diart/benchmark.py b/src/diart/benchmark.py
index c83d9e54..b5fcc567 100644
--- a/src/diart/benchmark.py
+++ b/src/diart/benchmark.py
@@ -22,13 +22,15 @@
parser.add_argument("--cpu", dest="cpu", action="store_true", help=f"{argdoc.CPU}. Defaults to GPU if available, CPU otherwise")
parser.add_argument("--output", type=str, help=f"{argdoc.OUTPUT}. Defaults to `root`")
args = parser.parse_args()
-
args.device = torch.device("cpu") if args.cpu else None
- args.tau_active = args.tau
- args.rho_update = args.rho
- args.delta_new = args.delta
- Benchmark(args.root, args.reference, args.output)(
- OnlineSpeakerDiarization(PipelineConfig.from_namespace(args), profile=True),
- args.batch_size
+ benchmark = Benchmark(
+ args.root,
+ args.reference,
+ args.output,
+ show_progress=True,
+ show_report=True,
+ batch_size=args.batch_size
)
+
+ benchmark(OnlineSpeakerDiarization(PipelineConfig.from_namespace(args), profile=True))
diff --git a/src/diart/inference.py b/src/diart/inference.py
index ec31a2fe..e58240e9 100644
--- a/src/diart/inference.py
+++ b/src/diart/inference.py
@@ -92,6 +92,9 @@ def __init__(
speech_path: Union[Text, Path],
reference_path: Optional[Union[Text, Path]] = None,
output_path: Optional[Union[Text, Path]] = None,
+ show_progress: bool = True,
+ show_report: bool = True,
+ batch_size: int = 32,
):
self.speech_path = Path(speech_path).expanduser()
assert self.speech_path.is_dir(), "Speech path must be a directory"
@@ -107,12 +110,11 @@ def __init__(
self.output_path = Path(output_path).expanduser()
self.output_path.mkdir(parents=True, exist_ok=True)
- def __call__(
- self,
- pipeline: OnlineSpeakerDiarization,
- batch_size: int = 32,
- verbose: bool = True
- ) -> Optional[pd.DataFrame]:
+ self.show_progress = show_progress
+ self.show_report = show_report
+ self.batch_size = batch_size
+
+ def __call__(self, pipeline: OnlineSpeakerDiarization) -> Optional[pd.DataFrame]:
"""
Run a given pipeline on a set of audio files using
pre-calculated segmentation and embeddings in batches.
@@ -121,10 +123,6 @@ def __call__(
----------
pipeline: OnlineSpeakerDiarization
Configured speaker diarization pipeline.
- batch_size: int
- Batch size. Defaults to 32.
- verbose: bool
- Whether to log its progress. Defaults to True.
Returns
-------
@@ -141,7 +139,7 @@ def __call__(
)
# Stream fully online if batch size is 1 or lower
- if batch_size < 2:
+ if self.batch_size < 2:
source = src.FileAudioSource(
filepath,
pipeline.config.sample_rate,
@@ -158,17 +156,17 @@ def __call__(
pipeline.embedding,
pipeline.config.duration,
pipeline.config.step,
- batch_size,
- progress_msg=msg if verbose else None,
+ self.batch_size,
+ progress_msg=msg if self.show_progress else None,
)
observable = pipeline.from_feature_source(source)
- if verbose:
+ if self.show_progress:
observable = observable.pipe(
dops.progress(
desc=f"Streaming {filepath.stem} ({i + 1}/{num_audio_files})",
total=num_chunks,
- leave=isinstance(source, src.PrecalculatedFeaturesAudioSource)
+ leave=False,
)
)
@@ -184,4 +182,4 @@ def __call__(
hyp = load_rttm(self.output_path / ref_path.name).popitem()[1]
metric(ref, hyp)
- return metric.report(display=verbose)
+ return metric.report(display=self.show_report)
diff --git a/src/diart/optim.py b/src/diart/optim.py
index e1138ffa..4260f952 100644
--- a/src/diart/optim.py
+++ b/src/diart/optim.py
@@ -1,90 +1,122 @@
+from collections import OrderedDict
+from dataclasses import dataclass
from pathlib import Path
-from typing import Dict, Text, Tuple, Optional, Callable
+from typing import Iterable, Text, Optional
-import optuna
-from optuna.pruners._base import BasePruner
-from optuna.samplers import TPESampler
-from optuna.samplers._base import BaseSampler
-from optuna.trial import Trial
+from optuna import TrialPruned, Study, create_study
+from optuna.pruners import BasePruner
+from optuna.samplers import TPESampler, BaseSampler
+from optuna.trial import Trial, FrozenTrial
+from tqdm import trange, tqdm
-from audio import FilePath
-from benchmark import Benchmark
-from pipelines import PipelineConfig, OnlineSpeakerDiarization
+from .benchmark import Benchmark
+from .pipelines import PipelineConfig, OnlineSpeakerDiarization
-class HyperParameterOptimizer:
+@dataclass
+class HyperParameter:
+ name: Text
+ low: float
+ high: float
+
+
+TauActive = HyperParameter("tau_active", low=0, high=1)
+RhoUpdate = HyperParameter("rho_update", low=0, high=1)
+DeltaNew = HyperParameter("delta_new", low=0, high=2)
+
+
+class OptimizationObjective:
def __init__(
self,
- speech_path: FilePath,
- reference_path: FilePath,
- output_path: FilePath,
+ benchmark: Benchmark,
base_config: PipelineConfig,
- hparams: Dict[Text, Tuple[float, float]],
- batch_size: int = 32,
+ hparams: Iterable[HyperParameter],
):
+ self.benchmark = benchmark
self.base_config = base_config
self.hparams = hparams
- self.batch_size = batch_size
- self.output_path = Path(output_path).expanduser()
- self.output_path.mkdir(parents=True, exist_ok=False)
- self.tmp_path = self.output_path / "current_iter"
- self.benchmark = Benchmark(speech_path, reference_path, self.tmp_path)
- def _objective(self, trial: Trial) -> float:
+ def __call__(self, trial: Trial) -> float:
# Set suggested values for optimized hyper-parameters
trial_config = vars(self.base_config)
- for hp_name, (low, high) in self.hparams.items():
- trial_config[hp_name] = trial.suggest_uniform(hp_name, low, high)
+ for hparam in self.hparams:
+ trial_config[hparam.name] = trial.suggest_uniform(
+ hparam.name, hparam.low, hparam.high
+ )
# Instantiate pipeline with the new configuration
pipeline = OnlineSpeakerDiarization(PipelineConfig(**trial_config))
+ # Prune trial if required
+ if trial.should_prune():
+ raise TrialPruned()
+
# Run pipeline over the dataset
- report = self.benchmark(pipeline, self.batch_size, verbose=False)
+ report = self.benchmark(pipeline)
# Clean RTTM files
- for tmp_file in self.tmp_path.iterdir():
- tmp_file.unlink()
+ for tmp_file in self.benchmark.output_path.iterdir():
+ if tmp_file.name.endswith(".rttm"):
+ tmp_file.unlink()
# Extract DER from report
- return report.loc["TOTAL", "diarization error rate %"]
+ return report.loc["TOTAL", "diarization error rate"]["%"]
+
- def optimize(
+class Optimizer:
+ def __init__(
self,
+ objective: OptimizationObjective,
+ study_name: Optional[Text] = None,
+ storage: Optional[Text] = None,
sampler: Optional[BaseSampler] = None,
pruner: Optional[BasePruner] = None,
- num_iter: int = 100,
- experiment_name: Optional[Text] = None,
- ) -> Tuple[float, Dict[Text, float]]:
- """Optimize the given hyper-parameters on the given dataset.
-
- Parameters
- ----------
- sampler: Optional[optuna.BaseSampler]
- The Optuna sampler to use during optimization. Defaults to TPESampler.
- pruner: Optional[optuna.BasePruner]
- The Optuna pruner to use during optimization. Defaults to None.
- num_iter: int
- Number of iterations over the dataset. Defaults to 100.
- experiment_name: Optional[Text]
- Name of the optimization run. Defaults to None.
-
- Returns
- -------
- der_value: float
- Diarization error rate of the best iteration.
- best_hparams: Dict[Text, float]
- Hyper-parameters of the best iteration.
- """
- sampler = TPESampler() if sampler is None else sampler
- storage = self.output_path / "trials.db"
- study = optuna.create_study(
- storage=f"sqlite://{str(storage)}",
- sampler=sampler,
+ ):
+ self.objective = objective
+ self.study = create_study(
+ storage=self.default_storage if storage is None else storage,
+ sampler=TPESampler() if sampler is None else sampler,
pruner=pruner,
- study_name=experiment_name,
+ study_name=self.default_study_name if study_name is None else study_name,
direction="minimize",
load_if_exists=True,
)
- study.optimize(self._objective, n_trials=num_iter)
- return study.best_value, study.best_params
+ self._progress: Optional[tqdm] = None
+
+ @property
+ def default_output_path(self) -> Path:
+ return self.objective.benchmark.output_path.parent
+
+ @property
+ def default_study_name(self) -> Text:
+ return self.default_output_path.name
+
+ @property
+ def default_storage(self) -> Text:
+ return "sqlite:///" + str(self.default_output_path / "trials.db")
+
+ @property
+ def best_performance(self):
+ return self.study.best_value
+
+ @property
+ def best_hparams(self):
+ return self.study.best_params
+
+ def _callback(self, study: Study, trial: FrozenTrial):
+ if self._progress is None:
+ return
+ self._progress.update(1)
+ self._progress.set_description(f"Trial {trial.number + 1}")
+ values = {"best_der": study.best_value}
+ for name, value in study.best_params.items():
+ values[f"best_{name}"] = value
+ self._progress.set_postfix(OrderedDict(values))
+
+ def optimize(self, num_iter: int, show_progress: bool = True):
+ self._progress = None
+ if show_progress:
+ self._progress = trange(num_iter)
+ last_trial = self.study.trials[-1].number
+ self._progress.set_description(f"Trial {last_trial + 1}")
+ self.study.optimize(self.objective, num_iter, callbacks=[self._callback])
diff --git a/src/diart/pipelines.py b/src/diart/pipelines.py
index fa27c087..391d60b3 100644
--- a/src/diart/pipelines.py
+++ b/src/diart/pipelines.py
@@ -26,6 +26,7 @@ def __init__(
beta: float = 10,
max_speakers: int = 20,
device: Optional[torch.device] = None,
+ **kwargs,
):
# Default segmentation model is pyannote/segmentation
self.segmentation = segmentation
@@ -70,18 +71,15 @@ def from_namespace(args: Any) -> 'PipelineConfig':
duration=getattr(args, "duration", None),
step=args.step,
latency=args.latency,
- tau_active=args.tau_active,
- rho_update=args.rho_update,
- delta_new=args.delta_new,
+ tau_active=args.tau,
+ rho_update=args.rho,
+ delta_new=args.delta,
gamma=args.gamma,
beta=args.beta,
max_speakers=args.max_speakers,
device=args.device,
)
- def copy(self) -> 'PipelineConfig':
- return PipelineConfig.from_namespace(self)
-
def last_chunk_end_time(self, conv_duration: float) -> Optional[float]:
"""Return the end time of the last chunk for a given conversation duration.
diff --git a/src/diart/stream.py b/src/diart/stream.py
index 1e548bf6..c263364d 100644
--- a/src/diart/stream.py
+++ b/src/diart/stream.py
@@ -23,11 +23,7 @@
parser.add_argument("--cpu", dest="cpu", action="store_true", help=f"{argdoc.CPU}. Defaults to GPU if available, CPU otherwise")
parser.add_argument("--output", type=str, help=f"{argdoc.OUTPUT}. Defaults to home directory if SOURCE == 'microphone' or parent directory if SOURCE is a file")
args = parser.parse_args()
-
args.device = torch.device("cpu") if args.cpu else None
- args.tau_active = args.tau
- args.rho_update = args.rho
- args.delta_new = args.delta
# Define online speaker diarization pipeline
config = PipelineConfig.from_namespace(args)
From 74c2d12583f124e6191e4fc74e83ffe561562889 Mon Sep 17 00:00:00 2001
From: Juan Manuel Coria
Date: Wed, 15 Jun 2022 17:16:33 +0200
Subject: [PATCH 21/29] Remove OptimizationObjective
---
README.md | 16 +++++------
src/diart/optim.py | 70 ++++++++++++++++++++--------------------------
2 files changed, 39 insertions(+), 47 deletions(-)
diff --git a/README.md b/README.md
index 2a52b0d5..c0aee4f6 100644
--- a/README.md
+++ b/README.md
@@ -93,20 +93,20 @@ More information on Optuna can be found [here](https://optuna.readthedocs.io/en/
### A simple example
```python
-from diart.optim import OptimizationObjective, Optimizer, TauActive, RhoUpdate, DeltaNew
+from diart.optim import Optimizer, TauActive, RhoUpdate, DeltaNew
from diart.pipelines import PipelineConfig
from diart.inference import Benchmark
-# Benchmark runs and evaluates the pipeline with each configuration
+# Benchmark runs and evaluates the pipeline on a dataset
benchmark = Benchmark("/wav/dir", "/rttm/dir", "/out/dir", show_report=False)
# Base configuration for the pipeline we're going to tune
base_config = PipelineConfig(duration=5, step=0.5, latency=5)
# Hyper-parameters to optimize
hparams = [TauActive, RhoUpdate, DeltaNew]
-# The objective implements an optimization step
-objective = OptimizationObjective(benchmark, base_config, hparams)
-# Run optimization for 100 iterations
-Optimizer(objective).optimize(num_iter=100, show_progress=True)
+# Optimizer implements the optimization loop
+optimizer = Optimizer(benchmark, base_config, hparams)
+# Run optimization
+optimizer.optimize(num_iter=100, show_progress=True)
```
### Distributed optimization
@@ -118,10 +118,10 @@ More information on distributed optimization can be found [here](https://optuna.
```python
from diart.optim import Optimizer
-objective = ...
+benchmark, base_config, hparams = ...
study_name = "my_study"
storage = "mysql://root@localhost/example"
-optimizer = Optimizer(objective, study_name, storage)
+optimizer = Optimizer(benchmark, base_config, hparams, study_name, storage)
optimizer.optimize(num_iter=100, show_progress=True)
```
diff --git a/src/diart/optim.py b/src/diart/optim.py
index 4260f952..200fb642 100644
--- a/src/diart/optim.py
+++ b/src/diart/optim.py
@@ -25,54 +25,20 @@ class HyperParameter:
DeltaNew = HyperParameter("delta_new", low=0, high=2)
-class OptimizationObjective:
+class Optimizer:
def __init__(
self,
benchmark: Benchmark,
base_config: PipelineConfig,
hparams: Iterable[HyperParameter],
- ):
- self.benchmark = benchmark
- self.base_config = base_config
- self.hparams = hparams
-
- def __call__(self, trial: Trial) -> float:
- # Set suggested values for optimized hyper-parameters
- trial_config = vars(self.base_config)
- for hparam in self.hparams:
- trial_config[hparam.name] = trial.suggest_uniform(
- hparam.name, hparam.low, hparam.high
- )
-
- # Instantiate pipeline with the new configuration
- pipeline = OnlineSpeakerDiarization(PipelineConfig(**trial_config))
-
- # Prune trial if required
- if trial.should_prune():
- raise TrialPruned()
-
- # Run pipeline over the dataset
- report = self.benchmark(pipeline)
-
- # Clean RTTM files
- for tmp_file in self.benchmark.output_path.iterdir():
- if tmp_file.name.endswith(".rttm"):
- tmp_file.unlink()
-
- # Extract DER from report
- return report.loc["TOTAL", "diarization error rate"]["%"]
-
-
-class Optimizer:
- def __init__(
- self,
- objective: OptimizationObjective,
study_name: Optional[Text] = None,
storage: Optional[Text] = None,
sampler: Optional[BaseSampler] = None,
pruner: Optional[BasePruner] = None,
):
- self.objective = objective
+ self.benchmark = benchmark
+ self.base_config = base_config
+ self.hparams = hparams
self.study = create_study(
storage=self.default_storage if storage is None else storage,
sampler=TPESampler() if sampler is None else sampler,
@@ -85,7 +51,7 @@ def __init__(
@property
def default_output_path(self) -> Path:
- return self.objective.benchmark.output_path.parent
+ return self.benchmark.output_path.parent
@property
def default_study_name(self) -> Text:
@@ -113,6 +79,32 @@ def _callback(self, study: Study, trial: FrozenTrial):
values[f"best_{name}"] = value
self._progress.set_postfix(OrderedDict(values))
+ def objective(self, trial: Trial) -> float:
+ # Set suggested values for optimized hyper-parameters
+ trial_config = vars(self.base_config)
+ for hparam in self.hparams:
+ trial_config[hparam.name] = trial.suggest_uniform(
+ hparam.name, hparam.low, hparam.high
+ )
+
+ # Instantiate pipeline with the new configuration
+ pipeline = OnlineSpeakerDiarization(PipelineConfig(**trial_config))
+
+ # Prune trial if required
+ if trial.should_prune():
+ raise TrialPruned()
+
+ # Run pipeline over the dataset
+ report = self.benchmark(pipeline)
+
+ # Clean RTTM files
+ for tmp_file in self.benchmark.output_path.iterdir():
+ if tmp_file.name.endswith(".rttm"):
+ tmp_file.unlink()
+
+ # Extract DER from report
+ return report.loc["TOTAL", "diarization error rate"]["%"]
+
def optimize(self, num_iter: int, show_progress: bool = True):
self._progress = None
if show_progress:
From 69789b25dc3f23b9bc5057159c286f14db438e04 Mon Sep 17 00:00:00 2001
From: Juan Manuel Coria
Date: Wed, 15 Jun 2022 17:48:52 +0200
Subject: [PATCH 22/29] Improve README.md
---
README.md | 55 ++++++++++++++++++++++++++++++++++++++++++++-----------
1 file changed, 44 insertions(+), 11 deletions(-)
diff --git a/README.md b/README.md
index c0aee4f6..5ce6a4c5 100644
--- a/README.md
+++ b/README.md
@@ -12,13 +12,45 @@
+
+
-## Install
+## Installation
1) Create environment:
@@ -51,13 +83,15 @@ pip install diart
## Stream your own audio
-### A recorded conversation
+### From the command line
+
+A recorded conversation:
```shell
python -m diart.stream /path/to/audio.wav
```
-### From your microphone
+A live conversation:
```shell
python -m diart.stream microphone
@@ -65,9 +99,9 @@ python -m diart.stream microphone
See `python -m diart.stream -h` for more options.
-## Inference API
+### From python
-Run a customized real-time speaker diarization pipeline over an audio stream with `diart.inference.RealTimeInference`:
+Run a real-time speaker diarization pipeline over an audio stream with `RealTimeInference`:
```python
from diart.sources import MicrophoneAudioSource
@@ -78,17 +112,16 @@ config = PipelineConfig() # Default parameters
pipeline = OnlineSpeakerDiarization(config)
audio_source = MicrophoneAudioSource(config.sample_rate)
inference = RealTimeInference("/output/path", do_plot=True)
-
inference(pipeline, audio_source)
```
-For faster inference and evaluation on a dataset we recommend to use `Benchmark` (see our notes on [reproducibility](#reproducibility))
+For faster inference and evaluation on a dataset we recommend to use `Benchmark` instead (see our notes on [reproducibility](#reproducibility)).
+
+## Optimize hyper-parameters to your dataset
-## Optimize hyper-parameters to your own dataset
+Diart implements a hyper-parameter optimizer based on optuna that allows you to tune any pipeline to any dataset.
-Diart implements a hyper-parameter optimizer based on [optuna](https://github.com/optuna/optuna).
-`diart.optim.Optimizer` allows you to tune any pipeline to a custom dataset.
-More information on Optuna can be found [here](https://optuna.readthedocs.io/en/stable/index.html).
+[More about optuna](https://optuna.readthedocs.io/en/stable/index.html).
### A simple example
From 278b1dcbf82df344a8cc64df9a0fa79bd9e58a88 Mon Sep 17 00:00:00 2001
From: Juan Manuel Coria
Date: Thu, 16 Jun 2022 12:06:45 +0200
Subject: [PATCH 23/29] Improve optimization API
---
README.md | 36 ++++++++++++++++++-------------
src/diart/optim.py | 53 ++++++++++++++++++++--------------------------
2 files changed, 44 insertions(+), 45 deletions(-)
diff --git a/README.md b/README.md
index 5ce6a4c5..1b36099a 100644
--- a/README.md
+++ b/README.md
@@ -18,15 +18,15 @@
Installation
|
-
+
Stream audio
|
-
+
Tune hyper-parameters
|
-
+
Build pipelines
|
@@ -81,7 +81,7 @@ pip install git+https://github.com/pyannote/pyannote-audio.git@develop#egg=pyann
pip install diart
```
-## Stream your own audio
+## Stream audio
### From the command line
@@ -117,7 +117,7 @@ inference(pipeline, audio_source)
For faster inference and evaluation on a dataset we recommend to use `Benchmark` instead (see our notes on [reproducibility](#reproducibility)).
-## Optimize hyper-parameters to your dataset
+## Tune hyper-parameters
Diart implements a hyper-parameter optimizer based on optuna that allows you to tune any pipeline to any dataset.
@@ -137,30 +137,36 @@ base_config = PipelineConfig(duration=5, step=0.5, latency=5)
# Hyper-parameters to optimize
hparams = [TauActive, RhoUpdate, DeltaNew]
# Optimizer implements the optimization loop
-optimizer = Optimizer(benchmark, base_config, hparams)
+optimizer = Optimizer(benchmark, base_config, hparams, "/db/out/dir")
# Run optimization
optimizer.optimize(num_iter=100, show_progress=True)
```
+This will store temporary predictions in `/out/dir` and write results of each trial in an sqlite database `/db/out/dir/trials.db`.
+
### Distributed optimization
-For bigger datasets, it is sometimes more convenient to run optimization in parallel.
-If the same `study_name` and `storage` are given to the optimizer, all optimization processes will share the information from previous runs.
-More information on distributed optimization can be found [here](https://optuna.readthedocs.io/en/stable/tutorial/10_key_features/004_distributed.html#sphx-glr-tutorial-10-key-features-004-distributed-py).
+For bigger datasets, it is sometimes more convenient to run multiple optimization processes in parallel.
+To do this, make sure that the output directory or the study given to the optimizer is the same in each process.
+Notice that the output directories of `Benchmark` MUST be different to avoid concurrency issues.
```python
from diart.optim import Optimizer
+from diart.inference import Benchmark
-benchmark, base_config, hparams = ...
-study_name = "my_study"
-storage = "mysql://root@localhost/example"
-optimizer = Optimizer(benchmark, base_config, hparams, study_name, storage)
+ID = 0 # Worker identifier
+base_config, hparams = ...
+benchmark = Benchmark("/wav/dir", "/rttm/dir", f"/out/dir/worker_{ID}", show_report=False)
+optimizer = Optimizer(benchmark, base_config, hparams, "/db/out/dir")
optimizer.optimize(num_iter=100, show_progress=True)
```
-## Build your own pipeline
+It is recommended to use other databases like mysql instead of sqlite in distributed optimization.
+More on this [here](https://optuna.readthedocs.io/en/stable/tutorial/10_key_features/004_distributed.html#sphx-glr-tutorial-10-key-features-004-distributed-py).
+
+## Build pipelines
-For a more advanced usage, diart also provides building blocks that can be combined to create your own pipelines.
+For a more advanced usage, diart also provides building blocks that can be combined to create your own pipeline.
Streaming is powered by [RxPY](https://github.com/ReactiveX/RxPY), but the `blocks` module is completely independent and can be used separately.
### Example
diff --git a/src/diart/optim.py b/src/diart/optim.py
index 200fb642..a71bd571 100644
--- a/src/diart/optim.py
+++ b/src/diart/optim.py
@@ -1,14 +1,14 @@
from collections import OrderedDict
from dataclasses import dataclass
from pathlib import Path
-from typing import Iterable, Text, Optional
+from typing import Iterable, Text, Optional, Union
from optuna import TrialPruned, Study, create_study
-from optuna.pruners import BasePruner
-from optuna.samplers import TPESampler, BaseSampler
+from optuna.samplers import TPESampler
from optuna.trial import Trial, FrozenTrial
from tqdm import trange, tqdm
+from .audio import FilePath
from .benchmark import Benchmark
from .pipelines import PipelineConfig, OnlineSpeakerDiarization
@@ -31,35 +31,26 @@ def __init__(
benchmark: Benchmark,
base_config: PipelineConfig,
hparams: Iterable[HyperParameter],
- study_name: Optional[Text] = None,
- storage: Optional[Text] = None,
- sampler: Optional[BaseSampler] = None,
- pruner: Optional[BasePruner] = None,
+ study_or_path: Union[FilePath, Study],
):
self.benchmark = benchmark
self.base_config = base_config
self.hparams = hparams
- self.study = create_study(
- storage=self.default_storage if storage is None else storage,
- sampler=TPESampler() if sampler is None else sampler,
- pruner=pruner,
- study_name=self.default_study_name if study_name is None else study_name,
- direction="minimize",
- load_if_exists=True,
- )
self._progress: Optional[tqdm] = None
- @property
- def default_output_path(self) -> Path:
- return self.benchmark.output_path.parent
-
- @property
- def default_study_name(self) -> Text:
- return self.default_output_path.name
-
- @property
- def default_storage(self) -> Text:
- return "sqlite:///" + str(self.default_output_path / "trials.db")
+ if isinstance(study_or_path, Study):
+ self.study = study_or_path
+ elif isinstance(study_or_path, str) or isinstance(study_or_path, Path):
+ self.study = create_study(
+ storage="sqlite:///" + str(study_or_path / "trials.db"),
+ sampler=TPESampler(),
+ study_name=study_or_path.name,
+ direction="minimize",
+ load_if_exists=True,
+ )
+ else:
+ msg = f"Expected Study object or path-like, but got {type(study_or_path).__name__}"
+ raise ValueError(msg)
@property
def best_performance(self):
@@ -87,13 +78,13 @@ def objective(self, trial: Trial) -> float:
hparam.name, hparam.low, hparam.high
)
- # Instantiate pipeline with the new configuration
- pipeline = OnlineSpeakerDiarization(PipelineConfig(**trial_config))
-
# Prune trial if required
if trial.should_prune():
raise TrialPruned()
+ # Instantiate pipeline with the new configuration
+ pipeline = OnlineSpeakerDiarization(PipelineConfig(**trial_config), profile=False)
+
# Run pipeline over the dataset
report = self.benchmark(pipeline)
@@ -109,6 +100,8 @@ def optimize(self, num_iter: int, show_progress: bool = True):
self._progress = None
if show_progress:
self._progress = trange(num_iter)
- last_trial = self.study.trials[-1].number
+ last_trial = -1
+ if self.study.trials:
+ last_trial = self.study.trials[-1].number
self._progress.set_description(f"Trial {last_trial + 1}")
self.study.optimize(self.objective, num_iter, callbacks=[self._callback])
From 780681c04f28b56aba227c71f4e06f77414eb88d Mon Sep 17 00:00:00 2001
From: Juan Manuel Coria
Date: Thu, 16 Jun 2022 16:43:41 +0200
Subject: [PATCH 24/29] Add diart.tune script
---
README.md | 79 +++++++++++++++++++++++++-----------------
src/diart/argdoc.py | 1 +
src/diart/benchmark.py | 2 +-
src/diart/optim.py | 14 ++++++--
src/diart/tune.py | 68 ++++++++++++++++++++++++++++++++++++
5 files changed, 129 insertions(+), 35 deletions(-)
create mode 100644 src/diart/tune.py
diff --git a/README.md b/README.md
index 1b36099a..204a3a2f 100644
--- a/README.md
+++ b/README.md
@@ -119,11 +119,17 @@ For faster inference and evaluation on a dataset we recommend to use `Benchmark`
## Tune hyper-parameters
-Diart implements a hyper-parameter optimizer based on optuna that allows you to tune any pipeline to any dataset.
+Diart implements a hyper-parameter optimizer based on [optuna](https://optuna.readthedocs.io/en/stable/index.html) that allows you to tune any pipeline to any dataset.
-[More about optuna](https://optuna.readthedocs.io/en/stable/index.html).
+### From the command line
-### A simple example
+```shell
+python -m diart.tune /wav/dir --reference /rttm/dir --output /out/dir
+```
+
+See `python -m diart.tune -h` for more options.
+
+### From python
```python
from diart.optim import Optimizer, TauActive, RhoUpdate, DeltaNew
@@ -131,39 +137,51 @@ from diart.pipelines import PipelineConfig
from diart.inference import Benchmark
# Benchmark runs and evaluates the pipeline on a dataset
-benchmark = Benchmark("/wav/dir", "/rttm/dir", "/out/dir", show_report=False)
+benchmark = Benchmark("/wav/dir", "/rttm/dir", "/out/dir/tmp", show_report=False)
# Base configuration for the pipeline we're going to tune
-base_config = PipelineConfig(duration=5, step=0.5, latency=5)
+base_config = PipelineConfig()
# Hyper-parameters to optimize
hparams = [TauActive, RhoUpdate, DeltaNew]
# Optimizer implements the optimization loop
-optimizer = Optimizer(benchmark, base_config, hparams, "/db/out/dir")
+optimizer = Optimizer(benchmark, base_config, hparams, "/out/dir")
# Run optimization
optimizer.optimize(num_iter=100, show_progress=True)
```
-This will store temporary predictions in `/out/dir` and write results of each trial in an sqlite database `/db/out/dir/trials.db`.
+This will use `/out/dir/tmp` as a working directory and write results to an sqlite database in `/out/dir`.
### Distributed optimization
For bigger datasets, it is sometimes more convenient to run multiple optimization processes in parallel.
-To do this, make sure that the output directory or the study given to the optimizer is the same in each process.
-Notice that the output directories of `Benchmark` MUST be different to avoid concurrency issues.
+To do this, create a study on a [recommended DBMS](https://optuna.readthedocs.io/en/stable/tutorial/10_key_features/004_distributed.html#sphx-glr-tutorial-10-key-features-004-distributed-py) (e.g. MySQL or PostgreSQL) making sure that the study and database names match:
+
+```shell
+mysql -u root -e "CREATE DATABASE IF NOT EXISTS example"
+optuna create-study --study-name "example" --storage "mysql://root@localhost/example"
+```
+
+Then you can run multiple identical optimizers pointing to the database:
+
+```shell
+python -m diart.tune /wav/dir --reference /rttm/dir --output /out/dir --storage mysql://root@localhost/example
+```
+
+If you are using the python API, make sure that worker directories are different to avoid concurrency issues:
```python
from diart.optim import Optimizer
from diart.inference import Benchmark
+from optuna.samplers import TPESampler
+import optuna
ID = 0 # Worker identifier
base_config, hparams = ...
-benchmark = Benchmark("/wav/dir", "/rttm/dir", f"/out/dir/worker_{ID}", show_report=False)
-optimizer = Optimizer(benchmark, base_config, hparams, "/db/out/dir")
+benchmark = Benchmark("/wav/dir", "/rttm/dir", f"/out/dir/worker-{ID}", show_report=False)
+study = optuna.load_study("example", "mysql://root@localhost/example", TPESampler())
+optimizer = Optimizer(benchmark, base_config, hparams, study)
optimizer.optimize(num_iter=100, show_progress=True)
```
-It is recommended to use other databases like mysql instead of sqlite in distributed optimization.
-More on this [here](https://optuna.readthedocs.io/en/stable/tutorial/10_key_features/004_distributed.html#sphx-glr-tutorial-10-key-features-004-distributed-py).
-
## Build pipelines
For a more advanced usage, diart also provides building blocks that can be combined to create your own pipeline.
@@ -174,30 +192,27 @@ Streaming is powered by [RxPY](https://github.com/ReactiveX/RxPY), but the `bloc
Obtain overlap-aware speaker embeddings from a microphone stream:
```python
-import rx
import rx.operators as ops
import diart.operators as dops
from diart.sources import MicrophoneAudioSource
from diart.blocks import SpeakerSegmentation, OverlapAwareSpeakerEmbedding
from diart.models import SegmentationModel, EmbeddingModel
-# Initialize independent modules
-seg_model = SegmentationModel.from_pyannote("pyannote/segmentation")
-segmentation = SpeakerSegmentation(seg_model)
-emb_model = EmbeddingModel.from_pyannote("pyannote/embedding")
-embedding = OverlapAwareSpeakerEmbedding(emb_model)
-mic = MicrophoneAudioSource(seg_model.get_sample_rate())
-
-# Reformat microphone stream. Defaults to 5s duration and 500ms shift
-regular_stream = mic.stream.pipe(dops.regularize_audio_stream(seg_model.get_sample_rate()))
-# Branch the microphone stream to calculate segmentation
-segmentation_stream = regular_stream.pipe(ops.map(segmentation))
-# Join audio and segmentation stream to calculate speaker embeddings
-embedding_stream = rx.zip(
- regular_stream, segmentation_stream
-).pipe(ops.starmap(embedding))
+segmentation = SpeakerSegmentation(
+ SegmentationModel.from_pyannote("pyannote/segmentation")
+)
+embedding = OverlapAwareSpeakerEmbedding(
+ EmbeddingModel.from_pyannote("pyannote/embedding")
+)
+sample_rate = segmentation.model.get_sample_rate()
+mic = MicrophoneAudioSource(sample_rate)
-embedding_stream.subscribe(on_next=lambda emb: print(emb.shape))
+stream = mic.stream.pipe(
+ # Reformat stream to 5s duration and 500ms shift
+ dops.regularize_audio_stream(sample_rate),
+ ops.map(lambda wav: (wav, segmentation(wav))),
+ ops.starmap(embedding)
+).subscribe(on_next=lambda emb: print(emb.shape))
mic.read()
```
@@ -276,7 +291,7 @@ config = PipelineConfig(
pipeline = OnlineSpeakerDiarization(config)
benchmark = Benchmark("/wav/dir", "/rttm/dir", "/out/dir")
-benchmark(pipeline, batch_size=32)
+benchmark(pipeline)
```
This runs a faster inference by pre-calculating model outputs in batches.
diff --git a/src/diart/argdoc.py b/src/diart/argdoc.py
index 640dae14..0fb2a1a6 100644
--- a/src/diart/argdoc.py
+++ b/src/diart/argdoc.py
@@ -7,4 +7,5 @@
BETA = "Parameter beta for overlapped speech penalty"
MAX_SPEAKERS = "Maximum number of speakers"
CPU = "Force models to run on CPU"
+BATCH_SIZE = "For segmentation and embedding pre-calculation. If BATCH_SIZE < 2, run fully online and estimate real-time latency"
OUTPUT = "Directory to store the system's output in RTTM format"
diff --git a/src/diart/benchmark.py b/src/diart/benchmark.py
index b5fcc567..674942a3 100644
--- a/src/diart/benchmark.py
+++ b/src/diart/benchmark.py
@@ -18,7 +18,7 @@
parser.add_argument("--gamma", default=3, type=float, help=f"{argdoc.GAMMA}. Defaults to 3")
parser.add_argument("--beta", default=10, type=float, help=f"{argdoc.BETA}. Defaults to 10")
parser.add_argument("--max-speakers", default=20, type=int, help=f"{argdoc.MAX_SPEAKERS}. Defaults to 20")
- parser.add_argument("--batch-size", default=32, type=int, help="For segmentation and embedding pre-calculation. If BATCH_SIZE < 2, run fully online and estimate real-time latency. Defaults to 32")
+ parser.add_argument("--batch-size", default=32, type=int, help=f"{argdoc.BATCH_SIZE}. Defaults to 32")
parser.add_argument("--cpu", dest="cpu", action="store_true", help=f"{argdoc.CPU}. Defaults to GPU if available, CPU otherwise")
parser.add_argument("--output", type=str, help=f"{argdoc.OUTPUT}. Defaults to `root`")
args = parser.parse_args()
diff --git a/src/diart/optim.py b/src/diart/optim.py
index a71bd571..db616db5 100644
--- a/src/diart/optim.py
+++ b/src/diart/optim.py
@@ -19,6 +19,16 @@ class HyperParameter:
low: float
high: float
+ @staticmethod
+ def from_name(name: Text) -> 'HyperParameter':
+ if name == "tau_active":
+ return TauActive
+ if name == "rho_update":
+ return RhoUpdate
+ if name == "delta_new":
+ return DeltaNew
+ raise ValueError(f"Hyper-parameter '{name}' not recognized")
+
TauActive = HyperParameter("tau_active", low=0, high=1)
RhoUpdate = HyperParameter("rho_update", low=0, high=1)
@@ -42,9 +52,9 @@ def __init__(
self.study = study_or_path
elif isinstance(study_or_path, str) or isinstance(study_or_path, Path):
self.study = create_study(
- storage="sqlite:///" + str(study_or_path / "trials.db"),
+ storage="sqlite:///" + str(study_or_path / f"{study_or_path.stem}.db"),
sampler=TPESampler(),
- study_name=study_or_path.name,
+ study_name=study_or_path.stem,
direction="minimize",
load_if_exists=True,
)
diff --git a/src/diart/tune.py b/src/diart/tune.py
new file mode 100644
index 00000000..a35060fd
--- /dev/null
+++ b/src/diart/tune.py
@@ -0,0 +1,68 @@
+import argparse
+from pathlib import Path
+from uuid import uuid4
+
+import optuna
+import torch
+from optuna.samplers import TPESampler
+
+import diart.argdoc as argdoc
+from diart.inference import Benchmark
+from diart.optim import Optimizer, HyperParameter
+from diart.pipelines import PipelineConfig
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("root", type=str, help="Directory with audio files CONVERSATION.(wav|flac|m4a|...)")
+ parser.add_argument("--reference", required=True, type=str, help="Directory with RTTM files CONVERSATION.rttm. Names must match audio files")
+ parser.add_argument("--step", default=0.5, type=float, help=f"{argdoc.STEP}. Defaults to 0.5")
+ parser.add_argument("--latency", default=0.5, type=float, help=f"{argdoc.LATENCY}. Defaults to 0.5")
+ parser.add_argument("--tau", default=0.5, type=float, help=f"{argdoc.TAU}. Defaults to 0.5")
+ parser.add_argument("--rho", default=0.3, type=float, help=f"{argdoc.RHO}. Defaults to 0.3")
+ parser.add_argument("--delta", default=1, type=float, help=f"{argdoc.DELTA}. Defaults to 1")
+ parser.add_argument("--gamma", default=3, type=float, help=f"{argdoc.GAMMA}. Defaults to 3")
+ parser.add_argument("--beta", default=10, type=float, help=f"{argdoc.BETA}. Defaults to 10")
+ parser.add_argument("--max-speakers", default=20, type=int, help=f"{argdoc.MAX_SPEAKERS}. Defaults to 20")
+ parser.add_argument("--batch-size", default=32, type=int, help=f"{argdoc.BATCH_SIZE}. Defaults to 32")
+ parser.add_argument("--cpu", dest="cpu", action="store_true", help=f"{argdoc.CPU}. Defaults to GPU if available, CPU otherwise")
+ parser.add_argument("--hparams", nargs="+", default=("tau_active", "rho_update", "delta_new"), help="Hyper-parameters to optimize. Must match names in `PipelineConfig`. Defaults to tau_active, rho_update and delta_new")
+ parser.add_argument("--num-iter", default=100, type=int, help="Number of optimization trials")
+ parser.add_argument("--storage", type=str, help="Optuna storage string. If provided, continue a previous study instead of creating one. The database name must match the study name")
+ parser.add_argument("--output", required=True, type=str, help="Working directory")
+ args = parser.parse_args()
+ args.output = Path(args.output)
+ args.output.mkdir(parents=True, exist_ok=True)
+ args.device = torch.device("cpu") if args.cpu else None
+
+ # Assign unique worker ID
+ idx = uuid4()
+
+ # Create benchmark object to run the pipeline on a set of files
+ work_path = args.output / f"worker-{idx}"
+ benchmark = Benchmark(
+ args.root,
+ args.reference,
+ work_path,
+ show_progress=True,
+ show_report=False,
+ batch_size=args.batch_size
+ )
+
+ # Create the base configuration for each trial
+ base_config = PipelineConfig.from_namespace(args)
+
+ # Create hyper-parameters to optimize
+ hparams = [HyperParameter.from_name(name) for name in args.hparams]
+
+ # Use a custom storage if given
+ study_or_path = args.output
+ if args.storage is not None:
+ db_name = Path(args.storage).stem
+ study_or_path = optuna.load_study(db_name, args.storage, TPESampler())
+
+ # Run optimization
+ optimizer = Optimizer(benchmark, base_config, hparams, study_or_path)
+ optimizer.optimize(num_iter=args.num_iter, show_progress=True)
+
+ # Clean temporary directory
+ work_path.rmdir()
From eec293678d53f3ce604d55070da0500277da1bdf Mon Sep 17 00:00:00 2001
From: Juan Manuel Coria
Date: Thu, 16 Jun 2022 16:54:45 +0200
Subject: [PATCH 25/29] Simplify loading segmentation and embedding models from
pyannote
---
README.md | 9 ++-------
src/diart/blocks/embedding.py | 15 +++++++++++++++
src/diart/blocks/segmentation.py | 4 ++++
3 files changed, 21 insertions(+), 7 deletions(-)
diff --git a/README.md b/README.md
index 204a3a2f..acf3abfa 100644
--- a/README.md
+++ b/README.md
@@ -196,14 +196,9 @@ import rx.operators as ops
import diart.operators as dops
from diart.sources import MicrophoneAudioSource
from diart.blocks import SpeakerSegmentation, OverlapAwareSpeakerEmbedding
-from diart.models import SegmentationModel, EmbeddingModel
-segmentation = SpeakerSegmentation(
- SegmentationModel.from_pyannote("pyannote/segmentation")
-)
-embedding = OverlapAwareSpeakerEmbedding(
- EmbeddingModel.from_pyannote("pyannote/embedding")
-)
+segmentation = SpeakerSegmentation.from_pyannote("pyannote/segmentation")
+embedding = OverlapAwareSpeakerEmbedding.from_pyannote("pyannote/embedding")
sample_rate = segmentation.model.get_sample_rate()
mic = MicrophoneAudioSource(sample_rate)
diff --git a/src/diart/blocks/embedding.py b/src/diart/blocks/embedding.py
index 19288876..8855ae08 100644
--- a/src/diart/blocks/embedding.py
+++ b/src/diart/blocks/embedding.py
@@ -18,6 +18,10 @@ def __init__(self, model: EmbeddingModel, device: Optional[torch.device] = None)
self.waveform_formatter = TemporalFeatureFormatter()
self.weights_formatter = TemporalFeatureFormatter()
+ @staticmethod
+ def from_pyannote(model, device: Optional[torch.device] = None) -> 'SpeakerEmbedding':
+ return SpeakerEmbedding(EmbeddingModel.from_pyannote(model), device)
+
def __call__(self, waveform: TemporalFeatures, weights: Optional[TemporalFeatures] = None) -> torch.Tensor:
"""
Calculate speaker embeddings of input audio.
@@ -134,5 +138,16 @@ def __init__(
self.osp = OverlappedSpeechPenalty(gamma, beta)
self.normalize = EmbeddingNormalization(norm)
+ @staticmethod
+ def from_pyannote(
+ model,
+ gamma: float = 3,
+ beta: float = 10,
+ norm: Union[float, torch.Tensor] = 1,
+ device: Optional[torch.device] = None,
+ ):
+ model = EmbeddingModel.from_pyannote(model)
+ return OverlapAwareSpeakerEmbedding(model, gamma, beta, norm, device)
+
def __call__(self, waveform: TemporalFeatures, segmentation: TemporalFeatures) -> torch.Tensor:
return self.normalize(self.embedding(waveform, self.osp(segmentation)))
diff --git a/src/diart/blocks/segmentation.py b/src/diart/blocks/segmentation.py
index 33064207..1b310944 100644
--- a/src/diart/blocks/segmentation.py
+++ b/src/diart/blocks/segmentation.py
@@ -17,6 +17,10 @@ def __init__(self, model: SegmentationModel, device: Optional[torch.device] = No
self.model.to(self.device)
self.formatter = TemporalFeatureFormatter()
+ @staticmethod
+ def from_pyannote(model, device: Optional[torch.device] = None) -> 'SpeakerSegmentation':
+ return SpeakerSegmentation(SegmentationModel.from_pyannote(model), device)
+
def __call__(self, waveform: TemporalFeatures) -> TemporalFeatures:
"""
Calculate the speaker segmentation of input audio.
From ad041525fdf768b474021ab71093077566c3c5f2 Mon Sep 17 00:00:00 2001
From: Juan Manuel Coria
Date: Thu, 16 Jun 2022 17:15:23 +0200
Subject: [PATCH 26/29] Register stream, benchmark and tune as entry points
during installation
---
README.md | 16 ++++++++--------
setup.cfg | 9 +++++++--
src/diart/benchmark.py | 13 ++++++++++---
src/diart/stream.py | 13 ++++++++++---
src/diart/tune.py | 19 ++++++++++++++-----
5 files changed, 49 insertions(+), 21 deletions(-)
diff --git a/README.md b/README.md
index acf3abfa..7efd2c47 100644
--- a/README.md
+++ b/README.md
@@ -88,16 +88,16 @@ pip install diart
A recorded conversation:
```shell
-python -m diart.stream /path/to/audio.wav
+diart.stream /path/to/audio.wav
```
A live conversation:
```shell
-python -m diart.stream microphone
+diart.stream microphone
```
-See `python -m diart.stream -h` for more options.
+See `diart.stream -h` for more options.
### From python
@@ -124,10 +124,10 @@ Diart implements a hyper-parameter optimizer based on [optuna](https://optuna.re
### From the command line
```shell
-python -m diart.tune /wav/dir --reference /rttm/dir --output /out/dir
+diart.tune /wav/dir --reference /rttm/dir --output /out/dir
```
-See `python -m diart.tune -h` for more options.
+See `diart.tune -h` for more options.
### From python
@@ -163,7 +163,7 @@ optuna create-study --study-name "example" --storage "mysql://root@localhost/exa
Then you can run multiple identical optimizers pointing to the database:
```shell
-python -m diart.tune /wav/dir --reference /rttm/dir --output /out/dir --storage mysql://root@localhost/example
+diart.tune /wav/dir --reference /rttm/dir --output /out/dir --storage mysql://root@localhost/example
```
If you are using the python API, make sure that worker directories are different to avoid concurrency issues:
@@ -267,7 +267,7 @@ To obtain the best results, make sure to use the following hyper-parameters:
`diart.benchmark` and `diart.inference.Benchmark` can quickly run and evaluate the pipeline, and even measure its real-time latency. For instance, for a DIHARD III configuration:
```shell
-python -m diart.benchmark /wav/dir --reference /rttm/dir --tau=0.555 --rho=0.422 --delta=1.517 --output /out/dir
+diart.benchmark /wav/dir --reference /rttm/dir --tau=0.555 --rho=0.422 --delta=1.517 --output /out/dir
```
or using the inference API:
@@ -290,7 +290,7 @@ benchmark(pipeline)
```
This runs a faster inference by pre-calculating model outputs in batches.
-See `python -m diart.benchmark -h` for more options.
+See `diart.benchmark -h` for more options.
For convenience and to facilitate future comparisons, we also provide the [expected outputs](/expected_outputs) of the paper implementation in RTTM format for every entry of Table 1 and Figure 5. This includes the VBx offline topline as well as our proposed online approach with latencies 500ms, 1s, 2s, 3s, 4s, and 5s.
diff --git a/setup.cfg b/setup.cfg
index a0f7c776..214ca57b 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -19,7 +19,7 @@ classifiers=
package_dir=
=src
packages=find:
-install_requires =
+install_requires=
numpy>=1.20.2
matplotlib>=3.3.3
rx>=3.2.0
@@ -34,6 +34,11 @@ install_requires =
pyannote.metrics>=3.2
optuna>=2.10
-
[options.packages.find]
where=src
+
+[options.entry_points]
+console_scripts=
+ diart.stream=diart.stream:run
+ diart.benchmark=diart.benchmark:run
+ diart.tune=diart.tune:run
diff --git a/src/diart/benchmark.py b/src/diart/benchmark.py
index 674942a3..ad0969fb 100644
--- a/src/diart/benchmark.py
+++ b/src/diart/benchmark.py
@@ -6,10 +6,12 @@
from diart.inference import Benchmark
from diart.pipelines import OnlineSpeakerDiarization, PipelineConfig
-if __name__ == "__main__":
+
+def run():
parser = argparse.ArgumentParser()
parser.add_argument("root", type=str, help="Directory with audio files CONVERSATION.(wav|flac|m4a|...)")
- parser.add_argument("--reference", type=str, help="Optional. Directory with RTTM files CONVERSATION.rttm. Names must match audio files")
+ parser.add_argument("--reference", type=str,
+ help="Optional. Directory with RTTM files CONVERSATION.rttm. Names must match audio files")
parser.add_argument("--step", default=0.5, type=float, help=f"{argdoc.STEP}. Defaults to 0.5")
parser.add_argument("--latency", default=0.5, type=float, help=f"{argdoc.LATENCY}. Defaults to 0.5")
parser.add_argument("--tau", default=0.5, type=float, help=f"{argdoc.TAU}. Defaults to 0.5")
@@ -19,7 +21,8 @@
parser.add_argument("--beta", default=10, type=float, help=f"{argdoc.BETA}. Defaults to 10")
parser.add_argument("--max-speakers", default=20, type=int, help=f"{argdoc.MAX_SPEAKERS}. Defaults to 20")
parser.add_argument("--batch-size", default=32, type=int, help=f"{argdoc.BATCH_SIZE}. Defaults to 32")
- parser.add_argument("--cpu", dest="cpu", action="store_true", help=f"{argdoc.CPU}. Defaults to GPU if available, CPU otherwise")
+ parser.add_argument("--cpu", dest="cpu", action="store_true",
+ help=f"{argdoc.CPU}. Defaults to GPU if available, CPU otherwise")
parser.add_argument("--output", type=str, help=f"{argdoc.OUTPUT}. Defaults to `root`")
args = parser.parse_args()
args.device = torch.device("cpu") if args.cpu else None
@@ -34,3 +37,7 @@
)
benchmark(OnlineSpeakerDiarization(PipelineConfig.from_namespace(args), profile=True))
+
+
+if __name__ == "__main__":
+ run()
diff --git a/src/diart/stream.py b/src/diart/stream.py
index c263364d..63232e43 100644
--- a/src/diart/stream.py
+++ b/src/diart/stream.py
@@ -8,7 +8,8 @@
from diart.inference import RealTimeInference
from diart.pipelines import OnlineSpeakerDiarization, PipelineConfig
-if __name__ == "__main__":
+
+def run():
parser = argparse.ArgumentParser()
parser.add_argument("source", type=str, help="Path to an audio file | 'microphone'")
parser.add_argument("--step", default=0.5, type=float, help=f"{argdoc.STEP}. Defaults to 0.5")
@@ -20,8 +21,10 @@
parser.add_argument("--beta", default=10, type=float, help=f"{argdoc.BETA}. Defaults to 10")
parser.add_argument("--max-speakers", default=20, type=int, help=f"{argdoc.MAX_SPEAKERS}. Defaults to 20")
parser.add_argument("--no-plot", dest="no_plot", action="store_true", help="Skip plotting for faster inference")
- parser.add_argument("--cpu", dest="cpu", action="store_true", help=f"{argdoc.CPU}. Defaults to GPU if available, CPU otherwise")
- parser.add_argument("--output", type=str, help=f"{argdoc.OUTPUT}. Defaults to home directory if SOURCE == 'microphone' or parent directory if SOURCE is a file")
+ parser.add_argument("--cpu", dest="cpu", action="store_true",
+ help=f"{argdoc.CPU}. Defaults to GPU if available, CPU otherwise")
+ parser.add_argument("--output", type=str,
+ help=f"{argdoc.OUTPUT}. Defaults to home directory if SOURCE == 'microphone' or parent directory if SOURCE is a file")
args = parser.parse_args()
args.device = torch.device("cpu") if args.cpu else None
@@ -42,3 +45,7 @@
# Run online inference
RealTimeInference(args.output, do_plot=not args.no_plot)(pipeline, audio_source)
+
+
+if __name__ == "__main__":
+ run()
diff --git a/src/diart/tune.py b/src/diart/tune.py
index a35060fd..000e3da8 100644
--- a/src/diart/tune.py
+++ b/src/diart/tune.py
@@ -11,10 +11,12 @@
from diart.optim import Optimizer, HyperParameter
from diart.pipelines import PipelineConfig
-if __name__ == "__main__":
+
+def run():
parser = argparse.ArgumentParser()
parser.add_argument("root", type=str, help="Directory with audio files CONVERSATION.(wav|flac|m4a|...)")
- parser.add_argument("--reference", required=True, type=str, help="Directory with RTTM files CONVERSATION.rttm. Names must match audio files")
+ parser.add_argument("--reference", required=True, type=str,
+ help="Directory with RTTM files CONVERSATION.rttm. Names must match audio files")
parser.add_argument("--step", default=0.5, type=float, help=f"{argdoc.STEP}. Defaults to 0.5")
parser.add_argument("--latency", default=0.5, type=float, help=f"{argdoc.LATENCY}. Defaults to 0.5")
parser.add_argument("--tau", default=0.5, type=float, help=f"{argdoc.TAU}. Defaults to 0.5")
@@ -24,10 +26,13 @@
parser.add_argument("--beta", default=10, type=float, help=f"{argdoc.BETA}. Defaults to 10")
parser.add_argument("--max-speakers", default=20, type=int, help=f"{argdoc.MAX_SPEAKERS}. Defaults to 20")
parser.add_argument("--batch-size", default=32, type=int, help=f"{argdoc.BATCH_SIZE}. Defaults to 32")
- parser.add_argument("--cpu", dest="cpu", action="store_true", help=f"{argdoc.CPU}. Defaults to GPU if available, CPU otherwise")
- parser.add_argument("--hparams", nargs="+", default=("tau_active", "rho_update", "delta_new"), help="Hyper-parameters to optimize. Must match names in `PipelineConfig`. Defaults to tau_active, rho_update and delta_new")
+ parser.add_argument("--cpu", dest="cpu", action="store_true",
+ help=f"{argdoc.CPU}. Defaults to GPU if available, CPU otherwise")
+ parser.add_argument("--hparams", nargs="+", default=("tau_active", "rho_update", "delta_new"),
+ help="Hyper-parameters to optimize. Must match names in `PipelineConfig`. Defaults to tau_active, rho_update and delta_new")
parser.add_argument("--num-iter", default=100, type=int, help="Number of optimization trials")
- parser.add_argument("--storage", type=str, help="Optuna storage string. If provided, continue a previous study instead of creating one. The database name must match the study name")
+ parser.add_argument("--storage", type=str,
+ help="Optuna storage string. If provided, continue a previous study instead of creating one. The database name must match the study name")
parser.add_argument("--output", required=True, type=str, help="Working directory")
args = parser.parse_args()
args.output = Path(args.output)
@@ -66,3 +71,7 @@
# Clean temporary directory
work_path.rmdir()
+
+
+if __name__ == "__main__":
+ run()
From 676b2b515296eff2a19cba6dd25f6c0a27c056f8 Mon Sep 17 00:00:00 2001
From: Juan Manuel Coria
Date: Thu, 16 Jun 2022 18:02:20 +0200
Subject: [PATCH 27/29] Add custom embedding model example
---
README.md | 35 +++++++++++++++++++++++++++++++++++
1 file changed, 35 insertions(+)
diff --git a/README.md b/README.md
index 7efd2c47..53f5cdba 100644
--- a/README.md
+++ b/README.md
@@ -22,6 +22,10 @@
Stream audio
|
+
+ Add your model
+
+ |
Tune hyper-parameters
@@ -117,6 +121,37 @@ inference(pipeline, audio_source)
For faster inference and evaluation on a dataset we recommend to use `Benchmark` instead (see our notes on [reproducibility](#reproducibility)).
+## Add your model
+
+Third-party segmentation and embedding models can be integrated seamlessly by subclassing `SegmentationModel` and `EmbeddingModel`:
+
+```python
+import torch
+from typing import Optional
+from diart.models import EmbeddingModel
+from diart.pipelines import PipelineConfig, OnlineSpeakerDiarization
+from diart.sources import MicrophoneAudioSource
+from diart.inference import RealTimeInference
+
+class MyEmbeddingModel(EmbeddingModel):
+ def __init__(self):
+ super().__init__()
+ self.my_pretrained_model = load("my_model.ckpt")
+
+ def __call__(
+ self,
+ waveform: torch.Tensor,
+ weights: Optional[torch.Tensor] = None
+ ) -> torch.Tensor:
+ return self.my_pretrained_model(waveform, weights)
+
+config = PipelineConfig(embedding=MyEmbeddingModel())
+pipeline = OnlineSpeakerDiarization(config)
+mic = MicrophoneAudioSource(config.sample_rate)
+inference = RealTimeInference("/out/dir")
+inference(pipeline, mic)
+```
+
## Tune hyper-parameters
Diart implements a hyper-parameter optimizer based on [optuna](https://optuna.readthedocs.io/en/stable/index.html) that allows you to tune any pipeline to any dataset.
From d551d9806f6e98e92dfb202fc755b486eef7c3f3 Mon Sep 17 00:00:00 2001
From: Juan Manuel Coria
Date: Wed, 13 Jul 2022 12:18:25 +0200
Subject: [PATCH 28/29] Update version
---
setup.cfg | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/setup.cfg b/setup.cfg
index 214ca57b..5ba6a171 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -1,6 +1,6 @@
[metadata]
name=diart
-version=0.3.0
+version=0.4.0
author=Juan Manuel Coria
description=Speaker diarization in real time
long_description=file: README.md
@@ -9,7 +9,7 @@ keywords=speaker diarization, streaming, online, real time, rxpy
url=https://github.com/juanmc2005/StreamingSpeakerDiarization
license=MIT
classifiers=
- Development Status :: 3 - Alpha
+ Development Status :: 4 - Beta
License :: OSI Approved :: MIT License
Topic :: Multimedia :: Sound/Audio :: Analysis
Topic :: Multimedia :: Sound/Audio :: Speech
From b40f091c27db65e990daf4fd419c882f79b164cb Mon Sep 17 00:00:00 2001
From: Juan Coria
Date: Wed, 13 Jul 2022 14:46:03 +0200
Subject: [PATCH 29/29] Split menu in README.md
---
README.md | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/README.md b/README.md
index 53f5cdba..7cf7030b 100644
--- a/README.md
+++ b/README.md
@@ -33,7 +33,7 @@
Build pipelines
- |
+
Research