From 6a43b7619499150b4e568f40d057610d49684feb Mon Sep 17 00:00:00 2001 From: Juan Manuel Coria Date: Tue, 31 May 2022 19:05:39 +0200 Subject: [PATCH 01/29] Add base classes for custom segmentation and embedding models. Unify pipeline specifications in PipelineConfig --- README.md | 20 +++++----- src/diart/argdoc.py | 2 +- src/diart/benchmark.py | 4 +- src/diart/blocks.py | 90 ++++++++++++++++++++++++++---------------- src/diart/inference.py | 18 ++++++--- src/diart/models.py | 83 ++++++++++++++++++++++++++++++++++++++ src/diart/pipelines.py | 60 +++++++++++++++------------- src/diart/stream.py | 10 +++-- 8 files changed, 205 insertions(+), 82 deletions(-) create mode 100644 src/diart/models.py diff --git a/README.md b/README.md index f742bd09..5faf55d5 100644 --- a/README.md +++ b/README.md @@ -63,8 +63,9 @@ from diart.sources import MicrophoneAudioSource from diart.inference import RealTimeInference from diart.pipelines import OnlineSpeakerDiarization, PipelineConfig -pipeline = OnlineSpeakerDiarization(PipelineConfig()) -audio_source = MicrophoneAudioSource(pipeline.sample_rate) +config = PipelineConfig() # Default parameters +pipeline = OnlineSpeakerDiarization(config) +audio_source = MicrophoneAudioSource(config.sample_rate) inference = RealTimeInference("/output/path", do_plot=True) inference(pipeline, audio_source) @@ -86,17 +87,18 @@ import rx import rx.operators as ops import diart.operators as dops from diart.sources import MicrophoneAudioSource -from diart.blocks import FramewiseModel, OverlapAwareSpeakerEmbedding - -sample_rate = 16000 -mic = MicrophoneAudioSource(sample_rate) +from diart.blocks import SpeakerSegmentation, OverlapAwareSpeakerEmbedding +from diart.models import SegmentationModel, EmbeddingModel # Initialize independent modules -segmentation = FramewiseModel("pyannote/segmentation") -embedding = OverlapAwareSpeakerEmbedding("pyannote/embedding") +seg_model = SegmentationModel.from_pyannote("pyannote/segmentation") +segmentation = SpeakerSegmentation(seg_model) +emb_model = EmbeddingModel.from_pyannote("pyannote/embedding") +embedding = OverlapAwareSpeakerEmbedding(emb_model) +mic = MicrophoneAudioSource(seg_model.get_sample_rate()) # Reformat microphone stream. Defaults to 5s duration and 500ms shift -regular_stream = mic.stream.pipe(dops.regularize_stream(sample_rate)) +regular_stream = mic.stream.pipe(dops.regularize_stream(seg_model.get_sample_rate())) # Branch the microphone stream to calculate segmentation segmentation_stream = regular_stream.pipe(ops.map(segmentation)) # Join audio and segmentation stream to calculate speaker embeddings diff --git a/src/diart/argdoc.py b/src/diart/argdoc.py index 933f96ca..640dae14 100644 --- a/src/diart/argdoc.py +++ b/src/diart/argdoc.py @@ -6,5 +6,5 @@ GAMMA = "Parameter gamma for overlapped speech penalty" BETA = "Parameter beta for overlapped speech penalty" MAX_SPEAKERS = "Maximum number of speakers" -GPU = "Run on GPU" +CPU = "Force models to run on CPU" OUTPUT = "Directory to store the system's output in RTTM format" diff --git a/src/diart/benchmark.py b/src/diart/benchmark.py index fc07f754..b161b708 100644 --- a/src/diart/benchmark.py +++ b/src/diart/benchmark.py @@ -20,7 +20,7 @@ parser.add_argument("--beta", default=10, type=float, help=f"{argdoc.BETA}. Defaults to 10") parser.add_argument("--max-speakers", default=20, type=int, help=f"{argdoc.MAX_SPEAKERS}. Defaults to 20") parser.add_argument("--batch-size", default=32, type=int, help="For segmentation and embedding pre-calculation. If BATCH_SIZE < 2, run fully online and estimate real-time latency. Defaults to 32") - parser.add_argument("--gpu", dest="gpu", action="store_true", help=argdoc.GPU) + parser.add_argument("--cpu", dest="cpu", action="store_true", help=f"{argdoc.CPU}. Defaults to GPU if available, CPU otherwise") parser.add_argument("--output", type=str, help=f"{argdoc.OUTPUT}. Defaults to `root`") args = parser.parse_args() @@ -37,7 +37,7 @@ gamma=args.gamma, beta=args.beta, max_speakers=args.max_speakers, - device=torch.device("cuda") if args.gpu else None, + device=torch.device("cpu") if args.cpu else None, )) benchmark(pipeline, args.batch_size) diff --git a/src/diart/blocks.py b/src/diart/blocks.py index 2da18939..b5117d07 100644 --- a/src/diart/blocks.py +++ b/src/diart/blocks.py @@ -1,15 +1,14 @@ from typing import Union, Optional, List, Iterable, Tuple -from typing_extensions import Literal import numpy as np import torch -from pyannote.audio.pipelines.utils import PipelineModel, get_model, get_devices +from einops import rearrange from pyannote.audio.utils.signal import Binarize as PyanBinarize from pyannote.core import Annotation, Segment, SlidingWindow, SlidingWindowFeature -from einops import rearrange +from typing_extensions import Literal from .mapping import SpeakerMap, SpeakerMapBuilder - +from .models import SegmentationModel, EmbeddingModel TemporalFeatures = Union[SlidingWindowFeature, np.ndarray, torch.Tensor] @@ -42,26 +41,31 @@ def resolve_features(features: TemporalFeatures) -> torch.Tensor: return data.float() -class FramewiseModel: - def __init__(self, model: PipelineModel, device: Optional[torch.device] = None): - self.model = get_model(model) +class SpeakerSegmentation: + def __init__(self, model: SegmentationModel, device: Optional[torch.device] = None): + self.model = model self.model.eval() - if device is None: - device = get_devices(needs=1)[0] - self.model.to(device) + self.device = device + if self.device is None: + self.device = torch.device("cpu") + self.model.to(self.device) - @property - def sample_rate(self) -> int: - return self.model.audio.sample_rate + def __call__(self, waveform: TemporalFeatures) -> TemporalFeatures: + """ + Calculate the speaker segmentation of input audio. - @property - def duration(self) -> float: - return self.model.specifications.duration + Parameters + ---------- + waveform: TemporalFeatures, shape (samples, channels) or (batch, samples, channels) - def __call__(self, waveform: TemporalFeatures) -> TemporalFeatures: + Returns + ------- + speaker_segmentation: TemporalFeatures, shape (batch, frames, speakers) + The batch dimension is omitted if waveform is a `SlidingWindowFeature`. + """ with torch.no_grad(): wave = rearrange(resolve_features(waveform), "batch sample channel -> batch channel sample") - output = self.model(wave.to(self.model.device)).cpu() + output = self.model(wave.to(self.device)).cpu() batch_size, num_frames, _ = output.shape @@ -72,7 +76,7 @@ def __call__(self, waveform: TemporalFeatures) -> TemporalFeatures: # Wrap if a SlidingWindowFeature was given as input if isinstance(waveform, SlidingWindowFeature): # Temporal resolution of the output - duration = wave.shape[-1] / self.sample_rate + duration = wave.shape[-1] / self.model.get_sample_rate() resolution = duration / num_frames # Temporal shift to keep track of current start time resolution = SlidingWindow( @@ -88,26 +92,44 @@ def __call__(self, waveform: TemporalFeatures) -> TemporalFeatures: return output -class ChunkwiseModel: - def __init__(self, model: PipelineModel, device: Optional[torch.device] = None): - self.model = get_model(model) +class SpeakerEmbedding: + def __init__(self, model: EmbeddingModel, device: Optional[torch.device] = None): + self.model = model self.model.eval() - if device is None: - device = get_devices(needs=1)[0] - self.model.to(device) + self.device = device + if self.device is None: + self.device = torch.device("cpu") + self.model.to(self.device) + + def __call__(self, waveform: TemporalFeatures, weights: Optional[TemporalFeatures] = None) -> torch.Tensor: + """ + Calculate speaker embeddings of input audio. + If weights are given, calculate many speaker embeddings from the same waveform. + + Parameters + ---------- + waveform: TemporalFeatures, shape (samples, channels) or (batch, samples, channels) + weights: Optional[TemporalFeatures], shape (frames, speakers) or (batch, frames, speakers) + Per-speaker and per-frame weights. Defaults to no weights. - def __call__(self, waveform: TemporalFeatures, weights: Optional[TemporalFeatures]) -> torch.Tensor: + Returns + ------- + embeddings: torch.Tensor + If weights are provided, the shape is (batch, speakers, embedding_dim), + otherwise the shape is (batch, embedding_dim). + If batch size == 1, the batch dimension is omitted. + """ with torch.no_grad(): - inputs = resolve_features(waveform).to(self.model.device) + inputs = resolve_features(waveform).to(self.device) inputs = rearrange(inputs, "batch sample channel -> batch channel sample") if weights is not None: - weights = resolve_features(weights).to(self.model.device) + weights = resolve_features(weights).to(self.device) batch_size, _, num_speakers = weights.shape inputs = inputs.repeat(1, num_speakers, 1) weights = rearrange(weights, "batch frame spk -> (batch spk) frame") inputs = rearrange(inputs, "batch spk sample -> (batch spk) 1 sample") output = rearrange( - self.model(inputs, weights=weights), + self.model(inputs, weights), "(batch spk) feat -> batch spk feat", batch=batch_size, spk=num_speakers @@ -125,7 +147,7 @@ class OverlappedSpeechPenalty: Exponent to lower low-confidence predictions. Defaults to 3. beta: float, optional - Softmax's temperature parameter (actually 1/beta) to lower joint speaker activations. + Temperature parameter (actually 1/beta) to lower joint speaker activations. Defaults to 10. """ def __init__(self, gamma: float = 3, beta: float = 10): @@ -171,8 +193,8 @@ class OverlapAwareSpeakerEmbedding: Parameters ---------- - model: pyannote.audio.Model, Text or Dict - The embedding model. It must take a waveform and weights as input. + model: EmbeddingModel + A pre-trained embedding model. gamma: float, optional Exponent to lower low-confidence predictions. Defaults to 3. @@ -188,13 +210,13 @@ class OverlapAwareSpeakerEmbedding: """ def __init__( self, - model: PipelineModel, + model: EmbeddingModel, gamma: float = 3, beta: float = 10, norm: Union[float, torch.Tensor] = 1, device: Optional[torch.device] = None, ): - self.embedding = ChunkwiseModel(model, device) + self.embedding = SpeakerEmbedding(model, device) self.osp = OverlappedSpeechPenalty(gamma, beta) self.normalize = EmbeddingNormalization(norm) diff --git a/src/diart/inference.py b/src/diart/inference.py index a9c79589..4574279b 100644 --- a/src/diart/inference.py +++ b/src/diart/inference.py @@ -59,12 +59,12 @@ def __call__(self, pipeline: OnlineSpeakerDiarization, source: src.AudioSource) observable.pipe( ops.do(rttm_writer), dops.buffer_output( - duration=pipeline.duration, + duration=pipeline.config.duration, step=pipeline.config.step, latency=pipeline.config.latency, - sample_rate=pipeline.sample_rate + sample_rate=pipeline.config.sample_rate ), - ).subscribe(RealTimePlot(pipeline.duration, pipeline.config.latency)) + ).subscribe(RealTimePlot(pipeline.config.duration, pipeline.config.latency)) # Stream audio through the pipeline source.read() @@ -125,7 +125,11 @@ def __call__(self, pipeline: OnlineSpeakerDiarization, batch_size: int = 32) -> DataFrame with detailed performance on each file, as well as average performance. None if the reference is not provided. """ - chunk_loader = src.ChunkLoader(pipeline.sample_rate, pipeline.duration, pipeline.config.step) + chunk_loader = src.ChunkLoader( + pipeline.config.sample_rate, + pipeline.config.duration, + pipeline.config.step + ) audio_file_paths = list(self.speech_path.iterdir()) num_audio_files = len(audio_file_paths) for i, filepath in enumerate(audio_file_paths): @@ -137,7 +141,11 @@ def __call__(self, pipeline: OnlineSpeakerDiarization, batch_size: int = 32) -> source = src.FileAudioSource( filepath, filepath.stem, - src.RegularAudioFileReader(pipeline.sample_rate, pipeline.duration, pipeline.config.step), + src.RegularAudioFileReader( + pipeline.config.sample_rate, + pipeline.config.duration, + pipeline.config.step + ), # Benchmark the processing time of a single chunk profile=True, ) diff --git a/src/diart/models.py b/src/diart/models.py new file mode 100644 index 00000000..74ae5c9b --- /dev/null +++ b/src/diart/models.py @@ -0,0 +1,83 @@ +from typing import Optional + +import torch +import torch.nn as nn +from pyannote.audio.pipelines.utils import PipelineModel, get_model + + +class SegmentationModel(nn.Module): + """ + Minimal interface for a segmentation model. + """ + + @staticmethod + def from_pyannote(model: PipelineModel) -> 'SegmentationModel': + class PyannoteSegmentationModel(SegmentationModel): + def __init__(self, pyannote_model: PipelineModel): + super().__init__() + self.model = get_model(pyannote_model) + + def get_sample_rate(self) -> int: + return self.model.audio.sample_rate + + def get_duration(self) -> float: + return self.model.specifications.duration + + def __call__(self, waveform: torch.Tensor) -> torch.Tensor: + return self.model(waveform) + + return PyannoteSegmentationModel(model) + + def get_sample_rate(self) -> int: + """Return the sample rate expected for model inputs""" + raise NotImplementedError + + def get_duration(self) -> float: + """Return the input duration by default (usually the one used during training)""" + raise NotImplementedError + + def __call__(self, waveform: torch.Tensor) -> torch.Tensor: + """ + Forward pass of a segmentation model. + + Parameters + ---------- + waveform: torch.Tensor, shape (batch, channels, samples) + + Returns + ------- + speaker_segmentation: torch.Tensor, shape (batch, frames, speakers) + """ + raise NotImplementedError + + +class EmbeddingModel(nn.Module): + """Minimal interface for an embedding model.""" + + @staticmethod + def from_pyannote(model: PipelineModel) -> 'EmbeddingModel': + class PyannoteEmbeddingModel(EmbeddingModel): + def __init__(self, pyannote_model: PipelineModel): + super().__init__() + self.model = get_model(pyannote_model) + + def __call__(self, waveform: torch.Tensor, weights: Optional[torch.Tensor] = None) -> torch.Tensor: + return self.model(waveform, weights=weights) + + return PyannoteEmbeddingModel(model) + + def __call__(self, waveform: torch.Tensor, weights: Optional[torch.Tensor] = None) -> torch.Tensor: + """ + Forward pass of an embedding model with optional weights. + + Parameters + ---------- + waveform: torch.Tensor, shape (batch, channels, samples) + weights: Optional[torch.Tensor], shape (batch, frames) + Temporal weights for each sample in the batch. Defaults to no weights. + + Returns + ------- + speaker_embeddings: torch.Tensor, shape (batch, speakers, embedding_dim) + """ + raise NotImplementedError diff --git a/src/diart/pipelines.py b/src/diart/pipelines.py index b56b235a..2678d182 100644 --- a/src/diart/pipelines.py +++ b/src/diart/pipelines.py @@ -7,11 +7,12 @@ import rx.operators as ops import torch from einops import rearrange -from pyannote.audio.pipelines.utils import PipelineModel from pyannote.core import SlidingWindowFeature, SlidingWindow +from pyannote.audio.pipelines.utils import get_devices from tqdm import tqdm from . import blocks +from . import models as m from . import operators as dops from . import sources as src @@ -19,8 +20,8 @@ class PipelineConfig: def __init__( self, - segmentation: PipelineModel = "pyannote/segmentation", - embedding: PipelineModel = "pyannote/embedding", + segmentation: Optional[m.SegmentationModel] = None, + embedding: Optional[m.EmbeddingModel] = None, duration: Optional[float] = None, step: float = 0.5, latency: Optional[float] = None, @@ -32,22 +33,40 @@ def __init__( max_speakers: int = 20, device: Optional[torch.device] = None, ): + # Default segmentation model is pyannote/segmentation self.segmentation = segmentation + if self.segmentation is None: + self.segmentation = m.SegmentationModel.from_pyannote("pyannote/segmentation") + + # Default duration is the one given by the segmentation model + self.duration = duration + if self.duration is None: + self.duration = self.segmentation.get_duration() + + # Expected sample rate is given by the segmentation model + self.sample_rate = self.segmentation.get_sample_rate() + + # Default embedding model is pyannote/embedding self.embedding = embedding - self.requested_duration = duration + if self.embedding is None: + self.embedding = m.EmbeddingModel.from_pyannote("pyannote/embedding") + + # Latency defaults to the step duration self.step = step self.latency = latency if self.latency is None: self.latency = self.step + self.tau_active = tau_active self.rho_update = rho_update self.delta_new = delta_new self.gamma = gamma self.beta = beta self.max_speakers = max_speakers + self.device = device if self.device is None: - self.device = torch.device("cpu") + self.device = get_devices(needs=1)[0] def last_chunk_end_time(self, conv_duration: float) -> Optional[float]: """ @@ -117,39 +136,26 @@ def from_model_streams( class OnlineSpeakerDiarization: def __init__(self, config: PipelineConfig): self.config = config - self.segmentation = blocks.FramewiseModel(config.segmentation, config.device) + self.segmentation = blocks.SpeakerSegmentation(config.segmentation, config.device) self.embedding = blocks.OverlapAwareSpeakerEmbedding( config.embedding, config.gamma, config.beta, norm=1, device=config.device ) self.speaker_tracking = OnlineSpeakerTracking(config) - msg = "Invalid latency requested" - assert config.step <= config.latency <= self.duration, msg - - @property - def sample_rate(self) -> int: - """Sample rate expected by the segmentation model""" - return self.segmentation.sample_rate - - @property - def duration(self) -> float: - """Chunk duration (in seconds). Defaults to segmentation model duration""" - duration = self.config.requested_duration - if duration is None: - duration = self.segmentation.duration - return duration + msg = f"Latency should be in the range [{config.step}, {config.duration}]" + assert config.step <= config.latency <= config.duration, msg def from_source( self, source: src.AudioSource, output_waveform: bool = True ) -> rx.Observable: - msg = f"Audio source has sample rate {source.sample_rate}, expected {self.sample_rate}" - assert source.sample_rate == self.sample_rate, msg + msg = f"Audio source has sample rate {source.sample_rate}, expected {self.config.sample_rate}" + assert source.sample_rate == self.config.sample_rate, msg # Regularize the stream to a specific chunk duration and step regular_stream = source.stream if not source.is_regular: regular_stream = source.stream.pipe( - dops.regularize_stream(self.duration, self.config.step, source.sample_rate) + dops.regularize_stream(self.config.duration, self.config.step, source.sample_rate) ) # Branch the stream to calculate chunk segmentation seg_stream = regular_stream.pipe(ops.map(self.segmentation)) @@ -170,7 +176,7 @@ def from_file( # Audio file information file = Path(file) chunk_loader = src.ChunkLoader( - self.sample_rate, self.duration, self.config.step + self.config.sample_rate, self.config.duration, self.config.step ) # Split audio into chunks @@ -207,14 +213,14 @@ def from_file( embeddings = torch.vstack(embeddings) # Stream pre-calculated segmentation, embeddings and chunks - resolution = self.duration / segmentation.shape[1] + resolution = self.config.duration / segmentation.shape[1] seg_stream = rx.range(0, num_chunks).pipe( ops.map(lambda i: SlidingWindowFeature( segmentation[i], SlidingWindow(resolution, resolution, i * self.config.step) )) ) emb_stream = rx.range(0, num_chunks).pipe(ops.map(lambda i: embeddings[i])) - wav_resolution = 1 / self.sample_rate + wav_resolution = 1 / self.config.sample_rate chunk_stream = None if output_waveform: chunk_stream = rx.range(0, num_chunks).pipe( diff --git a/src/diart/stream.py b/src/diart/stream.py index d98d2c8f..4d4bb819 100644 --- a/src/diart/stream.py +++ b/src/diart/stream.py @@ -21,7 +21,7 @@ parser.add_argument("--beta", default=10, type=float, help=f"{argdoc.BETA}. Defaults to 10") parser.add_argument("--max-speakers", default=20, type=int, help=f"{argdoc.MAX_SPEAKERS}. Defaults to 20") parser.add_argument("--no-plot", dest="no_plot", action="store_true", help="Skip plotting for faster inference") - parser.add_argument("--gpu", dest="gpu", action="store_true", help=argdoc.GPU) + parser.add_argument("--cpu", dest="cpu", action="store_true", help=f"{argdoc.CPU}. Defaults to GPU if available, CPU otherwise") parser.add_argument("--output", type=str, help=f"{argdoc.OUTPUT}. Defaults to home directory if SOURCE == 'microphone' or parent directory if SOURCE is a file") args = parser.parse_args() @@ -35,7 +35,7 @@ gamma=args.gamma, beta=args.beta, max_speakers=args.max_speakers, - device=torch.device("cuda") if args.gpu else None, + device=torch.device("cpu") if args.cpu else None, )) # Manage audio source @@ -46,12 +46,14 @@ file=args.source, uri=args.source.stem, reader=src.RegularAudioFileReader( - pipeline.sample_rate, pipeline.duration, pipeline.config.step + pipeline.config.sample_rate, + pipeline.config.duration, + pipeline.config.step, ), ) else: args.output = Path("~/").expanduser() if args.output is None else Path(args.output) - audio_source = src.MicrophoneAudioSource(pipeline.sample_rate) + audio_source = src.MicrophoneAudioSource(pipeline.config.sample_rate) # Run online inference RealTimeInference(args.output, do_plot=not args.no_plot)(pipeline, audio_source) From 6c0f5713b6ab0650d70506a9e20f61a8095910c8 Mon Sep 17 00:00:00 2001 From: Juan Manuel Coria Date: Wed, 1 Jun 2022 00:19:39 +0200 Subject: [PATCH 02/29] Add own binarize implementation --- src/diart/blocks.py | 36 +++++++++++++++++------------------- 1 file changed, 17 insertions(+), 19 deletions(-) diff --git a/src/diart/blocks.py b/src/diart/blocks.py index b5117d07..b79fcdd9 100644 --- a/src/diart/blocks.py +++ b/src/diart/blocks.py @@ -3,7 +3,6 @@ import numpy as np import torch from einops import rearrange -from pyannote.audio.utils.signal import Binarize as PyanBinarize from pyannote.core import Annotation, Segment, SlidingWindow, SlidingWindowFeature from typing_extensions import Literal @@ -561,26 +560,25 @@ def __call__(self, segmentation: SlidingWindowFeature, embeddings: torch.Tensor) class Binarize: - def __init__(self, uri: str, tau_active: float): + def __init__(self, uri: str, threshold: float): self.uri = uri - self._binarize = PyanBinarize( - onset=tau_active, - offset=tau_active, - min_duration_on=0, - min_duration_off=0, - ) - - def _select( - self, scores: SlidingWindowFeature, speaker: int - ) -> SlidingWindowFeature: - return SlidingWindowFeature( - scores[:, speaker].reshape(-1, 1), scores.sliding_window - ) + self.threshold = threshold def __call__(self, segmentation: SlidingWindowFeature) -> Annotation: + num_frames, num_speakers = segmentation.data.shape + timestamps = segmentation.sliding_window + is_active = segmentation.data.T > self.threshold # shape (speakers, frames) + # Artificially add last inactive frame to close any remaining speaker turns + is_active = np.append(is_active, [[False]] * num_speakers, axis=1) annotation = Annotation(uri=self.uri, modality="speech") - for speaker in range(segmentation.data.shape[1]): - turns = self._binarize(self._select(segmentation, speaker)) - for speaker_turn in turns.itersegments(): - annotation[speaker_turn, speaker] = f"speaker{speaker}" + for spk in range(num_speakers): + start = timestamps[0].middle + for t in range(num_frames): + # Any (False, True) start a speaker turn at "True" index + if not is_active[spk, t] and is_active[spk, t + 1]: + start = timestamps[t + 1].middle + # Any (True, False) end a speaker turn at "False" index + elif is_active[spk, t] and not is_active[spk, t + 1]: + region = Segment(start, timestamps[t + 1].middle) + annotation[region, spk] = f"speaker{spk}" return annotation From 7f771731b5cb9e99f80bf933d116d2137f5f72df Mon Sep 17 00:00:00 2001 From: Juan Manuel Coria Date: Wed, 1 Jun 2022 00:48:54 +0200 Subject: [PATCH 03/29] Make Binarize more efficient by looking at all speakers at the same time. Add docstring to Binarize --- src/diart/blocks.py | 54 +++++++++++++++++++++++++++++++++------------ 1 file changed, 40 insertions(+), 14 deletions(-) diff --git a/src/diart/blocks.py b/src/diart/blocks.py index b79fcdd9..83f099fb 100644 --- a/src/diart/blocks.py +++ b/src/diart/blocks.py @@ -1,4 +1,4 @@ -from typing import Union, Optional, List, Iterable, Tuple +from typing import Union, Optional, List, Iterable, Tuple, Text import numpy as np import torch @@ -560,25 +560,51 @@ def __call__(self, segmentation: SlidingWindowFeature, embeddings: torch.Tensor) class Binarize: - def __init__(self, uri: str, threshold: float): + """ + Transform a speaker segmentation from the discrete-time domain + into a continuous-time speaker segmentation. + + Parameters + ---------- + uri: Text + Uri of the audio stream. + threshold: float + Probability threshold to determine if a speaker is active at a given frame. + """ + + def __init__(self, uri: Text, threshold: float): self.uri = uri self.threshold = threshold def __call__(self, segmentation: SlidingWindowFeature) -> Annotation: + """ + Return the continuous-time segmentation + corresponding to the discrete-time input segmentation. + + Parameters + ---------- + segmentation: SlidingWindowFeature + Discrete-time speaker segmentation. + + Returns + ------- + annotation: Annotation + Continuous-time speaker segmentation. + """ num_frames, num_speakers = segmentation.data.shape timestamps = segmentation.sliding_window - is_active = segmentation.data.T > self.threshold # shape (speakers, frames) + is_active = segmentation.data > self.threshold # Artificially add last inactive frame to close any remaining speaker turns - is_active = np.append(is_active, [[False]] * num_speakers, axis=1) + is_active = np.append(is_active, [[False] * num_speakers], axis=0) + start_times = np.zeros(num_speakers) + timestamps[0].middle annotation = Annotation(uri=self.uri, modality="speech") - for spk in range(num_speakers): - start = timestamps[0].middle - for t in range(num_frames): - # Any (False, True) start a speaker turn at "True" index - if not is_active[spk, t] and is_active[spk, t + 1]: - start = timestamps[t + 1].middle - # Any (True, False) end a speaker turn at "False" index - elif is_active[spk, t] and not is_active[spk, t + 1]: - region = Segment(start, timestamps[t + 1].middle) - annotation[region, spk] = f"speaker{spk}" + for t in range(num_frames): + # Any (False, True) starts a speaker turn at "True" index + onsets = np.logical_and(np.logical_not(is_active[t]), is_active[t + 1]) + start_times[onsets] = timestamps[t + 1].middle + # Any (True, False) ends a speaker turn at "False" index + offsets = np.logical_and(is_active[t], np.logical_not(is_active[t + 1])) + for spk in np.where(offsets)[0]: + region = Segment(start_times[spk], timestamps[t + 1].middle) + annotation[region, spk] = f"speaker{spk}" return annotation From f2866c515deaaf8724501ad33d93d3188dc33ba9 Mon Sep 17 00:00:00 2001 From: Juan Manuel Coria Date: Wed, 1 Jun 2022 00:54:06 +0200 Subject: [PATCH 04/29] Simplify default device logic --- src/diart/pipelines.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/diart/pipelines.py b/src/diart/pipelines.py index 2678d182..2087f986 100644 --- a/src/diart/pipelines.py +++ b/src/diart/pipelines.py @@ -8,7 +8,6 @@ import torch from einops import rearrange from pyannote.core import SlidingWindowFeature, SlidingWindow -from pyannote.audio.pipelines.utils import get_devices from tqdm import tqdm from . import blocks @@ -66,7 +65,7 @@ def __init__( self.device = device if self.device is None: - self.device = get_devices(needs=1)[0] + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def last_chunk_end_time(self, conv_duration: float) -> Optional[float]: """ From 3f8bd6b50522f80b72177fb46fb5c49723bc8565 Mon Sep 17 00:00:00 2001 From: Juan Coria Date: Fri, 3 Jun 2022 10:39:01 +0200 Subject: [PATCH 05/29] Replace `resolve_features` with `TemporalFeatureFormatter` (#59) * Add TemporalFeatureFormatter to cast temporal inputs to pytorch tensors and then restore outputs to the original type * Update docstrings * Add cleaner implementation of TemporalFeatureFormatter * Bug fixes --- src/diart/blocks.py | 75 ++++------------------- src/diart/features.py | 135 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 146 insertions(+), 64 deletions(-) create mode 100644 src/diart/features.py diff --git a/src/diart/blocks.py b/src/diart/blocks.py index 2da18939..df2edf1c 100644 --- a/src/diart/blocks.py +++ b/src/diart/blocks.py @@ -9,37 +9,7 @@ from einops import rearrange from .mapping import SpeakerMap, SpeakerMapBuilder - - -TemporalFeatures = Union[SlidingWindowFeature, np.ndarray, torch.Tensor] - - -def resolve_features(features: TemporalFeatures) -> torch.Tensor: - """ - Transform features into a `torch.Tensor` and add batch dimension if missing. - - Parameters - ---------- - features: Union[SlidingWindowFeature, np.ndarray, torch.Tensor] - Shape (frames, channels) or (batch, frames, channels) - - Returns - ------- - transformed_features: torch.Tensor, shape (batch, frames, channels) - """ - # As torch.Tensor with shape (..., channels, frames) - if isinstance(features, SlidingWindowFeature): - data = torch.from_numpy(features.data) - elif isinstance(features, np.ndarray): - data = torch.from_numpy(features) - else: - data = features - # Make sure there's a batch dimension - msg = "Temporal features must be 2D or 3D" - assert data.ndim in (2, 3), msg - if data.ndim == 2: - data = data.unsqueeze(0) - return data.float() +from .features import TemporalFeatures, TemporalFeatureFormatter class FramewiseModel: @@ -49,6 +19,7 @@ def __init__(self, model: PipelineModel, device: Optional[torch.device] = None): if device is None: device = get_devices(needs=1)[0] self.model.to(device) + self.formatter = TemporalFeatureFormatter() @property def sample_rate(self) -> int: @@ -60,32 +31,9 @@ def duration(self) -> float: def __call__(self, waveform: TemporalFeatures) -> TemporalFeatures: with torch.no_grad(): - wave = rearrange(resolve_features(waveform), "batch sample channel -> batch channel sample") + wave = rearrange(self.formatter.cast(waveform), "batch sample channel -> batch channel sample") output = self.model(wave.to(self.model.device)).cpu() - - batch_size, num_frames, _ = output.shape - - # Remove batch dimension if batch size is 1 - if output.shape[0] == 1: - output = output[0] - - # Wrap if a SlidingWindowFeature was given as input - if isinstance(waveform, SlidingWindowFeature): - # Temporal resolution of the output - duration = wave.shape[-1] / self.sample_rate - resolution = duration / num_frames - # Temporal shift to keep track of current start time - resolution = SlidingWindow( - start=waveform.sliding_window.start, - duration=resolution, - step=resolution - ) - return SlidingWindowFeature(output.numpy(), resolution) - - if isinstance(waveform, np.ndarray): - return output.numpy() - - return output + return self.formatter.restore_type(output) class ChunkwiseModel: @@ -95,13 +43,15 @@ def __init__(self, model: PipelineModel, device: Optional[torch.device] = None): if device is None: device = get_devices(needs=1)[0] self.model.to(device) + self.waveform_formatter = TemporalFeatureFormatter() + self.weights_formatter = TemporalFeatureFormatter() def __call__(self, waveform: TemporalFeatures, weights: Optional[TemporalFeatures]) -> torch.Tensor: with torch.no_grad(): - inputs = resolve_features(waveform).to(self.model.device) + inputs = self.waveform_formatter.cast(waveform).to(self.model.device) inputs = rearrange(inputs, "batch sample channel -> batch channel sample") if weights is not None: - weights = resolve_features(weights).to(self.model.device) + weights = self.weights_formatter.cast(weights).to(self.model.device) batch_size, _, num_speakers = weights.shape inputs = inputs.repeat(1, num_speakers, 1) weights = rearrange(weights, "batch frame spk -> (batch spk) frame") @@ -131,18 +81,15 @@ class OverlappedSpeechPenalty: def __init__(self, gamma: float = 3, beta: float = 10): self.gamma = gamma self.beta = beta + self.formatter = TemporalFeatureFormatter() def __call__(self, segmentation: TemporalFeatures) -> TemporalFeatures: - weights = resolve_features(segmentation) # shape (batch, frames, speakers) + weights = self.formatter.cast(segmentation) # shape (batch, frames, speakers) with torch.no_grad(): probs = torch.softmax(self.beta * weights, dim=-1) weights = torch.pow(weights, self.gamma) * torch.pow(probs, self.gamma) weights[weights < 1e-8] = 1e-8 - if isinstance(segmentation, SlidingWindowFeature): - return SlidingWindowFeature(weights.cpu().numpy(), segmentation.sliding_window) - if isinstance(segmentation, np.ndarray): - return weights.cpu().numpy() - return weights + return self.formatter.restore_type(weights) class EmbeddingNormalization: diff --git a/src/diart/features.py b/src/diart/features.py new file mode 100644 index 00000000..eaf9d55e --- /dev/null +++ b/src/diart/features.py @@ -0,0 +1,135 @@ +from typing import Union, Optional + +import numpy as np +import torch +from pyannote.core import SlidingWindow, SlidingWindowFeature + + +TemporalFeatures = Union[SlidingWindowFeature, np.ndarray, torch.Tensor] + + +class TemporalFeatureFormatterState: + """ + Represents the recorded type of a temporal feature formatter. + Its job is to transform temporal features into tensors and + recover the original format on other features. + """ + def to_tensor(self, features: TemporalFeatures) -> torch.Tensor: + raise NotImplementedError + + def to_internal_type(self, features: torch.Tensor) -> TemporalFeatures: + """ + Cast `features` to the representing type and remove batch dimension if required. + + Parameters + ---------- + features: torch.Tensor, shape (batch, frames, dim) + Batched temporal features. + Returns + ------- + new_features: SlidingWindowFeature or numpy.ndarray or torch.Tensor, shape (batch, frames, dim) + """ + raise NotImplementedError + + +class SlidingWindowFeatureFormatterState(TemporalFeatureFormatterState): + def __init__(self, duration: float): + self.duration = duration + self._cur_start_time = 0 + + def to_tensor(self, features: SlidingWindowFeature) -> torch.Tensor: + msg = "Features sliding window duration and step must be equal" + assert features.sliding_window.duration == features.sliding_window.step, msg + self._cur_start_time = features.sliding_window.start + return torch.from_numpy(features.data) + + def to_internal_type(self, features: torch.Tensor) -> TemporalFeatures: + batch_size, num_frames, _ = features.shape + assert batch_size == 1, "Batched SlidingWindowFeature objects are not supported" + # Calculate resolution + resolution = self.duration / num_frames + # Temporal shift to keep track of current start time + resolution = SlidingWindow(start=self._cur_start_time, duration=resolution, step=resolution) + return SlidingWindowFeature(features.squeeze(dim=0).cpu().numpy(), resolution) + + +class NumpyArrayFormatterState(TemporalFeatureFormatterState): + def to_tensor(self, features: np.ndarray) -> torch.Tensor: + return torch.from_numpy(features) + + def to_internal_type(self, features: torch.Tensor) -> TemporalFeatures: + return features.cpu().numpy() + + +class PytorchTensorFormatterState(TemporalFeatureFormatterState): + def to_tensor(self, features: torch.Tensor) -> torch.Tensor: + return features + + def to_internal_type(self, features: torch.Tensor) -> TemporalFeatures: + return features + + +class TemporalFeatureFormatter: + """ + Manages the typing and format of temporal features. + When casting temporal features as torch.Tensor, it remembers its + type and format so it can lately restore it on other temporal features. + """ + def __init__(self): + self.state: Optional[TemporalFeatureFormatterState] = None + + def set_state(self, features: TemporalFeatures): + if self.state is not None: + return + + if isinstance(features, SlidingWindowFeature): + msg = "Features sliding window duration and step must be equal" + assert features.sliding_window.duration == features.sliding_window.step, msg + self.state = SlidingWindowFeatureFormatterState( + features.data.shape[0] * features.sliding_window.duration, + ) + elif isinstance(features, np.ndarray): + self.state = NumpyArrayFormatterState() + elif isinstance(features, torch.Tensor): + self.state = PytorchTensorFormatterState() + else: + msg = "Unknown format. Provide one of SlidingWindowFeature, numpy.ndarray, torch.Tensor" + raise ValueError(msg) + + def cast(self, features: TemporalFeatures) -> torch.Tensor: + """ + Transform features into a `torch.Tensor` and add batch dimension if missing. + + Parameters + ---------- + features: SlidingWindowFeature or numpy.ndarray or torch.Tensor + Shape (frames, dim) or (batch, frames, dim) + + Returns + ------- + features: torch.Tensor, shape (batch, frames, dim) + """ + # Set state if not initialized + self.set_state(features) + # Convert features to tensor + data = self.state.to_tensor(features) + # Make sure there's a batch dimension + msg = "Temporal features must be 2D or 3D" + assert data.ndim in (2, 3), msg + if data.ndim == 2: + data = data.unsqueeze(0) + return data.float() + + def restore_type(self, features: torch.Tensor) -> TemporalFeatures: + """ + Cast `features` to the internal type and remove batch dimension if required. + + Parameters + ---------- + features: torch.Tensor, shape (batch, frames, dim) + Batched temporal features. + Returns + ------- + new_features: SlidingWindowFeature or numpy.ndarray or torch.Tensor, shape (batch, frames, dim) + """ + return self.state.to_internal_type(features) From 7842258a3cf3b196364409590a6bb6520681cbc8 Mon Sep 17 00:00:00 2001 From: Juan Manuel Coria Date: Fri, 3 Jun 2022 10:42:27 +0200 Subject: [PATCH 06/29] Start reducing strong dependency on pyannote.audio non-model classes --- src/diart/models.py | 57 +++++++++++++++++++++++++++++++++++++------- src/diart/sinks.py | 2 +- src/diart/sources.py | 31 ++++++++++++++---------- 3 files changed, 67 insertions(+), 23 deletions(-) diff --git a/src/diart/models.py b/src/diart/models.py index 74ae5c9b..db777d9d 100644 --- a/src/diart/models.py +++ b/src/diart/models.py @@ -2,7 +2,12 @@ import torch import torch.nn as nn -from pyannote.audio.pipelines.utils import PipelineModel, get_model + +try: + import pyannote.audio.pipelines.utils as pyannote + _has_pyannote = True +except ImportError: + _has_pyannote = False class SegmentationModel(nn.Module): @@ -11,11 +16,24 @@ class SegmentationModel(nn.Module): """ @staticmethod - def from_pyannote(model: PipelineModel) -> 'SegmentationModel': + def from_pyannote(model) -> 'SegmentationModel': + """ + Returns a `SegmentationModel` wrapping a pyannote model. + + Parameters + ---------- + model: pyannote.PipelineModel + + Returns + ------- + wrapper: SegmentationModel + """ + assert _has_pyannote, "No pyannote.audio installation found" + class PyannoteSegmentationModel(SegmentationModel): - def __init__(self, pyannote_model: PipelineModel): + def __init__(self, pyannote_model): super().__init__() - self.model = get_model(pyannote_model) + self.model = pyannote.get_model(pyannote_model) def get_sample_rate(self) -> int: return self.model.audio.sample_rate @@ -55,18 +73,39 @@ class EmbeddingModel(nn.Module): """Minimal interface for an embedding model.""" @staticmethod - def from_pyannote(model: PipelineModel) -> 'EmbeddingModel': + def from_pyannote(model) -> 'EmbeddingModel': + """ + Returns an `EmbeddingModel` wrapping a pyannote model. + + Parameters + ---------- + model: pyannote.PipelineModel + + Returns + ------- + wrapper: EmbeddingModel + """ + assert _has_pyannote, "No pyannote.audio installation found" + class PyannoteEmbeddingModel(EmbeddingModel): - def __init__(self, pyannote_model: PipelineModel): + def __init__(self, pyannote_model): super().__init__() - self.model = get_model(pyannote_model) + self.model = pyannote.get_model(pyannote_model) - def __call__(self, waveform: torch.Tensor, weights: Optional[torch.Tensor] = None) -> torch.Tensor: + def __call__( + self, + waveform: torch.Tensor, + weights: Optional[torch.Tensor] = None, + ) -> torch.Tensor: return self.model(waveform, weights=weights) return PyannoteEmbeddingModel(model) - def __call__(self, waveform: torch.Tensor, weights: Optional[torch.Tensor] = None) -> torch.Tensor: + def __call__( + self, + waveform: torch.Tensor, + weights: Optional[torch.Tensor] = None + ) -> torch.Tensor: """ Forward pass of an embedding model with optional weights. diff --git a/src/diart/sinks.py b/src/diart/sinks.py index 7d0bb8d2..0d33bf01 100644 --- a/src/diart/sinks.py +++ b/src/diart/sinks.py @@ -1,13 +1,13 @@ from pathlib import Path from traceback import print_exc from typing import Union, Text, Optional, Tuple -from typing_extensions import Literal import matplotlib.pyplot as plt from pyannote.core import Annotation, Segment, SlidingWindowFeature, notebook from pyannote.database.util import load_rttm from pyannote.metrics.diarization import DiarizationErrorRate from rx.core import Observer +from typing_extensions import Literal class RTTMWriter(Observer): diff --git a/src/diart/sources.py b/src/diart/sources.py index 472c2a09..88012794 100644 --- a/src/diart/sources.py +++ b/src/diart/sources.py @@ -1,16 +1,21 @@ import random import time from queue import SimpleQueue -from typing import Tuple, Text, Optional, Iterable, List +from typing import Tuple, Text, Optional, Iterable, List, Union +from pathlib import Path import numpy as np import sounddevice as sd from einops import rearrange -from pyannote.audio.core.io import Audio, AudioFile +# TODO replace with torchaudio +from pyannote.audio.core.io import Audio from pyannote.core import SlidingWindowFeature, SlidingWindow from rx.subject import Subject +FilePath = Union[Text, Path] + + class ChunkLoader: """Loads an audio file and chunks it according to a given window and step size. @@ -36,7 +41,7 @@ def __init__( self.window_samples = int(round(window_duration * sample_rate)) self.step_samples = int(round(step_duration * sample_rate)) - def get_chunks(self, file: AudioFile) -> np.ndarray: + def get_chunks(self, file: FilePath) -> np.ndarray: waveform, _ = self.audio(file) _, num_samples = waveform.shape chunks = rearrange( @@ -51,7 +56,7 @@ def get_chunks(self, file: AudioFile) -> np.ndarray: return np.vstack([chunks, last_chunk]) return chunks - def num_chunks(self, file: AudioFile) -> int: + def num_chunks(self, file: FilePath) -> int: numerator = self.audio.get_duration(file) - self.window_duration + self.step_duration return int(np.ceil(numerator / self.step_duration)) @@ -115,13 +120,13 @@ def is_regular(self) -> bool: A regular reading method always yields the same amount of samples.""" return False - def get_duration(self, file: AudioFile) -> float: + def get_duration(self, file: FilePath) -> float: return self.audio.get_duration(file) - def get_num_chunks(self, file: AudioFile) -> Optional[int]: + def get_num_chunks(self, file: FilePath) -> Optional[int]: return None - def iterate(self, file: AudioFile) -> Iterable[SlidingWindowFeature]: + def iterate(self, file: FilePath) -> Iterable[SlidingWindowFeature]: """Return an iterable over the file's samples""" raise NotImplementedError @@ -153,11 +158,11 @@ def __init__( def is_regular(self) -> bool: return True - def get_num_chunks(self, file: AudioFile) -> Optional[int]: + def get_num_chunks(self, file: FilePath) -> Optional[int]: """Return the number of chunks emitted for `file`""" return self.chunk_loader.num_chunks(file) - def iterate(self, file: AudioFile) -> Iterable[SlidingWindowFeature]: + def iterate(self, file: FilePath) -> Iterable[SlidingWindowFeature]: chunks = self.chunk_loader.get_chunks(file) for i, chunk in enumerate(chunks): w = SlidingWindow( @@ -192,7 +197,7 @@ def __init__( self.start, self.end = refresh_rate_range self.delay = simulate_delay - def iterate(self, file: AudioFile) -> Iterable[SlidingWindowFeature]: + def iterate(self, file: FilePath) -> Iterable[SlidingWindowFeature]: waveform, _ = self.audio(file) total_samples = waveform.shape[1] i = 0 @@ -211,8 +216,8 @@ class FileAudioSource(AudioSource): Parameters ---------- - file: AudioFile - The file to stream. + file: FilePath + Path to the file to stream. uri: Text Unique identifier of the audio source. reader: AudioFileReader @@ -222,7 +227,7 @@ class FileAudioSource(AudioSource): """ def __init__( self, - file: AudioFile, + file: FilePath, uri: Text, reader: AudioFileReader, profile: bool = False, From f521a9e22ac97c489be84fc5ffd0ee383472795f Mon Sep 17 00:00:00 2001 From: Juan Manuel Coria Date: Fri, 3 Jun 2022 18:11:16 +0200 Subject: [PATCH 07/29] Replace pyannote's Audio class with simpler custom implementation to avoid hard dependency on pyannote.audio --- requirements.txt | 4 +++ setup.cfg | 6 +++- src/diart/pipelines.py | 2 +- src/diart/sources.py | 66 +++++++++++++++++++++++++++++++++++------- 4 files changed, 66 insertions(+), 12 deletions(-) diff --git a/requirements.txt b/requirements.txt index bc7472a7..2a267b34 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,4 +6,8 @@ sounddevice>=0.4.2 einops>=0.3.0 tqdm>=4.64.0 pandas>=1.4.2 +torchaudio>=0.10,<1.0 +pyannote.core>=4.4 +pyannote.database>=4.1.1 +pyannote.metrics>=3.2 git+https://github.com/pyannote/pyannote-audio.git@develop#egg=pyannote-audio diff --git a/setup.cfg b/setup.cfg index ee69d733..9b6f5490 100644 --- a/setup.cfg +++ b/setup.cfg @@ -28,7 +28,11 @@ install_requires = einops>=0.3.0 tqdm>=4.64.0 pandas>=1.4.2 - pyannote-audio @ git+https://github.com/pyannote/pyannote-audio.git@develop#egg=pyannote-audio + torchaudio>=0.10,<1.0 + pyannote.core>=4.4 + pyannote.database>=4.1.1 + pyannote.metrics>=3.2 + git+https://github.com/pyannote/pyannote-audio.git@develop#egg=pyannote-audio [options.packages.find] diff --git a/src/diart/pipelines.py b/src/diart/pipelines.py index 2087f986..01d0d6a4 100644 --- a/src/diart/pipelines.py +++ b/src/diart/pipelines.py @@ -229,7 +229,7 @@ def from_file( ) # Build speaker tracking pipeline - duration = chunk_loader.audio.get_duration(file) + duration = chunk_loader.loader.get_duration(file) return self.speaker_tracking.from_model_streams( file.stem, duration, seg_stream, emb_stream, chunk_stream ) diff --git a/src/diart/sources.py b/src/diart/sources.py index 88012794..15d2f535 100644 --- a/src/diart/sources.py +++ b/src/diart/sources.py @@ -1,21 +1,67 @@ import random import time +from pathlib import Path from queue import SimpleQueue from typing import Tuple, Text, Optional, Iterable, List, Union -from pathlib import Path import numpy as np import sounddevice as sd +import torch +import torchaudio from einops import rearrange -# TODO replace with torchaudio -from pyannote.audio.core.io import Audio from pyannote.core import SlidingWindowFeature, SlidingWindow from rx.subject import Subject +from torchaudio.functional import resample + +torchaudio.set_audio_backend("soundfile") FilePath = Union[Text, Path] +class AudioLoader: + def __init__(self, sample_rate: int, mono: bool = True): + self.sample_rate = sample_rate + self.mono = mono + + def load(self, filepath: FilePath) -> torch.Tensor: + """ + Load an audio file into a torch.Tensor. + + Parameters + ---------- + filepath : FilePath + + Returns + ------- + waveform : torch.Tensor, shape (channels, samples) + """ + waveform, sample_rate = torchaudio.load(filepath) + # Get channel mean if mono + if self.mono and waveform.shape[0] > 1: + waveform = waveform.mean(dim=0, keepdim=True) + # Resample if needed + if self.sample_rate != sample_rate: + waveform = resample(waveform, sample_rate, self.sample_rate) + return waveform + + @staticmethod + def get_duration(filepath: FilePath) -> float: + """Get audio file duration in seconds. + + Parameters + ---------- + filepath : FilePath + + Returns + ------- + duration : float + Duration in seconds. + """ + info = torchaudio.info(filepath) + return info.num_frames / info.sample_rate + + class ChunkLoader: """Loads an audio file and chunks it according to a given window and step size. @@ -35,14 +81,14 @@ def __init__( window_duration: float, step_duration: float, ): - self.audio = Audio(sample_rate, mono=True) + self.loader = AudioLoader(sample_rate, mono=True) self.window_duration = window_duration self.step_duration = step_duration self.window_samples = int(round(window_duration * sample_rate)) self.step_samples = int(round(step_duration * sample_rate)) def get_chunks(self, file: FilePath) -> np.ndarray: - waveform, _ = self.audio(file) + waveform = self.loader.load(file) _, num_samples = waveform.shape chunks = rearrange( waveform.unfold(1, self.window_samples, self.step_samples), @@ -57,7 +103,7 @@ def get_chunks(self, file: FilePath) -> np.ndarray: return chunks def num_chunks(self, file: FilePath) -> int: - numerator = self.audio.get_duration(file) - self.window_duration + self.step_duration + numerator = self.loader.get_duration(file) - self.window_duration + self.step_duration return int(np.ceil(numerator / self.step_duration)) @@ -107,12 +153,12 @@ class AudioFileReader: Sample rate of the audio file. """ def __init__(self, sample_rate: int): - self.audio = Audio(sample_rate=sample_rate, mono=True) + self.loader = AudioLoader(sample_rate, mono=True) self.resolution = 1 / sample_rate @property def sample_rate(self) -> int: - return self.audio.sample_rate + return self.loader.sample_rate @property def is_regular(self) -> bool: @@ -121,7 +167,7 @@ def is_regular(self) -> bool: return False def get_duration(self, file: FilePath) -> float: - return self.audio.get_duration(file) + return self.loader.get_duration(file) def get_num_chunks(self, file: FilePath) -> Optional[int]: return None @@ -198,7 +244,7 @@ def __init__( self.delay = simulate_delay def iterate(self, file: FilePath) -> Iterable[SlidingWindowFeature]: - waveform, _ = self.audio(file) + waveform = self.loader.load(file) total_samples = waveform.shape[1] i = 0 while i < total_samples: From 230cda7f0d7fb5e3b85414c9cdc60bdd9093505d Mon Sep 17 00:00:00 2001 From: Juan Manuel Coria Date: Mon, 6 Jun 2022 17:22:00 +0200 Subject: [PATCH 08/29] Merge AudioLoader and ChunkLoader --- README.md | 3 ++ setup.cfg | 1 - src/diart/audio.py | 101 ++++++++++++++++++++++++++++++++++ src/diart/inference.py | 10 ++-- src/diart/pipelines.py | 15 ++---- src/diart/sources.py | 119 ++++------------------------------------- 6 files changed, 124 insertions(+), 125 deletions(-) create mode 100644 src/diart/audio.py diff --git a/README.md b/README.md index 5faf55d5..f12e0753 100644 --- a/README.md +++ b/README.md @@ -29,10 +29,13 @@ conda activate diart 2) [Install PyTorch](https://pytorch.org/get-started/locally/#start-locally) 3) Install pyannote.audio 2.0 (currently in development) + ```shell pip install git+https://github.com/pyannote/pyannote-audio.git@develop#egg=pyannote-audio ``` +*Note:* starting from version 0.4, installing pyannote.audio is mandatory to run the default system or to use pyannote-based models. In any other case, this step can be ignored. + 4) Install diart: ```shell pip install diart diff --git a/setup.cfg b/setup.cfg index 9b6f5490..d2d5c5b3 100644 --- a/setup.cfg +++ b/setup.cfg @@ -32,7 +32,6 @@ install_requires = pyannote.core>=4.4 pyannote.database>=4.1.1 pyannote.metrics>=3.2 - git+https://github.com/pyannote/pyannote-audio.git@develop#egg=pyannote-audio [options.packages.find] diff --git a/src/diart/audio.py b/src/diart/audio.py new file mode 100644 index 00000000..e607ba5f --- /dev/null +++ b/src/diart/audio.py @@ -0,0 +1,101 @@ +from pathlib import Path +from typing import Text, Union + +import numpy as np +import torch +import torchaudio +from einops import rearrange +from torchaudio.functional import resample + +torchaudio.set_audio_backend("soundfile") + + +FilePath = Union[Text, Path] + + +class AudioLoader: + def __init__(self, sample_rate: int, mono: bool = True): + self.sample_rate = sample_rate + self.mono = mono + + def load(self, filepath: FilePath) -> torch.Tensor: + """Load an audio file into a torch.Tensor. + + Parameters + ---------- + filepath : FilePath + Path to an audio file + + Returns + ------- + waveform : torch.Tensor, shape (channels, samples) + """ + waveform, sample_rate = torchaudio.load(filepath) + # Get channel mean if mono + if self.mono and waveform.shape[0] > 1: + waveform = waveform.mean(dim=0, keepdim=True) + # Resample if needed + if self.sample_rate != sample_rate: + waveform = resample(waveform, sample_rate, self.sample_rate) + return waveform + + @staticmethod + def get_duration(filepath: FilePath) -> float: + """Get audio file duration in seconds. + + Parameters + ---------- + filepath : FilePath + Path to an audio file + + Returns + ------- + duration : float + Duration in seconds. + """ + info = torchaudio.info(filepath) + return info.num_frames / info.sample_rate + + def load_sliding_chunks(self, filepath: FilePath, chunk_duration: float, step_duration: float) -> np.ndarray: + """Load an audio file and extract sliding chunks of a given duration with a given step duration. + + Parameters + ---------- + filepath : FilePath + Path to an audio file + chunk_duration: float + Duration of the chunk in seconds. + step_duration: float + Duration of the step between chunks in seconds. + """ + chunk_samples = int(round(chunk_duration * self.sample_rate)) + step_samples = int(round(step_duration * self.sample_rate)) + waveform = self.load(filepath) + _, num_samples = waveform.shape + chunks = rearrange( + waveform.unfold(1, chunk_samples, step_samples), + "channel chunk sample -> chunk channel sample", + ).numpy() + # Add padded last chunk + if num_samples - chunk_samples % step_samples > 0: + last_chunk = waveform[:, chunks.shape[0] * step_samples:].unsqueeze(0).numpy() + diff_samples = chunk_samples - last_chunk.shape[-1] + last_chunk = np.concatenate([last_chunk, np.zeros((1, 1, diff_samples))], axis=-1) + return np.vstack([chunks, last_chunk]) + return chunks + + def get_num_sliding_chunks(self, filepath: FilePath, chunk_duration: float, step_duration: float) -> int: + """Estimate the number of sliding chunks of a + given chunk duration and step without loading the audio. + + Parameters + ---------- + filepath : FilePath + Path to an audio file + chunk_duration: float + Duration of the chunk in seconds. + step_duration: float + Duration of the step between chunks in seconds. + """ + numerator = self.get_duration(filepath) - chunk_duration + step_duration + return int(np.ceil(numerator / step_duration)) \ No newline at end of file diff --git a/src/diart/inference.py b/src/diart/inference.py index 4574279b..da9dfbc2 100644 --- a/src/diart/inference.py +++ b/src/diart/inference.py @@ -125,15 +125,13 @@ def __call__(self, pipeline: OnlineSpeakerDiarization, batch_size: int = 32) -> DataFrame with detailed performance on each file, as well as average performance. None if the reference is not provided. """ - chunk_loader = src.ChunkLoader( - pipeline.config.sample_rate, - pipeline.config.duration, - pipeline.config.step - ) + loader = src.AudioLoader(pipeline.config.sample_rate, mono=True) audio_file_paths = list(self.speech_path.iterdir()) num_audio_files = len(audio_file_paths) for i, filepath in enumerate(audio_file_paths): - num_chunks = chunk_loader.num_chunks(filepath) + num_chunks = loader.get_num_sliding_chunks( + filepath, pipeline.config.duration, pipeline.config.step + ) # Stream fully online if batch size is 1 or lower source = None diff --git a/src/diart/pipelines.py b/src/diart/pipelines.py index 01d0d6a4..4ff453ea 100644 --- a/src/diart/pipelines.py +++ b/src/diart/pipelines.py @@ -1,6 +1,6 @@ import math from pathlib import Path -from typing import Optional, Union, Text +from typing import Optional, Text import numpy as np import rx @@ -167,20 +167,16 @@ def from_source( def from_file( self, - file: Union[Text, Path], + filepath: src.FilePath, output_waveform: bool = False, batch_size: int = 32, desc: Optional[Text] = None, ) -> rx.Observable: - # Audio file information - file = Path(file) - chunk_loader = src.ChunkLoader( - self.config.sample_rate, self.config.duration, self.config.step - ) + loader = src.AudioLoader(self.config.sample_rate, mono=True) # Split audio into chunks chunks = rearrange( - chunk_loader.get_chunks(file), + loader.load_sliding_chunks(filepath, self.config.duration, self.config.step), "chunk channel sample -> chunk sample channel" ) num_chunks = chunks.shape[0] @@ -229,7 +225,6 @@ def from_file( ) # Build speaker tracking pipeline - duration = chunk_loader.loader.get_duration(file) return self.speaker_tracking.from_model_streams( - file.stem, duration, seg_stream, emb_stream, chunk_stream + Path(filepath).stem, loader.get_duration(filepath), seg_stream, emb_stream, chunk_stream ) diff --git a/src/diart/sources.py b/src/diart/sources.py index 15d2f535..93a33b3a 100644 --- a/src/diart/sources.py +++ b/src/diart/sources.py @@ -1,110 +1,14 @@ import random import time -from pathlib import Path from queue import SimpleQueue -from typing import Tuple, Text, Optional, Iterable, List, Union +from typing import Tuple, Text, Optional, Iterable, List import numpy as np import sounddevice as sd -import torch -import torchaudio -from einops import rearrange from pyannote.core import SlidingWindowFeature, SlidingWindow from rx.subject import Subject -from torchaudio.functional import resample -torchaudio.set_audio_backend("soundfile") - - -FilePath = Union[Text, Path] - - -class AudioLoader: - def __init__(self, sample_rate: int, mono: bool = True): - self.sample_rate = sample_rate - self.mono = mono - - def load(self, filepath: FilePath) -> torch.Tensor: - """ - Load an audio file into a torch.Tensor. - - Parameters - ---------- - filepath : FilePath - - Returns - ------- - waveform : torch.Tensor, shape (channels, samples) - """ - waveform, sample_rate = torchaudio.load(filepath) - # Get channel mean if mono - if self.mono and waveform.shape[0] > 1: - waveform = waveform.mean(dim=0, keepdim=True) - # Resample if needed - if self.sample_rate != sample_rate: - waveform = resample(waveform, sample_rate, self.sample_rate) - return waveform - - @staticmethod - def get_duration(filepath: FilePath) -> float: - """Get audio file duration in seconds. - - Parameters - ---------- - filepath : FilePath - - Returns - ------- - duration : float - Duration in seconds. - """ - info = torchaudio.info(filepath) - return info.num_frames / info.sample_rate - - -class ChunkLoader: - """Loads an audio file and chunks it according to a given window and step size. - - Parameters - ---------- - sample_rate: int - Sample rate to load audio. - window_duration: float - Duration of the chunk in seconds. - step_duration: float - Duration of the step between chunks in seconds. - """ - - def __init__( - self, - sample_rate: int, - window_duration: float, - step_duration: float, - ): - self.loader = AudioLoader(sample_rate, mono=True) - self.window_duration = window_duration - self.step_duration = step_duration - self.window_samples = int(round(window_duration * sample_rate)) - self.step_samples = int(round(step_duration * sample_rate)) - - def get_chunks(self, file: FilePath) -> np.ndarray: - waveform = self.loader.load(file) - _, num_samples = waveform.shape - chunks = rearrange( - waveform.unfold(1, self.window_samples, self.step_samples), - "channel chunk frame -> chunk channel frame", - ).numpy() - # Add padded last chunk - if num_samples - self.window_samples % self.step_samples > 0: - last_chunk = waveform[:, chunks.shape[0] * self.step_samples:].unsqueeze(0).numpy() - diff_samples = self.window_samples - last_chunk.shape[-1] - last_chunk = np.concatenate([last_chunk, np.zeros((1, 1, diff_samples))], axis=-1) - return np.vstack([chunks, last_chunk]) - return chunks - - def num_chunks(self, file: FilePath) -> int: - numerator = self.loader.get_duration(file) - self.window_duration + self.step_duration - return int(np.ceil(numerator / self.step_duration)) +from .audio import FilePath, AudioLoader class AudioSource: @@ -184,7 +88,7 @@ class RegularAudioFileReader(AudioFileReader): ---------- sample_rate: int Sample rate of the audio file. - window_duration: float + chunk_duration: float Duration of each chunk of samples (window) in seconds. step_duration: float Step duration between chunks in seconds. @@ -192,27 +96,26 @@ class RegularAudioFileReader(AudioFileReader): def __init__( self, sample_rate: int, - window_duration: float, + chunk_duration: float, step_duration: float, ): super().__init__(sample_rate) - self.chunk_loader = ChunkLoader( - sample_rate, window_duration, step_duration - ) + self.chunk_duration = chunk_duration + self.step_duration = step_duration @property def is_regular(self) -> bool: return True - def get_num_chunks(self, file: FilePath) -> Optional[int]: - """Return the number of chunks emitted for `file`""" - return self.chunk_loader.num_chunks(file) + def get_num_chunks(self, filepath: FilePath) -> Optional[int]: + """Return the number of chunks that will be emitted for a given file""" + return self.loader.get_num_sliding_chunks(filepath, self.chunk_duration, self.step_duration) def iterate(self, file: FilePath) -> Iterable[SlidingWindowFeature]: - chunks = self.chunk_loader.get_chunks(file) + chunks = self.loader.load_sliding_chunks(file, self.chunk_duration, self.step_duration) for i, chunk in enumerate(chunks): w = SlidingWindow( - start=i * self.chunk_loader.step_duration, + start=i * self.step_duration, duration=self.resolution, step=self.resolution ) From 0e9e718cfed41d177df93ebef390ade129d35a15 Mon Sep 17 00:00:00 2001 From: Juan Coria Date: Thu, 9 Jun 2022 14:05:00 +0200 Subject: [PATCH 09/29] Update README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index f12e0753..bbab03fa 100644 --- a/README.md +++ b/README.md @@ -34,7 +34,7 @@ conda activate diart pip install git+https://github.com/pyannote/pyannote-audio.git@develop#egg=pyannote-audio ``` -*Note:* starting from version 0.4, installing pyannote.audio is mandatory to run the default system or to use pyannote-based models. In any other case, this step can be ignored. +**Note:** starting from version 0.4, installing pyannote.audio is mandatory to run the default system or to use pyannote-based models. In any other case, this step can be ignored. 4) Install diart: ```shell From 9110d957be7cd70e8f39884f4f6485b6ab3c5a25 Mon Sep 17 00:00:00 2001 From: Juan Manuel Coria Date: Thu, 9 Jun 2022 15:33:11 +0200 Subject: [PATCH 10/29] Remove file audio readers and make all file reading regular --- src/diart/inference.py | 8 +-- src/diart/sources.py | 160 +++++++++-------------------------------- src/diart/stream.py | 12 ++-- 3 files changed, 41 insertions(+), 139 deletions(-) diff --git a/src/diart/inference.py b/src/diart/inference.py index da9dfbc2..b45de89f 100644 --- a/src/diart/inference.py +++ b/src/diart/inference.py @@ -139,11 +139,9 @@ def __call__(self, pipeline: OnlineSpeakerDiarization, batch_size: int = 32) -> source = src.FileAudioSource( filepath, filepath.stem, - src.RegularAudioFileReader( - pipeline.config.sample_rate, - pipeline.config.duration, - pipeline.config.step - ), + pipeline.config.sample_rate, + pipeline.config.duration, + pipeline.config.step, # Benchmark the processing time of a single chunk profile=True, ) diff --git a/src/diart/sources.py b/src/diart/sources.py index 93a33b3a..8af379e9 100644 --- a/src/diart/sources.py +++ b/src/diart/sources.py @@ -1,7 +1,6 @@ -import random import time from queue import SimpleQueue -from typing import Tuple, Text, Optional, Iterable, List +from typing import Text, Optional, List import numpy as np import sounddevice as sd @@ -48,118 +47,6 @@ def read(self): raise NotImplementedError -class AudioFileReader: - """Represents a method for reading an audio file. - - Parameters - ---------- - sample_rate: int - Sample rate of the audio file. - """ - def __init__(self, sample_rate: int): - self.loader = AudioLoader(sample_rate, mono=True) - self.resolution = 1 / sample_rate - - @property - def sample_rate(self) -> int: - return self.loader.sample_rate - - @property - def is_regular(self) -> bool: - """Whether the reading is regular. Defaults to False. - A regular reading method always yields the same amount of samples.""" - return False - - def get_duration(self, file: FilePath) -> float: - return self.loader.get_duration(file) - - def get_num_chunks(self, file: FilePath) -> Optional[int]: - return None - - def iterate(self, file: FilePath) -> Iterable[SlidingWindowFeature]: - """Return an iterable over the file's samples""" - raise NotImplementedError - - -class RegularAudioFileReader(AudioFileReader): - """Reads a file always yielding the same number of samples with a given step. - - Parameters - ---------- - sample_rate: int - Sample rate of the audio file. - chunk_duration: float - Duration of each chunk of samples (window) in seconds. - step_duration: float - Step duration between chunks in seconds. - """ - def __init__( - self, - sample_rate: int, - chunk_duration: float, - step_duration: float, - ): - super().__init__(sample_rate) - self.chunk_duration = chunk_duration - self.step_duration = step_duration - - @property - def is_regular(self) -> bool: - return True - - def get_num_chunks(self, filepath: FilePath) -> Optional[int]: - """Return the number of chunks that will be emitted for a given file""" - return self.loader.get_num_sliding_chunks(filepath, self.chunk_duration, self.step_duration) - - def iterate(self, file: FilePath) -> Iterable[SlidingWindowFeature]: - chunks = self.loader.load_sliding_chunks(file, self.chunk_duration, self.step_duration) - for i, chunk in enumerate(chunks): - w = SlidingWindow( - start=i * self.step_duration, - duration=self.resolution, - step=self.resolution - ) - yield SlidingWindowFeature(chunk.T, w) - - -class IrregularAudioFileReader(AudioFileReader): - """Reads an audio file yielding a different number of non-overlapping samples in each event. - This class is useful to simulate how a system would work in unreliable reading conditions. - - Parameters - ---------- - sample_rate: int - Sample rate of the audio file. - refresh_rate_range: (float, float) - Duration range within which to determine the number of samples to yield (in seconds). - simulate_delay: bool - Whether to simulate that the samples are being read in real time before they are yielded. - Defaults to False (no delay). - """ - def __init__( - self, - sample_rate: int, - refresh_rate_range: Tuple[float, float], - simulate_delay: bool = False, - ): - super().__init__(sample_rate) - self.start, self.end = refresh_rate_range - self.delay = simulate_delay - - def iterate(self, file: FilePath) -> Iterable[SlidingWindowFeature]: - waveform = self.loader.load(file) - total_samples = waveform.shape[1] - i = 0 - while i < total_samples: - rnd_duration = random.uniform(self.start, self.end) - if self.delay: - time.sleep(rnd_duration) - num_samples = int(round(rnd_duration * self.sample_rate)) - last_i = i - i += num_samples - yield waveform[:, last_i:i] - - class FileAudioSource(AudioSource): """Represents an audio source tied to a file. @@ -169,8 +56,12 @@ class FileAudioSource(AudioSource): Path to the file to stream. uri: Text Unique identifier of the audio source. - reader: AudioFileReader - Determines how the file will be read. + sample_rate: int + Sample rate of the chunks emitted. + chunk_duration: float + Duration of each chunk in seconds. Defaults to 5s. + step_duration: float + Duration of the step between consecutive chunks in seconds. Defaults to 500ms. profile: bool If True, prints the average processing time of emitting a chunk. Defaults to False. """ @@ -178,19 +69,24 @@ def __init__( self, file: FilePath, uri: Text, - reader: AudioFileReader, + sample_rate: int, + chunk_duration: float = 5, + step_duration: float = 0.5, profile: bool = False, ): - super().__init__(uri, reader.sample_rate) - self.reader = reader - self._duration = self.reader.get_duration(file) + super().__init__(uri, sample_rate) + self.loader = AudioLoader(sample_rate, mono=True) + self._duration = self.loader.get_duration(file) self.file = file + self.chunk_duration = chunk_duration + self.step_duration = step_duration self.profile = profile + self.resolution = 1 / sample_rate @property def is_regular(self) -> bool: - # The regularity depends on the reader - return self.reader.is_regular + # An audio file is always a regular source + return True @property def duration(self) -> Optional[float]: @@ -199,8 +95,9 @@ def duration(self) -> Optional[float]: @property def length(self) -> Optional[int]: - # Only the reader can know how many chunks are going to be emitted - return self.reader.get_num_chunks(self.file) + return self.loader.get_num_sliding_chunks( + self.file, self.chunk_duration, self.step_duration + ) def _check_print_time(self, times: List[float]): if self.profile: @@ -213,15 +110,24 @@ def _check_print_time(self, times: List[float]): def read(self): """Send each chunk of samples through the stream""" times = [] - for waveform in self.reader.iterate(self.file): + chunks = self.loader.load_sliding_chunks( + self.file, self.chunk_duration, self.step_duration + ) + for i, waveform in enumerate(chunks): + window = SlidingWindow( + start=i * self.step_duration, + duration=self.resolution, + step=self.resolution + ) + chunk = SlidingWindowFeature(waveform.T, window) try: if self.profile: # Profiling assumes that on_next is blocking start_time = time.monotonic() - self.stream.on_next(waveform) + self.stream.on_next(chunk) times.append(time.monotonic() - start_time) else: - self.stream.on_next(waveform) + self.stream.on_next(chunk) except Exception as e: self._check_print_time(times) self.stream.on_error(e) diff --git a/src/diart/stream.py b/src/diart/stream.py index 4d4bb819..da6ba97c 100644 --- a/src/diart/stream.py +++ b/src/diart/stream.py @@ -43,13 +43,11 @@ args.source = Path(args.source).expanduser() args.output = args.source.parent if args.output is None else Path(args.output) audio_source = src.FileAudioSource( - file=args.source, - uri=args.source.stem, - reader=src.RegularAudioFileReader( - pipeline.config.sample_rate, - pipeline.config.duration, - pipeline.config.step, - ), + args.source, + args.source.stem, + pipeline.config.sample_rate, + pipeline.config.duration, + pipeline.config.step, ) else: args.output = Path("~/").expanduser() if args.output is None else Path(args.output) From 935f8f45d7bef3ad8056837f98d60fb00fb8761d Mon Sep 17 00:00:00 2001 From: Juan Manuel Coria Date: Thu, 9 Jun 2022 17:40:32 +0200 Subject: [PATCH 11/29] Move real time profiling from audio source to OnlineSpeakerDiarization pipeline --- src/diart/benchmark.py | 2 +- src/diart/inference.py | 2 - src/diart/operators.py | 34 ++++++++++++++ src/diart/pipelines.py | 103 ++++++++++++++++++++--------------------- src/diart/sources.py | 27 +---------- src/diart/stream.py | 2 +- src/diart/utils.py | 6 ++- 7 files changed, 92 insertions(+), 84 deletions(-) diff --git a/src/diart/benchmark.py b/src/diart/benchmark.py index b161b708..713b54b6 100644 --- a/src/diart/benchmark.py +++ b/src/diart/benchmark.py @@ -38,6 +38,6 @@ beta=args.beta, max_speakers=args.max_speakers, device=torch.device("cpu") if args.cpu else None, - )) + ), profile=True) benchmark(pipeline, args.batch_size) diff --git a/src/diart/inference.py b/src/diart/inference.py index b45de89f..e1cd95a3 100644 --- a/src/diart/inference.py +++ b/src/diart/inference.py @@ -142,8 +142,6 @@ def __call__(self, pipeline: OnlineSpeakerDiarization, batch_size: int = 32) -> pipeline.config.sample_rate, pipeline.config.duration, pipeline.config.step, - # Benchmark the processing time of a single chunk - profile=True, ) observable = pipeline.from_source(source, output_waveform=False) else: diff --git a/src/diart/operators.py b/src/diart/operators.py index 646add85..cdd92dba 100644 --- a/src/diart/operators.py +++ b/src/diart/operators.py @@ -1,3 +1,4 @@ +import time from dataclasses import dataclass from typing import Callable, Optional, List, Any, Tuple, Text @@ -300,3 +301,36 @@ def progress( on_completed=lambda: pbar.close(), ) ) + + +class _Chronometer: + def __init__(self): + self.current_start_time = None + self.history = [] + + def start(self): + self.current_start_time = time.monotonic() + + def stop(self): + end_time = time.monotonic() - self.current_start_time + self.current_start_time = None + self.history.append(end_time) + + def report(self): + print( + f"Stream took {np.mean(self.history).item():.3f} " + f"(+/-{np.std(self.history).item():.3f}) seconds/chunk " + f"-- based on {len(self.history)} chunks" + ) + + +def profile(observable: rx.Observable, operations: List[Operator]) -> rx.Observable: + chronometer = _Chronometer() + return observable.pipe( + ops.do_action(lambda _: chronometer.start()), + *operations, + ops.do_action( + on_next=lambda _: chronometer.stop(), + on_completed=chronometer.report, + ) + ) diff --git a/src/diart/pipelines.py b/src/diart/pipelines.py index 4ff453ea..1c61de96 100644 --- a/src/diart/pipelines.py +++ b/src/diart/pipelines.py @@ -1,6 +1,6 @@ import math from pathlib import Path -from typing import Optional, Text +from typing import Optional, Text, List, Tuple import numpy as np import rx @@ -14,6 +14,7 @@ from . import models as m from . import operators as dops from . import sources as src +from . import utils class PipelineConfig: @@ -83,18 +84,15 @@ class OnlineSpeakerTracking: def __init__(self, config: PipelineConfig): self.config = config - def from_model_streams( + def get_end_time(self, duration: Optional[float]) -> Optional[float]: + return None if duration is None else self.config.last_chunk_end_time(duration) + + def get_operators( self, - uri: Text, + source_uri: Text, source_duration: Optional[float], - segmentation_stream: rx.Observable, - embedding_stream: rx.Observable, - audio_chunk_stream: Optional[rx.Observable] = None, - ) -> rx.Observable: - end_time = None - if source_duration is not None: - end_time = self.config.last_chunk_end_time(source_duration) - # Initialize clustering and aggregation modules + output_waveform: bool = True + ) -> List[dops.Operator]: clustering = blocks.OnlineSpeakerClustering( self.config.tau_active, self.config.rho_update, @@ -102,39 +100,33 @@ def from_model_streams( "cosine", self.config.max_speakers, ) - aggregation = blocks.DelayedAggregation( + end_time = self.get_end_time(source_duration) + pred_aggregation = blocks.DelayedAggregation( self.config.step, self.config.latency, strategy="hamming", stream_end=end_time ) - binarize = blocks.Binarize(uri, self.config.tau_active) - - # Join segmentation and embedding streams to update a background clustering model - # while regulating latency and binarizing the output - pipeline = rx.zip(segmentation_stream, embedding_stream).pipe( - ops.starmap(clustering), + audio_aggregation = blocks.DelayedAggregation( + self.config.step, self.config.latency, strategy="first", stream_end=end_time + ) + binarize = blocks.Binarize(source_uri, self.config.tau_active) + return [ + # Identify global speakers with online clustering + ops.starmap(lambda wav, seg, emb: (wav, clustering(seg, emb))), # Buffer 'num_overlapping' sliding chunks with a step of 1 chunk - dops.buffer_slide(aggregation.num_overlapping_windows), + dops.buffer_slide(pred_aggregation.num_overlapping_windows), # Aggregate overlapping output windows - ops.map(aggregation), + ops.map(lambda buffers: utils.unzip(buffers)), + ops.starmap(lambda wav_buffer, pred_buffer: ( + audio_aggregation(wav_buffer), pred_aggregation(pred_buffer) + )), # Binarize output - ops.map(binarize), - ) - # Add corresponding waveform to the output - if audio_chunk_stream is not None: - window_selector = blocks.DelayedAggregation( - self.config.step, self.config.latency, strategy="first", stream_end=end_time - ) - waveform_stream = audio_chunk_stream.pipe( - dops.buffer_slide(window_selector.num_overlapping_windows), - ops.map(window_selector), - ) - return rx.zip(pipeline, waveform_stream) - # No waveform needed, add None for consistency - return pipeline.pipe(ops.map(lambda ann: (ann, None))) + ops.starmap(lambda wav, pred: (binarize(pred), wav if output_waveform else None)), + ] class OnlineSpeakerDiarization: - def __init__(self, config: PipelineConfig): + def __init__(self, config: PipelineConfig, profile: bool = False): self.config = config + self.profile = profile self.segmentation = blocks.SpeakerSegmentation(config.segmentation, config.device) self.embedding = blocks.OverlapAwareSpeakerEmbedding( config.embedding, config.gamma, config.beta, norm=1, device=config.device @@ -143,27 +135,26 @@ def __init__(self, config: PipelineConfig): msg = f"Latency should be in the range [{config.step}, {config.duration}]" assert config.step <= config.latency <= config.duration, msg - def from_source( - self, - source: src.AudioSource, - output_waveform: bool = True - ) -> rx.Observable: + def from_source(self, source: src.AudioSource, output_waveform: bool = True) -> rx.Observable: msg = f"Audio source has sample rate {source.sample_rate}, expected {self.config.sample_rate}" assert source.sample_rate == self.config.sample_rate, msg + operators = [] # Regularize the stream to a specific chunk duration and step - regular_stream = source.stream if not source.is_regular: - regular_stream = source.stream.pipe( - dops.regularize_stream(self.config.duration, self.config.step, source.sample_rate) - ) - # Branch the stream to calculate chunk segmentation - seg_stream = regular_stream.pipe(ops.map(self.segmentation)) - # Join audio and segmentation stream to calculate overlap-aware speaker embeddings - emb_stream = rx.zip(regular_stream, seg_stream).pipe(ops.starmap(self.embedding)) - chunk_stream = regular_stream if output_waveform else None - return self.speaker_tracking.from_model_streams( - source.uri, source.duration, seg_stream, emb_stream, chunk_stream - ) + operators.append(dops.regularize_stream( + self.config.duration, self.config.step, source.sample_rate + )) + operators += [ + # Extract segmentation and keep audio + ops.map(lambda wav: (wav, self.segmentation(wav))), + # Extract embeddings and keep segmentation + ops.starmap(lambda wav, seg: (wav, seg, self.embedding(wav, seg))), + ] + # Add speaker tracking + operators += self.speaker_tracking.get_operators(source.uri, source.duration, output_waveform) + if self.profile: + return dops.profile(source.stream, operators) + return source.stream.pipe(*operators) def from_file( self, @@ -225,6 +216,10 @@ def from_file( ) # Build speaker tracking pipeline - return self.speaker_tracking.from_model_streams( - Path(filepath).stem, loader.get_duration(filepath), seg_stream, emb_stream, chunk_stream + return rx.zip(chunk_stream, seg_stream, emb_stream).pipe( + *self.speaker_tracking.get_operators( + Path(filepath).stem, + loader.get_duration(filepath), + output_waveform, + ) ) diff --git a/src/diart/sources.py b/src/diart/sources.py index 8af379e9..ad6c961a 100644 --- a/src/diart/sources.py +++ b/src/diart/sources.py @@ -1,8 +1,6 @@ -import time from queue import SimpleQueue -from typing import Text, Optional, List +from typing import Text, Optional -import numpy as np import sounddevice as sd from pyannote.core import SlidingWindowFeature, SlidingWindow from rx.subject import Subject @@ -62,8 +60,6 @@ class FileAudioSource(AudioSource): Duration of each chunk in seconds. Defaults to 5s. step_duration: float Duration of the step between consecutive chunks in seconds. Defaults to 500ms. - profile: bool - If True, prints the average processing time of emitting a chunk. Defaults to False. """ def __init__( self, @@ -72,7 +68,6 @@ def __init__( sample_rate: int, chunk_duration: float = 5, step_duration: float = 0.5, - profile: bool = False, ): super().__init__(uri, sample_rate) self.loader = AudioLoader(sample_rate, mono=True) @@ -80,7 +75,6 @@ def __init__( self.file = file self.chunk_duration = chunk_duration self.step_duration = step_duration - self.profile = profile self.resolution = 1 / sample_rate @property @@ -99,17 +93,8 @@ def length(self) -> Optional[int]: self.file, self.chunk_duration, self.step_duration ) - def _check_print_time(self, times: List[float]): - if self.profile: - print( - f"File {self.uri}: took {np.mean(times).item():.2f} seconds/chunk " - f"(+/- {np.std(times).item():.2f} seconds/chunk) " - f"-- based on {len(times)} inputs" - ) - def read(self): """Send each chunk of samples through the stream""" - times = [] chunks = self.loader.load_sliding_chunks( self.file, self.chunk_duration, self.step_duration ) @@ -121,18 +106,10 @@ def read(self): ) chunk = SlidingWindowFeature(waveform.T, window) try: - if self.profile: - # Profiling assumes that on_next is blocking - start_time = time.monotonic() - self.stream.on_next(chunk) - times.append(time.monotonic() - start_time) - else: - self.stream.on_next(chunk) + self.stream.on_next(chunk) except Exception as e: - self._check_print_time(times) self.stream.on_error(e) self.stream.on_completed() - self._check_print_time(times) class MicrophoneAudioSource(AudioSource): diff --git a/src/diart/stream.py b/src/diart/stream.py index da6ba97c..e843ed08 100644 --- a/src/diart/stream.py +++ b/src/diart/stream.py @@ -36,7 +36,7 @@ beta=args.beta, max_speakers=args.max_speakers, device=torch.device("cpu") if args.cpu else None, - )) + ), profile=True) # Manage audio source if args.source != "microphone": diff --git a/src/diart/utils.py b/src/diart/utils.py index 7e8214a9..ff843193 100644 --- a/src/diart/utils.py +++ b/src/diart/utils.py @@ -1,9 +1,13 @@ -from typing import Optional +from typing import Optional, List, Tuple import matplotlib.pyplot as plt from pyannote.core import Annotation, Segment, SlidingWindowFeature, notebook +def unzip(zipped: List[Tuple]) -> Tuple: + return tuple(zip(*zipped)) + + def visualize_feature(duration: Optional[float] = None): def apply(feature: SlidingWindowFeature): if duration is None: From f841b5643d384d8f35f83ee230facd2c704dcbf2 Mon Sep 17 00:00:00 2001 From: Juan Manuel Coria Date: Thu, 9 Jun 2022 17:48:58 +0200 Subject: [PATCH 12/29] Make pipeline always output the waveform alongside the prediction --- src/diart/inference.py | 2 +- src/diart/pipelines.py | 7 +++---- src/diart/sinks.py | 27 +++++++++------------------ 3 files changed, 13 insertions(+), 23 deletions(-) diff --git a/src/diart/inference.py b/src/diart/inference.py index e1cd95a3..9d384f73 100644 --- a/src/diart/inference.py +++ b/src/diart/inference.py @@ -143,7 +143,7 @@ def __call__(self, pipeline: OnlineSpeakerDiarization, batch_size: int = 32) -> pipeline.config.duration, pipeline.config.step, ) - observable = pipeline.from_source(source, output_waveform=False) + observable = pipeline.from_source(source) else: observable = pipeline.from_file( filepath, diff --git a/src/diart/pipelines.py b/src/diart/pipelines.py index 1c61de96..ddaf484f 100644 --- a/src/diart/pipelines.py +++ b/src/diart/pipelines.py @@ -91,7 +91,6 @@ def get_operators( self, source_uri: Text, source_duration: Optional[float], - output_waveform: bool = True ) -> List[dops.Operator]: clustering = blocks.OnlineSpeakerClustering( self.config.tau_active, @@ -119,7 +118,7 @@ def get_operators( audio_aggregation(wav_buffer), pred_aggregation(pred_buffer) )), # Binarize output - ops.starmap(lambda wav, pred: (binarize(pred), wav if output_waveform else None)), + ops.starmap(lambda wav, pred: (binarize(pred), wav)), ] @@ -135,7 +134,7 @@ def __init__(self, config: PipelineConfig, profile: bool = False): msg = f"Latency should be in the range [{config.step}, {config.duration}]" assert config.step <= config.latency <= config.duration, msg - def from_source(self, source: src.AudioSource, output_waveform: bool = True) -> rx.Observable: + def from_source(self, source: src.AudioSource) -> rx.Observable: msg = f"Audio source has sample rate {source.sample_rate}, expected {self.config.sample_rate}" assert source.sample_rate == self.config.sample_rate, msg operators = [] @@ -151,7 +150,7 @@ def from_source(self, source: src.AudioSource, output_waveform: bool = True) -> ops.starmap(lambda wav, seg: (wav, seg, self.embedding(wav, seg))), ] # Add speaker tracking - operators += self.speaker_tracking.get_operators(source.uri, source.duration, output_waveform) + operators += self.speaker_tracking.get_operators(source.uri, source.duration) if self.profile: return dops.profile(source.stream, operators) return source.stream.pipe(*operators) diff --git a/src/diart/sinks.py b/src/diart/sinks.py index 0d33bf01..993e95a8 100644 --- a/src/diart/sinks.py +++ b/src/diart/sinks.py @@ -24,7 +24,7 @@ def patch_rttm(self): with open(self.path, 'w') as file: annotation.support(self.patch_collar).write_rttm(file) - def on_next(self, value: Tuple[Annotation, Optional[SlidingWindowFeature]]): + def on_next(self, value: Tuple[Annotation, SlidingWindowFeature]): with open(self.path, 'a') as file: value[0].write_rttm(file) @@ -58,16 +58,14 @@ def __init__( self.latency = latency self.figure, self.axs, self.num_axs = None, None, -1 - def _init_num_axs(self, waveform: Optional[SlidingWindowFeature]): + def _init_num_axs(self): if self.num_axs == -1: - self.num_axs = 1 - if waveform is not None: - self.num_axs += 1 + self.num_axs = 2 if self.reference is not None: self.num_axs += 1 - def _init_figure(self, waveform: Optional[SlidingWindowFeature]): - self._init_num_axs(waveform) + def _init_figure(self): + self._init_num_axs() self.figure, self.axs = plt.subplots(self.num_axs, 1, figsize=(10, 2 * self.num_axs)) if self.num_axs == 1: self.axs = [self.axs] @@ -87,7 +85,7 @@ def on_next(self, values: Tuple[Annotation, SlidingWindowFeature, float]): prediction, waveform, real_time = values # Initialize figure if first call if self.figure is None: - self._init_figure(waveform) + self._init_figure() # Clear previous plots self._clear_axs() # Set plot bounds @@ -100,16 +98,9 @@ def on_next(self, values: Tuple[Annotation, SlidingWindowFeature, float]): prediction.rename_labels(mapping=mapping, copy=False) notebook.plot_annotation(prediction, self.axs[0]) self.axs[0].set_title("Output") - if self.num_axs == 2: - if waveform is not None: - notebook.plot_feature(waveform, self.axs[1]) - self.axs[1].set_title("Audio") - elif self.reference is not None: - notebook.plot_annotation(self.reference, self.axs[1]) - self.axs[1].set_title("Reference") - elif self.num_axs == 3: - notebook.plot_feature(waveform, self.axs[1]) - self.axs[1].set_title("Audio") + notebook.plot_feature(waveform, self.axs[1]) + self.axs[1].set_title("Audio") + if self.num_axs == 3: notebook.plot_annotation(self.reference, self.axs[2]) self.axs[2].set_title("Reference") From 0b0116b7b9d447711fbfb63a7e689437b960c867 Mon Sep 17 00:00:00 2001 From: Juan Manuel Coria Date: Thu, 9 Jun 2022 18:24:08 +0200 Subject: [PATCH 13/29] Move feature pre-calculation from OnlineSpeakerDiarization to a new type of streaming source: PrecalculatedFeaturesAudioSource --- src/diart/inference.py | 28 +++++++------ src/diart/pipelines.py | 83 ++++----------------------------------- src/diart/sources.py | 89 +++++++++++++++++++++++++++++++++++++++++- 3 files changed, 111 insertions(+), 89 deletions(-) diff --git a/src/diart/inference.py b/src/diart/inference.py index 9d384f73..3e98e51f 100644 --- a/src/diart/inference.py +++ b/src/diart/inference.py @@ -48,7 +48,7 @@ def __call__(self, pipeline: OnlineSpeakerDiarization, source: src.AudioSource) """ rttm_path = self.output_path / f"{source.uri}.rttm" rttm_writer = RTTMWriter(path=rttm_path) - observable = pipeline.from_source(source).pipe( + observable = pipeline.from_audio_source(source).pipe( dops.progress(f"Streaming {source.uri}", total=source.length, leave=True) ) if not self.do_plot: @@ -134,7 +134,6 @@ def __call__(self, pipeline: OnlineSpeakerDiarization, batch_size: int = 32) -> ) # Stream fully online if batch size is 1 or lower - source = None if batch_size < 2: source = src.FileAudioSource( filepath, @@ -143,26 +142,31 @@ def __call__(self, pipeline: OnlineSpeakerDiarization, batch_size: int = 32) -> pipeline.config.duration, pipeline.config.step, ) - observable = pipeline.from_source(source) + observable = pipeline.from_audio_source(source) else: - observable = pipeline.from_file( + source = src.PrecalculatedFeaturesAudioSource( filepath, - batch_size=batch_size, - desc=f"Pre-calculating {filepath.stem} ({i + 1}/{num_audio_files})", + filepath.stem, + pipeline.config.sample_rate, + pipeline.segmentation, + pipeline.embedding, + pipeline.config.duration, + pipeline.config.step, + batch_size, + # TODO decouple progress bar + progress_msg=f"Pre-calculating {filepath.stem} ({i + 1}/{num_audio_files})", ) + observable = pipeline.from_feature_source(source) observable.pipe( dops.progress( desc=f"Streaming {filepath.stem} ({i + 1}/{num_audio_files})", total=num_chunks, - leave=source is None + leave=isinstance(source, src.PrecalculatedFeaturesAudioSource) ) - ).subscribe( - RTTMWriter(path=self.output_path / f"{filepath.stem}.rttm") - ) + ).subscribe(RTTMWriter(path=self.output_path / f"{filepath.stem}.rttm")) - if source is not None: - source.read() + source.read() # Run evaluation if self.reference_path is not None: diff --git a/src/diart/pipelines.py b/src/diart/pipelines.py index ddaf484f..e4c9a987 100644 --- a/src/diart/pipelines.py +++ b/src/diart/pipelines.py @@ -87,11 +87,7 @@ def __init__(self, config: PipelineConfig): def get_end_time(self, duration: Optional[float]) -> Optional[float]: return None if duration is None else self.config.last_chunk_end_time(duration) - def get_operators( - self, - source_uri: Text, - source_duration: Optional[float], - ) -> List[dops.Operator]: + def get_operators(self, source: src.AudioSource) -> List[dops.Operator]: clustering = blocks.OnlineSpeakerClustering( self.config.tau_active, self.config.rho_update, @@ -99,14 +95,14 @@ def get_operators( "cosine", self.config.max_speakers, ) - end_time = self.get_end_time(source_duration) + end_time = self.get_end_time(source.duration) pred_aggregation = blocks.DelayedAggregation( self.config.step, self.config.latency, strategy="hamming", stream_end=end_time ) audio_aggregation = blocks.DelayedAggregation( self.config.step, self.config.latency, strategy="first", stream_end=end_time ) - binarize = blocks.Binarize(source_uri, self.config.tau_active) + binarize = blocks.Binarize(source.uri, self.config.tau_active) return [ # Identify global speakers with online clustering ops.starmap(lambda wav, seg, emb: (wav, clustering(seg, emb))), @@ -134,7 +130,7 @@ def __init__(self, config: PipelineConfig, profile: bool = False): msg = f"Latency should be in the range [{config.step}, {config.duration}]" assert config.step <= config.latency <= config.duration, msg - def from_source(self, source: src.AudioSource) -> rx.Observable: + def from_audio_source(self, source: src.AudioSource) -> rx.Observable: msg = f"Audio source has sample rate {source.sample_rate}, expected {self.config.sample_rate}" assert source.sample_rate == self.config.sample_rate, msg operators = [] @@ -150,75 +146,10 @@ def from_source(self, source: src.AudioSource) -> rx.Observable: ops.starmap(lambda wav, seg: (wav, seg, self.embedding(wav, seg))), ] # Add speaker tracking - operators += self.speaker_tracking.get_operators(source.uri, source.duration) + operators += self.speaker_tracking.get_operators(source) if self.profile: return dops.profile(source.stream, operators) return source.stream.pipe(*operators) - def from_file( - self, - filepath: src.FilePath, - output_waveform: bool = False, - batch_size: int = 32, - desc: Optional[Text] = None, - ) -> rx.Observable: - loader = src.AudioLoader(self.config.sample_rate, mono=True) - - # Split audio into chunks - chunks = rearrange( - loader.load_sliding_chunks(filepath, self.config.duration, self.config.step), - "chunk channel sample -> chunk sample channel" - ) - num_chunks = chunks.shape[0] - - # Set progress if needed - iterator = range(0, num_chunks, batch_size) - if desc is not None: - total = int(math.ceil(num_chunks / batch_size)) - iterator = tqdm(iterator, desc=desc, total=total, unit="batch", leave=False) - - # Pre-calculate segmentation and embeddings - segmentation, embeddings = [], [] - for i in iterator: - i_end = i + batch_size - if i_end > num_chunks: - i_end = num_chunks - batch = chunks[i:i_end] - seg = self.segmentation(batch) - # Edge case: add batch dimension if i == i_end + 1 - if seg.ndim == 2: - seg = seg[np.newaxis] - emb = self.embedding(batch, seg) - # Edge case: add batch dimension if i == i_end + 1 - if emb.ndim == 2: - emb = emb.unsqueeze(0) - segmentation.append(seg) - embeddings.append(emb) - segmentation = np.vstack(segmentation) - embeddings = torch.vstack(embeddings) - - # Stream pre-calculated segmentation, embeddings and chunks - resolution = self.config.duration / segmentation.shape[1] - seg_stream = rx.range(0, num_chunks).pipe( - ops.map(lambda i: SlidingWindowFeature( - segmentation[i], SlidingWindow(resolution, resolution, i * self.config.step) - )) - ) - emb_stream = rx.range(0, num_chunks).pipe(ops.map(lambda i: embeddings[i])) - wav_resolution = 1 / self.config.sample_rate - chunk_stream = None - if output_waveform: - chunk_stream = rx.range(0, num_chunks).pipe( - ops.map(lambda i: SlidingWindowFeature( - chunks[i], SlidingWindow(wav_resolution, wav_resolution, i * self.config.step) - )) - ) - - # Build speaker tracking pipeline - return rx.zip(chunk_stream, seg_stream, emb_stream).pipe( - *self.speaker_tracking.get_operators( - Path(filepath).stem, - loader.get_duration(filepath), - output_waveform, - ) - ) + def from_feature_source(self, source: src.AudioSource) -> rx.Observable: + return source.stream.pipe(*self.speaker_tracking.get_operators(source)) diff --git a/src/diart/sources.py b/src/diart/sources.py index ad6c961a..a87b8ce7 100644 --- a/src/diart/sources.py +++ b/src/diart/sources.py @@ -1,13 +1,20 @@ +import math from queue import SimpleQueue -from typing import Text, Optional +from typing import Text, Optional, Callable +import numpy as np import sounddevice as sd +import torch +from einops import rearrange from pyannote.core import SlidingWindowFeature, SlidingWindow from rx.subject import Subject +from tqdm import tqdm from .audio import FilePath, AudioLoader +from .features import TemporalFeatures +# TODO rename this to something else since the same API is also used to stream features class AudioSource: """Represents a source of audio that can start streaming via the `stream` property. @@ -112,6 +119,86 @@ def read(self): self.stream.on_completed() +class PrecalculatedFeaturesAudioSource(FileAudioSource): + def __init__( + self, + file: FilePath, + uri: Text, + sample_rate: int, + segmentation: Callable[[TemporalFeatures], TemporalFeatures], + embedding: Callable[[TemporalFeatures, TemporalFeatures], TemporalFeatures], + chunk_duration: float = 5, + step_duration: float = 0.5, + batch_size: int = 32, + progress_msg: Optional[Text] = None, + ): + super().__init__(file, uri, sample_rate, chunk_duration, step_duration) + self.segmentation = segmentation + self.embedding = embedding + self.batch_size = batch_size + self.progress_msg = progress_msg + + def read(self): + # Split audio into chunks + chunks = rearrange( + self.loader.load_sliding_chunks( + self.file, self.chunk_duration, self.step_duration + ), + "chunk channel sample -> chunk sample channel" + ) + num_chunks = chunks.shape[0] + + # Set progress if needed + iterator = range(0, num_chunks, self.batch_size) + if self.progress_msg is not None: + total = int(math.ceil(num_chunks / self.batch_size)) + iterator = tqdm(iterator, desc=self.progress_msg, total=total, unit="batch", leave=False) + + # Pre-calculate segmentation and embeddings + segmentation, embeddings = [], [] + for i in iterator: + i_end = i + self.batch_size + if i_end > num_chunks: + i_end = num_chunks + batch = chunks[i:i_end] + seg = self.segmentation(batch) + # Edge case: add batch dimension if i == i_end + 1 + if seg.ndim == 2: + seg = seg[np.newaxis] + emb = self.embedding(batch, seg) + # Edge case: add batch dimension if i == i_end + 1 + if emb.ndim == 2: + emb = emb.unsqueeze(0) + segmentation.append(seg) + embeddings.append(emb) + segmentation = np.vstack(segmentation) + embeddings = torch.vstack(embeddings) + + # Stream pre-calculated segmentation, embeddings and chunks + seg_resolution = self.chunk_duration / segmentation.shape[1] + for i in range(num_chunks): + chunk_window = SlidingWindow( + start=i * self.step_duration, + duration=self.resolution, + step=self.resolution, + ) + seg_window = SlidingWindow( + start=i * self.step_duration, + duration=seg_resolution, + step=seg_resolution, + ) + try: + self.stream.on_next(( + SlidingWindowFeature(chunks[i], chunk_window), + SlidingWindowFeature(segmentation[i], seg_window), + embeddings[i] + )) + except Exception as e: + self.stream.on_error(e) + + self.stream.on_completed() + + class MicrophoneAudioSource(AudioSource): """Represents an audio source tied to the default microphone available""" From ba3a53b41e3782288fd75789e51ae7e406645681 Mon Sep 17 00:00:00 2001 From: Juan Manuel Coria Date: Fri, 10 Jun 2022 10:43:27 +0200 Subject: [PATCH 14/29] Refactoring and minor improvements --- README.md | 2 +- src/diart/audio.py | 2 +- src/diart/benchmark.py | 24 ++++-------------------- src/diart/inference.py | 3 --- src/diart/mapping.py | 24 ++---------------------- src/diart/operators.py | 3 ++- src/diart/pipelines.py | 30 +++++++++++++++++++++--------- src/diart/sinks.py | 16 ++++++++++++++-- src/diart/sources.py | 10 +++------- src/diart/stream.py | 24 ++++-------------------- 10 files changed, 52 insertions(+), 86 deletions(-) diff --git a/README.md b/README.md index 29012f29..eb90e123 100644 --- a/README.md +++ b/README.md @@ -102,7 +102,7 @@ embedding = OverlapAwareSpeakerEmbedding(emb_model) mic = MicrophoneAudioSource(seg_model.get_sample_rate()) # Reformat microphone stream. Defaults to 5s duration and 500ms shift -regular_stream = mic.stream.pipe(dops.regularize_stream(seg_model.get_sample_rate())) +regular_stream = mic.stream.pipe(dops.regularize_audio_stream(seg_model.get_sample_rate())) # Branch the microphone stream to calculate segmentation segmentation_stream = regular_stream.pipe(ops.map(segmentation)) # Join audio and segmentation stream to calculate speaker embeddings diff --git a/src/diart/audio.py b/src/diart/audio.py index e607ba5f..764ac3ee 100644 --- a/src/diart/audio.py +++ b/src/diart/audio.py @@ -98,4 +98,4 @@ def get_num_sliding_chunks(self, filepath: FilePath, chunk_duration: float, step Duration of the step between chunks in seconds. """ numerator = self.get_duration(filepath) - chunk_duration + step_duration - return int(np.ceil(numerator / step_duration)) \ No newline at end of file + return int(np.ceil(numerator / step_duration)) diff --git a/src/diart/benchmark.py b/src/diart/benchmark.py index 713b54b6..1ac0bafd 100644 --- a/src/diart/benchmark.py +++ b/src/diart/benchmark.py @@ -1,13 +1,10 @@ import argparse -import torch - import diart.argdoc as argdoc from diart.inference import Benchmark from diart.pipelines import OnlineSpeakerDiarization, PipelineConfig if __name__ == "__main__": - # Define script arguments parser = argparse.ArgumentParser() parser.add_argument("root", type=str, help="Directory with audio files CONVERSATION.(wav|flac|m4a|...)") parser.add_argument("--reference", type=str, help="Optional. Directory with RTTM files CONVERSATION.rttm. Names must match audio files") @@ -24,20 +21,7 @@ parser.add_argument("--output", type=str, help=f"{argdoc.OUTPUT}. Defaults to `root`") args = parser.parse_args() - # Set benchmark configuration - benchmark = Benchmark(args.root, args.reference, args.output) - - # Define online speaker diarization pipeline - pipeline = OnlineSpeakerDiarization(PipelineConfig( - step=args.step, - latency=args.latency, - tau_active=args.tau, - rho_update=args.rho, - delta_new=args.delta, - gamma=args.gamma, - beta=args.beta, - max_speakers=args.max_speakers, - device=torch.device("cpu") if args.cpu else None, - ), profile=True) - - benchmark(pipeline, args.batch_size) + Benchmark(args.root, args.reference, args.output)( + OnlineSpeakerDiarization(PipelineConfig.from_namespace(args), profile=True), + args.batch_size + ) diff --git a/src/diart/inference.py b/src/diart/inference.py index 3e98e51f..6e204247 100644 --- a/src/diart/inference.py +++ b/src/diart/inference.py @@ -137,7 +137,6 @@ def __call__(self, pipeline: OnlineSpeakerDiarization, batch_size: int = 32) -> if batch_size < 2: source = src.FileAudioSource( filepath, - filepath.stem, pipeline.config.sample_rate, pipeline.config.duration, pipeline.config.step, @@ -146,14 +145,12 @@ def __call__(self, pipeline: OnlineSpeakerDiarization, batch_size: int = 32) -> else: source = src.PrecalculatedFeaturesAudioSource( filepath, - filepath.stem, pipeline.config.sample_rate, pipeline.segmentation, pipeline.embedding, pipeline.config.duration, pipeline.config.step, batch_size, - # TODO decouple progress bar progress_msg=f"Pre-calculating {filepath.stem} ({i + 1}/{num_audio_files})", ) observable = pipeline.from_feature_source(source) diff --git a/src/diart/mapping.py b/src/diart/mapping.py index 7327ac6e..01465086 100644 --- a/src/diart/mapping.py +++ b/src/diart/mapping.py @@ -4,7 +4,7 @@ import numpy as np from pyannote.core.utils.distance import cdist -from scipy.optimize import linear_sum_assignment +from scipy.optimize import linear_sum_assignment as lsap class MappingMatrixObjective: @@ -12,7 +12,7 @@ def invalid_tensor(self, shape: Union[Tuple, int]) -> np.ndarray: return np.ones(shape) * self.invalid_value def optimal_assignments(self, matrix: np.ndarray) -> List[int]: - return list(linear_sum_assignment(matrix, self.maximize)[1]) + return list(lsap(matrix, self.maximize)[1]) def mapped_indices(self, matrix: np.ndarray, axis: int) -> List[int]: # Entries full of invalid_value are not mapped @@ -137,26 +137,6 @@ def dist( dist_matrix = cdist(embeddings1, embeddings2, metric=metric) return SpeakerMap(dist_matrix, MinimizationObjective()) - @staticmethod - def clf_output(predictions: np.ndarray, pad_to: Optional[int] = None) -> SpeakerMap: - """ - Parameters - ---------- - predictions : np.ndarray, (num_local_speakers, num_global_speakers) - Probability outputs of a speaker embedding classifier - pad_to : int, optional - Pad num_global_speakers to this value. - Useful to deal with unknown speakers that may appear in the future. - Defaults to no padding - """ - num_locals, num_globals = predictions.shape - objective = MaximizationObjective(max_value=1) - if pad_to is not None and num_globals < pad_to: - padding = np.ones((num_locals, pad_to - num_globals)) - padding = objective.invalid_value * padding - predictions = np.concatenate([predictions, padding], axis=1) - return SpeakerMap(predictions, objective) - class SpeakerMap: def __init__(self, mapping_matrix: np.ndarray, objective: MappingMatrixObjective): diff --git a/src/diart/operators.py b/src/diart/operators.py index cdd92dba..c8575866 100644 --- a/src/diart/operators.py +++ b/src/diart/operators.py @@ -37,7 +37,7 @@ def call_fn(state) -> SlidingWindowFeature: return call_fn -def regularize_stream( +def regularize_audio_stream( duration: float = 5, step: float = 0.5, sample_rate: int = 16000 @@ -331,6 +331,7 @@ def profile(observable: rx.Observable, operations: List[Operator]) -> rx.Observa *operations, ops.do_action( on_next=lambda _: chronometer.stop(), + on_error=lambda _: chronometer.report(), on_completed=chronometer.report, ) ) diff --git a/src/diart/pipelines.py b/src/diart/pipelines.py index e4c9a987..4ccbf1d4 100644 --- a/src/diart/pipelines.py +++ b/src/diart/pipelines.py @@ -1,14 +1,9 @@ -import math -from pathlib import Path -from typing import Optional, Text, List, Tuple +from argparse import Namespace +from typing import Optional, List -import numpy as np import rx import rx.operators as ops import torch -from einops import rearrange -from pyannote.core import SlidingWindowFeature, SlidingWindow -from tqdm import tqdm from . import blocks from . import models as m @@ -68,6 +63,23 @@ def __init__( if self.device is None: self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + @staticmethod + def from_namespace(args: Namespace) -> 'PipelineConfig': + return PipelineConfig( + segmentation=getattr(args, "segmentation", None), + embedding=getattr(args, "embedding", None), + duration=getattr(args, "duration", None), + step=args.step, + latency=args.latency, + tau_active=args.tau, + rho_update=args.rho, + delta_new=args.delta, + gamma=args.gamma, + beta=args.beta, + max_speakers=args.max_speakers, + device=torch.device("cpu") if args.cpu else None, + ) + def last_chunk_end_time(self, conv_duration: float) -> Optional[float]: """ Return the end time of the last chunk for a given conversation duration. @@ -109,7 +121,7 @@ def get_operators(self, source: src.AudioSource) -> List[dops.Operator]: # Buffer 'num_overlapping' sliding chunks with a step of 1 chunk dops.buffer_slide(pred_aggregation.num_overlapping_windows), # Aggregate overlapping output windows - ops.map(lambda buffers: utils.unzip(buffers)), + ops.map(utils.unzip), ops.starmap(lambda wav_buffer, pred_buffer: ( audio_aggregation(wav_buffer), pred_aggregation(pred_buffer) )), @@ -136,7 +148,7 @@ def from_audio_source(self, source: src.AudioSource) -> rx.Observable: operators = [] # Regularize the stream to a specific chunk duration and step if not source.is_regular: - operators.append(dops.regularize_stream( + operators.append(dops.regularize_audio_stream( self.config.duration, self.config.step, source.sample_rate )) operators += [ diff --git a/src/diart/sinks.py b/src/diart/sinks.py index 993e95a8..e859aa77 100644 --- a/src/diart/sinks.py +++ b/src/diart/sinks.py @@ -10,6 +10,10 @@ from typing_extensions import Literal +class WindowClosedException(Exception): + pass + + class RTTMWriter(Observer): def __init__(self, path: Union[Path, Text], patch_collar: float = 0.05): super().__init__() @@ -57,6 +61,10 @@ def __init__( self.window_duration = duration self.latency = latency self.figure, self.axs, self.num_axs = None, None, -1 + self.window_closed = False + + def _on_window_closed(self, event): + self.window_closed = True def _init_num_axs(self): if self.num_axs == -1: @@ -69,6 +77,7 @@ def _init_figure(self): self.figure, self.axs = plt.subplots(self.num_axs, 1, figsize=(10, 2 * self.num_axs)) if self.num_axs == 1: self.axs = [self.axs] + self.figure.canvas.mpl_connect('close_event', self._on_window_closed) def _clear_axs(self): for i in range(self.num_axs): @@ -82,6 +91,9 @@ def get_plot_bounds(self, real_time: float) -> Segment: return Segment(start_time, end_time) def on_next(self, values: Tuple[Annotation, SlidingWindowFeature, float]): + if self.window_closed: + raise WindowClosedException + prediction, waveform, real_time = values # Initialize figure if first call if self.figure is None: @@ -111,5 +123,5 @@ def on_next(self, values: Tuple[Annotation, SlidingWindowFeature, float]): plt.pause(0.05) def on_error(self, error: Exception): - print_exc() - exit(1) + if not isinstance(error, WindowClosedException): + print_exc() diff --git a/src/diart/sources.py b/src/diart/sources.py index a87b8ce7..89fe2634 100644 --- a/src/diart/sources.py +++ b/src/diart/sources.py @@ -1,4 +1,5 @@ import math +from pathlib import Path from queue import SimpleQueue from typing import Text, Optional, Callable @@ -14,7 +15,6 @@ from .features import TemporalFeatures -# TODO rename this to something else since the same API is also used to stream features class AudioSource: """Represents a source of audio that can start streaming via the `stream` property. @@ -59,8 +59,6 @@ class FileAudioSource(AudioSource): ---------- file: FilePath Path to the file to stream. - uri: Text - Unique identifier of the audio source. sample_rate: int Sample rate of the chunks emitted. chunk_duration: float @@ -71,12 +69,11 @@ class FileAudioSource(AudioSource): def __init__( self, file: FilePath, - uri: Text, sample_rate: int, chunk_duration: float = 5, step_duration: float = 0.5, ): - super().__init__(uri, sample_rate) + super().__init__(Path(file).stem, sample_rate) self.loader = AudioLoader(sample_rate, mono=True) self._duration = self.loader.get_duration(file) self.file = file @@ -123,7 +120,6 @@ class PrecalculatedFeaturesAudioSource(FileAudioSource): def __init__( self, file: FilePath, - uri: Text, sample_rate: int, segmentation: Callable[[TemporalFeatures], TemporalFeatures], embedding: Callable[[TemporalFeatures, TemporalFeatures], TemporalFeatures], @@ -132,7 +128,7 @@ def __init__( batch_size: int = 32, progress_msg: Optional[Text] = None, ): - super().__init__(file, uri, sample_rate, chunk_duration, step_duration) + super().__init__(file, sample_rate, chunk_duration, step_duration) self.segmentation = segmentation self.embedding = embedding self.batch_size = batch_size diff --git a/src/diart/stream.py b/src/diart/stream.py index e843ed08..0ef8a5a6 100644 --- a/src/diart/stream.py +++ b/src/diart/stream.py @@ -1,15 +1,12 @@ import argparse from pathlib import Path -import torch - import diart.argdoc as argdoc import diart.sources as src from diart.inference import RealTimeInference from diart.pipelines import OnlineSpeakerDiarization, PipelineConfig if __name__ == "__main__": - # Define script arguments parser = argparse.ArgumentParser() parser.add_argument("source", type=str, help="Path to an audio file | 'microphone'") parser.add_argument("--step", default=0.5, type=float, help=f"{argdoc.STEP}. Defaults to 0.5") @@ -26,32 +23,19 @@ args = parser.parse_args() # Define online speaker diarization pipeline - pipeline = OnlineSpeakerDiarization(PipelineConfig( - step=args.step, - latency=args.latency, - tau_active=args.tau, - rho_update=args.rho, - delta_new=args.delta, - gamma=args.gamma, - beta=args.beta, - max_speakers=args.max_speakers, - device=torch.device("cpu") if args.cpu else None, - ), profile=True) + config = PipelineConfig.from_namespace(args) + pipeline = OnlineSpeakerDiarization(config, profile=True) # Manage audio source if args.source != "microphone": args.source = Path(args.source).expanduser() args.output = args.source.parent if args.output is None else Path(args.output) audio_source = src.FileAudioSource( - args.source, - args.source.stem, - pipeline.config.sample_rate, - pipeline.config.duration, - pipeline.config.step, + args.source, config.sample_rate, config.duration, config.step ) else: args.output = Path("~/").expanduser() if args.output is None else Path(args.output) - audio_source = src.MicrophoneAudioSource(pipeline.config.sample_rate) + audio_source = src.MicrophoneAudioSource(config.sample_rate) # Run online inference RealTimeInference(args.output, do_plot=not args.no_plot)(pipeline, audio_source) From afd3448307fbfa127a4e00e31772188c42e5f864 Mon Sep 17 00:00:00 2001 From: Juan Manuel Coria Date: Mon, 13 Jun 2022 18:23:59 +0200 Subject: [PATCH 15/29] Reorganize blocks into a package --- src/diart/blocks.py | 558 ------------------------------- src/diart/blocks/__init__.py | 16 + src/diart/blocks/aggregation.py | 198 +++++++++++ src/diart/blocks/clustering.py | 150 +++++++++ src/diart/blocks/embedding.py | 138 ++++++++ src/diart/blocks/segmentation.py | 36 ++ src/diart/blocks/utils.py | 55 +++ 7 files changed, 593 insertions(+), 558 deletions(-) delete mode 100644 src/diart/blocks.py create mode 100644 src/diart/blocks/__init__.py create mode 100644 src/diart/blocks/aggregation.py create mode 100644 src/diart/blocks/clustering.py create mode 100644 src/diart/blocks/embedding.py create mode 100644 src/diart/blocks/segmentation.py create mode 100644 src/diart/blocks/utils.py diff --git a/src/diart/blocks.py b/src/diart/blocks.py deleted file mode 100644 index 33d2b72a..00000000 --- a/src/diart/blocks.py +++ /dev/null @@ -1,558 +0,0 @@ -from typing import Union, Optional, List, Iterable, Tuple, Text - -import numpy as np -import torch -from einops import rearrange -from pyannote.core import Annotation, Segment, SlidingWindow, SlidingWindowFeature -from typing_extensions import Literal - -from .features import TemporalFeatures, TemporalFeatureFormatter -from .mapping import SpeakerMap, SpeakerMapBuilder -from .models import SegmentationModel, EmbeddingModel - - -class SpeakerSegmentation: - def __init__(self, model: SegmentationModel, device: Optional[torch.device] = None): - self.model = model - self.model.eval() - self.device = device - if self.device is None: - self.device = torch.device("cpu") - self.model.to(self.device) - self.formatter = TemporalFeatureFormatter() - - def __call__(self, waveform: TemporalFeatures) -> TemporalFeatures: - """ - Calculate the speaker segmentation of input audio. - - Parameters - ---------- - waveform: TemporalFeatures, shape (samples, channels) or (batch, samples, channels) - - Returns - ------- - speaker_segmentation: TemporalFeatures, shape (batch, frames, speakers) - The batch dimension is omitted if waveform is a `SlidingWindowFeature`. - """ - with torch.no_grad(): - wave = rearrange(self.formatter.cast(waveform), "batch sample channel -> batch channel sample") - output = self.model(wave.to(self.device)).cpu() - return self.formatter.restore_type(output) - - -class SpeakerEmbedding: - def __init__(self, model: EmbeddingModel, device: Optional[torch.device] = None): - self.model = model - self.model.eval() - self.device = device - if self.device is None: - self.device = torch.device("cpu") - self.model.to(self.device) - self.waveform_formatter = TemporalFeatureFormatter() - self.weights_formatter = TemporalFeatureFormatter() - - def __call__(self, waveform: TemporalFeatures, weights: Optional[TemporalFeatures] = None) -> torch.Tensor: - """ - Calculate speaker embeddings of input audio. - If weights are given, calculate many speaker embeddings from the same waveform. - - Parameters - ---------- - waveform: TemporalFeatures, shape (samples, channels) or (batch, samples, channels) - weights: Optional[TemporalFeatures], shape (frames, speakers) or (batch, frames, speakers) - Per-speaker and per-frame weights. Defaults to no weights. - - Returns - ------- - embeddings: torch.Tensor - If weights are provided, the shape is (batch, speakers, embedding_dim), - otherwise the shape is (batch, embedding_dim). - If batch size == 1, the batch dimension is omitted. - """ - with torch.no_grad(): - inputs = self.waveform_formatter.cast(waveform).to(self.device) - inputs = rearrange(inputs, "batch sample channel -> batch channel sample") - if weights is not None: - weights = self.weights_formatter.cast(weights).to(self.device) - batch_size, _, num_speakers = weights.shape - inputs = inputs.repeat(1, num_speakers, 1) - weights = rearrange(weights, "batch frame spk -> (batch spk) frame") - inputs = rearrange(inputs, "batch spk sample -> (batch spk) 1 sample") - output = rearrange( - self.model(inputs, weights), - "(batch spk) feat -> batch spk feat", - batch=batch_size, - spk=num_speakers - ) - else: - output = self.model(inputs) - return output.squeeze().cpu() - - -class OverlappedSpeechPenalty: - """ - Parameters - ---------- - gamma: float, optional - Exponent to lower low-confidence predictions. - Defaults to 3. - beta: float, optional - Temperature parameter (actually 1/beta) to lower joint speaker activations. - Defaults to 10. - """ - def __init__(self, gamma: float = 3, beta: float = 10): - self.gamma = gamma - self.beta = beta - self.formatter = TemporalFeatureFormatter() - - def __call__(self, segmentation: TemporalFeatures) -> TemporalFeatures: - weights = self.formatter.cast(segmentation) # shape (batch, frames, speakers) - with torch.no_grad(): - probs = torch.softmax(self.beta * weights, dim=-1) - weights = torch.pow(weights, self.gamma) * torch.pow(probs, self.gamma) - weights[weights < 1e-8] = 1e-8 - return self.formatter.restore_type(weights) - - -class EmbeddingNormalization: - def __init__(self, norm: Union[float, torch.Tensor] = 1): - self.norm = norm - # Add batch dimension if missing - if isinstance(self.norm, torch.Tensor) and self.norm.ndim == 2: - self.norm = self.norm.unsqueeze(0) - - def __call__(self, embeddings: torch.Tensor) -> torch.Tensor: - # Add batch dimension if missing - if embeddings.ndim == 2: - embeddings = embeddings.unsqueeze(0) - if isinstance(self.norm, torch.Tensor): - batch_size1, num_speakers1, _ = self.norm.shape - batch_size2, num_speakers2, _ = embeddings.shape - assert batch_size1 == batch_size2 and num_speakers1 == num_speakers2 - with torch.no_grad(): - norm_embs = self.norm * embeddings / torch.norm(embeddings, p=2, dim=-1, keepdim=True) - return norm_embs.squeeze() - - -class OverlapAwareSpeakerEmbedding: - """ - Extract overlap-aware speaker embeddings given an audio chunk and its segmentation. - - Parameters - ---------- - model: EmbeddingModel - A pre-trained embedding model. - gamma: float, optional - Exponent to lower low-confidence predictions. - Defaults to 3. - beta: float, optional - Softmax's temperature parameter (actually 1/beta) to lower joint speaker activations. - Defaults to 10. - norm: float or torch.Tensor of shape (batch, speakers, 1) where batch is optional - The target norm for the embeddings. It can be different for each speaker. - Defaults to 1. - device: Optional[torch.device] - The device on which to run the embedding model. - Defaults to GPU if available or CPU if not. - """ - def __init__( - self, - model: EmbeddingModel, - gamma: float = 3, - beta: float = 10, - norm: Union[float, torch.Tensor] = 1, - device: Optional[torch.device] = None, - ): - self.embedding = SpeakerEmbedding(model, device) - self.osp = OverlappedSpeechPenalty(gamma, beta) - self.normalize = EmbeddingNormalization(norm) - - def __call__(self, waveform: TemporalFeatures, segmentation: TemporalFeatures) -> torch.Tensor: - return self.normalize(self.embedding(waveform, self.osp(segmentation))) - - -class AggregationStrategy: - """Abstract class representing a strategy to aggregate overlapping buffers""" - - @staticmethod - def build(name: Literal["mean", "hamming", "first"]) -> 'AggregationStrategy': - """Build an AggregationStrategy instance based on its name""" - assert name in ("mean", "hamming", "first") - if name == "mean": - return AverageStrategy() - elif name == "hamming": - return HammingWeightedAverageStrategy() - else: - return FirstOnlyStrategy() - - def __call__(self, buffers: List[SlidingWindowFeature], focus: Segment) -> SlidingWindowFeature: - """Aggregate chunks over a specific region. - - Parameters - ---------- - buffers: list of SlidingWindowFeature, shapes (frames, speakers) - Buffers to aggregate - focus: Segment - Region to aggregate that is shared among the buffers - - Returns - ------- - aggregation: SlidingWindowFeature, shape (cropped_frames, speakers) - Aggregated values over the focus region - """ - aggregation = self.aggregate(buffers, focus) - resolution = focus.duration / aggregation.shape[0] - resolution = SlidingWindow( - start=focus.start, - duration=resolution, - step=resolution - ) - return SlidingWindowFeature(aggregation, resolution) - - def aggregate(self, buffers: List[SlidingWindowFeature], focus: Segment) -> np.ndarray: - raise NotImplementedError - - -class HammingWeightedAverageStrategy(AggregationStrategy): - """Compute the average weighted by the corresponding Hamming-window aligned to each buffer""" - - def aggregate(self, buffers: List[SlidingWindowFeature], focus: Segment) -> np.ndarray: - num_frames, num_speakers = buffers[0].data.shape - hamming, intersection = [], [] - for buffer in buffers: - # Crop buffer to focus region - b = buffer.crop(focus, fixed=focus.duration) - # Crop Hamming window to focus region - h = np.expand_dims(np.hamming(num_frames), axis=-1) - h = SlidingWindowFeature(h, buffer.sliding_window) - h = h.crop(focus, fixed=focus.duration) - hamming.append(h.data) - intersection.append(b.data) - hamming, intersection = np.stack(hamming), np.stack(intersection) - # Calculate weighted mean - return np.sum(hamming * intersection, axis=0) / np.sum(hamming, axis=0) - - -class AverageStrategy(AggregationStrategy): - """Compute a simple average over the focus region""" - - def aggregate(self, buffers: List[SlidingWindowFeature], focus: Segment) -> np.ndarray: - # Stack all overlapping regions - intersection = np.stack([ - buffer.crop(focus, fixed=focus.duration) - for buffer in buffers - ]) - return np.mean(intersection, axis=0) - - -class FirstOnlyStrategy(AggregationStrategy): - """Instead of aggregating, keep the first focus region in the buffer list""" - - def aggregate(self, buffers: List[SlidingWindowFeature], focus: Segment) -> np.ndarray: - return buffers[0].crop(focus, fixed=focus.duration) - - -class DelayedAggregation: - """Aggregate aligned overlapping windows of the same duration - across sliding buffers with a specific step and latency. - - Parameters - ---------- - step: float - Shift between two consecutive buffers, in seconds. - latency: float, optional - Desired latency, in seconds. Defaults to step. - The higher the latency, the more overlapping windows to aggregate. - strategy: ("mean", "hamming", "any"), optional - Specifies how to aggregate overlapping windows. Defaults to "hamming". - "mean": simple average - "hamming": average weighted by the Hamming window values (aligned to the buffer) - "any": no aggregation, pick the first overlapping window - stream_end: float, optional - Stream end time (in seconds). Defaults to None. - If the stream end time is known, then append remaining outputs at the end, - otherwise the last `latency - step` seconds are ignored. - - Example - -------- - >>> duration = 5 - >>> frames = 500 - >>> step = 0.5 - >>> speakers = 2 - >>> start_time = 10 - >>> resolution = duration / frames - >>> dagg = DelayedAggregation(step=step, latency=2, strategy="mean") - >>> buffers = [ - >>> SlidingWindowFeature( - >>> np.random.rand(frames, speakers), - >>> SlidingWindow(start=(i + start_time) * step, duration=resolution, step=resolution) - >>> ) - >>> for i in range(dagg.num_overlapping_windows) - >>> ] - >>> dagg.num_overlapping_windows - ... 4 - >>> dagg(buffers).data.shape - ... (51, 2) # Rounding errors are possible when cropping the buffers - """ - - def __init__( - self, - step: float, - latency: Optional[float] = None, - strategy: Literal["mean", "hamming", "first"] = "hamming", - stream_end: Optional[float] = None - ): - self.step = step - self.latency = latency - self.strategy = strategy - self.stream_end = stream_end - - if self.latency is None: - self.latency = self.step - - assert self.step <= self.latency, "Invalid latency requested" - - self.num_overlapping_windows = int(round(self.latency / self.step)) - self.aggregate = AggregationStrategy.build(self.strategy) - - def _prepend_or_append( - self, - output_window: SlidingWindowFeature, - output_region: Segment, - buffers: List[SlidingWindowFeature] - ): - last_buffer = buffers[-1].extent - # Prepend prediction until we match the latency in case of first buffer - if len(buffers) == 1 and last_buffer.start == 0: - num_frames = output_window.data.shape[0] - first_region = Segment(0, output_region.end) - first_output = buffers[0].crop( - first_region, fixed=first_region.duration - ) - first_output[-num_frames:] = output_window.data - resolution = output_region.end / first_output.shape[0] - output_window = SlidingWindowFeature( - first_output, - SlidingWindow(start=0, duration=resolution, step=resolution) - ) - # Append rest of the outputs - elif self.stream_end is not None and last_buffer.end == self.stream_end: - # FIXME instead of appending a larger chunk than expected when latency > step, - # keep emitting windows until the signal ends. - # This should be fixed at the observable level and not within the aggregation block. - num_frames = output_window.data.shape[0] - last_region = Segment(output_region.start, last_buffer.end) - last_output = buffers[-1].crop( - last_region, fixed=last_region.duration - ) - last_output[:num_frames] = output_window.data - resolution = self.latency / last_output.shape[0] - output_window = SlidingWindowFeature( - last_output, - SlidingWindow( - start=output_region.start, - duration=resolution, - step=resolution - ) - ) - return output_window - - def __call__(self, buffers: List[SlidingWindowFeature]) -> SlidingWindowFeature: - # Determine overlapping region to aggregate - start = buffers[-1].extent.end - self.latency - region = Segment(start, start + self.step) - return self._prepend_or_append(self.aggregate(buffers, region), region, buffers) - - -class OnlineSpeakerClustering: - def __init__( - self, - tau_active: float, - rho_update: float, - delta_new: float, - metric: Optional[str] = "cosine", - max_speakers: int = 20 - ): - self.tau_active = tau_active - self.rho_update = rho_update - self.delta_new = delta_new - self.metric = metric - self.max_speakers = max_speakers - self.centers: Optional[np.ndarray] = None - self.active_centers = set() - self.blocked_centers = set() - - @property - def num_free_centers(self) -> int: - return self.max_speakers - self.num_known_speakers - self.num_blocked_speakers - - @property - def num_known_speakers(self) -> int: - return len(self.active_centers) - - @property - def num_blocked_speakers(self) -> int: - return len(self.blocked_centers) - - @property - def inactive_centers(self) -> List[int]: - return [ - c - for c in range(self.max_speakers) - if c not in self.active_centers or c in self.blocked_centers - ] - - def get_next_center_position(self) -> Optional[int]: - for center in range(self.max_speakers): - if center not in self.active_centers and center not in self.blocked_centers: - return center - - def init_centers(self, dimension: int): - self.centers = np.zeros((self.max_speakers, dimension)) - self.active_centers = set() - self.blocked_centers = set() - - def update(self, assignments: Iterable[Tuple[int, int]], embeddings: np.ndarray): - if self.centers is not None: - for l_spk, g_spk in assignments: - assert g_spk in self.active_centers, "Cannot update unknown centers" - self.centers[g_spk] += embeddings[l_spk] - - def add_center(self, embedding: np.ndarray) -> int: - center = self.get_next_center_position() - self.centers[center] = embedding - self.active_centers.add(center) - return center - - def identify( - self, - segmentation: SlidingWindowFeature, - embeddings: torch.Tensor - ) -> SpeakerMap: - embeddings = embeddings.detach().cpu().numpy() - active_speakers = np.where(np.max(segmentation.data, axis=0) >= self.tau_active)[0] - long_speakers = np.where(np.mean(segmentation.data, axis=0) >= self.rho_update)[0] - num_local_speakers = segmentation.data.shape[1] - - if self.centers is None: - self.init_centers(embeddings.shape[1]) - assignments = [ - (spk, self.add_center(embeddings[spk])) - for spk in active_speakers - ] - return SpeakerMapBuilder.hard_map( - shape=(num_local_speakers, self.max_speakers), - assignments=assignments, - maximize=False, - ) - - # Obtain a mapping based on distances between embeddings and centers - dist_map = SpeakerMapBuilder.dist(embeddings, self.centers, self.metric) - # Remove any assignments containing invalid speakers - inactive_speakers = np.array([ - spk for spk in range(num_local_speakers) - if spk not in active_speakers - ]) - dist_map = dist_map.unmap_speakers(inactive_speakers, self.inactive_centers) - # Keep assignments under the distance threshold - valid_map = dist_map.unmap_threshold(self.delta_new) - - # Some speakers might be unidentified - missed_speakers = [ - s for s in active_speakers - if not valid_map.is_source_speaker_mapped(s) - ] - - # Add assignments to new centers if possible - new_center_speakers = [] - for spk in missed_speakers: - has_space = len(new_center_speakers) < self.num_free_centers - if has_space and spk in long_speakers: - # Flag as a new center - new_center_speakers.append(spk) - else: - # Cannot create a new center - # Get global speakers in order of preference - preferences = np.argsort(dist_map.mapping_matrix[spk, :]) - preferences = [ - g_spk for g_spk in preferences if g_spk in self.active_centers - ] - # Get the free global speakers among the preferences - _, g_assigned = valid_map.valid_assignments() - free = [g_spk for g_spk in preferences if g_spk not in g_assigned] - if free: - # The best global speaker is the closest free one - valid_map = valid_map.set_source_speaker(spk, free[0]) - - # Update known centers - to_update = [ - (ls, gs) - for ls, gs in zip(*valid_map.valid_assignments()) - if ls not in missed_speakers and ls in long_speakers - ] - self.update(to_update, embeddings) - - # Add new centers - for spk in new_center_speakers: - valid_map = valid_map.set_source_speaker( - spk, self.add_center(embeddings[spk]) - ) - - return valid_map - - def __call__(self, segmentation: SlidingWindowFeature, embeddings: torch.Tensor) -> SlidingWindowFeature: - return SlidingWindowFeature( - self.identify(segmentation, embeddings).apply(segmentation.data), - segmentation.sliding_window - ) - - -class Binarize: - """ - Transform a speaker segmentation from the discrete-time domain - into a continuous-time speaker segmentation. - - Parameters - ---------- - uri: Text - Uri of the audio stream. - threshold: float - Probability threshold to determine if a speaker is active at a given frame. - """ - - def __init__(self, uri: Text, threshold: float): - self.uri = uri - self.threshold = threshold - - def __call__(self, segmentation: SlidingWindowFeature) -> Annotation: - """ - Return the continuous-time segmentation - corresponding to the discrete-time input segmentation. - - Parameters - ---------- - segmentation: SlidingWindowFeature - Discrete-time speaker segmentation. - - Returns - ------- - annotation: Annotation - Continuous-time speaker segmentation. - """ - num_frames, num_speakers = segmentation.data.shape - timestamps = segmentation.sliding_window - is_active = segmentation.data > self.threshold - # Artificially add last inactive frame to close any remaining speaker turns - is_active = np.append(is_active, [[False] * num_speakers], axis=0) - start_times = np.zeros(num_speakers) + timestamps[0].middle - annotation = Annotation(uri=self.uri, modality="speech") - for t in range(num_frames): - # Any (False, True) starts a speaker turn at "True" index - onsets = np.logical_and(np.logical_not(is_active[t]), is_active[t + 1]) - start_times[onsets] = timestamps[t + 1].middle - # Any (True, False) ends a speaker turn at "False" index - offsets = np.logical_and(is_active[t], np.logical_not(is_active[t + 1])) - for spk in np.where(offsets)[0]: - region = Segment(start_times[spk], timestamps[t + 1].middle) - annotation[region, spk] = f"speaker{spk}" - return annotation diff --git a/src/diart/blocks/__init__.py b/src/diart/blocks/__init__.py new file mode 100644 index 00000000..a2d88c00 --- /dev/null +++ b/src/diart/blocks/__init__.py @@ -0,0 +1,16 @@ +from .aggregation import ( + AggregationStrategy, + HammingWeightedAverageStrategy, + AverageStrategy, + FirstOnlyStrategy, + DelayedAggregation, +) +from .clustering import OnlineSpeakerClustering +from .embedding import ( + SpeakerEmbedding, + OverlappedSpeechPenalty, + EmbeddingNormalization, + OverlapAwareSpeakerEmbedding, +) +from .segmentation import SpeakerSegmentation +from .utils import Binarize diff --git a/src/diart/blocks/aggregation.py b/src/diart/blocks/aggregation.py new file mode 100644 index 00000000..d0379711 --- /dev/null +++ b/src/diart/blocks/aggregation.py @@ -0,0 +1,198 @@ +from typing import Optional, List + +import numpy as np +from pyannote.core import Segment, SlidingWindow, SlidingWindowFeature +from typing_extensions import Literal + + +class AggregationStrategy: + """Abstract class representing a strategy to aggregate overlapping buffers""" + + @staticmethod + def build(name: Literal["mean", "hamming", "first"]) -> 'AggregationStrategy': + """Build an AggregationStrategy instance based on its name""" + assert name in ("mean", "hamming", "first") + if name == "mean": + return AverageStrategy() + elif name == "hamming": + return HammingWeightedAverageStrategy() + else: + return FirstOnlyStrategy() + + def __call__(self, buffers: List[SlidingWindowFeature], focus: Segment) -> SlidingWindowFeature: + """Aggregate chunks over a specific region. + + Parameters + ---------- + buffers: list of SlidingWindowFeature, shapes (frames, speakers) + Buffers to aggregate + focus: Segment + Region to aggregate that is shared among the buffers + + Returns + ------- + aggregation: SlidingWindowFeature, shape (cropped_frames, speakers) + Aggregated values over the focus region + """ + aggregation = self.aggregate(buffers, focus) + resolution = focus.duration / aggregation.shape[0] + resolution = SlidingWindow( + start=focus.start, + duration=resolution, + step=resolution + ) + return SlidingWindowFeature(aggregation, resolution) + + def aggregate(self, buffers: List[SlidingWindowFeature], focus: Segment) -> np.ndarray: + raise NotImplementedError + + +class HammingWeightedAverageStrategy(AggregationStrategy): + """Compute the average weighted by the corresponding Hamming-window aligned to each buffer""" + + def aggregate(self, buffers: List[SlidingWindowFeature], focus: Segment) -> np.ndarray: + num_frames, num_speakers = buffers[0].data.shape + hamming, intersection = [], [] + for buffer in buffers: + # Crop buffer to focus region + b = buffer.crop(focus, fixed=focus.duration) + # Crop Hamming window to focus region + h = np.expand_dims(np.hamming(num_frames), axis=-1) + h = SlidingWindowFeature(h, buffer.sliding_window) + h = h.crop(focus, fixed=focus.duration) + hamming.append(h.data) + intersection.append(b.data) + hamming, intersection = np.stack(hamming), np.stack(intersection) + # Calculate weighted mean + return np.sum(hamming * intersection, axis=0) / np.sum(hamming, axis=0) + + +class AverageStrategy(AggregationStrategy): + """Compute a simple average over the focus region""" + + def aggregate(self, buffers: List[SlidingWindowFeature], focus: Segment) -> np.ndarray: + # Stack all overlapping regions + intersection = np.stack([ + buffer.crop(focus, fixed=focus.duration) + for buffer in buffers + ]) + return np.mean(intersection, axis=0) + + +class FirstOnlyStrategy(AggregationStrategy): + """Instead of aggregating, keep the first focus region in the buffer list""" + + def aggregate(self, buffers: List[SlidingWindowFeature], focus: Segment) -> np.ndarray: + return buffers[0].crop(focus, fixed=focus.duration) + + +class DelayedAggregation: + """Aggregate aligned overlapping windows of the same duration + across sliding buffers with a specific step and latency. + + Parameters + ---------- + step: float + Shift between two consecutive buffers, in seconds. + latency: float, optional + Desired latency, in seconds. Defaults to step. + The higher the latency, the more overlapping windows to aggregate. + strategy: ("mean", "hamming", "any"), optional + Specifies how to aggregate overlapping windows. Defaults to "hamming". + "mean": simple average + "hamming": average weighted by the Hamming window values (aligned to the buffer) + "any": no aggregation, pick the first overlapping window + stream_end: float, optional + Stream end time (in seconds). Defaults to None. + If the stream end time is known, then append remaining outputs at the end, + otherwise the last `latency - step` seconds are ignored. + + Example + -------- + >>> duration = 5 + >>> frames = 500 + >>> step = 0.5 + >>> speakers = 2 + >>> start_time = 10 + >>> resolution = duration / frames + >>> dagg = DelayedAggregation(step=step, latency=2, strategy="mean") + >>> buffers = [ + >>> SlidingWindowFeature( + >>> np.random.rand(frames, speakers), + >>> SlidingWindow(start=(i + start_time) * step, duration=resolution, step=resolution) + >>> ) + >>> for i in range(dagg.num_overlapping_windows) + >>> ] + >>> dagg.num_overlapping_windows + ... 4 + >>> dagg(buffers).data.shape + ... (51, 2) # Rounding errors are possible when cropping the buffers + """ + + def __init__( + self, + step: float, + latency: Optional[float] = None, + strategy: Literal["mean", "hamming", "first"] = "hamming", + stream_end: Optional[float] = None + ): + self.step = step + self.latency = latency + self.strategy = strategy + self.stream_end = stream_end + + if self.latency is None: + self.latency = self.step + + assert self.step <= self.latency, "Invalid latency requested" + + self.num_overlapping_windows = int(round(self.latency / self.step)) + self.aggregate = AggregationStrategy.build(self.strategy) + + def _prepend_or_append( + self, + output_window: SlidingWindowFeature, + output_region: Segment, + buffers: List[SlidingWindowFeature] + ): + last_buffer = buffers[-1].extent + # Prepend prediction until we match the latency in case of first buffer + if len(buffers) == 1 and last_buffer.start == 0: + num_frames = output_window.data.shape[0] + first_region = Segment(0, output_region.end) + first_output = buffers[0].crop( + first_region, fixed=first_region.duration + ) + first_output[-num_frames:] = output_window.data + resolution = output_region.end / first_output.shape[0] + output_window = SlidingWindowFeature( + first_output, + SlidingWindow(start=0, duration=resolution, step=resolution) + ) + # Append rest of the outputs + elif self.stream_end is not None and last_buffer.end == self.stream_end: + # FIXME instead of appending a larger chunk than expected when latency > step, + # keep emitting windows until the signal ends. + # This should be fixed at the observable level and not within the aggregation block. + num_frames = output_window.data.shape[0] + last_region = Segment(output_region.start, last_buffer.end) + last_output = buffers[-1].crop( + last_region, fixed=last_region.duration + ) + last_output[:num_frames] = output_window.data + resolution = self.latency / last_output.shape[0] + output_window = SlidingWindowFeature( + last_output, + SlidingWindow( + start=output_region.start, + duration=resolution, + step=resolution + ) + ) + return output_window + + def __call__(self, buffers: List[SlidingWindowFeature]) -> SlidingWindowFeature: + # Determine overlapping region to aggregate + start = buffers[-1].extent.end - self.latency + region = Segment(start, start + self.step) + return self._prepend_or_append(self.aggregate(buffers, region), region, buffers) diff --git a/src/diart/blocks/clustering.py b/src/diart/blocks/clustering.py new file mode 100644 index 00000000..1842ecfa --- /dev/null +++ b/src/diart/blocks/clustering.py @@ -0,0 +1,150 @@ +from typing import Optional, List, Iterable, Tuple + +import numpy as np +import torch +from pyannote.core import SlidingWindowFeature + +from mapping import SpeakerMap, SpeakerMapBuilder + + +class OnlineSpeakerClustering: + def __init__( + self, + tau_active: float, + rho_update: float, + delta_new: float, + metric: Optional[str] = "cosine", + max_speakers: int = 20 + ): + self.tau_active = tau_active + self.rho_update = rho_update + self.delta_new = delta_new + self.metric = metric + self.max_speakers = max_speakers + self.centers: Optional[np.ndarray] = None + self.active_centers = set() + self.blocked_centers = set() + + @property + def num_free_centers(self) -> int: + return self.max_speakers - self.num_known_speakers - self.num_blocked_speakers + + @property + def num_known_speakers(self) -> int: + return len(self.active_centers) + + @property + def num_blocked_speakers(self) -> int: + return len(self.blocked_centers) + + @property + def inactive_centers(self) -> List[int]: + return [ + c + for c in range(self.max_speakers) + if c not in self.active_centers or c in self.blocked_centers + ] + + def get_next_center_position(self) -> Optional[int]: + for center in range(self.max_speakers): + if center not in self.active_centers and center not in self.blocked_centers: + return center + + def init_centers(self, dimension: int): + self.centers = np.zeros((self.max_speakers, dimension)) + self.active_centers = set() + self.blocked_centers = set() + + def update(self, assignments: Iterable[Tuple[int, int]], embeddings: np.ndarray): + if self.centers is not None: + for l_spk, g_spk in assignments: + assert g_spk in self.active_centers, "Cannot update unknown centers" + self.centers[g_spk] += embeddings[l_spk] + + def add_center(self, embedding: np.ndarray) -> int: + center = self.get_next_center_position() + self.centers[center] = embedding + self.active_centers.add(center) + return center + + def identify( + self, + segmentation: SlidingWindowFeature, + embeddings: torch.Tensor + ) -> SpeakerMap: + embeddings = embeddings.detach().cpu().numpy() + active_speakers = np.where(np.max(segmentation.data, axis=0) >= self.tau_active)[0] + long_speakers = np.where(np.mean(segmentation.data, axis=0) >= self.rho_update)[0] + num_local_speakers = segmentation.data.shape[1] + + if self.centers is None: + self.init_centers(embeddings.shape[1]) + assignments = [ + (spk, self.add_center(embeddings[spk])) + for spk in active_speakers + ] + return SpeakerMapBuilder.hard_map( + shape=(num_local_speakers, self.max_speakers), + assignments=assignments, + maximize=False, + ) + + # Obtain a mapping based on distances between embeddings and centers + dist_map = SpeakerMapBuilder.dist(embeddings, self.centers, self.metric) + # Remove any assignments containing invalid speakers + inactive_speakers = np.array([ + spk for spk in range(num_local_speakers) + if spk not in active_speakers + ]) + dist_map = dist_map.unmap_speakers(inactive_speakers, self.inactive_centers) + # Keep assignments under the distance threshold + valid_map = dist_map.unmap_threshold(self.delta_new) + + # Some speakers might be unidentified + missed_speakers = [ + s for s in active_speakers + if not valid_map.is_source_speaker_mapped(s) + ] + + # Add assignments to new centers if possible + new_center_speakers = [] + for spk in missed_speakers: + has_space = len(new_center_speakers) < self.num_free_centers + if has_space and spk in long_speakers: + # Flag as a new center + new_center_speakers.append(spk) + else: + # Cannot create a new center + # Get global speakers in order of preference + preferences = np.argsort(dist_map.mapping_matrix[spk, :]) + preferences = [ + g_spk for g_spk in preferences if g_spk in self.active_centers + ] + # Get the free global speakers among the preferences + _, g_assigned = valid_map.valid_assignments() + free = [g_spk for g_spk in preferences if g_spk not in g_assigned] + if free: + # The best global speaker is the closest free one + valid_map = valid_map.set_source_speaker(spk, free[0]) + + # Update known centers + to_update = [ + (ls, gs) + for ls, gs in zip(*valid_map.valid_assignments()) + if ls not in missed_speakers and ls in long_speakers + ] + self.update(to_update, embeddings) + + # Add new centers + for spk in new_center_speakers: + valid_map = valid_map.set_source_speaker( + spk, self.add_center(embeddings[spk]) + ) + + return valid_map + + def __call__(self, segmentation: SlidingWindowFeature, embeddings: torch.Tensor) -> SlidingWindowFeature: + return SlidingWindowFeature( + self.identify(segmentation, embeddings).apply(segmentation.data), + segmentation.sliding_window + ) diff --git a/src/diart/blocks/embedding.py b/src/diart/blocks/embedding.py new file mode 100644 index 00000000..8598daa2 --- /dev/null +++ b/src/diart/blocks/embedding.py @@ -0,0 +1,138 @@ +from typing import Optional, Union + +import torch +from einops import rearrange + +from features import TemporalFeatures, TemporalFeatureFormatter +from models import EmbeddingModel + + +class SpeakerEmbedding: + def __init__(self, model: EmbeddingModel, device: Optional[torch.device] = None): + self.model = model + self.model.eval() + self.device = device + if self.device is None: + self.device = torch.device("cpu") + self.model.to(self.device) + self.waveform_formatter = TemporalFeatureFormatter() + self.weights_formatter = TemporalFeatureFormatter() + + def __call__(self, waveform: TemporalFeatures, weights: Optional[TemporalFeatures] = None) -> torch.Tensor: + """ + Calculate speaker embeddings of input audio. + If weights are given, calculate many speaker embeddings from the same waveform. + + Parameters + ---------- + waveform: TemporalFeatures, shape (samples, channels) or (batch, samples, channels) + weights: Optional[TemporalFeatures], shape (frames, speakers) or (batch, frames, speakers) + Per-speaker and per-frame weights. Defaults to no weights. + + Returns + ------- + embeddings: torch.Tensor + If weights are provided, the shape is (batch, speakers, embedding_dim), + otherwise the shape is (batch, embedding_dim). + If batch size == 1, the batch dimension is omitted. + """ + with torch.no_grad(): + inputs = self.waveform_formatter.cast(waveform).to(self.device) + inputs = rearrange(inputs, "batch sample channel -> batch channel sample") + if weights is not None: + weights = self.weights_formatter.cast(weights).to(self.device) + batch_size, _, num_speakers = weights.shape + inputs = inputs.repeat(1, num_speakers, 1) + weights = rearrange(weights, "batch frame spk -> (batch spk) frame") + inputs = rearrange(inputs, "batch spk sample -> (batch spk) 1 sample") + output = rearrange( + self.model(inputs, weights), + "(batch spk) feat -> batch spk feat", + batch=batch_size, + spk=num_speakers + ) + else: + output = self.model(inputs) + return output.squeeze().cpu() + + +class OverlappedSpeechPenalty: + """ + Parameters + ---------- + gamma: float, optional + Exponent to lower low-confidence predictions. + Defaults to 3. + beta: float, optional + Temperature parameter (actually 1/beta) to lower joint speaker activations. + Defaults to 10. + """ + def __init__(self, gamma: float = 3, beta: float = 10): + self.gamma = gamma + self.beta = beta + self.formatter = TemporalFeatureFormatter() + + def __call__(self, segmentation: TemporalFeatures) -> TemporalFeatures: + weights = self.formatter.cast(segmentation) # shape (batch, frames, speakers) + with torch.no_grad(): + probs = torch.softmax(self.beta * weights, dim=-1) + weights = torch.pow(weights, self.gamma) * torch.pow(probs, self.gamma) + weights[weights < 1e-8] = 1e-8 + return self.formatter.restore_type(weights) + + +class EmbeddingNormalization: + def __init__(self, norm: Union[float, torch.Tensor] = 1): + self.norm = norm + # Add batch dimension if missing + if isinstance(self.norm, torch.Tensor) and self.norm.ndim == 2: + self.norm = self.norm.unsqueeze(0) + + def __call__(self, embeddings: torch.Tensor) -> torch.Tensor: + # Add batch dimension if missing + if embeddings.ndim == 2: + embeddings = embeddings.unsqueeze(0) + if isinstance(self.norm, torch.Tensor): + batch_size1, num_speakers1, _ = self.norm.shape + batch_size2, num_speakers2, _ = embeddings.shape + assert batch_size1 == batch_size2 and num_speakers1 == num_speakers2 + with torch.no_grad(): + norm_embs = self.norm * embeddings / torch.norm(embeddings, p=2, dim=-1, keepdim=True) + return norm_embs.squeeze() + + +class OverlapAwareSpeakerEmbedding: + """ + Extract overlap-aware speaker embeddings given an audio chunk and its segmentation. + + Parameters + ---------- + model: EmbeddingModel + A pre-trained embedding model. + gamma: float, optional + Exponent to lower low-confidence predictions. + Defaults to 3. + beta: float, optional + Softmax's temperature parameter (actually 1/beta) to lower joint speaker activations. + Defaults to 10. + norm: float or torch.Tensor of shape (batch, speakers, 1) where batch is optional + The target norm for the embeddings. It can be different for each speaker. + Defaults to 1. + device: Optional[torch.device] + The device on which to run the embedding model. + Defaults to GPU if available or CPU if not. + """ + def __init__( + self, + model: EmbeddingModel, + gamma: float = 3, + beta: float = 10, + norm: Union[float, torch.Tensor] = 1, + device: Optional[torch.device] = None, + ): + self.embedding = SpeakerEmbedding(model, device) + self.osp = OverlappedSpeechPenalty(gamma, beta) + self.normalize = EmbeddingNormalization(norm) + + def __call__(self, waveform: TemporalFeatures, segmentation: TemporalFeatures) -> torch.Tensor: + return self.normalize(self.embedding(waveform, self.osp(segmentation))) diff --git a/src/diart/blocks/segmentation.py b/src/diart/blocks/segmentation.py new file mode 100644 index 00000000..d3015806 --- /dev/null +++ b/src/diart/blocks/segmentation.py @@ -0,0 +1,36 @@ +from typing import Optional + +import torch +from einops import rearrange + +from features import TemporalFeatures, TemporalFeatureFormatter +from models import SegmentationModel + + +class SpeakerSegmentation: + def __init__(self, model: SegmentationModel, device: Optional[torch.device] = None): + self.model = model + self.model.eval() + self.device = device + if self.device is None: + self.device = torch.device("cpu") + self.model.to(self.device) + self.formatter = TemporalFeatureFormatter() + + def __call__(self, waveform: TemporalFeatures) -> TemporalFeatures: + """ + Calculate the speaker segmentation of input audio. + + Parameters + ---------- + waveform: TemporalFeatures, shape (samples, channels) or (batch, samples, channels) + + Returns + ------- + speaker_segmentation: TemporalFeatures, shape (batch, frames, speakers) + The batch dimension is omitted if waveform is a `SlidingWindowFeature`. + """ + with torch.no_grad(): + wave = rearrange(self.formatter.cast(waveform), "batch sample channel -> batch channel sample") + output = self.model(wave.to(self.device)).cpu() + return self.formatter.restore_type(output) \ No newline at end of file diff --git a/src/diart/blocks/utils.py b/src/diart/blocks/utils.py new file mode 100644 index 00000000..9a495ca8 --- /dev/null +++ b/src/diart/blocks/utils.py @@ -0,0 +1,55 @@ +from typing import Text + +import numpy as np +from pyannote.core import Annotation, Segment, SlidingWindowFeature + + +class Binarize: + """ + Transform a speaker segmentation from the discrete-time domain + into a continuous-time speaker segmentation. + + Parameters + ---------- + uri: Text + Uri of the audio stream. + threshold: float + Probability threshold to determine if a speaker is active at a given frame. + """ + + def __init__(self, uri: Text, threshold: float): + self.uri = uri + self.threshold = threshold + + def __call__(self, segmentation: SlidingWindowFeature) -> Annotation: + """ + Return the continuous-time segmentation + corresponding to the discrete-time input segmentation. + + Parameters + ---------- + segmentation: SlidingWindowFeature + Discrete-time speaker segmentation. + + Returns + ------- + annotation: Annotation + Continuous-time speaker segmentation. + """ + num_frames, num_speakers = segmentation.data.shape + timestamps = segmentation.sliding_window + is_active = segmentation.data > self.threshold + # Artificially add last inactive frame to close any remaining speaker turns + is_active = np.append(is_active, [[False] * num_speakers], axis=0) + start_times = np.zeros(num_speakers) + timestamps[0].middle + annotation = Annotation(uri=self.uri, modality="speech") + for t in range(num_frames): + # Any (False, True) starts a speaker turn at "True" index + onsets = np.logical_and(np.logical_not(is_active[t]), is_active[t + 1]) + start_times[onsets] = timestamps[t + 1].middle + # Any (True, False) ends a speaker turn at "False" index + offsets = np.logical_and(is_active[t], np.logical_not(is_active[t + 1])) + for spk in np.where(offsets)[0]: + region = Segment(start_times[spk], timestamps[t + 1].middle) + annotation[region, spk] = f"speaker{spk}" + return annotation From 8b47aebd6c7fdaded6d44ed77b235570e19a8029 Mon Sep 17 00:00:00 2001 From: Juan Manuel Coria Date: Mon, 13 Jun 2022 18:45:31 +0200 Subject: [PATCH 16/29] Add relative imports to blocks package --- src/diart/blocks/clustering.py | 2 +- src/diart/blocks/embedding.py | 4 ++-- src/diart/blocks/segmentation.py | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/diart/blocks/clustering.py b/src/diart/blocks/clustering.py index 1842ecfa..57a3ab2c 100644 --- a/src/diart/blocks/clustering.py +++ b/src/diart/blocks/clustering.py @@ -4,7 +4,7 @@ import torch from pyannote.core import SlidingWindowFeature -from mapping import SpeakerMap, SpeakerMapBuilder +from ..mapping import SpeakerMap, SpeakerMapBuilder class OnlineSpeakerClustering: diff --git a/src/diart/blocks/embedding.py b/src/diart/blocks/embedding.py index 8598daa2..19288876 100644 --- a/src/diart/blocks/embedding.py +++ b/src/diart/blocks/embedding.py @@ -3,8 +3,8 @@ import torch from einops import rearrange -from features import TemporalFeatures, TemporalFeatureFormatter -from models import EmbeddingModel +from ..features import TemporalFeatures, TemporalFeatureFormatter +from ..models import EmbeddingModel class SpeakerEmbedding: diff --git a/src/diart/blocks/segmentation.py b/src/diart/blocks/segmentation.py index d3015806..33064207 100644 --- a/src/diart/blocks/segmentation.py +++ b/src/diart/blocks/segmentation.py @@ -3,8 +3,8 @@ import torch from einops import rearrange -from features import TemporalFeatures, TemporalFeatureFormatter -from models import SegmentationModel +from ..features import TemporalFeatures, TemporalFeatureFormatter +from ..models import SegmentationModel class SpeakerSegmentation: From b91550bb83c959670905307174503ed32f0e3758 Mon Sep 17 00:00:00 2001 From: Juan Manuel Coria Date: Mon, 13 Jun 2022 21:00:01 +0200 Subject: [PATCH 17/29] Add README instructions needed for a no-pyannote installation --- README.md | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index eb90e123..4a32803c 100644 --- a/README.md +++ b/README.md @@ -27,9 +27,16 @@ conda create -n diart python=3.8 conda activate diart ``` -2) [Install PyTorch](https://pytorch.org/get-started/locally/#start-locally) +2) Install `PortAudio` and `soundfile`: -3) Install pyannote.audio 2.0 (currently in development) +```shell +conda install portaudio +conda install pysoundfile -c conda-forge +``` + +3) [Install PyTorch](https://pytorch.org/get-started/locally/#start-locally) + +4) Install pyannote.audio 2.0 (currently in development) ```shell pip install git+https://github.com/pyannote/pyannote-audio.git@develop#egg=pyannote-audio @@ -37,7 +44,7 @@ pip install git+https://github.com/pyannote/pyannote-audio.git@develop#egg=pyann **Note:** starting from version 0.4, installing pyannote.audio is mandatory to run the default system or to use pyannote-based models. In any other case, this step can be ignored. -4) Install diart: +5) Install diart: ```shell pip install diart ``` From a6ddb07907afe81dc9919c216e1beff4a791086a Mon Sep 17 00:00:00 2001 From: Khaled Zaouk Date: Mon, 20 Jun 2022 11:17:28 +0200 Subject: [PATCH 18/29] Add documentation for some classes and methods (#31) * Adds documentation for some of the classes and methods under functional.py and mapping.py Co-authored-by: Juan Coria --- src/diart/blocks/clustering.py | 62 ++++++++++++++++++++++++++++++++++ src/diart/mapping.py | 34 +++++++++++++++++++ src/diart/pipelines.py | 3 +- 3 files changed, 97 insertions(+), 2 deletions(-) diff --git a/src/diart/blocks/clustering.py b/src/diart/blocks/clustering.py index 57a3ab2c..882001b9 100644 --- a/src/diart/blocks/clustering.py +++ b/src/diart/blocks/clustering.py @@ -8,6 +8,25 @@ class OnlineSpeakerClustering: + """Implements constrained incremental online clustering of speakers and manages cluster centers. + + Parameters + ---------- + tau_active:float + Threshold for detecting active speakers. This threshold is applied on the maximum value of per-speaker output + activation of the local segmentation model. + rho_update: float + Threshold for considering the extracted embedding when updating the centroid of the local speaker. + The centroid to which a local speaker is mapped is only updated if the ratio of speech/chunk duration + of a given local speaker is greater than this threshold. + delta_new: float + Threshold on the distance between a speaker embedding and a centroid. If the distance between a local speaker and all + centroids is larger than delta_new, then a new centroid is created for the current speaker. + metric: str. Defaults to "cosine". + The distance metric to use. + max_speakers: int + Maximum number of global speakers to track through a conversation. Defaults to 20. + """ def __init__( self, tau_active: float, @@ -51,17 +70,46 @@ def get_next_center_position(self) -> Optional[int]: return center def init_centers(self, dimension: int): + """Initializes the speaker centroid matrix + + Parameters + ---------- + dimension: int + Dimension of embeddings used for representing a speaker. + """ self.centers = np.zeros((self.max_speakers, dimension)) self.active_centers = set() self.blocked_centers = set() def update(self, assignments: Iterable[Tuple[int, int]], embeddings: np.ndarray): + """Updates the speaker centroids given a list of assignments and local speaker embeddings + + Parameters + ---------- + assignments: Iterable[Tuple[int, int]]) + An iterable of tuples with two elements having the first element as the source speaker + and the second element as the target speaker. + embeddings: np.ndarray, shape (local_speakers, embedding_dim) + Matrix containing embeddings for all local speakers. + """ if self.centers is not None: for l_spk, g_spk in assignments: assert g_spk in self.active_centers, "Cannot update unknown centers" self.centers[g_spk] += embeddings[l_spk] def add_center(self, embedding: np.ndarray) -> int: + """Add a new speaker centroid initialized to a given embedding + + Parameters + ---------- + embedding: np.ndarray + Embedding vector of some local speaker + + Returns + ------- + center_index: int + Index of the created center + """ center = self.get_next_center_position() self.centers[center] = embedding self.active_centers.add(center) @@ -72,6 +120,20 @@ def identify( segmentation: SlidingWindowFeature, embeddings: torch.Tensor ) -> SpeakerMap: + """Identify the centroids to which the input speaker embeddings belong. + + Parameters + ---------- + segmentation: np.ndarray, shape (frames, local_speakers) + Matrix of segmentation outputs + embeddings: np.ndarray, shape (local_speakers, embedding_dim) + Matrix of embeddings + + Returns + ------- + speaker_map: SpeakerMap + A mapping from local speakers to global speakers. + """ embeddings = embeddings.detach().cpu().numpy() active_speakers = np.where(np.max(segmentation.data, axis=0) >= self.tau_active)[0] long_speakers = np.where(np.mean(segmentation.data, axis=0) >= self.rho_update)[0] diff --git a/src/diart/mapping.py b/src/diart/mapping.py index 01465086..2795ba0b 100644 --- a/src/diart/mapping.py +++ b/src/diart/mapping.py @@ -22,6 +22,23 @@ def mapped_indices(self, matrix: np.ndarray, axis: int) -> List[int]: def hard_speaker_map( self, num_src: int, num_tgt: int, assignments: Iterable[Tuple[int, int]] ) -> SpeakerMap: + """Create a hard map object where the highest cost is put + everywhere except on hard assignments from ``assignments``. + + Parameters + ---------- + num_src: int + Number of source speakers + num_tgt: int + Number of target speakers + assignments: Iterable[Tuple[int, int]] + An iterable of tuples with two elements having the first element as the source speaker + and the second element as the target speaker + + Returns + ------- + SpeakerMap + """ mapping_matrix = self.invalid_tensor(shape=(num_src, num_tgt)) for src, tgt in assignments: mapping_matrix[src, tgt] = self.best_possible_value @@ -82,6 +99,23 @@ class SpeakerMapBuilder: def hard_map( shape: Tuple[int, int], assignments: Iterable[Tuple[int, int]], maximize: bool ) -> SpeakerMap: + """Create a ``SpeakerMap`` object based on the given assignments. This is a "hard" map, meaning that the + highest cost is put everywhere except on hard assignments from ``assignments``. + + Parameters + ---------- + shape: Tuple[int, int]) + Shape of the mapping matrix + assignments: Iterable[Tuple[int, int]] + An iterable of tuples with two elements having the first element as the source speaker + and the second element as the target speaker + maximize: bool + whether to use scores where higher is better (true) or where lower is better (false) + + Returns + ------- + SpeakerMap + """ num_src, num_tgt = shape objective = MaximizationObjective if maximize else MinimizationObjective return objective().hard_speaker_map(num_src, num_tgt, assignments) diff --git a/src/diart/pipelines.py b/src/diart/pipelines.py index 4ccbf1d4..dea360d6 100644 --- a/src/diart/pipelines.py +++ b/src/diart/pipelines.py @@ -81,8 +81,7 @@ def from_namespace(args: Namespace) -> 'PipelineConfig': ) def last_chunk_end_time(self, conv_duration: float) -> Optional[float]: - """ - Return the end time of the last chunk for a given conversation duration. + """Return the end time of the last chunk for a given conversation duration. Parameters ---------- From 8b91773c7f8207a08205e943dff7c37bc3398e4d Mon Sep 17 00:00:00 2001 From: Juan Manuel Coria Date: Mon, 13 Jun 2022 22:29:51 +0200 Subject: [PATCH 19/29] Add initial implementation of hyper-parameter tuning with optuna --- requirements.txt | 2 +- setup.cfg | 1 + src/diart/benchmark.py | 7 ++++ src/diart/inference.py | 29 +++++++++----- src/diart/optim.py | 90 ++++++++++++++++++++++++++++++++++++++++++ src/diart/pipelines.py | 16 ++++---- src/diart/stream.py | 7 ++++ 7 files changed, 135 insertions(+), 17 deletions(-) create mode 100644 src/diart/optim.py diff --git a/requirements.txt b/requirements.txt index 2a267b34..d2259fb5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,4 +10,4 @@ torchaudio>=0.10,<1.0 pyannote.core>=4.4 pyannote.database>=4.1.1 pyannote.metrics>=3.2 -git+https://github.com/pyannote/pyannote-audio.git@develop#egg=pyannote-audio +optuna>=2.10 diff --git a/setup.cfg b/setup.cfg index d2d5c5b3..a0f7c776 100644 --- a/setup.cfg +++ b/setup.cfg @@ -32,6 +32,7 @@ install_requires = pyannote.core>=4.4 pyannote.database>=4.1.1 pyannote.metrics>=3.2 + optuna>=2.10 [options.packages.find] diff --git a/src/diart/benchmark.py b/src/diart/benchmark.py index 1ac0bafd..c83d9e54 100644 --- a/src/diart/benchmark.py +++ b/src/diart/benchmark.py @@ -1,5 +1,7 @@ import argparse +import torch + import diart.argdoc as argdoc from diart.inference import Benchmark from diart.pipelines import OnlineSpeakerDiarization, PipelineConfig @@ -21,6 +23,11 @@ parser.add_argument("--output", type=str, help=f"{argdoc.OUTPUT}. Defaults to `root`") args = parser.parse_args() + args.device = torch.device("cpu") if args.cpu else None + args.tau_active = args.tau + args.rho_update = args.rho + args.delta_new = args.delta + Benchmark(args.root, args.reference, args.output)( OnlineSpeakerDiarization(PipelineConfig.from_namespace(args), profile=True), args.batch_size diff --git a/src/diart/inference.py b/src/diart/inference.py index 6e204247..ec31a2fe 100644 --- a/src/diart/inference.py +++ b/src/diart/inference.py @@ -107,7 +107,12 @@ def __init__( self.output_path = Path(output_path).expanduser() self.output_path.mkdir(parents=True, exist_ok=True) - def __call__(self, pipeline: OnlineSpeakerDiarization, batch_size: int = 32) -> Optional[pd.DataFrame]: + def __call__( + self, + pipeline: OnlineSpeakerDiarization, + batch_size: int = 32, + verbose: bool = True + ) -> Optional[pd.DataFrame]: """ Run a given pipeline on a set of audio files using pre-calculated segmentation and embeddings in batches. @@ -118,6 +123,8 @@ def __call__(self, pipeline: OnlineSpeakerDiarization, batch_size: int = 32) -> Configured speaker diarization pipeline. batch_size: int Batch size. Defaults to 32. + verbose: bool + Whether to log its progress. Defaults to True. Returns ------- @@ -143,6 +150,7 @@ def __call__(self, pipeline: OnlineSpeakerDiarization, batch_size: int = 32) -> ) observable = pipeline.from_audio_source(source) else: + msg = f"Pre-calculating {filepath.stem} ({i + 1}/{num_audio_files})" source = src.PrecalculatedFeaturesAudioSource( filepath, pipeline.config.sample_rate, @@ -151,17 +159,20 @@ def __call__(self, pipeline: OnlineSpeakerDiarization, batch_size: int = 32) -> pipeline.config.duration, pipeline.config.step, batch_size, - progress_msg=f"Pre-calculating {filepath.stem} ({i + 1}/{num_audio_files})", + progress_msg=msg if verbose else None, ) observable = pipeline.from_feature_source(source) - observable.pipe( - dops.progress( - desc=f"Streaming {filepath.stem} ({i + 1}/{num_audio_files})", - total=num_chunks, - leave=isinstance(source, src.PrecalculatedFeaturesAudioSource) + if verbose: + observable = observable.pipe( + dops.progress( + desc=f"Streaming {filepath.stem} ({i + 1}/{num_audio_files})", + total=num_chunks, + leave=isinstance(source, src.PrecalculatedFeaturesAudioSource) + ) ) - ).subscribe(RTTMWriter(path=self.output_path / f"{filepath.stem}.rttm")) + + observable.subscribe(RTTMWriter(path=self.output_path / f"{filepath.stem}.rttm")) source.read() @@ -173,4 +184,4 @@ def __call__(self, pipeline: OnlineSpeakerDiarization, batch_size: int = 32) -> hyp = load_rttm(self.output_path / ref_path.name).popitem()[1] metric(ref, hyp) - return metric.report(display=True) + return metric.report(display=verbose) diff --git a/src/diart/optim.py b/src/diart/optim.py new file mode 100644 index 00000000..e1138ffa --- /dev/null +++ b/src/diart/optim.py @@ -0,0 +1,90 @@ +from pathlib import Path +from typing import Dict, Text, Tuple, Optional, Callable + +import optuna +from optuna.pruners._base import BasePruner +from optuna.samplers import TPESampler +from optuna.samplers._base import BaseSampler +from optuna.trial import Trial + +from audio import FilePath +from benchmark import Benchmark +from pipelines import PipelineConfig, OnlineSpeakerDiarization + + +class HyperParameterOptimizer: + def __init__( + self, + speech_path: FilePath, + reference_path: FilePath, + output_path: FilePath, + base_config: PipelineConfig, + hparams: Dict[Text, Tuple[float, float]], + batch_size: int = 32, + ): + self.base_config = base_config + self.hparams = hparams + self.batch_size = batch_size + self.output_path = Path(output_path).expanduser() + self.output_path.mkdir(parents=True, exist_ok=False) + self.tmp_path = self.output_path / "current_iter" + self.benchmark = Benchmark(speech_path, reference_path, self.tmp_path) + + def _objective(self, trial: Trial) -> float: + # Set suggested values for optimized hyper-parameters + trial_config = vars(self.base_config) + for hp_name, (low, high) in self.hparams.items(): + trial_config[hp_name] = trial.suggest_uniform(hp_name, low, high) + + # Instantiate pipeline with the new configuration + pipeline = OnlineSpeakerDiarization(PipelineConfig(**trial_config)) + + # Run pipeline over the dataset + report = self.benchmark(pipeline, self.batch_size, verbose=False) + + # Clean RTTM files + for tmp_file in self.tmp_path.iterdir(): + tmp_file.unlink() + + # Extract DER from report + return report.loc["TOTAL", "diarization error rate %"] + + def optimize( + self, + sampler: Optional[BaseSampler] = None, + pruner: Optional[BasePruner] = None, + num_iter: int = 100, + experiment_name: Optional[Text] = None, + ) -> Tuple[float, Dict[Text, float]]: + """Optimize the given hyper-parameters on the given dataset. + + Parameters + ---------- + sampler: Optional[optuna.BaseSampler] + The Optuna sampler to use during optimization. Defaults to TPESampler. + pruner: Optional[optuna.BasePruner] + The Optuna pruner to use during optimization. Defaults to None. + num_iter: int + Number of iterations over the dataset. Defaults to 100. + experiment_name: Optional[Text] + Name of the optimization run. Defaults to None. + + Returns + ------- + der_value: float + Diarization error rate of the best iteration. + best_hparams: Dict[Text, float] + Hyper-parameters of the best iteration. + """ + sampler = TPESampler() if sampler is None else sampler + storage = self.output_path / "trials.db" + study = optuna.create_study( + storage=f"sqlite://{str(storage)}", + sampler=sampler, + pruner=pruner, + study_name=experiment_name, + direction="minimize", + load_if_exists=True, + ) + study.optimize(self._objective, n_trials=num_iter) + return study.best_value, study.best_params diff --git a/src/diart/pipelines.py b/src/diart/pipelines.py index dea360d6..fa27c087 100644 --- a/src/diart/pipelines.py +++ b/src/diart/pipelines.py @@ -1,5 +1,4 @@ -from argparse import Namespace -from typing import Optional, List +from typing import Optional, List, Any import rx import rx.operators as ops @@ -64,22 +63,25 @@ def __init__( self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @staticmethod - def from_namespace(args: Namespace) -> 'PipelineConfig': + def from_namespace(args: Any) -> 'PipelineConfig': return PipelineConfig( segmentation=getattr(args, "segmentation", None), embedding=getattr(args, "embedding", None), duration=getattr(args, "duration", None), step=args.step, latency=args.latency, - tau_active=args.tau, - rho_update=args.rho, - delta_new=args.delta, + tau_active=args.tau_active, + rho_update=args.rho_update, + delta_new=args.delta_new, gamma=args.gamma, beta=args.beta, max_speakers=args.max_speakers, - device=torch.device("cpu") if args.cpu else None, + device=args.device, ) + def copy(self) -> 'PipelineConfig': + return PipelineConfig.from_namespace(self) + def last_chunk_end_time(self, conv_duration: float) -> Optional[float]: """Return the end time of the last chunk for a given conversation duration. diff --git a/src/diart/stream.py b/src/diart/stream.py index 0ef8a5a6..1e548bf6 100644 --- a/src/diart/stream.py +++ b/src/diart/stream.py @@ -1,6 +1,8 @@ import argparse from pathlib import Path +import torch + import diart.argdoc as argdoc import diart.sources as src from diart.inference import RealTimeInference @@ -22,6 +24,11 @@ parser.add_argument("--output", type=str, help=f"{argdoc.OUTPUT}. Defaults to home directory if SOURCE == 'microphone' or parent directory if SOURCE is a file") args = parser.parse_args() + args.device = torch.device("cpu") if args.cpu else None + args.tau_active = args.tau + args.rho_update = args.rho + args.delta_new = args.delta + # Define online speaker diarization pipeline config = PipelineConfig.from_namespace(args) pipeline = OnlineSpeakerDiarization(config, profile=True) From e1a647091e8a79b42dc71c52ae800c5f8575d8ef Mon Sep 17 00:00:00 2001 From: Juan Manuel Coria Date: Wed, 15 Jun 2022 17:08:27 +0200 Subject: [PATCH 20/29] Improve hyper-parameter optimization --- README.md | 43 +++++++++++- src/diart/benchmark.py | 16 +++-- src/diart/inference.py | 30 ++++----- src/diart/optim.py | 150 +++++++++++++++++++++++++---------------- src/diart/pipelines.py | 10 ++- src/diart/stream.py | 4 -- 6 files changed, 160 insertions(+), 93 deletions(-) diff --git a/README.md b/README.md index 4a32803c..2a52b0d5 100644 --- a/README.md +++ b/README.md @@ -84,9 +84,50 @@ inference(pipeline, audio_source) For faster inference and evaluation on a dataset we recommend to use `Benchmark` (see our notes on [reproducibility](#reproducibility)) +## Optimize hyper-parameters to your own dataset + +Diart implements a hyper-parameter optimizer based on [optuna](https://github.com/optuna/optuna). +`diart.optim.Optimizer` allows you to tune any pipeline to a custom dataset. +More information on Optuna can be found [here](https://optuna.readthedocs.io/en/stable/index.html). + +### A simple example + +```python +from diart.optim import OptimizationObjective, Optimizer, TauActive, RhoUpdate, DeltaNew +from diart.pipelines import PipelineConfig +from diart.inference import Benchmark + +# Benchmark runs and evaluates the pipeline with each configuration +benchmark = Benchmark("/wav/dir", "/rttm/dir", "/out/dir", show_report=False) +# Base configuration for the pipeline we're going to tune +base_config = PipelineConfig(duration=5, step=0.5, latency=5) +# Hyper-parameters to optimize +hparams = [TauActive, RhoUpdate, DeltaNew] +# The objective implements an optimization step +objective = OptimizationObjective(benchmark, base_config, hparams) +# Run optimization for 100 iterations +Optimizer(objective).optimize(num_iter=100, show_progress=True) +``` + +### Distributed optimization + +For bigger datasets, it is sometimes more convenient to run optimization in parallel. +If the same `study_name` and `storage` are given to the optimizer, all optimization processes will share the information from previous runs. +More information on distributed optimization can be found [here](https://optuna.readthedocs.io/en/stable/tutorial/10_key_features/004_distributed.html#sphx-glr-tutorial-10-key-features-004-distributed-py). + +```python +from diart.optim import Optimizer + +objective = ... +study_name = "my_study" +storage = "mysql://root@localhost/example" +optimizer = Optimizer(objective, study_name, storage) +optimizer.optimize(num_iter=100, show_progress=True) +``` + ## Build your own pipeline -Diart also provides building blocks that can be combined to create your own pipeline. +For a more advanced usage, diart also provides building blocks that can be combined to create your own pipelines. Streaming is powered by [RxPY](https://github.com/ReactiveX/RxPY), but the `blocks` module is completely independent and can be used separately. ### Example diff --git a/src/diart/benchmark.py b/src/diart/benchmark.py index c83d9e54..b5fcc567 100644 --- a/src/diart/benchmark.py +++ b/src/diart/benchmark.py @@ -22,13 +22,15 @@ parser.add_argument("--cpu", dest="cpu", action="store_true", help=f"{argdoc.CPU}. Defaults to GPU if available, CPU otherwise") parser.add_argument("--output", type=str, help=f"{argdoc.OUTPUT}. Defaults to `root`") args = parser.parse_args() - args.device = torch.device("cpu") if args.cpu else None - args.tau_active = args.tau - args.rho_update = args.rho - args.delta_new = args.delta - Benchmark(args.root, args.reference, args.output)( - OnlineSpeakerDiarization(PipelineConfig.from_namespace(args), profile=True), - args.batch_size + benchmark = Benchmark( + args.root, + args.reference, + args.output, + show_progress=True, + show_report=True, + batch_size=args.batch_size ) + + benchmark(OnlineSpeakerDiarization(PipelineConfig.from_namespace(args), profile=True)) diff --git a/src/diart/inference.py b/src/diart/inference.py index ec31a2fe..e58240e9 100644 --- a/src/diart/inference.py +++ b/src/diart/inference.py @@ -92,6 +92,9 @@ def __init__( speech_path: Union[Text, Path], reference_path: Optional[Union[Text, Path]] = None, output_path: Optional[Union[Text, Path]] = None, + show_progress: bool = True, + show_report: bool = True, + batch_size: int = 32, ): self.speech_path = Path(speech_path).expanduser() assert self.speech_path.is_dir(), "Speech path must be a directory" @@ -107,12 +110,11 @@ def __init__( self.output_path = Path(output_path).expanduser() self.output_path.mkdir(parents=True, exist_ok=True) - def __call__( - self, - pipeline: OnlineSpeakerDiarization, - batch_size: int = 32, - verbose: bool = True - ) -> Optional[pd.DataFrame]: + self.show_progress = show_progress + self.show_report = show_report + self.batch_size = batch_size + + def __call__(self, pipeline: OnlineSpeakerDiarization) -> Optional[pd.DataFrame]: """ Run a given pipeline on a set of audio files using pre-calculated segmentation and embeddings in batches. @@ -121,10 +123,6 @@ def __call__( ---------- pipeline: OnlineSpeakerDiarization Configured speaker diarization pipeline. - batch_size: int - Batch size. Defaults to 32. - verbose: bool - Whether to log its progress. Defaults to True. Returns ------- @@ -141,7 +139,7 @@ def __call__( ) # Stream fully online if batch size is 1 or lower - if batch_size < 2: + if self.batch_size < 2: source = src.FileAudioSource( filepath, pipeline.config.sample_rate, @@ -158,17 +156,17 @@ def __call__( pipeline.embedding, pipeline.config.duration, pipeline.config.step, - batch_size, - progress_msg=msg if verbose else None, + self.batch_size, + progress_msg=msg if self.show_progress else None, ) observable = pipeline.from_feature_source(source) - if verbose: + if self.show_progress: observable = observable.pipe( dops.progress( desc=f"Streaming {filepath.stem} ({i + 1}/{num_audio_files})", total=num_chunks, - leave=isinstance(source, src.PrecalculatedFeaturesAudioSource) + leave=False, ) ) @@ -184,4 +182,4 @@ def __call__( hyp = load_rttm(self.output_path / ref_path.name).popitem()[1] metric(ref, hyp) - return metric.report(display=verbose) + return metric.report(display=self.show_report) diff --git a/src/diart/optim.py b/src/diart/optim.py index e1138ffa..4260f952 100644 --- a/src/diart/optim.py +++ b/src/diart/optim.py @@ -1,90 +1,122 @@ +from collections import OrderedDict +from dataclasses import dataclass from pathlib import Path -from typing import Dict, Text, Tuple, Optional, Callable +from typing import Iterable, Text, Optional -import optuna -from optuna.pruners._base import BasePruner -from optuna.samplers import TPESampler -from optuna.samplers._base import BaseSampler -from optuna.trial import Trial +from optuna import TrialPruned, Study, create_study +from optuna.pruners import BasePruner +from optuna.samplers import TPESampler, BaseSampler +from optuna.trial import Trial, FrozenTrial +from tqdm import trange, tqdm -from audio import FilePath -from benchmark import Benchmark -from pipelines import PipelineConfig, OnlineSpeakerDiarization +from .benchmark import Benchmark +from .pipelines import PipelineConfig, OnlineSpeakerDiarization -class HyperParameterOptimizer: +@dataclass +class HyperParameter: + name: Text + low: float + high: float + + +TauActive = HyperParameter("tau_active", low=0, high=1) +RhoUpdate = HyperParameter("rho_update", low=0, high=1) +DeltaNew = HyperParameter("delta_new", low=0, high=2) + + +class OptimizationObjective: def __init__( self, - speech_path: FilePath, - reference_path: FilePath, - output_path: FilePath, + benchmark: Benchmark, base_config: PipelineConfig, - hparams: Dict[Text, Tuple[float, float]], - batch_size: int = 32, + hparams: Iterable[HyperParameter], ): + self.benchmark = benchmark self.base_config = base_config self.hparams = hparams - self.batch_size = batch_size - self.output_path = Path(output_path).expanduser() - self.output_path.mkdir(parents=True, exist_ok=False) - self.tmp_path = self.output_path / "current_iter" - self.benchmark = Benchmark(speech_path, reference_path, self.tmp_path) - def _objective(self, trial: Trial) -> float: + def __call__(self, trial: Trial) -> float: # Set suggested values for optimized hyper-parameters trial_config = vars(self.base_config) - for hp_name, (low, high) in self.hparams.items(): - trial_config[hp_name] = trial.suggest_uniform(hp_name, low, high) + for hparam in self.hparams: + trial_config[hparam.name] = trial.suggest_uniform( + hparam.name, hparam.low, hparam.high + ) # Instantiate pipeline with the new configuration pipeline = OnlineSpeakerDiarization(PipelineConfig(**trial_config)) + # Prune trial if required + if trial.should_prune(): + raise TrialPruned() + # Run pipeline over the dataset - report = self.benchmark(pipeline, self.batch_size, verbose=False) + report = self.benchmark(pipeline) # Clean RTTM files - for tmp_file in self.tmp_path.iterdir(): - tmp_file.unlink() + for tmp_file in self.benchmark.output_path.iterdir(): + if tmp_file.name.endswith(".rttm"): + tmp_file.unlink() # Extract DER from report - return report.loc["TOTAL", "diarization error rate %"] + return report.loc["TOTAL", "diarization error rate"]["%"] + - def optimize( +class Optimizer: + def __init__( self, + objective: OptimizationObjective, + study_name: Optional[Text] = None, + storage: Optional[Text] = None, sampler: Optional[BaseSampler] = None, pruner: Optional[BasePruner] = None, - num_iter: int = 100, - experiment_name: Optional[Text] = None, - ) -> Tuple[float, Dict[Text, float]]: - """Optimize the given hyper-parameters on the given dataset. - - Parameters - ---------- - sampler: Optional[optuna.BaseSampler] - The Optuna sampler to use during optimization. Defaults to TPESampler. - pruner: Optional[optuna.BasePruner] - The Optuna pruner to use during optimization. Defaults to None. - num_iter: int - Number of iterations over the dataset. Defaults to 100. - experiment_name: Optional[Text] - Name of the optimization run. Defaults to None. - - Returns - ------- - der_value: float - Diarization error rate of the best iteration. - best_hparams: Dict[Text, float] - Hyper-parameters of the best iteration. - """ - sampler = TPESampler() if sampler is None else sampler - storage = self.output_path / "trials.db" - study = optuna.create_study( - storage=f"sqlite://{str(storage)}", - sampler=sampler, + ): + self.objective = objective + self.study = create_study( + storage=self.default_storage if storage is None else storage, + sampler=TPESampler() if sampler is None else sampler, pruner=pruner, - study_name=experiment_name, + study_name=self.default_study_name if study_name is None else study_name, direction="minimize", load_if_exists=True, ) - study.optimize(self._objective, n_trials=num_iter) - return study.best_value, study.best_params + self._progress: Optional[tqdm] = None + + @property + def default_output_path(self) -> Path: + return self.objective.benchmark.output_path.parent + + @property + def default_study_name(self) -> Text: + return self.default_output_path.name + + @property + def default_storage(self) -> Text: + return "sqlite:///" + str(self.default_output_path / "trials.db") + + @property + def best_performance(self): + return self.study.best_value + + @property + def best_hparams(self): + return self.study.best_params + + def _callback(self, study: Study, trial: FrozenTrial): + if self._progress is None: + return + self._progress.update(1) + self._progress.set_description(f"Trial {trial.number + 1}") + values = {"best_der": study.best_value} + for name, value in study.best_params.items(): + values[f"best_{name}"] = value + self._progress.set_postfix(OrderedDict(values)) + + def optimize(self, num_iter: int, show_progress: bool = True): + self._progress = None + if show_progress: + self._progress = trange(num_iter) + last_trial = self.study.trials[-1].number + self._progress.set_description(f"Trial {last_trial + 1}") + self.study.optimize(self.objective, num_iter, callbacks=[self._callback]) diff --git a/src/diart/pipelines.py b/src/diart/pipelines.py index fa27c087..391d60b3 100644 --- a/src/diart/pipelines.py +++ b/src/diart/pipelines.py @@ -26,6 +26,7 @@ def __init__( beta: float = 10, max_speakers: int = 20, device: Optional[torch.device] = None, + **kwargs, ): # Default segmentation model is pyannote/segmentation self.segmentation = segmentation @@ -70,18 +71,15 @@ def from_namespace(args: Any) -> 'PipelineConfig': duration=getattr(args, "duration", None), step=args.step, latency=args.latency, - tau_active=args.tau_active, - rho_update=args.rho_update, - delta_new=args.delta_new, + tau_active=args.tau, + rho_update=args.rho, + delta_new=args.delta, gamma=args.gamma, beta=args.beta, max_speakers=args.max_speakers, device=args.device, ) - def copy(self) -> 'PipelineConfig': - return PipelineConfig.from_namespace(self) - def last_chunk_end_time(self, conv_duration: float) -> Optional[float]: """Return the end time of the last chunk for a given conversation duration. diff --git a/src/diart/stream.py b/src/diart/stream.py index 1e548bf6..c263364d 100644 --- a/src/diart/stream.py +++ b/src/diart/stream.py @@ -23,11 +23,7 @@ parser.add_argument("--cpu", dest="cpu", action="store_true", help=f"{argdoc.CPU}. Defaults to GPU if available, CPU otherwise") parser.add_argument("--output", type=str, help=f"{argdoc.OUTPUT}. Defaults to home directory if SOURCE == 'microphone' or parent directory if SOURCE is a file") args = parser.parse_args() - args.device = torch.device("cpu") if args.cpu else None - args.tau_active = args.tau - args.rho_update = args.rho - args.delta_new = args.delta # Define online speaker diarization pipeline config = PipelineConfig.from_namespace(args) From 74c2d12583f124e6191e4fc74e83ffe561562889 Mon Sep 17 00:00:00 2001 From: Juan Manuel Coria Date: Wed, 15 Jun 2022 17:16:33 +0200 Subject: [PATCH 21/29] Remove OptimizationObjective --- README.md | 16 +++++------ src/diart/optim.py | 70 ++++++++++++++++++++-------------------------- 2 files changed, 39 insertions(+), 47 deletions(-) diff --git a/README.md b/README.md index 2a52b0d5..c0aee4f6 100644 --- a/README.md +++ b/README.md @@ -93,20 +93,20 @@ More information on Optuna can be found [here](https://optuna.readthedocs.io/en/ ### A simple example ```python -from diart.optim import OptimizationObjective, Optimizer, TauActive, RhoUpdate, DeltaNew +from diart.optim import Optimizer, TauActive, RhoUpdate, DeltaNew from diart.pipelines import PipelineConfig from diart.inference import Benchmark -# Benchmark runs and evaluates the pipeline with each configuration +# Benchmark runs and evaluates the pipeline on a dataset benchmark = Benchmark("/wav/dir", "/rttm/dir", "/out/dir", show_report=False) # Base configuration for the pipeline we're going to tune base_config = PipelineConfig(duration=5, step=0.5, latency=5) # Hyper-parameters to optimize hparams = [TauActive, RhoUpdate, DeltaNew] -# The objective implements an optimization step -objective = OptimizationObjective(benchmark, base_config, hparams) -# Run optimization for 100 iterations -Optimizer(objective).optimize(num_iter=100, show_progress=True) +# Optimizer implements the optimization loop +optimizer = Optimizer(benchmark, base_config, hparams) +# Run optimization +optimizer.optimize(num_iter=100, show_progress=True) ``` ### Distributed optimization @@ -118,10 +118,10 @@ More information on distributed optimization can be found [here](https://optuna. ```python from diart.optim import Optimizer -objective = ... +benchmark, base_config, hparams = ... study_name = "my_study" storage = "mysql://root@localhost/example" -optimizer = Optimizer(objective, study_name, storage) +optimizer = Optimizer(benchmark, base_config, hparams, study_name, storage) optimizer.optimize(num_iter=100, show_progress=True) ``` diff --git a/src/diart/optim.py b/src/diart/optim.py index 4260f952..200fb642 100644 --- a/src/diart/optim.py +++ b/src/diart/optim.py @@ -25,54 +25,20 @@ class HyperParameter: DeltaNew = HyperParameter("delta_new", low=0, high=2) -class OptimizationObjective: +class Optimizer: def __init__( self, benchmark: Benchmark, base_config: PipelineConfig, hparams: Iterable[HyperParameter], - ): - self.benchmark = benchmark - self.base_config = base_config - self.hparams = hparams - - def __call__(self, trial: Trial) -> float: - # Set suggested values for optimized hyper-parameters - trial_config = vars(self.base_config) - for hparam in self.hparams: - trial_config[hparam.name] = trial.suggest_uniform( - hparam.name, hparam.low, hparam.high - ) - - # Instantiate pipeline with the new configuration - pipeline = OnlineSpeakerDiarization(PipelineConfig(**trial_config)) - - # Prune trial if required - if trial.should_prune(): - raise TrialPruned() - - # Run pipeline over the dataset - report = self.benchmark(pipeline) - - # Clean RTTM files - for tmp_file in self.benchmark.output_path.iterdir(): - if tmp_file.name.endswith(".rttm"): - tmp_file.unlink() - - # Extract DER from report - return report.loc["TOTAL", "diarization error rate"]["%"] - - -class Optimizer: - def __init__( - self, - objective: OptimizationObjective, study_name: Optional[Text] = None, storage: Optional[Text] = None, sampler: Optional[BaseSampler] = None, pruner: Optional[BasePruner] = None, ): - self.objective = objective + self.benchmark = benchmark + self.base_config = base_config + self.hparams = hparams self.study = create_study( storage=self.default_storage if storage is None else storage, sampler=TPESampler() if sampler is None else sampler, @@ -85,7 +51,7 @@ def __init__( @property def default_output_path(self) -> Path: - return self.objective.benchmark.output_path.parent + return self.benchmark.output_path.parent @property def default_study_name(self) -> Text: @@ -113,6 +79,32 @@ def _callback(self, study: Study, trial: FrozenTrial): values[f"best_{name}"] = value self._progress.set_postfix(OrderedDict(values)) + def objective(self, trial: Trial) -> float: + # Set suggested values for optimized hyper-parameters + trial_config = vars(self.base_config) + for hparam in self.hparams: + trial_config[hparam.name] = trial.suggest_uniform( + hparam.name, hparam.low, hparam.high + ) + + # Instantiate pipeline with the new configuration + pipeline = OnlineSpeakerDiarization(PipelineConfig(**trial_config)) + + # Prune trial if required + if trial.should_prune(): + raise TrialPruned() + + # Run pipeline over the dataset + report = self.benchmark(pipeline) + + # Clean RTTM files + for tmp_file in self.benchmark.output_path.iterdir(): + if tmp_file.name.endswith(".rttm"): + tmp_file.unlink() + + # Extract DER from report + return report.loc["TOTAL", "diarization error rate"]["%"] + def optimize(self, num_iter: int, show_progress: bool = True): self._progress = None if show_progress: From 69789b25dc3f23b9bc5057159c286f14db438e04 Mon Sep 17 00:00:00 2001 From: Juan Manuel Coria Date: Wed, 15 Jun 2022 17:48:52 +0200 Subject: [PATCH 22/29] Improve README.md --- README.md | 55 ++++++++++++++++++++++++++++++++++++++++++++----------- 1 file changed, 44 insertions(+), 11 deletions(-) diff --git a/README.md b/README.md index c0aee4f6..5ce6a4c5 100644 --- a/README.md +++ b/README.md @@ -12,13 +12,45 @@ License

+ +

-## Install +## Installation 1) Create environment: @@ -51,13 +83,15 @@ pip install diart ## Stream your own audio -### A recorded conversation +### From the command line + +A recorded conversation: ```shell python -m diart.stream /path/to/audio.wav ``` -### From your microphone +A live conversation: ```shell python -m diart.stream microphone @@ -65,9 +99,9 @@ python -m diart.stream microphone See `python -m diart.stream -h` for more options. -## Inference API +### From python -Run a customized real-time speaker diarization pipeline over an audio stream with `diart.inference.RealTimeInference`: +Run a real-time speaker diarization pipeline over an audio stream with `RealTimeInference`: ```python from diart.sources import MicrophoneAudioSource @@ -78,17 +112,16 @@ config = PipelineConfig() # Default parameters pipeline = OnlineSpeakerDiarization(config) audio_source = MicrophoneAudioSource(config.sample_rate) inference = RealTimeInference("/output/path", do_plot=True) - inference(pipeline, audio_source) ``` -For faster inference and evaluation on a dataset we recommend to use `Benchmark` (see our notes on [reproducibility](#reproducibility)) +For faster inference and evaluation on a dataset we recommend to use `Benchmark` instead (see our notes on [reproducibility](#reproducibility)). + +## Optimize hyper-parameters to your dataset -## Optimize hyper-parameters to your own dataset +Diart implements a hyper-parameter optimizer based on optuna that allows you to tune any pipeline to any dataset. -Diart implements a hyper-parameter optimizer based on [optuna](https://github.com/optuna/optuna). -`diart.optim.Optimizer` allows you to tune any pipeline to a custom dataset. -More information on Optuna can be found [here](https://optuna.readthedocs.io/en/stable/index.html). +[More about optuna](https://optuna.readthedocs.io/en/stable/index.html). ### A simple example From 278b1dcbf82df344a8cc64df9a0fa79bd9e58a88 Mon Sep 17 00:00:00 2001 From: Juan Manuel Coria Date: Thu, 16 Jun 2022 12:06:45 +0200 Subject: [PATCH 23/29] Improve optimization API --- README.md | 36 ++++++++++++++++++------------- src/diart/optim.py | 53 ++++++++++++++++++++-------------------------- 2 files changed, 44 insertions(+), 45 deletions(-) diff --git a/README.md b/README.md index 5ce6a4c5..1b36099a 100644 --- a/README.md +++ b/README.md @@ -18,15 +18,15 @@ Installation | - + Stream audio | - + Tune hyper-parameters | - + Build pipelines | @@ -81,7 +81,7 @@ pip install git+https://github.com/pyannote/pyannote-audio.git@develop#egg=pyann pip install diart ``` -## Stream your own audio +## Stream audio ### From the command line @@ -117,7 +117,7 @@ inference(pipeline, audio_source) For faster inference and evaluation on a dataset we recommend to use `Benchmark` instead (see our notes on [reproducibility](#reproducibility)). -## Optimize hyper-parameters to your dataset +## Tune hyper-parameters Diart implements a hyper-parameter optimizer based on optuna that allows you to tune any pipeline to any dataset. @@ -137,30 +137,36 @@ base_config = PipelineConfig(duration=5, step=0.5, latency=5) # Hyper-parameters to optimize hparams = [TauActive, RhoUpdate, DeltaNew] # Optimizer implements the optimization loop -optimizer = Optimizer(benchmark, base_config, hparams) +optimizer = Optimizer(benchmark, base_config, hparams, "/db/out/dir") # Run optimization optimizer.optimize(num_iter=100, show_progress=True) ``` +This will store temporary predictions in `/out/dir` and write results of each trial in an sqlite database `/db/out/dir/trials.db`. + ### Distributed optimization -For bigger datasets, it is sometimes more convenient to run optimization in parallel. -If the same `study_name` and `storage` are given to the optimizer, all optimization processes will share the information from previous runs. -More information on distributed optimization can be found [here](https://optuna.readthedocs.io/en/stable/tutorial/10_key_features/004_distributed.html#sphx-glr-tutorial-10-key-features-004-distributed-py). +For bigger datasets, it is sometimes more convenient to run multiple optimization processes in parallel. +To do this, make sure that the output directory or the study given to the optimizer is the same in each process. +Notice that the output directories of `Benchmark` MUST be different to avoid concurrency issues. ```python from diart.optim import Optimizer +from diart.inference import Benchmark -benchmark, base_config, hparams = ... -study_name = "my_study" -storage = "mysql://root@localhost/example" -optimizer = Optimizer(benchmark, base_config, hparams, study_name, storage) +ID = 0 # Worker identifier +base_config, hparams = ... +benchmark = Benchmark("/wav/dir", "/rttm/dir", f"/out/dir/worker_{ID}", show_report=False) +optimizer = Optimizer(benchmark, base_config, hparams, "/db/out/dir") optimizer.optimize(num_iter=100, show_progress=True) ``` -## Build your own pipeline +It is recommended to use other databases like mysql instead of sqlite in distributed optimization. +More on this [here](https://optuna.readthedocs.io/en/stable/tutorial/10_key_features/004_distributed.html#sphx-glr-tutorial-10-key-features-004-distributed-py). + +## Build pipelines -For a more advanced usage, diart also provides building blocks that can be combined to create your own pipelines. +For a more advanced usage, diart also provides building blocks that can be combined to create your own pipeline. Streaming is powered by [RxPY](https://github.com/ReactiveX/RxPY), but the `blocks` module is completely independent and can be used separately. ### Example diff --git a/src/diart/optim.py b/src/diart/optim.py index 200fb642..a71bd571 100644 --- a/src/diart/optim.py +++ b/src/diart/optim.py @@ -1,14 +1,14 @@ from collections import OrderedDict from dataclasses import dataclass from pathlib import Path -from typing import Iterable, Text, Optional +from typing import Iterable, Text, Optional, Union from optuna import TrialPruned, Study, create_study -from optuna.pruners import BasePruner -from optuna.samplers import TPESampler, BaseSampler +from optuna.samplers import TPESampler from optuna.trial import Trial, FrozenTrial from tqdm import trange, tqdm +from .audio import FilePath from .benchmark import Benchmark from .pipelines import PipelineConfig, OnlineSpeakerDiarization @@ -31,35 +31,26 @@ def __init__( benchmark: Benchmark, base_config: PipelineConfig, hparams: Iterable[HyperParameter], - study_name: Optional[Text] = None, - storage: Optional[Text] = None, - sampler: Optional[BaseSampler] = None, - pruner: Optional[BasePruner] = None, + study_or_path: Union[FilePath, Study], ): self.benchmark = benchmark self.base_config = base_config self.hparams = hparams - self.study = create_study( - storage=self.default_storage if storage is None else storage, - sampler=TPESampler() if sampler is None else sampler, - pruner=pruner, - study_name=self.default_study_name if study_name is None else study_name, - direction="minimize", - load_if_exists=True, - ) self._progress: Optional[tqdm] = None - @property - def default_output_path(self) -> Path: - return self.benchmark.output_path.parent - - @property - def default_study_name(self) -> Text: - return self.default_output_path.name - - @property - def default_storage(self) -> Text: - return "sqlite:///" + str(self.default_output_path / "trials.db") + if isinstance(study_or_path, Study): + self.study = study_or_path + elif isinstance(study_or_path, str) or isinstance(study_or_path, Path): + self.study = create_study( + storage="sqlite:///" + str(study_or_path / "trials.db"), + sampler=TPESampler(), + study_name=study_or_path.name, + direction="minimize", + load_if_exists=True, + ) + else: + msg = f"Expected Study object or path-like, but got {type(study_or_path).__name__}" + raise ValueError(msg) @property def best_performance(self): @@ -87,13 +78,13 @@ def objective(self, trial: Trial) -> float: hparam.name, hparam.low, hparam.high ) - # Instantiate pipeline with the new configuration - pipeline = OnlineSpeakerDiarization(PipelineConfig(**trial_config)) - # Prune trial if required if trial.should_prune(): raise TrialPruned() + # Instantiate pipeline with the new configuration + pipeline = OnlineSpeakerDiarization(PipelineConfig(**trial_config), profile=False) + # Run pipeline over the dataset report = self.benchmark(pipeline) @@ -109,6 +100,8 @@ def optimize(self, num_iter: int, show_progress: bool = True): self._progress = None if show_progress: self._progress = trange(num_iter) - last_trial = self.study.trials[-1].number + last_trial = -1 + if self.study.trials: + last_trial = self.study.trials[-1].number self._progress.set_description(f"Trial {last_trial + 1}") self.study.optimize(self.objective, num_iter, callbacks=[self._callback]) From 780681c04f28b56aba227c71f4e06f77414eb88d Mon Sep 17 00:00:00 2001 From: Juan Manuel Coria Date: Thu, 16 Jun 2022 16:43:41 +0200 Subject: [PATCH 24/29] Add diart.tune script --- README.md | 79 +++++++++++++++++++++++++----------------- src/diart/argdoc.py | 1 + src/diart/benchmark.py | 2 +- src/diart/optim.py | 14 ++++++-- src/diart/tune.py | 68 ++++++++++++++++++++++++++++++++++++ 5 files changed, 129 insertions(+), 35 deletions(-) create mode 100644 src/diart/tune.py diff --git a/README.md b/README.md index 1b36099a..204a3a2f 100644 --- a/README.md +++ b/README.md @@ -119,11 +119,17 @@ For faster inference and evaluation on a dataset we recommend to use `Benchmark` ## Tune hyper-parameters -Diart implements a hyper-parameter optimizer based on optuna that allows you to tune any pipeline to any dataset. +Diart implements a hyper-parameter optimizer based on [optuna](https://optuna.readthedocs.io/en/stable/index.html) that allows you to tune any pipeline to any dataset. -[More about optuna](https://optuna.readthedocs.io/en/stable/index.html). +### From the command line -### A simple example +```shell +python -m diart.tune /wav/dir --reference /rttm/dir --output /out/dir +``` + +See `python -m diart.tune -h` for more options. + +### From python ```python from diart.optim import Optimizer, TauActive, RhoUpdate, DeltaNew @@ -131,39 +137,51 @@ from diart.pipelines import PipelineConfig from diart.inference import Benchmark # Benchmark runs and evaluates the pipeline on a dataset -benchmark = Benchmark("/wav/dir", "/rttm/dir", "/out/dir", show_report=False) +benchmark = Benchmark("/wav/dir", "/rttm/dir", "/out/dir/tmp", show_report=False) # Base configuration for the pipeline we're going to tune -base_config = PipelineConfig(duration=5, step=0.5, latency=5) +base_config = PipelineConfig() # Hyper-parameters to optimize hparams = [TauActive, RhoUpdate, DeltaNew] # Optimizer implements the optimization loop -optimizer = Optimizer(benchmark, base_config, hparams, "/db/out/dir") +optimizer = Optimizer(benchmark, base_config, hparams, "/out/dir") # Run optimization optimizer.optimize(num_iter=100, show_progress=True) ``` -This will store temporary predictions in `/out/dir` and write results of each trial in an sqlite database `/db/out/dir/trials.db`. +This will use `/out/dir/tmp` as a working directory and write results to an sqlite database in `/out/dir`. ### Distributed optimization For bigger datasets, it is sometimes more convenient to run multiple optimization processes in parallel. -To do this, make sure that the output directory or the study given to the optimizer is the same in each process. -Notice that the output directories of `Benchmark` MUST be different to avoid concurrency issues. +To do this, create a study on a [recommended DBMS](https://optuna.readthedocs.io/en/stable/tutorial/10_key_features/004_distributed.html#sphx-glr-tutorial-10-key-features-004-distributed-py) (e.g. MySQL or PostgreSQL) making sure that the study and database names match: + +```shell +mysql -u root -e "CREATE DATABASE IF NOT EXISTS example" +optuna create-study --study-name "example" --storage "mysql://root@localhost/example" +``` + +Then you can run multiple identical optimizers pointing to the database: + +```shell +python -m diart.tune /wav/dir --reference /rttm/dir --output /out/dir --storage mysql://root@localhost/example +``` + +If you are using the python API, make sure that worker directories are different to avoid concurrency issues: ```python from diart.optim import Optimizer from diart.inference import Benchmark +from optuna.samplers import TPESampler +import optuna ID = 0 # Worker identifier base_config, hparams = ... -benchmark = Benchmark("/wav/dir", "/rttm/dir", f"/out/dir/worker_{ID}", show_report=False) -optimizer = Optimizer(benchmark, base_config, hparams, "/db/out/dir") +benchmark = Benchmark("/wav/dir", "/rttm/dir", f"/out/dir/worker-{ID}", show_report=False) +study = optuna.load_study("example", "mysql://root@localhost/example", TPESampler()) +optimizer = Optimizer(benchmark, base_config, hparams, study) optimizer.optimize(num_iter=100, show_progress=True) ``` -It is recommended to use other databases like mysql instead of sqlite in distributed optimization. -More on this [here](https://optuna.readthedocs.io/en/stable/tutorial/10_key_features/004_distributed.html#sphx-glr-tutorial-10-key-features-004-distributed-py). - ## Build pipelines For a more advanced usage, diart also provides building blocks that can be combined to create your own pipeline. @@ -174,30 +192,27 @@ Streaming is powered by [RxPY](https://github.com/ReactiveX/RxPY), but the `bloc Obtain overlap-aware speaker embeddings from a microphone stream: ```python -import rx import rx.operators as ops import diart.operators as dops from diart.sources import MicrophoneAudioSource from diart.blocks import SpeakerSegmentation, OverlapAwareSpeakerEmbedding from diart.models import SegmentationModel, EmbeddingModel -# Initialize independent modules -seg_model = SegmentationModel.from_pyannote("pyannote/segmentation") -segmentation = SpeakerSegmentation(seg_model) -emb_model = EmbeddingModel.from_pyannote("pyannote/embedding") -embedding = OverlapAwareSpeakerEmbedding(emb_model) -mic = MicrophoneAudioSource(seg_model.get_sample_rate()) - -# Reformat microphone stream. Defaults to 5s duration and 500ms shift -regular_stream = mic.stream.pipe(dops.regularize_audio_stream(seg_model.get_sample_rate())) -# Branch the microphone stream to calculate segmentation -segmentation_stream = regular_stream.pipe(ops.map(segmentation)) -# Join audio and segmentation stream to calculate speaker embeddings -embedding_stream = rx.zip( - regular_stream, segmentation_stream -).pipe(ops.starmap(embedding)) +segmentation = SpeakerSegmentation( + SegmentationModel.from_pyannote("pyannote/segmentation") +) +embedding = OverlapAwareSpeakerEmbedding( + EmbeddingModel.from_pyannote("pyannote/embedding") +) +sample_rate = segmentation.model.get_sample_rate() +mic = MicrophoneAudioSource(sample_rate) -embedding_stream.subscribe(on_next=lambda emb: print(emb.shape)) +stream = mic.stream.pipe( + # Reformat stream to 5s duration and 500ms shift + dops.regularize_audio_stream(sample_rate), + ops.map(lambda wav: (wav, segmentation(wav))), + ops.starmap(embedding) +).subscribe(on_next=lambda emb: print(emb.shape)) mic.read() ``` @@ -276,7 +291,7 @@ config = PipelineConfig( pipeline = OnlineSpeakerDiarization(config) benchmark = Benchmark("/wav/dir", "/rttm/dir", "/out/dir") -benchmark(pipeline, batch_size=32) +benchmark(pipeline) ``` This runs a faster inference by pre-calculating model outputs in batches. diff --git a/src/diart/argdoc.py b/src/diart/argdoc.py index 640dae14..0fb2a1a6 100644 --- a/src/diart/argdoc.py +++ b/src/diart/argdoc.py @@ -7,4 +7,5 @@ BETA = "Parameter beta for overlapped speech penalty" MAX_SPEAKERS = "Maximum number of speakers" CPU = "Force models to run on CPU" +BATCH_SIZE = "For segmentation and embedding pre-calculation. If BATCH_SIZE < 2, run fully online and estimate real-time latency" OUTPUT = "Directory to store the system's output in RTTM format" diff --git a/src/diart/benchmark.py b/src/diart/benchmark.py index b5fcc567..674942a3 100644 --- a/src/diart/benchmark.py +++ b/src/diart/benchmark.py @@ -18,7 +18,7 @@ parser.add_argument("--gamma", default=3, type=float, help=f"{argdoc.GAMMA}. Defaults to 3") parser.add_argument("--beta", default=10, type=float, help=f"{argdoc.BETA}. Defaults to 10") parser.add_argument("--max-speakers", default=20, type=int, help=f"{argdoc.MAX_SPEAKERS}. Defaults to 20") - parser.add_argument("--batch-size", default=32, type=int, help="For segmentation and embedding pre-calculation. If BATCH_SIZE < 2, run fully online and estimate real-time latency. Defaults to 32") + parser.add_argument("--batch-size", default=32, type=int, help=f"{argdoc.BATCH_SIZE}. Defaults to 32") parser.add_argument("--cpu", dest="cpu", action="store_true", help=f"{argdoc.CPU}. Defaults to GPU if available, CPU otherwise") parser.add_argument("--output", type=str, help=f"{argdoc.OUTPUT}. Defaults to `root`") args = parser.parse_args() diff --git a/src/diart/optim.py b/src/diart/optim.py index a71bd571..db616db5 100644 --- a/src/diart/optim.py +++ b/src/diart/optim.py @@ -19,6 +19,16 @@ class HyperParameter: low: float high: float + @staticmethod + def from_name(name: Text) -> 'HyperParameter': + if name == "tau_active": + return TauActive + if name == "rho_update": + return RhoUpdate + if name == "delta_new": + return DeltaNew + raise ValueError(f"Hyper-parameter '{name}' not recognized") + TauActive = HyperParameter("tau_active", low=0, high=1) RhoUpdate = HyperParameter("rho_update", low=0, high=1) @@ -42,9 +52,9 @@ def __init__( self.study = study_or_path elif isinstance(study_or_path, str) or isinstance(study_or_path, Path): self.study = create_study( - storage="sqlite:///" + str(study_or_path / "trials.db"), + storage="sqlite:///" + str(study_or_path / f"{study_or_path.stem}.db"), sampler=TPESampler(), - study_name=study_or_path.name, + study_name=study_or_path.stem, direction="minimize", load_if_exists=True, ) diff --git a/src/diart/tune.py b/src/diart/tune.py new file mode 100644 index 00000000..a35060fd --- /dev/null +++ b/src/diart/tune.py @@ -0,0 +1,68 @@ +import argparse +from pathlib import Path +from uuid import uuid4 + +import optuna +import torch +from optuna.samplers import TPESampler + +import diart.argdoc as argdoc +from diart.inference import Benchmark +from diart.optim import Optimizer, HyperParameter +from diart.pipelines import PipelineConfig + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("root", type=str, help="Directory with audio files CONVERSATION.(wav|flac|m4a|...)") + parser.add_argument("--reference", required=True, type=str, help="Directory with RTTM files CONVERSATION.rttm. Names must match audio files") + parser.add_argument("--step", default=0.5, type=float, help=f"{argdoc.STEP}. Defaults to 0.5") + parser.add_argument("--latency", default=0.5, type=float, help=f"{argdoc.LATENCY}. Defaults to 0.5") + parser.add_argument("--tau", default=0.5, type=float, help=f"{argdoc.TAU}. Defaults to 0.5") + parser.add_argument("--rho", default=0.3, type=float, help=f"{argdoc.RHO}. Defaults to 0.3") + parser.add_argument("--delta", default=1, type=float, help=f"{argdoc.DELTA}. Defaults to 1") + parser.add_argument("--gamma", default=3, type=float, help=f"{argdoc.GAMMA}. Defaults to 3") + parser.add_argument("--beta", default=10, type=float, help=f"{argdoc.BETA}. Defaults to 10") + parser.add_argument("--max-speakers", default=20, type=int, help=f"{argdoc.MAX_SPEAKERS}. Defaults to 20") + parser.add_argument("--batch-size", default=32, type=int, help=f"{argdoc.BATCH_SIZE}. Defaults to 32") + parser.add_argument("--cpu", dest="cpu", action="store_true", help=f"{argdoc.CPU}. Defaults to GPU if available, CPU otherwise") + parser.add_argument("--hparams", nargs="+", default=("tau_active", "rho_update", "delta_new"), help="Hyper-parameters to optimize. Must match names in `PipelineConfig`. Defaults to tau_active, rho_update and delta_new") + parser.add_argument("--num-iter", default=100, type=int, help="Number of optimization trials") + parser.add_argument("--storage", type=str, help="Optuna storage string. If provided, continue a previous study instead of creating one. The database name must match the study name") + parser.add_argument("--output", required=True, type=str, help="Working directory") + args = parser.parse_args() + args.output = Path(args.output) + args.output.mkdir(parents=True, exist_ok=True) + args.device = torch.device("cpu") if args.cpu else None + + # Assign unique worker ID + idx = uuid4() + + # Create benchmark object to run the pipeline on a set of files + work_path = args.output / f"worker-{idx}" + benchmark = Benchmark( + args.root, + args.reference, + work_path, + show_progress=True, + show_report=False, + batch_size=args.batch_size + ) + + # Create the base configuration for each trial + base_config = PipelineConfig.from_namespace(args) + + # Create hyper-parameters to optimize + hparams = [HyperParameter.from_name(name) for name in args.hparams] + + # Use a custom storage if given + study_or_path = args.output + if args.storage is not None: + db_name = Path(args.storage).stem + study_or_path = optuna.load_study(db_name, args.storage, TPESampler()) + + # Run optimization + optimizer = Optimizer(benchmark, base_config, hparams, study_or_path) + optimizer.optimize(num_iter=args.num_iter, show_progress=True) + + # Clean temporary directory + work_path.rmdir() From eec293678d53f3ce604d55070da0500277da1bdf Mon Sep 17 00:00:00 2001 From: Juan Manuel Coria Date: Thu, 16 Jun 2022 16:54:45 +0200 Subject: [PATCH 25/29] Simplify loading segmentation and embedding models from pyannote --- README.md | 9 ++------- src/diart/blocks/embedding.py | 15 +++++++++++++++ src/diart/blocks/segmentation.py | 4 ++++ 3 files changed, 21 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index 204a3a2f..acf3abfa 100644 --- a/README.md +++ b/README.md @@ -196,14 +196,9 @@ import rx.operators as ops import diart.operators as dops from diart.sources import MicrophoneAudioSource from diart.blocks import SpeakerSegmentation, OverlapAwareSpeakerEmbedding -from diart.models import SegmentationModel, EmbeddingModel -segmentation = SpeakerSegmentation( - SegmentationModel.from_pyannote("pyannote/segmentation") -) -embedding = OverlapAwareSpeakerEmbedding( - EmbeddingModel.from_pyannote("pyannote/embedding") -) +segmentation = SpeakerSegmentation.from_pyannote("pyannote/segmentation") +embedding = OverlapAwareSpeakerEmbedding.from_pyannote("pyannote/embedding") sample_rate = segmentation.model.get_sample_rate() mic = MicrophoneAudioSource(sample_rate) diff --git a/src/diart/blocks/embedding.py b/src/diart/blocks/embedding.py index 19288876..8855ae08 100644 --- a/src/diart/blocks/embedding.py +++ b/src/diart/blocks/embedding.py @@ -18,6 +18,10 @@ def __init__(self, model: EmbeddingModel, device: Optional[torch.device] = None) self.waveform_formatter = TemporalFeatureFormatter() self.weights_formatter = TemporalFeatureFormatter() + @staticmethod + def from_pyannote(model, device: Optional[torch.device] = None) -> 'SpeakerEmbedding': + return SpeakerEmbedding(EmbeddingModel.from_pyannote(model), device) + def __call__(self, waveform: TemporalFeatures, weights: Optional[TemporalFeatures] = None) -> torch.Tensor: """ Calculate speaker embeddings of input audio. @@ -134,5 +138,16 @@ def __init__( self.osp = OverlappedSpeechPenalty(gamma, beta) self.normalize = EmbeddingNormalization(norm) + @staticmethod + def from_pyannote( + model, + gamma: float = 3, + beta: float = 10, + norm: Union[float, torch.Tensor] = 1, + device: Optional[torch.device] = None, + ): + model = EmbeddingModel.from_pyannote(model) + return OverlapAwareSpeakerEmbedding(model, gamma, beta, norm, device) + def __call__(self, waveform: TemporalFeatures, segmentation: TemporalFeatures) -> torch.Tensor: return self.normalize(self.embedding(waveform, self.osp(segmentation))) diff --git a/src/diart/blocks/segmentation.py b/src/diart/blocks/segmentation.py index 33064207..1b310944 100644 --- a/src/diart/blocks/segmentation.py +++ b/src/diart/blocks/segmentation.py @@ -17,6 +17,10 @@ def __init__(self, model: SegmentationModel, device: Optional[torch.device] = No self.model.to(self.device) self.formatter = TemporalFeatureFormatter() + @staticmethod + def from_pyannote(model, device: Optional[torch.device] = None) -> 'SpeakerSegmentation': + return SpeakerSegmentation(SegmentationModel.from_pyannote(model), device) + def __call__(self, waveform: TemporalFeatures) -> TemporalFeatures: """ Calculate the speaker segmentation of input audio. From ad041525fdf768b474021ab71093077566c3c5f2 Mon Sep 17 00:00:00 2001 From: Juan Manuel Coria Date: Thu, 16 Jun 2022 17:15:23 +0200 Subject: [PATCH 26/29] Register stream, benchmark and tune as entry points during installation --- README.md | 16 ++++++++-------- setup.cfg | 9 +++++++-- src/diart/benchmark.py | 13 ++++++++++--- src/diart/stream.py | 13 ++++++++++--- src/diart/tune.py | 19 ++++++++++++++----- 5 files changed, 49 insertions(+), 21 deletions(-) diff --git a/README.md b/README.md index acf3abfa..7efd2c47 100644 --- a/README.md +++ b/README.md @@ -88,16 +88,16 @@ pip install diart A recorded conversation: ```shell -python -m diart.stream /path/to/audio.wav +diart.stream /path/to/audio.wav ``` A live conversation: ```shell -python -m diart.stream microphone +diart.stream microphone ``` -See `python -m diart.stream -h` for more options. +See `diart.stream -h` for more options. ### From python @@ -124,10 +124,10 @@ Diart implements a hyper-parameter optimizer based on [optuna](https://optuna.re ### From the command line ```shell -python -m diart.tune /wav/dir --reference /rttm/dir --output /out/dir +diart.tune /wav/dir --reference /rttm/dir --output /out/dir ``` -See `python -m diart.tune -h` for more options. +See `diart.tune -h` for more options. ### From python @@ -163,7 +163,7 @@ optuna create-study --study-name "example" --storage "mysql://root@localhost/exa Then you can run multiple identical optimizers pointing to the database: ```shell -python -m diart.tune /wav/dir --reference /rttm/dir --output /out/dir --storage mysql://root@localhost/example +diart.tune /wav/dir --reference /rttm/dir --output /out/dir --storage mysql://root@localhost/example ``` If you are using the python API, make sure that worker directories are different to avoid concurrency issues: @@ -267,7 +267,7 @@ To obtain the best results, make sure to use the following hyper-parameters: `diart.benchmark` and `diart.inference.Benchmark` can quickly run and evaluate the pipeline, and even measure its real-time latency. For instance, for a DIHARD III configuration: ```shell -python -m diart.benchmark /wav/dir --reference /rttm/dir --tau=0.555 --rho=0.422 --delta=1.517 --output /out/dir +diart.benchmark /wav/dir --reference /rttm/dir --tau=0.555 --rho=0.422 --delta=1.517 --output /out/dir ``` or using the inference API: @@ -290,7 +290,7 @@ benchmark(pipeline) ``` This runs a faster inference by pre-calculating model outputs in batches. -See `python -m diart.benchmark -h` for more options. +See `diart.benchmark -h` for more options. For convenience and to facilitate future comparisons, we also provide the [expected outputs](/expected_outputs) of the paper implementation in RTTM format for every entry of Table 1 and Figure 5. This includes the VBx offline topline as well as our proposed online approach with latencies 500ms, 1s, 2s, 3s, 4s, and 5s. diff --git a/setup.cfg b/setup.cfg index a0f7c776..214ca57b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -19,7 +19,7 @@ classifiers= package_dir= =src packages=find: -install_requires = +install_requires= numpy>=1.20.2 matplotlib>=3.3.3 rx>=3.2.0 @@ -34,6 +34,11 @@ install_requires = pyannote.metrics>=3.2 optuna>=2.10 - [options.packages.find] where=src + +[options.entry_points] +console_scripts= + diart.stream=diart.stream:run + diart.benchmark=diart.benchmark:run + diart.tune=diart.tune:run diff --git a/src/diart/benchmark.py b/src/diart/benchmark.py index 674942a3..ad0969fb 100644 --- a/src/diart/benchmark.py +++ b/src/diart/benchmark.py @@ -6,10 +6,12 @@ from diart.inference import Benchmark from diart.pipelines import OnlineSpeakerDiarization, PipelineConfig -if __name__ == "__main__": + +def run(): parser = argparse.ArgumentParser() parser.add_argument("root", type=str, help="Directory with audio files CONVERSATION.(wav|flac|m4a|...)") - parser.add_argument("--reference", type=str, help="Optional. Directory with RTTM files CONVERSATION.rttm. Names must match audio files") + parser.add_argument("--reference", type=str, + help="Optional. Directory with RTTM files CONVERSATION.rttm. Names must match audio files") parser.add_argument("--step", default=0.5, type=float, help=f"{argdoc.STEP}. Defaults to 0.5") parser.add_argument("--latency", default=0.5, type=float, help=f"{argdoc.LATENCY}. Defaults to 0.5") parser.add_argument("--tau", default=0.5, type=float, help=f"{argdoc.TAU}. Defaults to 0.5") @@ -19,7 +21,8 @@ parser.add_argument("--beta", default=10, type=float, help=f"{argdoc.BETA}. Defaults to 10") parser.add_argument("--max-speakers", default=20, type=int, help=f"{argdoc.MAX_SPEAKERS}. Defaults to 20") parser.add_argument("--batch-size", default=32, type=int, help=f"{argdoc.BATCH_SIZE}. Defaults to 32") - parser.add_argument("--cpu", dest="cpu", action="store_true", help=f"{argdoc.CPU}. Defaults to GPU if available, CPU otherwise") + parser.add_argument("--cpu", dest="cpu", action="store_true", + help=f"{argdoc.CPU}. Defaults to GPU if available, CPU otherwise") parser.add_argument("--output", type=str, help=f"{argdoc.OUTPUT}. Defaults to `root`") args = parser.parse_args() args.device = torch.device("cpu") if args.cpu else None @@ -34,3 +37,7 @@ ) benchmark(OnlineSpeakerDiarization(PipelineConfig.from_namespace(args), profile=True)) + + +if __name__ == "__main__": + run() diff --git a/src/diart/stream.py b/src/diart/stream.py index c263364d..63232e43 100644 --- a/src/diart/stream.py +++ b/src/diart/stream.py @@ -8,7 +8,8 @@ from diart.inference import RealTimeInference from diart.pipelines import OnlineSpeakerDiarization, PipelineConfig -if __name__ == "__main__": + +def run(): parser = argparse.ArgumentParser() parser.add_argument("source", type=str, help="Path to an audio file | 'microphone'") parser.add_argument("--step", default=0.5, type=float, help=f"{argdoc.STEP}. Defaults to 0.5") @@ -20,8 +21,10 @@ parser.add_argument("--beta", default=10, type=float, help=f"{argdoc.BETA}. Defaults to 10") parser.add_argument("--max-speakers", default=20, type=int, help=f"{argdoc.MAX_SPEAKERS}. Defaults to 20") parser.add_argument("--no-plot", dest="no_plot", action="store_true", help="Skip plotting for faster inference") - parser.add_argument("--cpu", dest="cpu", action="store_true", help=f"{argdoc.CPU}. Defaults to GPU if available, CPU otherwise") - parser.add_argument("--output", type=str, help=f"{argdoc.OUTPUT}. Defaults to home directory if SOURCE == 'microphone' or parent directory if SOURCE is a file") + parser.add_argument("--cpu", dest="cpu", action="store_true", + help=f"{argdoc.CPU}. Defaults to GPU if available, CPU otherwise") + parser.add_argument("--output", type=str, + help=f"{argdoc.OUTPUT}. Defaults to home directory if SOURCE == 'microphone' or parent directory if SOURCE is a file") args = parser.parse_args() args.device = torch.device("cpu") if args.cpu else None @@ -42,3 +45,7 @@ # Run online inference RealTimeInference(args.output, do_plot=not args.no_plot)(pipeline, audio_source) + + +if __name__ == "__main__": + run() diff --git a/src/diart/tune.py b/src/diart/tune.py index a35060fd..000e3da8 100644 --- a/src/diart/tune.py +++ b/src/diart/tune.py @@ -11,10 +11,12 @@ from diart.optim import Optimizer, HyperParameter from diart.pipelines import PipelineConfig -if __name__ == "__main__": + +def run(): parser = argparse.ArgumentParser() parser.add_argument("root", type=str, help="Directory with audio files CONVERSATION.(wav|flac|m4a|...)") - parser.add_argument("--reference", required=True, type=str, help="Directory with RTTM files CONVERSATION.rttm. Names must match audio files") + parser.add_argument("--reference", required=True, type=str, + help="Directory with RTTM files CONVERSATION.rttm. Names must match audio files") parser.add_argument("--step", default=0.5, type=float, help=f"{argdoc.STEP}. Defaults to 0.5") parser.add_argument("--latency", default=0.5, type=float, help=f"{argdoc.LATENCY}. Defaults to 0.5") parser.add_argument("--tau", default=0.5, type=float, help=f"{argdoc.TAU}. Defaults to 0.5") @@ -24,10 +26,13 @@ parser.add_argument("--beta", default=10, type=float, help=f"{argdoc.BETA}. Defaults to 10") parser.add_argument("--max-speakers", default=20, type=int, help=f"{argdoc.MAX_SPEAKERS}. Defaults to 20") parser.add_argument("--batch-size", default=32, type=int, help=f"{argdoc.BATCH_SIZE}. Defaults to 32") - parser.add_argument("--cpu", dest="cpu", action="store_true", help=f"{argdoc.CPU}. Defaults to GPU if available, CPU otherwise") - parser.add_argument("--hparams", nargs="+", default=("tau_active", "rho_update", "delta_new"), help="Hyper-parameters to optimize. Must match names in `PipelineConfig`. Defaults to tau_active, rho_update and delta_new") + parser.add_argument("--cpu", dest="cpu", action="store_true", + help=f"{argdoc.CPU}. Defaults to GPU if available, CPU otherwise") + parser.add_argument("--hparams", nargs="+", default=("tau_active", "rho_update", "delta_new"), + help="Hyper-parameters to optimize. Must match names in `PipelineConfig`. Defaults to tau_active, rho_update and delta_new") parser.add_argument("--num-iter", default=100, type=int, help="Number of optimization trials") - parser.add_argument("--storage", type=str, help="Optuna storage string. If provided, continue a previous study instead of creating one. The database name must match the study name") + parser.add_argument("--storage", type=str, + help="Optuna storage string. If provided, continue a previous study instead of creating one. The database name must match the study name") parser.add_argument("--output", required=True, type=str, help="Working directory") args = parser.parse_args() args.output = Path(args.output) @@ -66,3 +71,7 @@ # Clean temporary directory work_path.rmdir() + + +if __name__ == "__main__": + run() From 676b2b515296eff2a19cba6dd25f6c0a27c056f8 Mon Sep 17 00:00:00 2001 From: Juan Manuel Coria Date: Thu, 16 Jun 2022 18:02:20 +0200 Subject: [PATCH 27/29] Add custom embedding model example --- README.md | 35 +++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/README.md b/README.md index 7efd2c47..53f5cdba 100644 --- a/README.md +++ b/README.md @@ -22,6 +22,10 @@ Stream audio | + + Add your model + + | Tune hyper-parameters @@ -117,6 +121,37 @@ inference(pipeline, audio_source) For faster inference and evaluation on a dataset we recommend to use `Benchmark` instead (see our notes on [reproducibility](#reproducibility)). +## Add your model + +Third-party segmentation and embedding models can be integrated seamlessly by subclassing `SegmentationModel` and `EmbeddingModel`: + +```python +import torch +from typing import Optional +from diart.models import EmbeddingModel +from diart.pipelines import PipelineConfig, OnlineSpeakerDiarization +from diart.sources import MicrophoneAudioSource +from diart.inference import RealTimeInference + +class MyEmbeddingModel(EmbeddingModel): + def __init__(self): + super().__init__() + self.my_pretrained_model = load("my_model.ckpt") + + def __call__( + self, + waveform: torch.Tensor, + weights: Optional[torch.Tensor] = None + ) -> torch.Tensor: + return self.my_pretrained_model(waveform, weights) + +config = PipelineConfig(embedding=MyEmbeddingModel()) +pipeline = OnlineSpeakerDiarization(config) +mic = MicrophoneAudioSource(config.sample_rate) +inference = RealTimeInference("/out/dir") +inference(pipeline, mic) +``` + ## Tune hyper-parameters Diart implements a hyper-parameter optimizer based on [optuna](https://optuna.readthedocs.io/en/stable/index.html) that allows you to tune any pipeline to any dataset. From d551d9806f6e98e92dfb202fc755b486eef7c3f3 Mon Sep 17 00:00:00 2001 From: Juan Manuel Coria Date: Wed, 13 Jul 2022 12:18:25 +0200 Subject: [PATCH 28/29] Update version --- setup.cfg | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.cfg b/setup.cfg index 214ca57b..5ba6a171 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,6 @@ [metadata] name=diart -version=0.3.0 +version=0.4.0 author=Juan Manuel Coria description=Speaker diarization in real time long_description=file: README.md @@ -9,7 +9,7 @@ keywords=speaker diarization, streaming, online, real time, rxpy url=https://github.com/juanmc2005/StreamingSpeakerDiarization license=MIT classifiers= - Development Status :: 3 - Alpha + Development Status :: 4 - Beta License :: OSI Approved :: MIT License Topic :: Multimedia :: Sound/Audio :: Analysis Topic :: Multimedia :: Sound/Audio :: Speech From b40f091c27db65e990daf4fd419c882f79b164cb Mon Sep 17 00:00:00 2001 From: Juan Coria Date: Wed, 13 Jul 2022 14:46:03 +0200 Subject: [PATCH 29/29] Split menu in README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 53f5cdba..7cf7030b 100644 --- a/README.md +++ b/README.md @@ -33,7 +33,7 @@ Build pipelines - | +
Research